优化 PPO 清扫策略

This commit is contained in:
2026-04-26 17:29:03 +08:00
parent f04feb0cd9
commit f44e2483fc
6 changed files with 223 additions and 86 deletions

View File

@@ -192,18 +192,31 @@ class Agent(BaseAgent):
合法动作掩码下的 softmax。
"""
_w, _e = 1e20, 1e-5
tmp = logits - _w * (1.0 - legal_action)
tmp_max = np.max(tmp, keepdims=True)
tmp = np.clip(tmp - tmp_max, -_w, 1)
tmp = (np.exp(tmp) + _e) * legal_action
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
if not np.any(legal):
legal = np.ones(Config.ACTION_NUM, dtype=bool)
masked_logits = np.asarray(logits, dtype=np.float32).copy()
masked_logits[~legal] = -1e9
shifted = masked_logits - np.max(masked_logits)
probs = np.exp(shifted) * legal.astype(np.float32)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = legal.astype(np.float32)
prob_sum = float(np.sum(probs))
return probs / prob_sum
def _legal_sample(self, probs, use_max=False):
"""Sample action from probability distribution (argmax if use_max=True).
按概率分布采样动作use_max=True 时取 argmax
"""
probs = np.asarray(probs, dtype=np.float64)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
else:
probs = probs / prob_sum
if use_max:
return int(np.argmax(probs))
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
return int(np.random.choice(len(probs), p=probs))