This repository has been archived on 2026-05-02. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
-----/agent_ppo/agent.py
2026-04-26 12:38:39 +08:00

176 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
###########################################################################
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
###########################################################################
"""
Author: Tencent AI Arena Authors
Robot Vacuum Agent.
清扫大作战 Agent 主类。
"""
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
import numpy as np
from agent_ppo.algorithm.algorithm import Algorithm
from agent_ppo.conf.conf import Config
from agent_ppo.feature.definition import ActData, ObsData
from agent_ppo.feature.preprocessor import Preprocessor
from agent_ppo.model.model import Model
from kaiwudrl.interface.agent import BaseAgent
class Agent(BaseAgent):
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
torch.manual_seed(0)
self.device = device
self.model = Model(device).to(self.device)
self.optimizer = torch.optim.Adam(
params=self.model.parameters(),
lr=Config.INIT_LEARNING_RATE_START,
betas=(0.9, 0.999),
eps=1e-8,
)
self.logger = logger
self.monitor = monitor
self.algorithm = Algorithm(self.model, self.optimizer, self.device, self.logger, self.monitor)
self.preprocessor = Preprocessor()
self.last_action = -1
self.last_reward = 0.0
super().__init__(agent_type, device, logger, monitor)
def reset(self, env_obs):
"""Reset per-episode state.
每局开始时重置 Agent 内部状态。
"""
self.preprocessor = Preprocessor()
self.last_action = -1
self.last_reward = 0.0
def observation_process(self, env_obs):
"""Convert raw env_obs to ObsData (69D feature + legal action mask).
将原始 env_obs 转换为 ObsData69D 特征 + 合法动作掩码)。
"""
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
self.last_reward = reward
obs_data = ObsData(
feature=list(feature),
legal_action=legal_action,
)
remain_info = {}
return obs_data, remain_info
def action_process(self, act_data, is_stochastic=True):
"""Extract int action from ActData and update last_action.
从 ActData 中取出动作整数并更新 last_action。
"""
action = act_data.action if is_stochastic else act_data.d_action
self.last_action = int(action[0])
return self.last_action
def predict(self, list_obs_data):
"""Stochastic inference for training (exploration).
训练时推理(随机采样动作)。
"""
obs_data = list_obs_data[0]
feature = obs_data.feature
legal_action = obs_data.legal_action
logits, value = self._run_model(feature)
legal_arr = np.array(legal_action, dtype=np.float32)
prob = self._legal_soft_max(logits, legal_arr)
action = self._legal_sample(prob, use_max=False)
d_action = self._legal_sample(prob, use_max=True)
return [
ActData(
action=[action],
d_action=[d_action],
prob=list(prob),
value=value,
)
]
def exploit(self, env_obs):
"""Greedy inference for evaluation.
评估时推理(贪心)。
"""
obs_data, _ = self.observation_process(env_obs)
act_data = self.predict([obs_data])[0]
return self.action_process(act_data, is_stochastic=False)
def learn(self, list_sample_data):
"""Delegate to Algorithm for PPO update.
委托给 Algorithm 执行训练。
"""
return self.algorithm.learn(list_sample_data)
def save_model(self, path=None, id="1"):
"""Save model checkpoint.
保存模型检查点。
"""
model_file_path = f"{path}/model.ckpt-{id}.pkl"
state_dict_cpu = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
torch.save(state_dict_cpu, model_file_path)
self.logger.info(f"save model {model_file_path} successfully")
def load_model(self, path=None, id="1"):
"""Load model checkpoint.
加载模型检查点。
"""
model_file_path = f"{path}/model.ckpt-{id}.pkl"
self.model.load_state_dict(torch.load(model_file_path, map_location=self.device))
self.logger.info(f"load model {model_file_path} successfully")
def _run_model(self, feature):
"""Gradient-free forward pass, returns (logits_np, value_np).
无梯度推理,返回 (logits_np, value_np)。
"""
self.model.set_eval_mode()
obs_tensor = (
torch.tensor(np.array([feature], dtype=np.float32)).view(1, Config.DIM_OF_OBSERVATION).to(self.device)
)
with torch.no_grad():
rst = self.model(obs_tensor, inference=True)
logits = rst[0].cpu().numpy()[0]
value = rst[1].cpu().numpy()[0]
return logits, value
def _legal_soft_max(self, logits, legal_action):
"""Softmax with legal action masking.
合法动作掩码下的 softmax。
"""
_w, _e = 1e20, 1e-5
tmp = logits - _w * (1.0 - legal_action)
tmp_max = np.max(tmp, keepdims=True)
tmp = np.clip(tmp - tmp_max, -_w, 1)
tmp = (np.exp(tmp) + _e) * legal_action
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
def _legal_sample(self, probs, use_max=False):
"""Sample action from probability distribution (argmax if use_max=True).
按概率分布采样动作use_max=True 时取 argmax
"""
if use_max:
return int(np.argmax(probs))
return int(np.argmax(np.random.multinomial(1, probs, size=1)))