diff --git a/agent_ppo/agent.py b/agent_ppo/agent.py index d2def52..56a2621 100644 --- a/agent_ppo/agent.py +++ b/agent_ppo/agent.py @@ -55,9 +55,9 @@ class Agent(BaseAgent): self.last_reward = 0.0 def observation_process(self, env_obs): - """Convert raw env_obs to ObsData (69D feature + legal action mask). + """Convert raw env_obs to ObsData (feature + legal action mask). - 将原始 env_obs 转换为 ObsData(69D 特征 + 合法动作掩码)。 + 将原始 env_obs 转换为 ObsData(特征 + 合法动作掩码)。 """ feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action) self.last_reward = reward @@ -135,8 +135,16 @@ class Agent(BaseAgent): 加载模型检查点。 """ model_file_path = f"{path}/model.ckpt-{id}.pkl" - self.model.load_state_dict(torch.load(model_file_path, map_location=self.device)) - self.logger.info(f"load model {model_file_path} successfully") + state_dict = torch.load(model_file_path, map_location=self.device) + try: + self.model.load_state_dict(state_dict) + except RuntimeError as err: + msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}" + if self.logger: + self.logger.warning(msg) + return + if self.logger: + self.logger.info(f"load model {model_file_path} successfully") def _run_model(self, feature): """Gradient-free forward pass, returns (logits_np, value_np). diff --git a/agent_ppo/conf/conf.py b/agent_ppo/conf/conf.py index 5672d80..b291606 100644 --- a/agent_ppo/conf/conf.py +++ b/agent_ppo/conf/conf.py @@ -13,12 +13,12 @@ Configuration for Robot Vacuum PPO agent. class Config: - # Feature dimensions (69D) - # 特征维度(69D) + # Feature dimensions (157D) + # 特征维度(157D) FEATURES = [ - 7 * 7, - 12, - 8, + 11 * 11, # wider local map view / 更大的局部地图视野 + 28, # global, charger, NPC, and map-stat features / 全局、充电桩、NPC、地图统计特征 + 8, # last action one-hot / 上一步动作 one-hot ] FEATURE_SPLIT_SHAPE = FEATURES FEATURE_LEN = sum(FEATURES) diff --git a/agent_ppo/feature/definition.py b/agent_ppo/feature/definition.py index a137444..84438fe 100644 --- a/agent_ppo/feature/definition.py +++ b/agent_ppo/feature/definition.py @@ -33,7 +33,7 @@ ActData = create_cls( # 训练样本数据:字段值为 int 时框架自动按维度处理 SampleData = create_cls( "SampleData", - obs=Config.DIM_OF_OBSERVATION, # 69D feature vector / 特征向量 + obs=Config.DIM_OF_OBSERVATION, # feature vector / 特征向量 legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码 act=1, # action index / 执行的动作 reward=Config.VALUE_NUM, # 1D reward / 奖励 diff --git a/agent_ppo/feature/preprocessor.py b/agent_ppo/feature/preprocessor.py index 31560b9..9a51a29 100644 --- a/agent_ppo/feature/preprocessor.py +++ b/agent_ppo/feature/preprocessor.py @@ -24,6 +24,13 @@ def _norm(v, v_max, v_min=0.0): return (v - v_min) / (v_max - v_min) +def _signed_norm(v, v_max): + """Normalize signed value to [-1, 1].""" + if v_max <= 0: + return 0.0 + return float(np.clip(float(v) / float(v_max), -1.0, 1.0)) + + class Preprocessor: """Feature preprocessor for Robot Vacuum. @@ -32,7 +39,17 @@ class Preprocessor: GRID_SIZE = 128 VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径 - LOCAL_HALF = 3 # Cropped view radius (7×7) / 裁剪后的视野半径 + LOCAL_HALF = 5 # Cropped view radius (11×11) / 裁剪后的视野半径 + ACTION_DIRS = ( + (1, 0), + (1, -1), + (0, -1), + (-1, -1), + (-1, 0), + (-1, 1), + (0, 1), + (1, 1), + ) def __init__(self): self.reset() @@ -56,6 +73,10 @@ class Preprocessor: self.dirt_cleaned = 0 self.last_dirt_cleaned = 0 self.total_dirt = 1 + self.total_score = 0 + self.clean_score = 0 + self.step_cleaned_count = 0 + self.max_step = 1000 # Global passable map (0=obstacle, 1=passable), used for ray computation # 维护全局通行地图(0=障碍, 1=可通行),用于射线计算 @@ -69,6 +90,33 @@ class Preprocessor: self._view_map = np.zeros((21, 21), dtype=np.float32) self._legal_act = [1] * 8 + self.terminated = False + self.truncated = False + + self.remaining_charge = 0 + self.charge_count = 0 + self.last_charge_count = 0 + self.charge_delta = 0 + self.nearest_charger_dx = 0.0 + self.nearest_charger_dz = 0.0 + self.nearest_charger_dist = float(self.GRID_SIZE) + self.nearest_charger_range_dist = float(self.GRID_SIZE) + self.last_nearest_charger_range_dist = float(self.GRID_SIZE) + self.battery_margin = 0.0 + self.has_charger = False + self.low_battery = False + self.on_charger = False + self.recharge_mode = False + self.recharge_steps = 0 + + self.nearest_npc_dx = 0.0 + self.nearest_npc_dz = 0.0 + self.nearest_npc_dist = float(self.GRID_SIZE) + self.npc_danger = False + self.npcs = [] + + self.local_dirt_ratio = 0.0 + self.local_obstacle_ratio = 0.0 def pb2struct(self, env_obs, last_action): """Parse and cache essential fields from observation dict. @@ -76,14 +124,20 @@ class Preprocessor: 从 env_obs 字典中提取并缓存所有需要的状态量。 """ observation = env_obs["observation"] - frame_state = observation["frame_state"] - env_info = observation["env_info"] - hero = frame_state["heroes"] + frame_state = observation.get("frame_state", {}) + extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {}) + env_info = observation.get("env_info", {}) + hero = frame_state.get("heroes", {}) + if isinstance(hero, list): + hero = hero[0] if hero else {} self.last_action = int(last_action) - self.step_no = int(observation["step_no"]) + self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no))) + self.terminated = bool(env_obs.get("terminated", False)) + self.truncated = bool(env_obs.get("truncated", False)) self.prev_pos = self.cur_pos if self.has_position_history else None - self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"])) + hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]} + self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1]))) self.has_position_history = True hx, hz = self.cur_pos @@ -96,16 +150,30 @@ class Preprocessor: self.is_new_cell = False # Battery / 电量 - self.battery = int(hero["battery"]) - self.battery_max = max(int(hero["battery_max"]), 1) + self.battery = int(hero.get("battery", env_info.get("remaining_charge", self.battery))) + self.battery_max = max(int(hero.get("battery_max", env_info.get("battery_max", self.battery_max))), 1) + self.remaining_charge = int(env_info.get("remaining_charge", self.battery)) + self.max_step = max(int(env_info.get("max_step", self.max_step)), 1) # Cleaning progress / 清扫进度 self.last_dirt_cleaned = self.dirt_cleaned - self.dirt_cleaned = int(hero["dirt_cleaned"]) - self.total_dirt = max(int(env_info["total_dirt"]), 1) + self.dirt_cleaned = int(hero.get("dirt_cleaned", env_info.get("clean_score", self.dirt_cleaned))) + self.total_dirt = max(int(env_info.get("total_dirt", self.total_dirt)), 1) + self.total_score = int(env_info.get("total_score", self.total_score)) + self.clean_score = int(env_info.get("clean_score", self.dirt_cleaned)) + step_cleaned_cells = env_info.get("step_cleaned_cells") or [] + self.step_cleaned_count = len(step_cleaned_cells) + + # Charge progress / 充电进度 + self.last_charge_count = self.charge_count + self.charge_count = int(env_info.get("charge_count", self.charge_count)) + self.charge_delta = max(0, self.charge_count - self.last_charge_count) # Legal actions / 合法动作 - self._legal_act = [int(x) for x in (observation.get("legal_action") or [1] * 8)] + raw_legal_act = observation.get("legal_action") or observation.get("legal_act") or [1] * 8 + self._legal_act = [int(x) for x in raw_legal_act[:8]] + if len(self._legal_act) < 8: + self._legal_act.extend([1] * (8 - len(self._legal_act))) # Local view map (21×21) / 局部视野地图 map_info = observation.get("map_info") @@ -113,6 +181,14 @@ class Preprocessor: self._view_map = np.array(map_info, dtype=np.float32) hx, hz = self.cur_pos self._update_passable(hx, hz) + self._update_local_map_stats() + + organs = frame_state.get("organs") or extra_frame_state.get("organs") or [] + npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or [] + self.npcs = list(npcs) + self._update_charger_state(hx, hz, organs) + self._update_npc_state(hx, hz, self.npcs) + self._update_recharge_mode() def _update_passable(self, hx, hz): """Write local view into global passable map. @@ -132,10 +208,130 @@ class Preprocessor: # 0 = 障碍, 1/2 = 可通行 self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0 - def _get_local_view_feature(self): - """Local view feature (49D): crop center 7×7 from 21×21. + def _update_local_map_stats(self): + """Cache coarse 21x21 map statistics.""" + view = self._view_map + if view is None or view.size == 0: + self.local_dirt_ratio = 0.0 + self.local_obstacle_ratio = 0.0 + return + total = float(view.size) + self.local_dirt_ratio = float(np.sum(view == 2) / total) + self.local_obstacle_ratio = float(np.sum(view == 0) / total) - 局部视野特征(49D):从 21×21 视野中心裁剪 7×7。 + def _update_charger_state(self, hx, hz, organs): + """Find nearest charger and cache distance/direction features.""" + self.last_nearest_charger_range_dist = self.nearest_charger_range_dist + self.has_charger = False + self.on_charger = False + self.nearest_charger_dx = 0.0 + self.nearest_charger_dz = 0.0 + self.nearest_charger_dist = float(self.GRID_SIZE) + self.nearest_charger_range_dist = float(self.GRID_SIZE) + + best = None + for organ in organs: + if not isinstance(organ, dict): + continue + if int(organ.get("sub_type", 1)) != 1: + continue + pos = organ.get("pos") or {} + ox = int(pos.get("x", 0)) + oz = int(pos.get("z", 0)) + w = max(int(organ.get("w", 3)), 1) + h = max(int(organ.get("h", 3)), 1) + + for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)): + dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h) + dist = float(np.sqrt(dx * dx + dz * dz)) + range_dist = float(max(abs(dx), abs(dz))) + if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]): + best = (range_dist, dist, dx, dz) + + if best is None: + self.battery_margin = float(self.battery) + return + + range_dist, dist, dx, dz = best + self.has_charger = True + self.nearest_charger_dx = float(dx) + self.nearest_charger_dz = float(dz) + self.nearest_charger_dist = float(dist) + self.nearest_charger_range_dist = float(range_dist) + self.on_charger = range_dist <= 0.0 + self.battery_margin = float(self.battery) - self.nearest_charger_range_dist + + def _relative_vector_to_rect(self, x, z, rx, rz, w, h): + """Relative vector from point to the nearest cell in a rectangle.""" + if x < rx: + dx = rx - x + elif x > rx + w - 1: + dx = x - (rx + w - 1) + else: + dx = 0 + + if z < rz: + dz = rz - z + elif z > rz + h - 1: + dz = z - (rz + h - 1) + else: + dz = 0 + return float(dx), float(dz) + + def _update_npc_state(self, hx, hz, npcs): + """Find nearest NPC and cache safety features.""" + self.nearest_npc_dx = 0.0 + self.nearest_npc_dz = 0.0 + self.nearest_npc_dist = float(self.GRID_SIZE) + self.npc_danger = False + + best = None + for npc in npcs: + if not isinstance(npc, dict): + continue + pos = npc.get("pos") or {} + nx = int(pos.get("x", 0)) + nz = int(pos.get("z", 0)) + dx = nx - hx + dz = nz - hz + cheb = float(max(abs(dx), abs(dz))) + if best is None or cheb < best[0]: + best = (cheb, dx, dz) + + if best is None: + return + + cheb, dx, dz = best + self.nearest_npc_dx = float(dx) + self.nearest_npc_dz = float(dz) + self.nearest_npc_dist = float(cheb) + self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1 + + def _update_recharge_mode(self): + """Enter/exit low-battery recharge mode.""" + battery_ratio = self.battery / max(self.battery_max, 1) + self.low_battery = battery_ratio < 0.35 + + if not self.has_charger: + self.recharge_mode = False + return + + if self.charge_delta > 0 or (self.on_charger and battery_ratio > 0.85): + self.recharge_mode = False + elif self.battery <= self.nearest_charger_range_dist + 18 or battery_ratio < 0.22: + self.recharge_mode = True + elif self.recharge_mode and battery_ratio < 0.85: + self.recharge_mode = True + else: + self.recharge_mode = False + + if self.recharge_mode: + self.recharge_steps += 1 + + def _get_local_view_feature(self): + """Local view feature (121D): crop center 11×11 from 21×21. + + 局部视野特征(121D):从 21×21 视野中心裁剪 11×11。 """ center = self.VIEW_HALF h = self.LOCAL_HALF @@ -143,9 +339,9 @@ class Preprocessor: return (crop / 2.0).flatten() def _get_global_state_feature(self): - """Global state feature (12D). + """Global state feature (28D). - 全局状态特征(12D)。 + 全局状态特征(28D)。 Dimensions / 维度说明: [0] step_norm step progress / 步数归一化 [0,1] @@ -160,8 +356,24 @@ class Preprocessor: [9] ray_W_dirt west ray distance / 向左(x-)方向 [10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化 [11] dirt_delta approaching dirt indicator / 是否在接近污渍(1=是, 0=否) + [12] charger_dx nearest charger x direction / 最近充电桩 x 相对方向 + [13] charger_dz nearest charger z direction / 最近充电桩 z 相对方向 + [14] charger_dist nearest charger distance / 最近充电桩距离 + [15] battery_margin battery minus charger distance / 电量安全余量 + [16] low_battery low-battery flag / 低电量标记 + [17] recharge_mode recharge-mode flag / 回充模式标记 + [18] on_charger on charger flag / 是否在充电桩范围 + [19] charge_delta charge count increased / 本步是否成功充电 + [20] npc_dx nearest NPC x direction / 最近 NPC x 相对方向 + [21] npc_dz nearest NPC z direction / 最近 NPC z 相对方向 + [22] npc_dist nearest NPC Chebyshev distance / 最近 NPC 切比雪夫距离 + [23] npc_danger in NPC 3x3 danger zone / 是否处于 NPC 3x3 危险区 + [24] local_dirt_ratio dirt ratio in 21x21 view / 21x21 视野污渍比例 + [25] obstacle_ratio obstacle ratio in 21x21 view / 21x21 视野障碍比例 + [26] visit_count current cell visit count / 当前格访问次数 + [27] step_cleaned cells cleaned this step / 本步清扫格子数 """ - step_norm = _norm(self.step_no, 2000) + step_norm = _norm(self.step_no, self.max_step) battery_ratio = _norm(self.battery, self.battery_max) cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt) remaining_dirt = 1.0 - cleaning_progress @@ -205,6 +417,10 @@ class Preprocessor: nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180) dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0 + charge_delta = 1.0 if self.charge_delta > 0 else 0.0 + battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max) + visit_count_norm = _norm(min(self.current_visit_count, 10), 10) + step_cleaned_norm = _norm(self.step_cleaned_count, 9) return np.array( [ @@ -220,6 +436,22 @@ class Preprocessor: ray_dirt[3], nearest_dirt_norm, dirt_delta, + _signed_norm(self.nearest_charger_dx, self.GRID_SIZE), + _signed_norm(self.nearest_charger_dz, self.GRID_SIZE), + _norm(self.nearest_charger_range_dist, self.GRID_SIZE), + battery_margin_norm, + 1.0 if self.low_battery else 0.0, + 1.0 if self.recharge_mode else 0.0, + 1.0 if self.on_charger else 0.0, + charge_delta, + _signed_norm(self.nearest_npc_dx, 20), + _signed_norm(self.nearest_npc_dz, 20), + _norm(self.nearest_npc_dist, 20), + 1.0 if self.npc_danger else 0.0, + self.local_dirt_ratio, + self.local_obstacle_ratio, + visit_count_norm, + step_cleaned_norm, ], dtype=np.float32, ) @@ -244,27 +476,119 @@ class Preprocessor: 返回合法动作掩码(8D list)。 """ - return list(self._legal_act) + legal = self._filter_blocked_actions(self._legal_act) + legal = self._filter_npc_danger_actions(legal) + if self.recharge_mode: + legal = self._filter_recharge_actions(legal) + return list(legal) + + def _filter_blocked_actions(self, legal_action): + """Filter actions that are visibly blocked in the 21x21 view.""" + legal = [int(x) for x in legal_action] + hx, hz = self.cur_pos + for action, (dx, dz) in enumerate(self.ACTION_DIRS): + if legal[action] <= 0: + continue + if not self._is_visible_cell_passable(dx, dz): + legal[action] = 0 + continue + if dx != 0 and dz != 0: + side_a = self._is_visible_cell_passable(dx, 0) + side_b = self._is_visible_cell_passable(0, dz) + if not (side_a or side_b): + legal[action] = 0 + nx, nz = hx + dx, hz + dz + if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE): + legal[action] = 0 + + return legal if any(legal) else [int(x) for x in legal_action] + + def _is_visible_cell_passable(self, dx, dz): + """Whether a relative 21x21-view cell is passable.""" + ri = self.VIEW_HALF + dx + ci = self.VIEW_HALF + dz + if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]): + return True + return int(self._view_map[ri, ci]) != 0 + + def _filter_npc_danger_actions(self, legal_action): + """Avoid actions that would enter any NPC 3x3 danger zone.""" + if not self.npcs: + return list(legal_action) + + hx, hz = self.cur_pos + safe = [int(x) for x in legal_action] + for action, (dx, dz) in enumerate(self.ACTION_DIRS): + if safe[action] <= 0: + continue + nx, nz = hx + dx, hz + dz + if self._is_npc_danger_cell(nx, nz): + safe[action] = 0 + + return safe if any(safe) else list(legal_action) + + def _is_npc_danger_cell(self, x, z): + for npc in self.npcs: + if not isinstance(npc, dict): + continue + pos = npc.get("pos") or {} + nx = int(pos.get("x", -999)) + nz = int(pos.get("z", -999)) + if abs(x - nx) <= 1 and abs(z - nz) <= 1: + return True + return False + + def _filter_recharge_actions(self, legal_action): + """Restrict low-battery actions to moves that approach the charger.""" + if not self.has_charger: + return list(legal_action) + + hx, hz = self.cur_pos + current_dist = max(abs(self.nearest_charger_dx), abs(self.nearest_charger_dz)) + scored = [] + for action, (dx, dz) in enumerate(self.ACTION_DIRS): + if legal_action[action] <= 0: + continue + next_dx = self.nearest_charger_dx - dx + next_dz = self.nearest_charger_dz - dz + next_dist = max(abs(next_dx), abs(next_dz)) + improvement = current_dist - next_dist + alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz + scored.append((improvement, alignment, action)) + + if not scored: + return list(legal_action) + + best_improvement = max(item[0] for item in scored) + recharge = [0] * 8 + if best_improvement > 0: + for improvement, _, action in scored: + if improvement >= best_improvement - 0.1: + recharge[action] = 1 + else: + best_alignment = max(item[1] for item in scored) + for _, alignment, action in scored: + if alignment >= best_alignment: + recharge[action] = 1 + + return recharge if any(recharge) else list(legal_action) def feature_process(self, env_obs, last_action): - """Generate 69D feature vector, legal action mask, and scalar reward. + """Generate feature vector, legal action mask, and scalar reward. - 生成 69D 特征向量、合法动作掩码和标量奖励。 + 生成特征向量、合法动作掩码和标量奖励。 """ self.pb2struct(env_obs, last_action) - local_view = self._get_local_view_feature() # 49D - global_state = self._get_global_state_feature() # 12D + local_view = self._get_local_view_feature() # 121D + global_state = self._get_global_state_feature() # 28D legal_action = self.get_legal_action() # 8D last_action_feature = np.zeros(8, dtype=np.float32) if 0 <= last_action < 8: last_action_feature[last_action] = 1.0 - # The legal action mask is passed separately to PPO. Reusing this 8D slot - # for action history makes the 69D observation more informative without - # breaking the framework's fixed tensor shape. - feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D + feature = np.concatenate([local_view, global_state, last_action_feature]) reward = self.reward_process() @@ -273,7 +597,8 @@ class Preprocessor: def reward_process(self): # Cleaning reward / 清扫奖励 cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned) - cleaning_reward = 0.25 * cleaned_this_step + cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step + cleaning_reward = 0.7 * cleaned_cells # Step penalty / 时间惩罚 step_penalty = -0.002 @@ -281,13 +606,28 @@ class Preprocessor: # Dense guidance: prefer moving toward visible dirt. # 稠密引导:鼓励向视野内污渍靠近。 approach_reward = 0.0 - if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0: + if not self.recharge_mode and (self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0): dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0)) approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta + # Recharge guidance only activates when battery safety is the bottleneck. + # 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。 + charge_reward = 0.0 + if self.charge_delta > 0: + charge_reward += 3.0 * self.charge_delta + if self.has_charger and (self.recharge_mode or self.low_battery): + dist_delta = float( + np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0) + ) + charge_reward += 0.04 * dist_delta if dist_delta > 0 else 0.02 * dist_delta + if self.battery_margin < 0: + charge_reward -= min(0.25, abs(self.battery_margin) / max(self.battery_max, 1)) + elif self.on_charger and self.charge_delta == 0 and self.battery / max(self.battery_max, 1) > 0.85: + charge_reward -= 0.01 + # Encourage covering new passable cells and mildly discourage loops. # 鼓励探索新格子,轻微惩罚反复绕圈。 - exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5) + exploration_reward = 0.004 if self.is_new_cell else -0.0015 * min(self.current_visit_count, 6) # Collision/stuck signal: invalid moves waste both step and battery. # 撞墙/原地不动会浪费步数和电量。 @@ -295,4 +635,26 @@ class Preprocessor: if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8: stuck_penalty = -0.03 - return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty + npc_penalty = 0.0 + if self.npc_danger: + npc_penalty -= 4.0 + elif self.nearest_npc_dist <= 3: + npc_penalty -= 0.05 * (4 - self.nearest_npc_dist) + + terminal_penalty = 0.0 + if self.terminated and not self.truncated: + if self.battery <= 0 or self.remaining_charge <= 0: + terminal_penalty -= 4.0 + elif self.npc_danger or self.nearest_npc_dist <= 1: + terminal_penalty -= 3.0 + + return ( + cleaning_reward + + approach_reward + + charge_reward + + exploration_reward + + stuck_penalty + + npc_penalty + + terminal_penalty + + step_penalty + ) diff --git a/agent_ppo/model/model.py b/agent_ppo/model/model.py index 0fb930b..337ef4e 100644 --- a/agent_ppo/model/model.py +++ b/agent_ppo/model/model.py @@ -38,22 +38,22 @@ class Model(nn.Module): self.model_name = "robot_vacuum" self.device = device - obs_dim = Config.DIM_OF_OBSERVATION # 69 + obs_dim = Config.DIM_OF_OBSERVATION # 157 act_num = Config.ACTION_NUM # 8 # Shared backbone / 共享骨干网络 self.backbone = nn.Sequential( - _make_fc(obs_dim, 128), + _make_fc(obs_dim, 256), nn.ReLU(), - _make_fc(128, 64), + _make_fc(256, 128), nn.ReLU(), ) # Actor head: outputs action logits / 策略头:输出动作 logits - self.actor_head = _make_fc(64, act_num, gain=0.01) + self.actor_head = _make_fc(128, act_num, gain=0.01) # Critic head: outputs single state value / 价值头:输出单个状态价值 - self.critic_head = _make_fc(64, 1, gain=0.01) + self.critic_head = _make_fc(128, 1, gain=0.01) def forward(self, s, inference=False): """Forward pass. diff --git a/agent_ppo/workflow/train_workflow.py b/agent_ppo/workflow/train_workflow.py index 413c267..4b399b6 100644 --- a/agent_ppo/workflow/train_workflow.py +++ b/agent_ppo/workflow/train_workflow.py @@ -132,7 +132,14 @@ class EpisodeRunner: final_reward = 0.0 if done: fm = self.agent.preprocessor - total_score = env_obs["observation"]["env_info"]["total_score"] + env_info = env_obs["observation"]["env_info"] + extra_info = env_obs.get("extra_info", {}) + total_score = env_info.get("total_score", fm.total_score) + remaining_charge = env_info.get("remaining_charge", fm.remaining_charge) + charge_count = env_info.get("charge_count", fm.charge_count) + finished_steps = env_info.get("finished_steps", step) + result_message = extra_info.get("result_message", "") + result_code = extra_info.get("result_code", "") if truncated: # Survived to max steps: higher cleaning ratio → more reward @@ -141,19 +148,31 @@ class EpisodeRunner: final_reward = 2.0 + 8.0 * cleaning_ratio result_str = "WIN" else: - # Battery-depleted episodes are common with short runs; keep - # cleaning progress as the dominant terminal signal. - # 短训中电量耗尽较常见,终局奖励仍以清扫比例为主。 cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1) - final_reward = -1.0 + 6.0 * cleaning_ratio - result_str = "FAIL" + if fm.battery <= 0 or remaining_charge <= 0: + final_reward = -4.0 + 6.0 * cleaning_ratio + result_str = "BATTERY_FAIL" + elif fm.npc_danger or fm.nearest_npc_dist <= 1: + final_reward = -3.0 + 6.0 * cleaning_ratio + result_str = "NPC_FAIL" + else: + final_reward = -2.0 + 6.0 * cleaning_ratio + result_str = "FAIL" self.logger.info( f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} " + f"finished_steps:{finished_steps} " f"result:{result_str} final_bonus:{final_reward:.2f} " f"total_reward:{total_reward:.3f} " f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} " - f"total_score:{total_score}" + f"total_score:{total_score} " + f"remaining_charge:{remaining_charge} " + f"charge_count:{charge_count} " + f"recharge_steps:{fm.recharge_steps} " + f"nearest_charger:{fm.nearest_charger_range_dist:.1f} " + f"nearest_npc:{fm.nearest_npc_dist:.1f} " + f"result_code:{result_code} " + f"result_message:{result_message}" ) # Build sample frame