优化 PPO 清扫策略
This commit is contained in:
@@ -192,18 +192,31 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
合法动作掩码下的 softmax。
|
合法动作掩码下的 softmax。
|
||||||
"""
|
"""
|
||||||
_w, _e = 1e20, 1e-5
|
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||||||
tmp = logits - _w * (1.0 - legal_action)
|
if not np.any(legal):
|
||||||
tmp_max = np.max(tmp, keepdims=True)
|
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||||||
tmp = np.clip(tmp - tmp_max, -_w, 1)
|
|
||||||
tmp = (np.exp(tmp) + _e) * legal_action
|
masked_logits = np.asarray(logits, dtype=np.float32).copy()
|
||||||
return tmp / (np.sum(tmp, keepdims=True) * 1.00001)
|
masked_logits[~legal] = -1e9
|
||||||
|
shifted = masked_logits - np.max(masked_logits)
|
||||||
|
probs = np.exp(shifted) * legal.astype(np.float32)
|
||||||
|
prob_sum = float(np.sum(probs))
|
||||||
|
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||||||
|
probs = legal.astype(np.float32)
|
||||||
|
prob_sum = float(np.sum(probs))
|
||||||
|
return probs / prob_sum
|
||||||
|
|
||||||
def _legal_sample(self, probs, use_max=False):
|
def _legal_sample(self, probs, use_max=False):
|
||||||
"""Sample action from probability distribution (argmax if use_max=True).
|
"""Sample action from probability distribution (argmax if use_max=True).
|
||||||
|
|
||||||
按概率分布采样动作(use_max=True 时取 argmax)。
|
按概率分布采样动作(use_max=True 时取 argmax)。
|
||||||
"""
|
"""
|
||||||
|
probs = np.asarray(probs, dtype=np.float64)
|
||||||
|
prob_sum = float(np.sum(probs))
|
||||||
|
if prob_sum <= 0.0 or not np.isfinite(prob_sum):
|
||||||
|
probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM
|
||||||
|
else:
|
||||||
|
probs = probs / prob_sum
|
||||||
if use_max:
|
if use_max:
|
||||||
return int(np.argmax(probs))
|
return int(np.argmax(probs))
|
||||||
return int(np.argmax(np.random.multinomial(1, probs, size=1)))
|
return int(np.random.choice(len(probs), p=probs))
|
||||||
|
|||||||
@@ -43,14 +43,14 @@ class Algorithm:
|
|||||||
|
|
||||||
训练入口:接收一批 SampleData,执行一步梯度更新。
|
训练入口:接收一批 SampleData,执行一步梯度更新。
|
||||||
"""
|
"""
|
||||||
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device)
|
obs = self._batch_tensor([s.obs for s in list_sample_data])
|
||||||
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device)
|
legal_action = self._batch_tensor([s.legal_action for s in list_sample_data])
|
||||||
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1)
|
act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1)
|
||||||
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device)
|
old_prob = self._batch_tensor([s.prob for s in list_sample_data])
|
||||||
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device)
|
old_value = self._batch_tensor([s.value for s in list_sample_data])
|
||||||
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device)
|
reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data])
|
||||||
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
|
advantage = self._batch_tensor([s.advantage for s in list_sample_data])
|
||||||
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
|
reward = self._batch_tensor([s.reward for s in list_sample_data])
|
||||||
|
|
||||||
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
|
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
|
||||||
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
|
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
|
||||||
@@ -194,9 +194,22 @@ class Algorithm:
|
|||||||
对 logits 应用合法动作掩码后计算 softmax。
|
对 logits 应用合法动作掩码后计算 softmax。
|
||||||
"""
|
"""
|
||||||
legal_mask = legal_action > 0.5
|
legal_mask = legal_action > 0.5
|
||||||
|
all_illegal = ~legal_mask.any(dim=1, keepdim=True)
|
||||||
|
legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask)
|
||||||
safe_logits = logits.masked_fill(~legal_mask, -1e9)
|
safe_logits = logits.masked_fill(~legal_mask, -1e9)
|
||||||
return F.softmax(safe_logits, dim=1)
|
return F.softmax(safe_logits, dim=1)
|
||||||
|
|
||||||
|
def _batch_tensor(self, values):
|
||||||
|
"""Stack framework tensors or raw numpy/list values into a float tensor."""
|
||||||
|
tensors = []
|
||||||
|
for value in values:
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
tensor = value.to(self.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device)
|
||||||
|
tensors.append(tensor)
|
||||||
|
return torch.stack(tensors)
|
||||||
|
|
||||||
def _entropy_beta(self):
|
def _entropy_beta(self):
|
||||||
"""Linearly decay entropy regularization for fast early exploration."""
|
"""Linearly decay entropy regularization for fast early exploration."""
|
||||||
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
|
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|||||||
# Whether to randomly select maps. Boolean.
|
# Whether to randomly select maps. Boolean.
|
||||||
# true = randomly pick one from configured maps per episode, false = used sequentially.
|
# true = randomly pick one from configured maps per episode, false = used sequentially.
|
||||||
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张,false表示按顺序抽取地图训练。
|
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张,false表示按顺序抽取地图训练。
|
||||||
map_random = false
|
map_random = true
|
||||||
|
|
||||||
# Number of official robots. Range: 1~4 (integer).
|
# Number of official robots. Range: 1~4 (integer).
|
||||||
# In each round, official robots will be randomly generated on the road according to the configured.
|
# In each round, official robots will be randomly generated on the road according to the configured.
|
||||||
@@ -23,4 +23,4 @@ max_step = 1000
|
|||||||
|
|
||||||
# Maximum battery. The battery level when fully charged. Range: 100~999.
|
# Maximum battery. The battery level when fully charged. Range: 100~999.
|
||||||
# 最大电量。满电状态下的电量。可配置范围100~999。
|
# 最大电量。满电状态下的电量。可配置范围100~999。
|
||||||
battery_max = 200
|
battery_max = 200
|
||||||
|
|||||||
@@ -67,7 +67,19 @@ def _calc_gae(list_sample_data):
|
|||||||
gamma = Config.GAMMA
|
gamma = Config.GAMMA
|
||||||
lamda = Config.LAMDA
|
lamda = Config.LAMDA
|
||||||
for sample in reversed(list_sample_data):
|
for sample in reversed(list_sample_data):
|
||||||
delta = -sample.value + sample.reward + gamma * sample.next_value
|
value = _scalar(sample.value)
|
||||||
gae = gae * gamma * lamda + delta
|
reward = _scalar(sample.reward)
|
||||||
|
next_value = _scalar(sample.next_value)
|
||||||
|
nonterminal = 1.0 - _scalar(sample.done)
|
||||||
|
delta = reward + gamma * next_value * nonterminal - value
|
||||||
|
gae = delta + gamma * lamda * nonterminal * gae
|
||||||
sample.advantage = gae
|
sample.advantage = gae
|
||||||
sample.reward_sum = gae + sample.value
|
sample.reward_sum = gae + value
|
||||||
|
|
||||||
|
|
||||||
|
def _scalar(value):
|
||||||
|
"""Read the first scalar from numpy/tensor/list values."""
|
||||||
|
if hasattr(value, "detach"):
|
||||||
|
value = value.detach().cpu().numpy()
|
||||||
|
arr = np.asarray(value, dtype=np.float32).reshape(-1)
|
||||||
|
return float(arr[0]) if arr.size else 0.0
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ Feature preprocessor for Robot Vacuum.
|
|||||||
清扫大作战特征预处理器。
|
清扫大作战特征预处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@@ -55,6 +57,7 @@ class Preprocessor:
|
|||||||
(0, 1),
|
(0, 1),
|
||||||
(1, 1),
|
(1, 1),
|
||||||
)
|
)
|
||||||
|
INF_DIST = 1e6
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -83,12 +86,12 @@ class Preprocessor:
|
|||||||
self.step_cleaned_count = 0
|
self.step_cleaned_count = 0
|
||||||
self.max_step = 1000
|
self.max_step = 1000
|
||||||
|
|
||||||
# Global passable map (0=obstacle, 1=passable), used for ray computation
|
# Global passable map (0=obstacle, 1=passable), indexed by [x, z].
|
||||||
# 维护全局通行地图(0=障碍, 1=可通行),用于射线计算
|
# 维护全局通行地图(0=障碍, 1=可通行),索引为 [x, z]。
|
||||||
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
|
||||||
|
|
||||||
# Nearest dirt distance
|
# Nearest dirt path distance in the current local view.
|
||||||
# 最近污渍距离
|
# 当前局部视野内最近污渍路径距离。
|
||||||
self.nearest_dirt_dist = 200.0
|
self.nearest_dirt_dist = 200.0
|
||||||
self.last_nearest_dirt_dist = 200.0
|
self.last_nearest_dirt_dist = 200.0
|
||||||
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
|
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
|
||||||
@@ -114,6 +117,8 @@ class Preprocessor:
|
|||||||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||||
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
|
self.last_nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||||
|
self.nearest_charger_path_dist = float(self.GRID_SIZE)
|
||||||
|
self.last_nearest_charger_path_dist = float(self.GRID_SIZE)
|
||||||
self.charger_energy_cost = float(self.GRID_SIZE)
|
self.charger_energy_cost = float(self.GRID_SIZE)
|
||||||
self.charger_safety_buffer = 0.0
|
self.charger_safety_buffer = 0.0
|
||||||
self.charger_safety_margin = 0.0
|
self.charger_safety_margin = 0.0
|
||||||
@@ -226,13 +231,33 @@ class Preprocessor:
|
|||||||
|
|
||||||
for ri in range(vsize):
|
for ri in range(vsize):
|
||||||
for ci in range(vsize):
|
for ci in range(vsize):
|
||||||
gx = hx - half + ri
|
gx = hx + ci - half
|
||||||
gz = hz - half + ci
|
gz = hz + ri - half
|
||||||
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
|
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
|
||||||
# 0 = obstacle, 1/2 = passable
|
# 0 = obstacle, 1/2 = passable
|
||||||
# 0 = 障碍, 1/2 = 可通行
|
# 0 = 障碍, 1/2 = 可通行
|
||||||
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
|
||||||
|
|
||||||
|
def _view_index_to_global(self, ri, ci):
|
||||||
|
"""Convert local view row/col to global x/z coordinates."""
|
||||||
|
half = self._view_map.shape[0] // 2
|
||||||
|
hx, hz = self.cur_pos
|
||||||
|
return hx + ci - half, hz + ri - half
|
||||||
|
|
||||||
|
def _view_delta_to_index(self, dx, dz):
|
||||||
|
"""Convert global-coordinate dx/dz to local view row/col."""
|
||||||
|
return self.VIEW_HALF + dz, self.VIEW_HALF + dx
|
||||||
|
|
||||||
|
def _view_cell(self, dx, dz, default=None):
|
||||||
|
"""Read a local-view cell by global-coordinate delta.
|
||||||
|
|
||||||
|
`map_info` is row-major: row follows z/down, col follows x/right.
|
||||||
|
"""
|
||||||
|
ri, ci = self._view_delta_to_index(dx, dz)
|
||||||
|
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
||||||
|
return default
|
||||||
|
return int(self._view_map[ri, ci])
|
||||||
|
|
||||||
def _update_local_map_stats(self):
|
def _update_local_map_stats(self):
|
||||||
"""Cache coarse 21x21 map statistics."""
|
"""Cache coarse 21x21 map statistics."""
|
||||||
view = self._view_map
|
view = self._view_map
|
||||||
@@ -247,6 +272,7 @@ class Preprocessor:
|
|||||||
def _update_charger_state(self, hx, hz, organs):
|
def _update_charger_state(self, hx, hz, organs):
|
||||||
"""Find nearest charger and cache distance/direction features."""
|
"""Find nearest charger and cache distance/direction features."""
|
||||||
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
self.last_nearest_charger_range_dist = self.nearest_charger_range_dist
|
||||||
|
self.last_nearest_charger_path_dist = self.nearest_charger_path_dist
|
||||||
self.has_charger = False
|
self.has_charger = False
|
||||||
self.on_charger = False
|
self.on_charger = False
|
||||||
self.nearest_charger_dx = 0.0
|
self.nearest_charger_dx = 0.0
|
||||||
@@ -255,6 +281,7 @@ class Preprocessor:
|
|||||||
self.nearest_charger_center_dz = 0.0
|
self.nearest_charger_center_dz = 0.0
|
||||||
self.nearest_charger_dist = float(self.GRID_SIZE)
|
self.nearest_charger_dist = float(self.GRID_SIZE)
|
||||||
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
self.nearest_charger_range_dist = float(self.GRID_SIZE)
|
||||||
|
self.nearest_charger_path_dist = float(self.GRID_SIZE)
|
||||||
self.charger_energy_cost = float(self.GRID_SIZE)
|
self.charger_energy_cost = float(self.GRID_SIZE)
|
||||||
self.charger_safety_buffer = 0.0
|
self.charger_safety_buffer = 0.0
|
||||||
self.charger_safety_margin = 0.0
|
self.charger_safety_margin = 0.0
|
||||||
@@ -295,9 +322,11 @@ class Preprocessor:
|
|||||||
self.nearest_charger_center_dz = float(center_dz)
|
self.nearest_charger_center_dz = float(center_dz)
|
||||||
self.nearest_charger_dist = float(dist)
|
self.nearest_charger_dist = float(dist)
|
||||||
self.nearest_charger_range_dist = float(range_dist)
|
self.nearest_charger_range_dist = float(range_dist)
|
||||||
self.charger_energy_cost = float(range_dist)
|
path_dist = self._local_path_dist_to_charger(hx, hz)
|
||||||
|
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
|
||||||
|
self.charger_energy_cost = self.nearest_charger_path_dist
|
||||||
self.on_charger = range_dist <= 0.0
|
self.on_charger = range_dist <= 0.0
|
||||||
self.battery_margin = float(self.battery) - self.nearest_charger_range_dist
|
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
||||||
|
|
||||||
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
|
def _relative_vector_to_rect(self, x, z, rx, rz, w, h):
|
||||||
"""Relative vector from point to the nearest cell in a rectangle."""
|
"""Relative vector from point to the nearest cell in a rectangle."""
|
||||||
@@ -355,7 +384,7 @@ class Preprocessor:
|
|||||||
self.charger_safety_margin = float(self.battery)
|
self.charger_safety_margin = float(self.battery)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.charger_energy_cost = float(max(self.nearest_charger_range_dist, 0.0))
|
self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0))
|
||||||
self.charger_safety_buffer = self._charger_safety_buffer()
|
self.charger_safety_buffer = self._charger_safety_buffer()
|
||||||
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
|
self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer
|
||||||
|
|
||||||
@@ -450,30 +479,20 @@ class Preprocessor:
|
|||||||
ray_dirt = []
|
ray_dirt = []
|
||||||
max_ray = 30
|
max_ray = 30
|
||||||
for dx, dz in ray_dirs:
|
for dx, dz in ray_dirs:
|
||||||
x, z = hx, hz
|
|
||||||
found = max_ray
|
found = max_ray
|
||||||
for step in range(1, max_ray + 1):
|
for step in range(1, max_ray + 1):
|
||||||
x += dx
|
gx = hx + dx * step
|
||||||
z += dz
|
gz = hz + dz * step
|
||||||
if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
|
if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE):
|
||||||
|
break
|
||||||
|
cell = self._view_cell(dx * step, dz * step, default=0)
|
||||||
|
if cell == 2:
|
||||||
|
found = step
|
||||||
break
|
break
|
||||||
if self._view_map is not None:
|
|
||||||
cell = (
|
|
||||||
int(
|
|
||||||
self._view_map[
|
|
||||||
np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21)
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
if cell == 2:
|
|
||||||
found = step
|
|
||||||
break
|
|
||||||
ray_dirt.append(_norm(found, max_ray))
|
ray_dirt.append(_norm(found, max_ray))
|
||||||
|
|
||||||
# Nearest dirt Euclidean distance (estimated from 7×7 crop)
|
# Nearest dirt path distance in the visible map.
|
||||||
# 最近污渍欧氏距离(视野内 7×7 粗估)
|
# 视野内最近污渍路径距离。
|
||||||
self.last_nearest_dirt_dist = self.nearest_dirt_dist
|
self.last_nearest_dirt_dist = self.nearest_dirt_dist
|
||||||
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
|
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
|
||||||
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
|
||||||
@@ -500,7 +519,7 @@ class Preprocessor:
|
|||||||
dirt_delta,
|
dirt_delta,
|
||||||
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
|
_signed_norm(self.nearest_charger_dx, self.GRID_SIZE),
|
||||||
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
|
_signed_norm(self.nearest_charger_dz, self.GRID_SIZE),
|
||||||
_norm(self.nearest_charger_range_dist, self.GRID_SIZE),
|
_norm(self.nearest_charger_path_dist, self.GRID_SIZE),
|
||||||
battery_margin_norm,
|
battery_margin_norm,
|
||||||
1.0 if self.low_battery else 0.0,
|
1.0 if self.low_battery else 0.0,
|
||||||
1.0 if self.recharge_mode else 0.0,
|
1.0 if self.recharge_mode else 0.0,
|
||||||
@@ -519,19 +538,73 @@ class Preprocessor:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _calc_nearest_dirt_dist(self):
|
def _calc_nearest_dirt_dist(self):
|
||||||
"""Find nearest dirt Euclidean distance from local view.
|
"""Find nearest dirt path distance from local view.
|
||||||
|
|
||||||
从局部视野中找最近污渍的欧氏距离。
|
从局部视野中找最近污渍路径距离。
|
||||||
"""
|
"""
|
||||||
view = self._view_map
|
dist = self._local_bfs_distances()
|
||||||
if view is None:
|
dirt_coords = np.argwhere(self._view_map == 2)
|
||||||
return 200.0
|
|
||||||
dirt_coords = np.argwhere(view == 2)
|
|
||||||
if len(dirt_coords) == 0:
|
if len(dirt_coords) == 0:
|
||||||
return 200.0
|
return 200.0
|
||||||
center = self.VIEW_HALF
|
best = min(float(dist[ri, ci]) for ri, ci in dirt_coords)
|
||||||
dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2)
|
if best >= self.INF_DIST:
|
||||||
return float(np.min(dists))
|
return 200.0
|
||||||
|
return best
|
||||||
|
|
||||||
|
def _local_bfs_distances(self, start_dx=0, start_dz=0):
|
||||||
|
"""Shortest path distances inside the current 21x21 local view."""
|
||||||
|
view = self._view_map
|
||||||
|
shape = view.shape
|
||||||
|
dist = np.full(shape, self.INF_DIST, dtype=np.float32)
|
||||||
|
start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz)
|
||||||
|
if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]):
|
||||||
|
return dist
|
||||||
|
if int(view[start_ri, start_ci]) == 0:
|
||||||
|
return dist
|
||||||
|
|
||||||
|
dist[start_ri, start_ci] = 0.0
|
||||||
|
queue = deque([(start_ri, start_ci)])
|
||||||
|
while queue:
|
||||||
|
ri, ci = queue.popleft()
|
||||||
|
base = dist[ri, ci]
|
||||||
|
for dx, dz in self.ACTION_DIRS:
|
||||||
|
nri = ri + dz
|
||||||
|
nci = ci + dx
|
||||||
|
if not (0 <= nri < shape[0] and 0 <= nci < shape[1]):
|
||||||
|
continue
|
||||||
|
if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST:
|
||||||
|
continue
|
||||||
|
if dx != 0 and dz != 0:
|
||||||
|
side_a = int(view[ri, nci]) != 0
|
||||||
|
side_b = int(view[nri, ci]) != 0
|
||||||
|
if not (side_a or side_b):
|
||||||
|
continue
|
||||||
|
dist[nri, nci] = base + 1.0
|
||||||
|
queue.append((nri, nci))
|
||||||
|
return dist
|
||||||
|
|
||||||
|
def _local_path_dist_to_charger(self, gx, gz):
|
||||||
|
"""Visible-map BFS distance from global x/z to nearest charger cell."""
|
||||||
|
best = self.INF_DIST
|
||||||
|
start_dx = gx - self.cur_pos[0]
|
||||||
|
start_dz = gz - self.cur_pos[1]
|
||||||
|
dist = self._local_bfs_distances(start_dx, start_dz)
|
||||||
|
for rx, rz, w, h in self.charger_rects:
|
||||||
|
for tx in range(rx, rx + w):
|
||||||
|
for tz in range(rz, rz + h):
|
||||||
|
dx = tx - self.cur_pos[0]
|
||||||
|
dz = tz - self.cur_pos[1]
|
||||||
|
ri, ci = self._view_delta_to_index(dx, dz)
|
||||||
|
if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]:
|
||||||
|
best = min(best, float(dist[ri, ci]))
|
||||||
|
return best
|
||||||
|
|
||||||
|
def _charger_move_distance(self, gx, gz):
|
||||||
|
"""Use visible BFS to the charger when available, otherwise Chebyshev distance."""
|
||||||
|
path_dist = self._local_path_dist_to_charger(gx, gz)
|
||||||
|
if path_dist < self.INF_DIST:
|
||||||
|
return path_dist
|
||||||
|
return self._min_charger_range_dist(gx, gz)
|
||||||
|
|
||||||
def get_legal_action(self):
|
def get_legal_action(self):
|
||||||
"""Return legal action mask (8D list).
|
"""Return legal action mask (8D list).
|
||||||
@@ -569,11 +642,8 @@ class Preprocessor:
|
|||||||
|
|
||||||
def _is_visible_cell_passable(self, dx, dz):
|
def _is_visible_cell_passable(self, dx, dz):
|
||||||
"""Whether a relative 21x21-view cell is passable."""
|
"""Whether a relative 21x21-view cell is passable."""
|
||||||
ri = self.VIEW_HALF + dx
|
cell = self._view_cell(dx, dz, default=None)
|
||||||
ci = self.VIEW_HALF + dz
|
return True if cell is None else cell != 0
|
||||||
if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]):
|
|
||||||
return True
|
|
||||||
return int(self._view_map[ri, ci]) != 0
|
|
||||||
|
|
||||||
def _filter_npc_danger_actions(self, legal_action):
|
def _filter_npc_danger_actions(self, legal_action):
|
||||||
"""Avoid actions that would enter any NPC 3x3 danger zone."""
|
"""Avoid actions that would enter any NPC 3x3 danger zone."""
|
||||||
@@ -608,42 +678,43 @@ class Preprocessor:
|
|||||||
return list(legal_action)
|
return list(legal_action)
|
||||||
|
|
||||||
hx, hz = self.cur_pos
|
hx, hz = self.cur_pos
|
||||||
current_dist = self._min_charger_range_dist(hx, hz)
|
current_range_dist = self._min_charger_range_dist(hx, hz)
|
||||||
|
current_move_dist = self._charger_move_distance(hx, hz)
|
||||||
scored = []
|
scored = []
|
||||||
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
for action, (dx, dz) in enumerate(self.ACTION_DIRS):
|
||||||
if legal_action[action] <= 0:
|
if legal_action[action] <= 0:
|
||||||
continue
|
continue
|
||||||
nx, nz = hx + dx, hz + dz
|
nx, nz = hx + dx, hz + dz
|
||||||
next_dist = self._min_charger_range_dist(nx, nz)
|
next_dist = self._charger_move_distance(nx, nz)
|
||||||
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
|
alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz
|
||||||
scored.append((next_dist, alignment, action))
|
next_range_dist = self._min_charger_range_dist(nx, nz)
|
||||||
|
scored.append((next_dist, alignment, next_range_dist, action))
|
||||||
|
|
||||||
if not scored:
|
if not scored:
|
||||||
return list(legal_action)
|
return list(legal_action)
|
||||||
|
|
||||||
# When already inside the charger range, stay inside until recharge mode exits.
|
# When already inside the charger range, stay inside until recharge mode exits.
|
||||||
# 已经在充电区域内时,回充模式退出前不要离开充电区域。
|
# 已经在充电区域内时,回充模式退出前不要离开充电区域。
|
||||||
if current_dist <= 0.0:
|
if current_range_dist <= 0.0:
|
||||||
stay = [0] * 8
|
stay = [0] * 8
|
||||||
for next_dist, _, action in scored:
|
for _, _, next_range_dist, action in scored:
|
||||||
if next_dist <= 0.0:
|
if next_range_dist <= 0.0:
|
||||||
stay[action] = 1
|
stay[action] = 1
|
||||||
if any(stay):
|
if any(stay):
|
||||||
return stay
|
return stay
|
||||||
|
|
||||||
best_next_dist = min(item[0] for item in scored)
|
|
||||||
best_alignment = max(alignment for next_dist, alignment, _ in scored if next_dist <= best_next_dist + 0.1)
|
|
||||||
|
|
||||||
recharge = [0] * 8
|
recharge = [0] * 8
|
||||||
for next_dist, alignment, action in scored:
|
best_next_dist = min(item[0] for item in scored)
|
||||||
if next_dist <= best_next_dist + 0.1 and alignment >= best_alignment - 0.1:
|
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
|
||||||
|
for next_dist, _, _, action in ranked:
|
||||||
|
if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1:
|
||||||
recharge[action] = 1
|
recharge[action] = 1
|
||||||
|
if sum(recharge) >= 3:
|
||||||
|
break
|
||||||
|
|
||||||
if not any(recharge):
|
if not any(recharge):
|
||||||
best_alignment = max(item[1] for item in scored)
|
for _, _, _, action in ranked[: min(3, len(ranked))]:
|
||||||
for _, alignment, action in scored:
|
recharge[action] = 1
|
||||||
if alignment >= best_alignment - 0.1:
|
|
||||||
recharge[action] = 1
|
|
||||||
|
|
||||||
return recharge if any(recharge) else list(legal_action)
|
return recharge if any(recharge) else list(legal_action)
|
||||||
|
|
||||||
@@ -733,7 +804,7 @@ class Preprocessor:
|
|||||||
|
|
||||||
if self.has_charger and (self.recharge_mode or self.low_battery):
|
if self.has_charger and (self.recharge_mode or self.low_battery):
|
||||||
dist_delta = float(
|
dist_delta = float(
|
||||||
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0)
|
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
|
||||||
)
|
)
|
||||||
approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04
|
approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04
|
||||||
retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02
|
retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
"""
|
"""
|
||||||
Author: Tencent AI Arena Authors
|
Author: Tencent AI Arena Authors
|
||||||
|
|
||||||
Simple MLP policy network for Robot Vacuum.
|
CNN + MLP policy network for Robot Vacuum.
|
||||||
清扫大作战策略网络。
|
清扫大作战策略网络。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -28,9 +28,9 @@ def _make_fc(in_dim, out_dim, gain=1.41421):
|
|||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
"""Dual-head MLP for Robot Vacuum.
|
"""Dual-head CNN+MLP actor-critic for Robot Vacuum.
|
||||||
|
|
||||||
清扫大作战双头 MLP 策略网络。
|
清扫大作战双头 CNN+MLP 策略网络。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
@@ -38,12 +38,36 @@ class Model(nn.Module):
|
|||||||
self.model_name = "robot_vacuum"
|
self.model_name = "robot_vacuum"
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
obs_dim = Config.DIM_OF_OBSERVATION # 157
|
map_dim, scalar_dim, last_action_dim = Config.FEATURES
|
||||||
|
map_size = int(map_dim**0.5)
|
||||||
|
if map_size * map_size != map_dim:
|
||||||
|
raise ValueError(f"local map feature must be square, got {map_dim}")
|
||||||
|
self.map_size = map_size
|
||||||
|
self.map_dim = map_dim
|
||||||
|
self.scalar_dim = scalar_dim + last_action_dim
|
||||||
act_num = Config.ACTION_NUM # 8
|
act_num = Config.ACTION_NUM # 8
|
||||||
|
|
||||||
# Shared backbone / 共享骨干网络
|
# Local map encoder keeps spatial obstacle/dirt patterns.
|
||||||
|
# 局部地图编码器保留障碍/污渍空间结构。
|
||||||
|
self.map_encoder = nn.Sequential(
|
||||||
|
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.AdaptiveAvgPool2d((3, 3)),
|
||||||
|
nn.Flatten(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scalar encoder for battery, charger, NPC and last-action features.
|
||||||
|
# 标量编码器处理电量、充电桩、NPC、上一步动作等特征。
|
||||||
|
self.scalar_encoder = nn.Sequential(
|
||||||
|
_make_fc(self.scalar_dim, 64),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared fusion backbone / 共享融合骨干网络
|
||||||
self.backbone = nn.Sequential(
|
self.backbone = nn.Sequential(
|
||||||
_make_fc(obs_dim, 256),
|
_make_fc(32 * 3 * 3 + 64, 256),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
_make_fc(256, 128),
|
_make_fc(256, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
@@ -61,7 +85,11 @@ class Model(nn.Module):
|
|||||||
前向传播。
|
前向传播。
|
||||||
"""
|
"""
|
||||||
x = s.to(torch.float32)
|
x = s.to(torch.float32)
|
||||||
h = self.backbone(x)
|
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
|
||||||
|
scalar = x[:, self.map_dim :]
|
||||||
|
map_h = self.map_encoder(local_map)
|
||||||
|
scalar_h = self.scalar_encoder(scalar)
|
||||||
|
h = self.backbone(torch.cat([map_h, scalar_h], dim=1))
|
||||||
logits = self.actor_head(h)
|
logits = self.actor_head(h)
|
||||||
value = self.critic_head(h)
|
value = self.critic_head(h)
|
||||||
return [logits, value]
|
return [logits, value]
|
||||||
|
|||||||
Reference in New Issue
Block a user