266 lines
9.9 KiB
Python
266 lines
9.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: UTF-8 -*-
|
||
###########################################################################
|
||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||
###########################################################################
|
||
"""
|
||
Author: Tencent AI Arena Authors
|
||
|
||
Robot Vacuum Agent.
|
||
清扫大作战 Agent 主类。
|
||
"""
|
||
|
||
import torch
|
||
|
||
torch.set_num_threads(1)
|
||
torch.set_num_interop_threads(1)
|
||
|
||
import numpy as np
|
||
|
||
from agent_ppo.algorithm.algorithm import Algorithm
|
||
from agent_ppo.conf.conf import Config
|
||
from agent_ppo.feature.definition import ActData, ObsData
|
||
from agent_ppo.feature.preprocessor import Preprocessor
|
||
from agent_ppo.model.model import Model
|
||
from kaiwudrl.interface.agent import BaseAgent
|
||
|
||
|
||
class Agent(BaseAgent):
|
||
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
|
||
torch.manual_seed(0)
|
||
self.device = device
|
||
self.model = Model(device).to(self.device)
|
||
self.optimizer = torch.optim.Adam(
|
||
params=self.model.parameters(),
|
||
lr=Config.INIT_LEARNING_RATE_START,
|
||
betas=(0.9, 0.999),
|
||
eps=1e-8,
|
||
)
|
||
self.logger = logger
|
||
self.monitor = monitor
|
||
self.algorithm = Algorithm(self.model, self.optimizer, self.device, self.logger, self.monitor)
|
||
self.preprocessor = Preprocessor()
|
||
self.last_action = -1
|
||
self.last_reward = 0.0
|
||
|
||
super().__init__(agent_type, device, logger, monitor)
|
||
|
||
def reset(self, env_obs):
|
||
"""Reset per-episode state.
|
||
|
||
每局开始时重置 Agent 内部状态。
|
||
"""
|
||
self.preprocessor = Preprocessor()
|
||
self.last_action = -1
|
||
self.last_reward = 0.0
|
||
|
||
def observation_process(self, env_obs):
|
||
"""Convert raw env_obs to ObsData (feature + legal action mask).
|
||
|
||
将原始 env_obs 转换为 ObsData(特征 + 合法动作掩码)。
|
||
"""
|
||
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
||
self.last_reward = reward
|
||
|
||
obs_data = ObsData(
|
||
feature=list(feature),
|
||
legal_action=legal_action,
|
||
)
|
||
remain_info = {}
|
||
return obs_data, remain_info
|
||
|
||
def action_process(self, act_data, is_stochastic=True):
|
||
"""Extract int action from ActData and update last_action.
|
||
|
||
从 ActData 中取出动作整数并更新 last_action。
|
||
"""
|
||
action = act_data.action if is_stochastic else act_data.d_action
|
||
self.last_action = int(action[0])
|
||
if hasattr(self.preprocessor, "record_action"):
|
||
self.preprocessor.record_action(self.last_action)
|
||
return self.last_action
|
||
|
||
def predict(self, list_obs_data):
|
||
"""Stochastic inference for training (exploration).
|
||
|
||
训练时推理(随机采样动作)。
|
||
"""
|
||
obs_data = list_obs_data[0]
|
||
feature = obs_data.feature
|
||
legal_action = obs_data.legal_action
|
||
|
||
logits, value = self._run_model(feature)
|
||
|
||
legal_arr = np.array(legal_action, dtype=np.float32)
|
||
prob = self._legal_soft_max(logits, legal_arr)
|
||
action = self._legal_sample(prob, use_max=False)
|
||
d_action = self._legal_sample(prob, use_max=True)
|
||
|
||
return [
|
||
ActData(
|
||
action=[action],
|
||
d_action=[d_action],
|
||
prob=list(prob),
|
||
value=value,
|
||
)
|
||
]
|
||
|
||
def exploit(self, env_obs):
|
||
"""Greedy inference for evaluation.
|
||
|
||
评估时推理(贪心)。
|
||
"""
|
||
try:
|
||
obs_data, _ = self.observation_process(env_obs)
|
||
logits, value = self._run_model(obs_data.feature)
|
||
legal_arr = np.array(obs_data.legal_action, dtype=np.float32)
|
||
prob = self._legal_soft_max(logits, legal_arr)
|
||
action = self._tie_break_eval_action(prob, legal_arr)
|
||
act_data = ActData(
|
||
action=[action],
|
||
d_action=[action],
|
||
prob=list(prob),
|
||
value=value,
|
||
)
|
||
return self.action_process(act_data, is_stochastic=False)
|
||
except Exception as err:
|
||
if self.logger:
|
||
if hasattr(self.logger, "exception"):
|
||
self.logger.exception(f"exploit fallback action due to inference error: {err}")
|
||
else:
|
||
self.logger.error(f"exploit fallback action due to inference error: {err}")
|
||
return self._fallback_action(env_obs)
|
||
|
||
def learn(self, list_sample_data):
|
||
"""Delegate to Algorithm for PPO update.
|
||
|
||
委托给 Algorithm 执行训练。
|
||
"""
|
||
return self.algorithm.learn(list_sample_data)
|
||
|
||
def estimate_value(self, obs_data):
|
||
"""Estimate critic value for a processed observation."""
|
||
_, value = self._run_model(obs_data.feature)
|
||
return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM]
|
||
|
||
def save_model(self, path=None, id="1"):
|
||
"""Save model checkpoint.
|
||
|
||
保存模型检查点。
|
||
"""
|
||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||
state_dict_cpu = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
|
||
torch.save(state_dict_cpu, model_file_path)
|
||
self.logger.info(f"save model {model_file_path} successfully")
|
||
|
||
def load_model(self, path=None, id="1"):
|
||
"""Load model checkpoint.
|
||
|
||
加载模型检查点。
|
||
"""
|
||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||
state_dict = torch.load(model_file_path, map_location=self.device)
|
||
try:
|
||
self.model.load_state_dict(state_dict)
|
||
except RuntimeError as err:
|
||
msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}"
|
||
if self.logger:
|
||
self.logger.warning(msg)
|
||
return
|
||
if self.logger:
|
||
self.logger.info(f"load model {model_file_path} successfully")
|
||
|
||
def _fallback_action(self, env_obs):
|
||
"""Return a valid action instead of None during evaluation failures."""
|
||
legal_action = [1] * Config.ACTION_NUM
|
||
if isinstance(env_obs, dict):
|
||
observation = env_obs.get("observation")
|
||
if isinstance(observation, dict):
|
||
raw_legal = observation.get("legal_action") or observation.get("legal_act")
|
||
if isinstance(raw_legal, (list, tuple)) and raw_legal:
|
||
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
|
||
if len(legal_action) < Config.ACTION_NUM:
|
||
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
|
||
for action, is_legal in enumerate(legal_action):
|
||
if is_legal > 0:
|
||
self.last_action = action
|
||
return action
|
||
self.last_action = 0
|
||
return 0
|
||
|
||
def _run_model(self, feature):
|
||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||
|
||
无梯度推理,返回 (logits_np, value_np)。
|
||
"""
|
||
self.model.set_eval_mode()
|
||
obs_tensor = (
|
||
torch.tensor(np.array([feature], dtype=np.float32)).view(1, Config.DIM_OF_OBSERVATION).to(self.device)
|
||
)
|
||
with torch.no_grad():
|
||
rst = self.model(obs_tensor, inference=True)
|
||
logits = rst[0].cpu().numpy()[0]
|
||
value = rst[1].cpu().numpy()[0]
|
||
return logits, value
|
||
|
||
def _legal_soft_max(self, logits, legal_action):
|
||
"""Softmax with legal action masking.
|
||
|
||
合法动作掩码下的 softmax。
|
||
"""
|
||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||
if not np.any(legal):
|
||
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||
|
||
masked_logits = np.asarray(logits, dtype=np.float32).copy()
|
||
masked_logits[~legal] = -1e9
|
||
shifted = masked_logits - np.max(masked_logits)
|
||
probs = np.exp(shifted) * legal.astype(np.float32)
|
||
prob_sum = float(np.sum(probs))
|
||
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||
probs = legal.astype(np.float32)
|
||
prob_sum = float(np.sum(probs))
|
||
return probs / prob_sum
|
||
|
||
def _legal_sample(self, probs, use_max=False):
|
||
"""Sample action from probability distribution (argmax if use_max=True).
|
||
|
||
按概率分布采样动作(use_max=True 时取 argmax)。
|
||
"""
|
||
probs = np.asarray(probs, dtype=np.float64)
|
||
prob_sum = float(np.sum(probs))
|
||
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
|
||
else:
|
||
probs = probs / prob_sum
|
||
if use_max:
|
||
return int(np.argmax(probs))
|
||
return int(np.random.choice(len(probs), p=probs))
|
||
|
||
def _tie_break_eval_action(self, probs, legal_action):
|
||
"""Use a light heuristic only when evaluation probabilities are close."""
|
||
probs = np.asarray(probs, dtype=np.float64)
|
||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||
if not np.any(legal):
|
||
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||
legal_indices = np.flatnonzero(legal)
|
||
best_action = int(legal_indices[np.argmax(probs[legal_indices])])
|
||
best_prob = float(probs[best_action])
|
||
candidates = [
|
||
int(action)
|
||
for action in legal_indices
|
||
if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP
|
||
]
|
||
if len(candidates) <= 1:
|
||
return best_action
|
||
|
||
scored = []
|
||
for action in candidates:
|
||
heuristic = 0.0
|
||
if hasattr(self.preprocessor, "evaluation_action_score"):
|
||
heuristic = self.preprocessor.evaluation_action_score(action)
|
||
combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic
|
||
scored.append((combined, float(probs[action]), -action, action))
|
||
scored.sort(reverse=True)
|
||
return int(scored[0][3])
|