优化PPO充电与避障策略
扩展观测特征到157维,加入充电桩、NPC、电量安全余量、地图统计和本步清扫信息。 增加低电量回充动作过滤、NPC危险区过滤,并调整奖励和终局日志以突出充电、避障和真实清扫得分。
This commit is contained in:
@@ -55,9 +55,9 @@ class Agent(BaseAgent):
|
||||
self.last_reward = 0.0
|
||||
|
||||
def observation_process(self, env_obs):
|
||||
"""Convert raw env_obs to ObsData (69D feature + legal action mask).
|
||||
"""Convert raw env_obs to ObsData (feature + legal action mask).
|
||||
|
||||
将原始 env_obs 转换为 ObsData(69D 特征 + 合法动作掩码)。
|
||||
将原始 env_obs 转换为 ObsData(特征 + 合法动作掩码)。
|
||||
"""
|
||||
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
||||
self.last_reward = reward
|
||||
@@ -135,8 +135,16 @@ class Agent(BaseAgent):
|
||||
加载模型检查点。
|
||||
"""
|
||||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||||
self.model.load_state_dict(torch.load(model_file_path, map_location=self.device))
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
state_dict = torch.load(model_file_path, map_location=self.device)
|
||||
try:
|
||||
self.model.load_state_dict(state_dict)
|
||||
except RuntimeError as err:
|
||||
msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}"
|
||||
if self.logger:
|
||||
self.logger.warning(msg)
|
||||
return
|
||||
if self.logger:
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
|
||||
def _run_model(self, feature):
|
||||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||||
|
||||
@@ -13,12 +13,12 @@ Configuration for Robot Vacuum PPO agent.
|
||||
|
||||
class Config:
|
||||
|
||||
# Feature dimensions (69D)
|
||||
# 特征维度(69D)
|
||||
# Feature dimensions (157D)
|
||||
# 特征维度(157D)
|
||||
FEATURES = [
|
||||
7 * 7,
|
||||
12,
|
||||
8,
|
||||
11 * 11, # wider local map view / 更大的局部地图视野
|
||||
28, # global, charger, NPC, and map-stat features / 全局、充电桩、NPC、地图统计特征
|
||||
8, # last action one-hot / 上一步动作 one-hot
|
||||
]
|
||||
FEATURE_SPLIT_SHAPE = FEATURES
|
||||
FEATURE_LEN = sum(FEATURES)
|
||||
|
||||
@@ -33,7 +33,7 @@ ActData = create_cls(
|
||||
# 训练样本数据:字段值为 int 时框架自动按维度处理
|
||||
SampleData = create_cls(
|
||||
"SampleData",
|
||||
obs=Config.DIM_OF_OBSERVATION, # 69D feature vector / 特征向量
|
||||
obs=Config.DIM_OF_OBSERVATION, # feature vector / 特征向量
|
||||
legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码
|
||||
act=1, # action index / 执行的动作
|
||||
reward=Config.VALUE_NUM, # 1D reward / 奖励
|
||||
|
||||
@@ -24,6 +24,13 @@ def _norm(v, v_max, v_min=0.0):
|
||||
return (v - v_min) / (v_max - v_min)
|
||||
|
||||
|
||||
def _signed_norm(v, v_max):
|
||||
"""Normalize signed value to [-1, 1]."""
|
||||
if v_max <= 0:
|
||||
return 0.0
|
||||
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
|
||||
|
||||
|
||||
class Preprocessor:
|
||||
"""Feature preprocessor for Robot Vacuum.
|
||||
|
||||
@@ -32,7 +39,17 @@ class Preprocessor:
|
||||
|
||||
GRID_SIZE = 128
|
||||
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
|
||||
LOCAL_HALF = 3 # Cropped view radius (7×7) / 裁剪后的视野半径
|
||||
LOCAL_HALF = 5 # Cropped view radius (11×11) / 裁剪后的视野半径
|
||||
ACTION_DIRS = (
|
||||
(1, 0),
|
||||
(1, -1),
|
||||
(0, -1),
|
||||
(-1, -1),
|
||||
(-1, 0),
|
||||
(-1, 1),
|
||||
(0, 1),
|
||||
(1, 1),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -56,6 +73,10 @@ class Preprocessor:
|
||||
self.dirt_cleaned = 0
|
||||
self.last_dirt_cleaned = 0
|
||||
self.total_dirt = 1
|
||||
self.total_score = 0
|
||||
self.clean_score = 0
|
||||
self.step_cleaned_count = 0
|
||||
self.max_step = 1000
|
||||
|
||||
# Global passable map (0=obstacle, 1=passable), used for ray computation
|
||||
# 维护全局通行地图(0=障碍, 1=可通行),用于射线计算
|
||||
@@ -69,6 +90,33 @@ class Preprocessor:
|
||||
|
||||
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
||||
self._legal_act = [1] * 8
|
||||
self.terminated = False
|
||||
self.truncated = False
|
||||
|
||||
self.remaining_charge = 0
|
||||
self.charge_count = 0
|
||||
self.last_charge_count = 0
|
||||
self.charge_delta = 0
|
||||
self.nearest_charger_dx = 0.0
|
||||
self.nearest_charger_dz = 0.0
|
||||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||
self.battery_margin = 0.0
|
||||
self.has_charger = False
|
||||
self.low_battery = False
|
||||
self.on_charger = False
|
||||
self.recharge_mode = False
|
||||
self.recharge_steps = 0
|
||||
|
||||
self.nearest_npc_dx = 0.0
|
||||
self.nearest_npc_dz = 0.0
|
||||
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||||
self.npc_danger = False
|
||||
self.npcs = []
|
||||
|
||||
self.local_dirt_ratio = 0.0
|
||||
self.local_obstacle_ratio = 0.0
|
||||
|
||||
def pb2struct(self, env_obs, last_action):
|
||||
"""Parse and cache essential fields from observation dict.
|
||||
@@ -76,14 +124,20 @@ class Preprocessor:
|
||||
从 env_obs 字典中提取并缓存所有需要的状态量。
|
||||
"""
|
||||
observation = env_obs["observation"]
|
||||
frame_state = observation["frame_state"]
|
||||
env_info = observation["env_info"]
|
||||
hero = frame_state["heroes"]
|
||||
frame_state = observation.get("frame_state", {})
|
||||
extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {})
|
||||
env_info = observation.get("env_info", {})
|
||||
hero = frame_state.get("heroes", {})
|
||||
if isinstance(hero, list):
|
||||
hero = hero[0] if hero else {}
|
||||
|
||||
self.last_action = int(last_action)
|
||||
self.step_no = int(observation["step_no"])
|
||||
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
||||
self.terminated = bool(env_obs.get("terminated", False))
|
||||
self.truncated = bool(env_obs.get("truncated", False))
|
||||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||||
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
|
||||
hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]}
|
||||
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
|
||||
self.has_position_history = True
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
@@ -96,16 +150,30 @@ class Preprocessor:
|
||||
self.is_new_cell = False
|
||||
|
||||
# Battery / 电量
|
||||
self.battery = int(hero["battery"])
|
||||
self.battery_max = max(int(hero["battery_max"]), 1)
|
||||
self.battery = int(hero.get("battery", env_info.get("remaining_charge", self.battery)))
|
||||
self.battery_max = max(int(hero.get("battery_max", env_info.get("battery_max", self.battery_max))), 1)
|
||||
self.remaining_charge = int(env_info.get("remaining_charge", self.battery))
|
||||
self.max_step = max(int(env_info.get("max_step", self.max_step)), 1)
|
||||
|
||||
# Cleaning progress / 清扫进度
|
||||
self.last_dirt_cleaned = self.dirt_cleaned
|
||||
self.dirt_cleaned = int(hero["dirt_cleaned"])
|
||||
self.total_dirt = max(int(env_info["total_dirt"]), 1)
|
||||
self.dirt_cleaned = int(hero.get("dirt_cleaned", env_info.get("clean_score", self.dirt_cleaned)))
|
||||
self.total_dirt = max(int(env_info.get("total_dirt", self.total_dirt)), 1)
|
||||
self.total_score = int(env_info.get("total_score", self.total_score))
|
||||
self.clean_score = int(env_info.get("clean_score", self.dirt_cleaned))
|
||||
step_cleaned_cells = env_info.get("step_cleaned_cells") or []
|
||||
self.step_cleaned_count = len(step_cleaned_cells)
|
||||
|
||||
# Charge progress / 充电进度
|
||||
self.last_charge_count = self.charge_count
|
||||
self.charge_count = int(env_info.get("charge_count", self.charge_count))
|
||||
self.charge_delta = max(0, self.charge_count - self.last_charge_count)
|
||||
|
||||
# Legal actions / 合法动作
|
||||
self._legal_act = [int(x) for x in (observation.get("legal_action") or [1] * 8)]
|
||||
raw_legal_act = observation.get("legal_action") or observation.get("legal_act") or [1] * 8
|
||||
self._legal_act = [int(x) for x in raw_legal_act[:8]]
|
||||
if len(self._legal_act) < 8:
|
||||
self._legal_act.extend([1] * (8 - len(self._legal_act)))
|
||||
|
||||
# Local view map (21×21) / 局部视野地图
|
||||
map_info = observation.get("map_info")
|
||||
@@ -113,6 +181,14 @@ class Preprocessor:
|
||||
self._view_map = np.array(map_info, dtype=np.float32)
|
||||
hx, hz = self.cur_pos
|
||||
self._update_passable(hx, hz)
|
||||
self._update_local_map_stats()
|
||||
|
||||
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
||||
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
|
||||
self.npcs = list(npcs)
|
||||
self._update_charger_state(hx, hz, organs)
|
||||
self._update_npc_state(hx, hz, self.npcs)
|
||||
self._update_recharge_mode()
|
||||
|
||||
def _update_passable(self, hx, hz):
|
||||
"""Write local view into global passable map.
|
||||
@@ -132,10 +208,130 @@ class Preprocessor:
|
||||
# 0 = 障碍, 1/2 = 可通行
|
||||
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
||||
|
||||
def _get_local_view_feature(self):
|
||||
"""Local view feature (49D): crop center 7×7 from 21×21.
|
||||
def _update_local_map_stats(self):
|
||||
"""Cache coarse 21x21 map statistics."""
|
||||
view = self._view_map
|
||||
if view is None or view.size == 0:
|
||||
self.local_dirt_ratio = 0.0
|
||||
self.local_obstacle_ratio = 0.0
|
||||
return
|
||||
total = float(view.size)
|
||||
self.local_dirt_ratio = float(np.sum(view == 2) / total)
|
||||
self.local_obstacle_ratio = float(np.sum(view == 0) / total)
|
||||
|
||||
局部视野特征(49D):从 21×21 视野中心裁剪 7×7。
|
||||
def _update_charger_state(self, hx, hz, organs):
|
||||
"""Find nearest charger and cache distance/direction features."""
|
||||
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
||||
self.has_charger = False
|
||||
self.on_charger = False
|
||||
self.nearest_charger_dx = 0.0
|
||||
self.nearest_charger_dz = 0.0
|
||||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||
|
||||
best = None
|
||||
for organ in organs:
|
||||
if not isinstance(organ, dict):
|
||||
continue
|
||||
if int(organ.get("sub_type", 1)) != 1:
|
||||
continue
|
||||
pos = organ.get("pos") or {}
|
||||
ox = int(pos.get("x", 0))
|
||||
oz = int(pos.get("z", 0))
|
||||
w = max(int(organ.get("w", 3)), 1)
|
||||
h = max(int(organ.get("h", 3)), 1)
|
||||
|
||||
for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)):
|
||||
dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h)
|
||||
dist = float(np.sqrt(dx * dx + dz * dz))
|
||||
range_dist = float(max(abs(dx), abs(dz)))
|
||||
if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]):
|
||||
best = (range_dist, dist, dx, dz)
|
||||
|
||||
if best is None:
|
||||
self.battery_margin = float(self.battery)
|
||||
return
|
||||
|
||||
range_dist, dist, dx, dz = best
|
||||
self.has_charger = True
|
||||
self.nearest_charger_dx = float(dx)
|
||||
self.nearest_charger_dz = float(dz)
|
||||
self.nearest_charger_dist = float(dist)
|
||||
self.nearest_charger_range_dist = float(range_dist)
|
||||
self.on_charger = range_dist <= 0.0
|
||||
self.battery_margin = float(self.battery) - self.nearest_charger_range_dist
|
||||
|
||||
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
|
||||
"""Relative vector from point to the nearest cell in a rectangle."""
|
||||
if x < rx:
|
||||
dx = rx - x
|
||||
elif x > rx + w - 1:
|
||||
dx = x - (rx + w - 1)
|
||||
else:
|
||||
dx = 0
|
||||
|
||||
if z < rz:
|
||||
dz = rz - z
|
||||
elif z > rz + h - 1:
|
||||
dz = z - (rz + h - 1)
|
||||
else:
|
||||
dz = 0
|
||||
return float(dx), float(dz)
|
||||
|
||||
def _update_npc_state(self, hx, hz, npcs):
|
||||
"""Find nearest NPC and cache safety features."""
|
||||
self.nearest_npc_dx = 0.0
|
||||
self.nearest_npc_dz = 0.0
|
||||
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||||
self.npc_danger = False
|
||||
|
||||
best = None
|
||||
for npc in npcs:
|
||||
if not isinstance(npc, dict):
|
||||
continue
|
||||
pos = npc.get("pos") or {}
|
||||
nx = int(pos.get("x", 0))
|
||||
nz = int(pos.get("z", 0))
|
||||
dx = nx - hx
|
||||
dz = nz - hz
|
||||
cheb = float(max(abs(dx), abs(dz)))
|
||||
if best is None or cheb < best[0]:
|
||||
best = (cheb, dx, dz)
|
||||
|
||||
if best is None:
|
||||
return
|
||||
|
||||
cheb, dx, dz = best
|
||||
self.nearest_npc_dx = float(dx)
|
||||
self.nearest_npc_dz = float(dz)
|
||||
self.nearest_npc_dist = float(cheb)
|
||||
self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1
|
||||
|
||||
def _update_recharge_mode(self):
|
||||
"""Enter/exit low-battery recharge mode."""
|
||||
battery_ratio = self.battery / max(self.battery_max, 1)
|
||||
self.low_battery = battery_ratio < 0.35
|
||||
|
||||
if not self.has_charger:
|
||||
self.recharge_mode = False
|
||||
return
|
||||
|
||||
if self.charge_delta > 0 or (self.on_charger and battery_ratio > 0.85):
|
||||
self.recharge_mode = False
|
||||
elif self.battery <= self.nearest_charger_range_dist + 18 or battery_ratio < 0.22:
|
||||
self.recharge_mode = True
|
||||
elif self.recharge_mode and battery_ratio < 0.85:
|
||||
self.recharge_mode = True
|
||||
else:
|
||||
self.recharge_mode = False
|
||||
|
||||
if self.recharge_mode:
|
||||
self.recharge_steps += 1
|
||||
|
||||
def _get_local_view_feature(self):
|
||||
"""Local view feature (121D): crop center 11×11 from 21×21.
|
||||
|
||||
局部视野特征(121D):从 21×21 视野中心裁剪 11×11。
|
||||
"""
|
||||
center = self.VIEW_HALF
|
||||
h = self.LOCAL_HALF
|
||||
@@ -143,9 +339,9 @@ class Preprocessor:
|
||||
return (crop / 2.0).flatten()
|
||||
|
||||
def _get_global_state_feature(self):
|
||||
"""Global state feature (12D).
|
||||
"""Global state feature (28D).
|
||||
|
||||
全局状态特征(12D)。
|
||||
全局状态特征(28D)。
|
||||
|
||||
Dimensions / 维度说明:
|
||||
[0] step_norm step progress / 步数归一化 [0,1]
|
||||
@@ -160,8 +356,24 @@ class Preprocessor:
|
||||
[9] ray_W_dirt west ray distance / 向左(x-)方向
|
||||
[10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化
|
||||
[11] dirt_delta approaching dirt indicator / 是否在接近污渍(1=是, 0=否)
|
||||
[12] charger_dx nearest charger x direction / 最近充电桩 x 相对方向
|
||||
[13] charger_dz nearest charger z direction / 最近充电桩 z 相对方向
|
||||
[14] charger_dist nearest charger distance / 最近充电桩距离
|
||||
[15] battery_margin battery minus charger distance / 电量安全余量
|
||||
[16] low_battery low-battery flag / 低电量标记
|
||||
[17] recharge_mode recharge-mode flag / 回充模式标记
|
||||
[18] on_charger on charger flag / 是否在充电桩范围
|
||||
[19] charge_delta charge count increased / 本步是否成功充电
|
||||
[20] npc_dx nearest NPC x direction / 最近 NPC x 相对方向
|
||||
[21] npc_dz nearest NPC z direction / 最近 NPC z 相对方向
|
||||
[22] npc_dist nearest NPC Chebyshev distance / 最近 NPC 切比雪夫距离
|
||||
[23] npc_danger in NPC 3x3 danger zone / 是否处于 NPC 3x3 危险区
|
||||
[24] local_dirt_ratio dirt ratio in 21x21 view / 21x21 视野污渍比例
|
||||
[25] obstacle_ratio obstacle ratio in 21x21 view / 21x21 视野障碍比例
|
||||
[26] visit_count current cell visit count / 当前格访问次数
|
||||
[27] step_cleaned cells cleaned this step / 本步清扫格子数
|
||||
"""
|
||||
step_norm = _norm(self.step_no, 2000)
|
||||
step_norm = _norm(self.step_no, self.max_step)
|
||||
battery_ratio = _norm(self.battery, self.battery_max)
|
||||
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
|
||||
remaining_dirt = 1.0 - cleaning_progress
|
||||
@@ -205,6 +417,10 @@ class Preprocessor:
|
||||
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
||||
|
||||
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
|
||||
charge_delta = 1.0 if self.charge_delta > 0 else 0.0
|
||||
battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max)
|
||||
visit_count_norm = _norm(min(self.current_visit_count, 10), 10)
|
||||
step_cleaned_norm = _norm(self.step_cleaned_count, 9)
|
||||
|
||||
return np.array(
|
||||
[
|
||||
@@ -220,6 +436,22 @@ class Preprocessor:
|
||||
ray_dirt[3],
|
||||
nearest_dirt_norm,
|
||||
dirt_delta,
|
||||
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
|
||||
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
|
||||
_norm(self.nearest_charger_range_dist, self.GRID_SIZE),
|
||||
battery_margin_norm,
|
||||
1.0 if self.low_battery else 0.0,
|
||||
1.0 if self.recharge_mode else 0.0,
|
||||
1.0 if self.on_charger else 0.0,
|
||||
charge_delta,
|
||||
_signed_norm(self.nearest_npc_dx, 20),
|
||||
_signed_norm(self.nearest_npc_dz, 20),
|
||||
_norm(self.nearest_npc_dist, 20),
|
||||
1.0 if self.npc_danger else 0.0,
|
||||
self.local_dirt_ratio,
|
||||
self.local_obstacle_ratio,
|
||||
visit_count_norm,
|
||||
step_cleaned_norm,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
@@ -244,27 +476,119 @@ class Preprocessor:
|
||||
|
||||
返回合法动作掩码(8D list)。
|
||||
"""
|
||||
return list(self._legal_act)
|
||||
legal = self._filter_blocked_actions(self._legal_act)
|
||||
legal = self._filter_npc_danger_actions(legal)
|
||||
if self.recharge_mode:
|
||||
legal = self._filter_recharge_actions(legal)
|
||||
return list(legal)
|
||||
|
||||
def _filter_blocked_actions(self, legal_action):
|
||||
"""Filter actions that are visibly blocked in the 21x21 view."""
|
||||
legal = [int(x) for x in legal_action]
|
||||
hx, hz = self.cur_pos
|
||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||
if legal[action] <= 0:
|
||||
continue
|
||||
if not self._is_visible_cell_passable(dx, dz):
|
||||
legal[action] = 0
|
||||
continue
|
||||
if dx != 0 and dz != 0:
|
||||
side_a = self._is_visible_cell_passable(dx, 0)
|
||||
side_b = self._is_visible_cell_passable(0, dz)
|
||||
if not (side_a or side_b):
|
||||
legal[action] = 0
|
||||
nx, nz = hx + dx, hz + dz
|
||||
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
|
||||
legal[action] = 0
|
||||
|
||||
return legal if any(legal) else [int(x) for x in legal_action]
|
||||
|
||||
def _is_visible_cell_passable(self, dx, dz):
|
||||
"""Whether a relative 21x21-view cell is passable."""
|
||||
ri = self.VIEW_HALF + dx
|
||||
ci = self.VIEW_HALF + dz
|
||||
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
||||
return True
|
||||
return int(self._view_map[ri, ci]) != 0
|
||||
|
||||
def _filter_npc_danger_actions(self, legal_action):
|
||||
"""Avoid actions that would enter any NPC 3x3 danger zone."""
|
||||
if not self.npcs:
|
||||
return list(legal_action)
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
safe = [int(x) for x in legal_action]
|
||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||
if safe[action] <= 0:
|
||||
continue
|
||||
nx, nz = hx + dx, hz + dz
|
||||
if self._is_npc_danger_cell(nx, nz):
|
||||
safe[action] = 0
|
||||
|
||||
return safe if any(safe) else list(legal_action)
|
||||
|
||||
def _is_npc_danger_cell(self, x, z):
|
||||
for npc in self.npcs:
|
||||
if not isinstance(npc, dict):
|
||||
continue
|
||||
pos = npc.get("pos") or {}
|
||||
nx = int(pos.get("x", -999))
|
||||
nz = int(pos.get("z", -999))
|
||||
if abs(x - nx) <= 1 and abs(z - nz) <= 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _filter_recharge_actions(self, legal_action):
|
||||
"""Restrict low-battery actions to moves that approach the charger."""
|
||||
if not self.has_charger:
|
||||
return list(legal_action)
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
current_dist = max(abs(self.nearest_charger_dx), abs(self.nearest_charger_dz))
|
||||
scored = []
|
||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||
if legal_action[action] <= 0:
|
||||
continue
|
||||
next_dx = self.nearest_charger_dx - dx
|
||||
next_dz = self.nearest_charger_dz - dz
|
||||
next_dist = max(abs(next_dx), abs(next_dz))
|
||||
improvement = current_dist - next_dist
|
||||
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
|
||||
scored.append((improvement, alignment, action))
|
||||
|
||||
if not scored:
|
||||
return list(legal_action)
|
||||
|
||||
best_improvement = max(item[0] for item in scored)
|
||||
recharge = [0] * 8
|
||||
if best_improvement > 0:
|
||||
for improvement, _, action in scored:
|
||||
if improvement >= best_improvement - 0.1:
|
||||
recharge[action] = 1
|
||||
else:
|
||||
best_alignment = max(item[1] for item in scored)
|
||||
for _, alignment, action in scored:
|
||||
if alignment >= best_alignment:
|
||||
recharge[action] = 1
|
||||
|
||||
return recharge if any(recharge) else list(legal_action)
|
||||
|
||||
def feature_process(self, env_obs, last_action):
|
||||
"""Generate 69D feature vector, legal action mask, and scalar reward.
|
||||
"""Generate feature vector, legal action mask, and scalar reward.
|
||||
|
||||
生成 69D 特征向量、合法动作掩码和标量奖励。
|
||||
生成特征向量、合法动作掩码和标量奖励。
|
||||
"""
|
||||
self.pb2struct(env_obs, last_action)
|
||||
|
||||
local_view = self._get_local_view_feature() # 49D
|
||||
global_state = self._get_global_state_feature() # 12D
|
||||
local_view = self._get_local_view_feature() # 121D
|
||||
global_state = self._get_global_state_feature() # 28D
|
||||
legal_action = self.get_legal_action() # 8D
|
||||
|
||||
last_action_feature = np.zeros(8, dtype=np.float32)
|
||||
if 0 <= last_action < 8:
|
||||
last_action_feature[last_action] = 1.0
|
||||
|
||||
# The legal action mask is passed separately to PPO. Reusing this 8D slot
|
||||
# for action history makes the 69D observation more informative without
|
||||
# breaking the framework's fixed tensor shape.
|
||||
feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D
|
||||
feature = np.concatenate([local_view, global_state, last_action_feature])
|
||||
|
||||
reward = self.reward_process()
|
||||
|
||||
@@ -273,7 +597,8 @@ class Preprocessor:
|
||||
def reward_process(self):
|
||||
# Cleaning reward / 清扫奖励
|
||||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||||
cleaning_reward = 0.25 * cleaned_this_step
|
||||
cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step
|
||||
cleaning_reward = 0.7 * cleaned_cells
|
||||
|
||||
# Step penalty / 时间惩罚
|
||||
step_penalty = -0.002
|
||||
@@ -281,13 +606,28 @@ class Preprocessor:
|
||||
# Dense guidance: prefer moving toward visible dirt.
|
||||
# 稠密引导:鼓励向视野内污渍靠近。
|
||||
approach_reward = 0.0
|
||||
if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0:
|
||||
if not self.recharge_mode and (self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0):
|
||||
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
|
||||
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
|
||||
|
||||
# Recharge guidance only activates when battery safety is the bottleneck.
|
||||
# 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。
|
||||
charge_reward = 0.0
|
||||
if self.charge_delta > 0:
|
||||
charge_reward += 3.0 * self.charge_delta
|
||||
if self.has_charger and (self.recharge_mode or self.low_battery):
|
||||
dist_delta = float(
|
||||
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0)
|
||||
)
|
||||
charge_reward += 0.04 * dist_delta if dist_delta > 0 else 0.02 * dist_delta
|
||||
if self.battery_margin < 0:
|
||||
charge_reward -= min(0.25, abs(self.battery_margin) / max(self.battery_max, 1))
|
||||
elif self.on_charger and self.charge_delta == 0 and self.battery / max(self.battery_max, 1) > 0.85:
|
||||
charge_reward -= 0.01
|
||||
|
||||
# Encourage covering new passable cells and mildly discourage loops.
|
||||
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
||||
exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5)
|
||||
exploration_reward = 0.004 if self.is_new_cell else -0.0015 * min(self.current_visit_count, 6)
|
||||
|
||||
# Collision/stuck signal: invalid moves waste both step and battery.
|
||||
# 撞墙/原地不动会浪费步数和电量。
|
||||
@@ -295,4 +635,26 @@ class Preprocessor:
|
||||
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
||||
stuck_penalty = -0.03
|
||||
|
||||
return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty
|
||||
npc_penalty = 0.0
|
||||
if self.npc_danger:
|
||||
npc_penalty -= 4.0
|
||||
elif self.nearest_npc_dist <= 3:
|
||||
npc_penalty -= 0.05 * (4 - self.nearest_npc_dist)
|
||||
|
||||
terminal_penalty = 0.0
|
||||
if self.terminated and not self.truncated:
|
||||
if self.battery <= 0 or self.remaining_charge <= 0:
|
||||
terminal_penalty -= 4.0
|
||||
elif self.npc_danger or self.nearest_npc_dist <= 1:
|
||||
terminal_penalty -= 3.0
|
||||
|
||||
return (
|
||||
cleaning_reward
|
||||
+ approach_reward
|
||||
+ charge_reward
|
||||
+ exploration_reward
|
||||
+ stuck_penalty
|
||||
+ npc_penalty
|
||||
+ terminal_penalty
|
||||
+ step_penalty
|
||||
)
|
||||
|
||||
@@ -38,22 +38,22 @@ class Model(nn.Module):
|
||||
self.model_name = "robot_vacuum"
|
||||
self.device = device
|
||||
|
||||
obs_dim = Config.DIM_OF_OBSERVATION # 69
|
||||
obs_dim = Config.DIM_OF_OBSERVATION # 157
|
||||
act_num = Config.ACTION_NUM # 8
|
||||
|
||||
# Shared backbone / 共享骨干网络
|
||||
self.backbone = nn.Sequential(
|
||||
_make_fc(obs_dim, 128),
|
||||
_make_fc(obs_dim, 256),
|
||||
nn.ReLU(),
|
||||
_make_fc(128, 64),
|
||||
_make_fc(256, 128),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Actor head: outputs action logits / 策略头:输出动作 logits
|
||||
self.actor_head = _make_fc(64, act_num, gain=0.01)
|
||||
self.actor_head = _make_fc(128, act_num, gain=0.01)
|
||||
|
||||
# Critic head: outputs single state value / 价值头:输出单个状态价值
|
||||
self.critic_head = _make_fc(64, 1, gain=0.01)
|
||||
self.critic_head = _make_fc(128, 1, gain=0.01)
|
||||
|
||||
def forward(self, s, inference=False):
|
||||
"""Forward pass.
|
||||
|
||||
@@ -132,7 +132,14 @@ class EpisodeRunner:
|
||||
final_reward = 0.0
|
||||
if done:
|
||||
fm = self.agent.preprocessor
|
||||
total_score = env_obs["observation"]["env_info"]["total_score"]
|
||||
env_info = env_obs["observation"]["env_info"]
|
||||
extra_info = env_obs.get("extra_info", {})
|
||||
total_score = env_info.get("total_score", fm.total_score)
|
||||
remaining_charge = env_info.get("remaining_charge", fm.remaining_charge)
|
||||
charge_count = env_info.get("charge_count", fm.charge_count)
|
||||
finished_steps = env_info.get("finished_steps", step)
|
||||
result_message = extra_info.get("result_message", "")
|
||||
result_code = extra_info.get("result_code", "")
|
||||
|
||||
if truncated:
|
||||
# Survived to max steps: higher cleaning ratio → more reward
|
||||
@@ -141,19 +148,31 @@ class EpisodeRunner:
|
||||
final_reward = 2.0 + 8.0 * cleaning_ratio
|
||||
result_str = "WIN"
|
||||
else:
|
||||
# Battery-depleted episodes are common with short runs; keep
|
||||
# cleaning progress as the dominant terminal signal.
|
||||
# 短训中电量耗尽较常见,终局奖励仍以清扫比例为主。
|
||||
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
||||
final_reward = -1.0 + 6.0 * cleaning_ratio
|
||||
result_str = "FAIL"
|
||||
if fm.battery <= 0 or remaining_charge <= 0:
|
||||
final_reward = -4.0 + 6.0 * cleaning_ratio
|
||||
result_str = "BATTERY_FAIL"
|
||||
elif fm.npc_danger or fm.nearest_npc_dist <= 1:
|
||||
final_reward = -3.0 + 6.0 * cleaning_ratio
|
||||
result_str = "NPC_FAIL"
|
||||
else:
|
||||
final_reward = -2.0 + 6.0 * cleaning_ratio
|
||||
result_str = "FAIL"
|
||||
|
||||
self.logger.info(
|
||||
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
|
||||
f"finished_steps:{finished_steps} "
|
||||
f"result:{result_str} final_bonus:{final_reward:.2f} "
|
||||
f"total_reward:{total_reward:.3f} "
|
||||
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} "
|
||||
f"total_score:{total_score}"
|
||||
f"total_score:{total_score} "
|
||||
f"remaining_charge:{remaining_charge} "
|
||||
f"charge_count:{charge_count} "
|
||||
f"recharge_steps:{fm.recharge_steps} "
|
||||
f"nearest_charger:{fm.nearest_charger_range_dist:.1f} "
|
||||
f"nearest_npc:{fm.nearest_npc_dist:.1f} "
|
||||
f"result_code:{result_code} "
|
||||
f"result_message:{result_message}"
|
||||
)
|
||||
|
||||
# Build sample frame
|
||||
|
||||
Reference in New Issue
Block a user