优化 PPO 清扫策略

This commit is contained in:
2026-04-26 17:29:03 +08:00
parent f04feb0cd9
commit f44e2483fc
6 changed files with 223 additions and 86 deletions

View File

@@ -6,7 +6,7 @@
"""
Author: Tencent AI Arena Authors
Simple MLP policy network for Robot Vacuum.
CNN + MLP policy network for Robot Vacuum.
清扫大作战策略网络。
"""
@@ -28,9 +28,9 @@ def _make_fc(in_dim, out_dim, gain=1.41421):
class Model(nn.Module):
"""Dual-head MLP for Robot Vacuum.
"""Dual-head CNN+MLP actor-critic for Robot Vacuum.
清扫大作战双头 MLP 策略网络。
清扫大作战双头 CNN+MLP 策略网络。
"""
def __init__(self, device=None):
@@ -38,12 +38,36 @@ class Model(nn.Module):
self.model_name = "robot_vacuum"
self.device = device
obs_dim = Config.DIM_OF_OBSERVATION # 157
map_dim, scalar_dim, last_action_dim = Config.FEATURES
map_size = int(map_dim**0.5)
if map_size * map_size != map_dim:
raise ValueError(f"local map feature must be square, got {map_dim}")
self.map_size = map_size
self.map_dim = map_dim
self.scalar_dim = scalar_dim + last_action_dim
act_num = Config.ACTION_NUM # 8
# Shared backbone / 共享骨干网络
# Local map encoder keeps spatial obstacle/dirt patterns.
# 局部地图编码器保留障碍/污渍空间结构。
self.map_encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((3, 3)),
nn.Flatten(),
)
# Scalar encoder for battery, charger, NPC and last-action features.
# 标量编码器处理电量、充电桩、NPC、上一步动作等特征。
self.scalar_encoder = nn.Sequential(
_make_fc(self.scalar_dim, 64),
nn.ReLU(),
)
# Shared fusion backbone / 共享融合骨干网络
self.backbone = nn.Sequential(
_make_fc(obs_dim, 256),
_make_fc(32 * 3 * 3 + 64, 256),
nn.ReLU(),
_make_fc(256, 128),
nn.ReLU(),
@@ -61,7 +85,11 @@ class Model(nn.Module):
前向传播。
"""
x = s.to(torch.float32)
h = self.backbone(x)
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
scalar = x[:, self.map_dim :]
map_h = self.map_encoder(local_map)
scalar_h = self.scalar_encoder(scalar)
h = self.backbone(torch.cat([map_h, scalar_h], dim=1))
logits = self.actor_head(h)
value = self.critic_head(h)
return [logits, value]