Improve PPO diagnostics and recharge behavior

This commit is contained in:
2026-04-26 20:24:26 +08:00
parent 5b6133db13
commit 69b8a692db
6 changed files with 463 additions and 31 deletions

View File

@@ -26,6 +26,8 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
last_save_model_time = time.time()
env = envs[0]
agent = agents[0]
diag_max_episodes = _read_diag_max_episodes(logger)
diag_log_only = _read_bool_env("ROBOT_VACUUM_DIAG_LOG_ONLY")
# Read and validate user configuration
# 读取和校验用户配置
@@ -33,6 +35,7 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
if usr_conf is None:
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
return
_apply_diag_env_overrides(usr_conf, logger)
episode_runner = EpisodeRunner(
env=env,
@@ -40,9 +43,11 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
usr_conf=usr_conf,
logger=logger,
monitor=monitor,
diag_max_episodes=diag_max_episodes,
diag_log_only=diag_log_only,
)
while True:
while not episode_runner.stop_requested:
for g_data in episode_runner.run_episodes():
agent.send_sample_data(g_data)
g_data.clear()
@@ -51,10 +56,56 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
if now - last_save_model_time >= 1800:
agent.save_model()
last_save_model_time = now
if episode_runner.stop_requested:
break
if episode_runner.stop_requested:
logger.info(f"diagnostic max episodes reached: {episode_runner.episode_cnt}")
def _read_diag_max_episodes(logger):
raw_value = os.environ.get("ROBOT_VACUUM_DIAG_MAX_EPISODES", "").strip()
if not raw_value:
return 0
try:
value = int(raw_value)
except ValueError:
if logger:
logger.warning(f"ignore invalid ROBOT_VACUUM_DIAG_MAX_EPISODES={raw_value!r}")
return 0
return max(value, 0)
def _read_positive_int_env(name, logger):
raw_value = os.environ.get(name, "").strip()
if not raw_value:
return 0
try:
value = int(raw_value)
except ValueError:
if logger:
logger.warning(f"ignore invalid {name}={raw_value!r}")
return 0
return max(value, 0)
def _read_bool_env(name):
return os.environ.get(name, "").strip().lower() in ("1", "true", "yes", "on")
def _apply_diag_env_overrides(usr_conf, logger):
diag_max_step = _read_positive_int_env("ROBOT_VACUUM_DIAG_MAX_STEP", logger)
if diag_max_step <= 0:
return
env_conf = usr_conf.setdefault("env_conf", {})
old_max_step = env_conf.get("max_step")
env_conf["max_step"] = diag_max_step
if logger:
logger.info(f"diagnostic max_step override: {old_max_step} -> {diag_max_step}")
class EpisodeRunner:
def __init__(self, env, agent, usr_conf, logger, monitor):
def __init__(self, env, agent, usr_conf, logger, monitor, diag_max_episodes=0, diag_log_only=False):
self.env = env
self.agent = agent
self.usr_conf = usr_conf
@@ -63,6 +114,9 @@ class EpisodeRunner:
self.episode_cnt = 0
self.last_report_monitor_time = 0
self.last_get_training_metrics_time = 0
self.diag_max_episodes = int(diag_max_episodes)
self.diag_log_only = bool(diag_log_only)
self.stop_requested = False
def run_episodes(self):
"""Run a single episode and yield collected samples.
@@ -70,6 +124,8 @@ class EpisodeRunner:
单局流程generator完成一局后 yield 整局样本。
"""
while True:
if self.stop_requested:
return
# Periodically get training metrics
# 定期打印训练指标
now = time.time()
@@ -188,6 +244,39 @@ class EpisodeRunner:
f"result_code:{result_code} "
f"result_message:{result_message}"
)
diag = fm.get_diagnostic_summary()
self.logger.info(
f"[DIAG] ep:{self.episode_cnt} map:{diag['map_id']} "
f"steps:{step} result:{result_str} "
f"profile:{diag['reward_profile']} route:{diag['charger_route_source']} "
f"score:{float(total_score):.1f} reward:{total_reward + final_reward:.3f} "
f"mask_avg(raw/block/npc/recharge/escape/leave/final):"
f"{diag['avg_mask_counts']['raw']:.2f}/"
f"{diag['avg_mask_counts']['blocked']:.2f}/"
f"{diag['avg_mask_counts']['npc']:.2f}/"
f"{diag['avg_mask_counts']['recharge']:.2f}/"
f"{diag['avg_mask_counts']['escape']:.2f}/"
f"{diag['avg_mask_counts']['leave']:.2f}/"
f"{diag['avg_mask_counts']['final']:.2f} "
f"mask_changed(block/npc/recharge/escape/leave):"
f"{diag['mask_changed_steps']['blocked']}/"
f"{diag['mask_changed_steps']['npc']}/"
f"{diag['mask_changed_steps']['recharge']}/"
f"{diag['mask_changed_steps']['escape']}/"
f"{diag['mask_changed_steps']['leave']} "
f"mask_active(recharge/leave):"
f"{diag['mask_active_steps']['recharge']}/"
f"{diag['mask_active_steps']['leave']} "
f"tight(one/<=2/zero):"
f"{diag['one_action_steps']}/"
f"{diag['two_or_less_action_steps']}/"
f"{diag['zero_final_steps']} "
f"actions:{diag['action_hist']} "
f"known:{diag['known_ratio']:.3f} dirty_known:{diag['known_dirty_ratio']:.3f} "
f"frontier:{diag['frontier_ratio']:.3f} "
f"path_dirty/frontier:{diag['global_dirty_path_dist']:.1f}/"
f"{diag['frontier_path_dist']:.1f}"
)
# Build sample frame
# 构造样本帧
@@ -212,6 +301,9 @@ class EpisodeRunner:
# Add terminal reward to last frame
# 终局奖励叠加到最后一步
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
if truncated and not terminated:
collector[-1].next_value = self.agent.estimate_value(_obs_data)
collector[-1].done = np.array([0.0], dtype=np.float32)
# Monitor reporting / 监控上报
now = time.time()
@@ -231,6 +323,13 @@ class EpisodeRunner:
"battery_fail": float(fm.battery_fail),
"charge_count": float(charge_count),
"remaining_charge": float(remaining_charge),
"recharge_steps": float(fm.recharge_steps),
"mask_final_avg": float(diag["avg_mask_counts"]["final"]),
"mask_recharge_active": float(diag["mask_active_steps"]["recharge"]),
"mask_recharge_changed": float(diag["mask_changed_steps"]["recharge"]),
"mask_one_action_steps": float(diag["one_action_steps"]),
"mask_two_or_less_action_steps": float(diag["two_or_less_action_steps"]),
"mask_zero_final_steps": float(diag["zero_final_steps"]),
}
}
)
@@ -239,6 +338,13 @@ class EpisodeRunner:
# Compute GAE and yield samples
# GAE 计算并 yield 样本
if collector:
if self.diag_max_episodes > 0 and self.episode_cnt >= self.diag_max_episodes:
self.stop_requested = True
if self.diag_log_only:
collector.clear()
if self.stop_requested:
return
break
collector = sample_process(collector)
yield collector
break