#!/usr/bin/env python3 # -*- coding: UTF-8 -*- ########################################################################### # Copyright © 1998 - 2026 Tencent. All Rights Reserved. ########################################################################### """ Author: Tencent AI Arena Authors Training workflow for Robot Vacuum. 清扫大作战训练工作流。 """ import os import time import numpy as np from agent_ppo.conf.conf import Config from agent_ppo.feature.definition import SampleData, sample_process from tools.metrics_utils import get_training_metrics from tools.train_env_conf_validate import read_usr_conf from common_python.utils.workflow_disaster_recovery import handle_disaster_recovery def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs): last_save_model_time = time.time() env = envs[0] agent = agents[0] diag_max_episodes = _read_diag_max_episodes(logger) diag_log_only = _read_bool_env("ROBOT_VACUUM_DIAG_LOG_ONLY") # Read and validate user configuration # 读取和校验用户配置 usr_conf = read_usr_conf("agent_ppo/conf/train_env_conf.toml", logger) if usr_conf is None: logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml") return _apply_diag_env_overrides(usr_conf, logger) episode_runner = EpisodeRunner( env=env, agent=agent, usr_conf=usr_conf, logger=logger, monitor=monitor, diag_max_episodes=diag_max_episodes, diag_log_only=diag_log_only, ) while not episode_runner.stop_requested: for g_data in episode_runner.run_episodes(): agent.send_sample_data(g_data) g_data.clear() now = time.time() if now - last_save_model_time >= 1800: agent.save_model() last_save_model_time = now if episode_runner.stop_requested: break if episode_runner.stop_requested: logger.info(f"diagnostic max episodes reached: {episode_runner.episode_cnt}") def _read_diag_max_episodes(logger): raw_value = os.environ.get("ROBOT_VACUUM_DIAG_MAX_EPISODES", "").strip() if not raw_value: return 0 try: value = int(raw_value) except ValueError: if logger: logger.warning(f"ignore invalid ROBOT_VACUUM_DIAG_MAX_EPISODES={raw_value!r}") return 0 return max(value, 0) def _read_positive_int_env(name, logger): raw_value = os.environ.get(name, "").strip() if not raw_value: return 0 try: value = int(raw_value) except ValueError: if logger: logger.warning(f"ignore invalid {name}={raw_value!r}") return 0 return max(value, 0) def _read_bool_env(name): return os.environ.get(name, "").strip().lower() in ("1", "true", "yes", "on") def _apply_diag_env_overrides(usr_conf, logger): diag_max_step = _read_positive_int_env("ROBOT_VACUUM_DIAG_MAX_STEP", logger) if diag_max_step <= 0: return env_conf = usr_conf.setdefault("env_conf", {}) old_max_step = env_conf.get("max_step") env_conf["max_step"] = diag_max_step if logger: logger.info(f"diagnostic max_step override: {old_max_step} -> {diag_max_step}") class EpisodeRunner: def __init__(self, env, agent, usr_conf, logger, monitor, diag_max_episodes=0, diag_log_only=False): self.env = env self.agent = agent self.usr_conf = usr_conf self.logger = logger self.monitor = monitor self.episode_cnt = 0 self.last_report_monitor_time = 0 self.last_get_training_metrics_time = 0 self.diag_max_episodes = int(diag_max_episodes) self.diag_log_only = bool(diag_log_only) self.stop_requested = False def run_episodes(self): """Run a single episode and yield collected samples. 单局流程(generator),完成一局后 yield 整局样本。 """ while True: if self.stop_requested: return # Periodically get training metrics # 定期打印训练指标 now = time.time() if now - self.last_get_training_metrics_time >= 60: training_metrics = get_training_metrics() self.last_get_training_metrics_time = now if training_metrics is not None: self.logger.info(f"training_metrics: {training_metrics}") # Reset environment # 重置环境 env_obs = self.env.reset(self.usr_conf) if handle_disaster_recovery(env_obs, self.logger): continue # Reset agent and load latest model # 重置 Agent,加载最新模型 self.agent.reset(env_obs) self.agent.load_model(id="latest") # Initial observation processing # 初始观测 obs_data, remain_info = self.agent.observation_process(env_obs) collector = [] self.episode_cnt += 1 done = False step = 0 total_reward = 0.0 self.logger.info(f"Episode {self.episode_cnt} start") while not done: # Agent inference / 推理动作 act_data_list = self.agent.predict([obs_data]) act_data = act_data_list[0] act = self.agent.action_process(act_data) # Environment step / 与环境交互 env_reward, env_obs = self.env.step(act) if handle_disaster_recovery(env_obs, self.logger): break terminated = env_obs["terminated"] truncated = env_obs["truncated"] frame_no = env_obs["frame_no"] step += 1 done = terminated or truncated # Process next observation # 特征处理 _obs_data, _ = self.agent.observation_process(env_obs) _obs_data.frame_no = frame_no reward_scalar = float(self.agent.last_reward) total_reward += reward_scalar # Terminal reward calculation # 终局奖励 final_reward = 0.0 if done: fm = self.agent.preprocessor observation = env_obs.get("observation") or {} observation = observation if isinstance(observation, dict) else {} env_info = observation.get("env_info") or {} env_info = env_info if isinstance(env_info, dict) else {} extra_info = env_obs.get("extra_info") or {} extra_info = extra_info if isinstance(extra_info, dict) else {} total_score = env_info.get("total_score", fm.total_score) remaining_charge = env_info.get("remaining_charge", fm.remaining_charge) charge_count = env_info.get("charge_count", fm.charge_count) finished_steps = env_info.get("finished_steps", step) result_message = extra_info.get("result_message", "") result_code = extra_info.get("result_code", "") cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) score_per_step = total_score / max(finished_steps, 1) if truncated: if score_per_step < 0.25: final_reward = -3.0 + 6.0 * cleaning_ratio result_str = "STALL_TRUNCATED" else: # Survived to max steps: higher cleaning ratio → more reward # 存活到最大步数:清扫比例越高奖励越多 final_reward = 2.0 + 8.0 * cleaning_ratio result_str = "WIN" else: if fm.battery <= 0 or remaining_charge <= 0: final_reward = -fm.battery_fail_penalty() + 4.0 * cleaning_ratio result_str = "BATTERY_FAIL" elif fm.npc_danger or fm.npc_predicted_danger or fm.nearest_npc_dist <= 1: final_reward = -3.0 + 6.0 * cleaning_ratio result_str = "NPC_FAIL" else: final_reward = -2.0 + 6.0 * cleaning_ratio result_str = "FAIL" self.logger.info( f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} " f"finished_steps:{finished_steps} " f"result:{result_str} final_bonus:{final_reward:.2f} " f"total_reward:{total_reward:.3f} " f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} " f"total_score:{total_score} " f"remaining_charge:{remaining_charge} " f"charge_count:{charge_count} " f"recharge_steps:{fm.recharge_steps} " f"stuck_count:{fm.stuck_count} " f"max_stuck_steps:{fm.max_stuck_steps} " f"recharge_escape_count:{fm.recharge_escape_count} " f"npc_close_steps:{fm.npc_close_steps} " f"npc_danger_steps:{fm.npc_danger_steps} " f"npc_collision:{fm.npc_collision} " f"nearest_charger:{fm.nearest_charger_range_dist:.1f} " f"nearest_npc:{fm.nearest_npc_dist:.1f} " f"result_code:{result_code} " f"result_message:{result_message}" ) diag = fm.get_diagnostic_summary() self.logger.info( f"[DIAG] ep:{self.episode_cnt} map:{diag['map_id']} " f"steps:{step} result:{result_str} " f"profile:{diag['reward_profile']} route:{diag['charger_route_source']} " f"score:{float(total_score):.1f} reward:{total_reward + final_reward:.3f} " f"mask_avg(raw/block/npc/recharge/escape/leave/final):" f"{diag['avg_mask_counts']['raw']:.2f}/" f"{diag['avg_mask_counts']['blocked']:.2f}/" f"{diag['avg_mask_counts']['npc']:.2f}/" f"{diag['avg_mask_counts']['recharge']:.2f}/" f"{diag['avg_mask_counts']['escape']:.2f}/" f"{diag['avg_mask_counts']['leave']:.2f}/" f"{diag['avg_mask_counts']['final']:.2f} " f"mask_changed(block/npc/recharge/escape/leave):" f"{diag['mask_changed_steps']['blocked']}/" f"{diag['mask_changed_steps']['npc']}/" f"{diag['mask_changed_steps']['recharge']}/" f"{diag['mask_changed_steps']['escape']}/" f"{diag['mask_changed_steps']['leave']} " f"mask_active(recharge/leave):" f"{diag['mask_active_steps']['recharge']}/" f"{diag['mask_active_steps']['leave']} " f"tight(one/<=2/zero):" f"{diag['one_action_steps']}/" f"{diag['two_or_less_action_steps']}/" f"{diag['zero_final_steps']} " f"actions:{diag['action_hist']} " f"known:{diag['known_ratio']:.3f} dirty_known:{diag['known_dirty_ratio']:.3f} " f"frontier:{diag['frontier_ratio']:.3f} " f"path_dirty/frontier:{diag['global_dirty_path_dist']:.1f}/" f"{diag['frontier_path_dist']:.1f}" ) # Build sample frame # 构造样本帧 reward_arr = np.array([reward_scalar], dtype=np.float32) value_arr = act_data.value.flatten()[: Config.VALUE_NUM] frame = SampleData( obs=np.array(obs_data.feature, dtype=np.float32), legal_action=np.array(obs_data.legal_action, dtype=np.float32), act=np.array(act_data.action), reward=reward_arr, done=np.array([float(done)]), reward_sum=np.zeros(Config.VALUE_NUM, dtype=np.float32), value=value_arr, next_value=np.zeros(Config.VALUE_NUM, dtype=np.float32), advantage=np.zeros(Config.VALUE_NUM, dtype=np.float32), prob=np.array(act_data.prob, dtype=np.float32), ) collector.append(frame) if done: # Add terminal reward to last frame # 终局奖励叠加到最后一步 collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32) if truncated and not terminated: collector[-1].next_value = self.agent.estimate_value(_obs_data) collector[-1].done = np.array([0.0], dtype=np.float32) # Monitor reporting / 监控上报 now = time.time() if now - self.last_report_monitor_time >= 60 and self.monitor: self.monitor.put_data( { os.getpid(): { "reward": total_reward + final_reward, "episode_cnt": self.episode_cnt, "total_score": float(total_score), "stuck_count": float(fm.stuck_count), "max_stuck_steps": float(fm.max_stuck_steps), "recharge_escape_count": float(fm.recharge_escape_count), "npc_close_steps": float(fm.npc_close_steps), "npc_danger_steps": float(fm.npc_danger_steps), "npc_collision": float(fm.npc_collision), "battery_fail": float(fm.battery_fail), "charge_count": float(charge_count), "remaining_charge": float(remaining_charge), "recharge_steps": float(fm.recharge_steps), "mask_final_avg": float(diag["avg_mask_counts"]["final"]), "mask_recharge_active": float(diag["mask_active_steps"]["recharge"]), "mask_recharge_changed": float(diag["mask_changed_steps"]["recharge"]), "mask_one_action_steps": float(diag["one_action_steps"]), "mask_two_or_less_action_steps": float(diag["two_or_less_action_steps"]), "mask_zero_final_steps": float(diag["zero_final_steps"]), } } ) self.last_report_monitor_time = now # Compute GAE and yield samples # GAE 计算并 yield 样本 if collector: if self.diag_max_episodes > 0 and self.episode_cnt >= self.diag_max_episodes: self.stop_requested = True if self.diag_log_only: collector.clear() if self.stop_requested: return break collector = sample_process(collector) yield collector break # Advance state / 状态推进 obs_data = _obs_data