优化 PPO 清扫策略

This commit is contained in:
2026-04-26 17:29:03 +08:00
parent f04feb0cd9
commit f44e2483fc
6 changed files with 223 additions and 86 deletions

View File

@@ -192,18 +192,31 @@ class Agent(BaseAgent):
合法动作掩码下的 softmax。
"""
_w, _e = 1e20, 1e-5
tmp = logits - _w * (1.0 - legal_action)
tmp_max = np.max(tmp, keepdims=True)
tmp = np.clip(tmp - tmp_max, -_w, 1)
tmp = (np.exp(tmp) + _e) * legal_action
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
if not np.any(legal):
legal = np.ones(Config.ACTION_NUM, dtype=bool)
masked_logits = np.asarray(logits, dtype=np.float32).copy()
masked_logits[~legal] = -1e9
shifted = masked_logits - np.max(masked_logits)
probs = np.exp(shifted) * legal.astype(np.float32)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = legal.astype(np.float32)
prob_sum = float(np.sum(probs))
return probs / prob_sum
def _legal_sample(self, probs, use_max=False):
"""Sample action from probability distribution (argmax if use_max=True).
按概率分布采样动作use_max=True 时取 argmax
"""
probs = np.asarray(probs, dtype=np.float64)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
else:
probs = probs / prob_sum
if use_max:
return int(np.argmax(probs))
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
return int(np.random.choice(len(probs), p=probs))

View File

@@ -43,14 +43,14 @@ class Algorithm:
训练入口:接收一批 SampleData执行一步梯度更新。
"""
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device)
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device)
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1)
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device)
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device)
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device)
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
obs = self._batch_tensor([s.obs for s in list_sample_data])
legal_action = self._batch_tensor([s.legal_action for s in list_sample_data])
act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1)
old_prob = self._batch_tensor([s.prob for s in list_sample_data])
old_value = self._batch_tensor([s.value for s in list_sample_data])
reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data])
advantage = self._batch_tensor([s.advantage for s in list_sample_data])
reward = self._batch_tensor([s.reward for s in list_sample_data])
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
@@ -194,9 +194,22 @@ class Algorithm:
对 logits 应用合法动作掩码后计算 softmax。
"""
legal_mask = legal_action > 0.5
all_illegal = ~legal_mask.any(dim=1, keepdim=True)
legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask)
safe_logits = logits.masked_fill(~legal_mask, -1e9)
return F.softmax(safe_logits, dim=1)
def _batch_tensor(self, values):
"""Stack framework tensors or raw numpy/list values into a float tensor."""
tensors = []
for value in values:
if isinstance(value, torch.Tensor):
tensor = value.to(self.device, dtype=torch.float32)
else:
tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device)
tensors.append(tensor)
return torch.stack(tensors)
def _entropy_beta(self):
"""Linearly decay entropy regularization for fast early exploration."""
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)

View File

@@ -6,7 +6,7 @@ map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Whether to randomly select maps. Boolean.
# true = randomly pick one from configured maps per episode, false = used sequentially.
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张false表示按顺序抽取地图训练。
map_random = false
map_random = true
# Number of official robots. Range: 1~4 (integer).
# In each round, official robots will be randomly generated on the road according to the configured.

View File

@@ -67,7 +67,19 @@ def _calc_gae(list_sample_data):
gamma = Config.GAMMA
lamda = Config.LAMDA
for sample in reversed(list_sample_data):
delta = -sample.value + sample.reward + gamma * sample.next_value
gae = gae * gamma * lamda + delta
value = _scalar(sample.value)
reward = _scalar(sample.reward)
next_value = _scalar(sample.next_value)
nonterminal = 1.0 - _scalar(sample.done)
delta = reward + gamma * next_value * nonterminal - value
gae = delta + gamma * lamda * nonterminal * gae
sample.advantage = gae
sample.reward_sum = gae + sample.value
sample.reward_sum = gae + value
def _scalar(value):
"""Read the first scalar from numpy/tensor/list values."""
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
arr = np.asarray(value, dtype=np.float32).reshape(-1)
return float(arr[0]) if arr.size else 0.0

View File

@@ -10,6 +10,8 @@ Feature preprocessor for Robot Vacuum.
清扫大作战特征预处理器。
"""
from collections import deque
import numpy as np
@@ -55,6 +57,7 @@ class Preprocessor:
(0, 1),
(1, 1),
)
INF_DIST = 1e6
def __init__(self):
self.reset()
@@ -83,12 +86,12 @@ class Preprocessor:
self.step_cleaned_count = 0
self.max_step = 1000
# Global passable map (0=obstacle, 1=passable), used for ray computation
# 维护全局通行地图0=障碍, 1=可通行),用于射线计算
# Global passable map (0=obstacle, 1=passable), indexed by [x, z].
# 维护全局通行地图0=障碍, 1=可通行),索引为 [x, z]。
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
# Nearest dirt distance
# 最近污渍距离
# Nearest dirt path distance in the current local view.
# 当前局部视野内最近污渍路径距离
self.nearest_dirt_dist = 200.0
self.last_nearest_dirt_dist = 200.0
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
@@ -114,6 +117,8 @@ class Preprocessor:
self.nearest_charger_dist = float(self.GRID_SIZE)
self.nearest_charger_range_dist = float(self.GRID_SIZE)
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
self.nearest_charger_path_dist = float(self.GRID_SIZE)
self.last_nearest_charger_path_dist = float(self.GRID_SIZE)
self.charger_energy_cost = float(self.GRID_SIZE)
self.charger_safety_buffer = 0.0
self.charger_safety_margin = 0.0
@@ -226,13 +231,33 @@ class Preprocessor:
for ri in range(vsize):
for ci in range(vsize):
gx = hx - half + ri
gz = hz - half + ci
gx = hx + ci - half
gz = hz + ri - half
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
# 0 = obstacle, 1/2 = passable
# 0 = 障碍, 1/2 = 可通行
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
def _view_index_to_global(self, ri, ci):
"""Convert local view row/col to global x/z coordinates."""
half = self._view_map.shape[0] // 2
hx, hz = self.cur_pos
return hx + ci - half, hz + ri - half
def _view_delta_to_index(self, dx, dz):
"""Convert global-coordinate dx/dz to local view row/col."""
return self.VIEW_HALF + dz, self.VIEW_HALF + dx
def _view_cell(self, dx, dz, default=None):
"""Read a local-view cell by global-coordinate delta.
`map_info` is row-major: row follows z/down, col follows x/right.
"""
ri, ci = self._view_delta_to_index(dx, dz)
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
return default
return int(self._view_map[ri, ci])
def _update_local_map_stats(self):
"""Cache coarse 21x21 map statistics."""
view = self._view_map
@@ -247,6 +272,7 @@ class Preprocessor:
def _update_charger_state(self, hx, hz, organs):
"""Find nearest charger and cache distance/direction features."""
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
self.last_nearest_charger_path_dist = self.nearest_charger_path_dist
self.has_charger = False
self.on_charger = False
self.nearest_charger_dx = 0.0
@@ -255,6 +281,7 @@ class Preprocessor:
self.nearest_charger_center_dz = 0.0
self.nearest_charger_dist = float(self.GRID_SIZE)
self.nearest_charger_range_dist = float(self.GRID_SIZE)
self.nearest_charger_path_dist = float(self.GRID_SIZE)
self.charger_energy_cost = float(self.GRID_SIZE)
self.charger_safety_buffer = 0.0
self.charger_safety_margin = 0.0
@@ -295,9 +322,11 @@ class Preprocessor:
self.nearest_charger_center_dz = float(center_dz)
self.nearest_charger_dist = float(dist)
self.nearest_charger_range_dist = float(range_dist)
self.charger_energy_cost = float(range_dist)
path_dist = self._local_path_dist_to_charger(hx, hz)
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
self.charger_energy_cost = self.nearest_charger_path_dist
self.on_charger = range_dist <= 0.0
self.battery_margin = float(self.battery) - self.nearest_charger_range_dist
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
"""Relative vector from point to the nearest cell in a rectangle."""
@@ -355,7 +384,7 @@ class Preprocessor:
self.charger_safety_margin = float(self.battery)
return
self.charger_energy_cost = float(max(self.nearest_charger_range_dist, 0.0))
self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0))
self.charger_safety_buffer = self._charger_safety_buffer()
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
@@ -450,30 +479,20 @@ class Preprocessor:
ray_dirt = []
max_ray = 30
for dx, dz in ray_dirs:
x, z = hx, hz
found = max_ray
for step in range(1, max_ray + 1):
x += dx
z += dz
if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
gx = hx + dx * step
gz = hz + dz * step
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
break
cell = self._view_cell(dx * step, dz * step, default=0)
if cell == 2:
found = step
break
if self._view_map is not None:
cell = (
int(
self._view_map[
np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20)
]
)
if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21)
else 0
)
if cell == 2:
found = step
break
ray_dirt.append(_norm(found, max_ray))
# Nearest dirt Euclidean distance (estimated from 7×7 crop)
# 最近污渍欧氏距离(视野内 7×7 粗估)
# Nearest dirt path distance in the visible map.
# 视野内最近污渍路径距离。
self.last_nearest_dirt_dist = self.nearest_dirt_dist
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
@@ -500,7 +519,7 @@ class Preprocessor:
dirt_delta,
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
_norm(self.nearest_charger_range_dist, self.GRID_SIZE),
_norm(self.nearest_charger_path_dist, self.GRID_SIZE),
battery_margin_norm,
1.0 if self.low_battery else 0.0,
1.0 if self.recharge_mode else 0.0,
@@ -519,19 +538,73 @@ class Preprocessor:
)
def _calc_nearest_dirt_dist(self):
"""Find nearest dirt Euclidean distance from local view.
"""Find nearest dirt path distance from local view.
从局部视野中找最近污渍的欧氏距离。
从局部视野中找最近污渍路径距离。
"""
view = self._view_map
if view is None:
return 200.0
dirt_coords = np.argwhere(view == 2)
dist = self._local_bfs_distances()
dirt_coords = np.argwhere(self._view_map == 2)
if len(dirt_coords) == 0:
return 200.0
center = self.VIEW_HALF
dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2)
return float(np.min(dists))
best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
if best >= self.INF_DIST:
return 200.0
return best
def _local_bfs_distances(self, start_dx=0, start_dz=0):
"""Shortest path distances inside the current 21x21 local view."""
view = self._view_map
shape = view.shape
dist = np.full(shape, self.INF_DIST, dtype=np.float32)
start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz)
if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]):
return dist
if int(view[start_ri, start_ci]) == 0:
return dist
dist[start_ri, start_ci] = 0.0
queue = deque([(start_ri, start_ci)])
while queue:
ri, ci = queue.popleft()
base = dist[ri, ci]
for dx, dz in self.ACTION_DIRS:
nri = ri + dz
nci = ci + dx
if not (0 <= nri < shape[0] and 0 <= nci < shape[1]):
continue
if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST:
continue
if dx != 0 and dz != 0:
side_a = int(view[ri, nci]) != 0
side_b = int(view[nri, ci]) != 0
if not (side_a or side_b):
continue
dist[nri, nci] = base + 1.0
queue.append((nri, nci))
return dist
def _local_path_dist_to_charger(self, gx, gz):
"""Visible-map BFS distance from global x/z to nearest charger cell."""
best = self.INF_DIST
start_dx = gx - self.cur_pos[0]
start_dz = gz - self.cur_pos[1]
dist = self._local_bfs_distances(start_dx, start_dz)
for rx, rz, w, h in self.charger_rects:
for tx in range(rx, rx + w):
for tz in range(rz, rz + h):
dx = tx - self.cur_pos[0]
dz = tz - self.cur_pos[1]
ri, ci = self._view_delta_to_index(dx, dz)
if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]:
best = min(best, float(dist[ri, ci]))
return best
def _charger_move_distance(self, gx, gz):
"""Use visible BFS to the charger when available, otherwise Chebyshev distance."""
path_dist = self._local_path_dist_to_charger(gx, gz)
if path_dist < self.INF_DIST:
return path_dist
return self._min_charger_range_dist(gx, gz)
def get_legal_action(self):
"""Return legal action mask (8D list).
@@ -569,11 +642,8 @@ class Preprocessor:
def _is_visible_cell_passable(self, dx, dz):
"""Whether a relative 21x21-view cell is passable."""
ri = self.VIEW_HALF + dx
ci = self.VIEW_HALF + dz
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
return True
return int(self._view_map[ri, ci]) != 0
cell = self._view_cell(dx, dz, default=None)
return True if cell is None else cell != 0
def _filter_npc_danger_actions(self, legal_action):
"""Avoid actions that would enter any NPC 3x3 danger zone."""
@@ -608,42 +678,43 @@ class Preprocessor:
return list(legal_action)
hx, hz = self.cur_pos
current_dist = self._min_charger_range_dist(hx, hz)
current_range_dist = self._min_charger_range_dist(hx, hz)
current_move_dist = self._charger_move_distance(hx, hz)
scored = []
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if legal_action[action] <= 0:
continue
nx, nz = hx + dx, hz + dz
next_dist = self._min_charger_range_dist(nx, nz)
next_dist = self._charger_move_distance(nx, nz)
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
scored.append((next_dist, alignment, action))
next_range_dist = self._min_charger_range_dist(nx, nz)
scored.append((next_dist, alignment, next_range_dist, action))
if not scored:
return list(legal_action)
# When already inside the charger range, stay inside until recharge mode exits.
# 已经在充电区域内时,回充模式退出前不要离开充电区域。
if current_dist <= 0.0:
if current_range_dist <= 0.0:
stay = [0] * 8
for next_dist, _, action in scored:
if next_dist <= 0.0:
for _, _, next_range_dist, action in scored:
if next_range_dist <= 0.0:
stay[action] = 1
if any(stay):
return stay
best_next_dist = min(item[0] for item in scored)
best_alignment = max(alignment for next_dist, alignment, _ in scored if next_dist <= best_next_dist + 0.1)
recharge = [0] * 8
for next_dist, alignment, action in scored:
if next_dist <= best_next_dist + 0.1 and alignment >= best_alignment - 0.1:
best_next_dist = min(item[0] for item in scored)
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
for next_dist, _, _, action in ranked:
if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1:
recharge[action] = 1
if sum(recharge) >= 3:
break
if not any(recharge):
best_alignment = max(item[1] for item in scored)
for _, alignment, action in scored:
if alignment >= best_alignment - 0.1:
recharge[action] = 1
for _, _, _, action in ranked[: min(3, len(ranked))]:
recharge[action] = 1
return recharge if any(recharge) else list(legal_action)
@@ -733,7 +804,7 @@ class Preprocessor:
if self.has_charger and (self.recharge_mode or self.low_battery):
dist_delta = float(
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0)
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
)
approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04
retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02

View File

@@ -6,7 +6,7 @@
"""
Author: Tencent AI Arena Authors
Simple MLP policy network for Robot Vacuum.
CNN + MLP policy network for Robot Vacuum.
清扫大作战策略网络。
"""
@@ -28,9 +28,9 @@ def _make_fc(in_dim, out_dim, gain=1.41421):
class Model(nn.Module):
"""Dual-head MLP for Robot Vacuum.
"""Dual-head CNN+MLP actor-critic for Robot Vacuum.
清扫大作战双头 MLP 策略网络。
清扫大作战双头 CNN+MLP 策略网络。
"""
def __init__(self, device=None):
@@ -38,12 +38,36 @@ class Model(nn.Module):
self.model_name = "robot_vacuum"
self.device = device
obs_dim = Config.DIM_OF_OBSERVATION # 157
map_dim, scalar_dim, last_action_dim = Config.FEATURES
map_size = int(map_dim**0.5)
if map_size * map_size != map_dim:
raise ValueError(f"local map feature must be square, got {map_dim}")
self.map_size = map_size
self.map_dim = map_dim
self.scalar_dim = scalar_dim + last_action_dim
act_num = Config.ACTION_NUM # 8
# Shared backbone / 共享骨干网络
# Local map encoder keeps spatial obstacle/dirt patterns.
# 局部地图编码器保留障碍/污渍空间结构。
self.map_encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((3, 3)),
nn.Flatten(),
)
# Scalar encoder for battery, charger, NPC and last-action features.
# 标量编码器处理电量、充电桩、NPC、上一步动作等特征。
self.scalar_encoder = nn.Sequential(
_make_fc(self.scalar_dim, 64),
nn.ReLU(),
)
# Shared fusion backbone / 共享融合骨干网络
self.backbone = nn.Sequential(
_make_fc(obs_dim, 256),
_make_fc(32 * 3 * 3 + 64, 256),
nn.ReLU(),
_make_fc(256, 128),
nn.ReLU(),
@@ -61,7 +85,11 @@ class Model(nn.Module):
前向传播。
"""
x = s.to(torch.float32)
h = self.backbone(x)
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
scalar = x[:, self.map_dim :]
map_h = self.map_encoder(local_map)
scalar_h = self.scalar_encoder(scalar)
h = self.backbone(torch.cat([map_h, scalar_h], dim=1))
logits = self.actor_head(h)
value = self.critic_head(h)
return [logits, value]