Initial robot vacuum code
This commit is contained in:
0
agent_diy/__init__.py
Normal file
0
agent_diy/__init__.py
Normal file
96
agent_diy/agent.py
Normal file
96
agent_diy/agent.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Robot Vacuum DIY Agent class based on kaiwudrl BaseAgent interface.
|
||||
清扫大作战 DIY Agent 主类,基于 kaiwudrl BaseAgent 接口。
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from kaiwudrl.interface.agent import BaseAgent
|
||||
from agent_diy.model.model import Model
|
||||
from agent_diy.conf.conf import Config
|
||||
|
||||
|
||||
class Agent(BaseAgent):
|
||||
def __init__(self, agent_type="player", device=None, logger=None, monitor=None):
|
||||
"""Initialize the agent.
|
||||
|
||||
初始化 Agent。
|
||||
"""
|
||||
super().__init__(agent_type, device, logger, monitor)
|
||||
|
||||
def predict(self, list_obs_data):
|
||||
"""Predict action from observation data.
|
||||
|
||||
根据观测数据推理动作。
|
||||
"""
|
||||
pass
|
||||
|
||||
def exploit(self, list_obs_data):
|
||||
"""Evaluation mode inference (greedy).
|
||||
|
||||
评估模式推理(贪心)。
|
||||
"""
|
||||
pass
|
||||
|
||||
def learn(self, list_sample_data):
|
||||
"""Train the model.
|
||||
|
||||
训练模型。
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_model(self, path=None, id="1"):
|
||||
"""Save model checkpoint.
|
||||
|
||||
保存模型检查点。
|
||||
"""
|
||||
pass
|
||||
|
||||
def load_model(self, path=None, id="1"):
|
||||
"""Load model checkpoint.
|
||||
|
||||
加载模型检查点。
|
||||
"""
|
||||
pass
|
||||
|
||||
def observation_process(self, obs, preprocessor, extra_info=None):
|
||||
"""
|
||||
This function is an important feature processing function, mainly responsible for:
|
||||
- Parsing information in the raw data
|
||||
- Parsing preprocessed feature data
|
||||
- Processing the features and returning the processed feature vector
|
||||
- Concatenation of features
|
||||
- Annotation of legal actions
|
||||
Function inputs:
|
||||
- obs: Local observation information returned by the environment
|
||||
- preprocessor: Preprocessor
|
||||
- extra_info: Global information returned by the environment
|
||||
Function outputs:
|
||||
- ObsData: Observation data for model inference
|
||||
- remain_info: Other data for reward calculation
|
||||
|
||||
该函数是特征处理的重要函数, 主要负责:
|
||||
- 解析原始数据里的信息
|
||||
- 解析预处理后的特征数据
|
||||
- 对特征进行处理, 并返回处理后的特征向量
|
||||
- 特征的拼接
|
||||
- 合法动作的标注
|
||||
函数的输入:
|
||||
- obs: 环境返回的局部观测信息
|
||||
- preprocessor: 预处理器
|
||||
- extra_info: 环境返回的全局状态信息
|
||||
函数的输出:
|
||||
- ObsData: 用于模型推理的观测数据
|
||||
- remain_info: 用于奖励计算的其他数据
|
||||
"""
|
||||
pass
|
||||
|
||||
def action_process(self, act_data):
|
||||
pass
|
||||
0
agent_diy/algorithm/__init__.py
Normal file
0
agent_diy/algorithm/__init__.py
Normal file
32
agent_diy/algorithm/algorithm.py
Normal file
32
agent_diy/algorithm/algorithm.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Robot Vacuum DIY algorithm implementation.
|
||||
清扫大作战 DIY 算法实现。
|
||||
"""
|
||||
|
||||
|
||||
class Algorithm:
|
||||
"""DIY algorithm class.
|
||||
|
||||
DIY 算法类。
|
||||
"""
|
||||
|
||||
def __init__(self, model, optimizer, scheduler, device=None, logger=None, monitor=None):
|
||||
"""Initialize the algorithm.
|
||||
|
||||
初始化算法。
|
||||
"""
|
||||
pass
|
||||
|
||||
def learn(self, list_sample_data):
|
||||
"""Training entry.
|
||||
|
||||
训练入口。
|
||||
"""
|
||||
pass
|
||||
0
agent_diy/conf/__init__.py
Normal file
0
agent_diy/conf/__init__.py
Normal file
43
agent_diy/conf/conf.py
Normal file
43
agent_diy/conf/conf.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
"""
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# Configuration, including dimension settings and algorithm parameter settings.
|
||||
# 配置,包含维度设置,算法参数设置
|
||||
class Config:
|
||||
|
||||
# Whether to use CNN networks
|
||||
# 是否使用CNN网络
|
||||
USE_CNN = False
|
||||
VIEW_SIZE = 50 if USE_CNN else 0
|
||||
|
||||
FEATURE_VECTOR_SHAPE = (153,)
|
||||
FEATURE_IMAGE_SHAPE = (4, VIEW_SIZE + 1, VIEW_SIZE + 1)
|
||||
|
||||
ACTION_SHAPE = (8,)
|
||||
VALUE_SHAPE = (1,)
|
||||
|
||||
# Discount factor GAMMA in RL
|
||||
# RL中的回报折扣GAMMA
|
||||
GAMMA = 0.95
|
||||
|
||||
# Initial learning rate
|
||||
# 初始的学习率
|
||||
START_LR = 5e-4
|
||||
|
||||
# Value function loss coefficient
|
||||
# 价值函数损失系数
|
||||
VALUE_LOSS_COEFF = 0.5
|
||||
|
||||
# Entropy regularization coefficient
|
||||
# 熵正则化系数
|
||||
ENTROPY_LOSS_COEFF = 0.025
|
||||
83
agent_diy/conf/monitor_builder.py
Normal file
83
agent_diy/conf/monitor_builder.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Monitor panel configuration builder for Robot Vacuum.
|
||||
清扫大作战监控面板配置构建器。
|
||||
"""
|
||||
|
||||
|
||||
from kaiwudrl.common.monitor.monitor_config_builder import MonitorConfigBuilder
|
||||
|
||||
|
||||
def build_monitor():
|
||||
"""
|
||||
This function is used to create monitoring panel configurations for custom indicators.
|
||||
该函数用于创建自定义指标的监控面板配置。
|
||||
"""
|
||||
monitor = MonitorConfigBuilder()
|
||||
|
||||
config_dict = (
|
||||
monitor.title("扫地机器人")
|
||||
.add_group(
|
||||
group_name="算法指标",
|
||||
group_name_en="algorithm",
|
||||
)
|
||||
.add_panel(
|
||||
name="累积回报",
|
||||
name_en="reward",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="reward",
|
||||
expr="avg(reward{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="总损失",
|
||||
name_en="total_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="total_loss",
|
||||
expr="avg(total_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="价值损失",
|
||||
name_en="value_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="value_loss",
|
||||
expr="avg(value_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="策略损失",
|
||||
name_en="policy_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="policy_loss",
|
||||
expr="avg(policy_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="熵损失",
|
||||
name_en="entropy_loss",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="entropy_loss",
|
||||
expr="avg(entropy_loss{})",
|
||||
)
|
||||
.end_panel()
|
||||
.end_group()
|
||||
.build()
|
||||
)
|
||||
return config_dict
|
||||
26
agent_diy/conf/train_env_conf.toml
Normal file
26
agent_diy/conf/train_env_conf.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[env_conf]
|
||||
# Maps used for training. Customize by keeping only desired map IDs, e.g. [1, 2] for maps 1 and 2.
|
||||
# 训练使用的地图。可自定义选择期望用来训练的地图,如只期望使用1、2号地图训练数组内仅保留[1,2]即可。
|
||||
map = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
|
||||
# Whether to randomly select maps. Boolean.
|
||||
# true = randomly pick one from configured maps per episode, false = used sequentially.
|
||||
# 是否随机抽取地图。布尔值。true表示每局从配置的地图中随机抽取一张,false表示按顺序抽取地图训练。
|
||||
map_random = false
|
||||
|
||||
# Number of official robots. Range: 1~4 (integer).
|
||||
# In each round, official robots will be randomly generated on the road according to the configured.
|
||||
# 官方机器人数量。可配置范围为1~4(整数)。每局将按照配置数量在道路上随机生成官方机器人。
|
||||
robot_count = 4
|
||||
|
||||
# Number of chargers. Range: 1~4 (integer). When less than 4, spawn points are randomly chosen.
|
||||
# 充电桩数量。可配置范围为1~4(整数)。当配置小于4时,将从每张地图可生成充电桩的点位随机选择对应数量的点位生成。
|
||||
charger_count = 4
|
||||
|
||||
# Maximum steps. The task ends when the predicted steps in a single round reach the maximum. Range: 1~2000.
|
||||
# 最大步数。单局任务预测步数达到最大步数时,任务结束。可配置范围为1~2000。
|
||||
max_step = 1000
|
||||
|
||||
# Maximum battery. The battery level when fully charged. Range: 100~999.
|
||||
# 最大电量。满电状态下的电量。可配置范围100~999。
|
||||
battery_max = 200
|
||||
0
agent_diy/feature/__init__.py
Normal file
0
agent_diy/feature/__init__.py
Normal file
59
agent_diy/feature/definition.py
Normal file
59
agent_diy/feature/definition.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
"""
|
||||
|
||||
|
||||
from common_python.utils.common_func import create_cls
|
||||
import numpy as np
|
||||
from agent_diy.conf.conf import Config
|
||||
|
||||
# The create_cls function is used to dynamically create a class. The first parameter of the function is the type name,
|
||||
# and the remaining parameters are the attributes of the class, which should have a default value of None.
|
||||
# create_cls函数用于动态创建一个类,函数第一个参数为类型名称,剩余参数为类的属性,属性默认值应设为None
|
||||
ObsData = create_cls(
|
||||
"ObsData",
|
||||
feature=None,
|
||||
legal_act=None,
|
||||
)
|
||||
|
||||
|
||||
ActData = create_cls(
|
||||
"ActData",
|
||||
act=None,
|
||||
)
|
||||
|
||||
|
||||
# SampleData is used to transfer training samples between aisrv and learner.
|
||||
# SampleData用于在aisrv和learner之间传递训练样本
|
||||
SampleData = create_cls(
|
||||
"SampleData",
|
||||
obs=153, # Observation dimension / 观测维度
|
||||
legal_actions=8, # Legal action dimension / 合法动作维度
|
||||
actions=1, # Action dimension / 动作维度
|
||||
probs=8, # Action probability distribution dimension / 动作概率分布维度
|
||||
rewards=1, # Reward / 奖励
|
||||
advantages=1, # Advantage function / 优势函数
|
||||
values=1, # Value function / 价值函数
|
||||
dones=1, # Whether terminated / 是否结束
|
||||
)
|
||||
|
||||
|
||||
def reward_shaping(frame_no, score, terminated, truncated, remain_info, _remain_info, obs, _obs):
|
||||
"""Reward shaping function.
|
||||
|
||||
奖励塑形函数。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def sample_process(list_game_data):
|
||||
"""Sample processing function.
|
||||
|
||||
样本处理函数。
|
||||
"""
|
||||
pass
|
||||
0
agent_diy/model/__init__.py
Normal file
0
agent_diy/model/__init__.py
Normal file
34
agent_diy/model/model.py
Normal file
34
agent_diy/model/model.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Robot Vacuum DIY model implementation.
|
||||
清扫大作战 DIY 模型实现。
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
"""DIY model class.
|
||||
|
||||
DIY 模型类。
|
||||
"""
|
||||
|
||||
def __init__(self, state_shape, action_shape=0, softmax=False):
|
||||
"""Initialize the model.
|
||||
|
||||
初始化模型。
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# User-defined network
|
||||
# 用户自定义网络
|
||||
0
agent_diy/workflow/__init__.py
Normal file
0
agent_diy/workflow/__init__.py
Normal file
43
agent_diy/workflow/train_workflow.py
Normal file
43
agent_diy/workflow/train_workflow.py
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
"""
|
||||
|
||||
|
||||
import time
|
||||
from common_python.utils.common_func import Frame
|
||||
from agent_diy.feature.definition import (
|
||||
sample_process,
|
||||
reward_shaping,
|
||||
)
|
||||
from tools.train_env_conf_validate import read_usr_conf
|
||||
from tools.metrics_utils import get_training_metrics
|
||||
from common_python.utils.workflow_disaster_recovery import handle_disaster_recovery
|
||||
|
||||
|
||||
def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
env, agent = envs[0], agents[0]
|
||||
|
||||
# Read and validate configuration file
|
||||
# 配置文件读取和校验
|
||||
usr_conf = read_usr_conf("agent_diy/conf/train_env_conf.toml", logger)
|
||||
if usr_conf is None:
|
||||
logger.error(f"usr_conf is None, please check agent_diy/conf/train_env_conf.toml")
|
||||
return
|
||||
|
||||
# Please write your DIY training process below.
|
||||
# 请在下方写你DIY的训练流程
|
||||
|
||||
# At the start of each game, support loading the latest model file
|
||||
# 每次对局开始时, 支持加载最新model文件, 该调用会从远程的训练节点加载最新模型
|
||||
agent.load_model(id="latest")
|
||||
|
||||
# Model saving
|
||||
# 保存模型
|
||||
agent.save_model()
|
||||
|
||||
return
|
||||
Reference in New Issue
Block a user