Optimize PPO short-run training

This commit is contained in:
2026-04-26 12:46:00 +08:00
parent ca6234c941
commit eb3efa4df7
5 changed files with 153 additions and 41 deletions

View File

@@ -17,6 +17,7 @@ import os
import time
import torch
import torch.nn.functional as F
from agent_ppo.conf.conf import Config
@@ -32,7 +33,6 @@ class Algorithm:
self.clip_param = Config.CLIP_PARAM
self.vf_coef = Config.VF_COEF
self.var_beta = Config.BETA_START
self.label_size = Config.ACTION_NUM
self.train_step = 0
@@ -52,32 +52,60 @@ class Algorithm:
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
self.model.set_train_mode()
self.optimizer.zero_grad()
batch_size = obs.shape[0]
mini_batch_size = min(Config.MINI_BATCH_SIZE, batch_size)
stat_sum = {
"total_loss": 0.0,
"value_loss": 0.0,
"policy_loss": 0.0,
"entropy_loss": 0.0,
"approx_kl": 0.0,
"clip_fraction": 0.0,
}
stat_count = 0
rst_list = self.model(obs)
logits, value_pred = rst_list[0], rst_list[1]
for _ in range(Config.PPO_EPOCHS):
indices = torch.randperm(batch_size, device=self.device)
total_loss, info = self._compute_loss(
logits=logits,
value_pred=value_pred,
legal_action=legal_action,
old_action=act,
old_prob=old_prob,
old_value=old_value,
reward_sum=reward_sum,
advantage=advantage,
)
for start in range(0, batch_size, mini_batch_size):
mb_idx = indices[start : start + mini_batch_size]
total_loss.backward()
rst_list = self.model(obs[mb_idx])
logits, value_pred = rst_list[0], rst_list[1]
if Config.USE_GRAD_CLIP:
torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE)
total_loss, info = self._compute_loss(
logits=logits,
value_pred=value_pred,
legal_action=legal_action[mb_idx],
old_action=act[mb_idx],
old_prob=old_prob[mb_idx],
old_value=old_value[mb_idx],
reward_sum=reward_sum[mb_idx],
advantage=advantage[mb_idx],
)
self.optimizer.step()
self.train_step += 1
self.optimizer.zero_grad()
total_loss.backward()
results = {"total_loss": total_loss.item()}
if Config.USE_GRAD_CLIP:
torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE)
self.optimizer.step()
self.train_step += 1
for key in stat_sum:
stat_sum[key] += info[key]
stat_count += 1
if stat_count > 0 and stat_sum["approx_kl"] / stat_count > Config.TARGET_KL:
break
info = {key: value / max(stat_count, 1) for key, value in stat_sum.items()}
results = {"total_loss": info["total_loss"]}
# Periodic monitoring report
# 定期上报监控
@@ -87,11 +115,15 @@ class Algorithm:
results["policy_loss"] = round(info["policy_loss"], 4)
results["entropy_loss"] = round(info["entropy_loss"], 4)
results["reward"] = round(reward.mean().item(), 4)
results["approx_kl"] = round(info["approx_kl"], 4)
results["clip_fraction"] = round(info["clip_fraction"], 4)
self.logger.info(
f"policy_loss: {results['policy_loss']}, "
f"value_loss: {results['value_loss']}, "
f"entropy_loss: {results['entropy_loss']}"
f"entropy_loss: {results['entropy_loss']}, "
f"approx_kl: {results['approx_kl']}, "
f"clip_fraction: {results['clip_fraction']}"
)
if self.monitor:
self.monitor.put_data({os.getpid(): results})
@@ -115,8 +147,8 @@ class Algorithm:
value_loss = (
0.5
* torch.maximum(
(tdret - vp) ** 2,
(tdret - vp_clip) ** 2,
F.smooth_l1_loss(vp, tdret, reduction="none"),
F.smooth_l1_loss(vp_clip, tdret, reduction="none"),
).mean()
)
@@ -130,6 +162,9 @@ class Algorithm:
old_action_prob = (one_hot * old_prob).sum(1, keepdim=True)
ratio = new_prob / old_action_prob.clamp(1e-9)
log_ratio = torch.log(new_prob.clamp_min(1e-9)) - torch.log(old_action_prob.clamp_min(1e-9))
approx_kl = (-log_ratio).mean()
clip_fraction = ((ratio - 1.0).abs() > self.clip_param).float().mean()
adv = advantage.squeeze(-1) if advantage.dim() > 1 else advantage
adv = adv.unsqueeze(-1)
@@ -141,12 +176,16 @@ class Algorithm:
# Total loss
# 总损失
total_loss = self.vf_coef * value_loss + policy_loss - self.var_beta * entropy_loss
entropy_beta = self._entropy_beta()
total_loss = self.vf_coef * value_loss + policy_loss - entropy_beta * entropy_loss
return total_loss, {
"total_loss": total_loss.item(),
"value_loss": value_loss.item(),
"policy_loss": policy_loss.item(),
"entropy_loss": entropy_loss.item(),
"approx_kl": approx_kl.item(),
"clip_fraction": clip_fraction.item(),
}
def _masked_softmax(self, logits, legal_action):
@@ -154,8 +193,11 @@ class Algorithm:
对 logits 应用合法动作掩码后计算 softmax。
"""
label_max, _ = torch.max(logits * legal_action, dim=1, keepdim=True)
logits = logits - label_max
logits = logits * legal_action
logits = logits + 1e5 * (legal_action - 1)
return torch.nn.functional.softmax(logits, dim=1)
legal_mask = legal_action > 0.5
safe_logits = logits.masked_fill(~legal_mask, -1e9)
return F.softmax(safe_logits, dim=1)
def _entropy_beta(self):
"""Linearly decay entropy regularization for fast early exploration."""
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
return Config.BETA_START + progress * (Config.BETA_END - Config.BETA_START)

View File

@@ -37,10 +37,16 @@ class Config:
GAMMA = 0.99
LAMDA = 0.95
INIT_LEARNING_RATE_START = 0.0003
BETA_START = 0.001
INIT_LEARNING_RATE_START = 0.00025
BETA_START = 0.008
BETA_END = 0.002
BETA_DECAY_STEPS = 4000
CLIP_PARAM = 0.2
VF_COEF = 0.5
PPO_EPOCHS = 3
MINI_BATCH_SIZE = 256
NORMALIZE_ADVANTAGE = True
TARGET_KL = 0.04
LABEL_SIZE_LIST = [ACTION_NUM]
LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy()

View File

@@ -77,6 +77,26 @@ def build_monitor():
expr="avg(entropy_loss{})",
)
.end_panel()
.add_panel(
name="近似KL",
name_en="approx_kl",
type="line",
)
.add_metric(
metrics_name="approx_kl",
expr="avg(approx_kl{})",
)
.end_panel()
.add_panel(
name="裁剪比例",
name_en="clip_fraction",
type="line",
)
.add_metric(
metrics_name="clip_fraction",
expr="avg(clip_fraction{})",
)
.end_panel()
.end_group()
.build()
)

View File

@@ -47,6 +47,11 @@ class Preprocessor:
self.battery_max = 600
self.cur_pos = (0, 0)
self.prev_pos = None
self.has_position_history = False
self.current_visit_count = 0
self.is_new_cell = False
self.last_action = -1
self.dirt_cleaned = 0
self.last_dirt_cleaned = 0
@@ -60,6 +65,7 @@ class Preprocessor:
# 最近污渍距离
self.nearest_dirt_dist = 200.0
self.last_nearest_dirt_dist = 200.0
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
self._view_map = np.zeros((21, 21), dtype=np.float32)
self._legal_act = [1] * 8
@@ -74,8 +80,20 @@ class Preprocessor:
env_info = observation["env_info"]
hero = frame_state["heroes"]
self.last_action = int(last_action)
self.step_no = int(observation["step_no"])
self.prev_pos = self.cur_pos if self.has_position_history else None
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
self.has_position_history = True
hx, hz = self.cur_pos
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
self.current_visit_count = int(self.visit_count_map[hx, hz])
self.is_new_cell = self.current_visit_count == 0
self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max)
else:
self.current_visit_count = 0
self.is_new_cell = False
# Battery / 电量
self.battery = int(hero["battery"])
@@ -238,9 +256,15 @@ class Preprocessor:
local_view = self._get_local_view_feature() # 49D
global_state = self._get_global_state_feature() # 12D
legal_action = self.get_legal_action() # 8D
legal_arr = np.array(legal_action, dtype=np.float32)
feature = np.concatenate([local_view, global_state, legal_arr]) # 69D
last_action_feature = np.zeros(8, dtype=np.float32)
if 0 <= last_action < 8:
last_action_feature[last_action] = 1.0
# The legal action mask is passed separately to PPO. Reusing this 8D slot
# for action history makes the 69D observation more informative without
# breaking the framework's fixed tensor shape.
feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D
reward = self.reward_process()
@@ -249,9 +273,26 @@ class Preprocessor:
def reward_process(self):
# Cleaning reward / 清扫奖励
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
cleaning_reward = 0.1 * cleaned_this_step
cleaning_reward = 0.25 * cleaned_this_step
# Step penalty / 时间惩罚
step_penalty = -0.001
step_penalty = -0.002
return cleaning_reward + step_penalty
# Dense guidance: prefer moving toward visible dirt.
# 稠密引导:鼓励向视野内污渍靠近。
approach_reward = 0.0
if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0:
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
# Encourage covering new passable cells and mildly discourage loops.
# 鼓励探索新格子,轻微惩罚反复绕圈。
exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5)
# Collision/stuck signal: invalid moves waste both step and battery.
# 撞墙/原地不动会浪费步数和电量。
stuck_penalty = 0.0
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
stuck_penalty = -0.03
return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty

View File

@@ -138,19 +138,22 @@ class EpisodeRunner:
# Survived to max steps: higher cleaning ratio → more reward
# 存活到最大步数:清扫比例越高奖励越多
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
final_reward = 5.0 + 5.0 * cleaning_ratio
final_reward = 2.0 + 8.0 * cleaning_ratio
result_str = "WIN"
else:
# Early termination (battery depleted or collision): small penalty
# 提前结束(电量耗尽或碰撞):小惩罚
final_reward = -2.0
# Battery-depleted episodes are common with short runs; keep
# cleaning progress as the dominant terminal signal.
# 短训中电量耗尽较常见,终局奖励仍以清扫比例为主。
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
final_reward = -1.0 + 6.0 * cleaning_ratio
result_str = "FAIL"
self.logger.info(
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
f"result:{result_str} final_bonus:{final_reward:.2f} "
f"total_reward:{total_reward:.3f} "
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt}"
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} "
f"total_score:{total_score}"
)
# Build sample frame