From ba6cf2a797e6f80d2e8baf6c6de2a281512e866b Mon Sep 17 00:00:00 2001 From: gqt <3217233537@qq.com> Date: Sun, 26 Apr 2026 15:08:43 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3PPO=E5=85=85=E7=94=B5?= =?UTF-8?q?=E5=A5=96=E5=8A=B1=E9=98=B2=E6=AD=A2=E8=B9=B2=E6=A1=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_ppo/feature/preprocessor.py | 85 ++++++++++++++++++++++++++-- agent_ppo/workflow/train_workflow.py | 16 ++++-- 2 files changed, 89 insertions(+), 12 deletions(-) diff --git a/agent_ppo/feature/preprocessor.py b/agent_ppo/feature/preprocessor.py index 9a51a29..432574d 100644 --- a/agent_ppo/feature/preprocessor.py +++ b/agent_ppo/feature/preprocessor.py @@ -94,11 +94,18 @@ class Preprocessor: self.truncated = False self.remaining_charge = 0 + self.prev_battery = 600 + self.prev_battery_max = 600 + self.prev_on_charger = False + self.prev_low_battery = False + self.was_recharge_mode = False self.charge_count = 0 self.last_charge_count = 0 self.charge_delta = 0 self.nearest_charger_dx = 0.0 self.nearest_charger_dz = 0.0 + self.nearest_charger_center_dx = 0.0 + self.nearest_charger_center_dz = 0.0 self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE) self.last_nearest_charger_range_dist = float(self.GRID_SIZE) @@ -106,6 +113,7 @@ class Preprocessor: self.has_charger = False self.low_battery = False self.on_charger = False + self.charger_rects = [] self.recharge_mode = False self.recharge_steps = 0 @@ -135,6 +143,11 @@ class Preprocessor: self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no))) self.terminated = bool(env_obs.get("terminated", False)) self.truncated = bool(env_obs.get("truncated", False)) + self.prev_battery = self.battery + self.prev_battery_max = self.battery_max + self.prev_on_charger = self.on_charger + self.prev_low_battery = self.low_battery + self.was_recharge_mode = self.recharge_mode self.prev_pos = self.cur_pos if self.has_position_history else None hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]} self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1]))) @@ -226,8 +239,11 @@ class Preprocessor: self.on_charger = False self.nearest_charger_dx = 0.0 self.nearest_charger_dz = 0.0 + self.nearest_charger_center_dx = 0.0 + self.nearest_charger_center_dz = 0.0 self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE) + self.charger_rects = [] best = None for organ in organs: @@ -242,20 +258,25 @@ class Preprocessor: h = max(int(organ.get("h", 3)), 1) for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)): + self.charger_rects.append((rx, rz, w, h)) dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h) dist = float(np.sqrt(dx * dx + dz * dz)) range_dist = float(max(abs(dx), abs(dz))) + center_dx = (rx + (w - 1) * 0.5) - hx + center_dz = (rz + (h - 1) * 0.5) - hz if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]): - best = (range_dist, dist, dx, dz) + best = (range_dist, dist, dx, dz, center_dx, center_dz) if best is None: self.battery_margin = float(self.battery) return - range_dist, dist, dx, dz = best + range_dist, dist, dx, dz, center_dx, center_dz = best self.has_charger = True self.nearest_charger_dx = float(dx) self.nearest_charger_dz = float(dz) + self.nearest_charger_center_dx = float(center_dx) + self.nearest_charger_center_dz = float(center_dz) self.nearest_charger_dist = float(dist) self.nearest_charger_range_dist = float(range_dist) self.on_charger = range_dist <= 0.0 @@ -328,6 +349,15 @@ class Preprocessor: if self.recharge_mode: self.recharge_steps += 1 + def _min_charger_range_dist(self, x, z): + if not self.charger_rects: + return float(self.GRID_SIZE) + dists = [] + for rx, rz, w, h in self.charger_rects: + dx, dz = self._relative_vector_to_rect(x, z, rx, rz, w, h) + dists.append(max(abs(dx), abs(dz))) + return float(min(dists)) + def _get_local_view_feature(self): """Local view feature (121D): crop center 11×11 from 21×21. @@ -480,6 +510,8 @@ class Preprocessor: legal = self._filter_npc_danger_actions(legal) if self.recharge_mode: legal = self._filter_recharge_actions(legal) + elif self.on_charger and self.battery / max(self.battery_max, 1) > 0.65: + legal = self._filter_leave_charger_actions(legal) return list(legal) def _filter_blocked_actions(self, legal_action): @@ -573,6 +605,39 @@ class Preprocessor: return recharge if any(recharge) else list(legal_action) + def _filter_leave_charger_actions(self, legal_action): + """Prefer moves that leave charger range when battery is healthy.""" + if not self.has_charger: + return list(legal_action) + + hx, hz = self.cur_pos + current_dist = self._min_charger_range_dist(hx, hz) + scored = [] + for action, (dx, dz) in enumerate(self.ACTION_DIRS): + if legal_action[action] <= 0: + continue + nx, nz = hx + dx, hz + dz + next_dist = self._min_charger_range_dist(nx, nz) + away_score = -(dx * self.nearest_charger_center_dx + dz * self.nearest_charger_center_dz) + scored.append((next_dist - current_dist, away_score, action)) + + if not scored: + return list(legal_action) + + best_escape = max(item[0] for item in scored) + leave = [0] * 8 + if best_escape > 0: + for escape, _, action in scored: + if escape >= best_escape - 0.1: + leave[action] = 1 + else: + best_away = max(item[1] for item in scored) + for _, away_score, action in scored: + if away_score >= best_away: + leave[action] = 1 + + return leave if any(leave) else list(legal_action) + def feature_process(self, env_obs, last_action): """Generate feature vector, legal action mask, and scalar reward. @@ -613,8 +678,16 @@ class Preprocessor: # Recharge guidance only activates when battery safety is the bottleneck. # 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。 charge_reward = 0.0 - if self.charge_delta > 0: - charge_reward += 3.0 * self.charge_delta + battery_ratio = self.battery / max(self.battery_max, 1) + prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1) + useful_charge = self.charge_delta > 0 and ( + self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.45 + ) + if useful_charge: + charge_reward += 1.0 + elif self.charge_delta > 0 and battery_ratio > 0.65: + charge_reward -= 0.25 * min(self.charge_delta, 3) + if self.has_charger and (self.recharge_mode or self.low_battery): dist_delta = float( np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0) @@ -622,8 +695,8 @@ class Preprocessor: charge_reward += 0.04 * dist_delta if dist_delta > 0 else 0.02 * dist_delta if self.battery_margin < 0: charge_reward -= min(0.25, abs(self.battery_margin) / max(self.battery_max, 1)) - elif self.on_charger and self.charge_delta == 0 and self.battery / max(self.battery_max, 1) > 0.85: - charge_reward -= 0.01 + elif self.on_charger and battery_ratio > 0.65: + charge_reward -= 0.08 # Encourage covering new passable cells and mildly discourage loops. # 鼓励探索新格子,轻微惩罚反复绕圈。 diff --git a/agent_ppo/workflow/train_workflow.py b/agent_ppo/workflow/train_workflow.py index 4b399b6..dec00fa 100644 --- a/agent_ppo/workflow/train_workflow.py +++ b/agent_ppo/workflow/train_workflow.py @@ -140,15 +140,19 @@ class EpisodeRunner: finished_steps = env_info.get("finished_steps", step) result_message = extra_info.get("result_message", "") result_code = extra_info.get("result_code", "") + cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) + score_per_step = total_score / max(finished_steps, 1) if truncated: - # Survived to max steps: higher cleaning ratio → more reward - # 存活到最大步数:清扫比例越高奖励越多 - cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) - final_reward = 2.0 + 8.0 * cleaning_ratio - result_str = "WIN" + if score_per_step < 0.25: + final_reward = -3.0 + 6.0 * cleaning_ratio + result_str = "STALL_TRUNCATED" + else: + # Survived to max steps: higher cleaning ratio → more reward + # 存活到最大步数:清扫比例越高奖励越多 + final_reward = 2.0 + 8.0 * cleaning_ratio + result_str = "WIN" else: - cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) if fm.battery <= 0 or remaining_charge <= 0: final_reward = -4.0 + 6.0 * cleaning_ratio result_str = "BATTERY_FAIL"