优化 PPO 清扫策略

This commit is contained in:
2026-04-26 17:29:03 +08:00
parent f04feb0cd9
commit f44e2483fc
6 changed files with 223 additions and 86 deletions

View File

@@ -67,7 +67,19 @@ def _calc_gae(list_sample_data):
gamma = Config.GAMMA
lamda = Config.LAMDA
for sample in reversed(list_sample_data):
delta = -sample.value + sample.reward + gamma * sample.next_value
gae = gae * gamma * lamda + delta
value = _scalar(sample.value)
reward = _scalar(sample.reward)
next_value = _scalar(sample.next_value)
nonterminal = 1.0 - _scalar(sample.done)
delta = reward + gamma * next_value * nonterminal - value
gae = delta + gamma * lamda * nonterminal * gae
sample.advantage = gae
sample.reward_sum = gae + sample.value
sample.reward_sum = gae + value
def _scalar(value):
"""Read the first scalar from numpy/tensor/list values."""
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
arr = np.asarray(value, dtype=np.float32).reshape(-1)
return float(arr[0]) if arr.size else 0.0