This repository has been archived on 2026-05-02. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
-----/agent_ppo/feature/preprocessor.py

1304 lines
55 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
###########################################################################
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
###########################################################################
"""
Author: Tencent AI Arena Authors
Feature preprocessor for Robot Vacuum.
清扫大作战特征预处理器。
"""
from collections import deque
import numpy as np
def _norm(v, v_max, v_min=0.0):
"""Normalize value to [0, 1].
将值线性归一化到 [0, 1]。
"""
v = float(np.clip(v, v_min, v_max))
if v_max == v_min:
return 0.0
return (v - v_min) / (v_max - v_min)
def _signed_norm(v, v_max):
"""Normalize signed value to [-1, 1]."""
if v_max <= 0:
return 0.0
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
def _as_dict(value):
"""Return a dict for optional nested observation fields."""
return value if isinstance(value, dict) else {}
class Preprocessor:
"""Feature preprocessor for Robot Vacuum.
清扫大作战特征预处理器。
"""
GRID_SIZE = 128
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
VIEW_SIZE = 21
MAP_CHANNELS = 6
PLANNER_UPDATE_INTERVAL = 4
ACTION_DIRS = (
(1, 0),
(1, -1),
(0, -1),
(-1, -1),
(-1, 0),
(-1, 1),
(0, 1),
(1, 1),
)
INF_DIST = 1e6
ABS_POS_FEATURE_SCALE = 0.2
def __init__(self):
self.reset()
def reset(self):
"""Reset all internal state at episode start.
对局开始时重置所有状态。
"""
self.step_no = 0
self.battery = 600
self.battery_max = 600
self.cur_pos = (0, 0)
self.prev_pos = None
self.has_position_history = False
self.current_visit_count = 0
self.is_new_cell = False
self.last_action = -1
self.stuck_steps = 0
self.recharge_no_progress_steps = 0
self.fake_charger_steps = 0
self.stuck_count = 0
self.max_stuck_steps = 0
self.recharge_escape_count = 0
self.dirt_cleaned = 0
self.last_dirt_cleaned = 0
self.total_dirt = 1
self.total_score = 0
self.clean_score = 0
self.step_cleaned_count = 0
self.max_step = 1000
# Global belief maps indexed by [x, z].
# 全局 belief map索引为 [x, z]。
self.known_map = np.full((self.GRID_SIZE, self.GRID_SIZE), -1, dtype=np.int8)
self.passable_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
self.frontier_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
self.dirty_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
self._dirty_reverse_dist = None
self._frontier_reverse_dist = None
self._charger_reverse_dist = None
self._path_cache_dirty = True
self._planner_last_update_step = -self.PLANNER_UPDATE_INTERVAL
self.known_ratio = 0.0
self.known_passable_ratio = 0.0
self.known_dirty_ratio = 0.0
self.frontier_ratio = 0.0
self.global_dirty_path_dist = float(self.GRID_SIZE)
self.last_global_dirty_path_dist = float(self.GRID_SIZE)
self.frontier_path_dist = float(self.GRID_SIZE)
self.last_frontier_path_dist = float(self.GRID_SIZE)
self.global_dirty_action_delta = np.zeros(8, dtype=np.float32)
self.frontier_action_delta = np.zeros(8, dtype=np.float32)
self.charger_action_delta = np.zeros(8, dtype=np.float32)
self.charger_route_known = False
# Nearest dirt path distance in the current local view.
# 当前局部视野内最近污渍路径距离。
self.nearest_dirt_dist = 200.0
self.last_nearest_dirt_dist = 200.0
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
self._view_map = np.zeros((21, 21), dtype=np.float32)
self._legal_act = [1] * 8
self.terminated = False
self.truncated = False
self.remaining_charge = 0
self.prev_battery = 600
self.prev_battery_max = 600
self.prev_on_charger = False
self.prev_low_battery = False
self.was_recharge_mode = False
self.charge_count = 0
self.last_charge_count = 0
self.charge_delta = 0
self.nearest_charger_dx = 0.0
self.nearest_charger_dz = 0.0
self.nearest_charger_center_dx = 0.0
self.nearest_charger_center_dz = 0.0
self.nearest_charger_dist = float(self.GRID_SIZE)
self.nearest_charger_range_dist = float(self.GRID_SIZE)
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
self.nearest_charger_path_dist = float(self.GRID_SIZE)
self.last_nearest_charger_path_dist = float(self.GRID_SIZE)
self.charger_energy_cost = float(self.GRID_SIZE)
self.charger_safety_buffer = 0.0
self.charger_safety_margin = 0.0
self.recharge_enter_margin = 0.0
self.recharge_leave_margin = 0.0
self.recharge_low_battery_ratio = 0.28
self.full_charge_leave_ratio = 0.96
self.battery_margin = 0.0
self.has_charger = False
self.low_battery = False
self.on_charger = False
self.charger_rects = []
self.recharge_mode = False
self.recharge_steps = 0
self.nearest_npc_dx = 0.0
self.nearest_npc_dz = 0.0
self.nearest_npc_vx = 0.0
self.nearest_npc_vz = 0.0
self.nearest_npc_dist = float(self.GRID_SIZE)
self.predicted_npc_dist = float(self.GRID_SIZE)
self.npc_danger = False
self.npc_predicted_danger = False
self.npcs = []
self.prev_npc_positions = {}
self.predicted_npcs = []
self.npc_close_steps = 0
self.npc_danger_steps = 0
self.npc_collision = 0
self.battery_fail = 0
self.local_dirt_ratio = 0.0
self.local_obstacle_ratio = 0.0
def pb2struct(self, env_obs, last_action):
"""Parse and cache essential fields from observation dict.
从 env_obs 字典中提取并缓存所有需要的状态量。
"""
env_obs = _as_dict(env_obs)
observation = _as_dict(env_obs.get("observation"))
frame_state = _as_dict(observation.get("frame_state"))
extra_info = _as_dict(env_obs.get("extra_info"))
extra_frame_state = _as_dict(extra_info.get("frame_state"))
env_info = _as_dict(observation.get("env_info"))
hero = frame_state.get("heroes") or {}
if isinstance(hero, list):
hero = hero[0] if hero else {}
hero = _as_dict(hero)
self.last_action = int(last_action)
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
self.terminated = bool(env_obs.get("terminated", False))
self.truncated = bool(env_obs.get("truncated", False))
self.prev_battery = self.battery
self.prev_battery_max = self.battery_max
self.prev_on_charger = self.on_charger
self.prev_low_battery = self.low_battery
self.was_recharge_mode = self.recharge_mode
self.prev_pos = self.cur_pos if self.has_position_history else None
hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]})
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
self.has_position_history = True
hx, hz = self.cur_pos
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
self.current_visit_count = int(self.visit_count_map[hx, hz])
self.is_new_cell = self.current_visit_count == 0
self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max)
else:
self.current_visit_count = 0
self.is_new_cell = False
# Battery / 电量
self.battery = int(hero.get("battery", env_info.get("remaining_charge", self.battery)))
self.battery_max = max(int(hero.get("battery_max", env_info.get("battery_max", self.battery_max))), 1)
self.remaining_charge = int(env_info.get("remaining_charge", self.battery))
self.max_step = max(int(env_info.get("max_step", self.max_step)), 1)
# Cleaning progress / 清扫进度
self.last_dirt_cleaned = self.dirt_cleaned
self.dirt_cleaned = int(hero.get("dirt_cleaned", env_info.get("clean_score", self.dirt_cleaned)))
self.total_dirt = max(int(env_info.get("total_dirt", self.total_dirt)), 1)
self.total_score = int(env_info.get("total_score", self.total_score))
self.clean_score = int(env_info.get("clean_score", self.dirt_cleaned))
step_cleaned_cells = env_info.get("step_cleaned_cells") or []
self.step_cleaned_count = len(step_cleaned_cells)
# Charge progress / 充电进度
self.last_charge_count = self.charge_count
self.charge_count = int(env_info.get("charge_count", self.charge_count))
self.charge_delta = max(0, self.charge_count - self.last_charge_count)
# Legal actions / 合法动作
raw_legal_act = observation.get("legal_action") or observation.get("legal_act") or [1] * 8
self._legal_act = [int(x) for x in raw_legal_act[:8]]
if len(self._legal_act) < 8:
self._legal_act.extend([1] * (8 - len(self._legal_act)))
# Local view map (21×21) / 局部视野地图
map_info = observation.get("map_info")
if map_info is not None:
self._view_map = np.array(map_info, dtype=np.float32)
hx, hz = self.cur_pos
self._update_passable(hx, hz)
self._mark_cleaned_cells(step_cleaned_cells)
self._update_local_map_stats()
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
organs = organs if isinstance(organs, (list, tuple)) else []
self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else []
self._update_charger_state(hx, hz, organs)
self._update_npc_state(hx, hz, self.npcs)
self._update_global_planning_state()
self._update_recharge_mode()
self._update_motion_health()
def _update_passable(self, hx, hz):
"""Write local view into global passable map.
将局部视野写入全局通行地图。
"""
view = self._view_map
vsize = view.shape[0]
half = vsize // 2
for ri in range(vsize):
for ci in range(vsize):
gx = hx + ci - half
gz = hz + ri - half
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
cell = int(view[ri, ci])
self.known_map[gx, gz] = cell
self.passable_map[gx, gz] = 1 if cell != 0 else 0
self.dirty_map[gx, gz] = 1 if cell == 2 else 0
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
self.known_map[hx, hz] = 1
self.passable_map[hx, hz] = 1
self.dirty_map[hx, hz] = 0
self._clear_path_caches()
def _mark_cleaned_cells(self, step_cleaned_cells):
"""Mark cells cleaned in the current step in the global belief map."""
for pos in step_cleaned_cells or []:
pos = _as_dict(pos)
x = int(pos.get("x", -1))
z = int(pos.get("z", -1))
if 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE:
self.known_map[x, z] = 1
self.passable_map[x, z] = 1
self.dirty_map[x, z] = 0
self._clear_path_caches()
def _clear_path_caches(self):
self._path_cache_dirty = True
def _drop_path_caches(self):
self._dirty_reverse_dist = None
self._frontier_reverse_dist = None
self._charger_reverse_dist = None
def _view_index_to_global(self, ri, ci):
"""Convert local view row/col to global x/z coordinates."""
half = self._view_map.shape[0] // 2
hx, hz = self.cur_pos
return hx + ci - half, hz + ri - half
def _view_delta_to_index(self, dx, dz):
"""Convert global-coordinate dx/dz to local view row/col."""
return self.VIEW_HALF + dz, self.VIEW_HALF + dx
def _view_cell(self, dx, dz, default=None):
"""Read a local-view cell by global-coordinate delta.
`map_info` is row-major: row follows z/down, col follows x/right.
"""
ri, ci = self._view_delta_to_index(dx, dz)
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
return default
return int(self._view_map[ri, ci])
def _update_local_map_stats(self):
"""Cache coarse 21x21 map statistics."""
view = self._view_map
if view is None or view.size == 0:
self.local_dirt_ratio = 0.0
self.local_obstacle_ratio = 0.0
return
total = float(view.size)
self.local_dirt_ratio = float(np.sum(view == 2) / total)
self.local_obstacle_ratio = float(np.sum(view == 0) / total)
def _update_global_planning_state(self):
"""Refresh global coverage, frontier, and action-improvement features."""
self.last_global_dirty_path_dist = self.global_dirty_path_dist
self.last_frontier_path_dist = self.frontier_path_dist
self._update_frontier_map()
hx, hz = self.cur_pos
should_refresh_paths = (
self._dirty_reverse_dist is None
or self._frontier_reverse_dist is None
or (self.has_charger and self._charger_reverse_dist is None)
or (
self._path_cache_dirty
and self.step_no - self._planner_last_update_step >= self.PLANNER_UPDATE_INTERVAL
)
)
if should_refresh_paths:
self._drop_path_caches()
self._planner_last_update_step = self.step_no
self._path_cache_dirty = False
known_count = float(np.sum(self.known_map >= 0))
passable_count = float(np.sum(self.passable_map > 0))
dirty_count = float(np.sum(self.dirty_map > 0))
frontier_count = float(np.sum(self.frontier_map > 0))
total_cells = float(self.GRID_SIZE * self.GRID_SIZE)
self.known_ratio = known_count / total_cells
self.known_passable_ratio = passable_count / total_cells
self.known_dirty_ratio = dirty_count / max(float(self.total_dirt), 1.0)
self.frontier_ratio = frontier_count / max(passable_count, 1.0)
dirty_dist = self._get_dirty_reverse_dist()
frontier_dist = self._get_frontier_reverse_dist()
charger_dist = self._get_charger_reverse_dist()
self.global_dirty_path_dist = self._dist_at(dirty_dist, hx, hz, default=float(self.GRID_SIZE))
self.frontier_path_dist = self._dist_at(frontier_dist, hx, hz, default=float(self.GRID_SIZE))
if charger_dist is not None:
charger_path = self._dist_at(charger_dist, hx, hz, default=self.INF_DIST)
if charger_path < self.INF_DIST:
self.nearest_charger_path_dist = min(self.nearest_charger_path_dist, float(charger_path))
self.charger_energy_cost = self.nearest_charger_path_dist
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
self.charger_route_known = True
self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist)
self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist)
current_charger = self._dist_at(charger_dist, hx, hz, default=self.nearest_charger_path_dist)
self.charger_action_delta = self._action_distance_delta(charger_dist, current_charger)
def _update_frontier_map(self):
"""Mark known passable cells adjacent to unseen space as exploration frontiers."""
self.frontier_map.fill(0)
passable_coords = np.argwhere(self.passable_map > 0)
for x, z in passable_coords:
x = int(x)
z = int(z)
for dx, dz in ((1, 0), (-1, 0), (0, 1), (0, -1)):
nx, nz = x + dx, z + dz
if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE and self.known_map[nx, nz] < 0:
self.frontier_map[x, z] = 1
break
def _get_dirty_reverse_dist(self):
if self._dirty_reverse_dist is None:
targets = np.argwhere((self.dirty_map > 0) & (self.passable_map > 0))
self._dirty_reverse_dist = self._global_bfs_from_targets(targets)
return self._dirty_reverse_dist
def _get_frontier_reverse_dist(self):
if self._frontier_reverse_dist is None:
targets = np.argwhere((self.frontier_map > 0) & (self.passable_map > 0))
self._frontier_reverse_dist = self._global_bfs_from_targets(targets)
return self._frontier_reverse_dist
def _get_charger_reverse_dist(self):
if not self.charger_rects:
return None
if self._charger_reverse_dist is None:
self._charger_reverse_dist = self._global_bfs_from_targets(self._charger_target_cells())
return self._charger_reverse_dist
def _charger_target_cells(self):
targets = []
for rx, rz, w, h in self.charger_rects:
for x in range(rx, rx + w):
for z in range(rz, rz + h):
if self._is_known_passable(x, z):
targets.append((x, z))
return targets
def _global_bfs_from_targets(self, targets):
"""Reverse BFS over the accumulated known passable map."""
dist = np.full((self.GRID_SIZE, self.GRID_SIZE), self.INF_DIST, dtype=np.float32)
queue = deque()
for target in targets:
if len(target) < 2:
continue
x = int(target[0])
z = int(target[1])
if not self._is_known_passable(x, z) or dist[x, z] == 0.0:
continue
dist[x, z] = 0.0
queue.append((x, z))
while queue:
x, z = queue.popleft()
base = dist[x, z]
for dx, dz in self.ACTION_DIRS:
nx, nz = x + dx, z + dz
if not self._can_global_move(x, z, dx, dz):
continue
if dist[nx, nz] < self.INF_DIST:
continue
dist[nx, nz] = base + 1.0
queue.append((nx, nz))
return dist
def _is_known_passable(self, x, z):
return 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE and self.passable_map[x, z] > 0
def _can_global_move(self, x, z, dx, dz):
nx, nz = x + dx, z + dz
if not self._is_known_passable(x, z) or not self._is_known_passable(nx, nz):
return False
if dx != 0 and dz != 0:
return self._is_known_passable(x + dx, z) or self._is_known_passable(x, z + dz)
return True
def _dist_at(self, dist, x, z, default=None):
if default is None:
default = self.INF_DIST
if dist is None or not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
return float(default)
value = float(dist[x, z])
return value if value < self.INF_DIST else float(default)
def _action_distance_delta(self, dist, current_dist):
delta = np.zeros(8, dtype=np.float32)
if dist is None or current_dist >= self.INF_DIST:
return delta
hx, hz = self.cur_pos
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
nx, nz = hx + dx, hz + dz
if not self._can_global_move(hx, hz, dx, dz):
continue
next_dist = self._dist_at(dist, nx, nz, default=self.INF_DIST)
if next_dist >= self.INF_DIST:
continue
delta[action] = np.float32(np.clip((current_dist - next_dist) / 4.0, -1.0, 1.0))
return delta
def _update_charger_state(self, hx, hz, organs):
"""Find nearest charger and cache distance/direction features."""
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
self.last_nearest_charger_path_dist = self.nearest_charger_path_dist
self.has_charger = False
self.on_charger = False
self.nearest_charger_dx = 0.0
self.nearest_charger_dz = 0.0
self.nearest_charger_center_dx = 0.0
self.nearest_charger_center_dz = 0.0
self.nearest_charger_dist = float(self.GRID_SIZE)
self.nearest_charger_range_dist = float(self.GRID_SIZE)
self.nearest_charger_path_dist = float(self.GRID_SIZE)
self.charger_energy_cost = float(self.GRID_SIZE)
self.charger_safety_buffer = 0.0
self.charger_safety_margin = 0.0
self.charger_rects = []
self.charger_route_known = False
best = None
for organ in organs:
if not isinstance(organ, dict):
continue
if int(organ.get("sub_type", 1)) != 1:
continue
pos = organ.get("pos") or {}
ox = int(pos.get("x", 0))
oz = int(pos.get("z", 0))
w = max(int(organ.get("w", 3)), 1)
h = max(int(organ.get("h", 3)), 1)
for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)):
self.charger_rects.append((rx, rz, w, h))
dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h)
dist = float(np.sqrt(dx * dx + dz * dz))
range_dist = float(max(abs(dx), abs(dz)))
center_dx = (rx + (w - 1) * 0.5) - hx
center_dz = (rz + (h - 1) * 0.5) - hz
if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]):
best = (range_dist, dist, dx, dz, center_dx, center_dz)
if best is None:
self.battery_margin = float(self.battery)
self.charger_safety_margin = float(self.battery)
return
range_dist, dist, dx, dz, center_dx, center_dz = best
self.has_charger = True
self.nearest_charger_dx = float(dx)
self.nearest_charger_dz = float(dz)
self.nearest_charger_center_dx = float(center_dx)
self.nearest_charger_center_dz = float(center_dz)
self.nearest_charger_dist = float(dist)
self.nearest_charger_range_dist = float(range_dist)
path_dist = self._global_path_dist_to_charger(hx, hz)
self.charger_route_known = path_dist < self.INF_DIST
if not self.charger_route_known:
path_dist = self._local_path_dist_to_charger(hx, hz)
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
self.charger_energy_cost = self.nearest_charger_path_dist
self.on_charger = range_dist <= 0.0
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
"""Relative vector from point to the nearest cell in a rectangle."""
if x < rx:
dx = rx - x
elif x > rx + w - 1:
dx = x - (rx + w - 1)
else:
dx = 0
if z < rz:
dz = rz - z
elif z > rz + h - 1:
dz = z - (rz + h - 1)
else:
dz = 0
return float(dx), float(dz)
def _update_npc_state(self, hx, hz, npcs):
"""Find nearest NPC and cache safety features."""
self.nearest_npc_dx = 0.0
self.nearest_npc_dz = 0.0
self.nearest_npc_vx = 0.0
self.nearest_npc_vz = 0.0
self.nearest_npc_dist = float(self.GRID_SIZE)
self.predicted_npc_dist = float(self.GRID_SIZE)
self.npc_danger = False
self.npc_predicted_danger = False
self.predicted_npcs = []
best = None
current_positions = {}
for npc in npcs:
if not isinstance(npc, dict):
continue
pos = npc.get("pos") or {}
nx = int(pos.get("x", 0))
nz = int(pos.get("z", 0))
npc_key = str(npc.get("npc_id", npc.get("idx", len(current_positions))))
prev_pos = self.prev_npc_positions.get(npc_key)
vx = 0
vz = 0
if prev_pos is not None:
vx = int(np.clip(nx - prev_pos[0], -1, 1))
vz = int(np.clip(nz - prev_pos[1], -1, 1))
px = int(np.clip(nx + vx, 0, self.GRID_SIZE - 1))
pz = int(np.clip(nz + vz, 0, self.GRID_SIZE - 1))
current_positions[npc_key] = (nx, nz)
self.predicted_npcs.append((px, pz, 1))
dx = nx - hx
dz = nz - hz
cheb = float(max(abs(dx), abs(dz)))
pred_cheb = float(max(abs(px - hx), abs(pz - hz)))
if best is None or cheb < best[0]:
best = (cheb, dx, dz, vx, vz, pred_cheb)
self.prev_npc_positions = current_positions
if best is None:
return
cheb, dx, dz, vx, vz, pred_cheb = best
self.nearest_npc_dx = float(dx)
self.nearest_npc_dz = float(dz)
self.nearest_npc_vx = float(vx)
self.nearest_npc_vz = float(vz)
self.nearest_npc_dist = float(cheb)
self.predicted_npc_dist = float(pred_cheb)
self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1
self.npc_predicted_danger = pred_cheb <= 1
def _update_recharge_mode(self):
"""Enter/exit low-battery recharge mode."""
battery_ratio = self.battery / max(self.battery_max, 1)
if not self.has_charger:
self.recharge_mode = False
self.charger_safety_margin = float(self.battery)
self.recharge_enter_margin = 0.0
self.recharge_leave_margin = 0.0
self.recharge_low_battery_ratio = 0.28
self.full_charge_leave_ratio = 0.96
self.low_battery = battery_ratio < self.recharge_low_battery_ratio
return
self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0))
self.charger_safety_buffer = self._charger_safety_buffer()
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
self.recharge_enter_margin = self._recharge_enter_margin()
self.recharge_leave_margin = self._recharge_leave_margin()
self.recharge_low_battery_ratio = self._recharge_low_battery_ratio()
self.full_charge_leave_ratio = self._full_charge_leave_ratio()
self.low_battery = battery_ratio < self.recharge_low_battery_ratio
should_recharge = self.charger_safety_margin <= self.recharge_enter_margin or self.low_battery
safe_to_leave = (
battery_ratio >= self.full_charge_leave_ratio
and self.charger_safety_margin >= self.recharge_leave_margin
)
if self.on_charger and safe_to_leave:
self.recharge_mode = False
elif should_recharge:
self.recharge_mode = True
elif self.recharge_mode and not safe_to_leave:
self.recharge_mode = True
else:
self.recharge_mode = False
if self.recharge_mode:
self.recharge_steps += 1
def _update_motion_health(self):
"""Track recharge-mode stalls so action masking can recover."""
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
if self.charge_delta <= 0:
self.stuck_steps += 1
self.stuck_count += 1
self.max_stuck_steps = max(self.max_stuck_steps, self.stuck_steps)
else:
self.stuck_steps = 0
else:
self.stuck_steps = 0
battery_ratio = self.battery / max(self.battery_max, 1)
battery_increased = self.battery > self.prev_battery + 1
maybe_fake_charger = self.on_charger and battery_ratio < 0.9 and self.charge_delta <= 0 and not battery_increased
self.fake_charger_steps = self.fake_charger_steps + 1 if maybe_fake_charger else 0
no_progress = (
self.recharge_mode
and self.has_charger
and self.charge_delta <= 0
and not battery_increased
and self.nearest_charger_path_dist >= self.last_nearest_charger_path_dist - 0.1
)
self.recharge_no_progress_steps = self.recharge_no_progress_steps + 1 if no_progress else 0
if self.step_no > 0 and min(self.nearest_npc_dist, self.predicted_npc_dist) <= 3:
self.npc_close_steps += 1
if self.step_no > 0 and (self.npc_danger or self.npc_predicted_danger):
self.npc_danger_steps += 1
if self.terminated and not self.truncated:
if self.battery <= 0 or self.remaining_charge <= 0:
self.battery_fail = 1
if self.npc_danger or self.npc_predicted_danger or self.nearest_npc_dist <= 1:
self.npc_collision = 1
def _need_recharge_escape(self):
return self.stuck_steps >= 2 or self.recharge_no_progress_steps >= 5 or self.fake_charger_steps >= 2
def _charger_safety_buffer(self):
# One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise.
base = max(18.0, 0.12 * float(self.battery_max))
distance_buffer = min(16.0, 0.18 * float(max(self.nearest_charger_range_dist, 0.0)))
obstacle_buffer = 12.0 * float(self.local_obstacle_ratio)
return float(np.clip(base + distance_buffer + obstacle_buffer, 18.0, 48.0))
def _recharge_enter_margin(self):
"""Adaptive margin for entering recharge mode before the battery is barely enough."""
base = max(5.0, 0.018 * float(self.battery_max))
path_margin = min(12.0, 0.08 * float(max(self.nearest_charger_path_dist, 0.0)))
obstacle_margin = 12.0 * float(self.local_obstacle_ratio)
recovery_margin = min(8.0, 1.5 * float(self.recharge_no_progress_steps + self.fake_charger_steps))
return float(np.clip(base + path_margin + obstacle_margin + recovery_margin, 4.0, 32.0))
def _recharge_leave_margin(self):
"""Adaptive safety margin required before leaving a charger."""
base = max(20.0, 0.08 * float(self.battery_max))
path_margin = min(18.0, 0.14 * float(max(self.nearest_charger_path_dist, 0.0)))
obstacle_margin = 12.0 * float(self.local_obstacle_ratio)
return float(np.clip(base + path_margin + obstacle_margin, 20.0, 64.0))
def _recharge_low_battery_ratio(self):
"""Adaptive low-battery ratio based on route length and local obstacle density."""
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
ratio = 0.25 + min(0.08, 0.40 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio))
if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0:
ratio += 0.02
return float(np.clip(ratio, 0.25, 0.40))
def _full_charge_leave_ratio(self):
"""Adaptive near-full threshold for leaving a charger."""
remaining_step_ratio = 1.0 - _norm(self.step_no, self.max_step)
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
ratio = 0.88 + 0.04 * remaining_step_ratio + min(0.02, 0.08 * path_pressure)
ratio += min(0.01, 0.04 * float(self.local_obstacle_ratio))
return float(np.clip(ratio, 0.88, 0.95))
def _recharge_risk_score(self):
"""Risk score in [0, 1] used to scale recharge rewards and penalties."""
if not self.has_charger:
return 0.0
battery_ratio = self.battery / max(self.battery_max, 1)
margin_deficit = max(0.0, self.recharge_enter_margin - self.charger_safety_margin)
margin_risk = margin_deficit / max(self.charger_safety_buffer + self.recharge_enter_margin, 1.0)
low_battery_risk = max(0.0, self.recharge_low_battery_ratio - battery_ratio)
low_battery_risk /= max(self.recharge_low_battery_ratio, 1e-6)
progress_risk = min(1.0, float(self.recharge_no_progress_steps) / 5.0)
return float(np.clip(0.55 * margin_risk + 0.35 * low_battery_risk + 0.10 * progress_risk, 0.0, 1.0))
def useful_charge_reward_weight(self):
"""Adaptive reward weight for charging that happens under real battery pressure."""
prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1)
prev_low_risk = max(0.0, self.recharge_low_battery_ratio - prev_battery_ratio)
prev_low_risk /= max(self.recharge_low_battery_ratio, 1e-6)
risk = max(self._recharge_risk_score(), prev_low_risk)
mode_bonus = 0.8 if self.was_recharge_mode or self.prev_low_battery else 0.0
return float(np.clip(3.0 + 2.8 * risk + mode_bonus, 3.0, 6.5))
def battery_fail_penalty(self):
"""Adaptive terminal penalty for running out of battery before max steps."""
step_ratio = _norm(self.step_no, self.max_step)
early_fail_risk = 1.0 - step_ratio
path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0)
risk = max(self._recharge_risk_score(), min(1.0, path_pressure))
return float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0))
def _min_charger_range_dist(self, x, z):
if not self.charger_rects:
return float(self.GRID_SIZE)
dists = []
for rx, rz, w, h in self.charger_rects:
dx, dz = self._relative_vector_to_rect(x, z, rx, rz, w, h)
dists.append(max(abs(dx), abs(dz)))
return float(min(dists))
def _is_charger_cell(self, x, z):
for rx, rz, w, h in self.charger_rects:
if rx <= x < rx + w and rz <= z < rz + h:
return True
return False
def _get_local_view_feature(self):
"""Local view feature: 21×21×6 multi-channel map.
Channels: obstacle, clean, dirt, visit count, NPC danger, charger.
"""
view = self._view_map
channels = np.zeros((self.MAP_CHANNELS, self.VIEW_SIZE, self.VIEW_SIZE), dtype=np.float32)
if view is None or view.shape[0] != self.VIEW_SIZE or view.shape[1] != self.VIEW_SIZE:
return channels.flatten()
channels[0] = (view == 0).astype(np.float32)
channels[1] = (view == 1).astype(np.float32)
channels[2] = (view == 2).astype(np.float32)
for ri in range(self.VIEW_SIZE):
for ci in range(self.VIEW_SIZE):
gx, gz = self._view_index_to_global(ri, ci)
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
continue
channels[3, ri, ci] = _norm(min(int(self.visit_count_map[gx, gz]), 10), 10)
channels[4, ri, ci] = 1.0 if self._is_npc_danger_cell(gx, gz, expanded=True) else 0.0
channels[5, ri, ci] = 1.0 if self._is_charger_cell(gx, gz) else 0.0
return channels.flatten()
def _get_global_state_feature(self):
"""Global state feature (66D).
Existing global state plus belief-map distances, action distance improvements,
known charger-route safety, and predicted NPC motion.
"""
step_norm = _norm(self.step_no, self.max_step)
battery_ratio = _norm(self.battery, self.battery_max)
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
remaining_dirt = 1.0 - cleaning_progress
hx, hz = self.cur_pos
pos_x_weak = self._weak_abs_position_feature(hx)
pos_z_weak = self._weak_abs_position_feature(hz)
# 4-directional ray to find nearest dirt
# 四方向射线找最近污渍距离
ray_dirs = [(0, -1), (1, 0), (0, 1), (-1, 0)] # N E S W
ray_dirt = []
max_ray = 30
for dx, dz in ray_dirs:
found = max_ray
for step in range(1, max_ray + 1):
gx = hx + dx * step
gz = hz + dz * step
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
break
cell = self._view_cell(dx * step, dz * step, default=0)
if cell == 2:
found = step
break
ray_dirt.append(_norm(found, max_ray))
# Nearest dirt path distance in the visible map.
# 视野内最近污渍路径距离。
self.last_nearest_dirt_dist = self.nearest_dirt_dist
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
charge_delta = 1.0 if self.charge_delta > 0 else 0.0
battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max)
visit_count_norm = _norm(min(self.current_visit_count, 10), 10)
step_cleaned_norm = _norm(self.step_cleaned_count, 9)
global_dirty_delta = _signed_norm(
np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -4.0, 4.0), 4.0
)
frontier_delta = _signed_norm(
np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -4.0, 4.0), 4.0
)
charger_margin_after_buffer = self.battery - self.nearest_charger_path_dist - self.charger_safety_buffer
base_features = np.array(
[
step_norm,
battery_ratio,
cleaning_progress,
remaining_dirt,
pos_x_weak,
pos_z_weak,
ray_dirt[0],
ray_dirt[1],
ray_dirt[2],
ray_dirt[3],
nearest_dirt_norm,
dirt_delta,
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
_norm(self.nearest_charger_path_dist, self.GRID_SIZE),
battery_margin_norm,
1.0 if self.low_battery else 0.0,
1.0 if self.recharge_mode else 0.0,
1.0 if self.on_charger else 0.0,
charge_delta,
_signed_norm(self.nearest_npc_dx, 20),
_signed_norm(self.nearest_npc_dz, 20),
_norm(self.nearest_npc_dist, 20),
1.0 if self.npc_danger else 0.0,
self.local_dirt_ratio,
self.local_obstacle_ratio,
visit_count_norm,
step_cleaned_norm,
_norm(self.global_dirty_path_dist, self.GRID_SIZE),
_norm(self.frontier_path_dist, self.GRID_SIZE),
global_dirty_delta,
frontier_delta,
self.known_ratio,
self.known_passable_ratio,
_norm(self.known_dirty_ratio, 1.0),
_norm(self.frontier_ratio, 1.0),
1.0 if self.charger_route_known else 0.0,
_signed_norm(charger_margin_after_buffer, self.battery_max),
_signed_norm(self.nearest_npc_vx, 1.0),
_signed_norm(self.nearest_npc_vz, 1.0),
_norm(self.predicted_npc_dist, 20),
1.0 if self.npc_predicted_danger else 0.0,
],
dtype=np.float32,
)
return np.concatenate(
[
base_features,
self.global_dirty_action_delta.astype(np.float32),
self.frontier_action_delta.astype(np.float32),
self.charger_action_delta.astype(np.float32),
]
)
def _weak_abs_position_feature(self, value):
pos_norm = _norm(value, self.GRID_SIZE)
return 0.5 + self.ABS_POS_FEATURE_SCALE * (pos_norm - 0.5)
def _calc_nearest_dirt_dist(self):
"""Find nearest dirt path distance from local view.
从局部视野中找最近污渍路径距离。
"""
dist = self._local_bfs_distances()
dirt_coords = np.argwhere(self._view_map == 2)
if len(dirt_coords) == 0:
return 200.0
best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
if best >= self.INF_DIST:
return 200.0
return best
def _local_bfs_distances(self, start_dx=0, start_dz=0):
"""Shortest path distances inside the current 21x21 local view."""
view = self._view_map
shape = view.shape
dist = np.full(shape, self.INF_DIST, dtype=np.float32)
start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz)
if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]):
return dist
if int(view[start_ri, start_ci]) == 0:
return dist
dist[start_ri, start_ci] = 0.0
queue = deque([(start_ri, start_ci)])
while queue:
ri, ci = queue.popleft()
base = dist[ri, ci]
for dx, dz in self.ACTION_DIRS:
nri = ri + dz
nci = ci + dx
if not (0 <= nri < shape[0] and 0 <= nci < shape[1]):
continue
if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST:
continue
if dx != 0 and dz != 0:
side_a = int(view[ri, nci]) != 0
side_b = int(view[nri, ci]) != 0
if not (side_a or side_b):
continue
dist[nri, nci] = base + 1.0
queue.append((nri, nci))
return dist
def _local_path_dist_to_charger(self, gx, gz):
"""Visible-map BFS distance from global x/z to nearest charger cell."""
best = self.INF_DIST
start_dx = gx - self.cur_pos[0]
start_dz = gz - self.cur_pos[1]
dist = self._local_bfs_distances(start_dx, start_dz)
for rx, rz, w, h in self.charger_rects:
for tx in range(rx, rx + w):
for tz in range(rz, rz + h):
dx = tx - self.cur_pos[0]
dz = tz - self.cur_pos[1]
ri, ci = self._view_delta_to_index(dx, dz)
if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]:
best = min(best, float(dist[ri, ci]))
return best
def _global_path_dist_to_charger(self, gx, gz):
"""Known-map BFS distance from a global cell to the nearest observed charger cell."""
dist = self._get_charger_reverse_dist()
return self._dist_at(dist, gx, gz, default=self.INF_DIST)
def _charger_move_distance(self, gx, gz):
"""Use known-map BFS to the charger when available, then visible BFS, then Chebyshev."""
path_dist = self._global_path_dist_to_charger(gx, gz)
if path_dist < self.INF_DIST:
return path_dist
path_dist = self._local_path_dist_to_charger(gx, gz)
if path_dist < self.INF_DIST:
return path_dist
return self._min_charger_range_dist(gx, gz)
def get_legal_action(self):
"""Return legal action mask (8D list).
返回合法动作掩码8D list
"""
legal = self._filter_blocked_actions(self._legal_act)
legal = self._filter_npc_danger_actions(legal)
safe_legal = list(legal)
if self.recharge_mode:
legal = self._filter_recharge_actions(legal)
legal = self._filter_recharge_escape_actions(legal, safe_legal)
elif self.on_charger and self.battery / max(self.battery_max, 1) >= self.full_charge_leave_ratio:
legal = self._filter_leave_charger_actions(legal)
return list(legal)
def _filter_blocked_actions(self, legal_action):
"""Filter actions that are visibly blocked in the 21x21 view."""
legal = [int(x) for x in legal_action]
hx, hz = self.cur_pos
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if legal[action] <= 0:
continue
if not self._is_visible_cell_passable(dx, dz):
legal[action] = 0
continue
if dx != 0 and dz != 0:
side_a = self._is_visible_cell_passable(dx, 0)
side_b = self._is_visible_cell_passable(0, dz)
if not (side_a or side_b):
legal[action] = 0
nx, nz = hx + dx, hz + dz
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
legal[action] = 0
return legal if any(legal) else [int(x) for x in legal_action]
def _is_visible_cell_passable(self, dx, dz):
"""Whether a relative 21x21-view cell is passable."""
cell = self._view_cell(dx, dz, default=None)
return True if cell is None else cell != 0
def _filter_npc_danger_actions(self, legal_action):
"""Avoid current and predicted NPC danger zones."""
if not self.npcs:
return list(legal_action)
hx, hz = self.cur_pos
safe = [int(x) for x in legal_action]
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if safe[action] <= 0:
continue
nx, nz = hx + dx, hz + dz
if self._is_npc_danger_cell(nx, nz, expanded=True):
safe[action] = 0
if any(safe):
return safe
hard_safe = [int(x) for x in legal_action]
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if hard_safe[action] <= 0:
continue
nx, nz = hx + dx, hz + dz
if self._is_npc_danger_cell(nx, nz, expanded=False):
hard_safe[action] = 0
return hard_safe if any(hard_safe) else list(legal_action)
def _is_npc_danger_cell(self, x, z, expanded=True):
for npc in self.npcs:
if not isinstance(npc, dict):
continue
pos = npc.get("pos") or {}
nx = int(pos.get("x", -999))
nz = int(pos.get("z", -999))
if abs(x - nx) <= 1 and abs(z - nz) <= 1:
return True
if expanded and abs(x - nx) <= 2 and abs(z - nz) <= 2 and self.nearest_npc_dist <= 4:
return True
if expanded:
for px, pz, radius in self.predicted_npcs:
if abs(x - px) <= radius and abs(z - pz) <= radius:
return True
if self.nearest_npc_dist <= 4 and abs(x - px) <= 2 and abs(z - pz) <= 2:
return True
return False
def _filter_recharge_actions(self, legal_action):
"""Restrict recharge-mode actions to safe moves toward the charger range."""
if not self.has_charger:
return list(legal_action)
hx, hz = self.cur_pos
current_range_dist = self._min_charger_range_dist(hx, hz)
current_move_dist = self._charger_move_distance(hx, hz)
scored = []
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if legal_action[action] <= 0:
continue
nx, nz = hx + dx, hz + dz
next_dist = self._charger_move_distance(nx, nz)
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
next_range_dist = self._min_charger_range_dist(nx, nz)
scored.append((next_dist, alignment, next_range_dist, action))
if not scored:
return list(legal_action)
# When already inside the charger range, stay inside until recharge mode exits.
# 已经在充电区域内时,回充模式退出前不要离开充电区域。
confirmed_charger = self.charge_delta > 0 or self.battery > self.prev_battery + 1
if current_range_dist <= 0.0 and confirmed_charger:
stay = [0] * 8
for _, _, next_range_dist, action in scored:
if next_range_dist <= 0.0:
stay[action] = 1
if any(stay):
return stay
recharge = [0] * 8
best_next_dist = min(item[0] for item in scored)
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
for next_dist, _, _, action in ranked:
if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1:
recharge[action] = 1
if sum(recharge) >= 3:
break
if not any(recharge):
for _, _, _, action in ranked[: min(3, len(ranked))]:
recharge[action] = 1
return recharge if any(recharge) else list(legal_action)
def _filter_recharge_escape_actions(self, recharge_action, safe_action):
"""Escape repeated no-move states during low-battery recharge."""
if not self._need_recharge_escape():
return list(recharge_action)
hx, hz = self.cur_pos
current_dist = self._charger_move_distance(hx, hz)
ranked = []
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if safe_action[action] <= 0:
continue
nx, nz = hx + dx, hz + dz
next_dist = self._charger_move_distance(nx, nz)
visit_count = 0
if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE:
visit_count = int(self.visit_count_map[nx, nz])
failed_action_penalty = 6.0 if action == self.last_action and self.stuck_steps >= 2 else 0.0
no_progress_penalty = 1.5 if next_dist > current_dist + 0.1 else 0.0
ranked.append((next_dist + 0.05 * min(visit_count, 20) + failed_action_penalty + no_progress_penalty, action))
if not ranked:
return list(recharge_action)
self.recharge_escape_count += 1
ranked.sort()
escape = [0] * 8
for _, action in ranked[: min(4, len(ranked))]:
escape[action] = 1
if self.stuck_steps >= 2 and sum(escape) > 1 and 0 <= self.last_action < 8:
escape[self.last_action] = 0
return escape if any(escape) else list(recharge_action)
def _filter_leave_charger_actions(self, legal_action):
"""Prefer moves that leave charger range when battery is healthy."""
if not self.has_charger:
return list(legal_action)
hx, hz = self.cur_pos
current_dist = self._min_charger_range_dist(hx, hz)
scored = []
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if legal_action[action] <= 0:
continue
nx, nz = hx + dx, hz + dz
next_dist = self._min_charger_range_dist(nx, nz)
away_score = -(dx * self.nearest_charger_center_dx + dz * self.nearest_charger_center_dz)
scored.append((next_dist - current_dist, away_score, action))
if not scored:
return list(legal_action)
best_escape = max(item[0] for item in scored)
leave = [0] * 8
if best_escape > 0:
for escape, _, action in scored:
if escape >= best_escape - 0.1:
leave[action] = 1
else:
best_away = max(item[1] for item in scored)
for _, away_score, action in scored:
if away_score >= best_away:
leave[action] = 1
return leave if any(leave) else list(legal_action)
def feature_process(self, env_obs, last_action):
"""Generate feature vector, legal action mask, and scalar reward.
生成特征向量、合法动作掩码和标量奖励。
"""
self.pb2struct(env_obs, last_action)
local_view = self._get_local_view_feature() # 2646D
global_state = self._get_global_state_feature() # 66D
legal_action = self.get_legal_action() # 8D
last_action_feature = np.zeros(8, dtype=np.float32)
if 0 <= last_action < 8:
last_action_feature[last_action] = 1.0
feature = np.concatenate([local_view, global_state, last_action_feature])
reward = self.reward_process()
return feature, legal_action, reward
def reward_process(self):
# Cleaning reward / 清扫奖励
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step
cleaning_scale = 0.2 if self.recharge_mode else 0.7
cleaning_reward = cleaning_scale * cleaned_cells
# Step penalty / 时间惩罚
step_penalty = -0.002
# Recharge guidance only activates when battery safety is the bottleneck.
# 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。
charge_reward = 0.0
battery_ratio = self.battery / max(self.battery_max, 1)
prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1)
useful_charge = self.charge_delta > 0 and (
self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.45
)
if useful_charge:
charge_reward += self.useful_charge_reward_weight()
elif self.charge_delta > 0 and battery_ratio > 0.65:
charge_reward -= 0.25 * min(self.charge_delta, 3)
if self.has_charger and (self.recharge_mode or self.low_battery):
dist_delta = float(
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
)
recharge_risk = self._recharge_risk_score()
approach_scale = 0.07 + 0.06 * recharge_risk
retreat_scale = 0.035 + 0.045 * recharge_risk
charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta
if self.charger_safety_margin < self.recharge_enter_margin:
safety_shortage = self.recharge_enter_margin - self.charger_safety_margin
charge_reward -= min(0.55, safety_shortage / max(self.battery_max, 1))
elif self.on_charger and battery_ratio > 0.65:
charge_reward -= 0.08
# Encourage covering new passable cells and mildly discourage loops.
# 鼓励探索新格子,轻微惩罚反复绕圈。
if self.recharge_mode:
exploration_reward = 0.0
else:
exploration_reward = 0.004 if self.is_new_cell else -0.0015 * min(self.current_visit_count, 6)
if self.global_dirty_path_dist < self.GRID_SIZE:
dirty_progress = np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -3.0, 3.0)
exploration_reward += 0.008 * dirty_progress
elif self.frontier_path_dist < self.GRID_SIZE:
frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
exploration_reward += 0.005 * frontier_progress
# Collision/stuck signal: invalid moves waste both step and battery.
# 撞墙/原地不动会浪费步数和电量。
stuck_penalty = 0.0
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
stuck_penalty = -0.03
if self.recharge_mode:
stuck_penalty -= 0.02 * min(self.stuck_steps, 5)
npc_penalty = 0.0
if self.npc_danger:
npc_penalty -= 4.0
elif self.npc_predicted_danger:
npc_penalty -= 0.4
elif self.nearest_npc_dist <= 3:
npc_penalty -= 0.05 * (4 - self.nearest_npc_dist)
return (
cleaning_reward
+ charge_reward
+ exploration_reward
+ stuck_penalty
+ npc_penalty
+ step_penalty
)