From 69b8a692db60b4142196e665a70a5339d0e23c52 Mon Sep 17 00:00:00 2001 From: gqt <3217233537@qq.com> Date: Sun, 26 Apr 2026 20:24:26 +0800 Subject: [PATCH] Improve PPO diagnostics and recharge behavior --- agent_ppo/agent.py | 45 ++++- agent_ppo/conf/conf.py | 5 + agent_ppo/conf/monitor_builder.py | 40 ++++ agent_ppo/feature/preprocessor.py | 288 ++++++++++++++++++++++++--- agent_ppo/workflow/train_workflow.py | 110 +++++++++- train_test.py | 6 +- 6 files changed, 463 insertions(+), 31 deletions(-) diff --git a/agent_ppo/agent.py b/agent_ppo/agent.py index 8cd84d2..1dbbde8 100644 --- a/agent_ppo/agent.py +++ b/agent_ppo/agent.py @@ -76,6 +76,8 @@ class Agent(BaseAgent): """ action = act_data.action if is_stochastic else act_data.d_action self.last_action = int(action[0]) + if hasattr(self.preprocessor, "record_action"): + self.preprocessor.record_action(self.last_action) return self.last_action def predict(self, list_obs_data): @@ -110,7 +112,16 @@ class Agent(BaseAgent): """ try: obs_data, _ = self.observation_process(env_obs) - act_data = self.predict([obs_data])[0] + logits, value = self._run_model(obs_data.feature) + legal_arr = np.array(obs_data.legal_action, dtype=np.float32) + prob = self._legal_soft_max(logits, legal_arr) + action = self._tie_break_eval_action(prob, legal_arr) + act_data = ActData( + action=[action], + d_action=[action], + prob=list(prob), + value=value, + ) return self.action_process(act_data, is_stochastic=False) except Exception as err: if self.logger: @@ -127,6 +138,11 @@ class Agent(BaseAgent): """ return self.algorithm.learn(list_sample_data) + def estimate_value(self, obs_data): + """Estimate critic value for a processed observation.""" + _, value = self._run_model(obs_data.feature) + return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM] + def save_model(self, path=None, id="1"): """Save model checkpoint. @@ -220,3 +236,30 @@ class Agent(BaseAgent): if use_max: return int(np.argmax(probs)) return int(np.random.choice(len(probs), p=probs)) + + def _tie_break_eval_action(self, probs, legal_action): + """Use a light heuristic only when evaluation probabilities are close.""" + probs = np.asarray(probs, dtype=np.float64) + legal = np.asarray(legal_action, dtype=np.float32) > 0.5 + if not np.any(legal): + legal = np.ones(Config.ACTION_NUM, dtype=bool) + legal_indices = np.flatnonzero(legal) + best_action = int(legal_indices[np.argmax(probs[legal_indices])]) + best_prob = float(probs[best_action]) + candidates = [ + int(action) + for action in legal_indices + if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP + ] + if len(candidates) <= 1: + return best_action + + scored = [] + for action in candidates: + heuristic = 0.0 + if hasattr(self.preprocessor, "evaluation_action_score"): + heuristic = self.preprocessor.evaluation_action_score(action) + combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic + scored.append((combined, float(probs[action]), -action, action)) + scored.sort(reverse=True) + return int(scored[0][3]) diff --git a/agent_ppo/conf/conf.py b/agent_ppo/conf/conf.py index 64d5ac0..f8d4825 100644 --- a/agent_ppo/conf/conf.py +++ b/agent_ppo/conf/conf.py @@ -50,6 +50,11 @@ class Config: NORMALIZE_ADVANTAGE = True TARGET_KL = 0.04 + # Evaluation tie-break: when policy probabilities are close, prefer safer + # coverage/recharge actions with a lightweight heuristic. + EVAL_TIE_BREAK_PROB_GAP = 0.015 + EVAL_TIE_BREAK_SCORE_SCALE = 0.01 + LABEL_SIZE_LIST = [ACTION_NUM] LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy() diff --git a/agent_ppo/conf/monitor_builder.py b/agent_ppo/conf/monitor_builder.py index 9d8ece6..e94d283 100644 --- a/agent_ppo/conf/monitor_builder.py +++ b/agent_ppo/conf/monitor_builder.py @@ -125,6 +125,10 @@ def build_monitor(): metrics_name="recharge_escape_count", expr="avg(recharge_escape_count{})", ) + .add_metric( + metrics_name="recharge_steps", + expr="avg(recharge_steps{})", + ) .end_panel() .add_panel( name="NPC危险接近", @@ -172,6 +176,42 @@ def build_monitor(): expr="avg(remaining_charge{})", ) .end_panel() + .add_panel( + name="动作掩码健康", + name_en="mask_health", + type="line", + ) + .add_metric( + metrics_name="mask_final_avg", + expr="avg(mask_final_avg{})", + ) + .add_metric( + metrics_name="mask_one_action_steps", + expr="avg(mask_one_action_steps{})", + ) + .add_metric( + metrics_name="mask_two_or_less_action_steps", + expr="avg(mask_two_or_less_action_steps{})", + ) + .add_metric( + metrics_name="mask_zero_final_steps", + expr="avg(mask_zero_final_steps{})", + ) + .end_panel() + .add_panel( + name="回充动作掩码", + name_en="recharge_mask", + type="line", + ) + .add_metric( + metrics_name="mask_recharge_active", + expr="avg(mask_recharge_active{})", + ) + .add_metric( + metrics_name="mask_recharge_changed", + expr="avg(mask_recharge_changed{})", + ) + .end_panel() .end_group() .build() ) diff --git a/agent_ppo/feature/preprocessor.py b/agent_ppo/feature/preprocessor.py index 28f72d9..5c0d197 100644 --- a/agent_ppo/feature/preprocessor.py +++ b/agent_ppo/feature/preprocessor.py @@ -10,6 +10,7 @@ Feature preprocessor for Robot Vacuum. 清扫大作战特征预处理器。 """ +import os from collections import deque import numpy as np @@ -70,6 +71,7 @@ class Preprocessor: 对局开始时重置所有状态。 """ + self.map_id = -1 self.step_no = 0 self.battery = 600 self.battery_max = 600 @@ -118,6 +120,7 @@ class Preprocessor: self.frontier_action_delta = np.zeros(8, dtype=np.float32) self.charger_action_delta = np.zeros(8, dtype=np.float32) self.charger_route_known = False + self.charger_route_source = "none" # Nearest dirt path distance in the current local view. # 当前局部视野内最近污渍路径距离。 @@ -181,6 +184,36 @@ class Preprocessor: self.local_dirt_ratio = 0.0 self.local_obstacle_ratio = 0.0 + self.reward_profile = os.environ.get("ROBOT_VACUUM_REWARD_PROFILE", "current").strip().lower() or "current" + self._reset_diagnostics() + + def _reset_diagnostics(self): + """Reset episode-local diagnostic counters.""" + self.diag_mask_steps = 0 + self.diag_mask_count_sums = { + "raw": 0, + "blocked": 0, + "npc": 0, + "recharge": 0, + "escape": 0, + "leave": 0, + "final": 0, + } + self.diag_mask_changed_steps = { + "blocked": 0, + "npc": 0, + "recharge": 0, + "escape": 0, + "leave": 0, + } + self.diag_mask_active_steps = { + "recharge": 0, + "leave": 0, + } + self.diag_one_action_steps = 0 + self.diag_two_or_less_action_steps = 0 + self.diag_zero_final_steps = 0 + self.diag_action_hist = [0] * 8 def pb2struct(self, env_obs, last_action): """Parse and cache essential fields from observation dict. @@ -199,6 +232,11 @@ class Preprocessor: hero = _as_dict(hero) self.last_action = int(last_action) + map_id_value = extra_info.get("map_id", env_info.get("map_id", self.map_id)) + try: + self.map_id = int(map_id_value) + except (TypeError, ValueError): + pass self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no))) self.terminated = bool(env_obs.get("terminated", False)) self.truncated = bool(env_obs.get("truncated", False)) @@ -387,6 +425,7 @@ class Preprocessor: self.charger_energy_cost = self.nearest_charger_path_dist self.battery_margin = float(self.battery) - self.nearest_charger_path_dist self.charger_route_known = True + self.charger_route_source = "global" self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist) self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist) @@ -513,6 +552,7 @@ class Preprocessor: self.charger_safety_margin = 0.0 self.charger_rects = [] self.charger_route_known = False + self.charger_route_source = "none" best = None for organ in organs: @@ -550,9 +590,17 @@ class Preprocessor: self.nearest_charger_dist = float(dist) self.nearest_charger_range_dist = float(range_dist) path_dist = self._global_path_dist_to_charger(hx, hz) - self.charger_route_known = path_dist < self.INF_DIST - if not self.charger_route_known: + if path_dist < self.INF_DIST: + self.charger_route_known = True + self.charger_route_source = "global" + else: path_dist = self._local_path_dist_to_charger(hx, hz) + if path_dist < self.INF_DIST: + self.charger_route_known = True + self.charger_route_source = "local" + else: + self.charger_route_known = False + self.charger_route_source = "range" self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist) self.charger_energy_cost = self.nearest_charger_path_dist self.on_charger = range_dist <= 0.0 @@ -711,18 +759,26 @@ class Preprocessor: def _charger_safety_buffer(self): # One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise. - base = max(18.0, 0.12 * float(self.battery_max)) - distance_buffer = min(16.0, 0.18 * float(max(self.nearest_charger_range_dist, 0.0))) - obstacle_buffer = 12.0 * float(self.local_obstacle_ratio) - return float(np.clip(base + distance_buffer + obstacle_buffer, 18.0, 48.0)) + base = max(22.0, 0.14 * float(self.battery_max)) + distance_buffer = min(18.0, 0.20 * float(max(self.nearest_charger_range_dist, 0.0))) + obstacle_buffer = 14.0 * float(self.local_obstacle_ratio) + route_uncertainty_buffer = 10.0 if self.has_charger and not self.charger_route_known else 0.0 + return float(np.clip(base + distance_buffer + obstacle_buffer + route_uncertainty_buffer, 22.0, 58.0)) def _recharge_enter_margin(self): """Adaptive margin for entering recharge mode before the battery is barely enough.""" - base = max(5.0, 0.018 * float(self.battery_max)) - path_margin = min(12.0, 0.08 * float(max(self.nearest_charger_path_dist, 0.0))) - obstacle_margin = 12.0 * float(self.local_obstacle_ratio) + base = max(7.0, 0.025 * float(self.battery_max)) + path_margin = min(14.0, 0.10 * float(max(self.nearest_charger_path_dist, 0.0))) + obstacle_margin = 14.0 * float(self.local_obstacle_ratio) + route_uncertainty_margin = 8.0 if self.has_charger and not self.charger_route_known else 0.0 recovery_margin = min(8.0, 1.5 * float(self.recharge_no_progress_steps + self.fake_charger_steps)) - return float(np.clip(base + path_margin + obstacle_margin + recovery_margin, 4.0, 32.0)) + return float( + np.clip( + base + path_margin + obstacle_margin + route_uncertainty_margin + recovery_margin, + 6.0, + 42.0, + ) + ) def _recharge_leave_margin(self): """Adaptive safety margin required before leaving a charger.""" @@ -734,10 +790,12 @@ class Preprocessor: def _recharge_low_battery_ratio(self): """Adaptive low-battery ratio based on route length and local obstacle density.""" path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0) - ratio = 0.25 + min(0.08, 0.40 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio)) + ratio = 0.32 + min(0.09, 0.42 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio)) + if self.has_charger and not self.charger_route_known: + ratio += 0.04 if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0: ratio += 0.02 - return float(np.clip(ratio, 0.25, 0.40)) + return float(np.clip(ratio, 0.32, 0.46)) def _full_charge_leave_ratio(self): """Adaptive near-full threshold for leaving a charger.""" @@ -774,7 +832,10 @@ class Preprocessor: early_fail_risk = 1.0 - step_ratio path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0) risk = max(self._recharge_risk_score(), min(1.0, path_pressure)) - return float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0)) + penalty = float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0)) + if self.reward_profile == "battery_safe": + penalty *= 1.25 + return penalty def _min_charger_range_dist(self, x, z): if not self.charger_rects: @@ -1011,16 +1072,167 @@ class Preprocessor: 返回合法动作掩码(8D list)。 """ - legal = self._filter_blocked_actions(self._legal_act) - legal = self._filter_npc_danger_actions(legal) - safe_legal = list(legal) + raw_legal = [int(x) for x in self._legal_act] + blocked_legal = self._filter_blocked_actions(raw_legal) + npc_legal = self._filter_npc_danger_actions(blocked_legal) + safe_legal = list(npc_legal) + recharge_legal = None + escape_legal = None + leave_legal = None + legal = npc_legal if self.recharge_mode: - legal = self._filter_recharge_actions(legal) - legal = self._filter_recharge_escape_actions(legal, safe_legal) + recharge_legal = self._filter_recharge_actions(legal) + escape_legal = self._filter_recharge_escape_actions(recharge_legal, safe_legal) + legal = escape_legal elif self.on_charger and self.battery / max(self.battery_max, 1) >= self.full_charge_leave_ratio: - legal = self._filter_leave_charger_actions(legal) + leave_legal = self._filter_leave_charger_actions(legal) + legal = leave_legal + self._record_mask_diagnostics( + raw_legal=raw_legal, + blocked_legal=blocked_legal, + npc_legal=npc_legal, + recharge_legal=recharge_legal, + escape_legal=escape_legal, + leave_legal=leave_legal, + final_legal=legal, + ) return list(legal) + def record_action(self, action): + """Record the chosen action for episode diagnostics.""" + try: + action = int(action) + except (TypeError, ValueError): + return + if 0 <= action < len(self.diag_action_hist): + self.diag_action_hist[action] += 1 + + def _record_mask_diagnostics( + self, + raw_legal, + blocked_legal, + npc_legal, + recharge_legal, + escape_legal, + leave_legal, + final_legal, + ): + """Record action-mask counts without changing mask behavior.""" + self.diag_mask_steps += 1 + stages = { + "raw": raw_legal, + "blocked": blocked_legal, + "npc": npc_legal, + "recharge": recharge_legal if recharge_legal is not None else npc_legal, + "escape": escape_legal if escape_legal is not None else (recharge_legal if recharge_legal is not None else npc_legal), + "leave": leave_legal if leave_legal is not None else npc_legal, + "final": final_legal, + } + for name, mask in stages.items(): + self.diag_mask_count_sums[name] += self._mask_count(mask) + + if not self._same_mask(raw_legal, blocked_legal): + self.diag_mask_changed_steps["blocked"] += 1 + if not self._same_mask(blocked_legal, npc_legal): + self.diag_mask_changed_steps["npc"] += 1 + if recharge_legal is not None: + self.diag_mask_active_steps["recharge"] += 1 + if not self._same_mask(npc_legal, recharge_legal): + self.diag_mask_changed_steps["recharge"] += 1 + if escape_legal is not None and recharge_legal is not None: + if not self._same_mask(recharge_legal, escape_legal): + self.diag_mask_changed_steps["escape"] += 1 + if leave_legal is not None: + self.diag_mask_active_steps["leave"] += 1 + if not self._same_mask(npc_legal, leave_legal): + self.diag_mask_changed_steps["leave"] += 1 + + final_count = self._mask_count(final_legal) + if final_count <= 0: + self.diag_zero_final_steps += 1 + if final_count == 1: + self.diag_one_action_steps += 1 + if final_count <= 2: + self.diag_two_or_less_action_steps += 1 + + def _mask_count(self, mask): + return int(sum(1 for value in mask if int(value) > 0)) + + def _same_mask(self, left, right): + return [int(x) for x in left] == [int(x) for x in right] + + def get_diagnostic_summary(self): + """Return episode-level diagnostic counters for logging.""" + steps = max(self.diag_mask_steps, 1) + avg_mask_counts = { + name: self.diag_mask_count_sums[name] / steps for name in sorted(self.diag_mask_count_sums) + } + return { + "map_id": self.map_id, + "mask_steps": self.diag_mask_steps, + "avg_mask_counts": avg_mask_counts, + "mask_changed_steps": dict(self.diag_mask_changed_steps), + "mask_active_steps": dict(self.diag_mask_active_steps), + "one_action_steps": self.diag_one_action_steps, + "two_or_less_action_steps": self.diag_two_or_less_action_steps, + "zero_final_steps": self.diag_zero_final_steps, + "action_hist": list(self.diag_action_hist), + "known_ratio": self.known_ratio, + "known_dirty_ratio": self.known_dirty_ratio, + "frontier_ratio": self.frontier_ratio, + "local_dirt_ratio": self.local_dirt_ratio, + "local_obstacle_ratio": self.local_obstacle_ratio, + "global_dirty_path_dist": self.global_dirty_path_dist, + "frontier_path_dist": self.frontier_path_dist, + "charger_route_source": self.charger_route_source, + "reward_profile": self.reward_profile, + } + + def evaluation_action_score(self, action): + """Heuristic score used only to break close evaluation-policy ties.""" + if not (0 <= int(action) < len(self.ACTION_DIRS)): + return -1e6 + action = int(action) + dx, dz = self.ACTION_DIRS[action] + hx, hz = self.cur_pos + nx, nz = hx + dx, hz + dz + if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE): + return -1e6 + + score = 0.0 + cell = self._view_cell(dx, dz, default=1) + if cell == 0: + score -= 8.0 + elif cell == 2: + score += 3.0 + else: + score -= 0.10 + + visit_count = int(self.visit_count_map[nx, nz]) if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE else 0 + score += 0.35 if visit_count == 0 else -0.05 * min(visit_count, 10) + + if self.recharge_mode: + score += 2.2 * float(self.charger_action_delta[action]) + if self._charger_move_distance(nx, nz) < self._charger_move_distance(hx, hz): + score += 0.8 + else: + if self.global_dirty_path_dist < self.GRID_SIZE: + score += 1.8 * float(self.global_dirty_action_delta[action]) + elif self.frontier_path_dist < self.GRID_SIZE: + score += 1.4 * float(self.frontier_action_delta[action]) + + if self.low_battery and self.has_charger: + score += 1.2 * float(self.charger_action_delta[action]) + if self._is_charger_cell(nx, nz): + score += 0.8 if self.low_battery or self.recharge_mode else -0.2 + if self._is_npc_danger_cell(nx, nz, expanded=False): + score -= 6.0 + elif self._is_npc_danger_cell(nx, nz, expanded=True): + score -= 1.5 + if action == self.last_action and self.stuck_steps >= 1: + score -= 1.0 + return float(score) + def _filter_blocked_actions(self, legal_action): """Filter actions that are visibly blocked in the 21x21 view.""" legal = [int(x) for x in legal_action] @@ -1127,14 +1339,21 @@ class Preprocessor: recharge = [0] * 8 best_next_dist = min(item[0] for item in scored) ranked = sorted(scored, key=lambda item: (item[0], -item[1])) - for next_dist, _, _, action in ranked: - if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1: + max_recharge_actions = 4 if self.charger_route_known else 5 + dist_slack = 2.5 if self.charger_route_known else 4.0 + for next_dist, alignment, next_range_dist, action in ranked: + route_progress = next_dist <= current_move_dist + 0.1 + range_progress = next_range_dist <= current_range_dist + direction_progress = alignment > 0 + if next_dist <= best_next_dist + dist_slack and ( + route_progress or (not self.charger_route_known and (range_progress or direction_progress)) + ): recharge[action] = 1 - if sum(recharge) >= 3: + if sum(recharge) >= max_recharge_actions: break if not any(recharge): - for _, _, _, action in ranked[: min(3, len(ranked))]: + for _, _, _, action in ranked[: min(max_recharge_actions, len(ranked))]: recharge[action] = 1 return recharge if any(recharge) else list(legal_action) @@ -1228,10 +1447,15 @@ class Preprocessor: return feature, legal_action, reward def reward_process(self): + cleaning_multiplier, charge_multiplier, exploration_multiplier = self._reward_profile_scales() + # Cleaning reward / 清扫奖励 cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned) cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step - cleaning_scale = 0.2 if self.recharge_mode else 0.7 + battery_ratio = self.battery / max(self.battery_max, 1) + battery_pressure = self.has_charger and battery_ratio < self.recharge_low_battery_ratio + 0.06 + cleaning_scale = 0.2 if self.recharge_mode else (0.55 if battery_pressure else 0.7) + cleaning_scale *= cleaning_multiplier cleaning_reward = cleaning_scale * cleaned_cells # Step penalty / 时间惩罚 @@ -1240,7 +1464,6 @@ class Preprocessor: # Recharge guidance only activates when battery safety is the bottleneck. # 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。 charge_reward = 0.0 - battery_ratio = self.battery / max(self.battery_max, 1) prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1) useful_charge = self.charge_delta > 0 and ( self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.45 @@ -1257,12 +1480,16 @@ class Preprocessor: recharge_risk = self._recharge_risk_score() approach_scale = 0.07 + 0.06 * recharge_risk retreat_scale = 0.035 + 0.045 * recharge_risk + if not self.charger_route_known: + approach_scale += 0.02 + retreat_scale += 0.01 charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta if self.charger_safety_margin < self.recharge_enter_margin: safety_shortage = self.recharge_enter_margin - self.charger_safety_margin charge_reward -= min(0.55, safety_shortage / max(self.battery_max, 1)) elif self.on_charger and battery_ratio > 0.65: charge_reward -= 0.08 + charge_reward *= charge_multiplier # Encourage covering new passable cells and mildly discourage loops. # 鼓励探索新格子,轻微惩罚反复绕圈。 @@ -1276,6 +1503,7 @@ class Preprocessor: elif self.frontier_path_dist < self.GRID_SIZE: frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0) exploration_reward += 0.005 * frontier_progress + exploration_reward *= exploration_multiplier # Collision/stuck signal: invalid moves waste both step and battery. # 撞墙/原地不动会浪费步数和电量。 @@ -1301,3 +1529,13 @@ class Preprocessor: + npc_penalty + step_penalty ) + + def _reward_profile_scales(self): + """Return multipliers for quick reward-shaping ablations.""" + if self.reward_profile == "lower_recharge": + return 1.0, 0.70, 1.0 + if self.reward_profile == "clean_explore": + return 1.15, 0.85, 1.50 + if self.reward_profile == "battery_safe": + return 0.95, 1.25, 0.90 + return 1.0, 1.0, 1.0 diff --git a/agent_ppo/workflow/train_workflow.py b/agent_ppo/workflow/train_workflow.py index 2bb2795..573b020 100644 --- a/agent_ppo/workflow/train_workflow.py +++ b/agent_ppo/workflow/train_workflow.py @@ -26,6 +26,8 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs): last_save_model_time = time.time() env = envs[0] agent = agents[0] + diag_max_episodes = _read_diag_max_episodes(logger) + diag_log_only = _read_bool_env("ROBOT_VACUUM_DIAG_LOG_ONLY") # Read and validate user configuration # 读取和校验用户配置 @@ -33,6 +35,7 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs): if usr_conf is None: logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml") return + _apply_diag_env_overrides(usr_conf, logger) episode_runner = EpisodeRunner( env=env, @@ -40,9 +43,11 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs): usr_conf=usr_conf, logger=logger, monitor=monitor, + diag_max_episodes=diag_max_episodes, + diag_log_only=diag_log_only, ) - while True: + while not episode_runner.stop_requested: for g_data in episode_runner.run_episodes(): agent.send_sample_data(g_data) g_data.clear() @@ -51,10 +56,56 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs): if now - last_save_model_time >= 1800: agent.save_model() last_save_model_time = now + if episode_runner.stop_requested: + break + + if episode_runner.stop_requested: + logger.info(f"diagnostic max episodes reached: {episode_runner.episode_cnt}") + + +def _read_diag_max_episodes(logger): + raw_value = os.environ.get("ROBOT_VACUUM_DIAG_MAX_EPISODES", "").strip() + if not raw_value: + return 0 + try: + value = int(raw_value) + except ValueError: + if logger: + logger.warning(f"ignore invalid ROBOT_VACUUM_DIAG_MAX_EPISODES={raw_value!r}") + return 0 + return max(value, 0) + + +def _read_positive_int_env(name, logger): + raw_value = os.environ.get(name, "").strip() + if not raw_value: + return 0 + try: + value = int(raw_value) + except ValueError: + if logger: + logger.warning(f"ignore invalid {name}={raw_value!r}") + return 0 + return max(value, 0) + + +def _read_bool_env(name): + return os.environ.get(name, "").strip().lower() in ("1", "true", "yes", "on") + + +def _apply_diag_env_overrides(usr_conf, logger): + diag_max_step = _read_positive_int_env("ROBOT_VACUUM_DIAG_MAX_STEP", logger) + if diag_max_step <= 0: + return + env_conf = usr_conf.setdefault("env_conf", {}) + old_max_step = env_conf.get("max_step") + env_conf["max_step"] = diag_max_step + if logger: + logger.info(f"diagnostic max_step override: {old_max_step} -> {diag_max_step}") class EpisodeRunner: - def __init__(self, env, agent, usr_conf, logger, monitor): + def __init__(self, env, agent, usr_conf, logger, monitor, diag_max_episodes=0, diag_log_only=False): self.env = env self.agent = agent self.usr_conf = usr_conf @@ -63,6 +114,9 @@ class EpisodeRunner: self.episode_cnt = 0 self.last_report_monitor_time = 0 self.last_get_training_metrics_time = 0 + self.diag_max_episodes = int(diag_max_episodes) + self.diag_log_only = bool(diag_log_only) + self.stop_requested = False def run_episodes(self): """Run a single episode and yield collected samples. @@ -70,6 +124,8 @@ class EpisodeRunner: 单局流程(generator),完成一局后 yield 整局样本。 """ while True: + if self.stop_requested: + return # Periodically get training metrics # 定期打印训练指标 now = time.time() @@ -188,6 +244,39 @@ class EpisodeRunner: f"result_code:{result_code} " f"result_message:{result_message}" ) + diag = fm.get_diagnostic_summary() + self.logger.info( + f"[DIAG] ep:{self.episode_cnt} map:{diag['map_id']} " + f"steps:{step} result:{result_str} " + f"profile:{diag['reward_profile']} route:{diag['charger_route_source']} " + f"score:{float(total_score):.1f} reward:{total_reward + final_reward:.3f} " + f"mask_avg(raw/block/npc/recharge/escape/leave/final):" + f"{diag['avg_mask_counts']['raw']:.2f}/" + f"{diag['avg_mask_counts']['blocked']:.2f}/" + f"{diag['avg_mask_counts']['npc']:.2f}/" + f"{diag['avg_mask_counts']['recharge']:.2f}/" + f"{diag['avg_mask_counts']['escape']:.2f}/" + f"{diag['avg_mask_counts']['leave']:.2f}/" + f"{diag['avg_mask_counts']['final']:.2f} " + f"mask_changed(block/npc/recharge/escape/leave):" + f"{diag['mask_changed_steps']['blocked']}/" + f"{diag['mask_changed_steps']['npc']}/" + f"{diag['mask_changed_steps']['recharge']}/" + f"{diag['mask_changed_steps']['escape']}/" + f"{diag['mask_changed_steps']['leave']} " + f"mask_active(recharge/leave):" + f"{diag['mask_active_steps']['recharge']}/" + f"{diag['mask_active_steps']['leave']} " + f"tight(one/<=2/zero):" + f"{diag['one_action_steps']}/" + f"{diag['two_or_less_action_steps']}/" + f"{diag['zero_final_steps']} " + f"actions:{diag['action_hist']} " + f"known:{diag['known_ratio']:.3f} dirty_known:{diag['known_dirty_ratio']:.3f} " + f"frontier:{diag['frontier_ratio']:.3f} " + f"path_dirty/frontier:{diag['global_dirty_path_dist']:.1f}/" + f"{diag['frontier_path_dist']:.1f}" + ) # Build sample frame # 构造样本帧 @@ -212,6 +301,9 @@ class EpisodeRunner: # Add terminal reward to last frame # 终局奖励叠加到最后一步 collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32) + if truncated and not terminated: + collector[-1].next_value = self.agent.estimate_value(_obs_data) + collector[-1].done = np.array([0.0], dtype=np.float32) # Monitor reporting / 监控上报 now = time.time() @@ -231,6 +323,13 @@ class EpisodeRunner: "battery_fail": float(fm.battery_fail), "charge_count": float(charge_count), "remaining_charge": float(remaining_charge), + "recharge_steps": float(fm.recharge_steps), + "mask_final_avg": float(diag["avg_mask_counts"]["final"]), + "mask_recharge_active": float(diag["mask_active_steps"]["recharge"]), + "mask_recharge_changed": float(diag["mask_changed_steps"]["recharge"]), + "mask_one_action_steps": float(diag["one_action_steps"]), + "mask_two_or_less_action_steps": float(diag["two_or_less_action_steps"]), + "mask_zero_final_steps": float(diag["zero_final_steps"]), } } ) @@ -239,6 +338,13 @@ class EpisodeRunner: # Compute GAE and yield samples # GAE 计算并 yield 样本 if collector: + if self.diag_max_episodes > 0 and self.episode_cnt >= self.diag_max_episodes: + self.stop_requested = True + if self.diag_log_only: + collector.clear() + if self.stop_requested: + return + break collector = sample_process(collector) yield collector break diff --git a/train_test.py b/train_test.py index 25402db..0fa486e 100644 --- a/train_test.py +++ b/train_test.py @@ -21,9 +21,9 @@ if __name__ == "__main__": algorithm_name=algorithm_name, algorithm_name_list=algorithm_name_list, env_vars={ - "replay_buffer_capacity": "10", - "preload_ratio": "0.2", + "replay_buffer_capacity": "8", + "preload_ratio": "0.1", "train_batch_size": "2", - "dump_model_freq": "1", + "dump_model_freq": "100", }, )