From 3d0a8122bbae71c7d084012e7e49e091f72dac8d Mon Sep 17 00:00:00 2001 From: gqt <3217233537@qq.com> Date: Sun, 26 Apr 2026 15:35:19 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DPPO=E8=AF=84=E4=BC=B0?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E8=BF=94=E5=9B=9ENone=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_ppo/agent.py | 32 +++++++++++++++++++++++++--- agent_ppo/feature/preprocessor.py | 23 ++++++++++++++------ agent_ppo/workflow/train_workflow.py | 8 +++++-- 3 files changed, 51 insertions(+), 12 deletions(-) diff --git a/agent_ppo/agent.py b/agent_ppo/agent.py index 56a2621..0509878 100644 --- a/agent_ppo/agent.py +++ b/agent_ppo/agent.py @@ -108,9 +108,17 @@ class Agent(BaseAgent): 评估时推理(贪心)。 """ - obs_data, _ = self.observation_process(env_obs) - act_data = self.predict([obs_data])[0] - return self.action_process(act_data, is_stochastic=False) + try: + obs_data, _ = self.observation_process(env_obs) + act_data = self.predict([obs_data])[0] + return self.action_process(act_data, is_stochastic=False) + except Exception as err: + if self.logger: + if hasattr(self.logger, "exception"): + self.logger.exception(f"exploit fallback action due to inference error: {err}") + else: + self.logger.error(f"exploit fallback action due to inference error: {err}") + return self._fallback_action(env_obs) def learn(self, list_sample_data): """Delegate to Algorithm for PPO update. @@ -146,6 +154,24 @@ class Agent(BaseAgent): if self.logger: self.logger.info(f"load model {model_file_path} successfully") + def _fallback_action(self, env_obs): + """Return a valid action instead of None during evaluation failures.""" + legal_action = [1] * Config.ACTION_NUM + if isinstance(env_obs, dict): + observation = env_obs.get("observation") + if isinstance(observation, dict): + raw_legal = observation.get("legal_action") or observation.get("legal_act") + if isinstance(raw_legal, (list, tuple)) and raw_legal: + legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]] + if len(legal_action) < Config.ACTION_NUM: + legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action))) + for action, is_legal in enumerate(legal_action): + if is_legal > 0: + self.last_action = action + return action + self.last_action = 0 + return 0 + def _run_model(self, feature): """Gradient-free forward pass, returns (logits_np, value_np). diff --git a/agent_ppo/feature/preprocessor.py b/agent_ppo/feature/preprocessor.py index 432574d..5d69fc7 100644 --- a/agent_ppo/feature/preprocessor.py +++ b/agent_ppo/feature/preprocessor.py @@ -31,6 +31,11 @@ def _signed_norm(v, v_max): return float(np.clip(float(v) / float(v_max), -1.0, 1.0)) +def _as_dict(value): + """Return a dict for optional nested observation fields.""" + return value if isinstance(value, dict) else {} + + class Preprocessor: """Feature preprocessor for Robot Vacuum. @@ -131,13 +136,16 @@ class Preprocessor: 从 env_obs 字典中提取并缓存所有需要的状态量。 """ - observation = env_obs["observation"] - frame_state = observation.get("frame_state", {}) - extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {}) - env_info = observation.get("env_info", {}) - hero = frame_state.get("heroes", {}) + env_obs = _as_dict(env_obs) + observation = _as_dict(env_obs.get("observation")) + frame_state = _as_dict(observation.get("frame_state")) + extra_info = _as_dict(env_obs.get("extra_info")) + extra_frame_state = _as_dict(extra_info.get("frame_state")) + env_info = _as_dict(observation.get("env_info")) + hero = frame_state.get("heroes") or {} if isinstance(hero, list): hero = hero[0] if hero else {} + hero = _as_dict(hero) self.last_action = int(last_action) self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no))) @@ -149,7 +157,7 @@ class Preprocessor: self.prev_low_battery = self.low_battery self.was_recharge_mode = self.recharge_mode self.prev_pos = self.cur_pos if self.has_position_history else None - hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]} + hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]}) self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1]))) self.has_position_history = True @@ -198,7 +206,8 @@ class Preprocessor: organs = frame_state.get("organs") or extra_frame_state.get("organs") or [] npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or [] - self.npcs = list(npcs) + organs = organs if isinstance(organs, (list, tuple)) else [] + self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else [] self._update_charger_state(hx, hz, organs) self._update_npc_state(hx, hz, self.npcs) self._update_recharge_mode() diff --git a/agent_ppo/workflow/train_workflow.py b/agent_ppo/workflow/train_workflow.py index dec00fa..f24dc7b 100644 --- a/agent_ppo/workflow/train_workflow.py +++ b/agent_ppo/workflow/train_workflow.py @@ -132,8 +132,12 @@ class EpisodeRunner: final_reward = 0.0 if done: fm = self.agent.preprocessor - env_info = env_obs["observation"]["env_info"] - extra_info = env_obs.get("extra_info", {}) + observation = env_obs.get("observation") or {} + observation = observation if isinstance(observation, dict) else {} + env_info = observation.get("env_info") or {} + env_info = env_info if isinstance(env_info, dict) else {} + extra_info = env_obs.get("extra_info") or {} + extra_info = extra_info if isinstance(extra_info, dict) else {} total_score = env_info.get("total_score", fm.total_score) remaining_charge = env_info.get("remaining_charge", fm.remaining_charge) charge_count = env_info.get("charge_count", fm.charge_count)