1733 lines
72 KiB
Python
1733 lines
72 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: UTF-8 -*-
|
||
###########################################################################
|
||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||
###########################################################################
|
||
"""
|
||
Author: Tencent AI Arena Authors
|
||
|
||
Feature preprocessor for Robot Vacuum.
|
||
清扫大作战特征预处理器。
|
||
"""
|
||
|
||
import os
|
||
from collections import deque
|
||
|
||
import numpy as np
|
||
|
||
|
||
def _norm(v, v_max, v_min=0.0):
|
||
"""Normalize value to [0, 1].
|
||
|
||
将值线性归一化到 [0, 1]。
|
||
"""
|
||
v = float(np.clip(v, v_min, v_max))
|
||
if v_max == v_min:
|
||
return 0.0
|
||
return (v - v_min) / (v_max - v_min)
|
||
|
||
|
||
def _signed_norm(v, v_max):
|
||
"""Normalize signed value to [-1, 1]."""
|
||
if v_max <= 0:
|
||
return 0.0
|
||
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
|
||
|
||
|
||
def _as_dict(value):
|
||
"""Return a dict for optional nested observation fields."""
|
||
return value if isinstance(value, dict) else {}
|
||
|
||
|
||
class Preprocessor:
|
||
"""Feature preprocessor for Robot Vacuum.
|
||
|
||
清扫大作战特征预处理器。
|
||
"""
|
||
|
||
GRID_SIZE = 128
|
||
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
|
||
VIEW_SIZE = 21
|
||
MAP_CHANNELS = 6
|
||
PLANNER_UPDATE_INTERVAL = 4
|
||
ACTION_DIRS = (
|
||
(1, 0),
|
||
(1, -1),
|
||
(0, -1),
|
||
(-1, -1),
|
||
(-1, 0),
|
||
(-1, 1),
|
||
(0, 1),
|
||
(1, 1),
|
||
)
|
||
INF_DIST = 1e6
|
||
ABS_POS_FEATURE_SCALE = 0.2
|
||
|
||
def __init__(self):
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
"""Reset all internal state at episode start.
|
||
|
||
对局开始时重置所有状态。
|
||
"""
|
||
self.map_id = -1
|
||
self.step_no = 0
|
||
self.battery = 600
|
||
self.battery_max = 600
|
||
|
||
self.cur_pos = (0, 0)
|
||
self.prev_pos = None
|
||
self.has_position_history = False
|
||
self.current_visit_count = 0
|
||
self.is_new_cell = False
|
||
self.last_action = -1
|
||
self.stuck_steps = 0
|
||
self.recharge_no_progress_steps = 0
|
||
self.fake_charger_steps = 0
|
||
self.stuck_count = 0
|
||
self.max_stuck_steps = 0
|
||
self.recharge_escape_count = 0
|
||
|
||
self.dirt_cleaned = 0
|
||
self.last_dirt_cleaned = 0
|
||
self.total_dirt = 1
|
||
self.total_score = 0
|
||
self.clean_score = 0
|
||
self.step_cleaned_count = 0
|
||
self.max_step = 1000
|
||
|
||
# Global belief maps indexed by [x, z].
|
||
# 全局 belief map,索引为 [x, z]。
|
||
self.known_map = np.full((self.GRID_SIZE, self.GRID_SIZE), -1, dtype=np.int8)
|
||
self.passable_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||
self.frontier_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||
self.dirty_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||
self._dirty_reverse_dist = None
|
||
self._frontier_reverse_dist = None
|
||
self._charger_reverse_dist = None
|
||
self._path_cache_dirty = True
|
||
self._planner_last_update_step = -self.PLANNER_UPDATE_INTERVAL
|
||
self.known_ratio = 0.0
|
||
self.known_passable_ratio = 0.0
|
||
self.known_dirty_ratio = 0.0
|
||
self.frontier_ratio = 0.0
|
||
self.global_dirty_path_dist = float(self.GRID_SIZE)
|
||
self.last_global_dirty_path_dist = float(self.GRID_SIZE)
|
||
self.frontier_path_dist = float(self.GRID_SIZE)
|
||
self.last_frontier_path_dist = float(self.GRID_SIZE)
|
||
self.global_dirty_action_delta = np.zeros(8, dtype=np.float32)
|
||
self.frontier_action_delta = np.zeros(8, dtype=np.float32)
|
||
self.charger_action_delta = np.zeros(8, dtype=np.float32)
|
||
self.charger_route_known = False
|
||
self.charger_route_source = "none"
|
||
|
||
# Nearest dirt path distance in the current local view.
|
||
# 当前局部视野内最近污渍路径距离。
|
||
self.nearest_dirt_dist = 200.0
|
||
self.last_nearest_dirt_dist = 200.0
|
||
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
|
||
|
||
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
||
self._legal_act = [1] * 8
|
||
self.terminated = False
|
||
self.truncated = False
|
||
|
||
self.remaining_charge = 0
|
||
self.prev_battery = 600
|
||
self.prev_battery_max = 600
|
||
self.prev_on_charger = False
|
||
self.prev_low_battery = False
|
||
self.was_recharge_mode = False
|
||
self.charge_count = 0
|
||
self.last_charge_count = 0
|
||
self.charge_delta = 0
|
||
self.nearest_charger_dx = 0.0
|
||
self.nearest_charger_dz = 0.0
|
||
self.nearest_charger_center_dx = 0.0
|
||
self.nearest_charger_center_dz = 0.0
|
||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
|
||
self.nearest_charger_path_dist = float(self.GRID_SIZE)
|
||
self.last_nearest_charger_path_dist = float(self.GRID_SIZE)
|
||
self.charger_energy_cost = float(self.GRID_SIZE)
|
||
self.charger_safety_buffer = 0.0
|
||
self.charger_safety_margin = 0.0
|
||
self.recharge_enter_margin = 0.0
|
||
self.recharge_leave_margin = 0.0
|
||
self.recharge_low_battery_ratio = 0.28
|
||
self.full_charge_leave_ratio = 0.96
|
||
self.battery_margin = 0.0
|
||
self.has_charger = False
|
||
self.low_battery = False
|
||
self.on_charger = False
|
||
self.charger_rects = []
|
||
self.recharge_mode = False
|
||
self.recharge_steps = 0
|
||
|
||
self.nearest_npc_dx = 0.0
|
||
self.nearest_npc_dz = 0.0
|
||
self.nearest_npc_vx = 0.0
|
||
self.nearest_npc_vz = 0.0
|
||
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||
self.predicted_npc_dist = float(self.GRID_SIZE)
|
||
self.npc_danger = False
|
||
self.npc_predicted_danger = False
|
||
self.npcs = []
|
||
self.prev_npc_positions = {}
|
||
self.predicted_npcs = []
|
||
self.npc_close_steps = 0
|
||
self.npc_danger_steps = 0
|
||
self.npc_collision = 0
|
||
self.battery_fail = 0
|
||
|
||
self.local_dirt_ratio = 0.0
|
||
self.local_obstacle_ratio = 0.0
|
||
self.reward_profile = os.environ.get("ROBOT_VACUUM_REWARD_PROFILE", "current").strip().lower() or "current"
|
||
self._reset_diagnostics()
|
||
|
||
def _reset_diagnostics(self):
|
||
"""Reset episode-local diagnostic counters."""
|
||
self.diag_mask_steps = 0
|
||
self.diag_mask_count_sums = {
|
||
"raw": 0,
|
||
"blocked": 0,
|
||
"npc": 0,
|
||
"recharge": 0,
|
||
"escape": 0,
|
||
"leave": 0,
|
||
"final": 0,
|
||
}
|
||
self.diag_mask_changed_steps = {
|
||
"blocked": 0,
|
||
"npc": 0,
|
||
"recharge": 0,
|
||
"escape": 0,
|
||
"leave": 0,
|
||
}
|
||
self.diag_mask_active_steps = {
|
||
"recharge": 0,
|
||
"leave": 0,
|
||
}
|
||
self.diag_one_action_steps = 0
|
||
self.diag_two_or_less_action_steps = 0
|
||
self.diag_zero_final_steps = 0
|
||
self.diag_action_hist = [0] * 8
|
||
|
||
def pb2struct(self, env_obs, last_action):
|
||
"""Parse and cache essential fields from observation dict.
|
||
|
||
从 env_obs 字典中提取并缓存所有需要的状态量。
|
||
"""
|
||
env_obs = _as_dict(env_obs)
|
||
observation = _as_dict(env_obs.get("observation"))
|
||
frame_state = _as_dict(observation.get("frame_state"))
|
||
extra_info = _as_dict(env_obs.get("extra_info"))
|
||
extra_frame_state = _as_dict(extra_info.get("frame_state"))
|
||
env_info = _as_dict(observation.get("env_info"))
|
||
hero = frame_state.get("heroes") or {}
|
||
if isinstance(hero, list):
|
||
hero = hero[0] if hero else {}
|
||
hero = _as_dict(hero)
|
||
|
||
self.last_action = int(last_action)
|
||
map_id_value = extra_info.get("map_id", env_info.get("map_id", self.map_id))
|
||
try:
|
||
self.map_id = int(map_id_value)
|
||
except (TypeError, ValueError):
|
||
pass
|
||
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
||
self.terminated = bool(env_obs.get("terminated", False))
|
||
self.truncated = bool(env_obs.get("truncated", False))
|
||
self.prev_battery = self.battery
|
||
self.prev_battery_max = self.battery_max
|
||
self.prev_on_charger = self.on_charger
|
||
self.prev_low_battery = self.low_battery
|
||
self.was_recharge_mode = self.recharge_mode
|
||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||
hero_pos = _as_dict(hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]})
|
||
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
|
||
self.has_position_history = True
|
||
|
||
hx, hz = self.cur_pos
|
||
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
|
||
self.current_visit_count = int(self.visit_count_map[hx, hz])
|
||
self.is_new_cell = self.current_visit_count == 0
|
||
self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max)
|
||
else:
|
||
self.current_visit_count = 0
|
||
self.is_new_cell = False
|
||
|
||
# Battery / 电量
|
||
self.battery = int(hero.get("battery", env_info.get("remaining_charge", self.battery)))
|
||
self.battery_max = max(int(hero.get("battery_max", env_info.get("battery_max", self.battery_max))), 1)
|
||
self.remaining_charge = int(env_info.get("remaining_charge", self.battery))
|
||
self.max_step = max(int(env_info.get("max_step", self.max_step)), 1)
|
||
|
||
# Cleaning progress / 清扫进度
|
||
self.last_dirt_cleaned = self.dirt_cleaned
|
||
self.dirt_cleaned = int(hero.get("dirt_cleaned", env_info.get("clean_score", self.dirt_cleaned)))
|
||
self.total_dirt = max(int(env_info.get("total_dirt", self.total_dirt)), 1)
|
||
self.total_score = int(env_info.get("total_score", self.total_score))
|
||
self.clean_score = int(env_info.get("clean_score", self.dirt_cleaned))
|
||
step_cleaned_cells = env_info.get("step_cleaned_cells") or []
|
||
self.step_cleaned_count = len(step_cleaned_cells)
|
||
|
||
# Charge progress / 充电进度
|
||
self.last_charge_count = self.charge_count
|
||
self.charge_count = int(env_info.get("charge_count", self.charge_count))
|
||
self.charge_delta = max(0, self.charge_count - self.last_charge_count)
|
||
|
||
# Legal actions / 合法动作
|
||
raw_legal_act = observation.get("legal_action") or observation.get("legal_act") or [1] * 8
|
||
self._legal_act = [int(x) for x in raw_legal_act[:8]]
|
||
if len(self._legal_act) < 8:
|
||
self._legal_act.extend([1] * (8 - len(self._legal_act)))
|
||
|
||
# Local view map (21×21) / 局部视野地图
|
||
map_info = observation.get("map_info")
|
||
if map_info is not None:
|
||
self._view_map = np.array(map_info, dtype=np.float32)
|
||
hx, hz = self.cur_pos
|
||
self._update_passable(hx, hz)
|
||
self._mark_cleaned_cells(step_cleaned_cells)
|
||
self._update_local_map_stats()
|
||
|
||
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
||
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
|
||
organs = organs if isinstance(organs, (list, tuple)) else []
|
||
self.npcs = list(npcs) if isinstance(npcs, (list, tuple)) else []
|
||
self._update_charger_state(hx, hz, organs)
|
||
self._update_npc_state(hx, hz, self.npcs)
|
||
self._update_global_planning_state()
|
||
self._update_recharge_mode()
|
||
self._update_motion_health()
|
||
|
||
def _update_passable(self, hx, hz):
|
||
"""Write local view into global passable map.
|
||
|
||
将局部视野写入全局通行地图。
|
||
"""
|
||
view = self._view_map
|
||
vsize = view.shape[0]
|
||
half = vsize // 2
|
||
|
||
for ri in range(vsize):
|
||
for ci in range(vsize):
|
||
gx = hx + ci - half
|
||
gz = hz + ri - half
|
||
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
|
||
cell = int(view[ri, ci])
|
||
self.known_map[gx, gz] = cell
|
||
self.passable_map[gx, gz] = 1 if cell != 0 else 0
|
||
self.dirty_map[gx, gz] = 1 if cell == 2 else 0
|
||
|
||
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
|
||
self.known_map[hx, hz] = 1
|
||
self.passable_map[hx, hz] = 1
|
||
self.dirty_map[hx, hz] = 0
|
||
self._clear_path_caches()
|
||
|
||
def _mark_cleaned_cells(self, step_cleaned_cells):
|
||
"""Mark cells cleaned in the current step in the global belief map."""
|
||
for pos in step_cleaned_cells or []:
|
||
pos = _as_dict(pos)
|
||
x = int(pos.get("x", -1))
|
||
z = int(pos.get("z", -1))
|
||
if 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE:
|
||
self.known_map[x, z] = 1
|
||
self.passable_map[x, z] = 1
|
||
self.dirty_map[x, z] = 0
|
||
self._clear_path_caches()
|
||
|
||
def _clear_path_caches(self):
|
||
self._path_cache_dirty = True
|
||
|
||
def _drop_path_caches(self):
|
||
self._dirty_reverse_dist = None
|
||
self._frontier_reverse_dist = None
|
||
self._charger_reverse_dist = None
|
||
|
||
def _view_index_to_global(self, ri, ci):
|
||
"""Convert local view row/col to global x/z coordinates."""
|
||
half = self._view_map.shape[0] // 2
|
||
hx, hz = self.cur_pos
|
||
return hx + ci - half, hz + ri - half
|
||
|
||
def _view_delta_to_index(self, dx, dz):
|
||
"""Convert global-coordinate dx/dz to local view row/col."""
|
||
return self.VIEW_HALF + dz, self.VIEW_HALF + dx
|
||
|
||
def _view_cell(self, dx, dz, default=None):
|
||
"""Read a local-view cell by global-coordinate delta.
|
||
|
||
`map_info` is row-major: row follows z/down, col follows x/right.
|
||
"""
|
||
ri, ci = self._view_delta_to_index(dx, dz)
|
||
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
||
return default
|
||
return int(self._view_map[ri, ci])
|
||
|
||
def _update_local_map_stats(self):
|
||
"""Cache coarse 21x21 map statistics."""
|
||
view = self._view_map
|
||
if view is None or view.size == 0:
|
||
self.local_dirt_ratio = 0.0
|
||
self.local_obstacle_ratio = 0.0
|
||
return
|
||
total = float(view.size)
|
||
self.local_dirt_ratio = float(np.sum(view == 2) / total)
|
||
self.local_obstacle_ratio = float(np.sum(view == 0) / total)
|
||
|
||
def _update_global_planning_state(self):
|
||
"""Refresh global coverage, frontier, and action-improvement features."""
|
||
self.last_global_dirty_path_dist = self.global_dirty_path_dist
|
||
self.last_frontier_path_dist = self.frontier_path_dist
|
||
|
||
self._update_frontier_map()
|
||
hx, hz = self.cur_pos
|
||
|
||
should_refresh_paths = (
|
||
self._dirty_reverse_dist is None
|
||
or self._frontier_reverse_dist is None
|
||
or (self.has_charger and self._charger_reverse_dist is None)
|
||
or (
|
||
self._path_cache_dirty
|
||
and self.step_no - self._planner_last_update_step >= self.PLANNER_UPDATE_INTERVAL
|
||
)
|
||
)
|
||
if should_refresh_paths:
|
||
self._drop_path_caches()
|
||
self._planner_last_update_step = self.step_no
|
||
self._path_cache_dirty = False
|
||
|
||
known_count = float(np.sum(self.known_map >= 0))
|
||
passable_count = float(np.sum(self.passable_map > 0))
|
||
dirty_count = float(np.sum(self.dirty_map > 0))
|
||
frontier_count = float(np.sum(self.frontier_map > 0))
|
||
total_cells = float(self.GRID_SIZE * self.GRID_SIZE)
|
||
self.known_ratio = known_count / total_cells
|
||
self.known_passable_ratio = passable_count / total_cells
|
||
self.known_dirty_ratio = dirty_count / max(float(self.total_dirt), 1.0)
|
||
self.frontier_ratio = frontier_count / max(passable_count, 1.0)
|
||
|
||
dirty_dist = self._get_dirty_reverse_dist()
|
||
frontier_dist = self._get_frontier_reverse_dist()
|
||
charger_dist = self._get_charger_reverse_dist()
|
||
|
||
self.global_dirty_path_dist = self._dist_at(dirty_dist, hx, hz, default=float(self.GRID_SIZE))
|
||
self.frontier_path_dist = self._dist_at(frontier_dist, hx, hz, default=float(self.GRID_SIZE))
|
||
if charger_dist is not None:
|
||
charger_path = self._dist_at(charger_dist, hx, hz, default=self.INF_DIST)
|
||
if charger_path < self.INF_DIST:
|
||
self.nearest_charger_path_dist = min(self.nearest_charger_path_dist, float(charger_path))
|
||
self.charger_energy_cost = self.nearest_charger_path_dist
|
||
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
||
self.charger_route_known = True
|
||
self.charger_route_source = "global"
|
||
|
||
self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist)
|
||
self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist)
|
||
current_charger = self._dist_at(charger_dist, hx, hz, default=self.nearest_charger_path_dist)
|
||
self.charger_action_delta = self._action_distance_delta(charger_dist, current_charger)
|
||
|
||
def _update_frontier_map(self):
|
||
"""Mark known passable cells adjacent to unseen space as exploration frontiers."""
|
||
self.frontier_map.fill(0)
|
||
passable_coords = np.argwhere(self.passable_map > 0)
|
||
for x, z in passable_coords:
|
||
x = int(x)
|
||
z = int(z)
|
||
for dx, dz in ((1, 0), (-1, 0), (0, 1), (0, -1)):
|
||
nx, nz = x + dx, z + dz
|
||
if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE and self.known_map[nx, nz] < 0:
|
||
self.frontier_map[x, z] = 1
|
||
break
|
||
|
||
def _get_dirty_reverse_dist(self):
|
||
if self._dirty_reverse_dist is None:
|
||
targets = np.argwhere((self.dirty_map > 0) & (self.passable_map > 0))
|
||
self._dirty_reverse_dist = self._global_bfs_from_targets(targets)
|
||
return self._dirty_reverse_dist
|
||
|
||
def _get_frontier_reverse_dist(self):
|
||
if self._frontier_reverse_dist is None:
|
||
targets = np.argwhere((self.frontier_map > 0) & (self.passable_map > 0))
|
||
self._frontier_reverse_dist = self._global_bfs_from_targets(targets)
|
||
return self._frontier_reverse_dist
|
||
|
||
def _get_charger_reverse_dist(self):
|
||
if not self.charger_rects:
|
||
return None
|
||
if self._charger_reverse_dist is None:
|
||
self._charger_reverse_dist = self._global_bfs_from_targets(self._charger_target_cells())
|
||
return self._charger_reverse_dist
|
||
|
||
def _charger_target_cells(self):
|
||
targets = []
|
||
for rx, rz, w, h in self.charger_rects:
|
||
for x in range(rx, rx + w):
|
||
for z in range(rz, rz + h):
|
||
if self._is_known_passable(x, z):
|
||
targets.append((x, z))
|
||
return targets
|
||
|
||
def _global_bfs_from_targets(self, targets):
|
||
"""Reverse BFS over the accumulated known passable map."""
|
||
dist = np.full((self.GRID_SIZE, self.GRID_SIZE), self.INF_DIST, dtype=np.float32)
|
||
queue = deque()
|
||
for target in targets:
|
||
if len(target) < 2:
|
||
continue
|
||
x = int(target[0])
|
||
z = int(target[1])
|
||
if not self._is_known_passable(x, z) or dist[x, z] == 0.0:
|
||
continue
|
||
dist[x, z] = 0.0
|
||
queue.append((x, z))
|
||
|
||
while queue:
|
||
x, z = queue.popleft()
|
||
base = dist[x, z]
|
||
for dx, dz in self.ACTION_DIRS:
|
||
nx, nz = x + dx, z + dz
|
||
if not self._can_global_move(x, z, dx, dz):
|
||
continue
|
||
if dist[nx, nz] < self.INF_DIST:
|
||
continue
|
||
dist[nx, nz] = base + 1.0
|
||
queue.append((nx, nz))
|
||
return dist
|
||
|
||
def _is_known_passable(self, x, z):
|
||
return 0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE and self.passable_map[x, z] > 0
|
||
|
||
def _can_global_move(self, x, z, dx, dz):
|
||
nx, nz = x + dx, z + dz
|
||
if not self._is_known_passable(x, z) or not self._is_known_passable(nx, nz):
|
||
return False
|
||
if dx != 0 and dz != 0:
|
||
return self._is_known_passable(x + dx, z) or self._is_known_passable(x, z + dz)
|
||
return True
|
||
|
||
def _dist_at(self, dist, x, z, default=None):
|
||
if default is None:
|
||
default = self.INF_DIST
|
||
if dist is None or not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
|
||
return float(default)
|
||
value = float(dist[x, z])
|
||
return value if value < self.INF_DIST else float(default)
|
||
|
||
def _action_distance_delta(self, dist, current_dist):
|
||
delta = np.zeros(8, dtype=np.float32)
|
||
if dist is None or current_dist >= self.INF_DIST:
|
||
return delta
|
||
hx, hz = self.cur_pos
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
nx, nz = hx + dx, hz + dz
|
||
if not self._can_global_move(hx, hz, dx, dz):
|
||
continue
|
||
next_dist = self._dist_at(dist, nx, nz, default=self.INF_DIST)
|
||
if next_dist >= self.INF_DIST:
|
||
continue
|
||
delta[action] = np.float32(np.clip((current_dist - next_dist) / 4.0, -1.0, 1.0))
|
||
return delta
|
||
|
||
def _update_charger_state(self, hx, hz, organs):
|
||
"""Find nearest charger and cache distance/direction features."""
|
||
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
||
self.last_nearest_charger_path_dist = self.nearest_charger_path_dist
|
||
self.has_charger = False
|
||
self.on_charger = False
|
||
self.nearest_charger_dx = 0.0
|
||
self.nearest_charger_dz = 0.0
|
||
self.nearest_charger_center_dx = 0.0
|
||
self.nearest_charger_center_dz = 0.0
|
||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||
self.nearest_charger_path_dist = float(self.GRID_SIZE)
|
||
self.charger_energy_cost = float(self.GRID_SIZE)
|
||
self.charger_safety_buffer = 0.0
|
||
self.charger_safety_margin = 0.0
|
||
self.charger_rects = []
|
||
self.charger_route_known = False
|
||
self.charger_route_source = "none"
|
||
|
||
best = None
|
||
for organ in organs:
|
||
if not isinstance(organ, dict):
|
||
continue
|
||
if int(organ.get("sub_type", 1)) != 1:
|
||
continue
|
||
pos = organ.get("pos") or {}
|
||
ox = int(pos.get("x", 0))
|
||
oz = int(pos.get("z", 0))
|
||
w = max(int(organ.get("w", 3)), 1)
|
||
h = max(int(organ.get("h", 3)), 1)
|
||
|
||
for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)):
|
||
self.charger_rects.append((rx, rz, w, h))
|
||
dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h)
|
||
dist = float(np.sqrt(dx * dx + dz * dz))
|
||
range_dist = float(max(abs(dx), abs(dz)))
|
||
center_dx = (rx + (w - 1) * 0.5) - hx
|
||
center_dz = (rz + (h - 1) * 0.5) - hz
|
||
if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]):
|
||
best = (range_dist, dist, dx, dz, center_dx, center_dz)
|
||
|
||
if best is None:
|
||
self.battery_margin = float(self.battery)
|
||
self.charger_safety_margin = float(self.battery)
|
||
return
|
||
|
||
range_dist, dist, dx, dz, center_dx, center_dz = best
|
||
self.has_charger = True
|
||
self.nearest_charger_dx = float(dx)
|
||
self.nearest_charger_dz = float(dz)
|
||
self.nearest_charger_center_dx = float(center_dx)
|
||
self.nearest_charger_center_dz = float(center_dz)
|
||
self.nearest_charger_dist = float(dist)
|
||
self.nearest_charger_range_dist = float(range_dist)
|
||
path_dist = self._global_path_dist_to_charger(hx, hz)
|
||
if path_dist < self.INF_DIST:
|
||
self.charger_route_known = True
|
||
self.charger_route_source = "global"
|
||
else:
|
||
path_dist = self._local_path_dist_to_charger(hx, hz)
|
||
if path_dist < self.INF_DIST:
|
||
self.charger_route_known = True
|
||
self.charger_route_source = "local"
|
||
else:
|
||
self.charger_route_known = False
|
||
self.charger_route_source = "range"
|
||
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
|
||
self.charger_energy_cost = self.nearest_charger_path_dist
|
||
self.on_charger = range_dist <= 0.0
|
||
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
||
|
||
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
|
||
"""Relative vector from point to the nearest cell in a rectangle."""
|
||
if x < rx:
|
||
dx = rx - x
|
||
elif x > rx + w - 1:
|
||
dx = x - (rx + w - 1)
|
||
else:
|
||
dx = 0
|
||
|
||
if z < rz:
|
||
dz = rz - z
|
||
elif z > rz + h - 1:
|
||
dz = z - (rz + h - 1)
|
||
else:
|
||
dz = 0
|
||
return float(dx), float(dz)
|
||
|
||
def _update_npc_state(self, hx, hz, npcs):
|
||
"""Find nearest NPC and cache safety features."""
|
||
self.nearest_npc_dx = 0.0
|
||
self.nearest_npc_dz = 0.0
|
||
self.nearest_npc_vx = 0.0
|
||
self.nearest_npc_vz = 0.0
|
||
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||
self.predicted_npc_dist = float(self.GRID_SIZE)
|
||
self.npc_danger = False
|
||
self.npc_predicted_danger = False
|
||
self.predicted_npcs = []
|
||
|
||
best = None
|
||
current_positions = {}
|
||
for npc in npcs:
|
||
if not isinstance(npc, dict):
|
||
continue
|
||
pos = npc.get("pos") or {}
|
||
nx = int(pos.get("x", 0))
|
||
nz = int(pos.get("z", 0))
|
||
npc_key = str(npc.get("npc_id", npc.get("idx", len(current_positions))))
|
||
prev_pos = self.prev_npc_positions.get(npc_key)
|
||
vx = 0
|
||
vz = 0
|
||
if prev_pos is not None:
|
||
vx = int(np.clip(nx - prev_pos[0], -1, 1))
|
||
vz = int(np.clip(nz - prev_pos[1], -1, 1))
|
||
px = int(np.clip(nx + vx, 0, self.GRID_SIZE - 1))
|
||
pz = int(np.clip(nz + vz, 0, self.GRID_SIZE - 1))
|
||
current_positions[npc_key] = (nx, nz)
|
||
self.predicted_npcs.append((px, pz, 1))
|
||
|
||
dx = nx - hx
|
||
dz = nz - hz
|
||
cheb = float(max(abs(dx), abs(dz)))
|
||
pred_cheb = float(max(abs(px - hx), abs(pz - hz)))
|
||
if best is None or cheb < best[0]:
|
||
best = (cheb, dx, dz, vx, vz, pred_cheb)
|
||
|
||
self.prev_npc_positions = current_positions
|
||
if best is None:
|
||
return
|
||
|
||
cheb, dx, dz, vx, vz, pred_cheb = best
|
||
self.nearest_npc_dx = float(dx)
|
||
self.nearest_npc_dz = float(dz)
|
||
self.nearest_npc_vx = float(vx)
|
||
self.nearest_npc_vz = float(vz)
|
||
self.nearest_npc_dist = float(cheb)
|
||
self.predicted_npc_dist = float(pred_cheb)
|
||
self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1
|
||
self.npc_predicted_danger = pred_cheb <= 1
|
||
|
||
def _update_recharge_mode(self):
|
||
"""Enter/exit low-battery recharge mode."""
|
||
battery_ratio = self.battery / max(self.battery_max, 1)
|
||
|
||
if not self.has_charger:
|
||
self.recharge_mode = False
|
||
self.charger_safety_margin = float(self.battery)
|
||
self.recharge_enter_margin = 0.0
|
||
self.recharge_leave_margin = 0.0
|
||
self.recharge_low_battery_ratio = 0.28
|
||
self.full_charge_leave_ratio = 0.96
|
||
self.low_battery = battery_ratio < self.recharge_low_battery_ratio
|
||
return
|
||
|
||
self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0))
|
||
self.charger_safety_buffer = self._charger_safety_buffer()
|
||
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
|
||
self.recharge_enter_margin = self._recharge_enter_margin()
|
||
self.recharge_leave_margin = self._recharge_leave_margin()
|
||
self.recharge_low_battery_ratio = self._recharge_low_battery_ratio()
|
||
self.full_charge_leave_ratio = self._full_charge_leave_ratio()
|
||
|
||
self.low_battery = battery_ratio < self.recharge_low_battery_ratio
|
||
should_recharge = self.charger_safety_margin <= self.recharge_enter_margin or self.low_battery
|
||
safe_to_leave = (
|
||
battery_ratio >= self.full_charge_leave_ratio
|
||
and self.charger_safety_margin >= self.recharge_leave_margin
|
||
)
|
||
|
||
if self.on_charger and safe_to_leave:
|
||
self.recharge_mode = False
|
||
elif should_recharge:
|
||
self.recharge_mode = True
|
||
elif self.recharge_mode and not safe_to_leave:
|
||
self.recharge_mode = True
|
||
else:
|
||
self.recharge_mode = False
|
||
|
||
if self.recharge_mode:
|
||
self.recharge_steps += 1
|
||
|
||
def _update_motion_health(self):
|
||
"""Track recharge-mode stalls so action masking can recover."""
|
||
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
||
if self.charge_delta <= 0:
|
||
self.stuck_steps += 1
|
||
self.stuck_count += 1
|
||
self.max_stuck_steps = max(self.max_stuck_steps, self.stuck_steps)
|
||
else:
|
||
self.stuck_steps = 0
|
||
else:
|
||
self.stuck_steps = 0
|
||
|
||
battery_ratio = self.battery / max(self.battery_max, 1)
|
||
battery_increased = self.battery > self.prev_battery + 1
|
||
maybe_fake_charger = self.on_charger and battery_ratio < 0.9 and self.charge_delta <= 0 and not battery_increased
|
||
self.fake_charger_steps = self.fake_charger_steps + 1 if maybe_fake_charger else 0
|
||
|
||
no_progress = (
|
||
self.recharge_mode
|
||
and self.has_charger
|
||
and self.charge_delta <= 0
|
||
and not battery_increased
|
||
and self.nearest_charger_path_dist >= self.last_nearest_charger_path_dist - 0.1
|
||
)
|
||
self.recharge_no_progress_steps = self.recharge_no_progress_steps + 1 if no_progress else 0
|
||
|
||
if self.step_no > 0 and min(self.nearest_npc_dist, self.predicted_npc_dist) <= 3:
|
||
self.npc_close_steps += 1
|
||
if self.step_no > 0 and (self.npc_danger or self.npc_predicted_danger):
|
||
self.npc_danger_steps += 1
|
||
|
||
if self.terminated and not self.truncated:
|
||
if self.battery <= 0 or self.remaining_charge <= 0:
|
||
self.battery_fail = 1
|
||
if self.npc_danger or self.npc_predicted_danger or self.nearest_npc_dist <= 1:
|
||
self.npc_collision = 1
|
||
|
||
def _need_recharge_escape(self):
|
||
return self.stuck_steps >= 2 or self.recharge_no_progress_steps >= 5 or self.fake_charger_steps >= 2
|
||
|
||
def _charger_safety_buffer(self):
|
||
# One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise.
|
||
base = max(12.0, 0.07 * float(self.battery_max))
|
||
distance_buffer = min(10.0, 0.12 * float(max(self.nearest_charger_range_dist, 0.0)))
|
||
obstacle_buffer = 10.0 * float(self.local_obstacle_ratio)
|
||
route_uncertainty_buffer = 6.0 if self.has_charger and not self.charger_route_known else 0.0
|
||
return float(np.clip(base + distance_buffer + obstacle_buffer + route_uncertainty_buffer, 12.0, 34.0))
|
||
|
||
def _recharge_enter_margin(self):
|
||
"""Adaptive margin for entering recharge mode before the battery is barely enough."""
|
||
base = max(4.0, 0.018 * float(self.battery_max))
|
||
path_margin = min(8.0, 0.06 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||
obstacle_margin = 8.0 * float(self.local_obstacle_ratio)
|
||
route_uncertainty_margin = 5.0 if self.has_charger and not self.charger_route_known else 0.0
|
||
recovery_margin = min(6.0, 1.2 * float(self.recharge_no_progress_steps + self.fake_charger_steps))
|
||
return float(
|
||
np.clip(
|
||
base + path_margin + obstacle_margin + route_uncertainty_margin + recovery_margin,
|
||
4.0,
|
||
26.0,
|
||
)
|
||
)
|
||
|
||
def _recharge_leave_margin(self):
|
||
"""Adaptive safety margin required before leaving a charger."""
|
||
base = max(12.0, 0.05 * float(self.battery_max))
|
||
path_margin = min(12.0, 0.10 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||
obstacle_margin = 8.0 * float(self.local_obstacle_ratio)
|
||
return float(np.clip(base + path_margin + obstacle_margin, 12.0, 42.0))
|
||
|
||
def _recharge_low_battery_ratio(self):
|
||
"""Adaptive low-battery ratio based on route length and local obstacle density."""
|
||
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
|
||
ratio = 0.22 + min(0.10, 0.36 * path_pressure) + min(0.035, 0.12 * float(self.local_obstacle_ratio))
|
||
if self.has_charger and not self.charger_route_known:
|
||
ratio += 0.035
|
||
if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0:
|
||
ratio += 0.02
|
||
return float(np.clip(ratio, 0.22, 0.38))
|
||
|
||
def _full_charge_leave_ratio(self):
|
||
"""Adaptive near-full threshold for leaving a charger."""
|
||
remaining_step_ratio = 1.0 - _norm(self.step_no, self.max_step)
|
||
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
|
||
ratio = 0.84 + 0.04 * remaining_step_ratio + min(0.02, 0.08 * path_pressure)
|
||
ratio += min(0.01, 0.04 * float(self.local_obstacle_ratio))
|
||
return float(np.clip(ratio, 0.84, 0.92))
|
||
|
||
def _recharge_risk_score(self):
|
||
"""Risk score in [0, 1] used to scale recharge rewards and penalties."""
|
||
if not self.has_charger:
|
||
return 0.0
|
||
battery_ratio = self.battery / max(self.battery_max, 1)
|
||
margin_deficit = max(0.0, self.recharge_enter_margin - self.charger_safety_margin)
|
||
margin_risk = margin_deficit / max(self.charger_safety_buffer + self.recharge_enter_margin, 1.0)
|
||
low_battery_risk = max(0.0, self.recharge_low_battery_ratio - battery_ratio)
|
||
low_battery_risk /= max(self.recharge_low_battery_ratio, 1e-6)
|
||
progress_risk = min(1.0, float(self.recharge_no_progress_steps) / 5.0)
|
||
return float(np.clip(0.55 * margin_risk + 0.35 * low_battery_risk + 0.10 * progress_risk, 0.0, 1.0))
|
||
|
||
def useful_charge_reward_weight(self):
|
||
"""Adaptive reward weight for charging that happens under real battery pressure."""
|
||
prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1)
|
||
prev_low_risk = max(0.0, self.recharge_low_battery_ratio - prev_battery_ratio)
|
||
prev_low_risk /= max(self.recharge_low_battery_ratio, 1e-6)
|
||
risk = max(self._recharge_risk_score(), prev_low_risk)
|
||
mode_bonus = 0.25 if self.was_recharge_mode or self.prev_low_battery else 0.0
|
||
return float(np.clip(0.60 + 0.65 * risk + mode_bonus, 0.60, 1.45))
|
||
|
||
def battery_fail_penalty(self):
|
||
"""Adaptive terminal penalty for running out of battery before max steps."""
|
||
step_ratio = _norm(self.step_no, self.max_step)
|
||
early_fail_risk = 1.0 - step_ratio
|
||
path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0)
|
||
risk = max(self._recharge_risk_score(), min(1.0, path_pressure))
|
||
penalty = float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0))
|
||
if self.reward_profile == "battery_safe":
|
||
penalty *= 1.25
|
||
return penalty
|
||
|
||
def _min_charger_range_dist(self, x, z):
|
||
if not self.charger_rects:
|
||
return float(self.GRID_SIZE)
|
||
dists = []
|
||
for rx, rz, w, h in self.charger_rects:
|
||
dx, dz = self._relative_vector_to_rect(x, z, rx, rz, w, h)
|
||
dists.append(max(abs(dx), abs(dz)))
|
||
return float(min(dists))
|
||
|
||
def _is_charger_cell(self, x, z):
|
||
for rx, rz, w, h in self.charger_rects:
|
||
if rx <= x < rx + w and rz <= z < rz + h:
|
||
return True
|
||
return False
|
||
|
||
def _get_local_view_feature(self):
|
||
"""Local view feature: 21×21×6 multi-channel map.
|
||
|
||
Channels: obstacle, clean, dirt, visit count, NPC danger, charger.
|
||
"""
|
||
view = self._view_map
|
||
channels = np.zeros((self.MAP_CHANNELS, self.VIEW_SIZE, self.VIEW_SIZE), dtype=np.float32)
|
||
if view is None or view.shape[0] != self.VIEW_SIZE or view.shape[1] != self.VIEW_SIZE:
|
||
return channels.flatten()
|
||
|
||
channels[0] = (view == 0).astype(np.float32)
|
||
channels[1] = (view == 1).astype(np.float32)
|
||
channels[2] = (view == 2).astype(np.float32)
|
||
|
||
for ri in range(self.VIEW_SIZE):
|
||
for ci in range(self.VIEW_SIZE):
|
||
gx, gz = self._view_index_to_global(ri, ci)
|
||
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
|
||
continue
|
||
channels[3, ri, ci] = _norm(min(int(self.visit_count_map[gx, gz]), 10), 10)
|
||
channels[4, ri, ci] = 1.0 if self._is_npc_danger_cell(gx, gz, expanded=True) else 0.0
|
||
channels[5, ri, ci] = 1.0 if self._is_charger_cell(gx, gz) else 0.0
|
||
|
||
return channels.flatten()
|
||
|
||
def _get_global_state_feature(self):
|
||
"""Global state feature (66D).
|
||
|
||
Existing global state plus belief-map distances, action distance improvements,
|
||
known charger-route safety, and predicted NPC motion.
|
||
"""
|
||
step_norm = _norm(self.step_no, self.max_step)
|
||
battery_ratio = _norm(self.battery, self.battery_max)
|
||
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
|
||
remaining_dirt = 1.0 - cleaning_progress
|
||
|
||
hx, hz = self.cur_pos
|
||
pos_x_weak = self._weak_abs_position_feature(hx)
|
||
pos_z_weak = self._weak_abs_position_feature(hz)
|
||
|
||
# 4-directional ray to find nearest dirt
|
||
# 四方向射线找最近污渍距离
|
||
ray_dirs = [(0, -1), (1, 0), (0, 1), (-1, 0)] # N E S W
|
||
ray_dirt = []
|
||
max_ray = 30
|
||
for dx, dz in ray_dirs:
|
||
found = max_ray
|
||
for step in range(1, max_ray + 1):
|
||
gx = hx + dx * step
|
||
gz = hz + dz * step
|
||
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
|
||
break
|
||
cell = self._view_cell(dx * step, dz * step, default=0)
|
||
if cell == 2:
|
||
found = step
|
||
break
|
||
ray_dirt.append(_norm(found, max_ray))
|
||
|
||
# Nearest dirt path distance in the visible map.
|
||
# 视野内最近污渍路径距离。
|
||
self.last_nearest_dirt_dist = self.nearest_dirt_dist
|
||
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
|
||
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
||
|
||
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
|
||
charge_delta = 1.0 if self.charge_delta > 0 else 0.0
|
||
battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max)
|
||
visit_count_norm = _norm(min(self.current_visit_count, 10), 10)
|
||
step_cleaned_norm = _norm(self.step_cleaned_count, 9)
|
||
global_dirty_delta = _signed_norm(
|
||
np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -4.0, 4.0), 4.0
|
||
)
|
||
frontier_delta = _signed_norm(
|
||
np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -4.0, 4.0), 4.0
|
||
)
|
||
charger_margin_after_buffer = self.battery - self.nearest_charger_path_dist - self.charger_safety_buffer
|
||
|
||
base_features = np.array(
|
||
[
|
||
step_norm,
|
||
battery_ratio,
|
||
cleaning_progress,
|
||
remaining_dirt,
|
||
pos_x_weak,
|
||
pos_z_weak,
|
||
ray_dirt[0],
|
||
ray_dirt[1],
|
||
ray_dirt[2],
|
||
ray_dirt[3],
|
||
nearest_dirt_norm,
|
||
dirt_delta,
|
||
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
|
||
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
|
||
_norm(self.nearest_charger_path_dist, self.GRID_SIZE),
|
||
battery_margin_norm,
|
||
1.0 if self.low_battery else 0.0,
|
||
1.0 if self.recharge_mode else 0.0,
|
||
1.0 if self.on_charger else 0.0,
|
||
charge_delta,
|
||
_signed_norm(self.nearest_npc_dx, 20),
|
||
_signed_norm(self.nearest_npc_dz, 20),
|
||
_norm(self.nearest_npc_dist, 20),
|
||
1.0 if self.npc_danger else 0.0,
|
||
self.local_dirt_ratio,
|
||
self.local_obstacle_ratio,
|
||
visit_count_norm,
|
||
step_cleaned_norm,
|
||
_norm(self.global_dirty_path_dist, self.GRID_SIZE),
|
||
_norm(self.frontier_path_dist, self.GRID_SIZE),
|
||
global_dirty_delta,
|
||
frontier_delta,
|
||
self.known_ratio,
|
||
self.known_passable_ratio,
|
||
_norm(self.known_dirty_ratio, 1.0),
|
||
_norm(self.frontier_ratio, 1.0),
|
||
1.0 if self.charger_route_known else 0.0,
|
||
_signed_norm(charger_margin_after_buffer, self.battery_max),
|
||
_signed_norm(self.nearest_npc_vx, 1.0),
|
||
_signed_norm(self.nearest_npc_vz, 1.0),
|
||
_norm(self.predicted_npc_dist, 20),
|
||
1.0 if self.npc_predicted_danger else 0.0,
|
||
],
|
||
dtype=np.float32,
|
||
)
|
||
|
||
return np.concatenate(
|
||
[
|
||
base_features,
|
||
self.global_dirty_action_delta.astype(np.float32),
|
||
self.frontier_action_delta.astype(np.float32),
|
||
self.charger_action_delta.astype(np.float32),
|
||
]
|
||
)
|
||
|
||
def _weak_abs_position_feature(self, value):
|
||
pos_norm = _norm(value, self.GRID_SIZE)
|
||
return 0.5 + self.ABS_POS_FEATURE_SCALE * (pos_norm - 0.5)
|
||
|
||
def _calc_nearest_dirt_dist(self):
|
||
"""Find nearest dirt path distance from local view.
|
||
|
||
从局部视野中找最近污渍路径距离。
|
||
"""
|
||
dist = self._local_bfs_distances()
|
||
dirt_coords = np.argwhere(self._view_map == 2)
|
||
if len(dirt_coords) == 0:
|
||
return 200.0
|
||
best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
|
||
if best >= self.INF_DIST:
|
||
return 200.0
|
||
return best
|
||
|
||
def _local_bfs_distances(self, start_dx=0, start_dz=0):
|
||
"""Shortest path distances inside the current 21x21 local view."""
|
||
view = self._view_map
|
||
shape = view.shape
|
||
dist = np.full(shape, self.INF_DIST, dtype=np.float32)
|
||
start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz)
|
||
if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]):
|
||
return dist
|
||
if int(view[start_ri, start_ci]) == 0:
|
||
return dist
|
||
|
||
dist[start_ri, start_ci] = 0.0
|
||
queue = deque([(start_ri, start_ci)])
|
||
while queue:
|
||
ri, ci = queue.popleft()
|
||
base = dist[ri, ci]
|
||
for dx, dz in self.ACTION_DIRS:
|
||
nri = ri + dz
|
||
nci = ci + dx
|
||
if not (0 <= nri < shape[0] and 0 <= nci < shape[1]):
|
||
continue
|
||
if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST:
|
||
continue
|
||
if dx != 0 and dz != 0:
|
||
side_a = int(view[ri, nci]) != 0
|
||
side_b = int(view[nri, ci]) != 0
|
||
if not (side_a or side_b):
|
||
continue
|
||
dist[nri, nci] = base + 1.0
|
||
queue.append((nri, nci))
|
||
return dist
|
||
|
||
def _local_path_dist_to_charger(self, gx, gz):
|
||
"""Visible-map BFS distance from global x/z to nearest charger cell."""
|
||
best = self.INF_DIST
|
||
start_dx = gx - self.cur_pos[0]
|
||
start_dz = gz - self.cur_pos[1]
|
||
dist = self._local_bfs_distances(start_dx, start_dz)
|
||
for rx, rz, w, h in self.charger_rects:
|
||
for tx in range(rx, rx + w):
|
||
for tz in range(rz, rz + h):
|
||
dx = tx - self.cur_pos[0]
|
||
dz = tz - self.cur_pos[1]
|
||
ri, ci = self._view_delta_to_index(dx, dz)
|
||
if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]:
|
||
best = min(best, float(dist[ri, ci]))
|
||
return best
|
||
|
||
def _global_path_dist_to_charger(self, gx, gz):
|
||
"""Known-map BFS distance from a global cell to the nearest observed charger cell."""
|
||
dist = self._get_charger_reverse_dist()
|
||
return self._dist_at(dist, gx, gz, default=self.INF_DIST)
|
||
|
||
def _charger_move_distance(self, gx, gz):
|
||
"""Use known-map BFS to the charger when available, then visible BFS, then Chebyshev."""
|
||
path_dist = self._global_path_dist_to_charger(gx, gz)
|
||
if path_dist < self.INF_DIST:
|
||
return path_dist
|
||
path_dist = self._local_path_dist_to_charger(gx, gz)
|
||
if path_dist < self.INF_DIST:
|
||
return path_dist
|
||
return self._min_charger_range_dist(gx, gz)
|
||
|
||
def get_legal_action(self):
|
||
"""Return legal action mask (8D list).
|
||
|
||
返回合法动作掩码(8D list)。
|
||
"""
|
||
raw_legal = [int(x) for x in self._legal_act]
|
||
blocked_legal = self._filter_blocked_actions(raw_legal)
|
||
npc_legal = self._filter_npc_danger_actions(blocked_legal)
|
||
safe_legal = list(npc_legal)
|
||
recharge_legal = None
|
||
escape_legal = None
|
||
leave_legal = None
|
||
legal = npc_legal
|
||
if self.recharge_mode:
|
||
recharge_legal = self._filter_recharge_actions(legal)
|
||
escape_legal = self._filter_recharge_escape_actions(recharge_legal, safe_legal)
|
||
legal = escape_legal
|
||
elif self.on_charger and self.battery / max(self.battery_max, 1) >= self.full_charge_leave_ratio:
|
||
leave_legal = self._filter_leave_charger_actions(legal)
|
||
legal = leave_legal
|
||
self._record_mask_diagnostics(
|
||
raw_legal=raw_legal,
|
||
blocked_legal=blocked_legal,
|
||
npc_legal=npc_legal,
|
||
recharge_legal=recharge_legal,
|
||
escape_legal=escape_legal,
|
||
leave_legal=leave_legal,
|
||
final_legal=legal,
|
||
)
|
||
return list(legal)
|
||
|
||
def record_action(self, action):
|
||
"""Record the chosen action for episode diagnostics."""
|
||
try:
|
||
action = int(action)
|
||
except (TypeError, ValueError):
|
||
return
|
||
if 0 <= action < len(self.diag_action_hist):
|
||
self.diag_action_hist[action] += 1
|
||
|
||
def _record_mask_diagnostics(
|
||
self,
|
||
raw_legal,
|
||
blocked_legal,
|
||
npc_legal,
|
||
recharge_legal,
|
||
escape_legal,
|
||
leave_legal,
|
||
final_legal,
|
||
):
|
||
"""Record action-mask counts without changing mask behavior."""
|
||
self.diag_mask_steps += 1
|
||
stages = {
|
||
"raw": raw_legal,
|
||
"blocked": blocked_legal,
|
||
"npc": npc_legal,
|
||
"recharge": recharge_legal if recharge_legal is not None else npc_legal,
|
||
"escape": escape_legal if escape_legal is not None else (recharge_legal if recharge_legal is not None else npc_legal),
|
||
"leave": leave_legal if leave_legal is not None else npc_legal,
|
||
"final": final_legal,
|
||
}
|
||
for name, mask in stages.items():
|
||
self.diag_mask_count_sums[name] += self._mask_count(mask)
|
||
|
||
if not self._same_mask(raw_legal, blocked_legal):
|
||
self.diag_mask_changed_steps["blocked"] += 1
|
||
if not self._same_mask(blocked_legal, npc_legal):
|
||
self.diag_mask_changed_steps["npc"] += 1
|
||
if recharge_legal is not None:
|
||
self.diag_mask_active_steps["recharge"] += 1
|
||
if not self._same_mask(npc_legal, recharge_legal):
|
||
self.diag_mask_changed_steps["recharge"] += 1
|
||
if escape_legal is not None and recharge_legal is not None:
|
||
if not self._same_mask(recharge_legal, escape_legal):
|
||
self.diag_mask_changed_steps["escape"] += 1
|
||
if leave_legal is not None:
|
||
self.diag_mask_active_steps["leave"] += 1
|
||
if not self._same_mask(npc_legal, leave_legal):
|
||
self.diag_mask_changed_steps["leave"] += 1
|
||
|
||
final_count = self._mask_count(final_legal)
|
||
if final_count <= 0:
|
||
self.diag_zero_final_steps += 1
|
||
if final_count == 1:
|
||
self.diag_one_action_steps += 1
|
||
if final_count <= 2:
|
||
self.diag_two_or_less_action_steps += 1
|
||
|
||
def _mask_count(self, mask):
|
||
return int(sum(1 for value in mask if int(value) > 0))
|
||
|
||
def _same_mask(self, left, right):
|
||
return [int(x) for x in left] == [int(x) for x in right]
|
||
|
||
def get_diagnostic_summary(self):
|
||
"""Return episode-level diagnostic counters for logging."""
|
||
steps = max(self.diag_mask_steps, 1)
|
||
avg_mask_counts = {
|
||
name: self.diag_mask_count_sums[name] / steps for name in sorted(self.diag_mask_count_sums)
|
||
}
|
||
return {
|
||
"map_id": self.map_id,
|
||
"mask_steps": self.diag_mask_steps,
|
||
"avg_mask_counts": avg_mask_counts,
|
||
"mask_changed_steps": dict(self.diag_mask_changed_steps),
|
||
"mask_active_steps": dict(self.diag_mask_active_steps),
|
||
"one_action_steps": self.diag_one_action_steps,
|
||
"two_or_less_action_steps": self.diag_two_or_less_action_steps,
|
||
"zero_final_steps": self.diag_zero_final_steps,
|
||
"action_hist": list(self.diag_action_hist),
|
||
"known_ratio": self.known_ratio,
|
||
"known_dirty_ratio": self.known_dirty_ratio,
|
||
"frontier_ratio": self.frontier_ratio,
|
||
"local_dirt_ratio": self.local_dirt_ratio,
|
||
"local_obstacle_ratio": self.local_obstacle_ratio,
|
||
"global_dirty_path_dist": self.global_dirty_path_dist,
|
||
"frontier_path_dist": self.frontier_path_dist,
|
||
"charger_route_source": self.charger_route_source,
|
||
"reward_profile": self.reward_profile,
|
||
}
|
||
|
||
def evaluation_action_score(self, action):
|
||
"""Heuristic score used only to break close evaluation-policy ties."""
|
||
if not (0 <= int(action) < len(self.ACTION_DIRS)):
|
||
return -1e6
|
||
action = int(action)
|
||
dx, dz = self.ACTION_DIRS[action]
|
||
hx, hz = self.cur_pos
|
||
nx, nz = hx + dx, hz + dz
|
||
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
|
||
return -1e6
|
||
|
||
score = 0.0
|
||
cell = self._view_cell(dx, dz, default=1)
|
||
if cell == 0:
|
||
score -= 8.0
|
||
elif cell == 2:
|
||
score += 3.0
|
||
else:
|
||
score -= 0.10
|
||
|
||
visit_count = int(self.visit_count_map[nx, nz]) if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE else 0
|
||
score += 0.35 if visit_count == 0 else -0.05 * min(visit_count, 10)
|
||
|
||
if self.recharge_mode:
|
||
if self.charger_route_known:
|
||
score += 2.2 * float(self.charger_action_delta[action])
|
||
if self._charger_move_distance(nx, nz) < self._charger_move_distance(hx, hz):
|
||
score += 0.8
|
||
else:
|
||
score += 2.0 * float(self.frontier_action_delta[action])
|
||
score += 0.7 * max(float(self.global_dirty_action_delta[action]), 0.0)
|
||
if self._min_charger_range_dist(nx, nz) < self._min_charger_range_dist(hx, hz):
|
||
score += 0.15
|
||
else:
|
||
if self.global_dirty_path_dist < self.GRID_SIZE:
|
||
score += 1.8 * float(self.global_dirty_action_delta[action])
|
||
elif self.frontier_path_dist < self.GRID_SIZE:
|
||
score += 1.4 * float(self.frontier_action_delta[action])
|
||
|
||
if self.low_battery and self.has_charger:
|
||
score += 1.2 * float(self.charger_action_delta[action])
|
||
if self._is_charger_cell(nx, nz):
|
||
score += 0.8 if self.low_battery or self.recharge_mode else -0.2
|
||
if self._is_npc_danger_cell(nx, nz, expanded=False):
|
||
score -= 6.0
|
||
elif self._is_npc_danger_cell(nx, nz, expanded=True):
|
||
score -= 1.5
|
||
if action == self.last_action and self.stuck_steps >= 1:
|
||
score -= 1.0
|
||
return float(score)
|
||
|
||
def planned_eval_action(self, probs, legal_action):
|
||
"""Return a planner action for evaluation when it clearly beats the policy.
|
||
|
||
The planner is only used by exploit(). Training samples still come from
|
||
the stochastic PPO policy.
|
||
"""
|
||
probs = np.asarray(probs, dtype=np.float64)
|
||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||
if not np.any(legal):
|
||
legal = np.ones(8, dtype=bool)
|
||
|
||
legal_indices = np.flatnonzero(legal)
|
||
if legal_indices.size == 0:
|
||
return None
|
||
|
||
scored = []
|
||
for action in legal_indices:
|
||
action = int(action)
|
||
score = self._planned_eval_score(action)
|
||
if score <= -1e5:
|
||
continue
|
||
scored.append((score, float(probs[action]), -action, action))
|
||
|
||
if not scored:
|
||
return None
|
||
|
||
scored.sort(reverse=True)
|
||
best_score, _, _, planned_action = scored[0]
|
||
policy_action = int(legal_indices[np.argmax(probs[legal_indices])])
|
||
if planned_action == policy_action:
|
||
return planned_action
|
||
|
||
policy_score = self._planned_eval_score(policy_action)
|
||
policy_prob = float(probs[policy_action])
|
||
planned_prob = float(probs[planned_action])
|
||
force_safety = (
|
||
self.recharge_mode
|
||
or self.low_battery
|
||
or self.npc_danger
|
||
or self.npc_predicted_danger
|
||
or self.stuck_steps >= 1
|
||
)
|
||
if force_safety:
|
||
return planned_action
|
||
|
||
# Strongly prefer deterministic coverage when the learned policy is
|
||
# uncertain or the planner sees a much better cleaning/frontier move.
|
||
if policy_prob < 0.45 and best_score >= policy_score + 0.50:
|
||
return planned_action
|
||
if policy_prob - planned_prob <= 0.35 and best_score >= policy_score + 2.20:
|
||
return planned_action
|
||
return None
|
||
|
||
def _planned_eval_score(self, action):
|
||
"""Score one legal action for evaluation-time coverage planning."""
|
||
if not (0 <= int(action) < len(self.ACTION_DIRS)):
|
||
return -1e6
|
||
action = int(action)
|
||
dx, dz = self.ACTION_DIRS[action]
|
||
hx, hz = self.cur_pos
|
||
nx, nz = hx + dx, hz + dz
|
||
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
|
||
return -1e6
|
||
if not self._is_visible_cell_passable(dx, dz):
|
||
return -1e6
|
||
if dx != 0 and dz != 0:
|
||
if not (self._is_visible_cell_passable(dx, 0) or self._is_visible_cell_passable(0, dz)):
|
||
return -1e6
|
||
if self._is_npc_danger_cell(nx, nz, expanded=False):
|
||
return -1e6
|
||
|
||
score = self.evaluation_action_score(action)
|
||
cell = self._view_cell(dx, dz, default=1)
|
||
battery_ratio = self.battery / max(self.battery_max, 1)
|
||
visit_count = int(self.visit_count_map[nx, nz])
|
||
|
||
recharge_required = (
|
||
self.has_charger
|
||
and (
|
||
self.recharge_mode
|
||
or self.low_battery
|
||
or self.charger_safety_margin <= self.recharge_enter_margin + 4.0
|
||
)
|
||
)
|
||
if recharge_required:
|
||
cur_dist = self._charger_move_distance(hx, hz)
|
||
next_dist = self._charger_move_distance(nx, nz)
|
||
dist_delta = float(np.clip(cur_dist - next_dist, -2.0, 2.0))
|
||
score += 10.0 * dist_delta
|
||
if next_dist < cur_dist:
|
||
score += 3.0
|
||
if self._is_charger_cell(nx, nz):
|
||
score += 5.0
|
||
if cell == 2 and self.charger_safety_margin > self.recharge_enter_margin + 10.0:
|
||
score += 1.0
|
||
return float(score)
|
||
|
||
if cell == 2:
|
||
score += 10.0
|
||
else:
|
||
score -= 0.15
|
||
|
||
current_local_dirt = self.nearest_dirt_dist
|
||
next_local_dirt = self._nearest_local_dirt_dist_from(dx, dz)
|
||
if current_local_dirt < 200.0 and next_local_dirt < 200.0:
|
||
score += 3.0 * float(np.clip(current_local_dirt - next_local_dirt, -2.0, 2.0))
|
||
|
||
if self.global_dirty_path_dist < self.GRID_SIZE:
|
||
score += 5.0 * float(self.global_dirty_action_delta[action])
|
||
elif self.frontier_path_dist < self.GRID_SIZE:
|
||
score += 3.5 * float(self.frontier_action_delta[action])
|
||
|
||
score += 0.65 if visit_count == 0 else -0.16 * min(visit_count, 12)
|
||
if action == self.last_action and self.stuck_steps == 0:
|
||
score += 0.10
|
||
if self.has_charger and self.charger_safety_margin <= self.recharge_enter_margin + 12.0:
|
||
score += 2.0 * float(self.charger_action_delta[action])
|
||
if self._is_charger_cell(nx, nz) and battery_ratio > 0.55:
|
||
score -= 4.0
|
||
if self._is_npc_danger_cell(nx, nz, expanded=True):
|
||
score -= 3.0
|
||
return float(score)
|
||
|
||
def _nearest_local_dirt_dist_from(self, dx, dz):
|
||
"""Nearest visible dirt path distance after applying a candidate move."""
|
||
cell = self._view_cell(dx, dz, default=0)
|
||
if cell == 0:
|
||
return 200.0
|
||
if cell == 2:
|
||
return 0.0
|
||
|
||
dirt_coords = np.argwhere(self._view_map == 2)
|
||
if len(dirt_coords) == 0:
|
||
return 200.0
|
||
|
||
dist = self._local_bfs_distances(dx, dz)
|
||
best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
|
||
return best if best < self.INF_DIST else 200.0
|
||
|
||
def _filter_blocked_actions(self, legal_action):
|
||
"""Filter actions that are visibly blocked in the 21x21 view."""
|
||
legal = [int(x) for x in legal_action]
|
||
hx, hz = self.cur_pos
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
if legal[action] <= 0:
|
||
continue
|
||
if not self._is_visible_cell_passable(dx, dz):
|
||
legal[action] = 0
|
||
continue
|
||
if dx != 0 and dz != 0:
|
||
side_a = self._is_visible_cell_passable(dx, 0)
|
||
side_b = self._is_visible_cell_passable(0, dz)
|
||
if not (side_a or side_b):
|
||
legal[action] = 0
|
||
nx, nz = hx + dx, hz + dz
|
||
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
|
||
legal[action] = 0
|
||
|
||
return legal if any(legal) else [int(x) for x in legal_action]
|
||
|
||
def _is_visible_cell_passable(self, dx, dz):
|
||
"""Whether a relative 21x21-view cell is passable."""
|
||
cell = self._view_cell(dx, dz, default=None)
|
||
return True if cell is None else cell != 0
|
||
|
||
def _filter_npc_danger_actions(self, legal_action):
|
||
"""Avoid current and predicted NPC danger zones."""
|
||
if not self.npcs:
|
||
return list(legal_action)
|
||
|
||
hx, hz = self.cur_pos
|
||
safe = [int(x) for x in legal_action]
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
if safe[action] <= 0:
|
||
continue
|
||
nx, nz = hx + dx, hz + dz
|
||
if self._is_npc_danger_cell(nx, nz, expanded=True):
|
||
safe[action] = 0
|
||
|
||
if any(safe):
|
||
return safe
|
||
|
||
hard_safe = [int(x) for x in legal_action]
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
if hard_safe[action] <= 0:
|
||
continue
|
||
nx, nz = hx + dx, hz + dz
|
||
if self._is_npc_danger_cell(nx, nz, expanded=False):
|
||
hard_safe[action] = 0
|
||
return hard_safe if any(hard_safe) else list(legal_action)
|
||
|
||
def _is_npc_danger_cell(self, x, z, expanded=True):
|
||
for npc in self.npcs:
|
||
if not isinstance(npc, dict):
|
||
continue
|
||
pos = npc.get("pos") or {}
|
||
nx = int(pos.get("x", -999))
|
||
nz = int(pos.get("z", -999))
|
||
if abs(x - nx) <= 1 and abs(z - nz) <= 1:
|
||
return True
|
||
if expanded and abs(x - nx) <= 2 and abs(z - nz) <= 2 and self.nearest_npc_dist <= 4:
|
||
return True
|
||
if expanded:
|
||
for px, pz, radius in self.predicted_npcs:
|
||
if abs(x - px) <= radius and abs(z - pz) <= radius:
|
||
return True
|
||
if self.nearest_npc_dist <= 4 and abs(x - px) <= 2 and abs(z - pz) <= 2:
|
||
return True
|
||
return False
|
||
|
||
def _filter_recharge_actions(self, legal_action):
|
||
"""Restrict recharge-mode actions to safe moves toward the charger range."""
|
||
if not self.has_charger:
|
||
return list(legal_action)
|
||
|
||
hx, hz = self.cur_pos
|
||
current_range_dist = self._min_charger_range_dist(hx, hz)
|
||
current_move_dist = self._charger_move_distance(hx, hz)
|
||
scored = []
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
if legal_action[action] <= 0:
|
||
continue
|
||
nx, nz = hx + dx, hz + dz
|
||
next_dist = self._charger_move_distance(nx, nz)
|
||
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
|
||
next_range_dist = self._min_charger_range_dist(nx, nz)
|
||
scored.append((next_dist, alignment, next_range_dist, action))
|
||
|
||
if not scored:
|
||
return list(legal_action)
|
||
|
||
# When already inside the charger range, stay inside until recharge mode exits.
|
||
# 已经在充电区域内时,回充模式退出前不要离开充电区域。
|
||
confirmed_charger = self.charge_delta > 0 or self.battery > self.prev_battery + 1
|
||
if current_range_dist <= 0.0 and confirmed_charger:
|
||
stay = [0] * 8
|
||
for _, _, next_range_dist, action in scored:
|
||
if next_range_dist <= 0.0:
|
||
stay[action] = 1
|
||
if any(stay):
|
||
return stay
|
||
|
||
if not self.charger_route_known:
|
||
return self._filter_recharge_discovery_actions(legal_action, scored, current_range_dist)
|
||
|
||
recharge = [0] * 8
|
||
best_next_dist = min(item[0] for item in scored)
|
||
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
|
||
max_recharge_actions = 4
|
||
dist_slack = 2.5
|
||
for next_dist, alignment, next_range_dist, action in ranked:
|
||
route_progress = next_dist <= current_move_dist + 0.1
|
||
if next_dist <= best_next_dist + dist_slack and route_progress:
|
||
recharge[action] = 1
|
||
if sum(recharge) >= max_recharge_actions:
|
||
break
|
||
|
||
if not any(recharge):
|
||
for _, _, _, action in ranked[: min(max_recharge_actions, len(ranked))]:
|
||
recharge[action] = 1
|
||
|
||
return recharge if any(recharge) else list(legal_action)
|
||
|
||
def _filter_recharge_discovery_actions(self, legal_action, scored, current_range_dist):
|
||
"""When charger route is unknown, search for a route instead of pushing into walls."""
|
||
ranked = []
|
||
hx, hz = self.cur_pos
|
||
for next_dist, alignment, next_range_dist, action in scored:
|
||
if legal_action[action] <= 0:
|
||
continue
|
||
dx, dz = self.ACTION_DIRS[action]
|
||
nx, nz = hx + dx, hz + dz
|
||
visit_count = int(self.visit_count_map[nx, nz]) if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE else 0
|
||
frontier_gain = float(self.frontier_action_delta[action])
|
||
dirty_gain = float(self.global_dirty_action_delta[action])
|
||
range_gain = float(np.clip(current_range_dist - next_range_dist, -2.0, 2.0)) / 2.0
|
||
alignment_gain = 0.25 if alignment > 0 else 0.0
|
||
repeat_penalty = 0.8 if action == self.last_action and self.recharge_no_progress_steps >= 2 else 0.0
|
||
wall_hug_penalty = 0.35 * float(self.local_obstacle_ratio)
|
||
score = (
|
||
2.4 * frontier_gain
|
||
+ 0.8 * max(dirty_gain, 0.0)
|
||
+ 0.35 * range_gain
|
||
+ alignment_gain
|
||
- 0.04 * min(visit_count, 12)
|
||
- repeat_penalty
|
||
- wall_hug_penalty
|
||
)
|
||
ranked.append((score, action))
|
||
|
||
if not ranked:
|
||
return list(legal_action)
|
||
|
||
ranked.sort(reverse=True)
|
||
best_score = ranked[0][0]
|
||
discovery = [0] * 8
|
||
for score, action in ranked:
|
||
if score >= best_score - 0.35 or sum(discovery) < 3:
|
||
discovery[action] = 1
|
||
if sum(discovery) >= 5:
|
||
break
|
||
return discovery if any(discovery) else list(legal_action)
|
||
|
||
def _filter_recharge_escape_actions(self, recharge_action, safe_action):
|
||
"""Escape repeated no-move states during low-battery recharge."""
|
||
if not self._need_recharge_escape():
|
||
return list(recharge_action)
|
||
|
||
hx, hz = self.cur_pos
|
||
current_dist = self._charger_move_distance(hx, hz)
|
||
ranked = []
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
if safe_action[action] <= 0:
|
||
continue
|
||
nx, nz = hx + dx, hz + dz
|
||
next_dist = self._charger_move_distance(nx, nz)
|
||
visit_count = 0
|
||
if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE:
|
||
visit_count = int(self.visit_count_map[nx, nz])
|
||
failed_action_penalty = 6.0 if action == self.last_action and self.stuck_steps >= 2 else 0.0
|
||
no_progress_penalty = 1.5 if next_dist > current_dist + 0.1 else 0.0
|
||
ranked.append((next_dist + 0.05 * min(visit_count, 20) + failed_action_penalty + no_progress_penalty, action))
|
||
|
||
if not ranked:
|
||
return list(recharge_action)
|
||
|
||
self.recharge_escape_count += 1
|
||
ranked.sort()
|
||
escape = [0] * 8
|
||
for _, action in ranked[: min(4, len(ranked))]:
|
||
escape[action] = 1
|
||
|
||
if self.stuck_steps >= 2 and sum(escape) > 1 and 0 <= self.last_action < 8:
|
||
escape[self.last_action] = 0
|
||
|
||
return escape if any(escape) else list(recharge_action)
|
||
|
||
def _filter_leave_charger_actions(self, legal_action):
|
||
"""Prefer moves that leave charger range when battery is healthy."""
|
||
if not self.has_charger:
|
||
return list(legal_action)
|
||
|
||
hx, hz = self.cur_pos
|
||
current_dist = self._min_charger_range_dist(hx, hz)
|
||
scored = []
|
||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||
if legal_action[action] <= 0:
|
||
continue
|
||
nx, nz = hx + dx, hz + dz
|
||
next_dist = self._min_charger_range_dist(nx, nz)
|
||
away_score = -(dx * self.nearest_charger_center_dx + dz * self.nearest_charger_center_dz)
|
||
scored.append((next_dist - current_dist, away_score, action))
|
||
|
||
if not scored:
|
||
return list(legal_action)
|
||
|
||
best_escape = max(item[0] for item in scored)
|
||
leave = [0] * 8
|
||
if best_escape > 0:
|
||
for escape, _, action in scored:
|
||
if escape >= best_escape - 0.1:
|
||
leave[action] = 1
|
||
else:
|
||
best_away = max(item[1] for item in scored)
|
||
for _, away_score, action in scored:
|
||
if away_score >= best_away:
|
||
leave[action] = 1
|
||
|
||
return leave if any(leave) else list(legal_action)
|
||
|
||
def feature_process(self, env_obs, last_action):
|
||
"""Generate feature vector, legal action mask, and scalar reward.
|
||
|
||
生成特征向量、合法动作掩码和标量奖励。
|
||
"""
|
||
self.pb2struct(env_obs, last_action)
|
||
|
||
local_view = self._get_local_view_feature() # 2646D
|
||
global_state = self._get_global_state_feature() # 66D
|
||
legal_action = self.get_legal_action() # 8D
|
||
|
||
last_action_feature = np.zeros(8, dtype=np.float32)
|
||
if 0 <= last_action < 8:
|
||
last_action_feature[last_action] = 1.0
|
||
|
||
feature = np.concatenate([local_view, global_state, last_action_feature])
|
||
|
||
reward = self.reward_process()
|
||
|
||
return feature, legal_action, reward
|
||
|
||
def reward_process(self):
|
||
cleaning_multiplier, charge_multiplier, exploration_multiplier = self._reward_profile_scales()
|
||
|
||
# Cleaning reward / 清扫奖励
|
||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||
cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step
|
||
battery_ratio = self.battery / max(self.battery_max, 1)
|
||
cleaning_reward = cleaning_multiplier * float(cleaned_cells)
|
||
|
||
# Step penalty / 时间惩罚
|
||
step_penalty = -0.004
|
||
|
||
# Recharge guidance only activates when battery safety is the bottleneck.
|
||
# 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。
|
||
charge_reward = 0.0
|
||
prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1)
|
||
useful_charge = self.charge_delta > 0 and (
|
||
self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.35
|
||
)
|
||
if useful_charge:
|
||
charge_reward += self.useful_charge_reward_weight()
|
||
elif self.charge_delta > 0 and battery_ratio > 0.55:
|
||
charge_reward -= 0.45 * min(self.charge_delta, 3)
|
||
|
||
if self.has_charger and (self.recharge_mode or self.low_battery):
|
||
recharge_risk = self._recharge_risk_score()
|
||
if not self.charger_route_known:
|
||
frontier_progress = float(
|
||
np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
|
||
)
|
||
range_delta = float(
|
||
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -2.0, 2.0)
|
||
)
|
||
discovery_scale = 0.020 + 0.030 * recharge_risk
|
||
range_scale = 0.010 + 0.018 * recharge_risk
|
||
charge_reward += discovery_scale * frontier_progress
|
||
if self.prev_pos is not None and self.cur_pos != self.prev_pos and self.stuck_steps == 0:
|
||
charge_reward += range_scale * range_delta
|
||
else:
|
||
dist_delta = float(
|
||
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
|
||
)
|
||
approach_scale = 0.040 + 0.045 * recharge_risk
|
||
retreat_scale = 0.020 + 0.035 * recharge_risk
|
||
charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta
|
||
if self.charger_safety_margin < self.recharge_enter_margin:
|
||
safety_shortage = self.recharge_enter_margin - self.charger_safety_margin
|
||
charge_reward -= min(0.35, safety_shortage / max(self.battery_max, 1))
|
||
elif self.on_charger and battery_ratio > 0.55:
|
||
charge_reward -= 0.18
|
||
charge_reward *= charge_multiplier
|
||
|
||
# Encourage covering new passable cells and mildly discourage loops.
|
||
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
||
if self.recharge_mode:
|
||
exploration_reward = 0.0
|
||
else:
|
||
exploration_reward = 0.020 if self.is_new_cell else -0.006 * min(self.current_visit_count, 8)
|
||
if self.global_dirty_path_dist < self.GRID_SIZE:
|
||
dirty_progress = np.clip(self.last_global_dirty_path_dist - self.global_dirty_path_dist, -3.0, 3.0)
|
||
exploration_reward += 0.020 * dirty_progress
|
||
elif self.frontier_path_dist < self.GRID_SIZE:
|
||
frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
|
||
exploration_reward += 0.014 * frontier_progress
|
||
exploration_reward *= exploration_multiplier
|
||
|
||
# Collision/stuck signal: invalid moves waste both step and battery.
|
||
# 撞墙/原地不动会浪费步数和电量。
|
||
stuck_penalty = 0.0
|
||
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
||
stuck_penalty = -0.08
|
||
if self.recharge_mode:
|
||
stuck_penalty -= 0.04 * min(self.stuck_steps, 5)
|
||
|
||
npc_penalty = 0.0
|
||
if self.npc_danger:
|
||
npc_penalty -= 4.0
|
||
elif self.npc_predicted_danger:
|
||
npc_penalty -= 0.4
|
||
elif self.nearest_npc_dist <= 3:
|
||
npc_penalty -= 0.05 * (4 - self.nearest_npc_dist)
|
||
|
||
return (
|
||
cleaning_reward
|
||
+ charge_reward
|
||
+ exploration_reward
|
||
+ stuck_penalty
|
||
+ npc_penalty
|
||
+ step_penalty
|
||
)
|
||
|
||
def _reward_profile_scales(self):
|
||
"""Return multipliers for quick reward-shaping ablations."""
|
||
if self.reward_profile == "lower_recharge":
|
||
return 1.0, 0.45, 1.0
|
||
if self.reward_profile == "clean_explore":
|
||
return 1.10, 0.60, 1.35
|
||
if self.reward_profile == "battery_safe":
|
||
return 0.95, 0.85, 0.90
|
||
return 1.0, 1.0, 1.0
|