修复PPO评估推理返回None异常
This commit is contained in:
@@ -108,9 +108,17 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
评估时推理(贪心)。
|
评估时推理(贪心)。
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
obs_data, _ = self.observation_process(env_obs)
|
obs_data, _ = self.observation_process(env_obs)
|
||||||
act_data = self.predict([obs_data])[0]
|
act_data = self.predict([obs_data])[0]
|
||||||
return self.action_process(act_data, is_stochastic=False)
|
return self.action_process(act_data, is_stochastic=False)
|
||||||
|
except Exception as err:
|
||||||
|
if self.logger:
|
||||||
|
if hasattr(self.logger, "exception"):
|
||||||
|
self.logger.exception(f"exploit fallback action due to inference error: {err}")
|
||||||
|
else:
|
||||||
|
self.logger.error(f"exploit fallback action due to inference error: {err}")
|
||||||
|
return self._fallback_action(env_obs)
|
||||||
|
|
||||||
def learn(self, list_sample_data):
|
def learn(self, list_sample_data):
|
||||||
"""Delegate to Algorithm for PPO update.
|
"""Delegate to Algorithm for PPO update.
|
||||||
@@ -146,6 +154,24 @@ class Agent(BaseAgent):
|
|||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.info(f"load model {model_file_path} successfully")
|
self.logger.info(f"load model {model_file_path} successfully")
|
||||||
|
|
||||||
|
def _fallback_action(self, env_obs):
|
||||||
|
"""Return a valid action instead of None during evaluation failures."""
|
||||||
|
legal_action = [1] * Config.ACTION_NUM
|
||||||
|
if isinstance(env_obs, dict):
|
||||||
|
observation = env_obs.get("observation")
|
||||||
|
if isinstance(observation, dict):
|
||||||
|
raw_legal = observation.get("legal_action") or observation.get("legal_act")
|
||||||
|
if isinstance(raw_legal, (list, tuple)) and raw_legal:
|
||||||
|
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
|
||||||
|
if len(legal_action) < Config.ACTION_NUM:
|
||||||
|
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
|
||||||
|
for action, is_legal in enumerate(legal_action):
|
||||||
|
if is_legal > 0:
|
||||||
|
self.last_action = action
|
||||||
|
return action
|
||||||
|
self.last_action = 0
|
||||||
|
return 0
|
||||||
|
|
||||||
def _run_model(self, feature):
|
def _run_model(self, feature):
|
||||||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,11 @@ def _signed_norm(v, v_max):
|
|||||||
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
|
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def _as_dict(value):
|
||||||
|
"""Return a dict for optional nested observation fields."""
|
||||||
|
return value if isinstance(value, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
class Preprocessor:
|
class Preprocessor:
|
||||||
"""Feature preprocessor for Robot Vacuum.
|
"""Feature preprocessor for Robot Vacuum.
|
||||||
|
|
||||||
@@ -131,13 +136,16 @@ class Preprocessor:
|
|||||||
|
|
||||||
从 env_obs 字典中提取并缓存所有需要的状态量。
|
从 env_obs 字典中提取并缓存所有需要的状态量。
|
||||||
"""
|
"""
|
||||||
observation = env_obs["observation"]
|
env_obs = _as_dict(env_obs)
|
||||||
frame_state = observation.get("frame_state", {})
|
observation = _as_dict(env_obs.get("observation"))
|
||||||
extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {})
|
frame_state = _as_dict(observation.get("frame_state"))
|
||||||
env_info = observation.get("env_info", {})
|
extra_info = _as_dict(env_obs.get("extra_info"))
|
||||||
hero = frame_state.get("heroes", {})
|
extra_frame_state = _as_dict(extra_info.get("frame_state"))
|
||||||
|
env_info = _as_dict(observation.get("env_info"))
|
||||||
|
hero = frame_state.get("heroes") or {}
|
||||||
if isinstance(hero, list):
|
if isinstance(hero, list):
|
||||||
hero = hero[0] if hero else {}
|
hero = hero[0] if hero else {}
|
||||||
|
hero = _as_dict(hero)
|
||||||
|
|
||||||
self.last_action = int(last_action)
|
self.last_action = int(last_action)
|
||||||
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
||||||
@@ -149,7 +157,7 @@ class Preprocessor:
|
|||||||
self.prev_low_battery = self.low_battery
|
self.prev_low_battery = self.low_battery
|
||||||
self.was_recharge_mode = self.recharge_mode
|
self.was_recharge_mode = self.recharge_mode
|
||||||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||||||
hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]}
|
hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]})
|
||||||
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
|
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
|
||||||
self.has_position_history = True
|
self.has_position_history = True
|
||||||
|
|
||||||
@@ -198,7 +206,8 @@ class Preprocessor:
|
|||||||
|
|
||||||
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
||||||
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
|
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
|
||||||
self.npcs = list(npcs)
|
organs = organs if isinstance(organs, (list, tuple)) else []
|
||||||
|
self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else []
|
||||||
self._update_charger_state(hx, hz, organs)
|
self._update_charger_state(hx, hz, organs)
|
||||||
self._update_npc_state(hx, hz, self.npcs)
|
self._update_npc_state(hx, hz, self.npcs)
|
||||||
self._update_recharge_mode()
|
self._update_recharge_mode()
|
||||||
|
|||||||
@@ -132,8 +132,12 @@ class EpisodeRunner:
|
|||||||
final_reward = 0.0
|
final_reward = 0.0
|
||||||
if done:
|
if done:
|
||||||
fm = self.agent.preprocessor
|
fm = self.agent.preprocessor
|
||||||
env_info = env_obs["observation"]["env_info"]
|
observation = env_obs.get("observation") or {}
|
||||||
extra_info = env_obs.get("extra_info", {})
|
observation = observation if isinstance(observation, dict) else {}
|
||||||
|
env_info = observation.get("env_info") or {}
|
||||||
|
env_info = env_info if isinstance(env_info, dict) else {}
|
||||||
|
extra_info = env_obs.get("extra_info") or {}
|
||||||
|
extra_info = extra_info if isinstance(extra_info, dict) else {}
|
||||||
total_score = env_info.get("total_score", fm.total_score)
|
total_score = env_info.get("total_score", fm.total_score)
|
||||||
remaining_charge = env_info.get("remaining_charge", fm.remaining_charge)
|
remaining_charge = env_info.get("remaining_charge", fm.remaining_charge)
|
||||||
charge_count = env_info.get("charge_count", fm.charge_count)
|
charge_count = env_info.get("charge_count", fm.charge_count)
|
||||||
|
|||||||
Reference in New Issue
Block a user