#!/usr/bin/env python3 # -*- coding: UTF-8 -*- ########################################################################### # Copyright © 1998 - 2026 Tencent. All Rights Reserved. ########################################################################### """ Author: Tencent AI Arena Authors CNN + MLP policy network for Robot Vacuum. 清扫大作战策略网络。 """ import torch import torch.nn as nn from agent_ppo.conf.conf import Config def _make_fc(in_dim, out_dim, gain=1.41421): """Create a linear layer with orthogonal initialization. 创建正交初始化的线性层。 """ layer = nn.Linear(in_dim, out_dim) nn.init.orthogonal_(layer.weight, gain=gain) nn.init.zeros_(layer.bias) return layer class Model(nn.Module): """Dual-head CNN+MLP actor-critic for Robot Vacuum. 清扫大作战双头 CNN+MLP 策略网络。 """ def __init__(self, device=None): super().__init__() self.model_name = "robot_vacuum" self.device = device map_dim, scalar_dim, last_action_dim = Config.FEATURES map_size = int(map_dim**0.5) if map_size * map_size != map_dim: raise ValueError(f"local map feature must be square, got {map_dim}") self.map_size = map_size self.map_dim = map_dim self.scalar_dim = scalar_dim + last_action_dim act_num = Config.ACTION_NUM # 8 # Local map encoder keeps spatial obstacle/dirt patterns. # 局部地图编码器保留障碍/污渍空间结构。 self.map_encoder = nn.Sequential( nn.Conv2d(1, 16, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((3, 3)), nn.Flatten(), ) # Scalar encoder for battery, charger, NPC and last-action features. # 标量编码器处理电量、充电桩、NPC、上一步动作等特征。 self.scalar_encoder = nn.Sequential( _make_fc(self.scalar_dim, 64), nn.ReLU(), ) # Shared fusion backbone / 共享融合骨干网络 self.backbone = nn.Sequential( _make_fc(32 * 3 * 3 + 64, 256), nn.ReLU(), _make_fc(256, 128), nn.ReLU(), ) # Actor head: outputs action logits / 策略头:输出动作 logits self.actor_head = _make_fc(128, act_num, gain=0.01) # Critic head: outputs single state value / 价值头:输出单个状态价值 self.critic_head = _make_fc(128, 1, gain=0.01) def forward(self, s, inference=False): """Forward pass. 前向传播。 """ x = s.to(torch.float32) local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size) scalar = x[:, self.map_dim :] map_h = self.map_encoder(local_map) scalar_h = self.scalar_encoder(scalar) h = self.backbone(torch.cat([map_h, scalar_h], dim=1)) logits = self.actor_head(h) value = self.critic_head(h) return [logits, value] def set_train_mode(self): self.train() def set_eval_mode(self): self.eval()