Optimize PPO coverage and recharge strategy
This commit is contained in:
@@ -46,7 +46,9 @@ class Preprocessor:
|
||||
|
||||
GRID_SIZE = 128
|
||||
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
|
||||
LOCAL_HALF = 5 # Cropped view radius (11×11) / 裁剪后的视野半径
|
||||
VIEW_SIZE = 21
|
||||
MAP_CHANNELS = 6
|
||||
PLANNER_UPDATE_INTERVAL = 4
|
||||
ACTION_DIRS = (
|
||||
(1, 0),
|
||||
(1, -1),
|
||||
@@ -93,9 +95,29 @@ class Preprocessor:
|
||||
self.step_cleaned_count = 0
|
||||
self.max_step = 1000
|
||||
|
||||
# Global passable map (0=obstacle, 1=passable), indexed by [x, z].
|
||||
# 维护全局通行地图(0=障碍, 1=可通行),索引为 [x, z]。
|
||||
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||
# Global belief maps indexed by [x, z].
|
||||
# 全局 belief map,索引为 [x, z]。
|
||||
self.known_map = np.full((self.GRID_SIZE, self.GRID_SIZE), -1, dtype=np.int8)
|
||||
self.passable_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||
self.frontier_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||
self.dirty_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||
self._dirty_reverse_dist = None
|
||||
self._frontier_reverse_dist = None
|
||||
self._charger_reverse_dist = None
|
||||
self._path_cache_dirty = True
|
||||
self._planner_last_update_step = -self.PLANNER_UPDATE_INTERVAL
|
||||
self.known_ratio = 0.0
|
||||
self.known_passable_ratio = 0.0
|
||||
self.known_dirty_ratio = 0.0
|
||||
self.frontier_ratio = 0.0
|
||||
self.global_dirty_path_dist = float(self.GRID_SIZE)
|
||||
self.last_global_dirty_path_dist = float(self.GRID_SIZE)
|
||||
self.frontier_path_dist = float(self.GRID_SIZE)
|
||||
self.last_frontier_path_dist = float(self.GRID_SIZE)
|
||||
self.global_dirty_action_delta = np.zeros(8, dtype=np.float32)
|
||||
self.frontier_action_delta = np.zeros(8, dtype=np.float32)
|
||||
self.charger_action_delta = np.zeros(8, dtype=np.float32)
|
||||
self.charger_route_known = False
|
||||
|
||||
# Nearest dirt path distance in the current local view.
|
||||
# 当前局部视野内最近污渍路径距离。
|
||||
@@ -131,7 +153,7 @@ class Preprocessor:
|
||||
self.charger_safety_margin = 0.0
|
||||
self.recharge_enter_margin = 0.0
|
||||
self.recharge_leave_margin = 0.0
|
||||
self.recharge_low_battery_ratio = 0.35
|
||||
self.recharge_low_battery_ratio = 0.28
|
||||
self.full_charge_leave_ratio = 0.96
|
||||
self.battery_margin = 0.0
|
||||
self.has_charger = False
|
||||
@@ -143,9 +165,15 @@ class Preprocessor:
|
||||
|
||||
self.nearest_npc_dx = 0.0
|
||||
self.nearest_npc_dz = 0.0
|
||||
self.nearest_npc_vx = 0.0
|
||||
self.nearest_npc_vz = 0.0
|
||||
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||||
self.predicted_npc_dist = float(self.GRID_SIZE)
|
||||
self.npc_danger = False
|
||||
self.npc_predicted_danger = False
|
||||
self.npcs = []
|
||||
self.prev_npc_positions = {}
|
||||
self.predicted_npcs = []
|
||||
self.npc_close_steps = 0
|
||||
self.npc_danger_steps = 0
|
||||
self.npc_collision = 0
|
||||
@@ -225,6 +253,7 @@ class Preprocessor:
|
||||
self._view_map = np.array(map_info, dtype=np.float32)
|
||||
hx, hz = self.cur_pos
|
||||
self._update_passable(hx, hz)
|
||||
self._mark_cleaned_cells(step_cleaned_cells)
|
||||
self._update_local_map_stats()
|
||||
|
||||
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
||||
@@ -233,6 +262,7 @@ class Preprocessor:
|
||||
self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else []
|
||||
self._update_charger_state(hx, hz, organs)
|
||||
self._update_npc_state(hx, hz, self.npcs)
|
||||
self._update_global_planning_state()
|
||||
self._update_recharge_mode()
|
||||
self._update_motion_health()
|
||||
|
||||
@@ -250,9 +280,36 @@ class Preprocessor:
|
||||
gx = hx + ci - half
|
||||
gz = hz + ri - half
|
||||
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
|
||||
# 0 = obstacle, 1/2 = passable
|
||||
# 0 = 障碍, 1/2 = 可通行
|
||||
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
||||
cell = int(view[ri, ci])
|
||||
self.known_map[gx, gz] = cell
|
||||
self.passable_map[gx, gz] = 1 if cell != 0 else 0
|
||||
self.dirty_map[gx, gz] = 1 if cell == 2 else 0
|
||||
|
||||
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
|
||||
self.known_map[hx, hz] = 1
|
||||
self.passable_map[hx, hz] = 1
|
||||
self.dirty_map[hx, hz] = 0
|
||||
self._clear_path_caches()
|
||||
|
||||
def _mark_cleaned_cells(self, step_cleaned_cells):
|
||||
"""Mark cells cleaned in the current step in the global belief map."""
|
||||
for pos in step_cleaned_cells or []:
|
||||
pos = _as_dict(pos)
|
||||
x = int(pos.get("x", -1))
|
||||
z = int(pos.get("z", -1))
|
||||
if 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE:
|
||||
self.known_map[x, z] = 1
|
||||
self.passable_map[x, z] = 1
|
||||
self.dirty_map[x, z] = 0
|
||||
self._clear_path_caches()
|
||||
|
||||
def _clear_path_caches(self):
|
||||
self._path_cache_dirty = True
|
||||
|
||||
def _drop_path_caches(self):
|
||||
self._dirty_reverse_dist = None
|
||||
self._frontier_reverse_dist = None
|
||||
self._charger_reverse_dist = None
|
||||
|
||||
def _view_index_to_global(self, ri, ci):
|
||||
"""Convert local view row/col to global x/z coordinates."""
|
||||
@@ -285,6 +342,159 @@ class Preprocessor:
|
||||
self.local_dirt_ratio = float(np.sum(view == 2) / total)
|
||||
self.local_obstacle_ratio = float(np.sum(view == 0) / total)
|
||||
|
||||
def _update_global_planning_state(self):
|
||||
"""Refresh global coverage, frontier, and action-improvement features."""
|
||||
self.last_global_dirty_path_dist = self.global_dirty_path_dist
|
||||
self.last_frontier_path_dist = self.frontier_path_dist
|
||||
|
||||
self._update_frontier_map()
|
||||
hx, hz = self.cur_pos
|
||||
|
||||
should_refresh_paths = (
|
||||
self._dirty_reverse_dist is None
|
||||
or self._frontier_reverse_dist is None
|
||||
or (self.has_charger and self._charger_reverse_dist is None)
|
||||
or (
|
||||
self._path_cache_dirty
|
||||
and self.step_no - self._planner_last_update_step >= self.PLANNER_UPDATE_INTERVAL
|
||||
)
|
||||
)
|
||||
if should_refresh_paths:
|
||||
self._drop_path_caches()
|
||||
self._planner_last_update_step = self.step_no
|
||||
self._path_cache_dirty = False
|
||||
|
||||
known_count = float(np.sum(self.known_map >= 0))
|
||||
passable_count = float(np.sum(self.passable_map > 0))
|
||||
dirty_count = float(np.sum(self.dirty_map > 0))
|
||||
frontier_count = float(np.sum(self.frontier_map > 0))
|
||||
total_cells = float(self.GRID_SIZE * self.GRID_SIZE)
|
||||
self.known_ratio = known_count / total_cells
|
||||
self.known_passable_ratio = passable_count / total_cells
|
||||
self.known_dirty_ratio = dirty_count / max(float(self.total_dirt), 1.0)
|
||||
self.frontier_ratio = frontier_count / max(passable_count, 1.0)
|
||||
|
||||
dirty_dist = self._get_dirty_reverse_dist()
|
||||
frontier_dist = self._get_frontier_reverse_dist()
|
||||
charger_dist = self._get_charger_reverse_dist()
|
||||
|
||||
self.global_dirty_path_dist = self._dist_at(dirty_dist, hx, hz, default=float(self.GRID_SIZE))
|
||||
self.frontier_path_dist = self._dist_at(frontier_dist, hx, hz, default=float(self.GRID_SIZE))
|
||||
if charger_dist is not None:
|
||||
charger_path = self._dist_at(charger_dist, hx, hz, default=self.INF_DIST)
|
||||
if charger_path < self.INF_DIST:
|
||||
self.nearest_charger_path_dist = min(self.nearest_charger_path_dist, float(charger_path))
|
||||
self.charger_energy_cost = self.nearest_charger_path_dist
|
||||
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
||||
self.charger_route_known = True
|
||||
|
||||
self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist)
|
||||
self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist)
|
||||
current_charger = self._dist_at(charger_dist, hx, hz, default=self.nearest_charger_path_dist)
|
||||
self.charger_action_delta = self._action_distance_delta(charger_dist, current_charger)
|
||||
|
||||
def _update_frontier_map(self):
|
||||
"""Mark known passable cells adjacent to unseen space as exploration frontiers."""
|
||||
self.frontier_map.fill(0)
|
||||
passable_coords = np.argwhere(self.passable_map > 0)
|
||||
for x, z in passable_coords:
|
||||
x = int(x)
|
||||
z = int(z)
|
||||
for dx, dz in ((1, 0), (-1, 0), (0, 1), (0, -1)):
|
||||
nx, nz = x + dx, z + dz
|
||||
if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE and self.known_map[nx, nz] < 0:
|
||||
self.frontier_map[x, z] = 1
|
||||
break
|
||||
|
||||
def _get_dirty_reverse_dist(self):
|
||||
if self._dirty_reverse_dist is None:
|
||||
targets = np.argwhere((self.dirty_map > 0) & (self.passable_map > 0))
|
||||
self._dirty_reverse_dist = self._global_bfs_from_targets(targets)
|
||||
return self._dirty_reverse_dist
|
||||
|
||||
def _get_frontier_reverse_dist(self):
|
||||
if self._frontier_reverse_dist is None:
|
||||
targets = np.argwhere((self.frontier_map > 0) & (self.passable_map > 0))
|
||||
self._frontier_reverse_dist = self._global_bfs_from_targets(targets)
|
||||
return self._frontier_reverse_dist
|
||||
|
||||
def _get_charger_reverse_dist(self):
|
||||
if not self.charger_rects:
|
||||
return None
|
||||
if self._charger_reverse_dist is None:
|
||||
self._charger_reverse_dist = self._global_bfs_from_targets(self._charger_target_cells())
|
||||
return self._charger_reverse_dist
|
||||
|
||||
def _charger_target_cells(self):
|
||||
targets = []
|
||||
for rx, rz, w, h in self.charger_rects:
|
||||
for x in range(rx, rx + w):
|
||||
for z in range(rz, rz + h):
|
||||
if self._is_known_passable(x, z):
|
||||
targets.append((x, z))
|
||||
return targets
|
||||
|
||||
def _global_bfs_from_targets(self, targets):
|
||||
"""Reverse BFS over the accumulated known passable map."""
|
||||
dist = np.full((self.GRID_SIZE, self.GRID_SIZE), self.INF_DIST, dtype=np.float32)
|
||||
queue = deque()
|
||||
for target in targets:
|
||||
if len(target) < 2:
|
||||
continue
|
||||
x = int(target[0])
|
||||
z = int(target[1])
|
||||
if not self._is_known_passable(x, z) or dist[x, z] == 0.0:
|
||||
continue
|
||||
dist[x, z] = 0.0
|
||||
queue.append((x, z))
|
||||
|
||||
while queue:
|
||||
x, z = queue.popleft()
|
||||
base = dist[x, z]
|
||||
for dx, dz in self.ACTION_DIRS:
|
||||
nx, nz = x + dx, z + dz
|
||||
if not self._can_global_move(x, z, dx, dz):
|
||||
continue
|
||||
if dist[nx, nz] < self.INF_DIST:
|
||||
continue
|
||||
dist[nx, nz] = base + 1.0
|
||||
queue.append((nx, nz))
|
||||
return dist
|
||||
|
||||
def _is_known_passable(self, x, z):
|
||||
return 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE and self.passable_map[x, z] > 0
|
||||
|
||||
def _can_global_move(self, x, z, dx, dz):
|
||||
nx, nz = x + dx, z + dz
|
||||
if not self._is_known_passable(x, z) or not self._is_known_passable(nx, nz):
|
||||
return False
|
||||
if dx != 0 and dz != 0:
|
||||
return self._is_known_passable(x + dx, z) or self._is_known_passable(x, z + dz)
|
||||
return True
|
||||
|
||||
def _dist_at(self, dist, x, z, default=None):
|
||||
if default is None:
|
||||
default = self.INF_DIST
|
||||
if dist is None or not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
|
||||
return float(default)
|
||||
value = float(dist[x, z])
|
||||
return value if value < self.INF_DIST else float(default)
|
||||
|
||||
def _action_distance_delta(self, dist, current_dist):
|
||||
delta = np.zeros(8, dtype=np.float32)
|
||||
if dist is None or current_dist >= self.INF_DIST:
|
||||
return delta
|
||||
hx, hz = self.cur_pos
|
||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||
nx, nz = hx + dx, hz + dz
|
||||
if not self._can_global_move(hx, hz, dx, dz):
|
||||
continue
|
||||
next_dist = self._dist_at(dist, nx, nz, default=self.INF_DIST)
|
||||
if next_dist >= self.INF_DIST:
|
||||
continue
|
||||
delta[action] = np.float32(np.clip((current_dist - next_dist) / 4.0, -1.0, 1.0))
|
||||
return delta
|
||||
|
||||
def _update_charger_state(self, hx, hz, organs):
|
||||
"""Find nearest charger and cache distance/direction features."""
|
||||
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
||||
@@ -302,6 +512,7 @@ class Preprocessor:
|
||||
self.charger_safety_buffer = 0.0
|
||||
self.charger_safety_margin = 0.0
|
||||
self.charger_rects = []
|
||||
self.charger_route_known = False
|
||||
|
||||
best = None
|
||||
for organ in organs:
|
||||
@@ -338,7 +549,10 @@ class Preprocessor:
|
||||
self.nearest_charger_center_dz = float(center_dz)
|
||||
self.nearest_charger_dist = float(dist)
|
||||
self.nearest_charger_range_dist = float(range_dist)
|
||||
path_dist = self._local_path_dist_to_charger(hx, hz)
|
||||
path_dist = self._global_path_dist_to_charger(hx, hz)
|
||||
self.charger_route_known = path_dist < self.INF_DIST
|
||||
if not self.charger_route_known:
|
||||
path_dist = self._local_path_dist_to_charger(hx, hz)
|
||||
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
|
||||
self.charger_energy_cost = self.nearest_charger_path_dist
|
||||
self.on_charger = range_dist <= 0.0
|
||||
@@ -365,30 +579,54 @@ class Preprocessor:
|
||||
"""Find nearest NPC and cache safety features."""
|
||||
self.nearest_npc_dx = 0.0
|
||||
self.nearest_npc_dz = 0.0
|
||||
self.nearest_npc_vx = 0.0
|
||||
self.nearest_npc_vz = 0.0
|
||||
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||||
self.predicted_npc_dist = float(self.GRID_SIZE)
|
||||
self.npc_danger = False
|
||||
self.npc_predicted_danger = False
|
||||
self.predicted_npcs = []
|
||||
|
||||
best = None
|
||||
current_positions = {}
|
||||
for npc in npcs:
|
||||
if not isinstance(npc, dict):
|
||||
continue
|
||||
pos = npc.get("pos") or {}
|
||||
nx = int(pos.get("x", 0))
|
||||
nz = int(pos.get("z", 0))
|
||||
npc_key = str(npc.get("npc_id", npc.get("idx", len(current_positions))))
|
||||
prev_pos = self.prev_npc_positions.get(npc_key)
|
||||
vx = 0
|
||||
vz = 0
|
||||
if prev_pos is not None:
|
||||
vx = int(np.clip(nx - prev_pos[0], -1, 1))
|
||||
vz = int(np.clip(nz - prev_pos[1], -1, 1))
|
||||
px = int(np.clip(nx + vx, 0, self.GRID_SIZE - 1))
|
||||
pz = int(np.clip(nz + vz, 0, self.GRID_SIZE - 1))
|
||||
current_positions[npc_key] = (nx, nz)
|
||||
self.predicted_npcs.append((px, pz, 1))
|
||||
|
||||
dx = nx - hx
|
||||
dz = nz - hz
|
||||
cheb = float(max(abs(dx), abs(dz)))
|
||||
pred_cheb = float(max(abs(px - hx), abs(pz - hz)))
|
||||
if best is None or cheb < best[0]:
|
||||
best = (cheb, dx, dz)
|
||||
best = (cheb, dx, dz, vx, vz, pred_cheb)
|
||||
|
||||
self.prev_npc_positions = current_positions
|
||||
if best is None:
|
||||
return
|
||||
|
||||
cheb, dx, dz = best
|
||||
cheb, dx, dz, vx, vz, pred_cheb = best
|
||||
self.nearest_npc_dx = float(dx)
|
||||
self.nearest_npc_dz = float(dz)
|
||||
self.nearest_npc_vx = float(vx)
|
||||
self.nearest_npc_vz = float(vz)
|
||||
self.nearest_npc_dist = float(cheb)
|
||||
self.predicted_npc_dist = float(pred_cheb)
|
||||
self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1
|
||||
self.npc_predicted_danger = pred_cheb <= 1
|
||||
|
||||
def _update_recharge_mode(self):
|
||||
"""Enter/exit low-battery recharge mode."""
|
||||
@@ -399,7 +637,7 @@ class Preprocessor:
|
||||
self.charger_safety_margin = float(self.battery)
|
||||
self.recharge_enter_margin = 0.0
|
||||
self.recharge_leave_margin = 0.0
|
||||
self.recharge_low_battery_ratio = 0.35
|
||||
self.recharge_low_battery_ratio = 0.28
|
||||
self.full_charge_leave_ratio = 0.96
|
||||
self.low_battery = battery_ratio < self.recharge_low_battery_ratio
|
||||
return
|
||||
@@ -457,15 +695,15 @@ class Preprocessor:
|
||||
)
|
||||
self.recharge_no_progress_steps = self.recharge_no_progress_steps + 1 if no_progress else 0
|
||||
|
||||
if self.step_no > 0 and self.nearest_npc_dist <= 3:
|
||||
if self.step_no > 0 and min(self.nearest_npc_dist, self.predicted_npc_dist) <= 3:
|
||||
self.npc_close_steps += 1
|
||||
if self.step_no > 0 and self.npc_danger:
|
||||
if self.step_no > 0 and (self.npc_danger or self.npc_predicted_danger):
|
||||
self.npc_danger_steps += 1
|
||||
|
||||
if self.terminated and not self.truncated:
|
||||
if self.battery <= 0 or self.remaining_charge <= 0:
|
||||
self.battery_fail = 1
|
||||
if self.npc_danger or self.nearest_npc_dist <= 1:
|
||||
if self.npc_danger or self.npc_predicted_danger or self.nearest_npc_dist <= 1:
|
||||
self.npc_collision = 1
|
||||
|
||||
def _need_recharge_escape(self):
|
||||
@@ -473,41 +711,41 @@ class Preprocessor:
|
||||
|
||||
def _charger_safety_buffer(self):
|
||||
# One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise.
|
||||
base = max(24.0, 0.16 * float(self.battery_max))
|
||||
distance_buffer = min(24.0, 0.25 * float(max(self.nearest_charger_range_dist, 0.0)))
|
||||
obstacle_buffer = 18.0 * float(self.local_obstacle_ratio)
|
||||
return float(np.clip(base + distance_buffer + obstacle_buffer, 24.0, 64.0))
|
||||
base = max(18.0, 0.12 * float(self.battery_max))
|
||||
distance_buffer = min(16.0, 0.18 * float(max(self.nearest_charger_range_dist, 0.0)))
|
||||
obstacle_buffer = 12.0 * float(self.local_obstacle_ratio)
|
||||
return float(np.clip(base + distance_buffer + obstacle_buffer, 18.0, 48.0))
|
||||
|
||||
def _recharge_enter_margin(self):
|
||||
"""Adaptive margin for entering recharge mode before the battery is barely enough."""
|
||||
base = max(8.0, 0.025 * float(self.battery_max))
|
||||
path_margin = min(18.0, 0.12 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||||
obstacle_margin = 20.0 * float(self.local_obstacle_ratio)
|
||||
recovery_margin = min(10.0, 2.0 * float(self.recharge_no_progress_steps + self.fake_charger_steps))
|
||||
return float(np.clip(base + path_margin + obstacle_margin + recovery_margin, 8.0, 48.0))
|
||||
base = max(5.0, 0.018 * float(self.battery_max))
|
||||
path_margin = min(12.0, 0.08 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||||
obstacle_margin = 12.0 * float(self.local_obstacle_ratio)
|
||||
recovery_margin = min(8.0, 1.5 * float(self.recharge_no_progress_steps + self.fake_charger_steps))
|
||||
return float(np.clip(base + path_margin + obstacle_margin + recovery_margin, 4.0, 32.0))
|
||||
|
||||
def _recharge_leave_margin(self):
|
||||
"""Adaptive safety margin required before leaving a charger."""
|
||||
base = max(28.0, 0.10 * float(self.battery_max))
|
||||
path_margin = min(24.0, 0.18 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||||
obstacle_margin = 16.0 * float(self.local_obstacle_ratio)
|
||||
return float(np.clip(base + path_margin + obstacle_margin, 28.0, 88.0))
|
||||
base = max(20.0, 0.08 * float(self.battery_max))
|
||||
path_margin = min(18.0, 0.14 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||||
obstacle_margin = 12.0 * float(self.local_obstacle_ratio)
|
||||
return float(np.clip(base + path_margin + obstacle_margin, 20.0, 64.0))
|
||||
|
||||
def _recharge_low_battery_ratio(self):
|
||||
"""Adaptive low-battery ratio based on route length and local obstacle density."""
|
||||
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
|
||||
ratio = 0.32 + min(0.10, 0.55 * path_pressure) + min(0.06, 0.20 * float(self.local_obstacle_ratio))
|
||||
ratio = 0.25 + min(0.08, 0.40 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio))
|
||||
if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0:
|
||||
ratio += 0.03
|
||||
return float(np.clip(ratio, 0.32, 0.48))
|
||||
ratio += 0.02
|
||||
return float(np.clip(ratio, 0.25, 0.40))
|
||||
|
||||
def _full_charge_leave_ratio(self):
|
||||
"""Adaptive near-full threshold for leaving a charger."""
|
||||
remaining_step_ratio = 1.0 - _norm(self.step_no, self.max_step)
|
||||
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
|
||||
ratio = 0.94 + 0.03 * remaining_step_ratio + min(0.02, 0.10 * path_pressure)
|
||||
ratio += min(0.01, 0.05 * float(self.local_obstacle_ratio))
|
||||
return float(np.clip(ratio, 0.94, 0.985))
|
||||
ratio = 0.88 + 0.04 * remaining_step_ratio + min(0.02, 0.08 * path_pressure)
|
||||
ratio += min(0.01, 0.04 * float(self.local_obstacle_ratio))
|
||||
return float(np.clip(ratio, 0.88, 0.95))
|
||||
|
||||
def _recharge_risk_score(self):
|
||||
"""Risk score in [0, 1] used to scale recharge rewards and penalties."""
|
||||
@@ -527,8 +765,8 @@ class Preprocessor:
|
||||
prev_low_risk = max(0.0, self.recharge_low_battery_ratio - prev_battery_ratio)
|
||||
prev_low_risk /= max(self.recharge_low_battery_ratio, 1e-6)
|
||||
risk = max(self._recharge_risk_score(), prev_low_risk)
|
||||
mode_bonus = 0.4 if self.was_recharge_mode or self.prev_low_battery else 0.0
|
||||
return float(np.clip(2.0 + 1.8 * risk + mode_bonus, 2.0, 4.2))
|
||||
mode_bonus = 0.8 if self.was_recharge_mode or self.prev_low_battery else 0.0
|
||||
return float(np.clip(3.0 + 2.8 * risk + mode_bonus, 3.0, 6.5))
|
||||
|
||||
def battery_fail_penalty(self):
|
||||
"""Adaptive terminal penalty for running out of battery before max steps."""
|
||||
@@ -536,7 +774,7 @@ class Preprocessor:
|
||||
early_fail_risk = 1.0 - step_ratio
|
||||
path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0)
|
||||
risk = max(self._recharge_risk_score(), min(1.0, path_pressure))
|
||||
return float(np.clip(5.5 + 2.5 * early_fail_risk + 1.0 * risk, 5.5, 9.0))
|
||||
return float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0))
|
||||
|
||||
def _min_charger_range_dist(self, x, z):
|
||||
if not self.charger_rects:
|
||||
@@ -547,50 +785,42 @@ class Preprocessor:
|
||||
dists.append(max(abs(dx), abs(dz)))
|
||||
return float(min(dists))
|
||||
|
||||
def _get_local_view_feature(self):
|
||||
"""Local view feature (121D): crop center 11×11 from 21×21.
|
||||
def _is_charger_cell(self, x, z):
|
||||
for rx, rz, w, h in self.charger_rects:
|
||||
if rx <= x < rx + w and rz <= z < rz + h:
|
||||
return True
|
||||
return False
|
||||
|
||||
局部视野特征(121D):从 21×21 视野中心裁剪 11×11。
|
||||
def _get_local_view_feature(self):
|
||||
"""Local view feature: 21×21×6 multi-channel map.
|
||||
|
||||
Channels: obstacle, clean, dirt, visit count, NPC danger, charger.
|
||||
"""
|
||||
center = self.VIEW_HALF
|
||||
h = self.LOCAL_HALF
|
||||
crop = self._view_map[center - h : center + h + 1, center - h : center + h + 1]
|
||||
return (crop / 2.0).flatten()
|
||||
view = self._view_map
|
||||
channels = np.zeros((self.MAP_CHANNELS, self.VIEW_SIZE, self.VIEW_SIZE), dtype=np.float32)
|
||||
if view is None or view.shape[0] != self.VIEW_SIZE or view.shape[1] != self.VIEW_SIZE:
|
||||
return channels.flatten()
|
||||
|
||||
channels[0] = (view == 0).astype(np.float32)
|
||||
channels[1] = (view == 1).astype(np.float32)
|
||||
channels[2] = (view == 2).astype(np.float32)
|
||||
|
||||
for ri in range(self.VIEW_SIZE):
|
||||
for ci in range(self.VIEW_SIZE):
|
||||
gx, gz = self._view_index_to_global(ri, ci)
|
||||
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
|
||||
continue
|
||||
channels[3, ri, ci] = _norm(min(int(self.visit_count_map[gx, gz]), 10), 10)
|
||||
channels[4, ri, ci] = 1.0 if self._is_npc_danger_cell(gx, gz, expanded=True) else 0.0
|
||||
channels[5, ri, ci] = 1.0 if self._is_charger_cell(gx, gz) else 0.0
|
||||
|
||||
return channels.flatten()
|
||||
|
||||
def _get_global_state_feature(self):
|
||||
"""Global state feature (28D).
|
||||
"""Global state feature (66D).
|
||||
|
||||
全局状态特征(28D)。
|
||||
|
||||
Dimensions / 维度说明:
|
||||
[0] step_norm step progress / 步数归一化 [0,1]
|
||||
[1] battery_ratio battery level / 电量比 [0,1]
|
||||
[2] cleaning_progress cleaned ratio / 已清扫比例 [0,1]
|
||||
[3] remaining_dirt remaining dirt ratio / 剩余污渍比例 [0,1]
|
||||
[4] pos_x_weak weak x position / 弱化后的 x 坐标 [0.4,0.6]
|
||||
[5] pos_z_weak weak z position / 弱化后的 z 坐标 [0.4,0.6]
|
||||
[6] ray_N_dirt north ray distance / 向上(z-)方向最近污渍距离
|
||||
[7] ray_E_dirt east ray distance / 向右(x+)方向
|
||||
[8] ray_S_dirt south ray distance / 向下(z+)方向
|
||||
[9] ray_W_dirt west ray distance / 向左(x-)方向
|
||||
[10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化
|
||||
[11] dirt_delta approaching dirt indicator / 是否在接近污渍(1=是, 0=否)
|
||||
[12] charger_dx nearest charger x direction / 最近充电桩 x 相对方向
|
||||
[13] charger_dz nearest charger z direction / 最近充电桩 z 相对方向
|
||||
[14] charger_dist nearest charger distance / 最近充电桩距离
|
||||
[15] battery_margin battery minus charger distance / 电量安全余量
|
||||
[16] low_battery low-battery flag / 低电量标记
|
||||
[17] recharge_mode recharge-mode flag / 回充模式标记
|
||||
[18] on_charger on charger flag / 是否在充电桩范围
|
||||
[19] charge_delta charge count increased / 本步是否成功充电
|
||||
[20] npc_dx nearest NPC x direction / 最近 NPC x 相对方向
|
||||
[21] npc_dz nearest NPC z direction / 最近 NPC z 相对方向
|
||||
[22] npc_dist nearest NPC Chebyshev distance / 最近 NPC 切比雪夫距离
|
||||
[23] npc_danger in NPC 3x3 danger zone / 是否处于 NPC 3x3 危险区
|
||||
[24] local_dirt_ratio dirt ratio in 21x21 view / 21x21 视野污渍比例
|
||||
[25] obstacle_ratio obstacle ratio in 21x21 view / 21x21 视野障碍比例
|
||||
[26] visit_count current cell visit count / 当前格访问次数
|
||||
[27] step_cleaned cells cleaned this step / 本步清扫格子数
|
||||
Existing global state plus belief-map distances, action distance improvements,
|
||||
known charger-route safety, and predicted NPC motion.
|
||||
"""
|
||||
step_norm = _norm(self.step_no, self.max_step)
|
||||
battery_ratio = _norm(self.battery, self.battery_max)
|
||||
@@ -630,8 +860,15 @@ class Preprocessor:
|
||||
battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max)
|
||||
visit_count_norm = _norm(min(self.current_visit_count, 10), 10)
|
||||
step_cleaned_norm = _norm(self.step_cleaned_count, 9)
|
||||
global_dirty_delta = _signed_norm(
|
||||
np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -4.0, 4.0), 4.0
|
||||
)
|
||||
frontier_delta = _signed_norm(
|
||||
np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -4.0, 4.0), 4.0
|
||||
)
|
||||
charger_margin_after_buffer = self.battery - self.nearest_charger_path_dist - self.charger_safety_buffer
|
||||
|
||||
return np.array(
|
||||
base_features = np.array(
|
||||
[
|
||||
step_norm,
|
||||
battery_ratio,
|
||||
@@ -661,10 +898,33 @@ class Preprocessor:
|
||||
self.local_obstacle_ratio,
|
||||
visit_count_norm,
|
||||
step_cleaned_norm,
|
||||
_norm(self.global_dirty_path_dist, self.GRID_SIZE),
|
||||
_norm(self.frontier_path_dist, self.GRID_SIZE),
|
||||
global_dirty_delta,
|
||||
frontier_delta,
|
||||
self.known_ratio,
|
||||
self.known_passable_ratio,
|
||||
_norm(self.known_dirty_ratio, 1.0),
|
||||
_norm(self.frontier_ratio, 1.0),
|
||||
1.0 if self.charger_route_known else 0.0,
|
||||
_signed_norm(charger_margin_after_buffer, self.battery_max),
|
||||
_signed_norm(self.nearest_npc_vx, 1.0),
|
||||
_signed_norm(self.nearest_npc_vz, 1.0),
|
||||
_norm(self.predicted_npc_dist, 20),
|
||||
1.0 if self.npc_predicted_danger else 0.0,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return np.concatenate(
|
||||
[
|
||||
base_features,
|
||||
self.global_dirty_action_delta.astype(np.float32),
|
||||
self.frontier_action_delta.astype(np.float32),
|
||||
self.charger_action_delta.astype(np.float32),
|
||||
]
|
||||
)
|
||||
|
||||
def _weak_abs_position_feature(self, value):
|
||||
pos_norm = _norm(value, self.GRID_SIZE)
|
||||
return 0.5 + self.ABS_POS_FEATURE_SCALE * (pos_norm - 0.5)
|
||||
@@ -731,8 +991,16 @@ class Preprocessor:
|
||||
best = min(best, float(dist[ri, ci]))
|
||||
return best
|
||||
|
||||
def _global_path_dist_to_charger(self, gx, gz):
|
||||
"""Known-map BFS distance from a global cell to the nearest observed charger cell."""
|
||||
dist = self._get_charger_reverse_dist()
|
||||
return self._dist_at(dist, gx, gz, default=self.INF_DIST)
|
||||
|
||||
def _charger_move_distance(self, gx, gz):
|
||||
"""Use visible BFS to the charger when available, otherwise Chebyshev distance."""
|
||||
"""Use known-map BFS to the charger when available, then visible BFS, then Chebyshev."""
|
||||
path_dist = self._global_path_dist_to_charger(gx, gz)
|
||||
if path_dist < self.INF_DIST:
|
||||
return path_dist
|
||||
path_dist = self._local_path_dist_to_charger(gx, gz)
|
||||
if path_dist < self.INF_DIST:
|
||||
return path_dist
|
||||
@@ -780,7 +1048,7 @@ class Preprocessor:
|
||||
return True if cell is None else cell != 0
|
||||
|
||||
def _filter_npc_danger_actions(self, legal_action):
|
||||
"""Avoid actions that would enter any NPC 3x3 danger zone."""
|
||||
"""Avoid current and predicted NPC danger zones."""
|
||||
if not self.npcs:
|
||||
return list(legal_action)
|
||||
|
||||
@@ -790,12 +1058,22 @@ class Preprocessor:
|
||||
if safe[action] <= 0:
|
||||
continue
|
||||
nx, nz = hx + dx, hz + dz
|
||||
if self._is_npc_danger_cell(nx, nz):
|
||||
if self._is_npc_danger_cell(nx, nz, expanded=True):
|
||||
safe[action] = 0
|
||||
|
||||
return safe if any(safe) else list(legal_action)
|
||||
if any(safe):
|
||||
return safe
|
||||
|
||||
def _is_npc_danger_cell(self, x, z):
|
||||
hard_safe = [int(x) for x in legal_action]
|
||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||
if hard_safe[action] <= 0:
|
||||
continue
|
||||
nx, nz = hx + dx, hz + dz
|
||||
if self._is_npc_danger_cell(nx, nz, expanded=False):
|
||||
hard_safe[action] = 0
|
||||
return hard_safe if any(hard_safe) else list(legal_action)
|
||||
|
||||
def _is_npc_danger_cell(self, x, z, expanded=True):
|
||||
for npc in self.npcs:
|
||||
if not isinstance(npc, dict):
|
||||
continue
|
||||
@@ -804,6 +1082,14 @@ class Preprocessor:
|
||||
nz = int(pos.get("z", -999))
|
||||
if abs(x - nx) <= 1 and abs(z - nz) <= 1:
|
||||
return True
|
||||
if expanded and abs(x - nx) <= 2 and abs(z - nz) <= 2 and self.nearest_npc_dist <= 4:
|
||||
return True
|
||||
if expanded:
|
||||
for px, pz, radius in self.predicted_npcs:
|
||||
if abs(x - px) <= radius and abs(z - pz) <= radius:
|
||||
return True
|
||||
if self.nearest_npc_dist <= 4 and abs(x - px) <= 2 and abs(z - pz) <= 2:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _filter_recharge_actions(self, legal_action):
|
||||
@@ -927,8 +1213,8 @@ class Preprocessor:
|
||||
"""
|
||||
self.pb2struct(env_obs, last_action)
|
||||
|
||||
local_view = self._get_local_view_feature() # 121D
|
||||
global_state = self._get_global_state_feature() # 28D
|
||||
local_view = self._get_local_view_feature() # 2646D
|
||||
global_state = self._get_global_state_feature() # 66D
|
||||
legal_action = self.get_legal_action() # 8D
|
||||
|
||||
last_action_feature = np.zeros(8, dtype=np.float32)
|
||||
@@ -969,8 +1255,8 @@ class Preprocessor:
|
||||
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
|
||||
)
|
||||
recharge_risk = self._recharge_risk_score()
|
||||
approach_scale = 0.04 + 0.04 * recharge_risk
|
||||
retreat_scale = 0.02 + 0.03 * recharge_risk
|
||||
approach_scale = 0.07 + 0.06 * recharge_risk
|
||||
retreat_scale = 0.035 + 0.045 * recharge_risk
|
||||
charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta
|
||||
if self.charger_safety_margin < self.recharge_enter_margin:
|
||||
safety_shortage = self.recharge_enter_margin - self.charger_safety_margin
|
||||
@@ -984,6 +1270,12 @@ class Preprocessor:
|
||||
exploration_reward = 0.0
|
||||
else:
|
||||
exploration_reward = 0.004 if self.is_new_cell else -0.0015 * min(self.current_visit_count, 6)
|
||||
if self.global_dirty_path_dist < self.GRID_SIZE:
|
||||
dirty_progress = np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -3.0, 3.0)
|
||||
exploration_reward += 0.008 * dirty_progress
|
||||
elif self.frontier_path_dist < self.GRID_SIZE:
|
||||
frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
|
||||
exploration_reward += 0.005 * frontier_progress
|
||||
|
||||
# Collision/stuck signal: invalid moves waste both step and battery.
|
||||
# 撞墙/原地不动会浪费步数和电量。
|
||||
@@ -996,22 +1288,16 @@ class Preprocessor:
|
||||
npc_penalty = 0.0
|
||||
if self.npc_danger:
|
||||
npc_penalty -= 4.0
|
||||
elif self.npc_predicted_danger:
|
||||
npc_penalty -= 0.4
|
||||
elif self.nearest_npc_dist <= 3:
|
||||
npc_penalty -= 0.05 * (4 - self.nearest_npc_dist)
|
||||
|
||||
terminal_penalty = 0.0
|
||||
if self.terminated and not self.truncated:
|
||||
if self.battery <= 0 or self.remaining_charge <= 0:
|
||||
terminal_penalty -= self.battery_fail_penalty()
|
||||
elif self.npc_danger or self.nearest_npc_dist <= 1:
|
||||
terminal_penalty -= 3.0
|
||||
|
||||
return (
|
||||
cleaning_reward
|
||||
+ charge_reward
|
||||
+ exploration_reward
|
||||
+ stuck_penalty
|
||||
+ npc_penalty
|
||||
+ terminal_penalty
|
||||
+ step_penalty
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user