210 lines
7.4 KiB
Python
210 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: UTF-8 -*-
|
||
###########################################################################
|
||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||
###########################################################################
|
||
"""
|
||
Author: Tencent AI Arena Authors
|
||
|
||
Robot Vacuum Agent.
|
||
清扫大作战 Agent 主类。
|
||
"""
|
||
|
||
import torch
|
||
|
||
torch.set_num_threads(1)
|
||
torch.set_num_interop_threads(1)
|
||
|
||
import numpy as np
|
||
|
||
from agent_ppo.algorithm.algorithm import Algorithm
|
||
from agent_ppo.conf.conf import Config
|
||
from agent_ppo.feature.definition import ActData, ObsData
|
||
from agent_ppo.feature.preprocessor import Preprocessor
|
||
from agent_ppo.model.model import Model
|
||
from kaiwudrl.interface.agent import BaseAgent
|
||
|
||
|
||
class Agent(BaseAgent):
|
||
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
|
||
torch.manual_seed(0)
|
||
self.device = device
|
||
self.model = Model(device).to(self.device)
|
||
self.optimizer = torch.optim.Adam(
|
||
params=self.model.parameters(),
|
||
lr=Config.INIT_LEARNING_RATE_START,
|
||
betas=(0.9, 0.999),
|
||
eps=1e-8,
|
||
)
|
||
self.logger = logger
|
||
self.monitor = monitor
|
||
self.algorithm = Algorithm(self.model, self.optimizer, self.device, self.logger, self.monitor)
|
||
self.preprocessor = Preprocessor()
|
||
self.last_action = -1
|
||
self.last_reward = 0.0
|
||
|
||
super().__init__(agent_type, device, logger, monitor)
|
||
|
||
def reset(self, env_obs):
|
||
"""Reset per-episode state.
|
||
|
||
每局开始时重置 Agent 内部状态。
|
||
"""
|
||
self.preprocessor = Preprocessor()
|
||
self.last_action = -1
|
||
self.last_reward = 0.0
|
||
|
||
def observation_process(self, env_obs):
|
||
"""Convert raw env_obs to ObsData (feature + legal action mask).
|
||
|
||
将原始 env_obs 转换为 ObsData(特征 + 合法动作掩码)。
|
||
"""
|
||
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
||
self.last_reward = reward
|
||
|
||
obs_data = ObsData(
|
||
feature=list(feature),
|
||
legal_action=legal_action,
|
||
)
|
||
remain_info = {}
|
||
return obs_data, remain_info
|
||
|
||
def action_process(self, act_data, is_stochastic=True):
|
||
"""Extract int action from ActData and update last_action.
|
||
|
||
从 ActData 中取出动作整数并更新 last_action。
|
||
"""
|
||
action = act_data.action if is_stochastic else act_data.d_action
|
||
self.last_action = int(action[0])
|
||
return self.last_action
|
||
|
||
def predict(self, list_obs_data):
|
||
"""Stochastic inference for training (exploration).
|
||
|
||
训练时推理(随机采样动作)。
|
||
"""
|
||
obs_data = list_obs_data[0]
|
||
feature = obs_data.feature
|
||
legal_action = obs_data.legal_action
|
||
|
||
logits, value = self._run_model(feature)
|
||
|
||
legal_arr = np.array(legal_action, dtype=np.float32)
|
||
prob = self._legal_soft_max(logits, legal_arr)
|
||
action = self._legal_sample(prob, use_max=False)
|
||
d_action = self._legal_sample(prob, use_max=True)
|
||
|
||
return [
|
||
ActData(
|
||
action=[action],
|
||
d_action=[d_action],
|
||
prob=list(prob),
|
||
value=value,
|
||
)
|
||
]
|
||
|
||
def exploit(self, env_obs):
|
||
"""Greedy inference for evaluation.
|
||
|
||
评估时推理(贪心)。
|
||
"""
|
||
try:
|
||
obs_data, _ = self.observation_process(env_obs)
|
||
act_data = self.predict([obs_data])[0]
|
||
return self.action_process(act_data, is_stochastic=False)
|
||
except Exception as err:
|
||
if self.logger:
|
||
if hasattr(self.logger, "exception"):
|
||
self.logger.exception(f"exploit fallback action due to inference error: {err}")
|
||
else:
|
||
self.logger.error(f"exploit fallback action due to inference error: {err}")
|
||
return self._fallback_action(env_obs)
|
||
|
||
def learn(self, list_sample_data):
|
||
"""Delegate to Algorithm for PPO update.
|
||
|
||
委托给 Algorithm 执行训练。
|
||
"""
|
||
return self.algorithm.learn(list_sample_data)
|
||
|
||
def save_model(self, path=None, id="1"):
|
||
"""Save model checkpoint.
|
||
|
||
保存模型检查点。
|
||
"""
|
||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||
state_dict_cpu = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
|
||
torch.save(state_dict_cpu, model_file_path)
|
||
self.logger.info(f"save model {model_file_path} successfully")
|
||
|
||
def load_model(self, path=None, id="1"):
|
||
"""Load model checkpoint.
|
||
|
||
加载模型检查点。
|
||
"""
|
||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||
state_dict = torch.load(model_file_path, map_location=self.device)
|
||
try:
|
||
self.model.load_state_dict(state_dict)
|
||
except RuntimeError as err:
|
||
msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}"
|
||
if self.logger:
|
||
self.logger.warning(msg)
|
||
return
|
||
if self.logger:
|
||
self.logger.info(f"load model {model_file_path} successfully")
|
||
|
||
def _fallback_action(self, env_obs):
|
||
"""Return a valid action instead of None during evaluation failures."""
|
||
legal_action = [1] * Config.ACTION_NUM
|
||
if isinstance(env_obs, dict):
|
||
observation = env_obs.get("observation")
|
||
if isinstance(observation, dict):
|
||
raw_legal = observation.get("legal_action") or observation.get("legal_act")
|
||
if isinstance(raw_legal, (list, tuple)) and raw_legal:
|
||
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
|
||
if len(legal_action) < Config.ACTION_NUM:
|
||
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
|
||
for action, is_legal in enumerate(legal_action):
|
||
if is_legal > 0:
|
||
self.last_action = action
|
||
return action
|
||
self.last_action = 0
|
||
return 0
|
||
|
||
def _run_model(self, feature):
|
||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||
|
||
无梯度推理,返回 (logits_np, value_np)。
|
||
"""
|
||
self.model.set_eval_mode()
|
||
obs_tensor = (
|
||
torch.tensor(np.array([feature], dtype=np.float32)).view(1, Config.DIM_OF_OBSERVATION).to(self.device)
|
||
)
|
||
with torch.no_grad():
|
||
rst = self.model(obs_tensor, inference=True)
|
||
logits = rst[0].cpu().numpy()[0]
|
||
value = rst[1].cpu().numpy()[0]
|
||
return logits, value
|
||
|
||
def _legal_soft_max(self, logits, legal_action):
|
||
"""Softmax with legal action masking.
|
||
|
||
合法动作掩码下的 softmax。
|
||
"""
|
||
_w, _e = 1e20, 1e-5
|
||
tmp = logits - _w * (1.0 - legal_action)
|
||
tmp_max = np.max(tmp, keepdims=True)
|
||
tmp = np.clip(tmp - tmp_max, -_w, 1)
|
||
tmp = (np.exp(tmp) + _e) * legal_action
|
||
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
|
||
|
||
def _legal_sample(self, probs, use_max=False):
|
||
"""Sample action from probability distribution (argmax if use_max=True).
|
||
|
||
按概率分布采样动作(use_max=True 时取 argmax)。
|
||
"""
|
||
if use_max:
|
||
return int(np.argmax(probs))
|
||
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
|