Initial robot vacuum code
This commit is contained in:
73
agent_ppo/feature/definition.py
Normal file
73
agent_ppo/feature/definition.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Data definition and GAE computation for Robot Vacuum.
|
||||
清扫大作战数据类定义与 GAE 计算。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from common_python.utils.common_func import create_cls
|
||||
from agent_ppo.conf.conf import Config
|
||||
|
||||
|
||||
# ObsData: feature vector + legal action mask
|
||||
# 观测数据:feature 为特征向量,legal_action 为合法动作掩码
|
||||
ObsData = create_cls("ObsData", feature=None, legal_action=None)
|
||||
|
||||
# ActData: sampled action, greedy action, action probabilities, state value
|
||||
# 动作数据:action 为采样动作,d_action 为贪心动作,prob 为动作概率,value 为状态价值
|
||||
ActData = create_cls(
|
||||
"ActData",
|
||||
action=None,
|
||||
d_action=None,
|
||||
prob=None,
|
||||
value=None,
|
||||
)
|
||||
|
||||
# SampleData: int values are treated as dimensions by the framework
|
||||
# 训练样本数据:字段值为 int 时框架自动按维度处理
|
||||
SampleData = create_cls(
|
||||
"SampleData",
|
||||
obs=Config.DIM_OF_OBSERVATION, # 69D feature vector / 特征向量
|
||||
legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码
|
||||
act=1, # action index / 执行的动作
|
||||
reward=Config.VALUE_NUM, # 1D reward / 奖励
|
||||
reward_sum=Config.VALUE_NUM, # GAE td-lambda return
|
||||
done=1,
|
||||
value=Config.VALUE_NUM, # 1D value estimate / 价值估计
|
||||
next_value=Config.VALUE_NUM,
|
||||
advantage=Config.VALUE_NUM, # 1D GAE advantage / GAE 优势
|
||||
prob=Config.ACTION_NUM, # 8D action probabilities / 动作概率
|
||||
)
|
||||
|
||||
|
||||
def sample_process(list_sample_data):
|
||||
"""Fill next_value and compute GAE advantage.
|
||||
|
||||
计算 GAE 并填充 next_value。
|
||||
"""
|
||||
for i in range(len(list_sample_data) - 1):
|
||||
list_sample_data[i].next_value = list_sample_data[i + 1].value
|
||||
|
||||
_calc_gae(list_sample_data)
|
||||
return list_sample_data
|
||||
|
||||
|
||||
def _calc_gae(list_sample_data):
|
||||
"""Compute advantage and cumulative return using GAE(λ).
|
||||
|
||||
使用 GAE(λ) 计算优势函数与累积回报。
|
||||
"""
|
||||
gae = 0.0
|
||||
gamma = Config.GAMMA
|
||||
lamda = Config.LAMDA
|
||||
for sample in reversed(list_sample_data):
|
||||
delta = -sample.value + sample.reward + gamma * sample.next_value
|
||||
gae = gae * gamma * lamda + delta
|
||||
sample.advantage = gae
|
||||
sample.reward_sum = gae + sample.value
|
||||
Reference in New Issue
Block a user