Improve PPO diagnostics and recharge behavior
This commit is contained in:
@@ -26,6 +26,8 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
last_save_model_time = time.time()
|
||||
env = envs[0]
|
||||
agent = agents[0]
|
||||
diag_max_episodes = _read_diag_max_episodes(logger)
|
||||
diag_log_only = _read_bool_env("ROBOT_VACUUM_DIAG_LOG_ONLY")
|
||||
|
||||
# Read and validate user configuration
|
||||
# 读取和校验用户配置
|
||||
@@ -33,6 +35,7 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
if usr_conf is None:
|
||||
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
|
||||
return
|
||||
_apply_diag_env_overrides(usr_conf, logger)
|
||||
|
||||
episode_runner = EpisodeRunner(
|
||||
env=env,
|
||||
@@ -40,9 +43,11 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
usr_conf=usr_conf,
|
||||
logger=logger,
|
||||
monitor=monitor,
|
||||
diag_max_episodes=diag_max_episodes,
|
||||
diag_log_only=diag_log_only,
|
||||
)
|
||||
|
||||
while True:
|
||||
while not episode_runner.stop_requested:
|
||||
for g_data in episode_runner.run_episodes():
|
||||
agent.send_sample_data(g_data)
|
||||
g_data.clear()
|
||||
@@ -51,10 +56,56 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
if now - last_save_model_time >= 1800:
|
||||
agent.save_model()
|
||||
last_save_model_time = now
|
||||
if episode_runner.stop_requested:
|
||||
break
|
||||
|
||||
if episode_runner.stop_requested:
|
||||
logger.info(f"diagnostic max episodes reached: {episode_runner.episode_cnt}")
|
||||
|
||||
|
||||
def _read_diag_max_episodes(logger):
|
||||
raw_value = os.environ.get("ROBOT_VACUUM_DIAG_MAX_EPISODES", "").strip()
|
||||
if not raw_value:
|
||||
return 0
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except ValueError:
|
||||
if logger:
|
||||
logger.warning(f"ignore invalid ROBOT_VACUUM_DIAG_MAX_EPISODES={raw_value!r}")
|
||||
return 0
|
||||
return max(value, 0)
|
||||
|
||||
|
||||
def _read_positive_int_env(name, logger):
|
||||
raw_value = os.environ.get(name, "").strip()
|
||||
if not raw_value:
|
||||
return 0
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except ValueError:
|
||||
if logger:
|
||||
logger.warning(f"ignore invalid {name}={raw_value!r}")
|
||||
return 0
|
||||
return max(value, 0)
|
||||
|
||||
|
||||
def _read_bool_env(name):
|
||||
return os.environ.get(name, "").strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def _apply_diag_env_overrides(usr_conf, logger):
|
||||
diag_max_step = _read_positive_int_env("ROBOT_VACUUM_DIAG_MAX_STEP", logger)
|
||||
if diag_max_step <= 0:
|
||||
return
|
||||
env_conf = usr_conf.setdefault("env_conf", {})
|
||||
old_max_step = env_conf.get("max_step")
|
||||
env_conf["max_step"] = diag_max_step
|
||||
if logger:
|
||||
logger.info(f"diagnostic max_step override: {old_max_step} -> {diag_max_step}")
|
||||
|
||||
|
||||
class EpisodeRunner:
|
||||
def __init__(self, env, agent, usr_conf, logger, monitor):
|
||||
def __init__(self, env, agent, usr_conf, logger, monitor, diag_max_episodes=0, diag_log_only=False):
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.usr_conf = usr_conf
|
||||
@@ -63,6 +114,9 @@ class EpisodeRunner:
|
||||
self.episode_cnt = 0
|
||||
self.last_report_monitor_time = 0
|
||||
self.last_get_training_metrics_time = 0
|
||||
self.diag_max_episodes = int(diag_max_episodes)
|
||||
self.diag_log_only = bool(diag_log_only)
|
||||
self.stop_requested = False
|
||||
|
||||
def run_episodes(self):
|
||||
"""Run a single episode and yield collected samples.
|
||||
@@ -70,6 +124,8 @@ class EpisodeRunner:
|
||||
单局流程(generator),完成一局后 yield 整局样本。
|
||||
"""
|
||||
while True:
|
||||
if self.stop_requested:
|
||||
return
|
||||
# Periodically get training metrics
|
||||
# 定期打印训练指标
|
||||
now = time.time()
|
||||
@@ -188,6 +244,39 @@ class EpisodeRunner:
|
||||
f"result_code:{result_code} "
|
||||
f"result_message:{result_message}"
|
||||
)
|
||||
diag = fm.get_diagnostic_summary()
|
||||
self.logger.info(
|
||||
f"[DIAG] ep:{self.episode_cnt} map:{diag['map_id']} "
|
||||
f"steps:{step} result:{result_str} "
|
||||
f"profile:{diag['reward_profile']} route:{diag['charger_route_source']} "
|
||||
f"score:{float(total_score):.1f} reward:{total_reward + final_reward:.3f} "
|
||||
f"mask_avg(raw/block/npc/recharge/escape/leave/final):"
|
||||
f"{diag['avg_mask_counts']['raw']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['blocked']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['npc']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['recharge']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['escape']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['leave']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['final']:.2f} "
|
||||
f"mask_changed(block/npc/recharge/escape/leave):"
|
||||
f"{diag['mask_changed_steps']['blocked']}/"
|
||||
f"{diag['mask_changed_steps']['npc']}/"
|
||||
f"{diag['mask_changed_steps']['recharge']}/"
|
||||
f"{diag['mask_changed_steps']['escape']}/"
|
||||
f"{diag['mask_changed_steps']['leave']} "
|
||||
f"mask_active(recharge/leave):"
|
||||
f"{diag['mask_active_steps']['recharge']}/"
|
||||
f"{diag['mask_active_steps']['leave']} "
|
||||
f"tight(one/<=2/zero):"
|
||||
f"{diag['one_action_steps']}/"
|
||||
f"{diag['two_or_less_action_steps']}/"
|
||||
f"{diag['zero_final_steps']} "
|
||||
f"actions:{diag['action_hist']} "
|
||||
f"known:{diag['known_ratio']:.3f} dirty_known:{diag['known_dirty_ratio']:.3f} "
|
||||
f"frontier:{diag['frontier_ratio']:.3f} "
|
||||
f"path_dirty/frontier:{diag['global_dirty_path_dist']:.1f}/"
|
||||
f"{diag['frontier_path_dist']:.1f}"
|
||||
)
|
||||
|
||||
# Build sample frame
|
||||
# 构造样本帧
|
||||
@@ -212,6 +301,9 @@ class EpisodeRunner:
|
||||
# Add terminal reward to last frame
|
||||
# 终局奖励叠加到最后一步
|
||||
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
|
||||
if truncated and not terminated:
|
||||
collector[-1].next_value = self.agent.estimate_value(_obs_data)
|
||||
collector[-1].done = np.array([0.0], dtype=np.float32)
|
||||
|
||||
# Monitor reporting / 监控上报
|
||||
now = time.time()
|
||||
@@ -231,6 +323,13 @@ class EpisodeRunner:
|
||||
"battery_fail": float(fm.battery_fail),
|
||||
"charge_count": float(charge_count),
|
||||
"remaining_charge": float(remaining_charge),
|
||||
"recharge_steps": float(fm.recharge_steps),
|
||||
"mask_final_avg": float(diag["avg_mask_counts"]["final"]),
|
||||
"mask_recharge_active": float(diag["mask_active_steps"]["recharge"]),
|
||||
"mask_recharge_changed": float(diag["mask_changed_steps"]["recharge"]),
|
||||
"mask_one_action_steps": float(diag["one_action_steps"]),
|
||||
"mask_two_or_less_action_steps": float(diag["two_or_less_action_steps"]),
|
||||
"mask_zero_final_steps": float(diag["zero_final_steps"]),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -239,6 +338,13 @@ class EpisodeRunner:
|
||||
# Compute GAE and yield samples
|
||||
# GAE 计算并 yield 样本
|
||||
if collector:
|
||||
if self.diag_max_episodes > 0 and self.episode_cnt >= self.diag_max_episodes:
|
||||
self.stop_requested = True
|
||||
if self.diag_log_only:
|
||||
collector.clear()
|
||||
if self.stop_requested:
|
||||
return
|
||||
break
|
||||
collector = sample_process(collector)
|
||||
yield collector
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user