修复PPO评估推理返回None异常
This commit is contained in:
@@ -108,9 +108,17 @@ class Agent(BaseAgent):
|
||||
|
||||
评估时推理(贪心)。
|
||||
"""
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
try:
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
except Exception as err:
|
||||
if self.logger:
|
||||
if hasattr(self.logger, "exception"):
|
||||
self.logger.exception(f"exploit fallback action due to inference error: {err}")
|
||||
else:
|
||||
self.logger.error(f"exploit fallback action due to inference error: {err}")
|
||||
return self._fallback_action(env_obs)
|
||||
|
||||
def learn(self, list_sample_data):
|
||||
"""Delegate to Algorithm for PPO update.
|
||||
@@ -146,6 +154,24 @@ class Agent(BaseAgent):
|
||||
if self.logger:
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
|
||||
def _fallback_action(self, env_obs):
|
||||
"""Return a valid action instead of None during evaluation failures."""
|
||||
legal_action = [1] * Config.ACTION_NUM
|
||||
if isinstance(env_obs, dict):
|
||||
observation = env_obs.get("observation")
|
||||
if isinstance(observation, dict):
|
||||
raw_legal = observation.get("legal_action") or observation.get("legal_act")
|
||||
if isinstance(raw_legal, (list, tuple)) and raw_legal:
|
||||
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
|
||||
if len(legal_action) < Config.ACTION_NUM:
|
||||
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
|
||||
for action, is_legal in enumerate(legal_action):
|
||||
if is_legal > 0:
|
||||
self.last_action = action
|
||||
return action
|
||||
self.last_action = 0
|
||||
return 0
|
||||
|
||||
def _run_model(self, feature):
|
||||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user