86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: UTF-8 -*-
|
||
###########################################################################
|
||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||
###########################################################################
|
||
"""
|
||
Author: Tencent AI Arena Authors
|
||
|
||
Data definition and GAE computation for Robot Vacuum.
|
||
清扫大作战数据类定义与 GAE 计算。
|
||
"""
|
||
|
||
import numpy as np
|
||
from common_python.utils.common_func import create_cls
|
||
from agent_ppo.conf.conf import Config
|
||
|
||
|
||
# ObsData: feature vector + legal action mask
|
||
# 观测数据:feature 为特征向量,legal_action 为合法动作掩码
|
||
ObsData = create_cls("ObsData", feature=None, legal_action=None)
|
||
|
||
# ActData: sampled action, greedy action, action probabilities, state value
|
||
# 动作数据:action 为采样动作,d_action 为贪心动作,prob 为动作概率,value 为状态价值
|
||
ActData = create_cls(
|
||
"ActData",
|
||
action=None,
|
||
d_action=None,
|
||
prob=None,
|
||
value=None,
|
||
)
|
||
|
||
# SampleData: int values are treated as dimensions by the framework
|
||
# 训练样本数据:字段值为 int 时框架自动按维度处理
|
||
SampleData = create_cls(
|
||
"SampleData",
|
||
obs=Config.DIM_OF_OBSERVATION, # feature vector / 特征向量
|
||
legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码
|
||
act=1, # action index / 执行的动作
|
||
reward=Config.VALUE_NUM, # 1D reward / 奖励
|
||
reward_sum=Config.VALUE_NUM, # GAE td-lambda return
|
||
done=1,
|
||
value=Config.VALUE_NUM, # 1D value estimate / 价值估计
|
||
next_value=Config.VALUE_NUM,
|
||
advantage=Config.VALUE_NUM, # 1D GAE advantage / GAE 优势
|
||
prob=Config.ACTION_NUM, # 8D action probabilities / 动作概率
|
||
)
|
||
|
||
|
||
def sample_process(list_sample_data):
|
||
"""Fill next_value and compute GAE advantage.
|
||
|
||
计算 GAE 并填充 next_value。
|
||
"""
|
||
for i in range(len(list_sample_data) - 1):
|
||
list_sample_data[i].next_value = list_sample_data[i + 1].value
|
||
|
||
_calc_gae(list_sample_data)
|
||
return list_sample_data
|
||
|
||
|
||
def _calc_gae(list_sample_data):
|
||
"""Compute advantage and cumulative return using GAE(λ).
|
||
|
||
使用 GAE(λ) 计算优势函数与累积回报。
|
||
"""
|
||
gae = 0.0
|
||
gamma = Config.GAMMA
|
||
lamda = Config.LAMDA
|
||
for sample in reversed(list_sample_data):
|
||
value = _scalar(sample.value)
|
||
reward = _scalar(sample.reward)
|
||
next_value = _scalar(sample.next_value)
|
||
nonterminal = 1.0 - _scalar(sample.done)
|
||
delta = reward + gamma * next_value * nonterminal - value
|
||
gae = delta + gamma * lamda * nonterminal * gae
|
||
sample.advantage = gae
|
||
sample.reward_sum = gae + value
|
||
|
||
|
||
def _scalar(value):
|
||
"""Read the first scalar from numpy/tensor/list values."""
|
||
if hasattr(value, "detach"):
|
||
value = value.detach().cpu().numpy()
|
||
arr = np.asarray(value, dtype=np.float32).reshape(-1)
|
||
return float(arr[0]) if arr.size else 0.0
|