Optimize PPO short-run training

This commit is contained in:
2026-04-26 12:46:00 +08:00
parent ca6234c941
commit eb3efa4df7
5 changed files with 153 additions and 41 deletions

View File

@@ -17,6 +17,7 @@ import os
import time
import torch
import torch.nn.functional as F
from agent_ppo.conf.conf import Config
@@ -32,7 +33,6 @@ class Algorithm:
self.clip_param = Config.CLIP_PARAM
self.vf_coef = Config.VF_COEF
self.var_beta = Config.BETA_START
self.label_size = Config.ACTION_NUM
self.train_step = 0
@@ -52,32 +52,60 @@ class Algorithm:
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
self.model.set_train_mode()
self.optimizer.zero_grad()
batch_size = obs.shape[0]
mini_batch_size = min(Config.MINI_BATCH_SIZE, batch_size)
stat_sum = {
"total_loss": 0.0,
"value_loss": 0.0,
"policy_loss": 0.0,
"entropy_loss": 0.0,
"approx_kl": 0.0,
"clip_fraction": 0.0,
}
stat_count = 0
rst_list = self.model(obs)
logits, value_pred = rst_list[0], rst_list[1]
for _ in range(Config.PPO_EPOCHS):
indices = torch.randperm(batch_size, device=self.device)
total_loss, info = self._compute_loss(
logits=logits,
value_pred=value_pred,
legal_action=legal_action,
old_action=act,
old_prob=old_prob,
old_value=old_value,
reward_sum=reward_sum,
advantage=advantage,
)
for start in range(0, batch_size, mini_batch_size):
mb_idx = indices[start : start + mini_batch_size]
total_loss.backward()
rst_list = self.model(obs[mb_idx])
logits, value_pred = rst_list[0], rst_list[1]
if Config.USE_GRAD_CLIP:
torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE)
total_loss, info = self._compute_loss(
logits=logits,
value_pred=value_pred,
legal_action=legal_action[mb_idx],
old_action=act[mb_idx],
old_prob=old_prob[mb_idx],
old_value=old_value[mb_idx],
reward_sum=reward_sum[mb_idx],
advantage=advantage[mb_idx],
)
self.optimizer.step()
self.train_step += 1
self.optimizer.zero_grad()
total_loss.backward()
results = {"total_loss": total_loss.item()}
if Config.USE_GRAD_CLIP:
torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE)
self.optimizer.step()
self.train_step += 1
for key in stat_sum:
stat_sum[key] += info[key]
stat_count += 1
if stat_count > 0 and stat_sum["approx_kl"] / stat_count > Config.TARGET_KL:
break
info = {key: value / max(stat_count, 1) for key, value in stat_sum.items()}
results = {"total_loss": info["total_loss"]}
# Periodic monitoring report
# 定期上报监控
@@ -87,11 +115,15 @@ class Algorithm:
results["policy_loss"] = round(info["policy_loss"], 4)
results["entropy_loss"] = round(info["entropy_loss"], 4)
results["reward"] = round(reward.mean().item(), 4)
results["approx_kl"] = round(info["approx_kl"], 4)
results["clip_fraction"] = round(info["clip_fraction"], 4)
self.logger.info(
f"policy_loss: {results['policy_loss']}, "
f"value_loss: {results['value_loss']}, "
f"entropy_loss: {results['entropy_loss']}"
f"entropy_loss: {results['entropy_loss']}, "
f"approx_kl: {results['approx_kl']}, "
f"clip_fraction: {results['clip_fraction']}"
)
if self.monitor:
self.monitor.put_data({os.getpid(): results})
@@ -115,8 +147,8 @@ class Algorithm:
value_loss = (
0.5
* torch.maximum(
(tdret - vp) ** 2,
(tdret - vp_clip) ** 2,
F.smooth_l1_loss(vp, tdret, reduction="none"),
F.smooth_l1_loss(vp_clip, tdret, reduction="none"),
).mean()
)
@@ -130,6 +162,9 @@ class Algorithm:
old_action_prob = (one_hot * old_prob).sum(1, keepdim=True)
ratio = new_prob / old_action_prob.clamp(1e-9)
log_ratio = torch.log(new_prob.clamp_min(1e-9)) - torch.log(old_action_prob.clamp_min(1e-9))
approx_kl = (-log_ratio).mean()
clip_fraction = ((ratio - 1.0).abs() > self.clip_param).float().mean()
adv = advantage.squeeze(-1) if advantage.dim() > 1 else advantage
adv = adv.unsqueeze(-1)
@@ -141,12 +176,16 @@ class Algorithm:
# Total loss
# 总损失
total_loss = self.vf_coef * value_loss + policy_loss - self.var_beta * entropy_loss
entropy_beta = self._entropy_beta()
total_loss = self.vf_coef * value_loss + policy_loss - entropy_beta * entropy_loss
return total_loss, {
"total_loss": total_loss.item(),
"value_loss": value_loss.item(),
"policy_loss": policy_loss.item(),
"entropy_loss": entropy_loss.item(),
"approx_kl": approx_kl.item(),
"clip_fraction": clip_fraction.item(),
}
def _masked_softmax(self, logits, legal_action):
@@ -154,8 +193,11 @@ class Algorithm:
对 logits 应用合法动作掩码后计算 softmax。
"""
label_max, _ = torch.max(logits * legal_action, dim=1, keepdim=True)
logits = logits - label_max
logits = logits * legal_action
logits = logits + 1e5 * (legal_action - 1)
return torch.nn.functional.softmax(logits, dim=1)
legal_mask = legal_action > 0.5
safe_logits = logits.masked_fill(~legal_mask, -1e9)
return F.softmax(safe_logits, dim=1)
def _entropy_beta(self):
"""Linearly decay entropy regularization for fast early exploration."""
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
return Config.BETA_START + progress * (Config.BETA_END - Config.BETA_START)