修复PPO评估推理返回None异常
This commit is contained in:
@@ -108,9 +108,17 @@ class Agent(BaseAgent):
|
||||
|
||||
评估时推理(贪心)。
|
||||
"""
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
try:
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
except Exception as err:
|
||||
if self.logger:
|
||||
if hasattr(self.logger, "exception"):
|
||||
self.logger.exception(f"exploit fallback action due to inference error: {err}")
|
||||
else:
|
||||
self.logger.error(f"exploit fallback action due to inference error: {err}")
|
||||
return self._fallback_action(env_obs)
|
||||
|
||||
def learn(self, list_sample_data):
|
||||
"""Delegate to Algorithm for PPO update.
|
||||
@@ -146,6 +154,24 @@ class Agent(BaseAgent):
|
||||
if self.logger:
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
|
||||
def _fallback_action(self, env_obs):
|
||||
"""Return a valid action instead of None during evaluation failures."""
|
||||
legal_action = [1] * Config.ACTION_NUM
|
||||
if isinstance(env_obs, dict):
|
||||
observation = env_obs.get("observation")
|
||||
if isinstance(observation, dict):
|
||||
raw_legal = observation.get("legal_action") or observation.get("legal_act")
|
||||
if isinstance(raw_legal, (list, tuple)) and raw_legal:
|
||||
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
|
||||
if len(legal_action) < Config.ACTION_NUM:
|
||||
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
|
||||
for action, is_legal in enumerate(legal_action):
|
||||
if is_legal > 0:
|
||||
self.last_action = action
|
||||
return action
|
||||
self.last_action = 0
|
||||
return 0
|
||||
|
||||
def _run_model(self, feature):
|
||||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||||
|
||||
|
||||
@@ -31,6 +31,11 @@ def _signed_norm(v, v_max):
|
||||
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
|
||||
|
||||
|
||||
def _as_dict(value):
|
||||
"""Return a dict for optional nested observation fields."""
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
class Preprocessor:
|
||||
"""Feature preprocessor for Robot Vacuum.
|
||||
|
||||
@@ -131,13 +136,16 @@ class Preprocessor:
|
||||
|
||||
从 env_obs 字典中提取并缓存所有需要的状态量。
|
||||
"""
|
||||
observation = env_obs["observation"]
|
||||
frame_state = observation.get("frame_state", {})
|
||||
extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {})
|
||||
env_info = observation.get("env_info", {})
|
||||
hero = frame_state.get("heroes", {})
|
||||
env_obs = _as_dict(env_obs)
|
||||
observation = _as_dict(env_obs.get("observation"))
|
||||
frame_state = _as_dict(observation.get("frame_state"))
|
||||
extra_info = _as_dict(env_obs.get("extra_info"))
|
||||
extra_frame_state = _as_dict(extra_info.get("frame_state"))
|
||||
env_info = _as_dict(observation.get("env_info"))
|
||||
hero = frame_state.get("heroes") or {}
|
||||
if isinstance(hero, list):
|
||||
hero = hero[0] if hero else {}
|
||||
hero = _as_dict(hero)
|
||||
|
||||
self.last_action = int(last_action)
|
||||
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
||||
@@ -149,7 +157,7 @@ class Preprocessor:
|
||||
self.prev_low_battery = self.low_battery
|
||||
self.was_recharge_mode = self.recharge_mode
|
||||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||||
hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]}
|
||||
hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]})
|
||||
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
|
||||
self.has_position_history = True
|
||||
|
||||
@@ -198,7 +206,8 @@ class Preprocessor:
|
||||
|
||||
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
||||
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
|
||||
self.npcs = list(npcs)
|
||||
organs = organs if isinstance(organs, (list, tuple)) else []
|
||||
self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else []
|
||||
self._update_charger_state(hx, hz, organs)
|
||||
self._update_npc_state(hx, hz, self.npcs)
|
||||
self._update_recharge_mode()
|
||||
|
||||
@@ -132,8 +132,12 @@ class EpisodeRunner:
|
||||
final_reward = 0.0
|
||||
if done:
|
||||
fm = self.agent.preprocessor
|
||||
env_info = env_obs["observation"]["env_info"]
|
||||
extra_info = env_obs.get("extra_info", {})
|
||||
observation = env_obs.get("observation") or {}
|
||||
observation = observation if isinstance(observation, dict) else {}
|
||||
env_info = observation.get("env_info") or {}
|
||||
env_info = env_info if isinstance(env_info, dict) else {}
|
||||
extra_info = env_obs.get("extra_info") or {}
|
||||
extra_info = extra_info if isinstance(extra_info, dict) else {}
|
||||
total_score = env_info.get("total_score", fm.total_score)
|
||||
remaining_charge = env_info.get("remaining_charge", fm.remaining_charge)
|
||||
charge_count = env_info.get("charge_count", fm.charge_count)
|
||||
|
||||
Reference in New Issue
Block a user