#!/usr/bin/env python3 # -*- coding: UTF-8 -*- ########################################################################### # Copyright © 1998 - 2026 Tencent. All Rights Reserved. ########################################################################### """ Author: Tencent AI Arena Authors Feature preprocessor for Robot Vacuum. 清扫大作战特征预处理器。 """ import os from collections import deque import numpy as np def _norm(v, v_max, v_min=0.0): """Normalize value to [0, 1]. 将值线性归一化到 [0, 1]。 """ v = float(np.clip(v, v_min, v_max)) if v_max == v_min: return 0.0 return (v - v_min) / (v_max - v_min) def _signed_norm(v, v_max): """Normalize signed value to [-1, 1].""" if v_max <= 0: return 0.0 return float(np.clip(float(v) / float(v_max), -1.0, 1.0)) def _as_dict(value): """Return a dict for optional nested observation fields.""" return value if isinstance(value, dict) else {} class Preprocessor: """Feature preprocessor for Robot Vacuum. 清扫大作战特征预处理器。 """ GRID_SIZE = 128 VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径 VIEW_SIZE = 21 MAP_CHANNELS = 6 PLANNER_UPDATE_INTERVAL = 4 ACTION_DIRS = ( (1, 0), (1, -1), (0, -1), (-1, -1), (-1, 0), (-1, 1), (0, 1), (1, 1), ) INF_DIST = 1e6 ABS_POS_FEATURE_SCALE = 0.2 def __init__(self): self.reset() def reset(self): """Reset all internal state at episode start. 对局开始时重置所有状态。 """ self.map_id = -1 self.step_no = 0 self.battery = 600 self.battery_max = 600 self.cur_pos = (0, 0) self.prev_pos = None self.has_position_history = False self.current_visit_count = 0 self.is_new_cell = False self.last_action = -1 self.stuck_steps = 0 self.recharge_no_progress_steps = 0 self.fake_charger_steps = 0 self.stuck_count = 0 self.max_stuck_steps = 0 self.recharge_escape_count = 0 self.dirt_cleaned = 0 self.last_dirt_cleaned = 0 self.total_dirt = 1 self.total_score = 0 self.clean_score = 0 self.step_cleaned_count = 0 self.max_step = 1000 # Global belief maps indexed by [x, z]. # 全局 belief map,索引为 [x, z]。 self.known_map = np.full((self.GRID_SIZE, self.GRID_SIZE), -1, dtype=np.int8) self.passable_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8) self.frontier_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8) self.dirty_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8) self._dirty_reverse_dist = None self._frontier_reverse_dist = None self._charger_reverse_dist = None self._path_cache_dirty = True self._planner_last_update_step = -self.PLANNER_UPDATE_INTERVAL self.known_ratio = 0.0 self.known_passable_ratio = 0.0 self.known_dirty_ratio = 0.0 self.frontier_ratio = 0.0 self.global_dirty_path_dist = float(self.GRID_SIZE) self.last_global_dirty_path_dist = float(self.GRID_SIZE) self.frontier_path_dist = float(self.GRID_SIZE) self.last_frontier_path_dist = float(self.GRID_SIZE) self.global_dirty_action_delta = np.zeros(8, dtype=np.float32) self.frontier_action_delta = np.zeros(8, dtype=np.float32) self.charger_action_delta = np.zeros(8, dtype=np.float32) self.charger_route_known = False self.charger_route_source = "none" # Nearest dirt path distance in the current local view. # 当前局部视野内最近污渍路径距离。 self.nearest_dirt_dist = 200.0 self.last_nearest_dirt_dist = 200.0 self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16) self._view_map = np.zeros((21, 21), dtype=np.float32) self._legal_act = [1] * 8 self.terminated = False self.truncated = False self.remaining_charge = 0 self.prev_battery = 600 self.prev_battery_max = 600 self.prev_on_charger = False self.prev_low_battery = False self.was_recharge_mode = False self.charge_count = 0 self.last_charge_count = 0 self.charge_delta = 0 self.nearest_charger_dx = 0.0 self.nearest_charger_dz = 0.0 self.nearest_charger_center_dx = 0.0 self.nearest_charger_center_dz = 0.0 self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE) self.last_nearest_charger_range_dist = float(self.GRID_SIZE) self.nearest_charger_path_dist = float(self.GRID_SIZE) self.last_nearest_charger_path_dist = float(self.GRID_SIZE) self.charger_energy_cost = float(self.GRID_SIZE) self.charger_safety_buffer = 0.0 self.charger_safety_margin = 0.0 self.recharge_enter_margin = 0.0 self.recharge_leave_margin = 0.0 self.recharge_low_battery_ratio = 0.28 self.full_charge_leave_ratio = 0.96 self.battery_margin = 0.0 self.has_charger = False self.low_battery = False self.on_charger = False self.charger_rects = [] self.recharge_mode = False self.recharge_steps = 0 self.nearest_npc_dx = 0.0 self.nearest_npc_dz = 0.0 self.nearest_npc_vx = 0.0 self.nearest_npc_vz = 0.0 self.nearest_npc_dist = float(self.GRID_SIZE) self.predicted_npc_dist = float(self.GRID_SIZE) self.npc_danger = False self.npc_predicted_danger = False self.npcs = [] self.prev_npc_positions = {} self.predicted_npcs = [] self.npc_close_steps = 0 self.npc_danger_steps = 0 self.npc_collision = 0 self.battery_fail = 0 self.local_dirt_ratio = 0.0 self.local_obstacle_ratio = 0.0 self.reward_profile = os.environ.get("ROBOT_VACUUM_REWARD_PROFILE", "current").strip().lower() or "current" self._reset_diagnostics() def _reset_diagnostics(self): """Reset episode-local diagnostic counters.""" self.diag_mask_steps = 0 self.diag_mask_count_sums = { "raw": 0, "blocked": 0, "npc": 0, "recharge": 0, "escape": 0, "leave": 0, "final": 0, } self.diag_mask_changed_steps = { "blocked": 0, "npc": 0, "recharge": 0, "escape": 0, "leave": 0, } self.diag_mask_active_steps = { "recharge": 0, "leave": 0, } self.diag_one_action_steps = 0 self.diag_two_or_less_action_steps = 0 self.diag_zero_final_steps = 0 self.diag_action_hist = [0] * 8 def pb2struct(self, env_obs, last_action): """Parse and cache essential fields from observation dict. 从 env_obs 字典中提取并缓存所有需要的状态量。 """ env_obs = _as_dict(env_obs) observation = _as_dict(env_obs.get("observation")) frame_state = _as_dict(observation.get("frame_state")) extra_info = _as_dict(env_obs.get("extra_info")) extra_frame_state = _as_dict(extra_info.get("frame_state")) env_info = _as_dict(observation.get("env_info")) hero = frame_state.get("heroes") or {} if isinstance(hero, list): hero = hero[0] if hero else {} hero = _as_dict(hero) self.last_action = int(last_action) map_id_value = extra_info.get("map_id", env_info.get("map_id", self.map_id)) try: self.map_id = int(map_id_value) except (TypeError, ValueError): pass self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no))) self.terminated = bool(env_obs.get("terminated", False)) self.truncated = bool(env_obs.get("truncated", False)) self.prev_battery = self.battery self.prev_battery_max = self.battery_max self.prev_on_charger = self.on_charger self.prev_low_battery = self.low_battery self.was_recharge_mode = self.recharge_mode self.prev_pos = self.cur_pos if self.has_position_history else None hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]}) self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1]))) self.has_position_history = True hx, hz = self.cur_pos if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE: self.current_visit_count = int(self.visit_count_map[hx, hz]) self.is_new_cell = self.current_visit_count == 0 self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max) else: self.current_visit_count = 0 self.is_new_cell = False # Battery / 电量 self.battery = int(hero.get("battery", env_info.get("remaining_charge", self.battery))) self.battery_max = max(int(hero.get("battery_max", env_info.get("battery_max", self.battery_max))), 1) self.remaining_charge = int(env_info.get("remaining_charge", self.battery)) self.max_step = max(int(env_info.get("max_step", self.max_step)), 1) # Cleaning progress / 清扫进度 self.last_dirt_cleaned = self.dirt_cleaned self.dirt_cleaned = int(hero.get("dirt_cleaned", env_info.get("clean_score", self.dirt_cleaned))) self.total_dirt = max(int(env_info.get("total_dirt", self.total_dirt)), 1) self.total_score = int(env_info.get("total_score", self.total_score)) self.clean_score = int(env_info.get("clean_score", self.dirt_cleaned)) step_cleaned_cells = env_info.get("step_cleaned_cells") or [] self.step_cleaned_count = len(step_cleaned_cells) # Charge progress / 充电进度 self.last_charge_count = self.charge_count self.charge_count = int(env_info.get("charge_count", self.charge_count)) self.charge_delta = max(0, self.charge_count - self.last_charge_count) # Legal actions / 合法动作 raw_legal_act = observation.get("legal_action") or observation.get("legal_act") or [1] * 8 self._legal_act = [int(x) for x in raw_legal_act[:8]] if len(self._legal_act) < 8: self._legal_act.extend([1] * (8 - len(self._legal_act))) # Local view map (21×21) / 局部视野地图 map_info = observation.get("map_info") if map_info is not None: self._view_map = np.array(map_info, dtype=np.float32) hx, hz = self.cur_pos self._update_passable(hx, hz) self._mark_cleaned_cells(step_cleaned_cells) self._update_local_map_stats() organs = frame_state.get("organs") or extra_frame_state.get("organs") or [] npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or [] organs = organs if isinstance(organs, (list, tuple)) else [] self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else [] self._update_charger_state(hx, hz, organs) self._update_npc_state(hx, hz, self.npcs) self._update_global_planning_state() self._update_recharge_mode() self._update_motion_health() def _update_passable(self, hx, hz): """Write local view into global passable map. 将局部视野写入全局通行地图。 """ view = self._view_map vsize = view.shape[0] half = vsize // 2 for ri in range(vsize): for ci in range(vsize): gx = hx + ci - half gz = hz + ri - half if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE: cell = int(view[ri, ci]) self.known_map[gx, gz] = cell self.passable_map[gx, gz] = 1 if cell != 0 else 0 self.dirty_map[gx, gz] = 1 if cell == 2 else 0 if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE: self.known_map[hx, hz] = 1 self.passable_map[hx, hz] = 1 self.dirty_map[hx, hz] = 0 self._clear_path_caches() def _mark_cleaned_cells(self, step_cleaned_cells): """Mark cells cleaned in the current step in the global belief map.""" for pos in step_cleaned_cells or []: pos = _as_dict(pos) x = int(pos.get("x", -1)) z = int(pos.get("z", -1)) if 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE: self.known_map[x, z] = 1 self.passable_map[x, z] = 1 self.dirty_map[x, z] = 0 self._clear_path_caches() def _clear_path_caches(self): self._path_cache_dirty = True def _drop_path_caches(self): self._dirty_reverse_dist = None self._frontier_reverse_dist = None self._charger_reverse_dist = None def _view_index_to_global(self, ri, ci): """Convert local view row/col to global x/z coordinates.""" half = self._view_map.shape[0] // 2 hx, hz = self.cur_pos return hx + ci - half, hz + ri - half def _view_delta_to_index(self, dx, dz): """Convert global-coordinate dx/dz to local view row/col.""" return self.VIEW_HALF + dz, self.VIEW_HALF + dx def _view_cell(self, dx, dz, default=None): """Read a local-view cell by global-coordinate delta. `map_info` is row-major: row follows z/down, col follows x/right. """ ri, ci = self._view_delta_to_index(dx, dz) if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]): return default return int(self._view_map[ri, ci]) def _update_local_map_stats(self): """Cache coarse 21x21 map statistics.""" view = self._view_map if view is None or view.size == 0: self.local_dirt_ratio = 0.0 self.local_obstacle_ratio = 0.0 return total = float(view.size) self.local_dirt_ratio = float(np.sum(view == 2) / total) self.local_obstacle_ratio = float(np.sum(view == 0) / total) def _update_global_planning_state(self): """Refresh global coverage, frontier, and action-improvement features.""" self.last_global_dirty_path_dist = self.global_dirty_path_dist self.last_frontier_path_dist = self.frontier_path_dist self._update_frontier_map() hx, hz = self.cur_pos should_refresh_paths = ( self._dirty_reverse_dist is None or self._frontier_reverse_dist is None or (self.has_charger and self._charger_reverse_dist is None) or ( self._path_cache_dirty and self.step_no - self._planner_last_update_step >= self.PLANNER_UPDATE_INTERVAL ) ) if should_refresh_paths: self._drop_path_caches() self._planner_last_update_step = self.step_no self._path_cache_dirty = False known_count = float(np.sum(self.known_map >= 0)) passable_count = float(np.sum(self.passable_map > 0)) dirty_count = float(np.sum(self.dirty_map > 0)) frontier_count = float(np.sum(self.frontier_map > 0)) total_cells = float(self.GRID_SIZE * self.GRID_SIZE) self.known_ratio = known_count / total_cells self.known_passable_ratio = passable_count / total_cells self.known_dirty_ratio = dirty_count / max(float(self.total_dirt), 1.0) self.frontier_ratio = frontier_count / max(passable_count, 1.0) dirty_dist = self._get_dirty_reverse_dist() frontier_dist = self._get_frontier_reverse_dist() charger_dist = self._get_charger_reverse_dist() self.global_dirty_path_dist = self._dist_at(dirty_dist, hx, hz, default=float(self.GRID_SIZE)) self.frontier_path_dist = self._dist_at(frontier_dist, hx, hz, default=float(self.GRID_SIZE)) if charger_dist is not None: charger_path = self._dist_at(charger_dist, hx, hz, default=self.INF_DIST) if charger_path < self.INF_DIST: self.nearest_charger_path_dist = min(self.nearest_charger_path_dist, float(charger_path)) self.charger_energy_cost = self.nearest_charger_path_dist self.battery_margin = float(self.battery) - self.nearest_charger_path_dist self.charger_route_known = True self.charger_route_source = "global" self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist) self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist) current_charger = self._dist_at(charger_dist, hx, hz, default=self.nearest_charger_path_dist) self.charger_action_delta = self._action_distance_delta(charger_dist, current_charger) def _update_frontier_map(self): """Mark known passable cells adjacent to unseen space as exploration frontiers.""" self.frontier_map.fill(0) passable_coords = np.argwhere(self.passable_map > 0) for x, z in passable_coords: x = int(x) z = int(z) for dx, dz in ((1, 0), (-1, 0), (0, 1), (0, -1)): nx, nz = x + dx, z + dz if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE and self.known_map[nx, nz] < 0: self.frontier_map[x, z] = 1 break def _get_dirty_reverse_dist(self): if self._dirty_reverse_dist is None: targets = np.argwhere((self.dirty_map > 0) & (self.passable_map > 0)) self._dirty_reverse_dist = self._global_bfs_from_targets(targets) return self._dirty_reverse_dist def _get_frontier_reverse_dist(self): if self._frontier_reverse_dist is None: targets = np.argwhere((self.frontier_map > 0) & (self.passable_map > 0)) self._frontier_reverse_dist = self._global_bfs_from_targets(targets) return self._frontier_reverse_dist def _get_charger_reverse_dist(self): if not self.charger_rects: return None if self._charger_reverse_dist is None: self._charger_reverse_dist = self._global_bfs_from_targets(self._charger_target_cells()) return self._charger_reverse_dist def _charger_target_cells(self): targets = [] for rx, rz, w, h in self.charger_rects: for x in range(rx, rx + w): for z in range(rz, rz + h): if self._is_known_passable(x, z): targets.append((x, z)) return targets def _global_bfs_from_targets(self, targets): """Reverse BFS over the accumulated known passable map.""" dist = np.full((self.GRID_SIZE, self.GRID_SIZE), self.INF_DIST, dtype=np.float32) queue = deque() for target in targets: if len(target) < 2: continue x = int(target[0]) z = int(target[1]) if not self._is_known_passable(x, z) or dist[x, z] == 0.0: continue dist[x, z] = 0.0 queue.append((x, z)) while queue: x, z = queue.popleft() base = dist[x, z] for dx, dz in self.ACTION_DIRS: nx, nz = x + dx, z + dz if not self._can_global_move(x, z, dx, dz): continue if dist[nx, nz] < self.INF_DIST: continue dist[nx, nz] = base + 1.0 queue.append((nx, nz)) return dist def _is_known_passable(self, x, z): return 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE and self.passable_map[x, z] > 0 def _can_global_move(self, x, z, dx, dz): nx, nz = x + dx, z + dz if not self._is_known_passable(x, z) or not self._is_known_passable(nx, nz): return False if dx != 0 and dz != 0: return self._is_known_passable(x + dx, z) or self._is_known_passable(x, z + dz) return True def _dist_at(self, dist, x, z, default=None): if default is None: default = self.INF_DIST if dist is None or not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE): return float(default) value = float(dist[x, z]) return value if value < self.INF_DIST else float(default) def _action_distance_delta(self, dist, current_dist): delta = np.zeros(8, dtype=np.float32) if dist is None or current_dist >= self.INF_DIST: return delta hx, hz = self.cur_pos for action, (dx, dz) in enumerate(self.ACTION_DIRS): nx, nz = hx + dx, hz + dz if not self._can_global_move(hx, hz, dx, dz): continue next_dist = self._dist_at(dist, nx, nz, default=self.INF_DIST) if next_dist >= self.INF_DIST: continue delta[action] = np.float32(np.clip((current_dist - next_dist) / 4.0, -1.0, 1.0)) return delta def _update_charger_state(self, hx, hz, organs): """Find nearest charger and cache distance/direction features.""" self.last_nearest_charger_range_dist = self.nearest_charger_range_dist self.last_nearest_charger_path_dist = self.nearest_charger_path_dist self.has_charger = False self.on_charger = False self.nearest_charger_dx = 0.0 self.nearest_charger_dz = 0.0 self.nearest_charger_center_dx = 0.0 self.nearest_charger_center_dz = 0.0 self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE) self.nearest_charger_path_dist = float(self.GRID_SIZE) self.charger_energy_cost = float(self.GRID_SIZE) self.charger_safety_buffer = 0.0 self.charger_safety_margin = 0.0 self.charger_rects = [] self.charger_route_known = False self.charger_route_source = "none" best = None for organ in organs: if not isinstance(organ, dict): continue if int(organ.get("sub_type", 1)) != 1: continue pos = organ.get("pos") or {} ox = int(pos.get("x", 0)) oz = int(pos.get("z", 0)) w = max(int(organ.get("w", 3)), 1) h = max(int(organ.get("h", 3)), 1) for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)): self.charger_rects.append((rx, rz, w, h)) dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h) dist = float(np.sqrt(dx * dx + dz * dz)) range_dist = float(max(abs(dx), abs(dz))) center_dx = (rx + (w - 1) * 0.5) - hx center_dz = (rz + (h - 1) * 0.5) - hz if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]): best = (range_dist, dist, dx, dz, center_dx, center_dz) if best is None: self.battery_margin = float(self.battery) self.charger_safety_margin = float(self.battery) return range_dist, dist, dx, dz, center_dx, center_dz = best self.has_charger = True self.nearest_charger_dx = float(dx) self.nearest_charger_dz = float(dz) self.nearest_charger_center_dx = float(center_dx) self.nearest_charger_center_dz = float(center_dz) self.nearest_charger_dist = float(dist) self.nearest_charger_range_dist = float(range_dist) path_dist = self._global_path_dist_to_charger(hx, hz) if path_dist < self.INF_DIST: self.charger_route_known = True self.charger_route_source = "global" else: path_dist = self._local_path_dist_to_charger(hx, hz) if path_dist < self.INF_DIST: self.charger_route_known = True self.charger_route_source = "local" else: self.charger_route_known = False self.charger_route_source = "range" self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist) self.charger_energy_cost = self.nearest_charger_path_dist self.on_charger = range_dist <= 0.0 self.battery_margin = float(self.battery) - self.nearest_charger_path_dist def _relative_vector_to_rect(self, x, z, rx, rz, w, h): """Relative vector from point to the nearest cell in a rectangle.""" if x < rx: dx = rx - x elif x > rx + w - 1: dx = x - (rx + w - 1) else: dx = 0 if z < rz: dz = rz - z elif z > rz + h - 1: dz = z - (rz + h - 1) else: dz = 0 return float(dx), float(dz) def _update_npc_state(self, hx, hz, npcs): """Find nearest NPC and cache safety features.""" self.nearest_npc_dx = 0.0 self.nearest_npc_dz = 0.0 self.nearest_npc_vx = 0.0 self.nearest_npc_vz = 0.0 self.nearest_npc_dist = float(self.GRID_SIZE) self.predicted_npc_dist = float(self.GRID_SIZE) self.npc_danger = False self.npc_predicted_danger = False self.predicted_npcs = [] best = None current_positions = {} for npc in npcs: if not isinstance(npc, dict): continue pos = npc.get("pos") or {} nx = int(pos.get("x", 0)) nz = int(pos.get("z", 0)) npc_key = str(npc.get("npc_id", npc.get("idx", len(current_positions)))) prev_pos = self.prev_npc_positions.get(npc_key) vx = 0 vz = 0 if prev_pos is not None: vx = int(np.clip(nx - prev_pos[0], -1, 1)) vz = int(np.clip(nz - prev_pos[1], -1, 1)) px = int(np.clip(nx + vx, 0, self.GRID_SIZE - 1)) pz = int(np.clip(nz + vz, 0, self.GRID_SIZE - 1)) current_positions[npc_key] = (nx, nz) self.predicted_npcs.append((px, pz, 1)) dx = nx - hx dz = nz - hz cheb = float(max(abs(dx), abs(dz))) pred_cheb = float(max(abs(px - hx), abs(pz - hz))) if best is None or cheb < best[0]: best = (cheb, dx, dz, vx, vz, pred_cheb) self.prev_npc_positions = current_positions if best is None: return cheb, dx, dz, vx, vz, pred_cheb = best self.nearest_npc_dx = float(dx) self.nearest_npc_dz = float(dz) self.nearest_npc_vx = float(vx) self.nearest_npc_vz = float(vz) self.nearest_npc_dist = float(cheb) self.predicted_npc_dist = float(pred_cheb) self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1 self.npc_predicted_danger = pred_cheb <= 1 def _update_recharge_mode(self): """Enter/exit low-battery recharge mode.""" battery_ratio = self.battery / max(self.battery_max, 1) if not self.has_charger: self.recharge_mode = False self.charger_safety_margin = float(self.battery) self.recharge_enter_margin = 0.0 self.recharge_leave_margin = 0.0 self.recharge_low_battery_ratio = 0.28 self.full_charge_leave_ratio = 0.96 self.low_battery = battery_ratio < self.recharge_low_battery_ratio return self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0)) self.charger_safety_buffer = self._charger_safety_buffer() self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer self.recharge_enter_margin = self._recharge_enter_margin() self.recharge_leave_margin = self._recharge_leave_margin() self.recharge_low_battery_ratio = self._recharge_low_battery_ratio() self.full_charge_leave_ratio = self._full_charge_leave_ratio() self.low_battery = battery_ratio < self.recharge_low_battery_ratio should_recharge = self.charger_safety_margin <= self.recharge_enter_margin or self.low_battery safe_to_leave = ( battery_ratio >= self.full_charge_leave_ratio and self.charger_safety_margin >= self.recharge_leave_margin ) if self.on_charger and safe_to_leave: self.recharge_mode = False elif should_recharge: self.recharge_mode = True elif self.recharge_mode and not safe_to_leave: self.recharge_mode = True else: self.recharge_mode = False if self.recharge_mode: self.recharge_steps += 1 def _update_motion_health(self): """Track recharge-mode stalls so action masking can recover.""" if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8: if self.charge_delta <= 0: self.stuck_steps += 1 self.stuck_count += 1 self.max_stuck_steps = max(self.max_stuck_steps, self.stuck_steps) else: self.stuck_steps = 0 else: self.stuck_steps = 0 battery_ratio = self.battery / max(self.battery_max, 1) battery_increased = self.battery > self.prev_battery + 1 maybe_fake_charger = self.on_charger and battery_ratio < 0.9 and self.charge_delta <= 0 and not battery_increased self.fake_charger_steps = self.fake_charger_steps + 1 if maybe_fake_charger else 0 no_progress = ( self.recharge_mode and self.has_charger and self.charge_delta <= 0 and not battery_increased and self.nearest_charger_path_dist >= self.last_nearest_charger_path_dist - 0.1 ) self.recharge_no_progress_steps = self.recharge_no_progress_steps + 1 if no_progress else 0 if self.step_no > 0 and min(self.nearest_npc_dist, self.predicted_npc_dist) <= 3: self.npc_close_steps += 1 if self.step_no > 0 and (self.npc_danger or self.npc_predicted_danger): self.npc_danger_steps += 1 if self.terminated and not self.truncated: if self.battery <= 0 or self.remaining_charge <= 0: self.battery_fail = 1 if self.npc_danger or self.npc_predicted_danger or self.nearest_npc_dist <= 1: self.npc_collision = 1 def _need_recharge_escape(self): return self.stuck_steps >= 2 or self.recharge_no_progress_steps >= 5 or self.fake_charger_steps >= 2 def _charger_safety_buffer(self): # One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise. base = max(22.0, 0.14 * float(self.battery_max)) distance_buffer = min(18.0, 0.20 * float(max(self.nearest_charger_range_dist, 0.0))) obstacle_buffer = 14.0 * float(self.local_obstacle_ratio) route_uncertainty_buffer = 10.0 if self.has_charger and not self.charger_route_known else 0.0 return float(np.clip(base + distance_buffer + obstacle_buffer + route_uncertainty_buffer, 22.0, 58.0)) def _recharge_enter_margin(self): """Adaptive margin for entering recharge mode before the battery is barely enough.""" base = max(7.0, 0.025 * float(self.battery_max)) path_margin = min(14.0, 0.10 * float(max(self.nearest_charger_path_dist, 0.0))) obstacle_margin = 14.0 * float(self.local_obstacle_ratio) route_uncertainty_margin = 8.0 if self.has_charger and not self.charger_route_known else 0.0 recovery_margin = min(8.0, 1.5 * float(self.recharge_no_progress_steps + self.fake_charger_steps)) return float( np.clip( base + path_margin + obstacle_margin + route_uncertainty_margin + recovery_margin, 6.0, 42.0, ) ) def _recharge_leave_margin(self): """Adaptive safety margin required before leaving a charger.""" base = max(20.0, 0.08 * float(self.battery_max)) path_margin = min(18.0, 0.14 * float(max(self.nearest_charger_path_dist, 0.0))) obstacle_margin = 12.0 * float(self.local_obstacle_ratio) return float(np.clip(base + path_margin + obstacle_margin, 20.0, 64.0)) def _recharge_low_battery_ratio(self): """Adaptive low-battery ratio based on route length and local obstacle density.""" path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0) ratio = 0.32 + min(0.09, 0.42 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio)) if self.has_charger and not self.charger_route_known: ratio += 0.04 if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0: ratio += 0.02 return float(np.clip(ratio, 0.32, 0.46)) def _full_charge_leave_ratio(self): """Adaptive near-full threshold for leaving a charger.""" remaining_step_ratio = 1.0 - _norm(self.step_no, self.max_step) path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0) ratio = 0.88 + 0.04 * remaining_step_ratio + min(0.02, 0.08 * path_pressure) ratio += min(0.01, 0.04 * float(self.local_obstacle_ratio)) return float(np.clip(ratio, 0.88, 0.95)) def _recharge_risk_score(self): """Risk score in [0, 1] used to scale recharge rewards and penalties.""" if not self.has_charger: return 0.0 battery_ratio = self.battery / max(self.battery_max, 1) margin_deficit = max(0.0, self.recharge_enter_margin - self.charger_safety_margin) margin_risk = margin_deficit / max(self.charger_safety_buffer + self.recharge_enter_margin, 1.0) low_battery_risk = max(0.0, self.recharge_low_battery_ratio - battery_ratio) low_battery_risk /= max(self.recharge_low_battery_ratio, 1e-6) progress_risk = min(1.0, float(self.recharge_no_progress_steps) / 5.0) return float(np.clip(0.55 * margin_risk + 0.35 * low_battery_risk + 0.10 * progress_risk, 0.0, 1.0)) def useful_charge_reward_weight(self): """Adaptive reward weight for charging that happens under real battery pressure.""" prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1) prev_low_risk = max(0.0, self.recharge_low_battery_ratio - prev_battery_ratio) prev_low_risk /= max(self.recharge_low_battery_ratio, 1e-6) risk = max(self._recharge_risk_score(), prev_low_risk) mode_bonus = 0.8 if self.was_recharge_mode or self.prev_low_battery else 0.0 return float(np.clip(3.0 + 2.8 * risk + mode_bonus, 3.0, 6.5)) def battery_fail_penalty(self): """Adaptive terminal penalty for running out of battery before max steps.""" step_ratio = _norm(self.step_no, self.max_step) early_fail_risk = 1.0 - step_ratio path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0) risk = max(self._recharge_risk_score(), min(1.0, path_pressure)) penalty = float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0)) if self.reward_profile == "battery_safe": penalty *= 1.25 return penalty def _min_charger_range_dist(self, x, z): if not self.charger_rects: return float(self.GRID_SIZE) dists = [] for rx, rz, w, h in self.charger_rects: dx, dz = self._relative_vector_to_rect(x, z, rx, rz, w, h) dists.append(max(abs(dx), abs(dz))) return float(min(dists)) def _is_charger_cell(self, x, z): for rx, rz, w, h in self.charger_rects: if rx <= x < rx + w and rz <= z < rz + h: return True return False def _get_local_view_feature(self): """Local view feature: 21×21×6 multi-channel map. Channels: obstacle, clean, dirt, visit count, NPC danger, charger. """ view = self._view_map channels = np.zeros((self.MAP_CHANNELS, self.VIEW_SIZE, self.VIEW_SIZE), dtype=np.float32) if view is None or view.shape[0] != self.VIEW_SIZE or view.shape[1] != self.VIEW_SIZE: return channels.flatten() channels[0] = (view == 0).astype(np.float32) channels[1] = (view == 1).astype(np.float32) channels[2] = (view == 2).astype(np.float32) for ri in range(self.VIEW_SIZE): for ci in range(self.VIEW_SIZE): gx, gz = self._view_index_to_global(ri, ci) if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE): continue channels[3, ri, ci] = _norm(min(int(self.visit_count_map[gx, gz]), 10), 10) channels[4, ri, ci] = 1.0 if self._is_npc_danger_cell(gx, gz, expanded=True) else 0.0 channels[5, ri, ci] = 1.0 if self._is_charger_cell(gx, gz) else 0.0 return channels.flatten() def _get_global_state_feature(self): """Global state feature (66D). Existing global state plus belief-map distances, action distance improvements, known charger-route safety, and predicted NPC motion. """ step_norm = _norm(self.step_no, self.max_step) battery_ratio = _norm(self.battery, self.battery_max) cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt) remaining_dirt = 1.0 - cleaning_progress hx, hz = self.cur_pos pos_x_weak = self._weak_abs_position_feature(hx) pos_z_weak = self._weak_abs_position_feature(hz) # 4-directional ray to find nearest dirt # 四方向射线找最近污渍距离 ray_dirs = [(0, -1), (1, 0), (0, 1), (-1, 0)] # N E S W ray_dirt = [] max_ray = 30 for dx, dz in ray_dirs: found = max_ray for step in range(1, max_ray + 1): gx = hx + dx * step gz = hz + dz * step if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE): break cell = self._view_cell(dx * step, dz * step, default=0) if cell == 2: found = step break ray_dirt.append(_norm(found, max_ray)) # Nearest dirt path distance in the visible map. # 视野内最近污渍路径距离。 self.last_nearest_dirt_dist = self.nearest_dirt_dist self.nearest_dirt_dist = self._calc_nearest_dirt_dist() nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180) dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0 charge_delta = 1.0 if self.charge_delta > 0 else 0.0 battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max) visit_count_norm = _norm(min(self.current_visit_count, 10), 10) step_cleaned_norm = _norm(self.step_cleaned_count, 9) global_dirty_delta = _signed_norm( np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -4.0, 4.0), 4.0 ) frontier_delta = _signed_norm( np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -4.0, 4.0), 4.0 ) charger_margin_after_buffer = self.battery - self.nearest_charger_path_dist - self.charger_safety_buffer base_features = np.array( [ step_norm, battery_ratio, cleaning_progress, remaining_dirt, pos_x_weak, pos_z_weak, ray_dirt[0], ray_dirt[1], ray_dirt[2], ray_dirt[3], nearest_dirt_norm, dirt_delta, _signed_norm(self.nearest_charger_dx, self.GRID_SIZE), _signed_norm(self.nearest_charger_dz, self.GRID_SIZE), _norm(self.nearest_charger_path_dist, self.GRID_SIZE), battery_margin_norm, 1.0 if self.low_battery else 0.0, 1.0 if self.recharge_mode else 0.0, 1.0 if self.on_charger else 0.0, charge_delta, _signed_norm(self.nearest_npc_dx, 20), _signed_norm(self.nearest_npc_dz, 20), _norm(self.nearest_npc_dist, 20), 1.0 if self.npc_danger else 0.0, self.local_dirt_ratio, self.local_obstacle_ratio, visit_count_norm, step_cleaned_norm, _norm(self.global_dirty_path_dist, self.GRID_SIZE), _norm(self.frontier_path_dist, self.GRID_SIZE), global_dirty_delta, frontier_delta, self.known_ratio, self.known_passable_ratio, _norm(self.known_dirty_ratio, 1.0), _norm(self.frontier_ratio, 1.0), 1.0 if self.charger_route_known else 0.0, _signed_norm(charger_margin_after_buffer, self.battery_max), _signed_norm(self.nearest_npc_vx, 1.0), _signed_norm(self.nearest_npc_vz, 1.0), _norm(self.predicted_npc_dist, 20), 1.0 if self.npc_predicted_danger else 0.0, ], dtype=np.float32, ) return np.concatenate( [ base_features, self.global_dirty_action_delta.astype(np.float32), self.frontier_action_delta.astype(np.float32), self.charger_action_delta.astype(np.float32), ] ) def _weak_abs_position_feature(self, value): pos_norm = _norm(value, self.GRID_SIZE) return 0.5 + self.ABS_POS_FEATURE_SCALE * (pos_norm - 0.5) def _calc_nearest_dirt_dist(self): """Find nearest dirt path distance from local view. 从局部视野中找最近污渍路径距离。 """ dist = self._local_bfs_distances() dirt_coords = np.argwhere(self._view_map == 2) if len(dirt_coords) == 0: return 200.0 best = min(float(dist[ri, ci]) for ri, ci in dirt_coords) if best >= self.INF_DIST: return 200.0 return best def _local_bfs_distances(self, start_dx=0, start_dz=0): """Shortest path distances inside the current 21x21 local view.""" view = self._view_map shape = view.shape dist = np.full(shape, self.INF_DIST, dtype=np.float32) start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz) if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]): return dist if int(view[start_ri, start_ci]) == 0: return dist dist[start_ri, start_ci] = 0.0 queue = deque([(start_ri, start_ci)]) while queue: ri, ci = queue.popleft() base = dist[ri, ci] for dx, dz in self.ACTION_DIRS: nri = ri + dz nci = ci + dx if not (0 <= nri < shape[0] and 0 <= nci < shape[1]): continue if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST: continue if dx != 0 and dz != 0: side_a = int(view[ri, nci]) != 0 side_b = int(view[nri, ci]) != 0 if not (side_a or side_b): continue dist[nri, nci] = base + 1.0 queue.append((nri, nci)) return dist def _local_path_dist_to_charger(self, gx, gz): """Visible-map BFS distance from global x/z to nearest charger cell.""" best = self.INF_DIST start_dx = gx - self.cur_pos[0] start_dz = gz - self.cur_pos[1] dist = self._local_bfs_distances(start_dx, start_dz) for rx, rz, w, h in self.charger_rects: for tx in range(rx, rx + w): for tz in range(rz, rz + h): dx = tx - self.cur_pos[0] dz = tz - self.cur_pos[1] ri, ci = self._view_delta_to_index(dx, dz) if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]: best = min(best, float(dist[ri, ci])) return best def _global_path_dist_to_charger(self, gx, gz): """Known-map BFS distance from a global cell to the nearest observed charger cell.""" dist = self._get_charger_reverse_dist() return self._dist_at(dist, gx, gz, default=self.INF_DIST) def _charger_move_distance(self, gx, gz): """Use known-map BFS to the charger when available, then visible BFS, then Chebyshev.""" path_dist = self._global_path_dist_to_charger(gx, gz) if path_dist < self.INF_DIST: return path_dist path_dist = self._local_path_dist_to_charger(gx, gz) if path_dist < self.INF_DIST: return path_dist return self._min_charger_range_dist(gx, gz) def get_legal_action(self): """Return legal action mask (8D list). 返回合法动作掩码(8D list)。 """ raw_legal = [int(x) for x in self._legal_act] blocked_legal = self._filter_blocked_actions(raw_legal) npc_legal = self._filter_npc_danger_actions(blocked_legal) safe_legal = list(npc_legal) recharge_legal = None escape_legal = None leave_legal = None legal = npc_legal if self.recharge_mode: recharge_legal = self._filter_recharge_actions(legal) escape_legal = self._filter_recharge_escape_actions(recharge_legal, safe_legal) legal = escape_legal elif self.on_charger and self.battery / max(self.battery_max, 1) >= self.full_charge_leave_ratio: leave_legal = self._filter_leave_charger_actions(legal) legal = leave_legal self._record_mask_diagnostics( raw_legal=raw_legal, blocked_legal=blocked_legal, npc_legal=npc_legal, recharge_legal=recharge_legal, escape_legal=escape_legal, leave_legal=leave_legal, final_legal=legal, ) return list(legal) def record_action(self, action): """Record the chosen action for episode diagnostics.""" try: action = int(action) except (TypeError, ValueError): return if 0 <= action < len(self.diag_action_hist): self.diag_action_hist[action] += 1 def _record_mask_diagnostics( self, raw_legal, blocked_legal, npc_legal, recharge_legal, escape_legal, leave_legal, final_legal, ): """Record action-mask counts without changing mask behavior.""" self.diag_mask_steps += 1 stages = { "raw": raw_legal, "blocked": blocked_legal, "npc": npc_legal, "recharge": recharge_legal if recharge_legal is not None else npc_legal, "escape": escape_legal if escape_legal is not None else (recharge_legal if recharge_legal is not None else npc_legal), "leave": leave_legal if leave_legal is not None else npc_legal, "final": final_legal, } for name, mask in stages.items(): self.diag_mask_count_sums[name] += self._mask_count(mask) if not self._same_mask(raw_legal, blocked_legal): self.diag_mask_changed_steps["blocked"] += 1 if not self._same_mask(blocked_legal, npc_legal): self.diag_mask_changed_steps["npc"] += 1 if recharge_legal is not None: self.diag_mask_active_steps["recharge"] += 1 if not self._same_mask(npc_legal, recharge_legal): self.diag_mask_changed_steps["recharge"] += 1 if escape_legal is not None and recharge_legal is not None: if not self._same_mask(recharge_legal, escape_legal): self.diag_mask_changed_steps["escape"] += 1 if leave_legal is not None: self.diag_mask_active_steps["leave"] += 1 if not self._same_mask(npc_legal, leave_legal): self.diag_mask_changed_steps["leave"] += 1 final_count = self._mask_count(final_legal) if final_count <= 0: self.diag_zero_final_steps += 1 if final_count == 1: self.diag_one_action_steps += 1 if final_count <= 2: self.diag_two_or_less_action_steps += 1 def _mask_count(self, mask): return int(sum(1 for value in mask if int(value) > 0)) def _same_mask(self, left, right): return [int(x) for x in left] == [int(x) for x in right] def get_diagnostic_summary(self): """Return episode-level diagnostic counters for logging.""" steps = max(self.diag_mask_steps, 1) avg_mask_counts = { name: self.diag_mask_count_sums[name] / steps for name in sorted(self.diag_mask_count_sums) } return { "map_id": self.map_id, "mask_steps": self.diag_mask_steps, "avg_mask_counts": avg_mask_counts, "mask_changed_steps": dict(self.diag_mask_changed_steps), "mask_active_steps": dict(self.diag_mask_active_steps), "one_action_steps": self.diag_one_action_steps, "two_or_less_action_steps": self.diag_two_or_less_action_steps, "zero_final_steps": self.diag_zero_final_steps, "action_hist": list(self.diag_action_hist), "known_ratio": self.known_ratio, "known_dirty_ratio": self.known_dirty_ratio, "frontier_ratio": self.frontier_ratio, "local_dirt_ratio": self.local_dirt_ratio, "local_obstacle_ratio": self.local_obstacle_ratio, "global_dirty_path_dist": self.global_dirty_path_dist, "frontier_path_dist": self.frontier_path_dist, "charger_route_source": self.charger_route_source, "reward_profile": self.reward_profile, } def evaluation_action_score(self, action): """Heuristic score used only to break close evaluation-policy ties.""" if not (0 <= int(action) < len(self.ACTION_DIRS)): return -1e6 action = int(action) dx, dz = self.ACTION_DIRS[action] hx, hz = self.cur_pos nx, nz = hx + dx, hz + dz if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE): return -1e6 score = 0.0 cell = self._view_cell(dx, dz, default=1) if cell == 0: score -= 8.0 elif cell == 2: score += 3.0 else: score -= 0.10 visit_count = int(self.visit_count_map[nx, nz]) if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE else 0 score += 0.35 if visit_count == 0 else -0.05 * min(visit_count, 10) if self.recharge_mode: score += 2.2 * float(self.charger_action_delta[action]) if self._charger_move_distance(nx, nz) < self._charger_move_distance(hx, hz): score += 0.8 else: if self.global_dirty_path_dist < self.GRID_SIZE: score += 1.8 * float(self.global_dirty_action_delta[action]) elif self.frontier_path_dist < self.GRID_SIZE: score += 1.4 * float(self.frontier_action_delta[action]) if self.low_battery and self.has_charger: score += 1.2 * float(self.charger_action_delta[action]) if self._is_charger_cell(nx, nz): score += 0.8 if self.low_battery or self.recharge_mode else -0.2 if self._is_npc_danger_cell(nx, nz, expanded=False): score -= 6.0 elif self._is_npc_danger_cell(nx, nz, expanded=True): score -= 1.5 if action == self.last_action and self.stuck_steps >= 1: score -= 1.0 return float(score) def _filter_blocked_actions(self, legal_action): """Filter actions that are visibly blocked in the 21x21 view.""" legal = [int(x) for x in legal_action] hx, hz = self.cur_pos for action, (dx, dz) in enumerate(self.ACTION_DIRS): if legal[action] <= 0: continue if not self._is_visible_cell_passable(dx, dz): legal[action] = 0 continue if dx != 0 and dz != 0: side_a = self._is_visible_cell_passable(dx, 0) side_b = self._is_visible_cell_passable(0, dz) if not (side_a or side_b): legal[action] = 0 nx, nz = hx + dx, hz + dz if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE): legal[action] = 0 return legal if any(legal) else [int(x) for x in legal_action] def _is_visible_cell_passable(self, dx, dz): """Whether a relative 21x21-view cell is passable.""" cell = self._view_cell(dx, dz, default=None) return True if cell is None else cell != 0 def _filter_npc_danger_actions(self, legal_action): """Avoid current and predicted NPC danger zones.""" if not self.npcs: return list(legal_action) hx, hz = self.cur_pos safe = [int(x) for x in legal_action] for action, (dx, dz) in enumerate(self.ACTION_DIRS): if safe[action] <= 0: continue nx, nz = hx + dx, hz + dz if self._is_npc_danger_cell(nx, nz, expanded=True): safe[action] = 0 if any(safe): return safe hard_safe = [int(x) for x in legal_action] for action, (dx, dz) in enumerate(self.ACTION_DIRS): if hard_safe[action] <= 0: continue nx, nz = hx + dx, hz + dz if self._is_npc_danger_cell(nx, nz, expanded=False): hard_safe[action] = 0 return hard_safe if any(hard_safe) else list(legal_action) def _is_npc_danger_cell(self, x, z, expanded=True): for npc in self.npcs: if not isinstance(npc, dict): continue pos = npc.get("pos") or {} nx = int(pos.get("x", -999)) nz = int(pos.get("z", -999)) if abs(x - nx) <= 1 and abs(z - nz) <= 1: return True if expanded and abs(x - nx) <= 2 and abs(z - nz) <= 2 and self.nearest_npc_dist <= 4: return True if expanded: for px, pz, radius in self.predicted_npcs: if abs(x - px) <= radius and abs(z - pz) <= radius: return True if self.nearest_npc_dist <= 4 and abs(x - px) <= 2 and abs(z - pz) <= 2: return True return False def _filter_recharge_actions(self, legal_action): """Restrict recharge-mode actions to safe moves toward the charger range.""" if not self.has_charger: return list(legal_action) hx, hz = self.cur_pos current_range_dist = self._min_charger_range_dist(hx, hz) current_move_dist = self._charger_move_distance(hx, hz) scored = [] for action, (dx, dz) in enumerate(self.ACTION_DIRS): if legal_action[action] <= 0: continue nx, nz = hx + dx, hz + dz next_dist = self._charger_move_distance(nx, nz) alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz next_range_dist = self._min_charger_range_dist(nx, nz) scored.append((next_dist, alignment, next_range_dist, action)) if not scored: return list(legal_action) # When already inside the charger range, stay inside until recharge mode exits. # 已经在充电区域内时,回充模式退出前不要离开充电区域。 confirmed_charger = self.charge_delta > 0 or self.battery > self.prev_battery + 1 if current_range_dist <= 0.0 and confirmed_charger: stay = [0] * 8 for _, _, next_range_dist, action in scored: if next_range_dist <= 0.0: stay[action] = 1 if any(stay): return stay recharge = [0] * 8 best_next_dist = min(item[0] for item in scored) ranked = sorted(scored, key=lambda item: (item[0], -item[1])) max_recharge_actions = 4 if self.charger_route_known else 5 dist_slack = 2.5 if self.charger_route_known else 4.0 for next_dist, alignment, next_range_dist, action in ranked: route_progress = next_dist <= current_move_dist + 0.1 range_progress = next_range_dist <= current_range_dist direction_progress = alignment > 0 if next_dist <= best_next_dist + dist_slack and ( route_progress or (not self.charger_route_known and (range_progress or direction_progress)) ): recharge[action] = 1 if sum(recharge) >= max_recharge_actions: break if not any(recharge): for _, _, _, action in ranked[: min(max_recharge_actions, len(ranked))]: recharge[action] = 1 return recharge if any(recharge) else list(legal_action) def _filter_recharge_escape_actions(self, recharge_action, safe_action): """Escape repeated no-move states during low-battery recharge.""" if not self._need_recharge_escape(): return list(recharge_action) hx, hz = self.cur_pos current_dist = self._charger_move_distance(hx, hz) ranked = [] for action, (dx, dz) in enumerate(self.ACTION_DIRS): if safe_action[action] <= 0: continue nx, nz = hx + dx, hz + dz next_dist = self._charger_move_distance(nx, nz) visit_count = 0 if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE: visit_count = int(self.visit_count_map[nx, nz]) failed_action_penalty = 6.0 if action == self.last_action and self.stuck_steps >= 2 else 0.0 no_progress_penalty = 1.5 if next_dist > current_dist + 0.1 else 0.0 ranked.append((next_dist + 0.05 * min(visit_count, 20) + failed_action_penalty + no_progress_penalty, action)) if not ranked: return list(recharge_action) self.recharge_escape_count += 1 ranked.sort() escape = [0] * 8 for _, action in ranked[: min(4, len(ranked))]: escape[action] = 1 if self.stuck_steps >= 2 and sum(escape) > 1 and 0 <= self.last_action < 8: escape[self.last_action] = 0 return escape if any(escape) else list(recharge_action) def _filter_leave_charger_actions(self, legal_action): """Prefer moves that leave charger range when battery is healthy.""" if not self.has_charger: return list(legal_action) hx, hz = self.cur_pos current_dist = self._min_charger_range_dist(hx, hz) scored = [] for action, (dx, dz) in enumerate(self.ACTION_DIRS): if legal_action[action] <= 0: continue nx, nz = hx + dx, hz + dz next_dist = self._min_charger_range_dist(nx, nz) away_score = -(dx * self.nearest_charger_center_dx + dz * self.nearest_charger_center_dz) scored.append((next_dist - current_dist, away_score, action)) if not scored: return list(legal_action) best_escape = max(item[0] for item in scored) leave = [0] * 8 if best_escape > 0: for escape, _, action in scored: if escape >= best_escape - 0.1: leave[action] = 1 else: best_away = max(item[1] for item in scored) for _, away_score, action in scored: if away_score >= best_away: leave[action] = 1 return leave if any(leave) else list(legal_action) def feature_process(self, env_obs, last_action): """Generate feature vector, legal action mask, and scalar reward. 生成特征向量、合法动作掩码和标量奖励。 """ self.pb2struct(env_obs, last_action) local_view = self._get_local_view_feature() # 2646D global_state = self._get_global_state_feature() # 66D legal_action = self.get_legal_action() # 8D last_action_feature = np.zeros(8, dtype=np.float32) if 0 <= last_action < 8: last_action_feature[last_action] = 1.0 feature = np.concatenate([local_view, global_state, last_action_feature]) reward = self.reward_process() return feature, legal_action, reward def reward_process(self): cleaning_multiplier, charge_multiplier, exploration_multiplier = self._reward_profile_scales() # Cleaning reward / 清扫奖励 cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned) cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step battery_ratio = self.battery / max(self.battery_max, 1) battery_pressure = self.has_charger and battery_ratio < self.recharge_low_battery_ratio + 0.06 cleaning_scale = 0.2 if self.recharge_mode else (0.55 if battery_pressure else 0.7) cleaning_scale *= cleaning_multiplier cleaning_reward = cleaning_scale * cleaned_cells # Step penalty / 时间惩罚 step_penalty = -0.002 # Recharge guidance only activates when battery safety is the bottleneck. # 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。 charge_reward = 0.0 prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1) useful_charge = self.charge_delta > 0 and ( self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.45 ) if useful_charge: charge_reward += self.useful_charge_reward_weight() elif self.charge_delta > 0 and battery_ratio > 0.65: charge_reward -= 0.25 * min(self.charge_delta, 3) if self.has_charger and (self.recharge_mode or self.low_battery): dist_delta = float( np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0) ) recharge_risk = self._recharge_risk_score() approach_scale = 0.07 + 0.06 * recharge_risk retreat_scale = 0.035 + 0.045 * recharge_risk if not self.charger_route_known: approach_scale += 0.02 retreat_scale += 0.01 charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta if self.charger_safety_margin < self.recharge_enter_margin: safety_shortage = self.recharge_enter_margin - self.charger_safety_margin charge_reward -= min(0.55, safety_shortage / max(self.battery_max, 1)) elif self.on_charger and battery_ratio > 0.65: charge_reward -= 0.08 charge_reward *= charge_multiplier # Encourage covering new passable cells and mildly discourage loops. # 鼓励探索新格子,轻微惩罚反复绕圈。 if self.recharge_mode: exploration_reward = 0.0 else: exploration_reward = 0.004 if self.is_new_cell else -0.0015 * min(self.current_visit_count, 6) if self.global_dirty_path_dist < self.GRID_SIZE: dirty_progress = np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -3.0, 3.0) exploration_reward += 0.008 * dirty_progress elif self.frontier_path_dist < self.GRID_SIZE: frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0) exploration_reward += 0.005 * frontier_progress exploration_reward *= exploration_multiplier # Collision/stuck signal: invalid moves waste both step and battery. # 撞墙/原地不动会浪费步数和电量。 stuck_penalty = 0.0 if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8: stuck_penalty = -0.03 if self.recharge_mode: stuck_penalty -= 0.02 * min(self.stuck_steps, 5) npc_penalty = 0.0 if self.npc_danger: npc_penalty -= 4.0 elif self.npc_predicted_danger: npc_penalty -= 0.4 elif self.nearest_npc_dist <= 3: npc_penalty -= 0.05 * (4 - self.nearest_npc_dist) return ( cleaning_reward + charge_reward + exploration_reward + stuck_penalty + npc_penalty + step_penalty ) def _reward_profile_scales(self): """Return multipliers for quick reward-shaping ablations.""" if self.reward_profile == "lower_recharge": return 1.0, 0.70, 1.0 if self.reward_profile == "clean_explore": return 1.15, 0.85, 1.50 if self.reward_profile == "battery_safe": return 0.95, 1.25, 0.90 return 1.0, 1.0, 1.0