Initial robot vacuum code
This commit is contained in:
0
agent_ppo/__init__.py
Normal file
0
agent_ppo/__init__.py
Normal file
175
agent_ppo/agent.py
Normal file
175
agent_ppo/agent.py
Normal file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Robot Vacuum Agent.
|
||||
清扫大作战 Agent 主类。
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agent_ppo.algorithm.algorithm import Algorithm
|
||||
from agent_ppo.conf.conf import Config
|
||||
from agent_ppo.feature.definition import ActData, ObsData
|
||||
from agent_ppo.feature.preprocessor import Preprocessor
|
||||
from agent_ppo.model.model import Model
|
||||
from kaiwudrl.interface.agent import BaseAgent
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
|
||||
torch.manual_seed(0)
|
||||
self.device = device
|
||||
self.model = Model(device).to(self.device)
|
||||
self.optimizer = torch.optim.Adam(
|
||||
params=self.model.parameters(),
|
||||
lr=Config.INIT_LEARNING_RATE_START,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
self.logger = logger
|
||||
self.monitor = monitor
|
||||
self.algorithm = Algorithm(self.model, self.optimizer, self.device, self.logger, self.monitor)
|
||||
self.preprocessor = Preprocessor()
|
||||
self.last_action = -1
|
||||
self.last_reward = 0.0
|
||||
|
||||
super().__init__(agent_type, device, logger, monitor)
|
||||
|
||||
def reset(self, env_obs):
|
||||
"""Reset per-episode state.
|
||||
|
||||
每局开始时重置 Agent 内部状态。
|
||||
"""
|
||||
self.preprocessor = Preprocessor()
|
||||
self.last_action = -1
|
||||
self.last_reward = 0.0
|
||||
|
||||
def observation_process(self, env_obs):
|
||||
"""Convert raw env_obs to ObsData (69D feature + legal action mask).
|
||||
|
||||
将原始 env_obs 转换为 ObsData(69D 特征 + 合法动作掩码)。
|
||||
"""
|
||||
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
||||
self.last_reward = reward
|
||||
|
||||
obs_data = ObsData(
|
||||
feature=list(feature),
|
||||
legal_action=legal_action,
|
||||
)
|
||||
remain_info = {}
|
||||
return obs_data, remain_info
|
||||
|
||||
def action_process(self, act_data, is_stochastic=True):
|
||||
"""Extract int action from ActData and update last_action.
|
||||
|
||||
从 ActData 中取出动作整数并更新 last_action。
|
||||
"""
|
||||
action = act_data.action if is_stochastic else act_data.d_action
|
||||
self.last_action = int(action[0])
|
||||
return self.last_action
|
||||
|
||||
def predict(self, list_obs_data):
|
||||
"""Stochastic inference for training (exploration).
|
||||
|
||||
训练时推理(随机采样动作)。
|
||||
"""
|
||||
obs_data = list_obs_data[0]
|
||||
feature = obs_data.feature
|
||||
legal_action = obs_data.legal_action
|
||||
|
||||
logits, value = self._run_model(feature)
|
||||
|
||||
legal_arr = np.array(legal_action, dtype=np.float32)
|
||||
prob = self._legal_soft_max(logits, legal_arr)
|
||||
action = self._legal_sample(prob, use_max=False)
|
||||
d_action = self._legal_sample(prob, use_max=True)
|
||||
|
||||
return [
|
||||
ActData(
|
||||
action=[action],
|
||||
d_action=[d_action],
|
||||
prob=list(prob),
|
||||
value=value,
|
||||
)
|
||||
]
|
||||
|
||||
def exploit(self, env_obs):
|
||||
"""Greedy inference for evaluation.
|
||||
|
||||
评估时推理(贪心)。
|
||||
"""
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
|
||||
def learn(self, list_sample_data):
|
||||
"""Delegate to Algorithm for PPO update.
|
||||
|
||||
委托给 Algorithm 执行训练。
|
||||
"""
|
||||
return self.algorithm.learn(list_sample_data)
|
||||
|
||||
def save_model(self, path=None, id="1"):
|
||||
"""Save model checkpoint.
|
||||
|
||||
保存模型检查点。
|
||||
"""
|
||||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||||
state_dict_cpu = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
|
||||
torch.save(state_dict_cpu, model_file_path)
|
||||
self.logger.info(f"save model {model_file_path} successfully")
|
||||
|
||||
def load_model(self, path=None, id="1"):
|
||||
"""Load model checkpoint.
|
||||
|
||||
加载模型检查点。
|
||||
"""
|
||||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||||
self.model.load_state_dict(torch.load(model_file_path, map_location=self.device))
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
|
||||
def _run_model(self, feature):
|
||||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||||
|
||||
无梯度推理,返回 (logits_np, value_np)。
|
||||
"""
|
||||
self.model.set_eval_mode()
|
||||
obs_tensor = (
|
||||
torch.tensor(np.array([feature], dtype=np.float32)).view(1, Config.DIM_OF_OBSERVATION).to(self.device)
|
||||
)
|
||||
with torch.no_grad():
|
||||
rst = self.model(obs_tensor, inference=True)
|
||||
logits = rst[0].cpu().numpy()[0]
|
||||
value = rst[1].cpu().numpy()[0]
|
||||
return logits, value
|
||||
|
||||
def _legal_soft_max(self, logits, legal_action):
|
||||
"""Softmax with legal action masking.
|
||||
|
||||
合法动作掩码下的 softmax。
|
||||
"""
|
||||
_w, _e = 1e20, 1e-5
|
||||
tmp = logits - _w * (1.0 - legal_action)
|
||||
tmp_max = np.max(tmp, keepdims=True)
|
||||
tmp = np.clip(tmp - tmp_max, -_w, 1)
|
||||
tmp = (np.exp(tmp) + _e) * legal_action
|
||||
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
|
||||
|
||||
def _legal_sample(self, probs, use_max=False):
|
||||
"""Sample action from probability distribution (argmax if use_max=True).
|
||||
|
||||
按概率分布采样动作(use_max=True 时取 argmax)。
|
||||
"""
|
||||
if use_max:
|
||||
return int(np.argmax(probs))
|
||||
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
|
||||
0
agent_ppo/algorithm/__init__.py
Normal file
0
agent_ppo/algorithm/__init__.py
Normal file
161
agent_ppo/algorithm/algorithm.py
Normal file
161
agent_ppo/algorithm/algorithm.py
Normal file
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Standard PPO algorithm for Robot Vacuum.
|
||||
清扫大作战 PPO 算法。
|
||||
|
||||
Loss composition / 损失组成:
|
||||
total_loss = vf_coef * value_loss + policy_loss - beta * entropy_loss
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from agent_ppo.conf.conf import Config
|
||||
|
||||
|
||||
class Algorithm:
|
||||
def __init__(self, model, optimizer, device=None, logger=None, monitor=None):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.parameters = [p for pg in optimizer.param_groups for p in pg["params"]]
|
||||
self.device = device
|
||||
self.logger = logger
|
||||
self.monitor = monitor
|
||||
|
||||
self.clip_param = Config.CLIP_PARAM
|
||||
self.vf_coef = Config.VF_COEF
|
||||
self.var_beta = Config.BETA_START
|
||||
self.label_size = Config.ACTION_NUM
|
||||
|
||||
self.train_step = 0
|
||||
self.last_report_time = 0
|
||||
|
||||
def learn(self, list_sample_data):
|
||||
"""Training entry: perform one PPO gradient step on a batch of SampleData.
|
||||
|
||||
训练入口:接收一批 SampleData,执行一步梯度更新。
|
||||
"""
|
||||
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device)
|
||||
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device)
|
||||
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1)
|
||||
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device)
|
||||
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device)
|
||||
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device)
|
||||
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
|
||||
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
|
||||
|
||||
self.model.set_train_mode()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
rst_list = self.model(obs)
|
||||
logits, value_pred = rst_list[0], rst_list[1]
|
||||
|
||||
total_loss, info = self._compute_loss(
|
||||
logits=logits,
|
||||
value_pred=value_pred,
|
||||
legal_action=legal_action,
|
||||
old_action=act,
|
||||
old_prob=old_prob,
|
||||
old_value=old_value,
|
||||
reward_sum=reward_sum,
|
||||
advantage=advantage,
|
||||
)
|
||||
|
||||
total_loss.backward()
|
||||
|
||||
if Config.USE_GRAD_CLIP:
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE)
|
||||
|
||||
self.optimizer.step()
|
||||
self.train_step += 1
|
||||
|
||||
results = {"total_loss": total_loss.item()}
|
||||
|
||||
# Periodic monitoring report
|
||||
# 定期上报监控
|
||||
now = time.time()
|
||||
if now - self.last_report_time >= 60:
|
||||
results["value_loss"] = round(info["value_loss"], 4)
|
||||
results["policy_loss"] = round(info["policy_loss"], 4)
|
||||
results["entropy_loss"] = round(info["entropy_loss"], 4)
|
||||
results["reward"] = round(reward.mean().item(), 4)
|
||||
|
||||
self.logger.info(
|
||||
f"policy_loss: {results['policy_loss']}, "
|
||||
f"value_loss: {results['value_loss']}, "
|
||||
f"entropy_loss: {results['entropy_loss']}"
|
||||
)
|
||||
if self.monitor:
|
||||
self.monitor.put_data({os.getpid(): results})
|
||||
|
||||
self.last_report_time = now
|
||||
|
||||
return results
|
||||
|
||||
def _compute_loss(self, logits, value_pred, legal_action, old_action, old_prob, old_value, reward_sum, advantage):
|
||||
"""Compute standard PPO loss (policy + value + entropy).
|
||||
|
||||
计算标准 PPO 三项损失。
|
||||
"""
|
||||
# Value loss (clipped)
|
||||
# 价值损失(裁剪)
|
||||
tdret = reward_sum.squeeze(-1) if reward_sum.dim() > 1 else reward_sum
|
||||
vp = value_pred.squeeze(-1) if value_pred.dim() > 1 else value_pred
|
||||
ov = old_value.squeeze(-1) if old_value.dim() > 1 else old_value
|
||||
|
||||
vp_clip = ov + (vp - ov).clamp(-self.clip_param, self.clip_param)
|
||||
value_loss = (
|
||||
0.5
|
||||
* torch.maximum(
|
||||
(tdret - vp) ** 2,
|
||||
(tdret - vp_clip) ** 2,
|
||||
).mean()
|
||||
)
|
||||
|
||||
# Policy loss (PPO clip)
|
||||
# 策略损失(PPO clip)
|
||||
prob_dist = self._masked_softmax(logits, legal_action)
|
||||
entropy_loss = (-(prob_dist * torch.log(prob_dist.clamp(1e-9, 1))).sum(1)).mean()
|
||||
|
||||
one_hot = torch.nn.functional.one_hot(old_action[:, 0].long(), self.label_size).float()
|
||||
new_prob = (one_hot * prob_dist).sum(1, keepdim=True)
|
||||
old_action_prob = (one_hot * old_prob).sum(1, keepdim=True)
|
||||
|
||||
ratio = new_prob / old_action_prob.clamp(1e-9)
|
||||
|
||||
adv = advantage.squeeze(-1) if advantage.dim() > 1 else advantage
|
||||
adv = adv.unsqueeze(-1)
|
||||
|
||||
policy_loss = torch.maximum(
|
||||
-ratio * adv,
|
||||
-ratio.clamp(1 - self.clip_param, 1 + self.clip_param) * adv,
|
||||
).mean()
|
||||
|
||||
# Total loss
|
||||
# 总损失
|
||||
total_loss = self.vf_coef * value_loss + policy_loss - self.var_beta * entropy_loss
|
||||
|
||||
return total_loss, {
|
||||
"value_loss": value_loss.item(),
|
||||
"policy_loss": policy_loss.item(),
|
||||
"entropy_loss": entropy_loss.item(),
|
||||
}
|
||||
|
||||
def _masked_softmax(self, logits, legal_action):
|
||||
"""Apply legal action mask to logits before computing softmax.
|
||||
|
||||
对 logits 应用合法动作掩码后计算 softmax。
|
||||
"""
|
||||
label_max, _ = torch.max(logits * legal_action, dim=1, keepdim=True)
|
||||
logits = logits - label_max
|
||||
logits = logits * legal_action
|
||||
logits = logits + 1e5 * (legal_action - 1)
|
||||
return torch.nn.functional.softmax(logits, dim=1)
|
||||
0
agent_ppo/conf/__init__.py
Normal file
0
agent_ppo/conf/__init__.py
Normal file
49
agent_ppo/conf/conf.py
Normal file
49
agent_ppo/conf/conf.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Configuration for Robot Vacuum PPO agent.
|
||||
清扫大作战 PPO 配置。
|
||||
"""
|
||||
|
||||
|
||||
class Config:
|
||||
|
||||
# Feature dimensions (69D)
|
||||
# 特征维度(69D)
|
||||
FEATURES = [
|
||||
7 * 7,
|
||||
12,
|
||||
8,
|
||||
]
|
||||
FEATURE_SPLIT_SHAPE = FEATURES
|
||||
FEATURE_LEN = sum(FEATURES)
|
||||
DIM_OF_OBSERVATION = FEATURE_LEN
|
||||
|
||||
# Action space: 8 directional moves
|
||||
# 动作空间:8个方向移动
|
||||
ACTION_NUM = 8
|
||||
|
||||
# Single-head value
|
||||
# 单头价值
|
||||
VALUE_NUM = 1
|
||||
|
||||
# PPO hyperparameters
|
||||
# PPO 超参数
|
||||
GAMMA = 0.99
|
||||
LAMDA = 0.95
|
||||
|
||||
INIT_LEARNING_RATE_START = 0.0003
|
||||
BETA_START = 0.001
|
||||
CLIP_PARAM = 0.2
|
||||
VF_COEF = 0.5
|
||||
|
||||
LABEL_SIZE_LIST = [ACTION_NUM]
|
||||
LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy()
|
||||
|
||||
USE_GRAD_CLIP = True
|
||||
GRAD_CLIP_RANGE = 0.5
|
||||
83
agent_ppo/conf/monitor_builder.py
Normal file
83
agent_ppo/conf/monitor_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Monitor panel configuration builder for Robot Vacuum.
|
||||
清扫大作战监控面板配置构建器。
|
||||
"""
|
||||
|
||||
|
||||
from kaiwudrl.common.monitor.monitor_config_builder import MonitorConfigBuilder
|
||||
|
||||
|
||||
def build_monitor():
|
||||
"""
|
||||
# This function is used to create monitoring panel configurations for custom indicators.
|
||||
# 该函数用于创建自定义指标的监控面板配置。
|
||||
"""
|
||||
monitor = MonitorConfigBuilder()
|
||||
|
||||
config_dict = (
|
||||
monitor.title("清扫大作战")
|
||||
.add_group(
|
||||
group_name="算法指标",
|
||||
group_name_en="algorithm",
|
||||
)
|
||||
.add_panel(
|
||||
name="累积回报",
|
||||
name_en="reward",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="reward",
|
||||
expr="avg(reward{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="总损失",
|
||||
name_en="total_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="total_loss",
|
||||
expr="avg(total_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="价值损失",
|
||||
name_en="value_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="value_loss",
|
||||
expr="avg(value_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="策略损失",
|
||||
name_en="policy_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="policy_loss",
|
||||
expr="avg(policy_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="熵损失",
|
||||
name_en="entropy_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="entropy_loss",
|
||||
expr="avg(entropy_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.end_group()
|
||||
.build()
|
||||
)
|
||||
return config_dict
|
||||
26
agent_ppo/conf/train_env_conf.toml
Normal file
26
agent_ppo/conf/train_env_conf.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[env_conf]
|
||||
# Maps used for training. Customize by keeping only desired map IDs, e.g. [1, 2] for maps 1 and 2.
|
||||
# 训练使用的地图。可自定义选择期望用来训练的地图,如只期望使用1、2号地图训练数组内仅保留[1,2]即可。
|
||||
map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
||||
# Whether to randomly select maps. Boolean.
|
||||
# true = randomly pick one from configured maps per episode, false = used sequentially.
|
||||
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张,false表示按顺序抽取地图训练。
|
||||
map_random = false
|
||||
|
||||
# Number of official robots. Range: 1~4 (integer).
|
||||
# In each round, official robots will be randomly generated on the road according to the configured.
|
||||
# 官方机器人数量。可配置范围为1~4(整数)。每局将按照配置数量在道路上随机生成官方机器人。
|
||||
robot_count = 4
|
||||
|
||||
# Number of chargers. Range: 1~4 (integer). When less than 4, spawn points are randomly chosen.
|
||||
# 充电桩数量。可配置范围为1~4(整数)。当配置小于4时,将从每张地图可生成充电桩的点位随机选择对应数量的点位生成。
|
||||
charger_count = 4
|
||||
|
||||
# Maximum steps. The task ends when the predicted steps in a single round reach the maximum. Range: 1~2000.
|
||||
# 最大步数。单局任务预测步数达到最大步数时,任务结束。可配置范围为1~2000。
|
||||
max_step = 1000
|
||||
|
||||
# Maximum battery. The battery level when fully charged. Range: 100~999.
|
||||
# 最大电量。满电状态下的电量。可配置范围100~999。
|
||||
battery_max = 200
|
||||
0
agent_ppo/feature/__init__.py
Normal file
0
agent_ppo/feature/__init__.py
Normal file
73
agent_ppo/feature/definition.py
Normal file
73
agent_ppo/feature/definition.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Data definition and GAE computation for Robot Vacuum.
|
||||
清扫大作战数据类定义与 GAE 计算。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from common_python.utils.common_func import create_cls
|
||||
from agent_ppo.conf.conf import Config
|
||||
|
||||
|
||||
# ObsData: feature vector + legal action mask
|
||||
# 观测数据:feature 为特征向量,legal_action 为合法动作掩码
|
||||
ObsData = create_cls("ObsData", feature=None, legal_action=None)
|
||||
|
||||
# ActData: sampled action, greedy action, action probabilities, state value
|
||||
# 动作数据:action 为采样动作,d_action 为贪心动作,prob 为动作概率,value 为状态价值
|
||||
ActData = create_cls(
|
||||
"ActData",
|
||||
action=None,
|
||||
d_action=None,
|
||||
prob=None,
|
||||
value=None,
|
||||
)
|
||||
|
||||
# SampleData: int values are treated as dimensions by the framework
|
||||
# 训练样本数据:字段值为 int 时框架自动按维度处理
|
||||
SampleData = create_cls(
|
||||
"SampleData",
|
||||
obs=Config.DIM_OF_OBSERVATION, # 69D feature vector / 特征向量
|
||||
legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码
|
||||
act=1, # action index / 执行的动作
|
||||
reward=Config.VALUE_NUM, # 1D reward / 奖励
|
||||
reward_sum=Config.VALUE_NUM, # GAE td-lambda return
|
||||
done=1,
|
||||
value=Config.VALUE_NUM, # 1D value estimate / 价值估计
|
||||
next_value=Config.VALUE_NUM,
|
||||
advantage=Config.VALUE_NUM, # 1D GAE advantage / GAE 优势
|
||||
prob=Config.ACTION_NUM, # 8D action probabilities / 动作概率
|
||||
)
|
||||
|
||||
|
||||
def sample_process(list_sample_data):
|
||||
"""Fill next_value and compute GAE advantage.
|
||||
|
||||
计算 GAE 并填充 next_value。
|
||||
"""
|
||||
for i in range(len(list_sample_data) - 1):
|
||||
list_sample_data[i].next_value = list_sample_data[i + 1].value
|
||||
|
||||
_calc_gae(list_sample_data)
|
||||
return list_sample_data
|
||||
|
||||
|
||||
def _calc_gae(list_sample_data):
|
||||
"""Compute advantage and cumulative return using GAE(λ).
|
||||
|
||||
使用 GAE(λ) 计算优势函数与累积回报。
|
||||
"""
|
||||
gae = 0.0
|
||||
gamma = Config.GAMMA
|
||||
lamda = Config.LAMDA
|
||||
for sample in reversed(list_sample_data):
|
||||
delta = -sample.value + sample.reward + gamma * sample.next_value
|
||||
gae = gae * gamma * lamda + delta
|
||||
sample.advantage = gae
|
||||
sample.reward_sum = gae + sample.value
|
||||
257
agent_ppo/feature/preprocessor.py
Normal file
257
agent_ppo/feature/preprocessor.py
Normal file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Feature preprocessor for Robot Vacuum.
|
||||
清扫大作战特征预处理器。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _norm(v, v_max, v_min=0.0):
|
||||
"""Normalize value to [0, 1].
|
||||
|
||||
将值线性归一化到 [0, 1]。
|
||||
"""
|
||||
v = float(np.clip(v, v_min, v_max))
|
||||
if v_max == v_min:
|
||||
return 0.0
|
||||
return (v - v_min) / (v_max - v_min)
|
||||
|
||||
|
||||
class Preprocessor:
|
||||
"""Feature preprocessor for Robot Vacuum.
|
||||
|
||||
清扫大作战特征预处理器。
|
||||
"""
|
||||
|
||||
GRID_SIZE = 128
|
||||
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
|
||||
LOCAL_HALF = 3 # Cropped view radius (7×7) / 裁剪后的视野半径
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all internal state at episode start.
|
||||
|
||||
对局开始时重置所有状态。
|
||||
"""
|
||||
self.step_no = 0
|
||||
self.battery = 600
|
||||
self.battery_max = 600
|
||||
|
||||
self.cur_pos = (0, 0)
|
||||
|
||||
self.dirt_cleaned = 0
|
||||
self.last_dirt_cleaned = 0
|
||||
self.total_dirt = 1
|
||||
|
||||
# Global passable map (0=obstacle, 1=passable), used for ray computation
|
||||
# 维护全局通行地图(0=障碍, 1=可通行),用于射线计算
|
||||
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||
|
||||
# Nearest dirt distance
|
||||
# 最近污渍距离
|
||||
self.nearest_dirt_dist = 200.0
|
||||
self.last_nearest_dirt_dist = 200.0
|
||||
|
||||
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
||||
self._legal_act = [1] * 8
|
||||
|
||||
def pb2struct(self, env_obs, last_action):
|
||||
"""Parse and cache essential fields from observation dict.
|
||||
|
||||
从 env_obs 字典中提取并缓存所有需要的状态量。
|
||||
"""
|
||||
observation = env_obs["observation"]
|
||||
frame_state = observation["frame_state"]
|
||||
env_info = observation["env_info"]
|
||||
hero = frame_state["heroes"]
|
||||
|
||||
self.step_no = int(observation["step_no"])
|
||||
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
|
||||
|
||||
# Battery / 电量
|
||||
self.battery = int(hero["battery"])
|
||||
self.battery_max = max(int(hero["battery_max"]), 1)
|
||||
|
||||
# Cleaning progress / 清扫进度
|
||||
self.last_dirt_cleaned = self.dirt_cleaned
|
||||
self.dirt_cleaned = int(hero["dirt_cleaned"])
|
||||
self.total_dirt = max(int(env_info["total_dirt"]), 1)
|
||||
|
||||
# Legal actions / 合法动作
|
||||
self._legal_act = [int(x) for x in (observation.get("legal_action") or [1] * 8)]
|
||||
|
||||
# Local view map (21×21) / 局部视野地图
|
||||
map_info = observation.get("map_info")
|
||||
if map_info is not None:
|
||||
self._view_map = np.array(map_info, dtype=np.float32)
|
||||
hx, hz = self.cur_pos
|
||||
self._update_passable(hx, hz)
|
||||
|
||||
def _update_passable(self, hx, hz):
|
||||
"""Write local view into global passable map.
|
||||
|
||||
将局部视野写入全局通行地图。
|
||||
"""
|
||||
view = self._view_map
|
||||
vsize = view.shape[0]
|
||||
half = vsize // 2
|
||||
|
||||
for ri in range(vsize):
|
||||
for ci in range(vsize):
|
||||
gx = hx - half + ri
|
||||
gz = hz - half + ci
|
||||
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
|
||||
# 0 = obstacle, 1/2 = passable
|
||||
# 0 = 障碍, 1/2 = 可通行
|
||||
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
||||
|
||||
def _get_local_view_feature(self):
|
||||
"""Local view feature (49D): crop center 7×7 from 21×21.
|
||||
|
||||
局部视野特征(49D):从 21×21 视野中心裁剪 7×7。
|
||||
"""
|
||||
center = self.VIEW_HALF
|
||||
h = self.LOCAL_HALF
|
||||
crop = self._view_map[center - h : center + h + 1, center - h : center + h + 1]
|
||||
return (crop / 2.0).flatten()
|
||||
|
||||
def _get_global_state_feature(self):
|
||||
"""Global state feature (12D).
|
||||
|
||||
全局状态特征(12D)。
|
||||
|
||||
Dimensions / 维度说明:
|
||||
[0] step_norm step progress / 步数归一化 [0,1]
|
||||
[1] battery_ratio battery level / 电量比 [0,1]
|
||||
[2] cleaning_progress cleaned ratio / 已清扫比例 [0,1]
|
||||
[3] remaining_dirt remaining dirt ratio / 剩余污渍比例 [0,1]
|
||||
[4] pos_x_norm x position / x 坐标归一化 [0,1]
|
||||
[5] pos_z_norm z position / z 坐标归一化 [0,1]
|
||||
[6] ray_N_dirt north ray distance / 向上(z-)方向最近污渍距离
|
||||
[7] ray_E_dirt east ray distance / 向右(x+)方向
|
||||
[8] ray_S_dirt south ray distance / 向下(z+)方向
|
||||
[9] ray_W_dirt west ray distance / 向左(x-)方向
|
||||
[10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化
|
||||
[11] dirt_delta approaching dirt indicator / 是否在接近污渍(1=是, 0=否)
|
||||
"""
|
||||
step_norm = _norm(self.step_no, 2000)
|
||||
battery_ratio = _norm(self.battery, self.battery_max)
|
||||
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
|
||||
remaining_dirt = 1.0 - cleaning_progress
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
pos_x_norm = _norm(hx, self.GRID_SIZE)
|
||||
pos_z_norm = _norm(hz, self.GRID_SIZE)
|
||||
|
||||
# 4-directional ray to find nearest dirt
|
||||
# 四方向射线找最近污渍距离
|
||||
ray_dirs = [(0, -1), (1, 0), (0, 1), (-1, 0)] # N E S W
|
||||
ray_dirt = []
|
||||
max_ray = 30
|
||||
for dx, dz in ray_dirs:
|
||||
x, z = hx, hz
|
||||
found = max_ray
|
||||
for step in range(1, max_ray + 1):
|
||||
x += dx
|
||||
z += dz
|
||||
if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
|
||||
break
|
||||
if self._view_map is not None:
|
||||
cell = (
|
||||
int(
|
||||
self._view_map[
|
||||
np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20)
|
||||
]
|
||||
)
|
||||
if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21)
|
||||
else 0
|
||||
)
|
||||
if cell == 2:
|
||||
found = step
|
||||
break
|
||||
ray_dirt.append(_norm(found, max_ray))
|
||||
|
||||
# Nearest dirt Euclidean distance (estimated from 7×7 crop)
|
||||
# 最近污渍欧氏距离(视野内 7×7 粗估)
|
||||
self.last_nearest_dirt_dist = self.nearest_dirt_dist
|
||||
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
|
||||
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
||||
|
||||
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
|
||||
|
||||
return np.array(
|
||||
[
|
||||
step_norm,
|
||||
battery_ratio,
|
||||
cleaning_progress,
|
||||
remaining_dirt,
|
||||
pos_x_norm,
|
||||
pos_z_norm,
|
||||
ray_dirt[0],
|
||||
ray_dirt[1],
|
||||
ray_dirt[2],
|
||||
ray_dirt[3],
|
||||
nearest_dirt_norm,
|
||||
dirt_delta,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
def _calc_nearest_dirt_dist(self):
|
||||
"""Find nearest dirt Euclidean distance from local view.
|
||||
|
||||
从局部视野中找最近污渍的欧氏距离。
|
||||
"""
|
||||
view = self._view_map
|
||||
if view is None:
|
||||
return 200.0
|
||||
dirt_coords = np.argwhere(view == 2)
|
||||
if len(dirt_coords) == 0:
|
||||
return 200.0
|
||||
center = self.VIEW_HALF
|
||||
dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2)
|
||||
return float(np.min(dists))
|
||||
|
||||
def get_legal_action(self):
|
||||
"""Return legal action mask (8D list).
|
||||
|
||||
返回合法动作掩码(8D list)。
|
||||
"""
|
||||
return list(self._legal_act)
|
||||
|
||||
def feature_process(self, env_obs, last_action):
|
||||
"""Generate 69D feature vector, legal action mask, and scalar reward.
|
||||
|
||||
生成 69D 特征向量、合法动作掩码和标量奖励。
|
||||
"""
|
||||
self.pb2struct(env_obs, last_action)
|
||||
|
||||
local_view = self._get_local_view_feature() # 49D
|
||||
global_state = self._get_global_state_feature() # 12D
|
||||
legal_action = self.get_legal_action() # 8D
|
||||
legal_arr = np.array(legal_action, dtype=np.float32)
|
||||
|
||||
feature = np.concatenate([local_view, global_state, legal_arr]) # 69D
|
||||
|
||||
reward = self.reward_process()
|
||||
|
||||
return feature, legal_action, reward
|
||||
|
||||
def reward_process(self):
|
||||
# Cleaning reward / 清扫奖励
|
||||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||||
cleaning_reward = 0.1 * cleaned_this_step
|
||||
|
||||
# Step penalty / 时间惩罚
|
||||
step_penalty = -0.001
|
||||
|
||||
return cleaning_reward + step_penalty
|
||||
0
agent_ppo/model/__init__.py
Normal file
0
agent_ppo/model/__init__.py
Normal file
73
agent_ppo/model/model.py
Normal file
73
agent_ppo/model/model.py
Normal file
@@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Simple MLP policy network for Robot Vacuum.
|
||||
清扫大作战策略网络。
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from agent_ppo.conf.conf import Config
|
||||
|
||||
|
||||
def _make_fc(in_dim, out_dim, gain=1.41421):
|
||||
"""Create a linear layer with orthogonal initialization.
|
||||
|
||||
创建正交初始化的线性层。
|
||||
"""
|
||||
layer = nn.Linear(in_dim, out_dim)
|
||||
nn.init.orthogonal_(layer.weight, gain=gain)
|
||||
nn.init.zeros_(layer.bias)
|
||||
return layer
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""Dual-head MLP for Robot Vacuum.
|
||||
|
||||
清扫大作战双头 MLP 策略网络。
|
||||
"""
|
||||
|
||||
def __init__(self, device=None):
|
||||
super().__init__()
|
||||
self.model_name = "robot_vacuum"
|
||||
self.device = device
|
||||
|
||||
obs_dim = Config.DIM_OF_OBSERVATION # 69
|
||||
act_num = Config.ACTION_NUM # 8
|
||||
|
||||
# Shared backbone / 共享骨干网络
|
||||
self.backbone = nn.Sequential(
|
||||
_make_fc(obs_dim, 128),
|
||||
nn.ReLU(),
|
||||
_make_fc(128, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Actor head: outputs action logits / 策略头:输出动作 logits
|
||||
self.actor_head = _make_fc(64, act_num, gain=0.01)
|
||||
|
||||
# Critic head: outputs single state value / 价值头:输出单个状态价值
|
||||
self.critic_head = _make_fc(64, 1, gain=0.01)
|
||||
|
||||
def forward(self, s, inference=False):
|
||||
"""Forward pass.
|
||||
|
||||
前向传播。
|
||||
"""
|
||||
x = s.to(torch.float32)
|
||||
h = self.backbone(x)
|
||||
logits = self.actor_head(h)
|
||||
value = self.critic_head(h)
|
||||
return [logits, value]
|
||||
|
||||
def set_train_mode(self):
|
||||
self.train()
|
||||
|
||||
def set_eval_mode(self):
|
||||
self.eval()
|
||||
0
agent_ppo/workflow/__init__.py
Normal file
0
agent_ppo/workflow/__init__.py
Normal file
201
agent_ppo/workflow/train_workflow.py
Normal file
201
agent_ppo/workflow/train_workflow.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Training workflow for Robot Vacuum.
|
||||
清扫大作战训练工作流。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agent_ppo.conf.conf import Config
|
||||
from agent_ppo.feature.definition import SampleData, sample_process
|
||||
from tools.metrics_utils import get_training_metrics
|
||||
from tools.train_env_conf_validate import read_usr_conf
|
||||
from common_python.utils.workflow_disaster_recovery import handle_disaster_recovery
|
||||
|
||||
|
||||
def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
last_save_model_time = time.time()
|
||||
env = envs[0]
|
||||
agent = agents[0]
|
||||
|
||||
# Read and validate user configuration
|
||||
# 读取和校验用户配置
|
||||
usr_conf = read_usr_conf("agent_ppo/conf/train_env_conf.toml", logger)
|
||||
if usr_conf is None:
|
||||
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
|
||||
return
|
||||
|
||||
episode_runner = EpisodeRunner(
|
||||
env=env,
|
||||
agent=agent,
|
||||
usr_conf=usr_conf,
|
||||
logger=logger,
|
||||
monitor=monitor,
|
||||
)
|
||||
|
||||
while True:
|
||||
for g_data in episode_runner.run_episodes():
|
||||
agent.send_sample_data(g_data)
|
||||
g_data.clear()
|
||||
|
||||
now = time.time()
|
||||
if now - last_save_model_time >= 1800:
|
||||
agent.save_model()
|
||||
last_save_model_time = now
|
||||
|
||||
|
||||
class EpisodeRunner:
|
||||
def __init__(self, env, agent, usr_conf, logger, monitor):
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.usr_conf = usr_conf
|
||||
self.logger = logger
|
||||
self.monitor = monitor
|
||||
self.episode_cnt = 0
|
||||
self.last_report_monitor_time = 0
|
||||
self.last_get_training_metrics_time = 0
|
||||
|
||||
def run_episodes(self):
|
||||
"""Run a single episode and yield collected samples.
|
||||
|
||||
单局流程(generator),完成一局后 yield 整局样本。
|
||||
"""
|
||||
while True:
|
||||
# Periodically get training metrics
|
||||
# 定期打印训练指标
|
||||
now = time.time()
|
||||
if now - self.last_get_training_metrics_time >= 60:
|
||||
training_metrics = get_training_metrics()
|
||||
self.last_get_training_metrics_time = now
|
||||
if training_metrics is not None:
|
||||
self.logger.info(f"training_metrics: {training_metrics}")
|
||||
|
||||
# Reset environment
|
||||
# 重置环境
|
||||
env_obs = self.env.reset(self.usr_conf)
|
||||
if handle_disaster_recovery(env_obs, self.logger):
|
||||
continue
|
||||
|
||||
# Reset agent and load latest model
|
||||
# 重置 Agent,加载最新模型
|
||||
self.agent.reset(env_obs)
|
||||
self.agent.load_model(id="latest")
|
||||
|
||||
# Initial observation processing
|
||||
# 初始观测
|
||||
obs_data, remain_info = self.agent.observation_process(env_obs)
|
||||
|
||||
collector = []
|
||||
self.episode_cnt += 1
|
||||
done = False
|
||||
step = 0
|
||||
total_reward = 0.0
|
||||
|
||||
self.logger.info(f"Episode {self.episode_cnt} start")
|
||||
|
||||
while not done:
|
||||
# Agent inference / 推理动作
|
||||
act_data_list = self.agent.predict([obs_data])
|
||||
act_data = act_data_list[0]
|
||||
act = self.agent.action_process(act_data)
|
||||
|
||||
# Environment step / 与环境交互
|
||||
env_reward, env_obs = self.env.step(act)
|
||||
if handle_disaster_recovery(env_obs, self.logger):
|
||||
break
|
||||
|
||||
terminated = env_obs["terminated"]
|
||||
truncated = env_obs["truncated"]
|
||||
frame_no = env_obs["frame_no"]
|
||||
step += 1
|
||||
done = terminated or truncated
|
||||
|
||||
# Process next observation
|
||||
# 特征处理
|
||||
_obs_data, _ = self.agent.observation_process(env_obs)
|
||||
_obs_data.frame_no = frame_no
|
||||
|
||||
reward_scalar = float(self.agent.last_reward)
|
||||
total_reward += reward_scalar
|
||||
|
||||
# Terminal reward calculation
|
||||
# 终局奖励
|
||||
final_reward = 0.0
|
||||
if done:
|
||||
fm = self.agent.preprocessor
|
||||
total_score = env_obs["observation"]["env_info"]["total_score"]
|
||||
|
||||
if truncated:
|
||||
# Survived to max steps: higher cleaning ratio → more reward
|
||||
# 存活到最大步数:清扫比例越高奖励越多
|
||||
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
||||
final_reward = 5.0 + 5.0 * cleaning_ratio
|
||||
result_str = "WIN"
|
||||
else:
|
||||
# Early termination (battery depleted or collision): small penalty
|
||||
# 提前结束(电量耗尽或碰撞):小惩罚
|
||||
final_reward = -2.0
|
||||
result_str = "FAIL"
|
||||
|
||||
self.logger.info(
|
||||
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
|
||||
f"result:{result_str} final_bonus:{final_reward:.2f} "
|
||||
f"total_reward:{total_reward:.3f} "
|
||||
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt}"
|
||||
)
|
||||
|
||||
# Build sample frame
|
||||
# 构造样本帧
|
||||
reward_arr = np.array([reward_scalar], dtype=np.float32)
|
||||
value_arr = act_data.value.flatten()[: Config.VALUE_NUM]
|
||||
|
||||
frame = SampleData(
|
||||
obs=np.array(obs_data.feature, dtype=np.float32),
|
||||
legal_action=np.array(obs_data.legal_action, dtype=np.float32),
|
||||
act=np.array(act_data.action),
|
||||
reward=reward_arr,
|
||||
done=np.array([float(done)]),
|
||||
reward_sum=np.zeros(Config.VALUE_NUM, dtype=np.float32),
|
||||
value=value_arr,
|
||||
next_value=np.zeros(Config.VALUE_NUM, dtype=np.float32),
|
||||
advantage=np.zeros(Config.VALUE_NUM, dtype=np.float32),
|
||||
prob=np.array(act_data.prob, dtype=np.float32),
|
||||
)
|
||||
collector.append(frame)
|
||||
|
||||
if done:
|
||||
# Add terminal reward to last frame
|
||||
# 终局奖励叠加到最后一步
|
||||
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
|
||||
|
||||
# Monitor reporting / 监控上报
|
||||
now = time.time()
|
||||
if now - self.last_report_monitor_time >= 60 and self.monitor:
|
||||
self.monitor.put_data(
|
||||
{
|
||||
os.getpid(): {
|
||||
"reward": total_reward + final_reward,
|
||||
"episode_cnt": self.episode_cnt,
|
||||
}
|
||||
}
|
||||
)
|
||||
self.last_report_monitor_time = now
|
||||
|
||||
# Compute GAE and yield samples
|
||||
# GAE 计算并 yield 样本
|
||||
if collector:
|
||||
collector = sample_process(collector)
|
||||
yield collector
|
||||
break
|
||||
|
||||
# Advance state / 状态推进
|
||||
obs_data = _obs_data
|
||||
Reference in New Issue
Block a user