This repository has been archived on 2026-05-02. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
-----/agent_ppo/agent.py

270 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
###########################################################################
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
###########################################################################
"""
Author: Tencent AI Arena Authors
Robot Vacuum Agent.
清扫大作战 Agent 主类。
"""
import torch
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
import numpy as np
from agent_ppo.algorithm.algorithm import Algorithm
from agent_ppo.conf.conf import Config
from agent_ppo.feature.definition import ActData, ObsData
from agent_ppo.feature.preprocessor import Preprocessor
from agent_ppo.model.model import Model
from kaiwudrl.interface.agent import BaseAgent
class Agent(BaseAgent):
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
torch.manual_seed(0)
self.device = device
self.model = Model(device).to(self.device)
self.optimizer = torch.optim.Adam(
params=self.model.parameters(),
lr=Config.INIT_LEARNING_RATE_START,
betas=(0.9, 0.999),
eps=1e-8,
)
self.logger = logger
self.monitor = monitor
self.algorithm = Algorithm(self.model, self.optimizer, self.device, self.logger, self.monitor)
self.preprocessor = Preprocessor()
self.last_action = -1
self.last_reward = 0.0
super().__init__(agent_type, device, logger, monitor)
def reset(self, env_obs):
"""Reset per-episode state.
每局开始时重置 Agent 内部状态。
"""
self.preprocessor = Preprocessor()
self.last_action = -1
self.last_reward = 0.0
def observation_process(self, env_obs):
"""Convert raw env_obs to ObsData (feature + legal action mask).
将原始 env_obs 转换为 ObsData特征 + 合法动作掩码)。
"""
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
self.last_reward = reward
obs_data = ObsData(
feature=list(feature),
legal_action=legal_action,
)
remain_info = {}
return obs_data, remain_info
def action_process(self, act_data, is_stochastic=True):
"""Extract int action from ActData and update last_action.
从 ActData 中取出动作整数并更新 last_action。
"""
action = act_data.action if is_stochastic else act_data.d_action
self.last_action = int(action[0])
if hasattr(self.preprocessor, "record_action"):
self.preprocessor.record_action(self.last_action)
return self.last_action
def predict(self, list_obs_data):
"""Stochastic inference for training (exploration).
训练时推理(随机采样动作)。
"""
obs_data = list_obs_data[0]
feature = obs_data.feature
legal_action = obs_data.legal_action
logits, value = self._run_model(feature)
legal_arr = np.array(legal_action, dtype=np.float32)
prob = self._legal_soft_max(logits, legal_arr)
action = self._legal_sample(prob, use_max=False)
d_action = self._legal_sample(prob, use_max=True)
return [
ActData(
action=[action],
d_action=[d_action],
prob=list(prob),
value=value,
)
]
def exploit(self, env_obs):
"""Greedy inference for evaluation.
评估时推理(贪心)。
"""
try:
obs_data, _ = self.observation_process(env_obs)
logits, value = self._run_model(obs_data.feature)
legal_arr = np.array(obs_data.legal_action, dtype=np.float32)
prob = self._legal_soft_max(logits, legal_arr)
action = None
if hasattr(self.preprocessor, "planned_eval_action"):
action = self.preprocessor.planned_eval_action(prob, legal_arr)
if action is None:
action = self._tie_break_eval_action(prob, legal_arr)
act_data = ActData(
action=[action],
d_action=[action],
prob=list(prob),
value=value,
)
return self.action_process(act_data, is_stochastic=False)
except Exception as err:
if self.logger:
if hasattr(self.logger, "exception"):
self.logger.exception(f"exploit fallback action due to inference error: {err}")
else:
self.logger.error(f"exploit fallback action due to inference error: {err}")
return self._fallback_action(env_obs)
def learn(self, list_sample_data):
"""Delegate to Algorithm for PPO update.
委托给 Algorithm 执行训练。
"""
return self.algorithm.learn(list_sample_data)
def estimate_value(self, obs_data):
"""Estimate critic value for a processed observation."""
_, value = self._run_model(obs_data.feature)
return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM]
def save_model(self, path=None, id="1"):
"""Save model checkpoint.
保存模型检查点。
"""
model_file_path = f"{path}/model.ckpt-{id}.pkl"
state_dict_cpu = {k: v.clone().cpu() for k, v in self.model.state_dict().items()}
torch.save(state_dict_cpu, model_file_path)
self.logger.info(f"save model {model_file_path} successfully")
def load_model(self, path=None, id="1"):
"""Load model checkpoint.
加载模型检查点。
"""
model_file_path = f"{path}/model.ckpt-{id}.pkl"
state_dict = torch.load(model_file_path, map_location=self.device)
try:
self.model.load_state_dict(state_dict)
except RuntimeError as err:
msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}"
if self.logger:
self.logger.warning(msg)
return
if self.logger:
self.logger.info(f"load model {model_file_path} successfully")
def _fallback_action(self, env_obs):
"""Return a valid action instead of None during evaluation failures."""
legal_action = [1] * Config.ACTION_NUM
if isinstance(env_obs, dict):
observation = env_obs.get("observation")
if isinstance(observation, dict):
raw_legal = observation.get("legal_action") or observation.get("legal_act")
if isinstance(raw_legal, (list, tuple)) and raw_legal:
legal_action = [int(x) for x in raw_legal[: Config.ACTION_NUM]]
if len(legal_action) < Config.ACTION_NUM:
legal_action.extend([0] * (Config.ACTION_NUM - len(legal_action)))
for action, is_legal in enumerate(legal_action):
if is_legal > 0:
self.last_action = action
return action
self.last_action = 0
return 0
def _run_model(self, feature):
"""Gradient-free forward pass, returns (logits_np, value_np).
无梯度推理,返回 (logits_np, value_np)。
"""
self.model.set_eval_mode()
obs_tensor = (
torch.tensor(np.array([feature], dtype=np.float32)).view(1, Config.DIM_OF_OBSERVATION).to(self.device)
)
with torch.no_grad():
rst = self.model(obs_tensor, inference=True)
logits = rst[0].cpu().numpy()[0]
value = rst[1].cpu().numpy()[0]
return logits, value
def _legal_soft_max(self, logits, legal_action):
"""Softmax with legal action masking.
合法动作掩码下的 softmax。
"""
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
if not np.any(legal):
legal = np.ones(Config.ACTION_NUM, dtype=bool)
masked_logits = np.asarray(logits, dtype=np.float32).copy()
masked_logits[~legal] = -1e9
shifted = masked_logits - np.max(masked_logits)
probs = np.exp(shifted) * legal.astype(np.float32)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = legal.astype(np.float32)
prob_sum = float(np.sum(probs))
return probs / prob_sum
def _legal_sample(self, probs, use_max=False):
"""Sample action from probability distribution (argmax if use_max=True).
按概率分布采样动作use_max=True 时取 argmax
"""
probs = np.asarray(probs, dtype=np.float64)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
else:
probs = probs / prob_sum
if use_max:
return int(np.argmax(probs))
return int(np.random.choice(len(probs), p=probs))
def _tie_break_eval_action(self, probs, legal_action):
"""Use a light heuristic only when evaluation probabilities are close."""
probs = np.asarray(probs, dtype=np.float64)
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
if not np.any(legal):
legal = np.ones(Config.ACTION_NUM, dtype=bool)
legal_indices = np.flatnonzero(legal)
best_action = int(legal_indices[np.argmax(probs[legal_indices])])
best_prob = float(probs[best_action])
candidates = [
int(action)
for action in legal_indices
if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP
]
if len(candidates) <= 1:
return best_action
scored = []
for action in candidates:
heuristic = 0.0
if hasattr(self.preprocessor, "evaluation_action_score"):
heuristic = self.preprocessor.evaluation_action_score(action)
combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic
scored.append((combined, float(probs[action]), -action, action))
scored.sort(reverse=True)
return int(scored[0][3])