Optimize PPO coverage and recharge strategy

This commit is contained in:
2026-04-26 19:25:05 +08:00
parent 220de372e0
commit 5b6133db13
4 changed files with 399 additions and 108 deletions

View File

@@ -39,10 +39,11 @@ class Model(nn.Module):
self.device = device
map_dim, scalar_dim, last_action_dim = Config.FEATURES
map_size = int(map_dim**0.5)
if map_size * map_size != map_dim:
raise ValueError(f"local map feature must be square, got {map_dim}")
self.map_size = map_size
self.map_size = Config.VIEW_SIZE
self.map_channels = Config.MAP_CHANNELS
expected_map_dim = self.map_size * self.map_size * self.map_channels
if map_dim != expected_map_dim:
raise ValueError(f"local map feature must be {expected_map_dim}, got {map_dim}")
self.map_dim = map_dim
self.scalar_dim = scalar_dim + last_action_dim
act_num = Config.ACTION_NUM # 8
@@ -50,11 +51,13 @@ class Model(nn.Module):
# Local map encoder keeps spatial obstacle/dirt patterns.
# 局部地图编码器保留障碍/污渍空间结构。
self.map_encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1),
nn.Conv2d(self.map_channels, 24, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, padding=1),
nn.Conv2d(24, 48, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((3, 3)),
nn.Conv2d(48, 48, kernel_size=3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((4, 4)),
nn.Flatten(),
)
@@ -67,7 +70,7 @@ class Model(nn.Module):
# Shared fusion backbone / 共享融合骨干网络
self.backbone = nn.Sequential(
_make_fc(32 * 3 * 3 + 64, 256),
_make_fc(48 * 4 * 4 + 64, 256),
nn.ReLU(),
_make_fc(256, 128),
nn.ReLU(),
@@ -85,7 +88,7 @@ class Model(nn.Module):
前向传播。
"""
x = s.to(torch.float32)
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
local_map = x[:, : self.map_dim].view(-1, self.map_channels, self.map_size, self.map_size)
scalar = x[:, self.map_dim :]
map_h = self.map_encoder(local_map)
scalar_h = self.scalar_encoder(scalar)