优化PPO充电与避障策略
扩展观测特征到157维,加入充电桩、NPC、电量安全余量、地图统计和本步清扫信息。 增加低电量回充动作过滤、NPC危险区过滤,并调整奖励和终局日志以突出充电、避障和真实清扫得分。
This commit is contained in:
@@ -55,9 +55,9 @@ class Agent(BaseAgent):
|
||||
self.last_reward = 0.0
|
||||
|
||||
def observation_process(self, env_obs):
|
||||
"""Convert raw env_obs to ObsData (69D feature + legal action mask).
|
||||
"""Convert raw env_obs to ObsData (feature + legal action mask).
|
||||
|
||||
将原始 env_obs 转换为 ObsData(69D 特征 + 合法动作掩码)。
|
||||
将原始 env_obs 转换为 ObsData(特征 + 合法动作掩码)。
|
||||
"""
|
||||
feature, legal_action, reward = self.preprocessor.feature_process(env_obs, self.last_action)
|
||||
self.last_reward = reward
|
||||
@@ -135,8 +135,16 @@ class Agent(BaseAgent):
|
||||
加载模型检查点。
|
||||
"""
|
||||
model_file_path = f"{path}/model.ckpt-{id}.pkl"
|
||||
self.model.load_state_dict(torch.load(model_file_path, map_location=self.device))
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
state_dict = torch.load(model_file_path, map_location=self.device)
|
||||
try:
|
||||
self.model.load_state_dict(state_dict)
|
||||
except RuntimeError as err:
|
||||
msg = f"skip incompatible model {model_file_path}, use current initialized model instead: {err}"
|
||||
if self.logger:
|
||||
self.logger.warning(msg)
|
||||
return
|
||||
if self.logger:
|
||||
self.logger.info(f"load model {model_file_path} successfully")
|
||||
|
||||
def _run_model(self, feature):
|
||||
"""Gradient-free forward pass, returns (logits_np, value_np).
|
||||
|
||||
Reference in New Issue
Block a user