diff --git a/agent_ppo/agent.py b/agent_ppo/agent.py index 0509878..8cd84d2 100644 --- a/agent_ppo/agent.py +++ b/agent_ppo/agent.py @@ -192,18 +192,31 @@ class Agent(BaseAgent): 合法动作掩码下的 softmax。 """ - _w, _e = 1e20, 1e-5 - tmp = logits - _w * (1.0 - legal_action) - tmp_max = np.max(tmp, keepdims=True) - tmp = np.clip(tmp - tmp_max, -_w, 1) - tmp = (np.exp(tmp) + _e) * legal_action - return tmp / (np.sum(tmp, keepdims=True) * 1.00001) + legal = np.asarray(legal_action, dtype=np.float32) > 0.5 + if not np.any(legal): + legal = np.ones(Config.ACTION_NUM, dtype=bool) + + masked_logits = np.asarray(logits, dtype=np.float32).copy() + masked_logits[~legal] = -1e9 + shifted = masked_logits - np.max(masked_logits) + probs = np.exp(shifted) * legal.astype(np.float32) + prob_sum = float(np.sum(probs)) + if prob_sum <= 0.0 or not np.isfinite(prob_sum): + probs = legal.astype(np.float32) + prob_sum = float(np.sum(probs)) + return probs / prob_sum def _legal_sample(self, probs, use_max=False): """Sample action from probability distribution (argmax if use_max=True). 按概率分布采样动作(use_max=True 时取 argmax)。 """ + probs = np.asarray(probs, dtype=np.float64) + prob_sum = float(np.sum(probs)) + if prob_sum <= 0.0 or not np.isfinite(prob_sum): + probs = np.ones(Config.ACTION_NUM, dtype=np.float64) / Config.ACTION_NUM + else: + probs = probs / prob_sum if use_max: return int(np.argmax(probs)) - return int(np.argmax(np.random.multinomial(1, probs, size=1))) + return int(np.random.choice(len(probs), p=probs)) diff --git a/agent_ppo/algorithm/algorithm.py b/agent_ppo/algorithm/algorithm.py index c1d97cc..53cf183 100644 --- a/agent_ppo/algorithm/algorithm.py +++ b/agent_ppo/algorithm/algorithm.py @@ -43,14 +43,14 @@ class Algorithm: 训练入口:接收一批 SampleData,执行一步梯度更新。 """ - obs = torch.stack([s.obs for s in list_sample_data]).to(self.device) - legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device) - act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1) - old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device) - old_value = torch.stack([s.value for s in list_sample_data]).to(self.device) - reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device) - advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device) - reward = torch.stack([s.reward for s in list_sample_data]).to(self.device) + obs = self._batch_tensor([s.obs for s in list_sample_data]) + legal_action = self._batch_tensor([s.legal_action for s in list_sample_data]) + act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1) + old_prob = self._batch_tensor([s.prob for s in list_sample_data]) + old_value = self._batch_tensor([s.value for s in list_sample_data]) + reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data]) + advantage = self._batch_tensor([s.advantage for s in list_sample_data]) + reward = self._batch_tensor([s.reward for s in list_sample_data]) if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1: advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8) @@ -194,9 +194,22 @@ class Algorithm: 对 logits 应用合法动作掩码后计算 softmax。 """ legal_mask = legal_action > 0.5 + all_illegal = ~legal_mask.any(dim=1, keepdim=True) + legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask) safe_logits = logits.masked_fill(~legal_mask, -1e9) return F.softmax(safe_logits, dim=1) + def _batch_tensor(self, values): + """Stack framework tensors or raw numpy/list values into a float tensor.""" + tensors = [] + for value in values: + if isinstance(value, torch.Tensor): + tensor = value.to(self.device, dtype=torch.float32) + else: + tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device) + tensors.append(tensor) + return torch.stack(tensors) + def _entropy_beta(self): """Linearly decay entropy regularization for fast early exploration.""" progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0) diff --git a/agent_ppo/conf/train_env_conf.toml b/agent_ppo/conf/train_env_conf.toml index 48f815b..5fdf0ec 100644 --- a/agent_ppo/conf/train_env_conf.toml +++ b/agent_ppo/conf/train_env_conf.toml @@ -6,7 +6,7 @@ map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # Whether to randomly select maps. Boolean. # true = randomly pick one from configured maps per episode, false = used sequentially. # 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张,false表示按顺序抽取地图训练。 -map_random = false +map_random = true # Number of official robots. Range: 1~4 (integer). # In each round, official robots will be randomly generated on the road according to the configured. @@ -23,4 +23,4 @@ max_step = 1000 # Maximum battery. The battery level when fully charged. Range: 100~999. # 最大电量。满电状态下的电量。可配置范围100~999。 -battery_max = 200 \ No newline at end of file +battery_max = 200 diff --git a/agent_ppo/feature/definition.py b/agent_ppo/feature/definition.py index 84438fe..4f2e511 100644 --- a/agent_ppo/feature/definition.py +++ b/agent_ppo/feature/definition.py @@ -67,7 +67,19 @@ def _calc_gae(list_sample_data): gamma = Config.GAMMA lamda = Config.LAMDA for sample in reversed(list_sample_data): - delta = -sample.value + sample.reward + gamma * sample.next_value - gae = gae * gamma * lamda + delta + value = _scalar(sample.value) + reward = _scalar(sample.reward) + next_value = _scalar(sample.next_value) + nonterminal = 1.0 - _scalar(sample.done) + delta = reward + gamma * next_value * nonterminal - value + gae = delta + gamma * lamda * nonterminal * gae sample.advantage = gae - sample.reward_sum = gae + sample.value + sample.reward_sum = gae + value + + +def _scalar(value): + """Read the first scalar from numpy/tensor/list values.""" + if hasattr(value, "detach"): + value = value.detach().cpu().numpy() + arr = np.asarray(value, dtype=np.float32).reshape(-1) + return float(arr[0]) if arr.size else 0.0 diff --git a/agent_ppo/feature/preprocessor.py b/agent_ppo/feature/preprocessor.py index 1a59529..396b619 100644 --- a/agent_ppo/feature/preprocessor.py +++ b/agent_ppo/feature/preprocessor.py @@ -10,6 +10,8 @@ Feature preprocessor for Robot Vacuum. 清扫大作战特征预处理器。 """ +from collections import deque + import numpy as np @@ -55,6 +57,7 @@ class Preprocessor: (0, 1), (1, 1), ) + INF_DIST = 1e6 def __init__(self): self.reset() @@ -83,12 +86,12 @@ class Preprocessor: self.step_cleaned_count = 0 self.max_step = 1000 - # Global passable map (0=obstacle, 1=passable), used for ray computation - # 维护全局通行地图(0=障碍, 1=可通行),用于射线计算 + # Global passable map (0=obstacle, 1=passable), indexed by [x, z]. + # 维护全局通行地图(0=障碍, 1=可通行),索引为 [x, z]。 self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8) - # Nearest dirt distance - # 最近污渍距离 + # Nearest dirt path distance in the current local view. + # 当前局部视野内最近污渍路径距离。 self.nearest_dirt_dist = 200.0 self.last_nearest_dirt_dist = 200.0 self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16) @@ -114,6 +117,8 @@ class Preprocessor: self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE) self.last_nearest_charger_range_dist = float(self.GRID_SIZE) + self.nearest_charger_path_dist = float(self.GRID_SIZE) + self.last_nearest_charger_path_dist = float(self.GRID_SIZE) self.charger_energy_cost = float(self.GRID_SIZE) self.charger_safety_buffer = 0.0 self.charger_safety_margin = 0.0 @@ -226,13 +231,33 @@ class Preprocessor: for ri in range(vsize): for ci in range(vsize): - gx = hx - half + ri - gz = hz - half + ci + gx = hx + ci - half + gz = hz + ri - half if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE: # 0 = obstacle, 1/2 = passable # 0 = 障碍, 1/2 = 可通行 self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0 + def _view_index_to_global(self, ri, ci): + """Convert local view row/col to global x/z coordinates.""" + half = self._view_map.shape[0] // 2 + hx, hz = self.cur_pos + return hx + ci - half, hz + ri - half + + def _view_delta_to_index(self, dx, dz): + """Convert global-coordinate dx/dz to local view row/col.""" + return self.VIEW_HALF + dz, self.VIEW_HALF + dx + + def _view_cell(self, dx, dz, default=None): + """Read a local-view cell by global-coordinate delta. + + `map_info` is row-major: row follows z/down, col follows x/right. + """ + ri, ci = self._view_delta_to_index(dx, dz) + if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]): + return default + return int(self._view_map[ri, ci]) + def _update_local_map_stats(self): """Cache coarse 21x21 map statistics.""" view = self._view_map @@ -247,6 +272,7 @@ class Preprocessor: def _update_charger_state(self, hx, hz, organs): """Find nearest charger and cache distance/direction features.""" self.last_nearest_charger_range_dist = self.nearest_charger_range_dist + self.last_nearest_charger_path_dist = self.nearest_charger_path_dist self.has_charger = False self.on_charger = False self.nearest_charger_dx = 0.0 @@ -255,6 +281,7 @@ class Preprocessor: self.nearest_charger_center_dz = 0.0 self.nearest_charger_dist = float(self.GRID_SIZE) self.nearest_charger_range_dist = float(self.GRID_SIZE) + self.nearest_charger_path_dist = float(self.GRID_SIZE) self.charger_energy_cost = float(self.GRID_SIZE) self.charger_safety_buffer = 0.0 self.charger_safety_margin = 0.0 @@ -295,9 +322,11 @@ class Preprocessor: self.nearest_charger_center_dz = float(center_dz) self.nearest_charger_dist = float(dist) self.nearest_charger_range_dist = float(range_dist) - self.charger_energy_cost = float(range_dist) + path_dist = self._local_path_dist_to_charger(hx, hz) + self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist) + self.charger_energy_cost = self.nearest_charger_path_dist self.on_charger = range_dist <= 0.0 - self.battery_margin = float(self.battery) - self.nearest_charger_range_dist + self.battery_margin = float(self.battery) - self.nearest_charger_path_dist def _relative_vector_to_rect(self, x, z, rx, rz, w, h): """Relative vector from point to the nearest cell in a rectangle.""" @@ -355,7 +384,7 @@ class Preprocessor: self.charger_safety_margin = float(self.battery) return - self.charger_energy_cost = float(max(self.nearest_charger_range_dist, 0.0)) + self.charger_energy_cost = float(max(self.nearest_charger_path_dist, 0.0)) self.charger_safety_buffer = self._charger_safety_buffer() self.charger_safety_margin = float(self.battery) - self.charger_energy_cost - self.charger_safety_buffer @@ -450,30 +479,20 @@ class Preprocessor: ray_dirt = [] max_ray = 30 for dx, dz in ray_dirs: - x, z = hx, hz found = max_ray for step in range(1, max_ray + 1): - x += dx - z += dz - if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE): + gx = hx + dx * step + gz = hz + dz * step + if not (0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE): + break + cell = self._view_cell(dx * step, dz * step, default=0) + if cell == 2: + found = step break - if self._view_map is not None: - cell = ( - int( - self._view_map[ - np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20) - ] - ) - if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21) - else 0 - ) - if cell == 2: - found = step - break ray_dirt.append(_norm(found, max_ray)) - # Nearest dirt Euclidean distance (estimated from 7×7 crop) - # 最近污渍欧氏距离(视野内 7×7 粗估) + # Nearest dirt path distance in the visible map. + # 视野内最近污渍路径距离。 self.last_nearest_dirt_dist = self.nearest_dirt_dist self.nearest_dirt_dist = self._calc_nearest_dirt_dist() nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180) @@ -500,7 +519,7 @@ class Preprocessor: dirt_delta, _signed_norm(self.nearest_charger_dx, self.GRID_SIZE), _signed_norm(self.nearest_charger_dz, self.GRID_SIZE), - _norm(self.nearest_charger_range_dist, self.GRID_SIZE), + _norm(self.nearest_charger_path_dist, self.GRID_SIZE), battery_margin_norm, 1.0 if self.low_battery else 0.0, 1.0 if self.recharge_mode else 0.0, @@ -519,19 +538,73 @@ class Preprocessor: ) def _calc_nearest_dirt_dist(self): - """Find nearest dirt Euclidean distance from local view. + """Find nearest dirt path distance from local view. - 从局部视野中找最近污渍的欧氏距离。 + 从局部视野中找最近污渍路径距离。 """ - view = self._view_map - if view is None: - return 200.0 - dirt_coords = np.argwhere(view == 2) + dist = self._local_bfs_distances() + dirt_coords = np.argwhere(self._view_map == 2) if len(dirt_coords) == 0: return 200.0 - center = self.VIEW_HALF - dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2) - return float(np.min(dists)) + best = min(float(dist[ri, ci]) for ri, ci in dirt_coords) + if best >= self.INF_DIST: + return 200.0 + return best + + def _local_bfs_distances(self, start_dx=0, start_dz=0): + """Shortest path distances inside the current 21x21 local view.""" + view = self._view_map + shape = view.shape + dist = np.full(shape, self.INF_DIST, dtype=np.float32) + start_ri, start_ci = self._view_delta_to_index(start_dx, start_dz) + if not (0 <= start_ri < shape[0] and 0 <= start_ci < shape[1]): + return dist + if int(view[start_ri, start_ci]) == 0: + return dist + + dist[start_ri, start_ci] = 0.0 + queue = deque([(start_ri, start_ci)]) + while queue: + ri, ci = queue.popleft() + base = dist[ri, ci] + for dx, dz in self.ACTION_DIRS: + nri = ri + dz + nci = ci + dx + if not (0 <= nri < shape[0] and 0 <= nci < shape[1]): + continue + if int(view[nri, nci]) == 0 or dist[nri, nci] < self.INF_DIST: + continue + if dx != 0 and dz != 0: + side_a = int(view[ri, nci]) != 0 + side_b = int(view[nri, ci]) != 0 + if not (side_a or side_b): + continue + dist[nri, nci] = base + 1.0 + queue.append((nri, nci)) + return dist + + def _local_path_dist_to_charger(self, gx, gz): + """Visible-map BFS distance from global x/z to nearest charger cell.""" + best = self.INF_DIST + start_dx = gx - self.cur_pos[0] + start_dz = gz - self.cur_pos[1] + dist = self._local_bfs_distances(start_dx, start_dz) + for rx, rz, w, h in self.charger_rects: + for tx in range(rx, rx + w): + for tz in range(rz, rz + h): + dx = tx - self.cur_pos[0] + dz = tz - self.cur_pos[1] + ri, ci = self._view_delta_to_index(dx, dz) + if 0 <= ri < dist.shape[0] and 0 <= ci < dist.shape[1]: + best = min(best, float(dist[ri, ci])) + return best + + def _charger_move_distance(self, gx, gz): + """Use visible BFS to the charger when available, otherwise Chebyshev distance.""" + path_dist = self._local_path_dist_to_charger(gx, gz) + if path_dist < self.INF_DIST: + return path_dist + return self._min_charger_range_dist(gx, gz) def get_legal_action(self): """Return legal action mask (8D list). @@ -569,11 +642,8 @@ class Preprocessor: def _is_visible_cell_passable(self, dx, dz): """Whether a relative 21x21-view cell is passable.""" - ri = self.VIEW_HALF + dx - ci = self.VIEW_HALF + dz - if not (0 <= ri < self._view_map.shape[0] and 0 <= ci < self._view_map.shape[1]): - return True - return int(self._view_map[ri, ci]) != 0 + cell = self._view_cell(dx, dz, default=None) + return True if cell is None else cell != 0 def _filter_npc_danger_actions(self, legal_action): """Avoid actions that would enter any NPC 3x3 danger zone.""" @@ -608,42 +678,43 @@ class Preprocessor: return list(legal_action) hx, hz = self.cur_pos - current_dist = self._min_charger_range_dist(hx, hz) + current_range_dist = self._min_charger_range_dist(hx, hz) + current_move_dist = self._charger_move_distance(hx, hz) scored = [] for action, (dx, dz) in enumerate(self.ACTION_DIRS): if legal_action[action] <= 0: continue nx, nz = hx + dx, hz + dz - next_dist = self._min_charger_range_dist(nx, nz) + next_dist = self._charger_move_distance(nx, nz) alignment = dx * self.nearest_charger_dx + dz * self.nearest_charger_dz - scored.append((next_dist, alignment, action)) + next_range_dist = self._min_charger_range_dist(nx, nz) + scored.append((next_dist, alignment, next_range_dist, action)) if not scored: return list(legal_action) # When already inside the charger range, stay inside until recharge mode exits. # 已经在充电区域内时,回充模式退出前不要离开充电区域。 - if current_dist <= 0.0: + if current_range_dist <= 0.0: stay = [0] * 8 - for next_dist, _, action in scored: - if next_dist <= 0.0: + for _, _, next_range_dist, action in scored: + if next_range_dist <= 0.0: stay[action] = 1 if any(stay): return stay - best_next_dist = min(item[0] for item in scored) - best_alignment = max(alignment for next_dist, alignment, _ in scored if next_dist <= best_next_dist + 0.1) - recharge = [0] * 8 - for next_dist, alignment, action in scored: - if next_dist <= best_next_dist + 0.1 and alignment >= best_alignment - 0.1: + best_next_dist = min(item[0] for item in scored) + ranked = sorted(scored, key=lambda item: (item[0], -item[1])) + for next_dist, _, _, action in ranked: + if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1: recharge[action] = 1 + if sum(recharge) >= 3: + break if not any(recharge): - best_alignment = max(item[1] for item in scored) - for _, alignment, action in scored: - if alignment >= best_alignment - 0.1: - recharge[action] = 1 + for _, _, _, action in ranked[: min(3, len(ranked))]: + recharge[action] = 1 return recharge if any(recharge) else list(legal_action) @@ -733,7 +804,7 @@ class Preprocessor: if self.has_charger and (self.recharge_mode or self.low_battery): dist_delta = float( - np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -4.0, 4.0) + np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0) ) approach_scale = 0.06 if self.charger_safety_margin <= 0 else 0.04 retreat_scale = 0.03 if self.charger_safety_margin <= 0 else 0.02 diff --git a/agent_ppo/model/model.py b/agent_ppo/model/model.py index 337ef4e..7b16153 100644 --- a/agent_ppo/model/model.py +++ b/agent_ppo/model/model.py @@ -6,7 +6,7 @@ """ Author: Tencent AI Arena Authors -Simple MLP policy network for Robot Vacuum. +CNN + MLP policy network for Robot Vacuum. 清扫大作战策略网络。 """ @@ -28,9 +28,9 @@ def _make_fc(in_dim, out_dim, gain=1.41421): class Model(nn.Module): - """Dual-head MLP for Robot Vacuum. + """Dual-head CNN+MLP actor-critic for Robot Vacuum. - 清扫大作战双头 MLP 策略网络。 + 清扫大作战双头 CNN+MLP 策略网络。 """ def __init__(self, device=None): @@ -38,12 +38,36 @@ class Model(nn.Module): self.model_name = "robot_vacuum" self.device = device - obs_dim = Config.DIM_OF_OBSERVATION # 157 + map_dim, scalar_dim, last_action_dim = Config.FEATURES + map_size = int(map_dim**0.5) + if map_size * map_size != map_dim: + raise ValueError(f"local map feature must be square, got {map_dim}") + self.map_size = map_size + self.map_dim = map_dim + self.scalar_dim = scalar_dim + last_action_dim act_num = Config.ACTION_NUM # 8 - # Shared backbone / 共享骨干网络 + # Local map encoder keeps spatial obstacle/dirt patterns. + # 局部地图编码器保留障碍/污渍空间结构。 + self.map_encoder = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((3, 3)), + nn.Flatten(), + ) + + # Scalar encoder for battery, charger, NPC and last-action features. + # 标量编码器处理电量、充电桩、NPC、上一步动作等特征。 + self.scalar_encoder = nn.Sequential( + _make_fc(self.scalar_dim, 64), + nn.ReLU(), + ) + + # Shared fusion backbone / 共享融合骨干网络 self.backbone = nn.Sequential( - _make_fc(obs_dim, 256), + _make_fc(32 * 3 * 3 + 64, 256), nn.ReLU(), _make_fc(256, 128), nn.ReLU(), @@ -61,7 +85,11 @@ class Model(nn.Module): 前向传播。 """ x = s.to(torch.float32) - h = self.backbone(x) + local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size) + scalar = x[:, self.map_dim :] + map_h = self.map_encoder(local_map) + scalar_h = self.scalar_encoder(scalar) + h = self.backbone(torch.cat([map_h, scalar_h], dim=1)) logits = self.actor_head(h) value = self.critic_head(h) return [logits, value]