Optimize PPO short-run training
This commit is contained in:
@@ -47,6 +47,11 @@ class Preprocessor:
|
||||
self.battery_max = 600
|
||||
|
||||
self.cur_pos = (0, 0)
|
||||
self.prev_pos = None
|
||||
self.has_position_history = False
|
||||
self.current_visit_count = 0
|
||||
self.is_new_cell = False
|
||||
self.last_action = -1
|
||||
|
||||
self.dirt_cleaned = 0
|
||||
self.last_dirt_cleaned = 0
|
||||
@@ -60,6 +65,7 @@ class Preprocessor:
|
||||
# 最近污渍距离
|
||||
self.nearest_dirt_dist = 200.0
|
||||
self.last_nearest_dirt_dist = 200.0
|
||||
self.visit_count_map = np.zeros((self.GRID_SIZE, self.GRID_SIZE), dtype=np.uint16)
|
||||
|
||||
self._view_map = np.zeros((21, 21), dtype=np.float32)
|
||||
self._legal_act = [1] * 8
|
||||
@@ -74,8 +80,20 @@ class Preprocessor:
|
||||
env_info = observation["env_info"]
|
||||
hero = frame_state["heroes"]
|
||||
|
||||
self.last_action = int(last_action)
|
||||
self.step_no = int(observation["step_no"])
|
||||
self.prev_pos = self.cur_pos if self.has_position_history else None
|
||||
self.cur_pos = (int(hero["pos"]["x"]), int(hero["pos"]["z"]))
|
||||
self.has_position_history = True
|
||||
|
||||
hx, hz = self.cur_pos
|
||||
if 0 <= hx < self.GRID_SIZE and 0 <= hz < self.GRID_SIZE:
|
||||
self.current_visit_count = int(self.visit_count_map[hx, hz])
|
||||
self.is_new_cell = self.current_visit_count == 0
|
||||
self.visit_count_map[hx, hz] = min(self.current_visit_count + 1, np.iinfo(np.uint16).max)
|
||||
else:
|
||||
self.current_visit_count = 0
|
||||
self.is_new_cell = False
|
||||
|
||||
# Battery / 电量
|
||||
self.battery = int(hero["battery"])
|
||||
@@ -238,9 +256,15 @@ class Preprocessor:
|
||||
local_view = self._get_local_view_feature() # 49D
|
||||
global_state = self._get_global_state_feature() # 12D
|
||||
legal_action = self.get_legal_action() # 8D
|
||||
legal_arr = np.array(legal_action, dtype=np.float32)
|
||||
|
||||
feature = np.concatenate([local_view, global_state, legal_arr]) # 69D
|
||||
last_action_feature = np.zeros(8, dtype=np.float32)
|
||||
if 0 <= last_action < 8:
|
||||
last_action_feature[last_action] = 1.0
|
||||
|
||||
# The legal action mask is passed separately to PPO. Reusing this 8D slot
|
||||
# for action history makes the 69D observation more informative without
|
||||
# breaking the framework's fixed tensor shape.
|
||||
feature = np.concatenate([local_view, global_state, last_action_feature]) # 69D
|
||||
|
||||
reward = self.reward_process()
|
||||
|
||||
@@ -249,9 +273,26 @@ class Preprocessor:
|
||||
def reward_process(self):
|
||||
# Cleaning reward / 清扫奖励
|
||||
cleaned_this_step = max(0, self.dirt_cleaned - self.last_dirt_cleaned)
|
||||
cleaning_reward = 0.1 * cleaned_this_step
|
||||
cleaning_reward = 0.25 * cleaned_this_step
|
||||
|
||||
# Step penalty / 时间惩罚
|
||||
step_penalty = -0.001
|
||||
step_penalty = -0.002
|
||||
|
||||
return cleaning_reward + step_penalty
|
||||
# Dense guidance: prefer moving toward visible dirt.
|
||||
# 稠密引导:鼓励向视野内污渍靠近。
|
||||
approach_reward = 0.0
|
||||
if self.last_nearest_dirt_dist < 200.0 or self.nearest_dirt_dist < 200.0:
|
||||
dist_delta = float(np.clip(self.last_nearest_dirt_dist - self.nearest_dirt_dist, -5.0, 5.0))
|
||||
approach_reward = 0.01 * dist_delta if dist_delta > 0 else 0.006 * dist_delta
|
||||
|
||||
# Encourage covering new passable cells and mildly discourage loops.
|
||||
# 鼓励探索新格子,轻微惩罚反复绕圈。
|
||||
exploration_reward = 0.002 if self.is_new_cell else -0.0008 * min(self.current_visit_count, 5)
|
||||
|
||||
# Collision/stuck signal: invalid moves waste both step and battery.
|
||||
# 撞墙/原地不动会浪费步数和电量。
|
||||
stuck_penalty = 0.0
|
||||
if self.prev_pos is not None and self.cur_pos == self.prev_pos and 0 <= self.last_action < 8:
|
||||
stuck_penalty = -0.03
|
||||
|
||||
return cleaning_reward + approach_reward + exploration_reward + stuck_penalty + step_penalty
|
||||
|
||||
Reference in New Issue
Block a user