优化 PPO 清扫策略

This commit is contained in:
2026-04-26 17:29:03 +08:00
parent f04feb0cd9
commit f44e2483fc
6 changed files with 223 additions and 86 deletions

View File

@@ -192,18 +192,31 @@ class Agent(BaseAgent):
合法动作掩码下的 softmax。 合法动作掩码下的 softmax。
""" """
_w, _e = 1e20, 1e-5 legal = np.asarray(legal_action, dtype=np.float32) > 0.5
tmp = logits - _w * (1.0 - legal_action) if not np.any(legal):
tmp_max = np.max(tmp, keepdims=True) legal = np.ones(Config.ACTION_NUM, dtype=bool)
tmp = np.clip(tmp - tmp_max, -_w, 1)
tmp = (np.exp(tmp) + _e) * legal_action masked_logits = np.asarray(logits, dtype=np.float32).copy()
return tmp / (np.sum(tmp, keepdims=True) * 1.00001) masked_logits[~legal] = -1e9
shifted = masked_logits - np.max(masked_logits)
probs = np.exp(shifted) * legal.astype(np.float32)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = legal.astype(np.float32)
prob_sum = float(np.sum(probs))
return probs / prob_sum
def _legal_sample(self, probs, use_max=False): def _legal_sample(self, probs, use_max=False):
"""Sample action from probability distribution (argmax if use_max=True). """Sample action from probability distribution (argmax if use_max=True).
按概率分布采样动作use_max=True 时取 argmax 按概率分布采样动作use_max=True 时取 argmax
""" """
probs = np.asarray(probs, dtype=np.float64)
prob_sum = float(np.sum(probs))
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
else:
probs = probs / prob_sum
if use_max: if use_max:
return int(np.argmax(probs)) return int(np.argmax(probs))
return int(np.argmax(np.random.multinomial(1, probs, size=1))) return int(np.random.choice(len(probs), p=probs))

View File

@@ -43,14 +43,14 @@ class Algorithm:
训练入口:接收一批 SampleData执行一步梯度更新。 训练入口:接收一批 SampleData执行一步梯度更新。
""" """
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device) obs = self._batch_tensor([s.obs for s in list_sample_data])
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device) legal_action = self._batch_tensor([s.legal_action for s in list_sample_data])
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1) act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1)
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device) old_prob = self._batch_tensor([s.prob for s in list_sample_data])
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device) old_value = self._batch_tensor([s.value for s in list_sample_data])
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device) reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data])
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device) advantage = self._batch_tensor([s.advantage for s in list_sample_data])
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device) reward = self._batch_tensor([s.reward for s in list_sample_data])
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1: if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8) advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
@@ -194,9 +194,22 @@ class Algorithm:
对 logits 应用合法动作掩码后计算 softmax。 对 logits 应用合法动作掩码后计算 softmax。
""" """
legal_mask = legal_action > 0.5 legal_mask = legal_action > 0.5
all_illegal = ~legal_mask.any(dim=1, keepdim=True)
legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask)
safe_logits = logits.masked_fill(~legal_mask, -1e9) safe_logits = logits.masked_fill(~legal_mask, -1e9)
return F.softmax(safe_logits, dim=1) return F.softmax(safe_logits, dim=1)
def _batch_tensor(self, values):
"""Stack framework tensors or raw numpy/list values into a float tensor."""
tensors = []
for value in values:
if isinstance(value, torch.Tensor):
tensor = value.to(self.device, dtype=torch.float32)
else:
tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device)
tensors.append(tensor)
return torch.stack(tensors)
def _entropy_beta(self): def _entropy_beta(self):
"""Linearly decay entropy regularization for fast early exploration.""" """Linearly decay entropy regularization for fast early exploration."""
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0) progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)

View File

@@ -6,7 +6,7 @@ map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# Whether to randomly select maps. Boolean. # Whether to randomly select maps. Boolean.
# true = randomly pick one from configured maps per episode, false = used sequentially. # true = randomly pick one from configured maps per episode, false = used sequentially.
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张false表示按顺序抽取地图训练。 # 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张false表示按顺序抽取地图训练。
map_random = false map_random = true
# Number of official robots. Range: 1~4 (integer). # Number of official robots. Range: 1~4 (integer).
# In each round, official robots will be randomly generated on the road according to the configured. # In each round, official robots will be randomly generated on the road according to the configured.

View File

@@ -67,7 +67,19 @@ def _calc_gae(list_sample_data):
gamma = Config.GAMMA gamma = Config.GAMMA
lamda = Config.LAMDA lamda = Config.LAMDA
for sample in reversed(list_sample_data): for sample in reversed(list_sample_data):
delta = -sample.value + sample.reward + gamma * sample.next_value value = _scalar(sample.value)
gae = gae * gamma * lamda + delta reward = _scalar(sample.reward)
next_value = _scalar(sample.next_value)
nonterminal = 1.0 - _scalar(sample.done)
delta = reward + gamma * next_value * nonterminal - value
gae = delta + gamma * lamda * nonterminal * gae
sample.advantage = gae sample.advantage = gae
sample.reward_sum = gae + sample.value sample.reward_sum = gae + value
def _scalar(value):
"""Read the first scalar from numpy/tensor/list values."""
if hasattr(value, "detach"):
value = value.detach().cpu().numpy()
arr = np.asarray(value, dtype=np.float32).reshape(-1)
return float(arr[0]) if arr.size else 0.0

View File

@@ -10,6 +10,8 @@ Feature preprocessor for Robot Vacuum.
清扫大作战特征预处理器。 清扫大作战特征预处理器。
""" """
from collections import deque
import numpy as np import numpy as np
@@ -55,6 +57,7 @@ class Preprocessor:
(0, 1), (0, 1),
(1, 1), (1, 1),
) )
INF_DIST = 1e6
def __init__(self): def __init__(self):
self.reset() self.reset()
@@ -83,12 +86,12 @@ class Preprocessor:
self.step_cleaned_count = 0 self.step_cleaned_count = 0
self.max_step = 1000 self.max_step = 1000
# Global passable map (0=obstacle, 1=passable), used for ray computation # Global passable map (0=obstacle, 1=passable), indexed by [x, z].
# 维护全局通行地图0=障碍, 1=可通行),用于射线计算 # 维护全局通行地图0=障碍, 1=可通行),索引为 [x, z]。
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8) self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
# Nearest dirt distance # Nearest dirt path distance in the current local view.
# 最近污渍距离 # 当前局部视野内最近污渍路径距离
self.nearest_dirt_dist = 200.0 self.nearest_dirt_dist = 200.0
self.last_nearest_dirt_dist = 200.0 self.last_nearest_dirt_dist = 200.0
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16) self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
@@ -114,6 +117,8 @@ class Preprocessor:
self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_dist = float(self.GRID_SIZE)
self.nearest_charger_range_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE)
self.last_nearest_charger_range_dist = float(self.GRID_SIZE) self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
self.nearest_charger_path_dist = float(self.GRID_SIZE)
self.last_nearest_charger_path_dist = float(self.GRID_SIZE)
self.charger_energy_cost = float(self.GRID_SIZE) self.charger_energy_cost = float(self.GRID_SIZE)
self.charger_safety_buffer = 0.0 self.charger_safety_buffer = 0.0
self.charger_safety_margin = 0.0 self.charger_safety_margin = 0.0
@@ -226,13 +231,33 @@ class Preprocessor:
for ri in range(vsize): for ri in range(vsize):
for ci in range(vsize): for ci in range(vsize):
gx = hx - half + ri gx = hx + ci - half
gz = hz - half + ci gz = hz + ri - half
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE: if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
# 0 = obstacle, 1/2 = passable # 0 = obstacle, 1/2 = passable
# 0 = 障碍, 1/2 = 可通行 # 0 = 障碍, 1/2 = 可通行
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0 self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
def _view_index_to_global(self, ri, ci):
"""Convert local view row/col to global x/z coordinates."""
half = self._view_map.shape[0] // 2
hx, hz = self.cur_pos
return hx + ci - half, hz + ri - half
def _view_delta_to_index(self, dx, dz):
"""Convert global-coordinate dx/dz to local view row/col."""
return self.VIEW_HALF + dz, self.VIEW_HALF + dx
def _view_cell(self, dx, dz, default=None):
"""Read a local-view cell by global-coordinate delta.
`map_info` is row-major: row follows z/down, col follows x/right.
"""
ri, ci = self._view_delta_to_index(dx, dz)
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
return default
return int(self._view_map[ri, ci])
def _update_local_map_stats(self): def _update_local_map_stats(self):
"""Cache coarse 21x21 map statistics.""" """Cache coarse 21x21 map statistics."""
view = self._view_map view = self._view_map
@@ -247,6 +272,7 @@ class Preprocessor:
def _update_charger_state(self, hx, hz, organs): def _update_charger_state(self, hx, hz, organs):
"""Find nearest charger and cache distance/direction features.""" """Find nearest charger and cache distance/direction features."""
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
self.last_nearest_charger_path_dist = self.nearest_charger_path_dist
self.has_charger = False self.has_charger = False
self.on_charger = False self.on_charger = False
self.nearest_charger_dx = 0.0 self.nearest_charger_dx = 0.0
@@ -255,6 +281,7 @@ class Preprocessor:
self.nearest_charger_center_dz = 0.0 self.nearest_charger_center_dz = 0.0
self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_dist = float(self.GRID_SIZE)
self.nearest_charger_range_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE)
self.nearest_charger_path_dist = float(self.GRID_SIZE)
self.charger_energy_cost = float(self.GRID_SIZE) self.charger_energy_cost = float(self.GRID_SIZE)
self.charger_safety_buffer = 0.0 self.charger_safety_buffer = 0.0
self.charger_safety_margin = 0.0 self.charger_safety_margin = 0.0
@@ -295,9 +322,11 @@ class Preprocessor:
self.nearest_charger_center_dz = float(center_dz) self.nearest_charger_center_dz = float(center_dz)
self.nearest_charger_dist = float(dist) self.nearest_charger_dist = float(dist)
self.nearest_charger_range_dist = float(range_dist) self.nearest_charger_range_dist = float(range_dist)
self.charger_energy_cost = float(range_dist) path_dist = self._local_path_dist_to_charger(hx, hz)
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
self.charger_energy_cost = self.nearest_charger_path_dist
self.on_charger = range_dist <= 0.0 self.on_charger = range_dist <= 0.0
self.battery_margin = float(self.battery) - self.nearest_charger_range_dist self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
def _relative_vector_to_rect(self, x, z, rx, rz, w, h): def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
"""Relative vector from point to the nearest cell in a rectangle.""" """Relative vector from point to the nearest cell in a rectangle."""
@@ -355,7 +384,7 @@ class Preprocessor:
self.charger_safety_margin = float(self.battery) self.charger_safety_margin = float(self.battery)
return return
self.charger_energy_cost = float(max(self.nearest_charger_range_dist, 0.0)) self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0))
self.charger_safety_buffer = self._charger_safety_buffer() self.charger_safety_buffer = self._charger_safety_buffer()
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
@@ -450,30 +479,20 @@ class Preprocessor:
ray_dirt = [] ray_dirt = []
max_ray = 30 max_ray = 30
for dx, dz in ray_dirs: for dx, dz in ray_dirs:
x, z = hx, hz
found = max_ray found = max_ray
for step in range(1, max_ray + 1): for step in range(1, max_ray + 1):
x += dx gx = hx + dx * step
z += dz gz = hz + dz * step
if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE): if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
break break
if self._view_map is not None: cell = self._view_cell(dx * step, dz * step, default=0)
cell = (
int(
self._view_map[
np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20)
]
)
if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21)
else 0
)
if cell == 2: if cell == 2:
found = step found = step
break break
ray_dirt.append(_norm(found, max_ray)) ray_dirt.append(_norm(found, max_ray))
# Nearest dirt Euclidean distance (estimated from 7×7 crop) # Nearest dirt path distance in the visible map.
# 最近污渍欧氏距离(视野内 7×7 粗估) # 视野内最近污渍路径距离。
self.last_nearest_dirt_dist = self.nearest_dirt_dist self.last_nearest_dirt_dist = self.nearest_dirt_dist
self.nearest_dirt_dist = self._calc_nearest_dirt_dist() self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180) nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
@@ -500,7 +519,7 @@ class Preprocessor:
dirt_delta, dirt_delta,
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE), _signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE), _signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
_norm(self.nearest_charger_range_dist, self.GRID_SIZE), _norm(self.nearest_charger_path_dist, self.GRID_SIZE),
battery_margin_norm, battery_margin_norm,
1.0 if self.low_battery else 0.0, 1.0 if self.low_battery else 0.0,
1.0 if self.recharge_mode else 0.0, 1.0 if self.recharge_mode else 0.0,
@@ -519,19 +538,73 @@ class Preprocessor:
) )
def _calc_nearest_dirt_dist(self): def _calc_nearest_dirt_dist(self):
"""Find nearest dirt Euclidean distance from local view. """Find nearest dirt path distance from local view.
从局部视野中找最近污渍的欧氏距离。 从局部视野中找最近污渍路径距离。
""" """
view = self._view_map dist = self._local_bfs_distances()
if view is None: dirt_coords = np.argwhere(self._view_map == 2)
return 200.0
dirt_coords = np.argwhere(view == 2)
if len(dirt_coords) == 0: if len(dirt_coords) == 0:
return 200.0 return 200.0
center = self.VIEW_HALF best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2) if best >= self.INF_DIST:
return float(np.min(dists)) return 200.0
return best
def _local_bfs_distances(self, start_dx=0, start_dz=0):
"""Shortest path distances inside the current 21x21 local view."""
view = self._view_map
shape = view.shape
dist = np.full(shape, self.INF_DIST, dtype=np.float32)
start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz)
if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]):
return dist
if int(view[start_ri, start_ci]) == 0:
return dist
dist[start_ri, start_ci] = 0.0
queue = deque([(start_ri, start_ci)])
while queue:
ri, ci = queue.popleft()
base = dist[ri, ci]
for dx, dz in self.ACTION_DIRS:
nri = ri + dz
nci = ci + dx
if not (0 <= nri < shape[0] and 0 <= nci < shape[1]):
continue
if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST:
continue
if dx != 0 and dz != 0:
side_a = int(view[ri, nci]) != 0
side_b = int(view[nri, ci]) != 0
if not (side_a or side_b):
continue
dist[nri, nci] = base + 1.0
queue.append((nri, nci))
return dist
def _local_path_dist_to_charger(self, gx, gz):
"""Visible-map BFS distance from global x/z to nearest charger cell."""
best = self.INF_DIST
start_dx = gx - self.cur_pos[0]
start_dz = gz - self.cur_pos[1]
dist = self._local_bfs_distances(start_dx, start_dz)
for rx, rz, w, h in self.charger_rects:
for tx in range(rx, rx + w):
for tz in range(rz, rz + h):
dx = tx - self.cur_pos[0]
dz = tz - self.cur_pos[1]
ri, ci = self._view_delta_to_index(dx, dz)
if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]:
best = min(best, float(dist[ri, ci]))
return best
def _charger_move_distance(self, gx, gz):
"""Use visible BFS to the charger when available, otherwise Chebyshev distance."""
path_dist = self._local_path_dist_to_charger(gx, gz)
if path_dist < self.INF_DIST:
return path_dist
return self._min_charger_range_dist(gx, gz)
def get_legal_action(self): def get_legal_action(self):
"""Return legal action mask (8D list). """Return legal action mask (8D list).
@@ -569,11 +642,8 @@ class Preprocessor:
def _is_visible_cell_passable(self, dx, dz): def _is_visible_cell_passable(self, dx, dz):
"""Whether a relative 21x21-view cell is passable.""" """Whether a relative 21x21-view cell is passable."""
ri = self.VIEW_HALF + dx cell = self._view_cell(dx, dz, default=None)
ci = self.VIEW_HALF + dz return True if cell is None else cell != 0
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
return True
return int(self._view_map[ri, ci]) != 0
def _filter_npc_danger_actions(self, legal_action): def _filter_npc_danger_actions(self, legal_action):
"""Avoid actions that would enter any NPC 3x3 danger zone.""" """Avoid actions that would enter any NPC 3x3 danger zone."""
@@ -608,41 +678,42 @@ class Preprocessor:
return list(legal_action) return list(legal_action)
hx, hz = self.cur_pos hx, hz = self.cur_pos
current_dist = self._min_charger_range_dist(hx, hz) current_range_dist = self._min_charger_range_dist(hx, hz)
current_move_dist = self._charger_move_distance(hx, hz)
scored = [] scored = []
for action, (dx, dz) in enumerate(self.ACTION_DIRS): for action, (dx, dz) in enumerate(self.ACTION_DIRS):
if legal_action[action] <= 0: if legal_action[action] <= 0:
continue continue
nx, nz = hx + dx, hz + dz nx, nz = hx + dx, hz + dz
next_dist = self._min_charger_range_dist(nx, nz) next_dist = self._charger_move_distance(nx, nz)
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
scored.append((next_dist, alignment, action)) next_range_dist = self._min_charger_range_dist(nx, nz)
scored.append((next_dist, alignment, next_range_dist, action))
if not scored: if not scored:
return list(legal_action) return list(legal_action)
# When already inside the charger range, stay inside until recharge mode exits. # When already inside the charger range, stay inside until recharge mode exits.
# 已经在充电区域内时,回充模式退出前不要离开充电区域。 # 已经在充电区域内时,回充模式退出前不要离开充电区域。
if current_dist <= 0.0: if current_range_dist <= 0.0:
stay = [0] * 8 stay = [0] * 8
for next_dist, _, action in scored: for _, _, next_range_dist, action in scored:
if next_dist <= 0.0: if next_range_dist <= 0.0:
stay[action] = 1 stay[action] = 1
if any(stay): if any(stay):
return stay return stay
best_next_dist = min(item[0] for item in scored)
best_alignment = max(alignment for next_dist, alignment, _ in scored if next_dist <= best_next_dist + 0.1)
recharge = [0] * 8 recharge = [0] * 8
for next_dist, alignment, action in scored: best_next_dist = min(item[0] for item in scored)
if next_dist <= best_next_dist + 0.1 and alignment >= best_alignment - 0.1: ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
for next_dist, _, _, action in ranked:
if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1:
recharge[action] = 1 recharge[action] = 1
if sum(recharge) >= 3:
break
if not any(recharge): if not any(recharge):
best_alignment = max(item[1] for item in scored) for _, _, _, action in ranked[: min(3, len(ranked))]:
for _, alignment, action in scored:
if alignment >= best_alignment - 0.1:
recharge[action] = 1 recharge[action] = 1
return recharge if any(recharge) else list(legal_action) return recharge if any(recharge) else list(legal_action)
@@ -733,7 +804,7 @@ class Preprocessor:
if self.has_charger and (self.recharge_mode or self.low_battery): if self.has_charger and (self.recharge_mode or self.low_battery):
dist_delta = float( dist_delta = float(
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0) np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
) )
approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04 approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04
retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02 retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02

View File

@@ -6,7 +6,7 @@
""" """
Author: Tencent AI Arena Authors Author: Tencent AI Arena Authors
Simple MLP policy network for Robot Vacuum. CNN + MLP policy network for Robot Vacuum.
清扫大作战策略网络。 清扫大作战策略网络。
""" """
@@ -28,9 +28,9 @@ def _make_fc(in_dim, out_dim, gain=1.41421):
class Model(nn.Module): class Model(nn.Module):
"""Dual-head MLP for Robot Vacuum. """Dual-head CNN+MLP actor-critic for Robot Vacuum.
清扫大作战双头 MLP 策略网络。 清扫大作战双头 CNN+MLP 策略网络。
""" """
def __init__(self, device=None): def __init__(self, device=None):
@@ -38,12 +38,36 @@ class Model(nn.Module):
self.model_name = "robot_vacuum" self.model_name = "robot_vacuum"
self.device = device self.device = device
obs_dim = Config.DIM_OF_OBSERVATION # 157 map_dim, scalar_dim, last_action_dim = Config.FEATURES
map_size = int(map_dim**0.5)
if map_size * map_size != map_dim:
raise ValueError(f"local map feature must be square, got {map_dim}")
self.map_size = map_size
self.map_dim = map_dim
self.scalar_dim = scalar_dim + last_action_dim
act_num = Config.ACTION_NUM # 8 act_num = Config.ACTION_NUM # 8
# Shared backbone / 共享骨干网络 # Local map encoder keeps spatial obstacle/dirt patterns.
# 局部地图编码器保留障碍/污渍空间结构。
self.map_encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((3, 3)),
nn.Flatten(),
)
# Scalar encoder for battery, charger, NPC and last-action features.
# 标量编码器处理电量、充电桩、NPC、上一步动作等特征。
self.scalar_encoder = nn.Sequential(
_make_fc(self.scalar_dim, 64),
nn.ReLU(),
)
# Shared fusion backbone / 共享融合骨干网络
self.backbone = nn.Sequential( self.backbone = nn.Sequential(
_make_fc(obs_dim, 256), _make_fc(32 * 3 * 3 + 64, 256),
nn.ReLU(), nn.ReLU(),
_make_fc(256, 128), _make_fc(256, 128),
nn.ReLU(), nn.ReLU(),
@@ -61,7 +85,11 @@ class Model(nn.Module):
前向传播。 前向传播。
""" """
x = s.to(torch.float32) x = s.to(torch.float32)
h = self.backbone(x) local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
scalar = x[:, self.map_dim :]
map_h = self.map_encoder(local_map)
scalar_h = self.scalar_encoder(scalar)
h = self.backbone(torch.cat([map_h, scalar_h], dim=1))
logits = self.actor_head(h) logits = self.actor_head(h)
value = self.critic_head(h) value = self.critic_head(h)
return [logits, value] return [logits, value]