Improve PPO diagnostics and recharge behavior
This commit is contained in:
@@ -76,6 +76,8 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
action = act_data.action if is_stochastic else act_data.d_action
|
||||
self.last_action = int(action[0])
|
||||
if hasattr(self.preprocessor, "record_action"):
|
||||
self.preprocessor.record_action(self.last_action)
|
||||
return self.last_action
|
||||
|
||||
def predict(self, list_obs_data):
|
||||
@@ -110,7 +112,16 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
try:
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
logits, value = self._run_model(obs_data.feature)
|
||||
legal_arr = np.array(obs_data.legal_action, dtype=np.float32)
|
||||
prob = self._legal_soft_max(logits, legal_arr)
|
||||
action = self._tie_break_eval_action(prob, legal_arr)
|
||||
act_data = ActData(
|
||||
action=[action],
|
||||
d_action=[action],
|
||||
prob=list(prob),
|
||||
value=value,
|
||||
)
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
except Exception as err:
|
||||
if self.logger:
|
||||
@@ -127,6 +138,11 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
return self.algorithm.learn(list_sample_data)
|
||||
|
||||
def estimate_value(self, obs_data):
|
||||
"""Estimate critic value for a processed observation."""
|
||||
_, value = self._run_model(obs_data.feature)
|
||||
return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM]
|
||||
|
||||
def save_model(self, path=None, id="1"):
|
||||
"""Save model checkpoint.
|
||||
|
||||
@@ -220,3 +236,30 @@ class Agent(BaseAgent):
|
||||
if use_max:
|
||||
return int(np.argmax(probs))
|
||||
return int(np.random.choice(len(probs), p=probs))
|
||||
|
||||
def _tie_break_eval_action(self, probs, legal_action):
|
||||
"""Use a light heuristic only when evaluation probabilities are close."""
|
||||
probs = np.asarray(probs, dtype=np.float64)
|
||||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||||
if not np.any(legal):
|
||||
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||||
legal_indices = np.flatnonzero(legal)
|
||||
best_action = int(legal_indices[np.argmax(probs[legal_indices])])
|
||||
best_prob = float(probs[best_action])
|
||||
candidates = [
|
||||
int(action)
|
||||
for action in legal_indices
|
||||
if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP
|
||||
]
|
||||
if len(candidates) <= 1:
|
||||
return best_action
|
||||
|
||||
scored = []
|
||||
for action in candidates:
|
||||
heuristic = 0.0
|
||||
if hasattr(self.preprocessor, "evaluation_action_score"):
|
||||
heuristic = self.preprocessor.evaluation_action_score(action)
|
||||
combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic
|
||||
scored.append((combined, float(probs[action]), -action, action))
|
||||
scored.sort(reverse=True)
|
||||
return int(scored[0][3])
|
||||
|
||||
Reference in New Issue
Block a user