Initial robot vacuum code

This commit is contained in:
2026-04-26 12:38:39 +08:00
commit ca6234c941
38 changed files with 1673 additions and 0 deletions

175
agent_ppo/agent.py Normal file
View File

@@ -0,0 +1,175 @@
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
###########################################################################
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
###########################################################################
"""
Author: Tencent AI Arena Authors
Robot Vacuum Agent.
清扫大作战 Agent 主类。
"""
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
import numpy as np
from agent_ppo.algorithm.algorithm import Algorithm
from agent_ppo.conf.conf import Config
from agent_ppo.feature.definition import ActData, ObsData
from agent_ppo.feature.preprocessor import Preprocessor
from agent_ppo.model.model import Model
from kaiwudrl.interface.agent import BaseAgent
class Agent(BaseAgent):
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
torch.manual_seed(0)
self.device = device
self.model = Model(device).to(self.device)
self.optimizer = torch.optim.Adam(
params=self.model.parameters(),
lr=Config.INIT_LEARNING_RATE_START,
betas=(0.9, 0.999),
eps=1e-8,
)
self.logger = logger
self.monitor = monitor
self.algorithm = Algorithm(self.model, self.optimizer, self.device, self.logger, self.monitor)
self.preprocessor = Preprocessor()
self.last_action = -1
self.last_reward = 0.0
super().__init__(agent_type, device, logger, monitor)
def reset(self, env_obs):
"""Reset per-episode state.
每局开始时重置 Agent 内部状态。
"""
self.preprocessor = Preprocessor()
self.last_action = -1
self.last_reward = 0.0
def observation_process(self, env_obs):
"""Convert raw env_obs to ObsData (69D feature + legal action mask).
将原始 env_obs 转换为 ObsData69D 特征 + 合法动作掩码)。
"""
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
self.last_reward = reward
obs_data = ObsData(
feature=list(feature),
legal_action=legal_action,
)
remain_info = {}
return obs_data, remain_info
def action_process(self, act_data, is_stochastic=True):
"""Extract int action from ActData and update last_action.
从 ActData 中取出动作整数并更新 last_action。
"""
action = act_data.action if is_stochastic else act_data.d_action
self.last_action = int(action[0])
return self.last_action
def predict(self, list_obs_data):
"""Stochastic inference for training (exploration).
训练时推理(随机采样动作)。
"""
obs_data = list_obs_data[0]
feature = obs_data.feature
legal_action = obs_data.legal_action
logits, value = self._run_model(feature)
legal_arr = np.array(legal_action, dtype=np.float32)
prob = self._legal_soft_max(logits, legal_arr)
action = self._legal_sample(prob, use_max=False)
d_action = self._legal_sample(prob, use_max=True)
return [
ActData(
action=[action],
d_action=[d_action],
prob=list(prob),
value=value,
)
]
def exploit(self, env_obs):
"""Greedy inference for evaluation.
评估时推理(贪心)。
"""
obs_data, _ = self.observation_process(env_obs)
act_data = self.predict([obs_data])[0]
return self.action_process(act_data, is_stochastic=False)
def learn(self, list_sample_data):
"""Delegate to Algorithm for PPO update.
委托给 Algorithm 执行训练。
"""
return self.algorithm.learn(list_sample_data)
def save_model(self, path=None, id="1"):
"""Save model checkpoint.
保存模型检查点。
"""
model_file_path = f"{path}/model.ckpt-{id}.pkl"
state_dict_cpu = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
torch.save(state_dict_cpu, model_file_path)
self.logger.info(f"save model {model_file_path} successfully")
def load_model(self, path=None, id="1"):
"""Load model checkpoint.
加载模型检查点。
"""
model_file_path = f"{path}/model.ckpt-{id}.pkl"
self.model.load_state_dict(torch.load(model_file_path, map_location=self.device))
self.logger.info(f"load model {model_file_path} successfully")
def _run_model(self, feature):
"""Gradient-free forward pass, returns (logits_np, value_np).
无梯度推理,返回 (logits_np, value_np)。
"""
self.model.set_eval_mode()
obs_tensor = (
torch.tensor(np.array([feature], dtype=np.float32)).view(1, Config.DIM_OF_OBSERVATION).to(self.device)
)
with torch.no_grad():
rst = self.model(obs_tensor, inference=True)
logits = rst[0].cpu().numpy()[0]
value = rst[1].cpu().numpy()[0]
return logits, value
def _legal_soft_max(self, logits, legal_action):
"""Softmax with legal action masking.
合法动作掩码下的 softmax。
"""
_w, _e = 1e20, 1e-5
tmp = logits - _w * (1.0 - legal_action)
tmp_max = np.max(tmp, keepdims=True)
tmp = np.clip(tmp - tmp_max, -_w, 1)
tmp = (np.exp(tmp) + _e) * legal_action
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
def _legal_sample(self, probs, use_max=False):
"""Sample action from probability distribution (argmax if use_max=True).
按概率分布采样动作use_max=True 时取 argmax
"""
if use_max:
return int(np.argmax(probs))
return int(np.argmax(np.random.multinomial(1, probs, size=1)))