Initial robot vacuum code
This commit is contained in:
0
agent_ppo/workflow/__init__.py
Normal file
0
agent_ppo/workflow/__init__.py
Normal file
201
agent_ppo/workflow/train_workflow.py
Normal file
201
agent_ppo/workflow/train_workflow.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: UTF-8 -*-
|
||||
###########################################################################
|
||||
# Copyright © 1998 - 2026 Tencent. All Rights Reserved.
|
||||
###########################################################################
|
||||
"""
|
||||
Author: Tencent AI Arena Authors
|
||||
|
||||
Training workflow for Robot Vacuum.
|
||||
清扫大作战训练工作流。
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from agent_ppo.conf.conf import Config
|
||||
from agent_ppo.feature.definition import SampleData, sample_process
|
||||
from tools.metrics_utils import get_training_metrics
|
||||
from tools.train_env_conf_validate import read_usr_conf
|
||||
from common_python.utils.workflow_disaster_recovery import handle_disaster_recovery
|
||||
|
||||
|
||||
def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
last_save_model_time = time.time()
|
||||
env = envs[0]
|
||||
agent = agents[0]
|
||||
|
||||
# Read and validate user configuration
|
||||
# 读取和校验用户配置
|
||||
usr_conf = read_usr_conf("agent_ppo/conf/train_env_conf.toml", logger)
|
||||
if usr_conf is None:
|
||||
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
|
||||
return
|
||||
|
||||
episode_runner = EpisodeRunner(
|
||||
env=env,
|
||||
agent=agent,
|
||||
usr_conf=usr_conf,
|
||||
logger=logger,
|
||||
monitor=monitor,
|
||||
)
|
||||
|
||||
while True:
|
||||
for g_data in episode_runner.run_episodes():
|
||||
agent.send_sample_data(g_data)
|
||||
g_data.clear()
|
||||
|
||||
now = time.time()
|
||||
if now - last_save_model_time >= 1800:
|
||||
agent.save_model()
|
||||
last_save_model_time = now
|
||||
|
||||
|
||||
class EpisodeRunner:
|
||||
def __init__(self, env, agent, usr_conf, logger, monitor):
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.usr_conf = usr_conf
|
||||
self.logger = logger
|
||||
self.monitor = monitor
|
||||
self.episode_cnt = 0
|
||||
self.last_report_monitor_time = 0
|
||||
self.last_get_training_metrics_time = 0
|
||||
|
||||
def run_episodes(self):
|
||||
"""Run a single episode and yield collected samples.
|
||||
|
||||
单局流程(generator),完成一局后 yield 整局样本。
|
||||
"""
|
||||
while True:
|
||||
# Periodically get training metrics
|
||||
# 定期打印训练指标
|
||||
now = time.time()
|
||||
if now - self.last_get_training_metrics_time >= 60:
|
||||
training_metrics = get_training_metrics()
|
||||
self.last_get_training_metrics_time = now
|
||||
if training_metrics is not None:
|
||||
self.logger.info(f"training_metrics: {training_metrics}")
|
||||
|
||||
# Reset environment
|
||||
# 重置环境
|
||||
env_obs = self.env.reset(self.usr_conf)
|
||||
if handle_disaster_recovery(env_obs, self.logger):
|
||||
continue
|
||||
|
||||
# Reset agent and load latest model
|
||||
# 重置 Agent,加载最新模型
|
||||
self.agent.reset(env_obs)
|
||||
self.agent.load_model(id="latest")
|
||||
|
||||
# Initial observation processing
|
||||
# 初始观测
|
||||
obs_data, remain_info = self.agent.observation_process(env_obs)
|
||||
|
||||
collector = []
|
||||
self.episode_cnt += 1
|
||||
done = False
|
||||
step = 0
|
||||
total_reward = 0.0
|
||||
|
||||
self.logger.info(f"Episode {self.episode_cnt} start")
|
||||
|
||||
while not done:
|
||||
# Agent inference / 推理动作
|
||||
act_data_list = self.agent.predict([obs_data])
|
||||
act_data = act_data_list[0]
|
||||
act = self.agent.action_process(act_data)
|
||||
|
||||
# Environment step / 与环境交互
|
||||
env_reward, env_obs = self.env.step(act)
|
||||
if handle_disaster_recovery(env_obs, self.logger):
|
||||
break
|
||||
|
||||
terminated = env_obs["terminated"]
|
||||
truncated = env_obs["truncated"]
|
||||
frame_no = env_obs["frame_no"]
|
||||
step += 1
|
||||
done = terminated or truncated
|
||||
|
||||
# Process next observation
|
||||
# 特征处理
|
||||
_obs_data, _ = self.agent.observation_process(env_obs)
|
||||
_obs_data.frame_no = frame_no
|
||||
|
||||
reward_scalar = float(self.agent.last_reward)
|
||||
total_reward += reward_scalar
|
||||
|
||||
# Terminal reward calculation
|
||||
# 终局奖励
|
||||
final_reward = 0.0
|
||||
if done:
|
||||
fm = self.agent.preprocessor
|
||||
total_score = env_obs["observation"]["env_info"]["total_score"]
|
||||
|
||||
if truncated:
|
||||
# Survived to max steps: higher cleaning ratio → more reward
|
||||
# 存活到最大步数:清扫比例越高奖励越多
|
||||
cleaning_ratio = fm.dirt_cleaned / max(fm.total_dirt, 1)
|
||||
final_reward = 5.0 + 5.0 * cleaning_ratio
|
||||
result_str = "WIN"
|
||||
else:
|
||||
# Early termination (battery depleted or collision): small penalty
|
||||
# 提前结束(电量耗尽或碰撞):小惩罚
|
||||
final_reward = -2.0
|
||||
result_str = "FAIL"
|
||||
|
||||
self.logger.info(
|
||||
f"[GAMEOVER] ep:{self.episode_cnt} steps:{step} "
|
||||
f"result:{result_str} final_bonus:{final_reward:.2f} "
|
||||
f"total_reward:{total_reward:.3f} "
|
||||
f"dirt_cleaned:{fm.dirt_cleaned}/{fm.total_dirt}"
|
||||
)
|
||||
|
||||
# Build sample frame
|
||||
# 构造样本帧
|
||||
reward_arr = np.array([reward_scalar], dtype=np.float32)
|
||||
value_arr = act_data.value.flatten()[: Config.VALUE_NUM]
|
||||
|
||||
frame = SampleData(
|
||||
obs=np.array(obs_data.feature, dtype=np.float32),
|
||||
legal_action=np.array(obs_data.legal_action, dtype=np.float32),
|
||||
act=np.array(act_data.action),
|
||||
reward=reward_arr,
|
||||
done=np.array([float(done)]),
|
||||
reward_sum=np.zeros(Config.VALUE_NUM, dtype=np.float32),
|
||||
value=value_arr,
|
||||
next_value=np.zeros(Config.VALUE_NUM, dtype=np.float32),
|
||||
advantage=np.zeros(Config.VALUE_NUM, dtype=np.float32),
|
||||
prob=np.array(act_data.prob, dtype=np.float32),
|
||||
)
|
||||
collector.append(frame)
|
||||
|
||||
if done:
|
||||
# Add terminal reward to last frame
|
||||
# 终局奖励叠加到最后一步
|
||||
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
|
||||
|
||||
# Monitor reporting / 监控上报
|
||||
now = time.time()
|
||||
if now - self.last_report_monitor_time >= 60 and self.monitor:
|
||||
self.monitor.put_data(
|
||||
{
|
||||
os.getpid(): {
|
||||
"reward": total_reward + final_reward,
|
||||
"episode_cnt": self.episode_cnt,
|
||||
}
|
||||
}
|
||||
)
|
||||
self.last_report_monitor_time = now
|
||||
|
||||
# Compute GAE and yield samples
|
||||
# GAE 计算并 yield 样本
|
||||
if collector:
|
||||
collector = sample_process(collector)
|
||||
yield collector
|
||||
break
|
||||
|
||||
# Advance state / 状态推进
|
||||
obs_data = _obs_data
|
||||
Reference in New Issue
Block a user