优化 PPO 清扫策略
This commit is contained in:
@@ -192,18 +192,31 @@ class Agent(BaseAgent):
|
||||
|
||||
合法动作掩码下的 softmax。
|
||||
"""
|
||||
_w, _e = 1e20, 1e-5
|
||||
tmp = logits - _w * (1.0 - legal_action)
|
||||
tmp_max = np.max(tmp, keepdims=True)
|
||||
tmp = np.clip(tmp - tmp_max, -_w, 1)
|
||||
tmp = (np.exp(tmp) + _e) * legal_action
|
||||
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
|
||||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||||
if not np.any(legal):
|
||||
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||||
|
||||
masked_logits = np.asarray(logits, dtype=np.float32).copy()
|
||||
masked_logits[~legal] = -1e9
|
||||
shifted = masked_logits - np.max(masked_logits)
|
||||
probs = np.exp(shifted) * legal.astype(np.float32)
|
||||
prob_sum = float(np.sum(probs))
|
||||
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||||
probs = legal.astype(np.float32)
|
||||
prob_sum = float(np.sum(probs))
|
||||
return probs / prob_sum
|
||||
|
||||
def _legal_sample(self, probs, use_max=False):
|
||||
"""Sample action from probability distribution (argmax if use_max=True).
|
||||
|
||||
按概率分布采样动作(use_max=True 时取 argmax)。
|
||||
"""
|
||||
probs = np.asarray(probs, dtype=np.float64)
|
||||
prob_sum = float(np.sum(probs))
|
||||
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||||
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
|
||||
else:
|
||||
probs = probs / prob_sum
|
||||
if use_max:
|
||||
return int(np.argmax(probs))
|
||||
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
|
||||
return int(np.random.choice(len(probs), p=probs))
|
||||
|
||||
Reference in New Issue
Block a user