优化 PPO 清扫策略
This commit is contained in:
@@ -192,18 +192,31 @@ class Agent(BaseAgent):
|
||||
|
||||
合法动作掩码下的 softmax。
|
||||
"""
|
||||
_w, _e = 1e20, 1e-5
|
||||
tmp = logits - _w * (1.0 - legal_action)
|
||||
tmp_max = np.max(tmp, keepdims=True)
|
||||
tmp = np.clip(tmp - tmp_max, -_w, 1)
|
||||
tmp = (np.exp(tmp) + _e) * legal_action
|
||||
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
|
||||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||||
if not np.any(legal):
|
||||
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||||
|
||||
masked_logits = np.asarray(logits, dtype=np.float32).copy()
|
||||
masked_logits[~legal] = -1e9
|
||||
shifted = masked_logits - np.max(masked_logits)
|
||||
probs = np.exp(shifted) * legal.astype(np.float32)
|
||||
prob_sum = float(np.sum(probs))
|
||||
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||||
probs = legal.astype(np.float32)
|
||||
prob_sum = float(np.sum(probs))
|
||||
return probs / prob_sum
|
||||
|
||||
def _legal_sample(self, probs, use_max=False):
|
||||
"""Sample action from probability distribution (argmax if use_max=True).
|
||||
|
||||
按概率分布采样动作(use_max=True 时取 argmax)。
|
||||
"""
|
||||
probs = np.asarray(probs, dtype=np.float64)
|
||||
prob_sum = float(np.sum(probs))
|
||||
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||||
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
|
||||
else:
|
||||
probs = probs / prob_sum
|
||||
if use_max:
|
||||
return int(np.argmax(probs))
|
||||
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
|
||||
return int(np.random.choice(len(probs), p=probs))
|
||||
|
||||
@@ -43,14 +43,14 @@ class Algorithm:
|
||||
|
||||
训练入口:接收一批 SampleData,执行一步梯度更新。
|
||||
"""
|
||||
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device)
|
||||
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device)
|
||||
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1)
|
||||
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device)
|
||||
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device)
|
||||
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device)
|
||||
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
|
||||
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
|
||||
obs = self._batch_tensor([s.obs for s in list_sample_data])
|
||||
legal_action = self._batch_tensor([s.legal_action for s in list_sample_data])
|
||||
act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1)
|
||||
old_prob = self._batch_tensor([s.prob for s in list_sample_data])
|
||||
old_value = self._batch_tensor([s.value for s in list_sample_data])
|
||||
reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data])
|
||||
advantage = self._batch_tensor([s.advantage for s in list_sample_data])
|
||||
reward = self._batch_tensor([s.reward for s in list_sample_data])
|
||||
|
||||
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
|
||||
@@ -194,9 +194,22 @@ class Algorithm:
|
||||
对 logits 应用合法动作掩码后计算 softmax。
|
||||
"""
|
||||
legal_mask = legal_action > 0.5
|
||||
all_illegal = ~legal_mask.any(dim=1, keepdim=True)
|
||||
legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask)
|
||||
safe_logits = logits.masked_fill(~legal_mask, -1e9)
|
||||
return F.softmax(safe_logits, dim=1)
|
||||
|
||||
def _batch_tensor(self, values):
|
||||
"""Stack framework tensors or raw numpy/list values into a float tensor."""
|
||||
tensors = []
|
||||
for value in values:
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value.to(self.device, dtype=torch.float32)
|
||||
else:
|
||||
tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device)
|
||||
tensors.append(tensor)
|
||||
return torch.stack(tensors)
|
||||
|
||||
def _entropy_beta(self):
|
||||
"""Linearly decay entropy regularization for fast early exploration."""
|
||||
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
|
||||
|
||||
@@ -6,7 +6,7 @@ map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# Whether to randomly select maps. Boolean.
|
||||
# true = randomly pick one from configured maps per episode, false = used sequentially.
|
||||
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张,false表示按顺序抽取地图训练。
|
||||
map_random = false
|
||||
map_random = true
|
||||
|
||||
# Number of official robots. Range: 1~4 (integer).
|
||||
# In each round, official robots will be randomly generated on the road according to the configured.
|
||||
|
||||
@@ -67,7 +67,19 @@ def _calc_gae(list_sample_data):
|
||||
gamma = Config.GAMMA
|
||||
lamda = Config.LAMDA
|
||||
for sample in reversed(list_sample_data):
|
||||
delta = -sample.value + sample.reward + gamma * sample.next_value
|
||||
gae = gae * gamma * lamda + delta
|
||||
value = _scalar(sample.value)
|
||||
reward = _scalar(sample.reward)
|
||||
next_value = _scalar(sample.next_value)
|
||||
nonterminal = 1.0 - _scalar(sample.done)
|
||||
delta = reward + gamma * next_value * nonterminal - value
|
||||
gae = delta + gamma * lamda * nonterminal * gae
|
||||
sample.advantage = gae
|
||||
sample.reward_sum = gae + sample.value
|
||||
sample.reward_sum = gae + value
|
||||
|
||||
|
||||
def _scalar(value):
|
||||
"""Read the first scalar from numpy/tensor/list values."""
|
||||
if hasattr(value, "detach"):
|
||||
value = value.detach().cpu().numpy()
|
||||
arr = np.asarray(value, dtype=np.float32).reshape(-1)
|
||||
return float(arr[0]) if arr.size else 0.0
|
||||
|
||||
@@ -10,6 +10,8 @@ Feature preprocessor for Robot Vacuum.
|
||||
清扫大作战特征预处理器。
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -55,6 +57,7 @@ class Preprocessor:
|
||||
(0, 1),
|
||||
(1, 1),
|
||||
)
|
||||
INF_DIST = 1e6
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -83,12 +86,12 @@ class Preprocessor:
|
||||
self.step_cleaned_count = 0
|
||||
self.max_step = 1000
|
||||
|
||||
# Global passable map (0=obstacle, 1=passable), used for ray computation
|
||||
# 维护全局通行地图(0=障碍, 1=可通行),用于射线计算
|
||||
# Global passable map (0=obstacle, 1=passable), indexed by [x, z].
|
||||
# 维护全局通行地图(0=障碍, 1=可通行),索引为 [x, z]。
|
||||
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||
|
||||
# Nearest dirt distance
|
||||
# 最近污渍距离
|
||||
# Nearest dirt path distance in the current local view.
|
||||
# 当前局部视野内最近污渍路径距离。
|
||||
self.nearest_dirt_dist = 200.0
|
||||
self.last_nearest_dirt_dist = 200.0
|
||||
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
|
||||
@@ -114,6 +117,8 @@ class Preprocessor:
|
||||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||
self.nearest_charger_path_dist = float(self.GRID_SIZE)
|
||||
self.last_nearest_charger_path_dist = float(self.GRID_SIZE)
|
||||
self.charger_energy_cost = float(self.GRID_SIZE)
|
||||
self.charger_safety_buffer = 0.0
|
||||
self.charger_safety_margin = 0.0
|
||||
@@ -226,13 +231,33 @@ class Preprocessor:
|
||||
|
||||
for ri in range(vsize):
|
||||
for ci in range(vsize):
|
||||
gx = hx - half + ri
|
||||
gz = hz - half + ci
|
||||
gx = hx + ci - half
|
||||
gz = hz + ri - half
|
||||
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
|
||||
# 0 = obstacle, 1/2 = passable
|
||||
# 0 = 障碍, 1/2 = 可通行
|
||||
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
||||
|
||||
def _view_index_to_global(self, ri, ci):
|
||||
"""Convert local view row/col to global x/z coordinates."""
|
||||
half = self._view_map.shape[0] // 2
|
||||
hx, hz = self.cur_pos
|
||||
return hx + ci - half, hz + ri - half
|
||||
|
||||
def _view_delta_to_index(self, dx, dz):
|
||||
"""Convert global-coordinate dx/dz to local view row/col."""
|
||||
return self.VIEW_HALF + dz, self.VIEW_HALF + dx
|
||||
|
||||
def _view_cell(self, dx, dz, default=None):
|
||||
"""Read a local-view cell by global-coordinate delta.
|
||||
|
||||
`map_info` is row-major: row follows z/down, col follows x/right.
|
||||
"""
|
||||
ri, ci = self._view_delta_to_index(dx, dz)
|
||||
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
||||
return default
|
||||
return int(self._view_map[ri, ci])
|
||||
|
||||
def _update_local_map_stats(self):
|
||||
"""Cache coarse 21x21 map statistics."""
|
||||
view = self._view_map
|
||||
@@ -247,6 +272,7 @@ class Preprocessor:
|
||||
def _update_charger_state(self, hx, hz, organs):
|
||||
"""Find nearest charger and cache distance/direction features."""
|
||||
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
||||
self.last_nearest_charger_path_dist = self.nearest_charger_path_dist
|
||||
self.has_charger = False
|
||||
self.on_charger = False
|
||||
self.nearest_charger_dx = 0.0
|
||||
@@ -255,6 +281,7 @@ class Preprocessor:
|
||||
self.nearest_charger_center_dz = 0.0
|
||||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||
self.nearest_charger_path_dist = float(self.GRID_SIZE)
|
||||
self.charger_energy_cost = float(self.GRID_SIZE)
|
||||
self.charger_safety_buffer = 0.0
|
||||
self.charger_safety_margin = 0.0
|
||||
@@ -295,9 +322,11 @@ class Preprocessor:
|
||||
self.nearest_charger_center_dz = float(center_dz)
|
||||
self.nearest_charger_dist = float(dist)
|
||||
self.nearest_charger_range_dist = float(range_dist)
|
||||
self.charger_energy_cost = float(range_dist)
|
||||
path_dist = self._local_path_dist_to_charger(hx, hz)
|
||||
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
|
||||
self.charger_energy_cost = self.nearest_charger_path_dist
|
||||
self.on_charger = range_dist <= 0.0
|
||||
self.battery_margin = float(self.battery) - self.nearest_charger_range_dist
|
||||
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
||||
|
||||
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
|
||||
"""Relative vector from point to the nearest cell in a rectangle."""
|
||||
@@ -355,7 +384,7 @@ class Preprocessor:
|
||||
self.charger_safety_margin = float(self.battery)
|
||||
return
|
||||
|
||||
self.charger_energy_cost = float(max(self.nearest_charger_range_dist, 0.0))
|
||||
self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0))
|
||||
self.charger_safety_buffer = self._charger_safety_buffer()
|
||||
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
|
||||
|
||||
@@ -450,30 +479,20 @@ class Preprocessor:
|
||||
ray_dirt = []
|
||||
max_ray = 30
|
||||
for dx, dz in ray_dirs:
|
||||
x, z = hx, hz
|
||||
found = max_ray
|
||||
for step in range(1, max_ray + 1):
|
||||
x += dx
|
||||
z += dz
|
||||
if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
|
||||
gx = hx + dx * step
|
||||
gz = hz + dz * step
|
||||
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
|
||||
break
|
||||
cell = self._view_cell(dx * step, dz * step, default=0)
|
||||
if cell == 2:
|
||||
found = step
|
||||
break
|
||||
if self._view_map is not None:
|
||||
cell = (
|
||||
int(
|
||||
self._view_map[
|
||||
np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20)
|
||||
]
|
||||
)
|
||||
if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21)
|
||||
else 0
|
||||
)
|
||||
if cell == 2:
|
||||
found = step
|
||||
break
|
||||
ray_dirt.append(_norm(found, max_ray))
|
||||
|
||||
# Nearest dirt Euclidean distance (estimated from 7×7 crop)
|
||||
# 最近污渍欧氏距离(视野内 7×7 粗估)
|
||||
# Nearest dirt path distance in the visible map.
|
||||
# 视野内最近污渍路径距离。
|
||||
self.last_nearest_dirt_dist = self.nearest_dirt_dist
|
||||
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
|
||||
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
||||
@@ -500,7 +519,7 @@ class Preprocessor:
|
||||
dirt_delta,
|
||||
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
|
||||
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
|
||||
_norm(self.nearest_charger_range_dist, self.GRID_SIZE),
|
||||
_norm(self.nearest_charger_path_dist, self.GRID_SIZE),
|
||||
battery_margin_norm,
|
||||
1.0 if self.low_battery else 0.0,
|
||||
1.0 if self.recharge_mode else 0.0,
|
||||
@@ -519,19 +538,73 @@ class Preprocessor:
|
||||
)
|
||||
|
||||
def _calc_nearest_dirt_dist(self):
|
||||
"""Find nearest dirt Euclidean distance from local view.
|
||||
"""Find nearest dirt path distance from local view.
|
||||
|
||||
从局部视野中找最近污渍的欧氏距离。
|
||||
从局部视野中找最近污渍路径距离。
|
||||
"""
|
||||
view = self._view_map
|
||||
if view is None:
|
||||
return 200.0
|
||||
dirt_coords = np.argwhere(view == 2)
|
||||
dist = self._local_bfs_distances()
|
||||
dirt_coords = np.argwhere(self._view_map == 2)
|
||||
if len(dirt_coords) == 0:
|
||||
return 200.0
|
||||
center = self.VIEW_HALF
|
||||
dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2)
|
||||
return float(np.min(dists))
|
||||
best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
|
||||
if best >= self.INF_DIST:
|
||||
return 200.0
|
||||
return best
|
||||
|
||||
def _local_bfs_distances(self, start_dx=0, start_dz=0):
|
||||
"""Shortest path distances inside the current 21x21 local view."""
|
||||
view = self._view_map
|
||||
shape = view.shape
|
||||
dist = np.full(shape, self.INF_DIST, dtype=np.float32)
|
||||
start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz)
|
||||
if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]):
|
||||
return dist
|
||||
if int(view[start_ri, start_ci]) == 0:
|
||||
return dist
|
||||
|
||||
dist[start_ri, start_ci] = 0.0
|
||||
queue = deque([(start_ri, start_ci)])
|
||||
while queue:
|
||||
ri, ci = queue.popleft()
|
||||
base = dist[ri, ci]
|
||||
for dx, dz in self.ACTION_DIRS:
|
||||
nri = ri + dz
|
||||
nci = ci + dx
|
||||
if not (0 <= nri < shape[0] and 0 <= nci < shape[1]):
|
||||
continue
|
||||
if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST:
|
||||
continue
|
||||
if dx != 0 and dz != 0:
|
||||
side_a = int(view[ri, nci]) != 0
|
||||
side_b = int(view[nri, ci]) != 0
|
||||
if not (side_a or side_b):
|
||||
continue
|
||||
dist[nri, nci] = base + 1.0
|
||||
queue.append((nri, nci))
|
||||
return dist
|
||||
|
||||
def _local_path_dist_to_charger(self, gx, gz):
|
||||
"""Visible-map BFS distance from global x/z to nearest charger cell."""
|
||||
best = self.INF_DIST
|
||||
start_dx = gx - self.cur_pos[0]
|
||||
start_dz = gz - self.cur_pos[1]
|
||||
dist = self._local_bfs_distances(start_dx, start_dz)
|
||||
for rx, rz, w, h in self.charger_rects:
|
||||
for tx in range(rx, rx + w):
|
||||
for tz in range(rz, rz + h):
|
||||
dx = tx - self.cur_pos[0]
|
||||
dz = tz - self.cur_pos[1]
|
||||
ri, ci = self._view_delta_to_index(dx, dz)
|
||||
if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]:
|
||||
best = min(best, float(dist[ri, ci]))
|
||||
return best
|
||||
|
||||
def _charger_move_distance(self, gx, gz):
|
||||
"""Use visible BFS to the charger when available, otherwise Chebyshev distance."""
|
||||
path_dist = self._local_path_dist_to_charger(gx, gz)
|
||||
if path_dist < self.INF_DIST:
|
||||
return path_dist
|
||||
return self._min_charger_range_dist(gx, gz)
|
||||
|
||||
def get_legal_action(self):
|
||||
"""Return legal action mask (8D list).
|
||||
@@ -569,11 +642,8 @@ class Preprocessor:
|
||||
|
||||
def _is_visible_cell_passable(self, dx, dz):
|
||||
"""Whether a relative 21x21-view cell is passable."""
|
||||
ri = self.VIEW_HALF + dx
|
||||
ci = self.VIEW_HALF + dz
|
||||
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
||||
return True
|
||||
return int(self._view_map[ri, ci]) != 0
|
||||
cell = self._view_cell(dx, dz, default=None)
|
||||
return True if cell is None else cell != 0
|
||||
|
||||
def _filter_npc_danger_actions(self, legal_action):
|
||||
"""Avoid actions that would enter any NPC 3x3 danger zone."""
|
||||
@@ -608,42 +678,43 @@ class Preprocessor:
|
||||
return list(legal_action)
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
current_dist = self._min_charger_range_dist(hx, hz)
|
||||
current_range_dist = self._min_charger_range_dist(hx, hz)
|
||||
current_move_dist = self._charger_move_distance(hx, hz)
|
||||
scored = []
|
||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||
if legal_action[action] <= 0:
|
||||
continue
|
||||
nx, nz = hx + dx, hz + dz
|
||||
next_dist = self._min_charger_range_dist(nx, nz)
|
||||
next_dist = self._charger_move_distance(nx, nz)
|
||||
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
|
||||
scored.append((next_dist, alignment, action))
|
||||
next_range_dist = self._min_charger_range_dist(nx, nz)
|
||||
scored.append((next_dist, alignment, next_range_dist, action))
|
||||
|
||||
if not scored:
|
||||
return list(legal_action)
|
||||
|
||||
# When already inside the charger range, stay inside until recharge mode exits.
|
||||
# 已经在充电区域内时,回充模式退出前不要离开充电区域。
|
||||
if current_dist <= 0.0:
|
||||
if current_range_dist <= 0.0:
|
||||
stay = [0] * 8
|
||||
for next_dist, _, action in scored:
|
||||
if next_dist <= 0.0:
|
||||
for _, _, next_range_dist, action in scored:
|
||||
if next_range_dist <= 0.0:
|
||||
stay[action] = 1
|
||||
if any(stay):
|
||||
return stay
|
||||
|
||||
best_next_dist = min(item[0] for item in scored)
|
||||
best_alignment = max(alignment for next_dist, alignment, _ in scored if next_dist <= best_next_dist + 0.1)
|
||||
|
||||
recharge = [0] * 8
|
||||
for next_dist, alignment, action in scored:
|
||||
if next_dist <= best_next_dist + 0.1 and alignment >= best_alignment - 0.1:
|
||||
best_next_dist = min(item[0] for item in scored)
|
||||
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
|
||||
for next_dist, _, _, action in ranked:
|
||||
if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1:
|
||||
recharge[action] = 1
|
||||
if sum(recharge) >= 3:
|
||||
break
|
||||
|
||||
if not any(recharge):
|
||||
best_alignment = max(item[1] for item in scored)
|
||||
for _, alignment, action in scored:
|
||||
if alignment >= best_alignment - 0.1:
|
||||
recharge[action] = 1
|
||||
for _, _, _, action in ranked[: min(3, len(ranked))]:
|
||||
recharge[action] = 1
|
||||
|
||||
return recharge if any(recharge) else list(legal_action)
|
||||
|
||||
@@ -733,7 +804,7 @@ class Preprocessor:
|
||||
|
||||
if self.has_charger and (self.recharge_mode or self.low_battery):
|
||||
dist_delta = float(
|
||||
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0)
|
||||
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
|
||||
)
|
||||
approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04
|
||||
retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Simple MLP policy network for Robot Vacuum.
|
||||
CNN + MLP policy network for Robot Vacuum.
|
||||
清扫大作战策略网络。
|
||||
"""
|
||||
|
||||
@@ -28,9 +28,9 @@ def _make_fc(in_dim, out_dim, gain=1.41421):
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""Dual-head MLP for Robot Vacuum.
|
||||
"""Dual-head CNN+MLP actor-critic for Robot Vacuum.
|
||||
|
||||
清扫大作战双头 MLP 策略网络。
|
||||
清扫大作战双头 CNN+MLP 策略网络。
|
||||
"""
|
||||
|
||||
def __init__(self, device=None):
|
||||
@@ -38,12 +38,36 @@ class Model(nn.Module):
|
||||
self.model_name = "robot_vacuum"
|
||||
self.device = device
|
||||
|
||||
obs_dim = Config.DIM_OF_OBSERVATION # 157
|
||||
map_dim, scalar_dim, last_action_dim = Config.FEATURES
|
||||
map_size = int(map_dim**0.5)
|
||||
if map_size * map_size != map_dim:
|
||||
raise ValueError(f"local map feature must be square, got {map_dim}")
|
||||
self.map_size = map_size
|
||||
self.map_dim = map_dim
|
||||
self.scalar_dim = scalar_dim + last_action_dim
|
||||
act_num = Config.ACTION_NUM # 8
|
||||
|
||||
# Shared backbone / 共享骨干网络
|
||||
# Local map encoder keeps spatial obstacle/dirt patterns.
|
||||
# 局部地图编码器保留障碍/污渍空间结构。
|
||||
self.map_encoder = nn.Sequential(
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((3, 3)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
# Scalar encoder for battery, charger, NPC and last-action features.
|
||||
# 标量编码器处理电量、充电桩、NPC、上一步动作等特征。
|
||||
self.scalar_encoder = nn.Sequential(
|
||||
_make_fc(self.scalar_dim, 64),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
# Shared fusion backbone / 共享融合骨干网络
|
||||
self.backbone = nn.Sequential(
|
||||
_make_fc(obs_dim, 256),
|
||||
_make_fc(32 * 3 * 3 + 64, 256),
|
||||
nn.ReLU(),
|
||||
_make_fc(256, 128),
|
||||
nn.ReLU(),
|
||||
@@ -61,7 +85,11 @@ class Model(nn.Module):
|
||||
前向传播。
|
||||
"""
|
||||
x = s.to(torch.float32)
|
||||
h = self.backbone(x)
|
||||
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
|
||||
scalar = x[:, self.map_dim :]
|
||||
map_h = self.map_encoder(local_map)
|
||||
scalar_h = self.scalar_encoder(scalar)
|
||||
h = self.backbone(torch.cat([map_h, scalar_h], dim=1))
|
||||
logits = self.actor_head(h)
|
||||
value = self.critic_head(h)
|
||||
return [logits, value]
|
||||
|
||||
Reference in New Issue
Block a user