105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: UTF-8 -*-
|
|
###########################################################################
|
|
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
|
###########################################################################
|
|
"""
|
|
Author: Tencent AI Arena Authors
|
|
|
|
CNN + MLP policy network for Robot Vacuum.
|
|
清扫大作战策略网络。
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from agent_ppo.conf.conf import Config
|
|
|
|
|
|
def _make_fc(in_dim, out_dim, gain=1.41421):
|
|
"""Create a linear layer with orthogonal initialization.
|
|
|
|
创建正交初始化的线性层。
|
|
"""
|
|
layer = nn.Linear(in_dim, out_dim)
|
|
nn.init.orthogonal_(layer.weight, gain=gain)
|
|
nn.init.zeros_(layer.bias)
|
|
return layer
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""Dual-head CNN+MLP actor-critic for Robot Vacuum.
|
|
|
|
清扫大作战双头 CNN+MLP 策略网络。
|
|
"""
|
|
|
|
def __init__(self, device=None):
|
|
super().__init__()
|
|
self.model_name = "robot_vacuum"
|
|
self.device = device
|
|
|
|
map_dim, scalar_dim, last_action_dim = Config.FEATURES
|
|
self.map_size = Config.VIEW_SIZE
|
|
self.map_channels = Config.MAP_CHANNELS
|
|
expected_map_dim = self.map_size * self.map_size * self.map_channels
|
|
if map_dim != expected_map_dim:
|
|
raise ValueError(f"local map feature must be {expected_map_dim}, got {map_dim}")
|
|
self.map_dim = map_dim
|
|
self.scalar_dim = scalar_dim + last_action_dim
|
|
act_num = Config.ACTION_NUM # 8
|
|
|
|
# Local map encoder keeps spatial obstacle/dirt patterns.
|
|
# 局部地图编码器保留障碍/污渍空间结构。
|
|
self.map_encoder = nn.Sequential(
|
|
nn.Conv2d(self.map_channels, 24, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(24, 48, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.Conv2d(48, 48, kernel_size=3, padding=1),
|
|
nn.ReLU(),
|
|
nn.AdaptiveAvgPool2d((4, 4)),
|
|
nn.Flatten(),
|
|
)
|
|
|
|
# Scalar encoder for battery, charger, NPC and last-action features.
|
|
# 标量编码器处理电量、充电桩、NPC、上一步动作等特征。
|
|
self.scalar_encoder = nn.Sequential(
|
|
_make_fc(self.scalar_dim, 64),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Shared fusion backbone / 共享融合骨干网络
|
|
self.backbone = nn.Sequential(
|
|
_make_fc(48 * 4 * 4 + 64, 256),
|
|
nn.ReLU(),
|
|
_make_fc(256, 128),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
# Actor head: outputs action logits / 策略头:输出动作 logits
|
|
self.actor_head = _make_fc(128, act_num, gain=0.01)
|
|
|
|
# Critic head: outputs single state value / 价值头:输出单个状态价值
|
|
self.critic_head = _make_fc(128, 1, gain=0.01)
|
|
|
|
def forward(self, s, inference=False):
|
|
"""Forward pass.
|
|
|
|
前向传播。
|
|
"""
|
|
x = s.to(torch.float32)
|
|
local_map = x[:, : self.map_dim].view(-1, self.map_channels, self.map_size, self.map_size)
|
|
scalar = x[:, self.map_dim :]
|
|
map_h = self.map_encoder(local_map)
|
|
scalar_h = self.scalar_encoder(scalar)
|
|
h = self.backbone(torch.cat([map_h, scalar_h], dim=1))
|
|
logits = self.actor_head(h)
|
|
value = self.critic_head(h)
|
|
return [logits, value]
|
|
|
|
def set_train_mode(self):
|
|
self.train()
|
|
|
|
def set_eval_mode(self):
|
|
self.eval()
|