Optimize PPO coverage and recharge strategy
This commit is contained in:
@@ -39,10 +39,11 @@ class Model(nn.Module):
|
||||
self.device = device
|
||||
|
||||
map_dim, scalar_dim, last_action_dim = Config.FEATURES
|
||||
map_size = int(map_dim**0.5)
|
||||
if map_size * map_size != map_dim:
|
||||
raise ValueError(f"local map feature must be square, got {map_dim}")
|
||||
self.map_size = map_size
|
||||
self.map_size = Config.VIEW_SIZE
|
||||
self.map_channels = Config.MAP_CHANNELS
|
||||
expected_map_dim = self.map_size * self.map_size * self.map_channels
|
||||
if map_dim != expected_map_dim:
|
||||
raise ValueError(f"local map feature must be {expected_map_dim}, got {map_dim}")
|
||||
self.map_dim = map_dim
|
||||
self.scalar_dim = scalar_dim + last_action_dim
|
||||
act_num = Config.ACTION_NUM # 8
|
||||
@@ -50,11 +51,13 @@ class Model(nn.Module):
|
||||
# Local map encoder keeps spatial obstacle/dirt patterns.
|
||||
# 局部地图编码器保留障碍/污渍空间结构。
|
||||
self.map_encoder = nn.Sequential(
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
nn.Conv2d(self.map_channels, 24, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.Conv2d(24, 48, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((3, 3)),
|
||||
nn.Conv2d(48, 48, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((4, 4)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
@@ -67,7 +70,7 @@ class Model(nn.Module):
|
||||
|
||||
# Shared fusion backbone / 共享融合骨干网络
|
||||
self.backbone = nn.Sequential(
|
||||
_make_fc(32 * 3 * 3 + 64, 256),
|
||||
_make_fc(48 * 4 * 4 + 64, 256),
|
||||
nn.ReLU(),
|
||||
_make_fc(256, 128),
|
||||
nn.ReLU(),
|
||||
@@ -85,7 +88,7 @@ class Model(nn.Module):
|
||||
前向传播。
|
||||
"""
|
||||
x = s.to(torch.float32)
|
||||
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
|
||||
local_map = x[:, : self.map_dim].view(-1, self.map_channels, self.map_size, self.map_size)
|
||||
scalar = x[:, self.map_dim :]
|
||||
map_h = self.map_encoder(local_map)
|
||||
scalar_h = self.scalar_encoder(scalar)
|
||||
|
||||
Reference in New Issue
Block a user