修复PPO评估推理返回None异常

This commit is contained in:
2026-04-26 15:35:19 +08:00
parent ba6cf2a797
commit 3d0a8122bb
3 changed files with 51 additions and 12 deletions

View File

@@ -108,9 +108,17 @@ class Agent(BaseAgent):
评估时推理(贪心)。 评估时推理(贪心)。
""" """
try:
obs_data, _ = self.observation_process(env_obs) obs_data, _ = self.observation_process(env_obs)
act_data = self.predict([obs_data])[0] act_data = self.predict([obs_data])[0]
return self.action_process(act_data, is_stochastic=False) return self.action_process(act_data, is_stochastic=False)
except Exception as err:
if self.logger:
if hasattr(self.logger, "exception"):
self.logger.exception(f"exploit fallback action due to inference error: {err}")
else:
self.logger.error(f"exploit fallback action due to inference error: {err}")
return self._fallback_action(env_obs)
def learn(self, list_sample_data): def learn(self, list_sample_data):
"""Delegate to Algorithm for PPO update. """Delegate to Algorithm for PPO update.
@@ -146,6 +154,24 @@ class Agent(BaseAgent):
if self.logger: if self.logger:
self.logger.info(f"load model {model_file_path} successfully") self.logger.info(f"load model {model_file_path} successfully")
def _fallback_action(self, env_obs):
"""Return a valid action instead of None during evaluation failures."""
legal_action = [1] * Config.ACTION_NUM
if isinstance(env_obs, dict):
observation = env_obs.get("observation")
if isinstance(observation, dict):
raw_legal = observation.get("legal_action") or observation.get("legal_act")
if isinstance(raw_legal, (list, tuple)) and raw_legal:
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
if len(legal_action) < Config.ACTION_NUM:
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
for action, is_legal in enumerate(legal_action):
if is_legal > 0:
self.last_action = action
return action
self.last_action = 0
return 0
def _run_model(self, feature): def _run_model(self, feature):
"""Gradient-free forward pass, returns (logits_np, value_np). """Gradient-free forward pass, returns (logits_np, value_np).

View File

@@ -31,6 +31,11 @@ def _signed_norm(v, v_max):
return float(np.clip(float(v) / float(v_max), -1.0, 1.0)) return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
def _as_dict(value):
"""Return a dict for optional nested observation fields."""
return value if isinstance(value, dict) else {}
class Preprocessor: class Preprocessor:
"""Feature preprocessor for Robot Vacuum. """Feature preprocessor for Robot Vacuum.
@@ -131,13 +136,16 @@ class Preprocessor:
从 env_obs 字典中提取并缓存所有需要的状态量。 从 env_obs 字典中提取并缓存所有需要的状态量。
""" """
observation = env_obs["observation"] env_obs = _as_dict(env_obs)
frame_state = observation.get("frame_state", {}) observation = _as_dict(env_obs.get("observation"))
extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {}) frame_state = _as_dict(observation.get("frame_state"))
env_info = observation.get("env_info", {}) extra_info = _as_dict(env_obs.get("extra_info"))
hero = frame_state.get("heroes", {}) extra_frame_state = _as_dict(extra_info.get("frame_state"))
env_info = _as_dict(observation.get("env_info"))
hero = frame_state.get("heroes") or {}
if isinstance(hero, list): if isinstance(hero, list):
hero = hero[0] if hero else {} hero = hero[0] if hero else {}
hero = _as_dict(hero)
self.last_action = int(last_action) self.last_action = int(last_action)
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no))) self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
@@ -149,7 +157,7 @@ class Preprocessor:
self.prev_low_battery = self.low_battery self.prev_low_battery = self.low_battery
self.was_recharge_mode = self.recharge_mode self.was_recharge_mode = self.recharge_mode
self.prev_pos = self.cur_pos if self.has_position_history else None self.prev_pos = self.cur_pos if self.has_position_history else None
hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]} hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]})
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1]))) self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
self.has_position_history = True self.has_position_history = True
@@ -198,7 +206,8 @@ class Preprocessor:
organs = frame_state.get("organs") or extra_frame_state.get("organs") or [] organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or [] npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
self.npcs = list(npcs) organs = organs if isinstance(organs, (list, tuple)) else []
self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else []
self._update_charger_state(hx, hz, organs) self._update_charger_state(hx, hz, organs)
self._update_npc_state(hx, hz, self.npcs) self._update_npc_state(hx, hz, self.npcs)
self._update_recharge_mode() self._update_recharge_mode()

View File

@@ -132,8 +132,12 @@ class EpisodeRunner:
final_reward = 0.0 final_reward = 0.0
if done: if done:
fm = self.agent.preprocessor fm = self.agent.preprocessor
env_info = env_obs["observation"]["env_info"] observation = env_obs.get("observation") or {}
extra_info = env_obs.get("extra_info", {}) observation = observation if isinstance(observation, dict) else {}
env_info = observation.get("env_info") or {}
env_info = env_info if isinstance(env_info, dict) else {}
extra_info = env_obs.get("extra_info") or {}
extra_info = extra_info if isinstance(extra_info, dict) else {}
total_score = env_info.get("total_score", fm.total_score) total_score = env_info.get("total_score", fm.total_score)
remaining_charge = env_info.get("remaining_charge", fm.remaining_charge) remaining_charge = env_info.get("remaining_charge", fm.remaining_charge)
charge_count = env_info.get("charge_count", fm.charge_count) charge_count = env_info.get("charge_count", fm.charge_count)