#!/usr/bin/env python3 # -*- coding: UTF-8 -*- ########################################################################### # Copyright © 1998 - 2026 Tencent. All Rights Reserved. ########################################################################### """ Author: Tencent AI Arena Authors Standard PPO algorithm for Robot Vacuum. 清扫大作战 PPO 算法。 Loss composition / 损失组成: total_loss = vf_coef * value_loss + policy_loss - beta * entropy_loss """ import os import time import torch from agent_ppo.conf.conf import Config class Algorithm: def __init__(self, model, optimizer, device=None, logger=None, monitor=None): self.model = model self.optimizer = optimizer self.parameters = [p for pg in optimizer.param_groups for p in pg["params"]] self.device = device self.logger = logger self.monitor = monitor self.clip_param = Config.CLIP_PARAM self.vf_coef = Config.VF_COEF self.var_beta = Config.BETA_START self.label_size = Config.ACTION_NUM self.train_step = 0 self.last_report_time = 0 def learn(self, list_sample_data): """Training entry: perform one PPO gradient step on a batch of SampleData. 训练入口:接收一批 SampleData,执行一步梯度更新。 """ obs = torch.stack([s.obs for s in list_sample_data]).to(self.device) legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device) act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1) old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device) old_value = torch.stack([s.value for s in list_sample_data]).to(self.device) reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device) advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device) reward = torch.stack([s.reward for s in list_sample_data]).to(self.device) self.model.set_train_mode() self.optimizer.zero_grad() rst_list = self.model(obs) logits, value_pred = rst_list[0], rst_list[1] total_loss, info = self._compute_loss( logits=logits, value_pred=value_pred, legal_action=legal_action, old_action=act, old_prob=old_prob, old_value=old_value, reward_sum=reward_sum, advantage=advantage, ) total_loss.backward() if Config.USE_GRAD_CLIP: torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE) self.optimizer.step() self.train_step += 1 results = {"total_loss": total_loss.item()} # Periodic monitoring report # 定期上报监控 now = time.time() if now - self.last_report_time >= 60: results["value_loss"] = round(info["value_loss"], 4) results["policy_loss"] = round(info["policy_loss"], 4) results["entropy_loss"] = round(info["entropy_loss"], 4) results["reward"] = round(reward.mean().item(), 4) self.logger.info( f"policy_loss: {results['policy_loss']}, " f"value_loss: {results['value_loss']}, " f"entropy_loss: {results['entropy_loss']}" ) if self.monitor: self.monitor.put_data({os.getpid(): results}) self.last_report_time = now return results def _compute_loss(self, logits, value_pred, legal_action, old_action, old_prob, old_value, reward_sum, advantage): """Compute standard PPO loss (policy + value + entropy). 计算标准 PPO 三项损失。 """ # Value loss (clipped) # 价值损失(裁剪) tdret = reward_sum.squeeze(-1) if reward_sum.dim() > 1 else reward_sum vp = value_pred.squeeze(-1) if value_pred.dim() > 1 else value_pred ov = old_value.squeeze(-1) if old_value.dim() > 1 else old_value vp_clip = ov + (vp - ov).clamp(-self.clip_param, self.clip_param) value_loss = ( 0.5 * torch.maximum( (tdret - vp) ** 2, (tdret - vp_clip) ** 2, ).mean() ) # Policy loss (PPO clip) # 策略损失(PPO clip) prob_dist = self._masked_softmax(logits, legal_action) entropy_loss = (-(prob_dist * torch.log(prob_dist.clamp(1e-9, 1))).sum(1)).mean() one_hot = torch.nn.functional.one_hot(old_action[:, 0].long(), self.label_size).float() new_prob = (one_hot * prob_dist).sum(1, keepdim=True) old_action_prob = (one_hot * old_prob).sum(1, keepdim=True) ratio = new_prob / old_action_prob.clamp(1e-9) adv = advantage.squeeze(-1) if advantage.dim() > 1 else advantage adv = adv.unsqueeze(-1) policy_loss = torch.maximum( -ratio * adv, -ratio.clamp(1 - self.clip_param, 1 + self.clip_param) * adv, ).mean() # Total loss # 总损失 total_loss = self.vf_coef * value_loss + policy_loss - self.var_beta * entropy_loss return total_loss, { "value_loss": value_loss.item(), "policy_loss": policy_loss.item(), "entropy_loss": entropy_loss.item(), } def _masked_softmax(self, logits, legal_action): """Apply legal action mask to logits before computing softmax. 对 logits 应用合法动作掩码后计算 softmax。 """ label_max, _ = torch.max(logits * legal_action, dim=1, keepdim=True) logits = logits - label_max logits = logits * legal_action logits = logits + 1e5 * (legal_action - 1) return torch.nn.functional.softmax(logits, dim=1)