From eb3efa4df786cdb1edb98a0d00e29797ed14b031 Mon Sep 17 00:00:00 2001 From: gqt <3217233537@qq.com> Date: Sun, 26 Apr 2026 12:46:00 +0800 Subject: [PATCH] Optimize PPO short-run training --- agent_ppo/algorithm/algorithm.py | 100 +++++++++++++++++++-------- agent_ppo/conf/conf.py | 10 ++- agent_ppo/conf/monitor_builder.py | 20 ++++++ agent_ppo/feature/preprocessor.py | 51 ++++++++++++-- agent_ppo/workflow/train_workflow.py | 13 ++-- 5 files changed, 153 insertions(+), 41 deletions(-) diff --git a/agent_ppo/algorithm/algorithm.py b/agent_ppo/algorithm/algorithm.py index a6ae736..c1d97cc 100644 --- a/agent_ppo/algorithm/algorithm.py +++ b/agent_ppo/algorithm/algorithm.py @@ -17,6 +17,7 @@ import os import time import torch +import torch.nn.functional as F from agent_ppo.conf.conf import Config @@ -32,7 +33,6 @@ class Algorithm: self.clip_param = Config.CLIP_PARAM self.vf_coef = Config.VF_COEF - self.var_beta = Config.BETA_START self.label_size = Config.ACTION_NUM self.train_step = 0 @@ -52,32 +52,60 @@ class Algorithm: advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device) reward = torch.stack([s.reward for s in list_sample_data]).to(self.device) + if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1: + advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8) + self.model.set_train_mode() - self.optimizer.zero_grad() + batch_size = obs.shape[0] + mini_batch_size = min(Config.MINI_BATCH_SIZE, batch_size) + stat_sum = { + "total_loss": 0.0, + "value_loss": 0.0, + "policy_loss": 0.0, + "entropy_loss": 0.0, + "approx_kl": 0.0, + "clip_fraction": 0.0, + } + stat_count = 0 - rst_list = self.model(obs) - logits, value_pred = rst_list[0], rst_list[1] + for _ in range(Config.PPO_EPOCHS): + indices = torch.randperm(batch_size, device=self.device) - total_loss, info = self._compute_loss( - logits=logits, - value_pred=value_pred, - legal_action=legal_action, - old_action=act, - old_prob=old_prob, - old_value=old_value, - reward_sum=reward_sum, - advantage=advantage, - ) + for start in range(0, batch_size, mini_batch_size): + mb_idx = indices[start : start + mini_batch_size] - total_loss.backward() + rst_list = self.model(obs[mb_idx]) + logits, value_pred = rst_list[0], rst_list[1] - if Config.USE_GRAD_CLIP: - torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE) + total_loss, info = self._compute_loss( + logits=logits, + value_pred=value_pred, + legal_action=legal_action[mb_idx], + old_action=act[mb_idx], + old_prob=old_prob[mb_idx], + old_value=old_value[mb_idx], + reward_sum=reward_sum[mb_idx], + advantage=advantage[mb_idx], + ) - self.optimizer.step() - self.train_step += 1 + self.optimizer.zero_grad() + total_loss.backward() - results = {"total_loss": total_loss.item()} + if Config.USE_GRAD_CLIP: + torch.nn.utils.clip_grad_norm_(self.parameters, Config.GRAD_CLIP_RANGE) + + self.optimizer.step() + self.train_step += 1 + + for key in stat_sum: + stat_sum[key] += info[key] + stat_count += 1 + + if stat_count > 0 and stat_sum["approx_kl"] / stat_count > Config.TARGET_KL: + break + + info = {key: value / max(stat_count, 1) for key, value in stat_sum.items()} + results = {"total_loss": info["total_loss"]} # Periodic monitoring report # 定期上报监控 @@ -87,11 +115,15 @@ class Algorithm: results["policy_loss"] = round(info["policy_loss"], 4) results["entropy_loss"] = round(info["entropy_loss"], 4) results["reward"] = round(reward.mean().item(), 4) + results["approx_kl"] = round(info["approx_kl"], 4) + results["clip_fraction"] = round(info["clip_fraction"], 4) self.logger.info( f"policy_loss: {results['policy_loss']}, " f"value_loss: {results['value_loss']}, " - f"entropy_loss: {results['entropy_loss']}" + f"entropy_loss: {results['entropy_loss']}, " + f"approx_kl: {results['approx_kl']}, " + f"clip_fraction: {results['clip_fraction']}" ) if self.monitor: self.monitor.put_data({os.getpid(): results}) @@ -115,8 +147,8 @@ class Algorithm: value_loss = ( 0.5 * torch.maximum( - (tdret - vp) ** 2, - (tdret - vp_clip) ** 2, + F.smooth_l1_loss(vp, tdret, reduction="none"), + F.smooth_l1_loss(vp_clip, tdret, reduction="none"), ).mean() ) @@ -130,6 +162,9 @@ class Algorithm: old_action_prob = (one_hot * old_prob).sum(1, keepdim=True) ratio = new_prob / old_action_prob.clamp(1e-9) + log_ratio = torch.log(new_prob.clamp_min(1e-9)) - torch.log(old_action_prob.clamp_min(1e-9)) + approx_kl = (-log_ratio).mean() + clip_fraction = ((ratio - 1.0).abs() > self.clip_param).float().mean() adv = advantage.squeeze(-1) if advantage.dim() > 1 else advantage adv = adv.unsqueeze(-1) @@ -141,12 +176,16 @@ class Algorithm: # Total loss # 总损失 - total_loss = self.vf_coef * value_loss + policy_loss - self.var_beta * entropy_loss + entropy_beta = self._entropy_beta() + total_loss = self.vf_coef * value_loss + policy_loss - entropy_beta * entropy_loss return total_loss, { + "total_loss": total_loss.item(), "value_loss": value_loss.item(), "policy_loss": policy_loss.item(), "entropy_loss": entropy_loss.item(), + "approx_kl": approx_kl.item(), + "clip_fraction": clip_fraction.item(), } def _masked_softmax(self, logits, legal_action): @@ -154,8 +193,11 @@ class Algorithm: 对 logits 应用合法动作掩码后计算 softmax。 """ - label_max, _ = torch.max(logits * legal_action, dim=1, keepdim=True) - logits = logits - label_max - logits = logits * legal_action - logits = logits + 1e5 * (legal_action - 1) - return torch.nn.functional.softmax(logits, dim=1) + legal_mask = legal_action > 0.5 + safe_logits = logits.masked_fill(~legal_mask, -1e9) + return F.softmax(safe_logits, dim=1) + + def _entropy_beta(self): + """Linearly decay entropy regularization for fast early exploration.""" + progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0) + return Config.BETA_START + progress * (Config.BETA_END - Config.BETA_START) diff --git a/agent_ppo/conf/conf.py b/agent_ppo/conf/conf.py index 4f69555..5672d80 100644 --- a/agent_ppo/conf/conf.py +++ b/agent_ppo/conf/conf.py @@ -37,10 +37,16 @@ class Config: GAMMA = 0.99 LAMDA = 0.95 - INIT_LEARNING_RATE_START = 0.0003 - BETA_START = 0.001 + INIT_LEARNING_RATE_START = 0.00025 + BETA_START = 0.008 + BETA_END = 0.002 + BETA_DECAY_STEPS = 4000 CLIP_PARAM = 0.2 VF_COEF = 0.5 + PPO_EPOCHS = 3 + MINI_BATCH_SIZE = 256 + NORMALIZE_ADVANTAGE = True + TARGET_KL = 0.04 LABEL_SIZE_LIST = [ACTION_NUM] LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy() diff --git a/agent_ppo/conf/monitor_builder.py b/agent_ppo/conf/monitor_builder.py index 5cd685f..049c0ea 100644 --- a/agent_ppo/conf/monitor_builder.py +++ b/agent_ppo/conf/monitor_builder.py @@ -77,6 +77,26 @@ def build_monitor(): expr="avg(entropy_loss{})", ) .end_panel() + .add_panel( + name="近似KL", + name_en="approx_kl", + type="line", + ) + .add_metric( + metrics_name="approx_kl", + expr="avg(approx_kl{})", + ) + .end_panel() + .add_panel( + name="裁剪比例", + name_en="clip_fraction", + type="line", + ) + .add_metric( + metrics_name="clip_fraction", + expr="avg(clip_fraction{})", + ) + .end_panel() .end_group() .build() ) diff --git a/agent_ppo/feature/preprocessor.py b/agent_ppo/feature/preprocessor.py index ab97a93..31560b9 100644 --- a/agent_ppo/feature/preprocessor.py +++ b/agent_ppo/feature/preprocessor.py @@ -47,6 +47,11 @@ class Preprocessor: self.battery_max = 600 self.cur_pos = (0, 0) + self.prev_pos = None + self.has_position_history = False + self.current_visit_count = 0 + self.is_new_cell = False + self.last_action = -1 self.dirt_cleaned = 0 self.last_dirt_cleaned = 0 @@ -60,6 +65,7 @@ class Preprocessor: # 最近污渍距离 self.nearest_dirt_dist = 200.0 self.last_nearest_dirt_dist = 200.0 + self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16) self._view_map = np.zeros((21, 21), dtype=np.float32) self._legal_act = [1] * 8 @@ -74,8 +80,20 @@ class Preprocessor: env_info = observation["env_info"] hero = frame_state["heroes"] + self.last_action = int(last_action) self.step_no = int(observation["step_no"]) + self.prev_pos = self.cur_pos if self.has_position_history else None self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"])) + self.has_position_history = True + + hx, hz = self.cur_pos + if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE: + self.current_visit_count = int(self.visit_count_map[hx, hz]) + self.is_new_cell = self.current_visit_count == 0 + self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max) + else: + self.current_visit_count = 0 + self.is_new_cell = False # Battery / 电量 self.battery = int(hero["battery"]) @@ -238,9 +256,15 @@ class Preprocessor: local_view = self._get_local_view_feature() # 49D global_state = self._get_global_state_feature() # 12D legal_action = self.get_legal_action() # 8D - legal_arr = np.array(legal_action, dtype=np.float32) - feature = np.concatenate([local_view, global_state, legal_arr]) # 69D + last_action_feature = np.zeros(8, dtype=np.float32) + if 0 <= last_action < 8: + last_action_feature[last_action] = 1.0 + + # The legal action mask is passed separately to PPO. Reusing this 8D slot + # for action history makes the 69D observation more informative without + # breaking the framework's fixed tensor shape. + feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D reward = self.reward_process() @@ -249,9 +273,26 @@ class Preprocessor: def reward_process(self): # Cleaning reward / 清扫奖励 cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned) - cleaning_reward = 0.1 * cleaned_this_step + cleaning_reward = 0.25 * cleaned_this_step # Step penalty / 时间惩罚 - step_penalty = -0.001 + step_penalty = -0.002 - return cleaning_reward + step_penalty + # Dense guidance: prefer moving toward visible dirt. + # 稠密引导:鼓励向视野内污渍靠近。 + approach_reward = 0.0 + if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0: + dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0)) + approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta + + # Encourage covering new passable cells and mildly discourage loops. + # 鼓励探索新格子,轻微惩罚反复绕圈。 + exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5) + + # Collision/stuck signal: invalid moves waste both step and battery. + # 撞墙/原地不动会浪费步数和电量。 + stuck_penalty = 0.0 + if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8: + stuck_penalty = -0.03 + + return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty diff --git a/agent_ppo/workflow/train_workflow.py b/agent_ppo/workflow/train_workflow.py index 49a34fa..413c267 100644 --- a/agent_ppo/workflow/train_workflow.py +++ b/agent_ppo/workflow/train_workflow.py @@ -138,19 +138,22 @@ class EpisodeRunner: # Survived to max steps: higher cleaning ratio → more reward # 存活到最大步数:清扫比例越高奖励越多 cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) - final_reward = 5.0 + 5.0 * cleaning_ratio + final_reward = 2.0 + 8.0 * cleaning_ratio result_str = "WIN" else: - # Early termination (battery depleted or collision): small penalty - # 提前结束(电量耗尽或碰撞):小惩罚 - final_reward = -2.0 + # Battery-depleted episodes are common with short runs; keep + # cleaning progress as the dominant terminal signal. + # 短训中电量耗尽较常见,终局奖励仍以清扫比例为主。 + cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) + final_reward = -1.0 + 6.0 * cleaning_ratio result_str = "FAIL" self.logger.info( f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} " f"result:{result_str} final_bonus:{final_reward:.2f} " f"total_reward:{total_reward:.3f} " - f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt}" + f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} " + f"total_score:{total_score}" ) # Build sample frame