Optimize PPO short-run training
This commit is contained in:
@@ -17,6 +17,7 @@ import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from agent_ppo.conf.conf import Config
|
||||
|
||||
@@ -32,7 +33,6 @@ class Algorithm:
|
||||
|
||||
self.clip_param = Config.CLIP_PARAM
|
||||
self.vf_coef = Config.VF_COEF
|
||||
self.var_beta = Config.BETA_START
|
||||
self.label_size = Config.ACTION_NUM
|
||||
|
||||
self.train_step = 0
|
||||
@@ -52,23 +52,43 @@ class Algorithm:
|
||||
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
|
||||
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
|
||||
|
||||
self.model.set_train_mode()
|
||||
self.optimizer.zero_grad()
|
||||
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
|
||||
|
||||
rst_list = self.model(obs)
|
||||
self.model.set_train_mode()
|
||||
batch_size = obs.shape[0]
|
||||
mini_batch_size = min(Config.MINI_BATCH_SIZE, batch_size)
|
||||
stat_sum = {
|
||||
"total_loss": 0.0,
|
||||
"value_loss": 0.0,
|
||||
"policy_loss": 0.0,
|
||||
"entropy_loss": 0.0,
|
||||
"approx_kl": 0.0,
|
||||
"clip_fraction": 0.0,
|
||||
}
|
||||
stat_count = 0
|
||||
|
||||
for _ in range(Config.PPO_EPOCHS):
|
||||
indices = torch.randperm(batch_size, device=self.device)
|
||||
|
||||
for start in range(0, batch_size, mini_batch_size):
|
||||
mb_idx = indices[start : start + mini_batch_size]
|
||||
|
||||
rst_list = self.model(obs[mb_idx])
|
||||
logits, value_pred = rst_list[0], rst_list[1]
|
||||
|
||||
total_loss, info = self._compute_loss(
|
||||
logits=logits,
|
||||
value_pred=value_pred,
|
||||
legal_action=legal_action,
|
||||
old_action=act,
|
||||
old_prob=old_prob,
|
||||
old_value=old_value,
|
||||
reward_sum=reward_sum,
|
||||
advantage=advantage,
|
||||
legal_action=legal_action[mb_idx],
|
||||
old_action=act[mb_idx],
|
||||
old_prob=old_prob[mb_idx],
|
||||
old_value=old_value[mb_idx],
|
||||
reward_sum=reward_sum[mb_idx],
|
||||
advantage=advantage[mb_idx],
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
|
||||
if Config.USE_GRAD_CLIP:
|
||||
@@ -77,7 +97,15 @@ class Algorithm:
|
||||
self.optimizer.step()
|
||||
self.train_step += 1
|
||||
|
||||
results = {"total_loss": total_loss.item()}
|
||||
for key in stat_sum:
|
||||
stat_sum[key] += info[key]
|
||||
stat_count += 1
|
||||
|
||||
if stat_count > 0 and stat_sum["approx_kl"] / stat_count > Config.TARGET_KL:
|
||||
break
|
||||
|
||||
info = {key: value / max(stat_count, 1) for key, value in stat_sum.items()}
|
||||
results = {"total_loss": info["total_loss"]}
|
||||
|
||||
# Periodic monitoring report
|
||||
# 定期上报监控
|
||||
@@ -87,11 +115,15 @@ class Algorithm:
|
||||
results["policy_loss"] = round(info["policy_loss"], 4)
|
||||
results["entropy_loss"] = round(info["entropy_loss"], 4)
|
||||
results["reward"] = round(reward.mean().item(), 4)
|
||||
results["approx_kl"] = round(info["approx_kl"], 4)
|
||||
results["clip_fraction"] = round(info["clip_fraction"], 4)
|
||||
|
||||
self.logger.info(
|
||||
f"policy_loss: {results['policy_loss']}, "
|
||||
f"value_loss: {results['value_loss']}, "
|
||||
f"entropy_loss: {results['entropy_loss']}"
|
||||
f"entropy_loss: {results['entropy_loss']}, "
|
||||
f"approx_kl: {results['approx_kl']}, "
|
||||
f"clip_fraction: {results['clip_fraction']}"
|
||||
)
|
||||
if self.monitor:
|
||||
self.monitor.put_data({os.getpid(): results})
|
||||
@@ -115,8 +147,8 @@ class Algorithm:
|
||||
value_loss = (
|
||||
0.5
|
||||
* torch.maximum(
|
||||
(tdret - vp) ** 2,
|
||||
(tdret - vp_clip) ** 2,
|
||||
F.smooth_l1_loss(vp, tdret, reduction="none"),
|
||||
F.smooth_l1_loss(vp_clip, tdret, reduction="none"),
|
||||
).mean()
|
||||
)
|
||||
|
||||
@@ -130,6 +162,9 @@ class Algorithm:
|
||||
old_action_prob = (one_hot * old_prob).sum(1, keepdim=True)
|
||||
|
||||
ratio = new_prob / old_action_prob.clamp(1e-9)
|
||||
log_ratio = torch.log(new_prob.clamp_min(1e-9)) - torch.log(old_action_prob.clamp_min(1e-9))
|
||||
approx_kl = (-log_ratio).mean()
|
||||
clip_fraction = ((ratio - 1.0).abs() > self.clip_param).float().mean()
|
||||
|
||||
adv = advantage.squeeze(-1) if advantage.dim() > 1 else advantage
|
||||
adv = adv.unsqueeze(-1)
|
||||
@@ -141,12 +176,16 @@ class Algorithm:
|
||||
|
||||
# Total loss
|
||||
# 总损失
|
||||
total_loss = self.vf_coef * value_loss + policy_loss - self.var_beta * entropy_loss
|
||||
entropy_beta = self._entropy_beta()
|
||||
total_loss = self.vf_coef * value_loss + policy_loss - entropy_beta * entropy_loss
|
||||
|
||||
return total_loss, {
|
||||
"total_loss": total_loss.item(),
|
||||
"value_loss": value_loss.item(),
|
||||
"policy_loss": policy_loss.item(),
|
||||
"entropy_loss": entropy_loss.item(),
|
||||
"approx_kl": approx_kl.item(),
|
||||
"clip_fraction": clip_fraction.item(),
|
||||
}
|
||||
|
||||
def _masked_softmax(self, logits, legal_action):
|
||||
@@ -154,8 +193,11 @@ class Algorithm:
|
||||
|
||||
对 logits 应用合法动作掩码后计算 softmax。
|
||||
"""
|
||||
label_max, _ = torch.max(logits * legal_action, dim=1, keepdim=True)
|
||||
logits = logits - label_max
|
||||
logits = logits * legal_action
|
||||
logits = logits + 1e5 * (legal_action - 1)
|
||||
return torch.nn.functional.softmax(logits, dim=1)
|
||||
legal_mask = legal_action > 0.5
|
||||
safe_logits = logits.masked_fill(~legal_mask, -1e9)
|
||||
return F.softmax(safe_logits, dim=1)
|
||||
|
||||
def _entropy_beta(self):
|
||||
"""Linearly decay entropy regularization for fast early exploration."""
|
||||
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
|
||||
return Config.BETA_START + progress * (Config.BETA_END - Config.BETA_START)
|
||||
|
||||
@@ -37,10 +37,16 @@ class Config:
|
||||
GAMMA = 0.99
|
||||
LAMDA = 0.95
|
||||
|
||||
INIT_LEARNING_RATE_START = 0.0003
|
||||
BETA_START = 0.001
|
||||
INIT_LEARNING_RATE_START = 0.00025
|
||||
BETA_START = 0.008
|
||||
BETA_END = 0.002
|
||||
BETA_DECAY_STEPS = 4000
|
||||
CLIP_PARAM = 0.2
|
||||
VF_COEF = 0.5
|
||||
PPO_EPOCHS = 3
|
||||
MINI_BATCH_SIZE = 256
|
||||
NORMALIZE_ADVANTAGE = True
|
||||
TARGET_KL = 0.04
|
||||
|
||||
LABEL_SIZE_LIST = [ACTION_NUM]
|
||||
LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy()
|
||||
|
||||
@@ -77,6 +77,26 @@ def build_monitor():
|
||||
expr="avg(entropy_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="近似KL",
|
||||
name_en="approx_kl",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="approx_kl",
|
||||
expr="avg(approx_kl{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="裁剪比例",
|
||||
name_en="clip_fraction",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="clip_fraction",
|
||||
expr="avg(clip_fraction{})",
|
||||
)
|
||||
.end_panel()
|
||||
.end_group()
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -47,6 +47,11 @@ class Preprocessor:
|
||||
self.battery_max = 600
|
||||
|
||||
self.cur_pos = (0, 0)
|
||||
self.prev_pos = None
|
||||
self.has_position_history = False
|
||||
self.current_visit_count = 0
|
||||
self.is_new_cell = False
|
||||
self.last_action = -1
|
||||
|
||||
self.dirt_cleaned = 0
|
||||
self.last_dirt_cleaned = 0
|
||||
@@ -60,6 +65,7 @@ class Preprocessor:
|
||||
# 最近污渍距离
|
||||
self.nearest_dirt_dist = 200.0
|
||||
self.last_nearest_dirt_dist = 200.0
|
||||
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
|
||||
|
||||
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
||||
self._legal_act = [1] * 8
|
||||
@@ -74,8 +80,20 @@ class Preprocessor:
|
||||
env_info = observation["env_info"]
|
||||
hero = frame_state["heroes"]
|
||||
|
||||
self.last_action = int(last_action)
|
||||
self.step_no = int(observation["step_no"])
|
||||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||||
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
|
||||
self.has_position_history = True
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
|
||||
self.current_visit_count = int(self.visit_count_map[hx, hz])
|
||||
self.is_new_cell = self.current_visit_count == 0
|
||||
self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max)
|
||||
else:
|
||||
self.current_visit_count = 0
|
||||
self.is_new_cell = False
|
||||
|
||||
# Battery / 电量
|
||||
self.battery = int(hero["battery"])
|
||||
@@ -238,9 +256,15 @@ class Preprocessor:
|
||||
local_view = self._get_local_view_feature() # 49D
|
||||
global_state = self._get_global_state_feature() # 12D
|
||||
legal_action = self.get_legal_action() # 8D
|
||||
legal_arr = np.array(legal_action, dtype=np.float32)
|
||||
|
||||
feature = np.concatenate([local_view, global_state, legal_arr]) # 69D
|
||||
last_action_feature = np.zeros(8, dtype=np.float32)
|
||||
if 0 <= last_action < 8:
|
||||
last_action_feature[last_action] = 1.0
|
||||
|
||||
# The legal action mask is passed separately to PPO. Reusing this 8D slot
|
||||
# for action history makes the 69D observation more informative without
|
||||
# breaking the framework's fixed tensor shape.
|
||||
feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D
|
||||
|
||||
reward = self.reward_process()
|
||||
|
||||
@@ -249,9 +273,26 @@ class Preprocessor:
|
||||
def reward_process(self):
|
||||
# Cleaning reward / 清扫奖励
|
||||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||||
cleaning_reward = 0.1 * cleaned_this_step
|
||||
cleaning_reward = 0.25 * cleaned_this_step
|
||||
|
||||
# Step penalty / 时间惩罚
|
||||
step_penalty = -0.001
|
||||
step_penalty = -0.002
|
||||
|
||||
return cleaning_reward + step_penalty
|
||||
# Dense guidance: prefer moving toward visible dirt.
|
||||
# 稠密引导:鼓励向视野内污渍靠近。
|
||||
approach_reward = 0.0
|
||||
if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0:
|
||||
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
|
||||
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
|
||||
|
||||
# Encourage covering new passable cells and mildly discourage loops.
|
||||
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
||||
exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5)
|
||||
|
||||
# Collision/stuck signal: invalid moves waste both step and battery.
|
||||
# 撞墙/原地不动会浪费步数和电量。
|
||||
stuck_penalty = 0.0
|
||||
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
||||
stuck_penalty = -0.03
|
||||
|
||||
return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty
|
||||
|
||||
@@ -138,19 +138,22 @@ class EpisodeRunner:
|
||||
# Survived to max steps: higher cleaning ratio → more reward
|
||||
# 存活到最大步数:清扫比例越高奖励越多
|
||||
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
||||
final_reward = 5.0 + 5.0 * cleaning_ratio
|
||||
final_reward = 2.0 + 8.0 * cleaning_ratio
|
||||
result_str = "WIN"
|
||||
else:
|
||||
# Early termination (battery depleted or collision): small penalty
|
||||
# 提前结束(电量耗尽或碰撞):小惩罚
|
||||
final_reward = -2.0
|
||||
# Battery-depleted episodes are common with short runs; keep
|
||||
# cleaning progress as the dominant terminal signal.
|
||||
# 短训中电量耗尽较常见,终局奖励仍以清扫比例为主。
|
||||
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
||||
final_reward = -1.0 + 6.0 * cleaning_ratio
|
||||
result_str = "FAIL"
|
||||
|
||||
self.logger.info(
|
||||
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
|
||||
f"result:{result_str} final_bonus:{final_reward:.2f} "
|
||||
f"total_reward:{total_reward:.3f} "
|
||||
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt}"
|
||||
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} "
|
||||
f"total_score:{total_score}"
|
||||
)
|
||||
|
||||
# Build sample frame
|
||||
|
||||
Reference in New Issue
Block a user