优化PPO充电与避障策略
扩展观测特征到157维,加入充电桩、NPC、电量安全余量、地图统计和本步清扫信息。 增加低电量回充动作过滤、NPC危险区过滤,并调整奖励和终局日志以突出充电、避障和真实清扫得分。
This commit is contained in:
@@ -55,9 +55,9 @@ class Agent(BaseAgent):
|
|||||||
self.last_reward = 0.0
|
self.last_reward = 0.0
|
||||||
|
|
||||||
def observation_process(self, env_obs):
|
def observation_process(self, env_obs):
|
||||||
"""Convert raw env_obs to ObsData (69D feature + legal action mask).
|
"""Convert raw env_obs to ObsData (feature + legal action mask).
|
||||||
|
|
||||||
将原始 env_obs 转换为 ObsData(69D 特征 + 合法动作掩码)。
|
将原始 env_obs 转换为 ObsData(特征 + 合法动作掩码)。
|
||||||
"""
|
"""
|
||||||
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
||||||
self.last_reward = reward
|
self.last_reward = reward
|
||||||
@@ -135,7 +135,15 @@ class Agent(BaseAgent):
|
|||||||
加载模型检查点。
|
加载模型检查点。
|
||||||
"""
|
"""
|
||||||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||||||
self.model.load_state_dict(torch.load(model_file_path, map_location=self.device))
|
state_dict = torch.load(model_file_path, map_location=self.device)
|
||||||
|
try:
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
except RuntimeError as err:
|
||||||
|
msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(msg)
|
||||||
|
return
|
||||||
|
if self.logger:
|
||||||
self.logger.info(f"load model {model_file_path} successfully")
|
self.logger.info(f"load model {model_file_path} successfully")
|
||||||
|
|
||||||
def _run_model(self, feature):
|
def _run_model(self, feature):
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ Configuration for Robot Vacuum PPO agent.
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
||||||
# Feature dimensions (69D)
|
# Feature dimensions (157D)
|
||||||
# 特征维度(69D)
|
# 特征维度(157D)
|
||||||
FEATURES = [
|
FEATURES = [
|
||||||
7 * 7,
|
11 * 11, # wider local map view / 更大的局部地图视野
|
||||||
12,
|
28, # global, charger, NPC, and map-stat features / 全局、充电桩、NPC、地图统计特征
|
||||||
8,
|
8, # last action one-hot / 上一步动作 one-hot
|
||||||
]
|
]
|
||||||
FEATURE_SPLIT_SHAPE = FEATURES
|
FEATURE_SPLIT_SHAPE = FEATURES
|
||||||
FEATURE_LEN = sum(FEATURES)
|
FEATURE_LEN = sum(FEATURES)
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ ActData = create_cls(
|
|||||||
# 训练样本数据:字段值为 int 时框架自动按维度处理
|
# 训练样本数据:字段值为 int 时框架自动按维度处理
|
||||||
SampleData = create_cls(
|
SampleData = create_cls(
|
||||||
"SampleData",
|
"SampleData",
|
||||||
obs=Config.DIM_OF_OBSERVATION, # 69D feature vector / 特征向量
|
obs=Config.DIM_OF_OBSERVATION, # feature vector / 特征向量
|
||||||
legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码
|
legal_action=Config.ACTION_NUM, # 8D legal action mask / 合法动作掩码
|
||||||
act=1, # action index / 执行的动作
|
act=1, # action index / 执行的动作
|
||||||
reward=Config.VALUE_NUM, # 1D reward / 奖励
|
reward=Config.VALUE_NUM, # 1D reward / 奖励
|
||||||
|
|||||||
@@ -24,6 +24,13 @@ def _norm(v, v_max, v_min=0.0):
|
|||||||
return (v - v_min) / (v_max - v_min)
|
return (v - v_min) / (v_max - v_min)
|
||||||
|
|
||||||
|
|
||||||
|
def _signed_norm(v, v_max):
|
||||||
|
"""Normalize signed value to [-1, 1]."""
|
||||||
|
if v_max <= 0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.clip(float(v) / float(v_max), -1.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
class Preprocessor:
|
class Preprocessor:
|
||||||
"""Feature preprocessor for Robot Vacuum.
|
"""Feature preprocessor for Robot Vacuum.
|
||||||
|
|
||||||
@@ -32,7 +39,17 @@ class Preprocessor:
|
|||||||
|
|
||||||
GRID_SIZE = 128
|
GRID_SIZE = 128
|
||||||
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
|
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
|
||||||
LOCAL_HALF = 3 # Cropped view radius (7×7) / 裁剪后的视野半径
|
LOCAL_HALF = 5 # Cropped view radius (11×11) / 裁剪后的视野半径
|
||||||
|
ACTION_DIRS = (
|
||||||
|
(1, 0),
|
||||||
|
(1, -1),
|
||||||
|
(0, -1),
|
||||||
|
(-1, -1),
|
||||||
|
(-1, 0),
|
||||||
|
(-1, 1),
|
||||||
|
(0, 1),
|
||||||
|
(1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -56,6 +73,10 @@ class Preprocessor:
|
|||||||
self.dirt_cleaned = 0
|
self.dirt_cleaned = 0
|
||||||
self.last_dirt_cleaned = 0
|
self.last_dirt_cleaned = 0
|
||||||
self.total_dirt = 1
|
self.total_dirt = 1
|
||||||
|
self.total_score = 0
|
||||||
|
self.clean_score = 0
|
||||||
|
self.step_cleaned_count = 0
|
||||||
|
self.max_step = 1000
|
||||||
|
|
||||||
# Global passable map (0=obstacle, 1=passable), used for ray computation
|
# Global passable map (0=obstacle, 1=passable), used for ray computation
|
||||||
# 维护全局通行地图(0=障碍, 1=可通行),用于射线计算
|
# 维护全局通行地图(0=障碍, 1=可通行),用于射线计算
|
||||||
@@ -69,6 +90,33 @@ class Preprocessor:
|
|||||||
|
|
||||||
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
||||||
self._legal_act = [1] * 8
|
self._legal_act = [1] * 8
|
||||||
|
self.terminated = False
|
||||||
|
self.truncated = False
|
||||||
|
|
||||||
|
self.remaining_charge = 0
|
||||||
|
self.charge_count = 0
|
||||||
|
self.last_charge_count = 0
|
||||||
|
self.charge_delta = 0
|
||||||
|
self.nearest_charger_dx = 0.0
|
||||||
|
self.nearest_charger_dz = 0.0
|
||||||
|
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||||
|
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||||
|
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||||
|
self.battery_margin = 0.0
|
||||||
|
self.has_charger = False
|
||||||
|
self.low_battery = False
|
||||||
|
self.on_charger = False
|
||||||
|
self.recharge_mode = False
|
||||||
|
self.recharge_steps = 0
|
||||||
|
|
||||||
|
self.nearest_npc_dx = 0.0
|
||||||
|
self.nearest_npc_dz = 0.0
|
||||||
|
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||||||
|
self.npc_danger = False
|
||||||
|
self.npcs = []
|
||||||
|
|
||||||
|
self.local_dirt_ratio = 0.0
|
||||||
|
self.local_obstacle_ratio = 0.0
|
||||||
|
|
||||||
def pb2struct(self, env_obs, last_action):
|
def pb2struct(self, env_obs, last_action):
|
||||||
"""Parse and cache essential fields from observation dict.
|
"""Parse and cache essential fields from observation dict.
|
||||||
@@ -76,14 +124,20 @@ class Preprocessor:
|
|||||||
从 env_obs 字典中提取并缓存所有需要的状态量。
|
从 env_obs 字典中提取并缓存所有需要的状态量。
|
||||||
"""
|
"""
|
||||||
observation = env_obs["observation"]
|
observation = env_obs["observation"]
|
||||||
frame_state = observation["frame_state"]
|
frame_state = observation.get("frame_state", {})
|
||||||
env_info = observation["env_info"]
|
extra_frame_state = env_obs.get("extra_info", {}).get("frame_state", {})
|
||||||
hero = frame_state["heroes"]
|
env_info = observation.get("env_info", {})
|
||||||
|
hero = frame_state.get("heroes", {})
|
||||||
|
if isinstance(hero, list):
|
||||||
|
hero = hero[0] if hero else {}
|
||||||
|
|
||||||
self.last_action = int(last_action)
|
self.last_action = int(last_action)
|
||||||
self.step_no = int(observation["step_no"])
|
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
||||||
|
self.terminated = bool(env_obs.get("terminated", False))
|
||||||
|
self.truncated = bool(env_obs.get("truncated", False))
|
||||||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||||||
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
|
hero_pos = hero.get("pos") or env_info.get("pos") or {"x": self.cur_pos[0], "z": self.cur_pos[1]}
|
||||||
|
self.cur_pos = (int(hero_pos.get("x", self.cur_pos[0])), int(hero_pos.get("z", self.cur_pos[1])))
|
||||||
self.has_position_history = True
|
self.has_position_history = True
|
||||||
|
|
||||||
hx, hz = self.cur_pos
|
hx, hz = self.cur_pos
|
||||||
@@ -96,16 +150,30 @@ class Preprocessor:
|
|||||||
self.is_new_cell = False
|
self.is_new_cell = False
|
||||||
|
|
||||||
# Battery / 电量
|
# Battery / 电量
|
||||||
self.battery = int(hero["battery"])
|
self.battery = int(hero.get("battery", env_info.get("remaining_charge", self.battery)))
|
||||||
self.battery_max = max(int(hero["battery_max"]), 1)
|
self.battery_max = max(int(hero.get("battery_max", env_info.get("battery_max", self.battery_max))), 1)
|
||||||
|
self.remaining_charge = int(env_info.get("remaining_charge", self.battery))
|
||||||
|
self.max_step = max(int(env_info.get("max_step", self.max_step)), 1)
|
||||||
|
|
||||||
# Cleaning progress / 清扫进度
|
# Cleaning progress / 清扫进度
|
||||||
self.last_dirt_cleaned = self.dirt_cleaned
|
self.last_dirt_cleaned = self.dirt_cleaned
|
||||||
self.dirt_cleaned = int(hero["dirt_cleaned"])
|
self.dirt_cleaned = int(hero.get("dirt_cleaned", env_info.get("clean_score", self.dirt_cleaned)))
|
||||||
self.total_dirt = max(int(env_info["total_dirt"]), 1)
|
self.total_dirt = max(int(env_info.get("total_dirt", self.total_dirt)), 1)
|
||||||
|
self.total_score = int(env_info.get("total_score", self.total_score))
|
||||||
|
self.clean_score = int(env_info.get("clean_score", self.dirt_cleaned))
|
||||||
|
step_cleaned_cells = env_info.get("step_cleaned_cells") or []
|
||||||
|
self.step_cleaned_count = len(step_cleaned_cells)
|
||||||
|
|
||||||
|
# Charge progress / 充电进度
|
||||||
|
self.last_charge_count = self.charge_count
|
||||||
|
self.charge_count = int(env_info.get("charge_count", self.charge_count))
|
||||||
|
self.charge_delta = max(0, self.charge_count - self.last_charge_count)
|
||||||
|
|
||||||
# Legal actions / 合法动作
|
# Legal actions / 合法动作
|
||||||
self._legal_act = [int(x) for x in (observation.get("legal_action") or [1] * 8)]
|
raw_legal_act = observation.get("legal_action") or observation.get("legal_act") or [1] * 8
|
||||||
|
self._legal_act = [int(x) for x in raw_legal_act[:8]]
|
||||||
|
if len(self._legal_act) < 8:
|
||||||
|
self._legal_act.extend([1] * (8 - len(self._legal_act)))
|
||||||
|
|
||||||
# Local view map (21×21) / 局部视野地图
|
# Local view map (21×21) / 局部视野地图
|
||||||
map_info = observation.get("map_info")
|
map_info = observation.get("map_info")
|
||||||
@@ -113,6 +181,14 @@ class Preprocessor:
|
|||||||
self._view_map = np.array(map_info, dtype=np.float32)
|
self._view_map = np.array(map_info, dtype=np.float32)
|
||||||
hx, hz = self.cur_pos
|
hx, hz = self.cur_pos
|
||||||
self._update_passable(hx, hz)
|
self._update_passable(hx, hz)
|
||||||
|
self._update_local_map_stats()
|
||||||
|
|
||||||
|
organs = frame_state.get("organs") or extra_frame_state.get("organs") or []
|
||||||
|
npcs = frame_state.get("npcs") or extra_frame_state.get("npcs") or []
|
||||||
|
self.npcs = list(npcs)
|
||||||
|
self._update_charger_state(hx, hz, organs)
|
||||||
|
self._update_npc_state(hx, hz, self.npcs)
|
||||||
|
self._update_recharge_mode()
|
||||||
|
|
||||||
def _update_passable(self, hx, hz):
|
def _update_passable(self, hx, hz):
|
||||||
"""Write local view into global passable map.
|
"""Write local view into global passable map.
|
||||||
@@ -132,10 +208,130 @@ class Preprocessor:
|
|||||||
# 0 = 障碍, 1/2 = 可通行
|
# 0 = 障碍, 1/2 = 可通行
|
||||||
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
||||||
|
|
||||||
def _get_local_view_feature(self):
|
def _update_local_map_stats(self):
|
||||||
"""Local view feature (49D): crop center 7×7 from 21×21.
|
"""Cache coarse 21x21 map statistics."""
|
||||||
|
view = self._view_map
|
||||||
|
if view is None or view.size == 0:
|
||||||
|
self.local_dirt_ratio = 0.0
|
||||||
|
self.local_obstacle_ratio = 0.0
|
||||||
|
return
|
||||||
|
total = float(view.size)
|
||||||
|
self.local_dirt_ratio = float(np.sum(view == 2) / total)
|
||||||
|
self.local_obstacle_ratio = float(np.sum(view == 0) / total)
|
||||||
|
|
||||||
局部视野特征(49D):从 21×21 视野中心裁剪 7×7。
|
def _update_charger_state(self, hx, hz, organs):
|
||||||
|
"""Find nearest charger and cache distance/direction features."""
|
||||||
|
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
||||||
|
self.has_charger = False
|
||||||
|
self.on_charger = False
|
||||||
|
self.nearest_charger_dx = 0.0
|
||||||
|
self.nearest_charger_dz = 0.0
|
||||||
|
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||||
|
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||||
|
|
||||||
|
best = None
|
||||||
|
for organ in organs:
|
||||||
|
if not isinstance(organ, dict):
|
||||||
|
continue
|
||||||
|
if int(organ.get("sub_type", 1)) != 1:
|
||||||
|
continue
|
||||||
|
pos = organ.get("pos") or {}
|
||||||
|
ox = int(pos.get("x", 0))
|
||||||
|
oz = int(pos.get("z", 0))
|
||||||
|
w = max(int(organ.get("w", 3)), 1)
|
||||||
|
h = max(int(organ.get("h", 3)), 1)
|
||||||
|
|
||||||
|
for rx, rz in ((ox, oz), (ox - w // 2, oz - h // 2)):
|
||||||
|
dx, dz = self._relative_vector_to_rect(hx, hz, rx, rz, w, h)
|
||||||
|
dist = float(np.sqrt(dx * dx + dz * dz))
|
||||||
|
range_dist = float(max(abs(dx), abs(dz)))
|
||||||
|
if best is None or range_dist < best[0] or (range_dist == best[0] and dist < best[1]):
|
||||||
|
best = (range_dist, dist, dx, dz)
|
||||||
|
|
||||||
|
if best is None:
|
||||||
|
self.battery_margin = float(self.battery)
|
||||||
|
return
|
||||||
|
|
||||||
|
range_dist, dist, dx, dz = best
|
||||||
|
self.has_charger = True
|
||||||
|
self.nearest_charger_dx = float(dx)
|
||||||
|
self.nearest_charger_dz = float(dz)
|
||||||
|
self.nearest_charger_dist = float(dist)
|
||||||
|
self.nearest_charger_range_dist = float(range_dist)
|
||||||
|
self.on_charger = range_dist <= 0.0
|
||||||
|
self.battery_margin = float(self.battery) - self.nearest_charger_range_dist
|
||||||
|
|
||||||
|
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
|
||||||
|
"""Relative vector from point to the nearest cell in a rectangle."""
|
||||||
|
if x < rx:
|
||||||
|
dx = rx - x
|
||||||
|
elif x > rx + w - 1:
|
||||||
|
dx = x - (rx + w - 1)
|
||||||
|
else:
|
||||||
|
dx = 0
|
||||||
|
|
||||||
|
if z < rz:
|
||||||
|
dz = rz - z
|
||||||
|
elif z > rz + h - 1:
|
||||||
|
dz = z - (rz + h - 1)
|
||||||
|
else:
|
||||||
|
dz = 0
|
||||||
|
return float(dx), float(dz)
|
||||||
|
|
||||||
|
def _update_npc_state(self, hx, hz, npcs):
|
||||||
|
"""Find nearest NPC and cache safety features."""
|
||||||
|
self.nearest_npc_dx = 0.0
|
||||||
|
self.nearest_npc_dz = 0.0
|
||||||
|
self.nearest_npc_dist = float(self.GRID_SIZE)
|
||||||
|
self.npc_danger = False
|
||||||
|
|
||||||
|
best = None
|
||||||
|
for npc in npcs:
|
||||||
|
if not isinstance(npc, dict):
|
||||||
|
continue
|
||||||
|
pos = npc.get("pos") or {}
|
||||||
|
nx = int(pos.get("x", 0))
|
||||||
|
nz = int(pos.get("z", 0))
|
||||||
|
dx = nx - hx
|
||||||
|
dz = nz - hz
|
||||||
|
cheb = float(max(abs(dx), abs(dz)))
|
||||||
|
if best is None or cheb < best[0]:
|
||||||
|
best = (cheb, dx, dz)
|
||||||
|
|
||||||
|
if best is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
cheb, dx, dz = best
|
||||||
|
self.nearest_npc_dx = float(dx)
|
||||||
|
self.nearest_npc_dz = float(dz)
|
||||||
|
self.nearest_npc_dist = float(cheb)
|
||||||
|
self.npc_danger = abs(dx) <= 1 and abs(dz) <= 1
|
||||||
|
|
||||||
|
def _update_recharge_mode(self):
|
||||||
|
"""Enter/exit low-battery recharge mode."""
|
||||||
|
battery_ratio = self.battery / max(self.battery_max, 1)
|
||||||
|
self.low_battery = battery_ratio < 0.35
|
||||||
|
|
||||||
|
if not self.has_charger:
|
||||||
|
self.recharge_mode = False
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.charge_delta > 0 or (self.on_charger and battery_ratio > 0.85):
|
||||||
|
self.recharge_mode = False
|
||||||
|
elif self.battery <= self.nearest_charger_range_dist + 18 or battery_ratio < 0.22:
|
||||||
|
self.recharge_mode = True
|
||||||
|
elif self.recharge_mode and battery_ratio < 0.85:
|
||||||
|
self.recharge_mode = True
|
||||||
|
else:
|
||||||
|
self.recharge_mode = False
|
||||||
|
|
||||||
|
if self.recharge_mode:
|
||||||
|
self.recharge_steps += 1
|
||||||
|
|
||||||
|
def _get_local_view_feature(self):
|
||||||
|
"""Local view feature (121D): crop center 11×11 from 21×21.
|
||||||
|
|
||||||
|
局部视野特征(121D):从 21×21 视野中心裁剪 11×11。
|
||||||
"""
|
"""
|
||||||
center = self.VIEW_HALF
|
center = self.VIEW_HALF
|
||||||
h = self.LOCAL_HALF
|
h = self.LOCAL_HALF
|
||||||
@@ -143,9 +339,9 @@ class Preprocessor:
|
|||||||
return (crop / 2.0).flatten()
|
return (crop / 2.0).flatten()
|
||||||
|
|
||||||
def _get_global_state_feature(self):
|
def _get_global_state_feature(self):
|
||||||
"""Global state feature (12D).
|
"""Global state feature (28D).
|
||||||
|
|
||||||
全局状态特征(12D)。
|
全局状态特征(28D)。
|
||||||
|
|
||||||
Dimensions / 维度说明:
|
Dimensions / 维度说明:
|
||||||
[0] step_norm step progress / 步数归一化 [0,1]
|
[0] step_norm step progress / 步数归一化 [0,1]
|
||||||
@@ -160,8 +356,24 @@ class Preprocessor:
|
|||||||
[9] ray_W_dirt west ray distance / 向左(x-)方向
|
[9] ray_W_dirt west ray distance / 向左(x-)方向
|
||||||
[10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化
|
[10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化
|
||||||
[11] dirt_delta approaching dirt indicator / 是否在接近污渍(1=是, 0=否)
|
[11] dirt_delta approaching dirt indicator / 是否在接近污渍(1=是, 0=否)
|
||||||
|
[12] charger_dx nearest charger x direction / 最近充电桩 x 相对方向
|
||||||
|
[13] charger_dz nearest charger z direction / 最近充电桩 z 相对方向
|
||||||
|
[14] charger_dist nearest charger distance / 最近充电桩距离
|
||||||
|
[15] battery_margin battery minus charger distance / 电量安全余量
|
||||||
|
[16] low_battery low-battery flag / 低电量标记
|
||||||
|
[17] recharge_mode recharge-mode flag / 回充模式标记
|
||||||
|
[18] on_charger on charger flag / 是否在充电桩范围
|
||||||
|
[19] charge_delta charge count increased / 本步是否成功充电
|
||||||
|
[20] npc_dx nearest NPC x direction / 最近 NPC x 相对方向
|
||||||
|
[21] npc_dz nearest NPC z direction / 最近 NPC z 相对方向
|
||||||
|
[22] npc_dist nearest NPC Chebyshev distance / 最近 NPC 切比雪夫距离
|
||||||
|
[23] npc_danger in NPC 3x3 danger zone / 是否处于 NPC 3x3 危险区
|
||||||
|
[24] local_dirt_ratio dirt ratio in 21x21 view / 21x21 视野污渍比例
|
||||||
|
[25] obstacle_ratio obstacle ratio in 21x21 view / 21x21 视野障碍比例
|
||||||
|
[26] visit_count current cell visit count / 当前格访问次数
|
||||||
|
[27] step_cleaned cells cleaned this step / 本步清扫格子数
|
||||||
"""
|
"""
|
||||||
step_norm = _norm(self.step_no, 2000)
|
step_norm = _norm(self.step_no, self.max_step)
|
||||||
battery_ratio = _norm(self.battery, self.battery_max)
|
battery_ratio = _norm(self.battery, self.battery_max)
|
||||||
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
|
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
|
||||||
remaining_dirt = 1.0 - cleaning_progress
|
remaining_dirt = 1.0 - cleaning_progress
|
||||||
@@ -205,6 +417,10 @@ class Preprocessor:
|
|||||||
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
||||||
|
|
||||||
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
|
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
|
||||||
|
charge_delta = 1.0 if self.charge_delta > 0 else 0.0
|
||||||
|
battery_margin_norm = _signed_norm(self.battery_margin, self.battery_max)
|
||||||
|
visit_count_norm = _norm(min(self.current_visit_count, 10), 10)
|
||||||
|
step_cleaned_norm = _norm(self.step_cleaned_count, 9)
|
||||||
|
|
||||||
return np.array(
|
return np.array(
|
||||||
[
|
[
|
||||||
@@ -220,6 +436,22 @@ class Preprocessor:
|
|||||||
ray_dirt[3],
|
ray_dirt[3],
|
||||||
nearest_dirt_norm,
|
nearest_dirt_norm,
|
||||||
dirt_delta,
|
dirt_delta,
|
||||||
|
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
|
||||||
|
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
|
||||||
|
_norm(self.nearest_charger_range_dist, self.GRID_SIZE),
|
||||||
|
battery_margin_norm,
|
||||||
|
1.0 if self.low_battery else 0.0,
|
||||||
|
1.0 if self.recharge_mode else 0.0,
|
||||||
|
1.0 if self.on_charger else 0.0,
|
||||||
|
charge_delta,
|
||||||
|
_signed_norm(self.nearest_npc_dx, 20),
|
||||||
|
_signed_norm(self.nearest_npc_dz, 20),
|
||||||
|
_norm(self.nearest_npc_dist, 20),
|
||||||
|
1.0 if self.npc_danger else 0.0,
|
||||||
|
self.local_dirt_ratio,
|
||||||
|
self.local_obstacle_ratio,
|
||||||
|
visit_count_norm,
|
||||||
|
step_cleaned_norm,
|
||||||
],
|
],
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
@@ -244,27 +476,119 @@ class Preprocessor:
|
|||||||
|
|
||||||
返回合法动作掩码(8D list)。
|
返回合法动作掩码(8D list)。
|
||||||
"""
|
"""
|
||||||
return list(self._legal_act)
|
legal = self._filter_blocked_actions(self._legal_act)
|
||||||
|
legal = self._filter_npc_danger_actions(legal)
|
||||||
|
if self.recharge_mode:
|
||||||
|
legal = self._filter_recharge_actions(legal)
|
||||||
|
return list(legal)
|
||||||
|
|
||||||
|
def _filter_blocked_actions(self, legal_action):
|
||||||
|
"""Filter actions that are visibly blocked in the 21x21 view."""
|
||||||
|
legal = [int(x) for x in legal_action]
|
||||||
|
hx, hz = self.cur_pos
|
||||||
|
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||||
|
if legal[action] <= 0:
|
||||||
|
continue
|
||||||
|
if not self._is_visible_cell_passable(dx, dz):
|
||||||
|
legal[action] = 0
|
||||||
|
continue
|
||||||
|
if dx != 0 and dz != 0:
|
||||||
|
side_a = self._is_visible_cell_passable(dx, 0)
|
||||||
|
side_b = self._is_visible_cell_passable(0, dz)
|
||||||
|
if not (side_a or side_b):
|
||||||
|
legal[action] = 0
|
||||||
|
nx, nz = hx + dx, hz + dz
|
||||||
|
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
|
||||||
|
legal[action] = 0
|
||||||
|
|
||||||
|
return legal if any(legal) else [int(x) for x in legal_action]
|
||||||
|
|
||||||
|
def _is_visible_cell_passable(self, dx, dz):
|
||||||
|
"""Whether a relative 21x21-view cell is passable."""
|
||||||
|
ri = self.VIEW_HALF + dx
|
||||||
|
ci = self.VIEW_HALF + dz
|
||||||
|
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
||||||
|
return True
|
||||||
|
return int(self._view_map[ri, ci]) != 0
|
||||||
|
|
||||||
|
def _filter_npc_danger_actions(self, legal_action):
|
||||||
|
"""Avoid actions that would enter any NPC 3x3 danger zone."""
|
||||||
|
if not self.npcs:
|
||||||
|
return list(legal_action)
|
||||||
|
|
||||||
|
hx, hz = self.cur_pos
|
||||||
|
safe = [int(x) for x in legal_action]
|
||||||
|
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||||
|
if safe[action] <= 0:
|
||||||
|
continue
|
||||||
|
nx, nz = hx + dx, hz + dz
|
||||||
|
if self._is_npc_danger_cell(nx, nz):
|
||||||
|
safe[action] = 0
|
||||||
|
|
||||||
|
return safe if any(safe) else list(legal_action)
|
||||||
|
|
||||||
|
def _is_npc_danger_cell(self, x, z):
|
||||||
|
for npc in self.npcs:
|
||||||
|
if not isinstance(npc, dict):
|
||||||
|
continue
|
||||||
|
pos = npc.get("pos") or {}
|
||||||
|
nx = int(pos.get("x", -999))
|
||||||
|
nz = int(pos.get("z", -999))
|
||||||
|
if abs(x - nx) <= 1 and abs(z - nz) <= 1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _filter_recharge_actions(self, legal_action):
|
||||||
|
"""Restrict low-battery actions to moves that approach the charger."""
|
||||||
|
if not self.has_charger:
|
||||||
|
return list(legal_action)
|
||||||
|
|
||||||
|
hx, hz = self.cur_pos
|
||||||
|
current_dist = max(abs(self.nearest_charger_dx), abs(self.nearest_charger_dz))
|
||||||
|
scored = []
|
||||||
|
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||||
|
if legal_action[action] <= 0:
|
||||||
|
continue
|
||||||
|
next_dx = self.nearest_charger_dx - dx
|
||||||
|
next_dz = self.nearest_charger_dz - dz
|
||||||
|
next_dist = max(abs(next_dx), abs(next_dz))
|
||||||
|
improvement = current_dist - next_dist
|
||||||
|
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
|
||||||
|
scored.append((improvement, alignment, action))
|
||||||
|
|
||||||
|
if not scored:
|
||||||
|
return list(legal_action)
|
||||||
|
|
||||||
|
best_improvement = max(item[0] for item in scored)
|
||||||
|
recharge = [0] * 8
|
||||||
|
if best_improvement > 0:
|
||||||
|
for improvement, _, action in scored:
|
||||||
|
if improvement >= best_improvement - 0.1:
|
||||||
|
recharge[action] = 1
|
||||||
|
else:
|
||||||
|
best_alignment = max(item[1] for item in scored)
|
||||||
|
for _, alignment, action in scored:
|
||||||
|
if alignment >= best_alignment:
|
||||||
|
recharge[action] = 1
|
||||||
|
|
||||||
|
return recharge if any(recharge) else list(legal_action)
|
||||||
|
|
||||||
def feature_process(self, env_obs, last_action):
|
def feature_process(self, env_obs, last_action):
|
||||||
"""Generate 69D feature vector, legal action mask, and scalar reward.
|
"""Generate feature vector, legal action mask, and scalar reward.
|
||||||
|
|
||||||
生成 69D 特征向量、合法动作掩码和标量奖励。
|
生成特征向量、合法动作掩码和标量奖励。
|
||||||
"""
|
"""
|
||||||
self.pb2struct(env_obs, last_action)
|
self.pb2struct(env_obs, last_action)
|
||||||
|
|
||||||
local_view = self._get_local_view_feature() # 49D
|
local_view = self._get_local_view_feature() # 121D
|
||||||
global_state = self._get_global_state_feature() # 12D
|
global_state = self._get_global_state_feature() # 28D
|
||||||
legal_action = self.get_legal_action() # 8D
|
legal_action = self.get_legal_action() # 8D
|
||||||
|
|
||||||
last_action_feature = np.zeros(8, dtype=np.float32)
|
last_action_feature = np.zeros(8, dtype=np.float32)
|
||||||
if 0 <= last_action < 8:
|
if 0 <= last_action < 8:
|
||||||
last_action_feature[last_action] = 1.0
|
last_action_feature[last_action] = 1.0
|
||||||
|
|
||||||
# The legal action mask is passed separately to PPO. Reusing this 8D slot
|
feature = np.concatenate([local_view, global_state, last_action_feature])
|
||||||
# for action history makes the 69D observation more informative without
|
|
||||||
# breaking the framework's fixed tensor shape.
|
|
||||||
feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D
|
|
||||||
|
|
||||||
reward = self.reward_process()
|
reward = self.reward_process()
|
||||||
|
|
||||||
@@ -273,7 +597,8 @@ class Preprocessor:
|
|||||||
def reward_process(self):
|
def reward_process(self):
|
||||||
# Cleaning reward / 清扫奖励
|
# Cleaning reward / 清扫奖励
|
||||||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||||||
cleaning_reward = 0.25 * cleaned_this_step
|
cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step
|
||||||
|
cleaning_reward = 0.7 * cleaned_cells
|
||||||
|
|
||||||
# Step penalty / 时间惩罚
|
# Step penalty / 时间惩罚
|
||||||
step_penalty = -0.002
|
step_penalty = -0.002
|
||||||
@@ -281,13 +606,28 @@ class Preprocessor:
|
|||||||
# Dense guidance: prefer moving toward visible dirt.
|
# Dense guidance: prefer moving toward visible dirt.
|
||||||
# 稠密引导:鼓励向视野内污渍靠近。
|
# 稠密引导:鼓励向视野内污渍靠近。
|
||||||
approach_reward = 0.0
|
approach_reward = 0.0
|
||||||
if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0:
|
if not self.recharge_mode and (self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0):
|
||||||
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
|
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
|
||||||
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
|
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
|
||||||
|
|
||||||
|
# Recharge guidance only activates when battery safety is the bottleneck.
|
||||||
|
# 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。
|
||||||
|
charge_reward = 0.0
|
||||||
|
if self.charge_delta > 0:
|
||||||
|
charge_reward += 3.0 * self.charge_delta
|
||||||
|
if self.has_charger and (self.recharge_mode or self.low_battery):
|
||||||
|
dist_delta = float(
|
||||||
|
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0)
|
||||||
|
)
|
||||||
|
charge_reward += 0.04 * dist_delta if dist_delta > 0 else 0.02 * dist_delta
|
||||||
|
if self.battery_margin < 0:
|
||||||
|
charge_reward -= min(0.25, abs(self.battery_margin) / max(self.battery_max, 1))
|
||||||
|
elif self.on_charger and self.charge_delta == 0 and self.battery / max(self.battery_max, 1) > 0.85:
|
||||||
|
charge_reward -= 0.01
|
||||||
|
|
||||||
# Encourage covering new passable cells and mildly discourage loops.
|
# Encourage covering new passable cells and mildly discourage loops.
|
||||||
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
||||||
exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5)
|
exploration_reward = 0.004 if self.is_new_cell else -0.0015 * min(self.current_visit_count, 6)
|
||||||
|
|
||||||
# Collision/stuck signal: invalid moves waste both step and battery.
|
# Collision/stuck signal: invalid moves waste both step and battery.
|
||||||
# 撞墙/原地不动会浪费步数和电量。
|
# 撞墙/原地不动会浪费步数和电量。
|
||||||
@@ -295,4 +635,26 @@ class Preprocessor:
|
|||||||
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
||||||
stuck_penalty = -0.03
|
stuck_penalty = -0.03
|
||||||
|
|
||||||
return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty
|
npc_penalty = 0.0
|
||||||
|
if self.npc_danger:
|
||||||
|
npc_penalty -= 4.0
|
||||||
|
elif self.nearest_npc_dist <= 3:
|
||||||
|
npc_penalty -= 0.05 * (4 - self.nearest_npc_dist)
|
||||||
|
|
||||||
|
terminal_penalty = 0.0
|
||||||
|
if self.terminated and not self.truncated:
|
||||||
|
if self.battery <= 0 or self.remaining_charge <= 0:
|
||||||
|
terminal_penalty -= 4.0
|
||||||
|
elif self.npc_danger or self.nearest_npc_dist <= 1:
|
||||||
|
terminal_penalty -= 3.0
|
||||||
|
|
||||||
|
return (
|
||||||
|
cleaning_reward
|
||||||
|
+ approach_reward
|
||||||
|
+ charge_reward
|
||||||
|
+ exploration_reward
|
||||||
|
+ stuck_penalty
|
||||||
|
+ npc_penalty
|
||||||
|
+ terminal_penalty
|
||||||
|
+ step_penalty
|
||||||
|
)
|
||||||
|
|||||||
@@ -38,22 +38,22 @@ class Model(nn.Module):
|
|||||||
self.model_name = "robot_vacuum"
|
self.model_name = "robot_vacuum"
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
obs_dim = Config.DIM_OF_OBSERVATION # 69
|
obs_dim = Config.DIM_OF_OBSERVATION # 157
|
||||||
act_num = Config.ACTION_NUM # 8
|
act_num = Config.ACTION_NUM # 8
|
||||||
|
|
||||||
# Shared backbone / 共享骨干网络
|
# Shared backbone / 共享骨干网络
|
||||||
self.backbone = nn.Sequential(
|
self.backbone = nn.Sequential(
|
||||||
_make_fc(obs_dim, 128),
|
_make_fc(obs_dim, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
_make_fc(128, 64),
|
_make_fc(256, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actor head: outputs action logits / 策略头:输出动作 logits
|
# Actor head: outputs action logits / 策略头:输出动作 logits
|
||||||
self.actor_head = _make_fc(64, act_num, gain=0.01)
|
self.actor_head = _make_fc(128, act_num, gain=0.01)
|
||||||
|
|
||||||
# Critic head: outputs single state value / 价值头:输出单个状态价值
|
# Critic head: outputs single state value / 价值头:输出单个状态价值
|
||||||
self.critic_head = _make_fc(64, 1, gain=0.01)
|
self.critic_head = _make_fc(128, 1, gain=0.01)
|
||||||
|
|
||||||
def forward(self, s, inference=False):
|
def forward(self, s, inference=False):
|
||||||
"""Forward pass.
|
"""Forward pass.
|
||||||
|
|||||||
@@ -132,7 +132,14 @@ class EpisodeRunner:
|
|||||||
final_reward = 0.0
|
final_reward = 0.0
|
||||||
if done:
|
if done:
|
||||||
fm = self.agent.preprocessor
|
fm = self.agent.preprocessor
|
||||||
total_score = env_obs["observation"]["env_info"]["total_score"]
|
env_info = env_obs["observation"]["env_info"]
|
||||||
|
extra_info = env_obs.get("extra_info", {})
|
||||||
|
total_score = env_info.get("total_score", fm.total_score)
|
||||||
|
remaining_charge = env_info.get("remaining_charge", fm.remaining_charge)
|
||||||
|
charge_count = env_info.get("charge_count", fm.charge_count)
|
||||||
|
finished_steps = env_info.get("finished_steps", step)
|
||||||
|
result_message = extra_info.get("result_message", "")
|
||||||
|
result_code = extra_info.get("result_code", "")
|
||||||
|
|
||||||
if truncated:
|
if truncated:
|
||||||
# Survived to max steps: higher cleaning ratio → more reward
|
# Survived to max steps: higher cleaning ratio → more reward
|
||||||
@@ -141,19 +148,31 @@ class EpisodeRunner:
|
|||||||
final_reward = 2.0 + 8.0 * cleaning_ratio
|
final_reward = 2.0 + 8.0 * cleaning_ratio
|
||||||
result_str = "WIN"
|
result_str = "WIN"
|
||||||
else:
|
else:
|
||||||
# Battery-depleted episodes are common with short runs; keep
|
|
||||||
# cleaning progress as the dominant terminal signal.
|
|
||||||
# 短训中电量耗尽较常见,终局奖励仍以清扫比例为主。
|
|
||||||
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
||||||
final_reward = -1.0 + 6.0 * cleaning_ratio
|
if fm.battery <= 0 or remaining_charge <= 0:
|
||||||
|
final_reward = -4.0 + 6.0 * cleaning_ratio
|
||||||
|
result_str = "BATTERY_FAIL"
|
||||||
|
elif fm.npc_danger or fm.nearest_npc_dist <= 1:
|
||||||
|
final_reward = -3.0 + 6.0 * cleaning_ratio
|
||||||
|
result_str = "NPC_FAIL"
|
||||||
|
else:
|
||||||
|
final_reward = -2.0 + 6.0 * cleaning_ratio
|
||||||
result_str = "FAIL"
|
result_str = "FAIL"
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
|
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
|
||||||
|
f"finished_steps:{finished_steps} "
|
||||||
f"result:{result_str} final_bonus:{final_reward:.2f} "
|
f"result:{result_str} final_bonus:{final_reward:.2f} "
|
||||||
f"total_reward:{total_reward:.3f} "
|
f"total_reward:{total_reward:.3f} "
|
||||||
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} "
|
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt} "
|
||||||
f"total_score:{total_score}"
|
f"total_score:{total_score} "
|
||||||
|
f"remaining_charge:{remaining_charge} "
|
||||||
|
f"charge_count:{charge_count} "
|
||||||
|
f"recharge_steps:{fm.recharge_steps} "
|
||||||
|
f"nearest_charger:{fm.nearest_charger_range_dist:.1f} "
|
||||||
|
f"nearest_npc:{fm.nearest_npc_dist:.1f} "
|
||||||
|
f"result_code:{result_code} "
|
||||||
|
f"result_message:{result_message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build sample frame
|
# Build sample frame
|
||||||
|
|||||||
Reference in New Issue
Block a user