Compare commits
2 Commits
5b6133db13
...
524ca8c070
| Author | SHA1 | Date | |
|---|---|---|---|
| 524ca8c070 | |||
| 69b8a692db |
@@ -76,6 +76,8 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
action = act_data.action if is_stochastic else act_data.d_action
|
action = act_data.action if is_stochastic else act_data.d_action
|
||||||
self.last_action = int(action[0])
|
self.last_action = int(action[0])
|
||||||
|
if hasattr(self.preprocessor, "record_action"):
|
||||||
|
self.preprocessor.record_action(self.last_action)
|
||||||
return self.last_action
|
return self.last_action
|
||||||
|
|
||||||
def predict(self, list_obs_data):
|
def predict(self, list_obs_data):
|
||||||
@@ -110,7 +112,16 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
obs_data, _ = self.observation_process(env_obs)
|
obs_data, _ = self.observation_process(env_obs)
|
||||||
act_data = self.predict([obs_data])[0]
|
logits, value = self._run_model(obs_data.feature)
|
||||||
|
legal_arr = np.array(obs_data.legal_action, dtype=np.float32)
|
||||||
|
prob = self._legal_soft_max(logits, legal_arr)
|
||||||
|
action = self._tie_break_eval_action(prob, legal_arr)
|
||||||
|
act_data = ActData(
|
||||||
|
action=[action],
|
||||||
|
d_action=[action],
|
||||||
|
prob=list(prob),
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
return self.action_process(act_data, is_stochastic=False)
|
return self.action_process(act_data, is_stochastic=False)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
if self.logger:
|
if self.logger:
|
||||||
@@ -127,6 +138,11 @@ class Agent(BaseAgent):
|
|||||||
"""
|
"""
|
||||||
return self.algorithm.learn(list_sample_data)
|
return self.algorithm.learn(list_sample_data)
|
||||||
|
|
||||||
|
def estimate_value(self, obs_data):
|
||||||
|
"""Estimate critic value for a processed observation."""
|
||||||
|
_, value = self._run_model(obs_data.feature)
|
||||||
|
return np.asarray(value, dtype=np.float32).reshape(-1)[: Config.VALUE_NUM]
|
||||||
|
|
||||||
def save_model(self, path=None, id="1"):
|
def save_model(self, path=None, id="1"):
|
||||||
"""Save model checkpoint.
|
"""Save model checkpoint.
|
||||||
|
|
||||||
@@ -220,3 +236,30 @@ class Agent(BaseAgent):
|
|||||||
if use_max:
|
if use_max:
|
||||||
return int(np.argmax(probs))
|
return int(np.argmax(probs))
|
||||||
return int(np.random.choice(len(probs), p=probs))
|
return int(np.random.choice(len(probs), p=probs))
|
||||||
|
|
||||||
|
def _tie_break_eval_action(self, probs, legal_action):
|
||||||
|
"""Use a light heuristic only when evaluation probabilities are close."""
|
||||||
|
probs = np.asarray(probs, dtype=np.float64)
|
||||||
|
legal = np.asarray(legal_action, dtype=np.float32) > 0.5
|
||||||
|
if not np.any(legal):
|
||||||
|
legal = np.ones(Config.ACTION_NUM, dtype=bool)
|
||||||
|
legal_indices = np.flatnonzero(legal)
|
||||||
|
best_action = int(legal_indices[np.argmax(probs[legal_indices])])
|
||||||
|
best_prob = float(probs[best_action])
|
||||||
|
candidates = [
|
||||||
|
int(action)
|
||||||
|
for action in legal_indices
|
||||||
|
if best_prob - float(probs[int(action)]) <= Config.EVAL_TIE_BREAK_PROB_GAP
|
||||||
|
]
|
||||||
|
if len(candidates) <= 1:
|
||||||
|
return best_action
|
||||||
|
|
||||||
|
scored = []
|
||||||
|
for action in candidates:
|
||||||
|
heuristic = 0.0
|
||||||
|
if hasattr(self.preprocessor, "evaluation_action_score"):
|
||||||
|
heuristic = self.preprocessor.evaluation_action_score(action)
|
||||||
|
combined = float(probs[action]) + Config.EVAL_TIE_BREAK_SCORE_SCALE * heuristic
|
||||||
|
scored.append((combined, float(probs[action]), -action, action))
|
||||||
|
scored.sort(reverse=True)
|
||||||
|
return int(scored[0][3])
|
||||||
|
|||||||
@@ -50,6 +50,11 @@ class Config:
|
|||||||
NORMALIZE_ADVANTAGE = True
|
NORMALIZE_ADVANTAGE = True
|
||||||
TARGET_KL = 0.04
|
TARGET_KL = 0.04
|
||||||
|
|
||||||
|
# Evaluation tie-break: when policy probabilities are close, prefer safer
|
||||||
|
# coverage/recharge actions with a lightweight heuristic.
|
||||||
|
EVAL_TIE_BREAK_PROB_GAP = 0.015
|
||||||
|
EVAL_TIE_BREAK_SCORE_SCALE = 0.01
|
||||||
|
|
||||||
LABEL_SIZE_LIST = [ACTION_NUM]
|
LABEL_SIZE_LIST = [ACTION_NUM]
|
||||||
LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy()
|
LEGAL_ACTION_SIZE_LIST = LABEL_SIZE_LIST.copy()
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,10 @@ def build_monitor():
|
|||||||
metrics_name="recharge_escape_count",
|
metrics_name="recharge_escape_count",
|
||||||
expr="avg(recharge_escape_count{})",
|
expr="avg(recharge_escape_count{})",
|
||||||
)
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="recharge_steps",
|
||||||
|
expr="avg(recharge_steps{})",
|
||||||
|
)
|
||||||
.end_panel()
|
.end_panel()
|
||||||
.add_panel(
|
.add_panel(
|
||||||
name="NPC危险接近",
|
name="NPC危险接近",
|
||||||
@@ -172,6 +176,42 @@ def build_monitor():
|
|||||||
expr="avg(remaining_charge{})",
|
expr="avg(remaining_charge{})",
|
||||||
)
|
)
|
||||||
.end_panel()
|
.end_panel()
|
||||||
|
.add_panel(
|
||||||
|
name="动作掩码健康",
|
||||||
|
name_en="mask_health",
|
||||||
|
type="line",
|
||||||
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="mask_final_avg",
|
||||||
|
expr="avg(mask_final_avg{})",
|
||||||
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="mask_one_action_steps",
|
||||||
|
expr="avg(mask_one_action_steps{})",
|
||||||
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="mask_two_or_less_action_steps",
|
||||||
|
expr="avg(mask_two_or_less_action_steps{})",
|
||||||
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="mask_zero_final_steps",
|
||||||
|
expr="avg(mask_zero_final_steps{})",
|
||||||
|
)
|
||||||
|
.end_panel()
|
||||||
|
.add_panel(
|
||||||
|
name="回充动作掩码",
|
||||||
|
name_en="recharge_mask",
|
||||||
|
type="line",
|
||||||
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="mask_recharge_active",
|
||||||
|
expr="avg(mask_recharge_active{})",
|
||||||
|
)
|
||||||
|
.add_metric(
|
||||||
|
metrics_name="mask_recharge_changed",
|
||||||
|
expr="avg(mask_recharge_changed{})",
|
||||||
|
)
|
||||||
|
.end_panel()
|
||||||
.end_group()
|
.end_group()
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ Feature preprocessor for Robot Vacuum.
|
|||||||
清扫大作战特征预处理器。
|
清扫大作战特征预处理器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -70,6 +71,7 @@ class Preprocessor:
|
|||||||
|
|
||||||
对局开始时重置所有状态。
|
对局开始时重置所有状态。
|
||||||
"""
|
"""
|
||||||
|
self.map_id = -1
|
||||||
self.step_no = 0
|
self.step_no = 0
|
||||||
self.battery = 600
|
self.battery = 600
|
||||||
self.battery_max = 600
|
self.battery_max = 600
|
||||||
@@ -118,6 +120,7 @@ class Preprocessor:
|
|||||||
self.frontier_action_delta = np.zeros(8, dtype=np.float32)
|
self.frontier_action_delta = np.zeros(8, dtype=np.float32)
|
||||||
self.charger_action_delta = np.zeros(8, dtype=np.float32)
|
self.charger_action_delta = np.zeros(8, dtype=np.float32)
|
||||||
self.charger_route_known = False
|
self.charger_route_known = False
|
||||||
|
self.charger_route_source = "none"
|
||||||
|
|
||||||
# Nearest dirt path distance in the current local view.
|
# Nearest dirt path distance in the current local view.
|
||||||
# 当前局部视野内最近污渍路径距离。
|
# 当前局部视野内最近污渍路径距离。
|
||||||
@@ -181,6 +184,36 @@ class Preprocessor:
|
|||||||
|
|
||||||
self.local_dirt_ratio = 0.0
|
self.local_dirt_ratio = 0.0
|
||||||
self.local_obstacle_ratio = 0.0
|
self.local_obstacle_ratio = 0.0
|
||||||
|
self.reward_profile = os.environ.get("ROBOT_VACUUM_REWARD_PROFILE", "current").strip().lower() or "current"
|
||||||
|
self._reset_diagnostics()
|
||||||
|
|
||||||
|
def _reset_diagnostics(self):
|
||||||
|
"""Reset episode-local diagnostic counters."""
|
||||||
|
self.diag_mask_steps = 0
|
||||||
|
self.diag_mask_count_sums = {
|
||||||
|
"raw": 0,
|
||||||
|
"blocked": 0,
|
||||||
|
"npc": 0,
|
||||||
|
"recharge": 0,
|
||||||
|
"escape": 0,
|
||||||
|
"leave": 0,
|
||||||
|
"final": 0,
|
||||||
|
}
|
||||||
|
self.diag_mask_changed_steps = {
|
||||||
|
"blocked": 0,
|
||||||
|
"npc": 0,
|
||||||
|
"recharge": 0,
|
||||||
|
"escape": 0,
|
||||||
|
"leave": 0,
|
||||||
|
}
|
||||||
|
self.diag_mask_active_steps = {
|
||||||
|
"recharge": 0,
|
||||||
|
"leave": 0,
|
||||||
|
}
|
||||||
|
self.diag_one_action_steps = 0
|
||||||
|
self.diag_two_or_less_action_steps = 0
|
||||||
|
self.diag_zero_final_steps = 0
|
||||||
|
self.diag_action_hist = [0] * 8
|
||||||
|
|
||||||
def pb2struct(self, env_obs, last_action):
|
def pb2struct(self, env_obs, last_action):
|
||||||
"""Parse and cache essential fields from observation dict.
|
"""Parse and cache essential fields from observation dict.
|
||||||
@@ -199,6 +232,11 @@ class Preprocessor:
|
|||||||
hero = _as_dict(hero)
|
hero = _as_dict(hero)
|
||||||
|
|
||||||
self.last_action = int(last_action)
|
self.last_action = int(last_action)
|
||||||
|
map_id_value = extra_info.get("map_id", env_info.get("map_id", self.map_id))
|
||||||
|
try:
|
||||||
|
self.map_id = int(map_id_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
self.step_no = int(observation.get("step_no", env_info.get("step_no", self.step_no)))
|
||||||
self.terminated = bool(env_obs.get("terminated", False))
|
self.terminated = bool(env_obs.get("terminated", False))
|
||||||
self.truncated = bool(env_obs.get("truncated", False))
|
self.truncated = bool(env_obs.get("truncated", False))
|
||||||
@@ -387,6 +425,7 @@ class Preprocessor:
|
|||||||
self.charger_energy_cost = self.nearest_charger_path_dist
|
self.charger_energy_cost = self.nearest_charger_path_dist
|
||||||
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
self.battery_margin = float(self.battery) - self.nearest_charger_path_dist
|
||||||
self.charger_route_known = True
|
self.charger_route_known = True
|
||||||
|
self.charger_route_source = "global"
|
||||||
|
|
||||||
self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist)
|
self.global_dirty_action_delta = self._action_distance_delta(dirty_dist, self.global_dirty_path_dist)
|
||||||
self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist)
|
self.frontier_action_delta = self._action_distance_delta(frontier_dist, self.frontier_path_dist)
|
||||||
@@ -513,6 +552,7 @@ class Preprocessor:
|
|||||||
self.charger_safety_margin = 0.0
|
self.charger_safety_margin = 0.0
|
||||||
self.charger_rects = []
|
self.charger_rects = []
|
||||||
self.charger_route_known = False
|
self.charger_route_known = False
|
||||||
|
self.charger_route_source = "none"
|
||||||
|
|
||||||
best = None
|
best = None
|
||||||
for organ in organs:
|
for organ in organs:
|
||||||
@@ -550,9 +590,17 @@ class Preprocessor:
|
|||||||
self.nearest_charger_dist = float(dist)
|
self.nearest_charger_dist = float(dist)
|
||||||
self.nearest_charger_range_dist = float(range_dist)
|
self.nearest_charger_range_dist = float(range_dist)
|
||||||
path_dist = self._global_path_dist_to_charger(hx, hz)
|
path_dist = self._global_path_dist_to_charger(hx, hz)
|
||||||
self.charger_route_known = path_dist < self.INF_DIST
|
if path_dist < self.INF_DIST:
|
||||||
if not self.charger_route_known:
|
self.charger_route_known = True
|
||||||
|
self.charger_route_source = "global"
|
||||||
|
else:
|
||||||
path_dist = self._local_path_dist_to_charger(hx, hz)
|
path_dist = self._local_path_dist_to_charger(hx, hz)
|
||||||
|
if path_dist < self.INF_DIST:
|
||||||
|
self.charger_route_known = True
|
||||||
|
self.charger_route_source = "local"
|
||||||
|
else:
|
||||||
|
self.charger_route_known = False
|
||||||
|
self.charger_route_source = "range"
|
||||||
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
|
self.nearest_charger_path_dist = float(path_dist if path_dist < self.INF_DIST else range_dist)
|
||||||
self.charger_energy_cost = self.nearest_charger_path_dist
|
self.charger_energy_cost = self.nearest_charger_path_dist
|
||||||
self.on_charger = range_dist <= 0.0
|
self.on_charger = range_dist <= 0.0
|
||||||
@@ -711,18 +759,26 @@ class Preprocessor:
|
|||||||
|
|
||||||
def _charger_safety_buffer(self):
|
def _charger_safety_buffer(self):
|
||||||
# One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise.
|
# One move roughly costs one charge; reserve extra for detours, local obstacles, and policy noise.
|
||||||
base = max(18.0, 0.12 * float(self.battery_max))
|
base = max(22.0, 0.14 * float(self.battery_max))
|
||||||
distance_buffer = min(16.0, 0.18 * float(max(self.nearest_charger_range_dist, 0.0)))
|
distance_buffer = min(18.0, 0.20 * float(max(self.nearest_charger_range_dist, 0.0)))
|
||||||
obstacle_buffer = 12.0 * float(self.local_obstacle_ratio)
|
obstacle_buffer = 14.0 * float(self.local_obstacle_ratio)
|
||||||
return float(np.clip(base + distance_buffer + obstacle_buffer, 18.0, 48.0))
|
route_uncertainty_buffer = 10.0 if self.has_charger and not self.charger_route_known else 0.0
|
||||||
|
return float(np.clip(base + distance_buffer + obstacle_buffer + route_uncertainty_buffer, 22.0, 58.0))
|
||||||
|
|
||||||
def _recharge_enter_margin(self):
|
def _recharge_enter_margin(self):
|
||||||
"""Adaptive margin for entering recharge mode before the battery is barely enough."""
|
"""Adaptive margin for entering recharge mode before the battery is barely enough."""
|
||||||
base = max(5.0, 0.018 * float(self.battery_max))
|
base = max(7.0, 0.025 * float(self.battery_max))
|
||||||
path_margin = min(12.0, 0.08 * float(max(self.nearest_charger_path_dist, 0.0)))
|
path_margin = min(14.0, 0.10 * float(max(self.nearest_charger_path_dist, 0.0)))
|
||||||
obstacle_margin = 12.0 * float(self.local_obstacle_ratio)
|
obstacle_margin = 14.0 * float(self.local_obstacle_ratio)
|
||||||
|
route_uncertainty_margin = 8.0 if self.has_charger and not self.charger_route_known else 0.0
|
||||||
recovery_margin = min(8.0, 1.5 * float(self.recharge_no_progress_steps + self.fake_charger_steps))
|
recovery_margin = min(8.0, 1.5 * float(self.recharge_no_progress_steps + self.fake_charger_steps))
|
||||||
return float(np.clip(base + path_margin + obstacle_margin + recovery_margin, 4.0, 32.0))
|
return float(
|
||||||
|
np.clip(
|
||||||
|
base + path_margin + obstacle_margin + route_uncertainty_margin + recovery_margin,
|
||||||
|
6.0,
|
||||||
|
42.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def _recharge_leave_margin(self):
|
def _recharge_leave_margin(self):
|
||||||
"""Adaptive safety margin required before leaving a charger."""
|
"""Adaptive safety margin required before leaving a charger."""
|
||||||
@@ -734,10 +790,12 @@ class Preprocessor:
|
|||||||
def _recharge_low_battery_ratio(self):
|
def _recharge_low_battery_ratio(self):
|
||||||
"""Adaptive low-battery ratio based on route length and local obstacle density."""
|
"""Adaptive low-battery ratio based on route length and local obstacle density."""
|
||||||
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
|
path_pressure = float(max(self.nearest_charger_path_dist, 0.0)) / max(float(self.battery_max), 1.0)
|
||||||
ratio = 0.25 + min(0.08, 0.40 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio))
|
ratio = 0.32 + min(0.09, 0.42 * path_pressure) + min(0.04, 0.14 * float(self.local_obstacle_ratio))
|
||||||
|
if self.has_charger and not self.charger_route_known:
|
||||||
|
ratio += 0.04
|
||||||
if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0:
|
if self.recharge_no_progress_steps > 0 or self.fake_charger_steps > 0:
|
||||||
ratio += 0.02
|
ratio += 0.02
|
||||||
return float(np.clip(ratio, 0.25, 0.40))
|
return float(np.clip(ratio, 0.32, 0.46))
|
||||||
|
|
||||||
def _full_charge_leave_ratio(self):
|
def _full_charge_leave_ratio(self):
|
||||||
"""Adaptive near-full threshold for leaving a charger."""
|
"""Adaptive near-full threshold for leaving a charger."""
|
||||||
@@ -774,7 +832,10 @@ class Preprocessor:
|
|||||||
early_fail_risk = 1.0 - step_ratio
|
early_fail_risk = 1.0 - step_ratio
|
||||||
path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0)
|
path_pressure = float(max(self.charger_energy_cost, 0.0)) / max(float(self.battery_max), 1.0)
|
||||||
risk = max(self._recharge_risk_score(), min(1.0, path_pressure))
|
risk = max(self._recharge_risk_score(), min(1.0, path_pressure))
|
||||||
return float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0))
|
penalty = float(np.clip(8.0 + 4.0 * early_fail_risk + 2.0 * risk, 8.0, 14.0))
|
||||||
|
if self.reward_profile == "battery_safe":
|
||||||
|
penalty *= 1.25
|
||||||
|
return penalty
|
||||||
|
|
||||||
def _min_charger_range_dist(self, x, z):
|
def _min_charger_range_dist(self, x, z):
|
||||||
if not self.charger_rects:
|
if not self.charger_rects:
|
||||||
@@ -1011,16 +1072,173 @@ class Preprocessor:
|
|||||||
|
|
||||||
返回合法动作掩码(8D list)。
|
返回合法动作掩码(8D list)。
|
||||||
"""
|
"""
|
||||||
legal = self._filter_blocked_actions(self._legal_act)
|
raw_legal = [int(x) for x in self._legal_act]
|
||||||
legal = self._filter_npc_danger_actions(legal)
|
blocked_legal = self._filter_blocked_actions(raw_legal)
|
||||||
safe_legal = list(legal)
|
npc_legal = self._filter_npc_danger_actions(blocked_legal)
|
||||||
|
safe_legal = list(npc_legal)
|
||||||
|
recharge_legal = None
|
||||||
|
escape_legal = None
|
||||||
|
leave_legal = None
|
||||||
|
legal = npc_legal
|
||||||
if self.recharge_mode:
|
if self.recharge_mode:
|
||||||
legal = self._filter_recharge_actions(legal)
|
recharge_legal = self._filter_recharge_actions(legal)
|
||||||
legal = self._filter_recharge_escape_actions(legal, safe_legal)
|
escape_legal = self._filter_recharge_escape_actions(recharge_legal, safe_legal)
|
||||||
|
legal = escape_legal
|
||||||
elif self.on_charger and self.battery / max(self.battery_max, 1) >= self.full_charge_leave_ratio:
|
elif self.on_charger and self.battery / max(self.battery_max, 1) >= self.full_charge_leave_ratio:
|
||||||
legal = self._filter_leave_charger_actions(legal)
|
leave_legal = self._filter_leave_charger_actions(legal)
|
||||||
|
legal = leave_legal
|
||||||
|
self._record_mask_diagnostics(
|
||||||
|
raw_legal=raw_legal,
|
||||||
|
blocked_legal=blocked_legal,
|
||||||
|
npc_legal=npc_legal,
|
||||||
|
recharge_legal=recharge_legal,
|
||||||
|
escape_legal=escape_legal,
|
||||||
|
leave_legal=leave_legal,
|
||||||
|
final_legal=legal,
|
||||||
|
)
|
||||||
return list(legal)
|
return list(legal)
|
||||||
|
|
||||||
|
def record_action(self, action):
|
||||||
|
"""Record the chosen action for episode diagnostics."""
|
||||||
|
try:
|
||||||
|
action = int(action)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return
|
||||||
|
if 0 <= action < len(self.diag_action_hist):
|
||||||
|
self.diag_action_hist[action] += 1
|
||||||
|
|
||||||
|
def _record_mask_diagnostics(
|
||||||
|
self,
|
||||||
|
raw_legal,
|
||||||
|
blocked_legal,
|
||||||
|
npc_legal,
|
||||||
|
recharge_legal,
|
||||||
|
escape_legal,
|
||||||
|
leave_legal,
|
||||||
|
final_legal,
|
||||||
|
):
|
||||||
|
"""Record action-mask counts without changing mask behavior."""
|
||||||
|
self.diag_mask_steps += 1
|
||||||
|
stages = {
|
||||||
|
"raw": raw_legal,
|
||||||
|
"blocked": blocked_legal,
|
||||||
|
"npc": npc_legal,
|
||||||
|
"recharge": recharge_legal if recharge_legal is not None else npc_legal,
|
||||||
|
"escape": escape_legal if escape_legal is not None else (recharge_legal if recharge_legal is not None else npc_legal),
|
||||||
|
"leave": leave_legal if leave_legal is not None else npc_legal,
|
||||||
|
"final": final_legal,
|
||||||
|
}
|
||||||
|
for name, mask in stages.items():
|
||||||
|
self.diag_mask_count_sums[name] += self._mask_count(mask)
|
||||||
|
|
||||||
|
if not self._same_mask(raw_legal, blocked_legal):
|
||||||
|
self.diag_mask_changed_steps["blocked"] += 1
|
||||||
|
if not self._same_mask(blocked_legal, npc_legal):
|
||||||
|
self.diag_mask_changed_steps["npc"] += 1
|
||||||
|
if recharge_legal is not None:
|
||||||
|
self.diag_mask_active_steps["recharge"] += 1
|
||||||
|
if not self._same_mask(npc_legal, recharge_legal):
|
||||||
|
self.diag_mask_changed_steps["recharge"] += 1
|
||||||
|
if escape_legal is not None and recharge_legal is not None:
|
||||||
|
if not self._same_mask(recharge_legal, escape_legal):
|
||||||
|
self.diag_mask_changed_steps["escape"] += 1
|
||||||
|
if leave_legal is not None:
|
||||||
|
self.diag_mask_active_steps["leave"] += 1
|
||||||
|
if not self._same_mask(npc_legal, leave_legal):
|
||||||
|
self.diag_mask_changed_steps["leave"] += 1
|
||||||
|
|
||||||
|
final_count = self._mask_count(final_legal)
|
||||||
|
if final_count <= 0:
|
||||||
|
self.diag_zero_final_steps += 1
|
||||||
|
if final_count == 1:
|
||||||
|
self.diag_one_action_steps += 1
|
||||||
|
if final_count <= 2:
|
||||||
|
self.diag_two_or_less_action_steps += 1
|
||||||
|
|
||||||
|
def _mask_count(self, mask):
|
||||||
|
return int(sum(1 for value in mask if int(value) > 0))
|
||||||
|
|
||||||
|
def _same_mask(self, left, right):
|
||||||
|
return [int(x) for x in left] == [int(x) for x in right]
|
||||||
|
|
||||||
|
def get_diagnostic_summary(self):
|
||||||
|
"""Return episode-level diagnostic counters for logging."""
|
||||||
|
steps = max(self.diag_mask_steps, 1)
|
||||||
|
avg_mask_counts = {
|
||||||
|
name: self.diag_mask_count_sums[name] / steps for name in sorted(self.diag_mask_count_sums)
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"map_id": self.map_id,
|
||||||
|
"mask_steps": self.diag_mask_steps,
|
||||||
|
"avg_mask_counts": avg_mask_counts,
|
||||||
|
"mask_changed_steps": dict(self.diag_mask_changed_steps),
|
||||||
|
"mask_active_steps": dict(self.diag_mask_active_steps),
|
||||||
|
"one_action_steps": self.diag_one_action_steps,
|
||||||
|
"two_or_less_action_steps": self.diag_two_or_less_action_steps,
|
||||||
|
"zero_final_steps": self.diag_zero_final_steps,
|
||||||
|
"action_hist": list(self.diag_action_hist),
|
||||||
|
"known_ratio": self.known_ratio,
|
||||||
|
"known_dirty_ratio": self.known_dirty_ratio,
|
||||||
|
"frontier_ratio": self.frontier_ratio,
|
||||||
|
"local_dirt_ratio": self.local_dirt_ratio,
|
||||||
|
"local_obstacle_ratio": self.local_obstacle_ratio,
|
||||||
|
"global_dirty_path_dist": self.global_dirty_path_dist,
|
||||||
|
"frontier_path_dist": self.frontier_path_dist,
|
||||||
|
"charger_route_source": self.charger_route_source,
|
||||||
|
"reward_profile": self.reward_profile,
|
||||||
|
}
|
||||||
|
|
||||||
|
def evaluation_action_score(self, action):
|
||||||
|
"""Heuristic score used only to break close evaluation-policy ties."""
|
||||||
|
if not (0 <= int(action) < len(self.ACTION_DIRS)):
|
||||||
|
return -1e6
|
||||||
|
action = int(action)
|
||||||
|
dx, dz = self.ACTION_DIRS[action]
|
||||||
|
hx, hz = self.cur_pos
|
||||||
|
nx, nz = hx + dx, hz + dz
|
||||||
|
if not (0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE):
|
||||||
|
return -1e6
|
||||||
|
|
||||||
|
score = 0.0
|
||||||
|
cell = self._view_cell(dx, dz, default=1)
|
||||||
|
if cell == 0:
|
||||||
|
score -= 8.0
|
||||||
|
elif cell == 2:
|
||||||
|
score += 3.0
|
||||||
|
else:
|
||||||
|
score -= 0.10
|
||||||
|
|
||||||
|
visit_count = int(self.visit_count_map[nx, nz]) if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE else 0
|
||||||
|
score += 0.35 if visit_count == 0 else -0.05 * min(visit_count, 10)
|
||||||
|
|
||||||
|
if self.recharge_mode:
|
||||||
|
if self.charger_route_known:
|
||||||
|
score += 2.2 * float(self.charger_action_delta[action])
|
||||||
|
if self._charger_move_distance(nx, nz) < self._charger_move_distance(hx, hz):
|
||||||
|
score += 0.8
|
||||||
|
else:
|
||||||
|
score += 2.0 * float(self.frontier_action_delta[action])
|
||||||
|
score += 0.7 * max(float(self.global_dirty_action_delta[action]), 0.0)
|
||||||
|
if self._min_charger_range_dist(nx, nz) < self._min_charger_range_dist(hx, hz):
|
||||||
|
score += 0.15
|
||||||
|
else:
|
||||||
|
if self.global_dirty_path_dist < self.GRID_SIZE:
|
||||||
|
score += 1.8 * float(self.global_dirty_action_delta[action])
|
||||||
|
elif self.frontier_path_dist < self.GRID_SIZE:
|
||||||
|
score += 1.4 * float(self.frontier_action_delta[action])
|
||||||
|
|
||||||
|
if self.low_battery and self.has_charger:
|
||||||
|
score += 1.2 * float(self.charger_action_delta[action])
|
||||||
|
if self._is_charger_cell(nx, nz):
|
||||||
|
score += 0.8 if self.low_battery or self.recharge_mode else -0.2
|
||||||
|
if self._is_npc_danger_cell(nx, nz, expanded=False):
|
||||||
|
score -= 6.0
|
||||||
|
elif self._is_npc_danger_cell(nx, nz, expanded=True):
|
||||||
|
score -= 1.5
|
||||||
|
if action == self.last_action and self.stuck_steps >= 1:
|
||||||
|
score -= 1.0
|
||||||
|
return float(score)
|
||||||
|
|
||||||
def _filter_blocked_actions(self, legal_action):
|
def _filter_blocked_actions(self, legal_action):
|
||||||
"""Filter actions that are visibly blocked in the 21x21 view."""
|
"""Filter actions that are visibly blocked in the 21x21 view."""
|
||||||
legal = [int(x) for x in legal_action]
|
legal = [int(x) for x in legal_action]
|
||||||
@@ -1124,21 +1342,67 @@ class Preprocessor:
|
|||||||
if any(stay):
|
if any(stay):
|
||||||
return stay
|
return stay
|
||||||
|
|
||||||
|
if not self.charger_route_known:
|
||||||
|
return self._filter_recharge_discovery_actions(legal_action, scored, current_range_dist)
|
||||||
|
|
||||||
recharge = [0] * 8
|
recharge = [0] * 8
|
||||||
best_next_dist = min(item[0] for item in scored)
|
best_next_dist = min(item[0] for item in scored)
|
||||||
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
|
ranked = sorted(scored, key=lambda item: (item[0], -item[1]))
|
||||||
for next_dist, _, _, action in ranked:
|
max_recharge_actions = 4
|
||||||
if next_dist <= best_next_dist + 2.0 and next_dist <= current_move_dist + 0.1:
|
dist_slack = 2.5
|
||||||
|
for next_dist, alignment, next_range_dist, action in ranked:
|
||||||
|
route_progress = next_dist <= current_move_dist + 0.1
|
||||||
|
if next_dist <= best_next_dist + dist_slack and route_progress:
|
||||||
recharge[action] = 1
|
recharge[action] = 1
|
||||||
if sum(recharge) >= 3:
|
if sum(recharge) >= max_recharge_actions:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not any(recharge):
|
if not any(recharge):
|
||||||
for _, _, _, action in ranked[: min(3, len(ranked))]:
|
for _, _, _, action in ranked[: min(max_recharge_actions, len(ranked))]:
|
||||||
recharge[action] = 1
|
recharge[action] = 1
|
||||||
|
|
||||||
return recharge if any(recharge) else list(legal_action)
|
return recharge if any(recharge) else list(legal_action)
|
||||||
|
|
||||||
|
def _filter_recharge_discovery_actions(self, legal_action, scored, current_range_dist):
|
||||||
|
"""When charger route is unknown, search for a route instead of pushing into walls."""
|
||||||
|
ranked = []
|
||||||
|
hx, hz = self.cur_pos
|
||||||
|
for next_dist, alignment, next_range_dist, action in scored:
|
||||||
|
if legal_action[action] <= 0:
|
||||||
|
continue
|
||||||
|
dx, dz = self.ACTION_DIRS[action]
|
||||||
|
nx, nz = hx + dx, hz + dz
|
||||||
|
visit_count = int(self.visit_count_map[nx, nz]) if 0 <= nx < self.GRID_SIZE and 0 <= nz < self.GRID_SIZE else 0
|
||||||
|
frontier_gain = float(self.frontier_action_delta[action])
|
||||||
|
dirty_gain = float(self.global_dirty_action_delta[action])
|
||||||
|
range_gain = float(np.clip(current_range_dist - next_range_dist, -2.0, 2.0)) / 2.0
|
||||||
|
alignment_gain = 0.25 if alignment > 0 else 0.0
|
||||||
|
repeat_penalty = 0.8 if action == self.last_action and self.recharge_no_progress_steps >= 2 else 0.0
|
||||||
|
wall_hug_penalty = 0.35 * float(self.local_obstacle_ratio)
|
||||||
|
score = (
|
||||||
|
2.4 * frontier_gain
|
||||||
|
+ 0.8 * max(dirty_gain, 0.0)
|
||||||
|
+ 0.35 * range_gain
|
||||||
|
+ alignment_gain
|
||||||
|
- 0.04 * min(visit_count, 12)
|
||||||
|
- repeat_penalty
|
||||||
|
- wall_hug_penalty
|
||||||
|
)
|
||||||
|
ranked.append((score, action))
|
||||||
|
|
||||||
|
if not ranked:
|
||||||
|
return list(legal_action)
|
||||||
|
|
||||||
|
ranked.sort(reverse=True)
|
||||||
|
best_score = ranked[0][0]
|
||||||
|
discovery = [0] * 8
|
||||||
|
for score, action in ranked:
|
||||||
|
if score >= best_score - 0.35 or sum(discovery) < 3:
|
||||||
|
discovery[action] = 1
|
||||||
|
if sum(discovery) >= 5:
|
||||||
|
break
|
||||||
|
return discovery if any(discovery) else list(legal_action)
|
||||||
|
|
||||||
def _filter_recharge_escape_actions(self, recharge_action, safe_action):
|
def _filter_recharge_escape_actions(self, recharge_action, safe_action):
|
||||||
"""Escape repeated no-move states during low-battery recharge."""
|
"""Escape repeated no-move states during low-battery recharge."""
|
||||||
if not self._need_recharge_escape():
|
if not self._need_recharge_escape():
|
||||||
@@ -1228,10 +1492,15 @@ class Preprocessor:
|
|||||||
return feature, legal_action, reward
|
return feature, legal_action, reward
|
||||||
|
|
||||||
def reward_process(self):
|
def reward_process(self):
|
||||||
|
cleaning_multiplier, charge_multiplier, exploration_multiplier = self._reward_profile_scales()
|
||||||
|
|
||||||
# Cleaning reward / 清扫奖励
|
# Cleaning reward / 清扫奖励
|
||||||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||||||
cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step
|
cleaned_cells = self.step_cleaned_count if self.step_cleaned_count > 0 else cleaned_this_step
|
||||||
cleaning_scale = 0.2 if self.recharge_mode else 0.7
|
battery_ratio = self.battery / max(self.battery_max, 1)
|
||||||
|
battery_pressure = self.has_charger and battery_ratio < self.recharge_low_battery_ratio + 0.06
|
||||||
|
cleaning_scale = 0.2 if self.recharge_mode else (0.55 if battery_pressure else 0.7)
|
||||||
|
cleaning_scale *= cleaning_multiplier
|
||||||
cleaning_reward = cleaning_scale * cleaned_cells
|
cleaning_reward = cleaning_scale * cleaned_cells
|
||||||
|
|
||||||
# Step penalty / 时间惩罚
|
# Step penalty / 时间惩罚
|
||||||
@@ -1240,7 +1509,6 @@ class Preprocessor:
|
|||||||
# Recharge guidance only activates when battery safety is the bottleneck.
|
# Recharge guidance only activates when battery safety is the bottleneck.
|
||||||
# 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。
|
# 仅在低电量/回充模式下引导靠近充电桩,避免高电量蹲充电桩。
|
||||||
charge_reward = 0.0
|
charge_reward = 0.0
|
||||||
battery_ratio = self.battery / max(self.battery_max, 1)
|
|
||||||
prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1)
|
prev_battery_ratio = self.prev_battery / max(self.prev_battery_max, 1)
|
||||||
useful_charge = self.charge_delta > 0 and (
|
useful_charge = self.charge_delta > 0 and (
|
||||||
self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.45
|
self.prev_low_battery or self.was_recharge_mode or prev_battery_ratio < 0.45
|
||||||
@@ -1251,10 +1519,23 @@ class Preprocessor:
|
|||||||
charge_reward -= 0.25 * min(self.charge_delta, 3)
|
charge_reward -= 0.25 * min(self.charge_delta, 3)
|
||||||
|
|
||||||
if self.has_charger and (self.recharge_mode or self.low_battery):
|
if self.has_charger and (self.recharge_mode or self.low_battery):
|
||||||
|
recharge_risk = self._recharge_risk_score()
|
||||||
|
if not self.charger_route_known:
|
||||||
|
frontier_progress = float(
|
||||||
|
np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
|
||||||
|
)
|
||||||
|
range_delta = float(
|
||||||
|
np.clip(self.last_nearest_charger_range_dist - self.nearest_charger_range_dist, -2.0, 2.0)
|
||||||
|
)
|
||||||
|
discovery_scale = 0.035 + 0.035 * recharge_risk
|
||||||
|
range_scale = 0.015 + 0.015 * recharge_risk
|
||||||
|
charge_reward += discovery_scale * frontier_progress
|
||||||
|
if self.prev_pos is not None and self.cur_pos != self.prev_pos and self.stuck_steps == 0:
|
||||||
|
charge_reward += range_scale * range_delta
|
||||||
|
else:
|
||||||
dist_delta = float(
|
dist_delta = float(
|
||||||
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
|
np.clip(self.last_nearest_charger_path_dist - self.nearest_charger_path_dist, -4.0, 4.0)
|
||||||
)
|
)
|
||||||
recharge_risk = self._recharge_risk_score()
|
|
||||||
approach_scale = 0.07 + 0.06 * recharge_risk
|
approach_scale = 0.07 + 0.06 * recharge_risk
|
||||||
retreat_scale = 0.035 + 0.045 * recharge_risk
|
retreat_scale = 0.035 + 0.045 * recharge_risk
|
||||||
charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta
|
charge_reward += approach_scale * dist_delta if dist_delta > 0 else retreat_scale * dist_delta
|
||||||
@@ -1263,6 +1544,7 @@ class Preprocessor:
|
|||||||
charge_reward -= min(0.55, safety_shortage / max(self.battery_max, 1))
|
charge_reward -= min(0.55, safety_shortage / max(self.battery_max, 1))
|
||||||
elif self.on_charger and battery_ratio > 0.65:
|
elif self.on_charger and battery_ratio > 0.65:
|
||||||
charge_reward -= 0.08
|
charge_reward -= 0.08
|
||||||
|
charge_reward *= charge_multiplier
|
||||||
|
|
||||||
# Encourage covering new passable cells and mildly discourage loops.
|
# Encourage covering new passable cells and mildly discourage loops.
|
||||||
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
||||||
@@ -1276,6 +1558,7 @@ class Preprocessor:
|
|||||||
elif self.frontier_path_dist < self.GRID_SIZE:
|
elif self.frontier_path_dist < self.GRID_SIZE:
|
||||||
frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
|
frontier_progress = np.clip(self.last_frontier_path_dist - self.frontier_path_dist, -3.0, 3.0)
|
||||||
exploration_reward += 0.005 * frontier_progress
|
exploration_reward += 0.005 * frontier_progress
|
||||||
|
exploration_reward *= exploration_multiplier
|
||||||
|
|
||||||
# Collision/stuck signal: invalid moves waste both step and battery.
|
# Collision/stuck signal: invalid moves waste both step and battery.
|
||||||
# 撞墙/原地不动会浪费步数和电量。
|
# 撞墙/原地不动会浪费步数和电量。
|
||||||
@@ -1301,3 +1584,13 @@ class Preprocessor:
|
|||||||
+ npc_penalty
|
+ npc_penalty
|
||||||
+ step_penalty
|
+ step_penalty
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _reward_profile_scales(self):
|
||||||
|
"""Return multipliers for quick reward-shaping ablations."""
|
||||||
|
if self.reward_profile == "lower_recharge":
|
||||||
|
return 1.0, 0.70, 1.0
|
||||||
|
if self.reward_profile == "clean_explore":
|
||||||
|
return 1.15, 0.85, 1.50
|
||||||
|
if self.reward_profile == "battery_safe":
|
||||||
|
return 0.95, 1.25, 0.90
|
||||||
|
return 1.0, 1.0, 1.0
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
|||||||
last_save_model_time = time.time()
|
last_save_model_time = time.time()
|
||||||
env = envs[0]
|
env = envs[0]
|
||||||
agent = agents[0]
|
agent = agents[0]
|
||||||
|
diag_max_episodes = _read_diag_max_episodes(logger)
|
||||||
|
diag_log_only = _read_bool_env("ROBOT_VACUUM_DIAG_LOG_ONLY")
|
||||||
|
|
||||||
# Read and validate user configuration
|
# Read and validate user configuration
|
||||||
# 读取和校验用户配置
|
# 读取和校验用户配置
|
||||||
@@ -33,6 +35,7 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
|||||||
if usr_conf is None:
|
if usr_conf is None:
|
||||||
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
|
logger.error("usr_conf is None, please check agent_ppo/conf/train_env_conf.toml")
|
||||||
return
|
return
|
||||||
|
_apply_diag_env_overrides(usr_conf, logger)
|
||||||
|
|
||||||
episode_runner = EpisodeRunner(
|
episode_runner = EpisodeRunner(
|
||||||
env=env,
|
env=env,
|
||||||
@@ -40,9 +43,11 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
|||||||
usr_conf=usr_conf,
|
usr_conf=usr_conf,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
monitor=monitor,
|
monitor=monitor,
|
||||||
|
diag_max_episodes=diag_max_episodes,
|
||||||
|
diag_log_only=diag_log_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while not episode_runner.stop_requested:
|
||||||
for g_data in episode_runner.run_episodes():
|
for g_data in episode_runner.run_episodes():
|
||||||
agent.send_sample_data(g_data)
|
agent.send_sample_data(g_data)
|
||||||
g_data.clear()
|
g_data.clear()
|
||||||
@@ -51,10 +56,56 @@ def workflow(envs, agents, logger=None, monitor=None, *args, **kwargs):
|
|||||||
if now - last_save_model_time >= 1800:
|
if now - last_save_model_time >= 1800:
|
||||||
agent.save_model()
|
agent.save_model()
|
||||||
last_save_model_time = now
|
last_save_model_time = now
|
||||||
|
if episode_runner.stop_requested:
|
||||||
|
break
|
||||||
|
|
||||||
|
if episode_runner.stop_requested:
|
||||||
|
logger.info(f"diagnostic max episodes reached: {episode_runner.episode_cnt}")
|
||||||
|
|
||||||
|
|
||||||
|
def _read_diag_max_episodes(logger):
|
||||||
|
raw_value = os.environ.get("ROBOT_VACUUM_DIAG_MAX_EPISODES", "").strip()
|
||||||
|
if not raw_value:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
value = int(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
if logger:
|
||||||
|
logger.warning(f"ignore invalid ROBOT_VACUUM_DIAG_MAX_EPISODES={raw_value!r}")
|
||||||
|
return 0
|
||||||
|
return max(value, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_positive_int_env(name, logger):
|
||||||
|
raw_value = os.environ.get(name, "").strip()
|
||||||
|
if not raw_value:
|
||||||
|
return 0
|
||||||
|
try:
|
||||||
|
value = int(raw_value)
|
||||||
|
except ValueError:
|
||||||
|
if logger:
|
||||||
|
logger.warning(f"ignore invalid {name}={raw_value!r}")
|
||||||
|
return 0
|
||||||
|
return max(value, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_bool_env(name):
|
||||||
|
return os.environ.get(name, "").strip().lower() in ("1", "true", "yes", "on")
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_diag_env_overrides(usr_conf, logger):
|
||||||
|
diag_max_step = _read_positive_int_env("ROBOT_VACUUM_DIAG_MAX_STEP", logger)
|
||||||
|
if diag_max_step <= 0:
|
||||||
|
return
|
||||||
|
env_conf = usr_conf.setdefault("env_conf", {})
|
||||||
|
old_max_step = env_conf.get("max_step")
|
||||||
|
env_conf["max_step"] = diag_max_step
|
||||||
|
if logger:
|
||||||
|
logger.info(f"diagnostic max_step override: {old_max_step} -> {diag_max_step}")
|
||||||
|
|
||||||
|
|
||||||
class EpisodeRunner:
|
class EpisodeRunner:
|
||||||
def __init__(self, env, agent, usr_conf, logger, monitor):
|
def __init__(self, env, agent, usr_conf, logger, monitor, diag_max_episodes=0, diag_log_only=False):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.usr_conf = usr_conf
|
self.usr_conf = usr_conf
|
||||||
@@ -63,6 +114,9 @@ class EpisodeRunner:
|
|||||||
self.episode_cnt = 0
|
self.episode_cnt = 0
|
||||||
self.last_report_monitor_time = 0
|
self.last_report_monitor_time = 0
|
||||||
self.last_get_training_metrics_time = 0
|
self.last_get_training_metrics_time = 0
|
||||||
|
self.diag_max_episodes = int(diag_max_episodes)
|
||||||
|
self.diag_log_only = bool(diag_log_only)
|
||||||
|
self.stop_requested = False
|
||||||
|
|
||||||
def run_episodes(self):
|
def run_episodes(self):
|
||||||
"""Run a single episode and yield collected samples.
|
"""Run a single episode and yield collected samples.
|
||||||
@@ -70,6 +124,8 @@ class EpisodeRunner:
|
|||||||
单局流程(generator),完成一局后 yield 整局样本。
|
单局流程(generator),完成一局后 yield 整局样本。
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
|
if self.stop_requested:
|
||||||
|
return
|
||||||
# Periodically get training metrics
|
# Periodically get training metrics
|
||||||
# 定期打印训练指标
|
# 定期打印训练指标
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@@ -188,6 +244,39 @@ class EpisodeRunner:
|
|||||||
f"result_code:{result_code} "
|
f"result_code:{result_code} "
|
||||||
f"result_message:{result_message}"
|
f"result_message:{result_message}"
|
||||||
)
|
)
|
||||||
|
diag = fm.get_diagnostic_summary()
|
||||||
|
self.logger.info(
|
||||||
|
f"[DIAG] ep:{self.episode_cnt} map:{diag['map_id']} "
|
||||||
|
f"steps:{step} result:{result_str} "
|
||||||
|
f"profile:{diag['reward_profile']} route:{diag['charger_route_source']} "
|
||||||
|
f"score:{float(total_score):.1f} reward:{total_reward + final_reward:.3f} "
|
||||||
|
f"mask_avg(raw/block/npc/recharge/escape/leave/final):"
|
||||||
|
f"{diag['avg_mask_counts']['raw']:.2f}/"
|
||||||
|
f"{diag['avg_mask_counts']['blocked']:.2f}/"
|
||||||
|
f"{diag['avg_mask_counts']['npc']:.2f}/"
|
||||||
|
f"{diag['avg_mask_counts']['recharge']:.2f}/"
|
||||||
|
f"{diag['avg_mask_counts']['escape']:.2f}/"
|
||||||
|
f"{diag['avg_mask_counts']['leave']:.2f}/"
|
||||||
|
f"{diag['avg_mask_counts']['final']:.2f} "
|
||||||
|
f"mask_changed(block/npc/recharge/escape/leave):"
|
||||||
|
f"{diag['mask_changed_steps']['blocked']}/"
|
||||||
|
f"{diag['mask_changed_steps']['npc']}/"
|
||||||
|
f"{diag['mask_changed_steps']['recharge']}/"
|
||||||
|
f"{diag['mask_changed_steps']['escape']}/"
|
||||||
|
f"{diag['mask_changed_steps']['leave']} "
|
||||||
|
f"mask_active(recharge/leave):"
|
||||||
|
f"{diag['mask_active_steps']['recharge']}/"
|
||||||
|
f"{diag['mask_active_steps']['leave']} "
|
||||||
|
f"tight(one/<=2/zero):"
|
||||||
|
f"{diag['one_action_steps']}/"
|
||||||
|
f"{diag['two_or_less_action_steps']}/"
|
||||||
|
f"{diag['zero_final_steps']} "
|
||||||
|
f"actions:{diag['action_hist']} "
|
||||||
|
f"known:{diag['known_ratio']:.3f} dirty_known:{diag['known_dirty_ratio']:.3f} "
|
||||||
|
f"frontier:{diag['frontier_ratio']:.3f} "
|
||||||
|
f"path_dirty/frontier:{diag['global_dirty_path_dist']:.1f}/"
|
||||||
|
f"{diag['frontier_path_dist']:.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
# Build sample frame
|
# Build sample frame
|
||||||
# 构造样本帧
|
# 构造样本帧
|
||||||
@@ -212,6 +301,9 @@ class EpisodeRunner:
|
|||||||
# Add terminal reward to last frame
|
# Add terminal reward to last frame
|
||||||
# 终局奖励叠加到最后一步
|
# 终局奖励叠加到最后一步
|
||||||
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
|
collector[-1].reward = collector[-1].reward + np.array([final_reward], dtype=np.float32)
|
||||||
|
if truncated and not terminated:
|
||||||
|
collector[-1].next_value = self.agent.estimate_value(_obs_data)
|
||||||
|
collector[-1].done = np.array([0.0], dtype=np.float32)
|
||||||
|
|
||||||
# Monitor reporting / 监控上报
|
# Monitor reporting / 监控上报
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@@ -231,6 +323,13 @@ class EpisodeRunner:
|
|||||||
"battery_fail": float(fm.battery_fail),
|
"battery_fail": float(fm.battery_fail),
|
||||||
"charge_count": float(charge_count),
|
"charge_count": float(charge_count),
|
||||||
"remaining_charge": float(remaining_charge),
|
"remaining_charge": float(remaining_charge),
|
||||||
|
"recharge_steps": float(fm.recharge_steps),
|
||||||
|
"mask_final_avg": float(diag["avg_mask_counts"]["final"]),
|
||||||
|
"mask_recharge_active": float(diag["mask_active_steps"]["recharge"]),
|
||||||
|
"mask_recharge_changed": float(diag["mask_changed_steps"]["recharge"]),
|
||||||
|
"mask_one_action_steps": float(diag["one_action_steps"]),
|
||||||
|
"mask_two_or_less_action_steps": float(diag["two_or_less_action_steps"]),
|
||||||
|
"mask_zero_final_steps": float(diag["zero_final_steps"]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -239,6 +338,13 @@ class EpisodeRunner:
|
|||||||
# Compute GAE and yield samples
|
# Compute GAE and yield samples
|
||||||
# GAE 计算并 yield 样本
|
# GAE 计算并 yield 样本
|
||||||
if collector:
|
if collector:
|
||||||
|
if self.diag_max_episodes > 0 and self.episode_cnt >= self.diag_max_episodes:
|
||||||
|
self.stop_requested = True
|
||||||
|
if self.diag_log_only:
|
||||||
|
collector.clear()
|
||||||
|
if self.stop_requested:
|
||||||
|
return
|
||||||
|
break
|
||||||
collector = sample_process(collector)
|
collector = sample_process(collector)
|
||||||
yield collector
|
yield collector
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ if __name__ == "__main__":
|
|||||||
algorithm_name=algorithm_name,
|
algorithm_name=algorithm_name,
|
||||||
algorithm_name_list=algorithm_name_list,
|
algorithm_name_list=algorithm_name_list,
|
||||||
env_vars={
|
env_vars={
|
||||||
"replay_buffer_capacity": "10",
|
"replay_buffer_capacity": "8",
|
||||||
"preload_ratio": "0.2",
|
"preload_ratio": "0.1",
|
||||||
"train_batch_size": "2",
|
"train_batch_size": "2",
|
||||||
"dump_model_freq": "1",
|
"dump_model_freq": "100",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user