Improve PPO diagnostics and recharge behavior

This commit is contained in:
2026-04-26 20:24:26 +08:00
parent 5b6133db13
commit 69b8a692db
6 changed files with 463 additions and 31 deletions

View File

@@ -76,6 +76,8 @@ class Agent(BaseAgent):
"""
action = act_data.action if is_stochastic else act_data.d_action
self.last_action = int(action[0])
if hasattr(self.preprocessor, "record_action"):
self.preprocessor.record_action(self.last_action)
return self.last_action
def predict(self, list_obs_data):
@@ -110,7 +112,16 @@ class Agent(BaseAgent):
"""
try:
obs_data, _ = self.observation_process(env_obs)
act_data = self.predict([obs_data])[0]
logits, value = self._run_model(obs_data.feature)
legal_arr = np.array(obs_data.legal_action, dtype=np.float32)
prob = self._legal_soft_max(logits, legal_arr)
action = self._tie_break_eval_action(prob, legal_arr)
act_data = ActData(
action=[action],
d_action=[action],
prob=list(prob),
value=value,
)
return self.action_process(act_data, is_stochastic=False)
except Exception as err:
if self.logger:
@@ -127,6 +138,11 @@ class Agent(BaseAgent):
"""
return self.algorithm.learn(list_sample_data)
def estimate_value(self, obs_data):
"""Estimate critic value for a processed observation."""
_, value = self._run_model(obs_data.feature)
return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM]
def save_model(self, path=None, id="1"):
"""Save model checkpoint.
@@ -220,3 +236,30 @@ class Agent(BaseAgent):
if use_max:
return int(np.argmax(probs))
return int(np.random.choice(len(probs), p=probs))
def _tie_break_eval_action(self, probs, legal_action):
"""Use a light heuristic only when evaluation probabilities are close."""
probs = np.asarray(probs, dtype=np.float64)
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
if not np.any(legal):
legal = np.ones(Config.ACTION_NUM, dtype=bool)
legal_indices = np.flatnonzero(legal)
best_action = int(legal_indices[np.argmax(probs[legal_indices])])
best_prob = float(probs[best_action])
candidates = [
int(action)
for action in legal_indices
if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP
]
if len(candidates) <= 1:
return best_action
scored = []
for action in candidates:
heuristic = 0.0
if hasattr(self.preprocessor, "evaluation_action_score"):
heuristic = self.preprocessor.evaluation_action_score(action)
combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic
scored.append((combined, float(probs[action]), -action, action))
scored.sort(reverse=True)
return int(scored[0][3])