Compare commits
6 Commits
00b26af3ed
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| dc86a3f338 | |||
| 524ca8c070 | |||
| 69b8a692db | |||
| 5b6133db13 | |||
| 220de372e0 | |||
| e99a224d86 |
@@ -76,6 +76,8 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
action = act_data.action if is_stochastic else act_data.d_action
|
||||
self.last_action = int(action[0])
|
||||
if hasattr(self.preprocessor, "record_action"):
|
||||
self.preprocessor.record_action(self.last_action)
|
||||
return self.last_action
|
||||
|
||||
def predict(self, list_obs_data):
|
||||
@@ -110,7 +112,20 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
try:
|
||||
obs_data, _ = self.observation_process(env_obs)
|
||||
act_data = self.predict([obs_data])[0]
|
||||
logits, value = self._run_model(obs_data.feature)
|
||||
legal_arr = np.array(obs_data.legal_action, dtype=np.float32)
|
||||
prob = self._legal_soft_max(logits, legal_arr)
|
||||
action = None
|
||||
if hasattr(self.preprocessor, "planned_eval_action"):
|
||||
action = self.preprocessor.planned_eval_action(prob, legal_arr)
|
||||
if action is None:
|
||||
action = self._tie_break_eval_action(prob, legal_arr)
|
||||
act_data = ActData(
|
||||
action=[action],
|
||||
d_action=[action],
|
||||
prob=list(prob),
|
||||
value=value,
|
||||
)
|
||||
return self.action_process(act_data, is_stochastic=False)
|
||||
except Exception as err:
|
||||
if self.logger:
|
||||
@@ -127,6 +142,11 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
return self.algorithm.learn(list_sample_data)
|
||||
|
||||
def estimate_value(self, obs_data):
|
||||
"""Estimate critic value for a processed observation."""
|
||||
_, value = self._run_model(obs_data.feature)
|
||||
return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM]
|
||||
|
||||
def save_model(self, path=None, id="1"):
|
||||
"""Save model checkpoint.
|
||||
|
||||
@@ -220,3 +240,30 @@ class Agent(BaseAgent):
|
||||
if use_max:
|
||||
return int(np.argmax(probs))
|
||||
return int(np.random.choice(len(probs), p=probs))
|
||||
|
||||
def _tie_break_eval_action(self, probs, legal_action):
|
||||
"""Use a light heuristic only when evaluation probabilities are close."""
|
||||
probs = np.asarray(probs, dtype=np.float64)
|
||||
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||||
if not np.any(legal):
|
||||
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||||
legal_indices = np.flatnonzero(legal)
|
||||
best_action = int(legal_indices[np.argmax(probs[legal_indices])])
|
||||
best_prob = float(probs[best_action])
|
||||
candidates = [
|
||||
int(action)
|
||||
for action in legal_indices
|
||||
if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP
|
||||
]
|
||||
if len(candidates) <= 1:
|
||||
return best_action
|
||||
|
||||
scored = []
|
||||
for action in candidates:
|
||||
heuristic = 0.0
|
||||
if hasattr(self.preprocessor, "evaluation_action_score"):
|
||||
heuristic = self.preprocessor.evaluation_action_score(action)
|
||||
combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic
|
||||
scored.append((combined, float(probs[action]), -action, action))
|
||||
scored.sort(reverse=True)
|
||||
return int(scored[0][3])
|
||||
|
||||
@@ -13,11 +13,13 @@ Configuration for Robot Vacuum PPO agent.
|
||||
|
||||
class Config:
|
||||
|
||||
# Feature dimensions (157D)
|
||||
# 特征维度(157D)
|
||||
# Feature dimensions: 21x21x6 local map + scalar planning features + last action.
|
||||
# 特征维度:21x21x6 多通道局部地图 + 标量规划特征 + 上一步动作。
|
||||
VIEW_SIZE = 21
|
||||
MAP_CHANNELS = 6
|
||||
FEATURES = [
|
||||
11 * 11, # wider local map view / 更大的局部地图视野
|
||||
28, # global, charger, NPC, and map-stat features / 全局、充电桩、NPC、地图统计特征
|
||||
VIEW_SIZE * VIEW_SIZE * MAP_CHANNELS,
|
||||
66, # global memory, charger, NPC, and action-improvement features
|
||||
8, # last action one-hot / 上一步动作 one-hot
|
||||
]
|
||||
FEATURE_SPLIT_SHAPE = FEATURES
|
||||
@@ -48,6 +50,11 @@ class Config:
|
||||
NORMALIZE_ADVANTAGE = True
|
||||
TARGET_KL = 0.04
|
||||
|
||||
# Evaluation tie-break: when policy probabilities are close, prefer safer
|
||||
# coverage/recharge actions with a lightweight heuristic.
|
||||
EVAL_TIE_BREAK_PROB_GAP = 0.015
|
||||
EVAL_TIE_BREAK_SCORE_SCALE = 0.01
|
||||
|
||||
LABEL_SIZE_LIST = [ACTION_NUM]
|
||||
LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy()
|
||||
|
||||
|
||||
@@ -125,6 +125,10 @@ def build_monitor():
|
||||
metrics_name="recharge_escape_count",
|
||||
expr="avg(recharge_escape_count{})",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="recharge_steps",
|
||||
expr="avg(recharge_steps{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="NPC危险接近",
|
||||
@@ -172,6 +176,42 @@ def build_monitor():
|
||||
expr="avg(remaining_charge{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="动作掩码健康",
|
||||
name_en="mask_health",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="mask_final_avg",
|
||||
expr="avg(mask_final_avg{})",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="mask_one_action_steps",
|
||||
expr="avg(mask_one_action_steps{})",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="mask_two_or_less_action_steps",
|
||||
expr="avg(mask_two_or_less_action_steps{})",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="mask_zero_final_steps",
|
||||
expr="avg(mask_zero_final_steps{})",
|
||||
)
|
||||
.end_panel()
|
||||
.add_panel(
|
||||
name="回充动作掩码",
|
||||
name_en="recharge_mask",
|
||||
type="line",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="mask_recharge_active",
|
||||
expr="avg(mask_recharge_active{})",
|
||||
)
|
||||
.add_metric(
|
||||
metrics_name="mask_recharge_changed",
|
||||
expr="avg(mask_recharge_changed{})",
|
||||
)
|
||||
.end_panel()
|
||||
.end_group()
|
||||
.build()
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -39,10 +39,11 @@ class Model(nn.Module):
|
||||
self.device = device
|
||||
|
||||
map_dim, scalar_dim, last_action_dim = Config.FEATURES
|
||||
map_size = int(map_dim**0.5)
|
||||
if map_size * map_size != map_dim:
|
||||
raise ValueError(f"local map feature must be square, got {map_dim}")
|
||||
self.map_size = map_size
|
||||
self.map_size = Config.VIEW_SIZE
|
||||
self.map_channels = Config.MAP_CHANNELS
|
||||
expected_map_dim = self.map_size * self.map_size * self.map_channels
|
||||
if map_dim != expected_map_dim:
|
||||
raise ValueError(f"local map feature must be {expected_map_dim}, got {map_dim}")
|
||||
self.map_dim = map_dim
|
||||
self.scalar_dim = scalar_dim + last_action_dim
|
||||
act_num = Config.ACTION_NUM # 8
|
||||
@@ -50,11 +51,13 @@ class Model(nn.Module):
|
||||
# Local map encoder keeps spatial obstacle/dirt patterns.
|
||||
# 局部地图编码器保留障碍/污渍空间结构。
|
||||
self.map_encoder = nn.Sequential(
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
nn.Conv2d(self.map_channels, 24, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.Conv2d(24, 48, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((3, 3)),
|
||||
nn.Conv2d(48, 48, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d((4, 4)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
@@ -67,7 +70,7 @@ class Model(nn.Module):
|
||||
|
||||
# Shared fusion backbone / 共享融合骨干网络
|
||||
self.backbone = nn.Sequential(
|
||||
_make_fc(32 * 3 * 3 + 64, 256),
|
||||
_make_fc(48 * 4 * 4 + 64, 256),
|
||||
nn.ReLU(),
|
||||
_make_fc(256, 128),
|
||||
nn.ReLU(),
|
||||
@@ -85,7 +88,7 @@ class Model(nn.Module):
|
||||
前向传播。
|
||||
"""
|
||||
x = s.to(torch.float32)
|
||||
local_map = x[:, : self.map_dim].view(-1, 1, self.map_size, self.map_size)
|
||||
local_map = x[:, : self.map_dim].view(-1, self.map_channels, self.map_size, self.map_size)
|
||||
scalar = x[:, self.map_dim :]
|
||||
map_h = self.map_encoder(local_map)
|
||||
scalar_h = self.scalar_encoder(scalar)
|
||||
|
||||
@@ -26,6 +26,8 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
last_save_model_time = time.time()
|
||||
env = envs[0]
|
||||
agent = agents[0]
|
||||
diag_max_episodes = _read_diag_max_episodes(logger)
|
||||
diag_log_only = _read_bool_env("ROBOT_VACUUM_DIAG_LOG_ONLY")
|
||||
|
||||
# Read and validate user configuration
|
||||
# 读取和校验用户配置
|
||||
@@ -33,6 +35,7 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
if usr_conf is None:
|
||||
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
|
||||
return
|
||||
_apply_diag_env_overrides(usr_conf, logger)
|
||||
|
||||
episode_runner = EpisodeRunner(
|
||||
env=env,
|
||||
@@ -40,9 +43,11 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
usr_conf=usr_conf,
|
||||
logger=logger,
|
||||
monitor=monitor,
|
||||
diag_max_episodes=diag_max_episodes,
|
||||
diag_log_only=diag_log_only,
|
||||
)
|
||||
|
||||
while True:
|
||||
while not episode_runner.stop_requested:
|
||||
for g_data in episode_runner.run_episodes():
|
||||
agent.send_sample_data(g_data)
|
||||
g_data.clear()
|
||||
@@ -51,10 +56,56 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
||||
if now - last_save_model_time >= 1800:
|
||||
agent.save_model()
|
||||
last_save_model_time = now
|
||||
if episode_runner.stop_requested:
|
||||
break
|
||||
|
||||
if episode_runner.stop_requested:
|
||||
logger.info(f"diagnostic max episodes reached: {episode_runner.episode_cnt}")
|
||||
|
||||
|
||||
def _read_diag_max_episodes(logger):
|
||||
raw_value = os.environ.get("ROBOT_VACUUM_DIAG_MAX_EPISODES", "").strip()
|
||||
if not raw_value:
|
||||
return 0
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except ValueError:
|
||||
if logger:
|
||||
logger.warning(f"ignore invalid ROBOT_VACUUM_DIAG_MAX_EPISODES={raw_value!r}")
|
||||
return 0
|
||||
return max(value, 0)
|
||||
|
||||
|
||||
def _read_positive_int_env(name, logger):
|
||||
raw_value = os.environ.get(name, "").strip()
|
||||
if not raw_value:
|
||||
return 0
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except ValueError:
|
||||
if logger:
|
||||
logger.warning(f"ignore invalid {name}={raw_value!r}")
|
||||
return 0
|
||||
return max(value, 0)
|
||||
|
||||
|
||||
def _read_bool_env(name):
|
||||
return os.environ.get(name, "").strip().lower() in ("1", "true", "yes", "on")
|
||||
|
||||
|
||||
def _apply_diag_env_overrides(usr_conf, logger):
|
||||
diag_max_step = _read_positive_int_env("ROBOT_VACUUM_DIAG_MAX_STEP", logger)
|
||||
if diag_max_step <= 0:
|
||||
return
|
||||
env_conf = usr_conf.setdefault("env_conf", {})
|
||||
old_max_step = env_conf.get("max_step")
|
||||
env_conf["max_step"] = diag_max_step
|
||||
if logger:
|
||||
logger.info(f"diagnostic max_step override: {old_max_step} -> {diag_max_step}")
|
||||
|
||||
|
||||
class EpisodeRunner:
|
||||
def __init__(self, env, agent, usr_conf, logger, monitor):
|
||||
def __init__(self, env, agent, usr_conf, logger, monitor, diag_max_episodes=0, diag_log_only=False):
|
||||
self.env = env
|
||||
self.agent = agent
|
||||
self.usr_conf = usr_conf
|
||||
@@ -63,6 +114,9 @@ class EpisodeRunner:
|
||||
self.episode_cnt = 0
|
||||
self.last_report_monitor_time = 0
|
||||
self.last_get_training_metrics_time = 0
|
||||
self.diag_max_episodes = int(diag_max_episodes)
|
||||
self.diag_log_only = bool(diag_log_only)
|
||||
self.stop_requested = False
|
||||
|
||||
def run_episodes(self):
|
||||
"""Run a single episode and yield collected samples.
|
||||
@@ -70,6 +124,8 @@ class EpisodeRunner:
|
||||
单局流程(generator),完成一局后 yield 整局样本。
|
||||
"""
|
||||
while True:
|
||||
if self.stop_requested:
|
||||
return
|
||||
# Periodically get training metrics
|
||||
# 定期打印训练指标
|
||||
now = time.time()
|
||||
@@ -158,9 +214,9 @@ class EpisodeRunner:
|
||||
result_str = "WIN"
|
||||
else:
|
||||
if fm.battery <= 0 or remaining_charge <= 0:
|
||||
final_reward = -4.0 + 6.0 * cleaning_ratio
|
||||
final_reward = -fm.battery_fail_penalty() + 4.0 * cleaning_ratio
|
||||
result_str = "BATTERY_FAIL"
|
||||
elif fm.npc_danger or fm.nearest_npc_dist <= 1:
|
||||
elif fm.npc_danger or fm.npc_predicted_danger or fm.nearest_npc_dist <= 1:
|
||||
final_reward = -3.0 + 6.0 * cleaning_ratio
|
||||
result_str = "NPC_FAIL"
|
||||
else:
|
||||
@@ -188,6 +244,39 @@ class EpisodeRunner:
|
||||
f"result_code:{result_code} "
|
||||
f"result_message:{result_message}"
|
||||
)
|
||||
diag = fm.get_diagnostic_summary()
|
||||
self.logger.info(
|
||||
f"[DIAG] ep:{self.episode_cnt} map:{diag['map_id']} "
|
||||
f"steps:{step} result:{result_str} "
|
||||
f"profile:{diag['reward_profile']} route:{diag['charger_route_source']} "
|
||||
f"score:{float(total_score):.1f} reward:{total_reward + final_reward:.3f} "
|
||||
f"mask_avg(raw/block/npc/recharge/escape/leave/final):"
|
||||
f"{diag['avg_mask_counts']['raw']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['blocked']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['npc']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['recharge']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['escape']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['leave']:.2f}/"
|
||||
f"{diag['avg_mask_counts']['final']:.2f} "
|
||||
f"mask_changed(block/npc/recharge/escape/leave):"
|
||||
f"{diag['mask_changed_steps']['blocked']}/"
|
||||
f"{diag['mask_changed_steps']['npc']}/"
|
||||
f"{diag['mask_changed_steps']['recharge']}/"
|
||||
f"{diag['mask_changed_steps']['escape']}/"
|
||||
f"{diag['mask_changed_steps']['leave']} "
|
||||
f"mask_active(recharge/leave):"
|
||||
f"{diag['mask_active_steps']['recharge']}/"
|
||||
f"{diag['mask_active_steps']['leave']} "
|
||||
f"tight(one/<=2/zero):"
|
||||
f"{diag['one_action_steps']}/"
|
||||
f"{diag['two_or_less_action_steps']}/"
|
||||
f"{diag['zero_final_steps']} "
|
||||
f"actions:{diag['action_hist']} "
|
||||
f"known:{diag['known_ratio']:.3f} dirty_known:{diag['known_dirty_ratio']:.3f} "
|
||||
f"frontier:{diag['frontier_ratio']:.3f} "
|
||||
f"path_dirty/frontier:{diag['global_dirty_path_dist']:.1f}/"
|
||||
f"{diag['frontier_path_dist']:.1f}"
|
||||
)
|
||||
|
||||
# Build sample frame
|
||||
# 构造样本帧
|
||||
@@ -212,6 +301,9 @@ class EpisodeRunner:
|
||||
# Add terminal reward to last frame
|
||||
# 终局奖励叠加到最后一步
|
||||
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
|
||||
if truncated and not terminated:
|
||||
collector[-1].next_value = self.agent.estimate_value(_obs_data)
|
||||
collector[-1].done = np.array([0.0], dtype=np.float32)
|
||||
|
||||
# Monitor reporting / 监控上报
|
||||
now = time.time()
|
||||
@@ -231,6 +323,13 @@ class EpisodeRunner:
|
||||
"battery_fail": float(fm.battery_fail),
|
||||
"charge_count": float(charge_count),
|
||||
"remaining_charge": float(remaining_charge),
|
||||
"recharge_steps": float(fm.recharge_steps),
|
||||
"mask_final_avg": float(diag["avg_mask_counts"]["final"]),
|
||||
"mask_recharge_active": float(diag["mask_active_steps"]["recharge"]),
|
||||
"mask_recharge_changed": float(diag["mask_changed_steps"]["recharge"]),
|
||||
"mask_one_action_steps": float(diag["one_action_steps"]),
|
||||
"mask_two_or_less_action_steps": float(diag["two_or_less_action_steps"]),
|
||||
"mask_zero_final_steps": float(diag["zero_final_steps"]),
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -239,6 +338,13 @@ class EpisodeRunner:
|
||||
# Compute GAE and yield samples
|
||||
# GAE 计算并 yield 样本
|
||||
if collector:
|
||||
if self.diag_max_episodes > 0 and self.episode_cnt >= self.diag_max_episodes:
|
||||
self.stop_requested = True
|
||||
if self.diag_log_only:
|
||||
collector.clear()
|
||||
if self.stop_requested:
|
||||
return
|
||||
break
|
||||
collector = sample_process(collector)
|
||||
yield collector
|
||||
break
|
||||
|
||||
@@ -21,9 +21,9 @@ if __name__ == "__main__":
|
||||
algorithm_name=algorithm_name,
|
||||
algorithm_name_list=algorithm_name_list,
|
||||
env_vars={
|
||||
"replay_buffer_capacity": "10",
|
||||
"preload_ratio": "0.2",
|
||||
"replay_buffer_capacity": "8",
|
||||
"preload_ratio": "0.1",
|
||||
"train_batch_size": "2",
|
||||
"dump_model_freq": "1",
|
||||
"dump_model_freq": "100",
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user