优化 PPO 清扫策略
This commit is contained in:
@@ -67,7 +67,19 @@ def _calc_gae(list_sample_data):
|
||||
gamma = Config.GAMMA
|
||||
lamda = Config.LAMDA
|
||||
for sample in reversed(list_sample_data):
|
||||
delta = -sample.value + sample.reward + gamma * sample.next_value
|
||||
gae = gae * gamma * lamda + delta
|
||||
value = _scalar(sample.value)
|
||||
reward = _scalar(sample.reward)
|
||||
next_value = _scalar(sample.next_value)
|
||||
nonterminal = 1.0 - _scalar(sample.done)
|
||||
delta = reward + gamma * next_value * nonterminal - value
|
||||
gae = delta + gamma * lamda * nonterminal * gae
|
||||
sample.advantage = gae
|
||||
sample.reward_sum = gae + sample.value
|
||||
sample.reward_sum = gae + value
|
||||
|
||||
|
||||
def _scalar(value):
|
||||
"""Read the first scalar from numpy/tensor/list values."""
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach().cpu().numpy()
|
||||
arr = np.asarray(value, dtype=np.float32).reshape(-1)
|
||||
return float(arr[0]) if arr.size else 0.0
|
||||
|
||||
Reference in New Issue
Block a user