This repository has been archived on 2026-05-02. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
-----/agent_ppo/feature/preprocessor.py
2026-04-26 12:46:00 +08:00

299 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
###########################################################################
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
###########################################################################
"""
Author: Tencent AI Arena Authors
Feature preprocessor for Robot Vacuum.
清扫大作战特征预处理器。
"""
import numpy as np
def _norm(v, v_max, v_min=0.0):
"""Normalize value to [0, 1].
将值线性归一化到 [0, 1]。
"""
v = float(np.clip(v, v_min, v_max))
if v_max == v_min:
return 0.0
return (v - v_min) / (v_max - v_min)
class Preprocessor:
"""Feature preprocessor for Robot Vacuum.
清扫大作战特征预处理器。
"""
GRID_SIZE = 128
VIEW_HALF = 10 # Full local view radius (21×21) / 完整局部视野半径
LOCAL_HALF = 3 # Cropped view radius (7×7) / 裁剪后的视野半径
def __init__(self):
self.reset()
def reset(self):
"""Reset all internal state at episode start.
对局开始时重置所有状态。
"""
self.step_no = 0
self.battery = 600
self.battery_max = 600
self.cur_pos = (0, 0)
self.prev_pos = None
self.has_position_history = False
self.current_visit_count = 0
self.is_new_cell = False
self.last_action = -1
self.dirt_cleaned = 0
self.last_dirt_cleaned = 0
self.total_dirt = 1
# Global passable map (0=obstacle, 1=passable), used for ray computation
# 维护全局通行地图0=障碍, 1=可通行),用于射线计算
self.passable_map = np.ones((self.GRID_SIZE, self.GRID_SIZE), dtype=np.int8)
# Nearest dirt distance
# 最近污渍距离
self.nearest_dirt_dist = 200.0
self.last_nearest_dirt_dist = 200.0
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
self._view_map = np.zeros((21, 21), dtype=np.float32)
self._legal_act = [1] * 8
def pb2struct(self, env_obs, last_action):
"""Parse and cache essential fields from observation dict.
从 env_obs 字典中提取并缓存所有需要的状态量。
"""
observation = env_obs["observation"]
frame_state = observation["frame_state"]
env_info = observation["env_info"]
hero = frame_state["heroes"]
self.last_action = int(last_action)
self.step_no = int(observation["step_no"])
self.prev_pos = self.cur_pos if self.has_position_history else None
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
self.has_position_history = True
hx, hz = self.cur_pos
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
self.current_visit_count = int(self.visit_count_map[hx, hz])
self.is_new_cell = self.current_visit_count == 0
self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max)
else:
self.current_visit_count = 0
self.is_new_cell = False
# Battery / 电量
self.battery = int(hero["battery"])
self.battery_max = max(int(hero["battery_max"]), 1)
# Cleaning progress / 清扫进度
self.last_dirt_cleaned = self.dirt_cleaned
self.dirt_cleaned = int(hero["dirt_cleaned"])
self.total_dirt = max(int(env_info["total_dirt"]), 1)
# Legal actions / 合法动作
self._legal_act = [int(x) for x in (observation.get("legal_action") or [1] * 8)]
# Local view map (21×21) / 局部视野地图
map_info = observation.get("map_info")
if map_info is not None:
self._view_map = np.array(map_info, dtype=np.float32)
hx, hz = self.cur_pos
self._update_passable(hx, hz)
def _update_passable(self, hx, hz):
"""Write local view into global passable map.
将局部视野写入全局通行地图。
"""
view = self._view_map
vsize = view.shape[0]
half = vsize // 2
for ri in range(vsize):
for ci in range(vsize):
gx = hx - half + ri
gz = hz - half + ci
if 0 <= gx < self.GRID_SIZE and 0 <= gz < self.GRID_SIZE:
# 0 = obstacle, 1/2 = passable
# 0 = 障碍, 1/2 = 可通行
self.passable_map[gx, gz] = 1 if view[ri, ci] != 0 else 0
def _get_local_view_feature(self):
"""Local view feature (49D): crop center 7×7 from 21×21.
局部视野特征49D从 21×21 视野中心裁剪 7×7。
"""
center = self.VIEW_HALF
h = self.LOCAL_HALF
crop = self._view_map[center - h : center + h + 1, center - h : center + h + 1]
return (crop / 2.0).flatten()
def _get_global_state_feature(self):
"""Global state feature (12D).
全局状态特征12D
Dimensions / 维度说明:
[0] step_norm step progress / 步数归一化 [0,1]
[1] battery_ratio battery level / 电量比 [0,1]
[2] cleaning_progress cleaned ratio / 已清扫比例 [0,1]
[3] remaining_dirt remaining dirt ratio / 剩余污渍比例 [0,1]
[4] pos_x_norm x position / x 坐标归一化 [0,1]
[5] pos_z_norm z position / z 坐标归一化 [0,1]
[6] ray_N_dirt north ray distance / 向上z-)方向最近污渍距离
[7] ray_E_dirt east ray distance / 向右x+)方向
[8] ray_S_dirt south ray distance / 向下z+)方向
[9] ray_W_dirt west ray distance / 向左x-)方向
[10] nearest_dirt_norm nearest dirt Euclidean distance / 最近污渍欧氏距离归一化
[11] dirt_delta approaching dirt indicator / 是否在接近污渍1=是, 0=否)
"""
step_norm = _norm(self.step_no, 2000)
battery_ratio = _norm(self.battery, self.battery_max)
cleaning_progress = _norm(self.dirt_cleaned, self.total_dirt)
remaining_dirt = 1.0 - cleaning_progress
hx, hz = self.cur_pos
pos_x_norm = _norm(hx, self.GRID_SIZE)
pos_z_norm = _norm(hz, self.GRID_SIZE)
# 4-directional ray to find nearest dirt
# 四方向射线找最近污渍距离
ray_dirs = [(0, -1), (1, 0), (0, 1), (-1, 0)] # N E S W
ray_dirt = []
max_ray = 30
for dx, dz in ray_dirs:
x, z = hx, hz
found = max_ray
for step in range(1, max_ray + 1):
x += dx
z += dz
if not (0 <= x < self.GRID_SIZE and 0 <= z < self.GRID_SIZE):
break
if self._view_map is not None:
cell = (
int(
self._view_map[
np.clip(x - (hx - self.VIEW_HALF), 0, 20), np.clip(z - (hz - self.VIEW_HALF), 0, 20)
]
)
if (0 <= x - hx + self.VIEW_HALF < 21 and 0 <= z - hz + self.VIEW_HALF < 21)
else 0
)
if cell == 2:
found = step
break
ray_dirt.append(_norm(found, max_ray))
# Nearest dirt Euclidean distance (estimated from 7×7 crop)
# 最近污渍欧氏距离(视野内 7×7 粗估)
self.last_nearest_dirt_dist = self.nearest_dirt_dist
self.nearest_dirt_dist = self._calc_nearest_dirt_dist()
nearest_dirt_norm = _norm(self.nearest_dirt_dist, 180)
dirt_delta = 1.0 if self.nearest_dirt_dist < self.last_nearest_dirt_dist else 0.0
return np.array(
[
step_norm,
battery_ratio,
cleaning_progress,
remaining_dirt,
pos_x_norm,
pos_z_norm,
ray_dirt[0],
ray_dirt[1],
ray_dirt[2],
ray_dirt[3],
nearest_dirt_norm,
dirt_delta,
],
dtype=np.float32,
)
def _calc_nearest_dirt_dist(self):
"""Find nearest dirt Euclidean distance from local view.
从局部视野中找最近污渍的欧氏距离。
"""
view = self._view_map
if view is None:
return 200.0
dirt_coords = np.argwhere(view == 2)
if len(dirt_coords) == 0:
return 200.0
center = self.VIEW_HALF
dists = np.sqrt((dirt_coords[:, 0] - center) ** 2 + (dirt_coords[:, 1] - center) ** 2)
return float(np.min(dists))
def get_legal_action(self):
"""Return legal action mask (8D list).
返回合法动作掩码8D list
"""
return list(self._legal_act)
def feature_process(self, env_obs, last_action):
"""Generate 69D feature vector, legal action mask, and scalar reward.
生成 69D 特征向量、合法动作掩码和标量奖励。
"""
self.pb2struct(env_obs, last_action)
local_view = self._get_local_view_feature() # 49D
global_state = self._get_global_state_feature() # 12D
legal_action = self.get_legal_action() # 8D
last_action_feature = np.zeros(8, dtype=np.float32)
if 0 <= last_action < 8:
last_action_feature[last_action] = 1.0
# The legal action mask is passed separately to PPO. Reusing this 8D slot
# for action history makes the 69D observation more informative without
# breaking the framework's fixed tensor shape.
feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D
reward = self.reward_process()
return feature, legal_action, reward
def reward_process(self):
# Cleaning reward / 清扫奖励
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
cleaning_reward = 0.25 * cleaned_this_step
# Step penalty / 时间惩罚
step_penalty = -0.002
# Dense guidance: prefer moving toward visible dirt.
# 稠密引导:鼓励向视野内污渍靠近。
approach_reward = 0.0
if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0:
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
# Encourage covering new passable cells and mildly discourage loops.
# 鼓励探索新格子,轻微惩罚反复绕圈。
exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5)
# Collision/stuck signal: invalid moves waste both step and battery.
# 撞墙/原地不动会浪费步数和电量。
stuck_penalty = 0.0
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
stuck_penalty = -0.03
return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty