修复PPO评估推理返回None异常

This commit is contained in:
2026-04-26 15:35:19 +08:00
parent ba6cf2a797
commit 3d0a8122bb
3 changed files with 51 additions and 12 deletions

View File

@@ -108,9 +108,17 @@ class Agent(BaseAgent):
评估时推理(贪心)。
"""
obs_data, _ = self.observation_process(env_obs)
act_data = self.predict([obs_data])[0]
return self.action_process(act_data, is_stochastic=False)
try:
obs_data, _ = self.observation_process(env_obs)
act_data = self.predict([obs_data])[0]
return self.action_process(act_data, is_stochastic=False)
except Exception as err:
if self.logger:
if hasattr(self.logger, "exception"):
self.logger.exception(f"exploit fallback action due to inference error: {err}")
else:
self.logger.error(f"exploit fallback action due to inference error: {err}")
return self._fallback_action(env_obs)
def learn(self, list_sample_data):
"""Delegate to Algorithm for PPO update.
@@ -146,6 +154,24 @@ class Agent(BaseAgent):
if self.logger:
self.logger.info(f"load model {model_file_path} successfully")
def _fallback_action(self, env_obs):
"""Return a valid action instead of None during evaluation failures."""
legal_action = [1] * Config.ACTION_NUM
if isinstance(env_obs, dict):
observation = env_obs.get("observation")
if isinstance(observation, dict):
raw_legal = observation.get("legal_action") or observation.get("legal_act")
if isinstance(raw_legal, (list, tuple)) and raw_legal:
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
if len(legal_action) < Config.ACTION_NUM:
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
for action, is_legal in enumerate(legal_action):
if is_legal > 0:
self.last_action = action
return action
self.last_action = 0
return 0
def _run_model(self, feature):
"""Gradient-free forward pass, returns (logits_np, value_np).