优化 PPO 清扫策略

This commit is contained in:
2026-04-26 17:29:03 +08:00
parent f04feb0cd9
commit f44e2483fc
6 changed files with 223 additions and 86 deletions

View File

@@ -43,14 +43,14 @@ class Algorithm:
训练入口:接收一批 SampleData执行一步梯度更新。
"""
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device)
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device)
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1)
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device)
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device)
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device)
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
obs = self._batch_tensor([s.obs for s in list_sample_data])
legal_action = self._batch_tensor([s.legal_action for s in list_sample_data])
act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1)
old_prob = self._batch_tensor([s.prob for s in list_sample_data])
old_value = self._batch_tensor([s.value for s in list_sample_data])
reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data])
advantage = self._batch_tensor([s.advantage for s in list_sample_data])
reward = self._batch_tensor([s.reward for s in list_sample_data])
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
@@ -194,9 +194,22 @@ class Algorithm:
对 logits 应用合法动作掩码后计算 softmax。
"""
legal_mask = legal_action > 0.5
all_illegal = ~legal_mask.any(dim=1, keepdim=True)
legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask)
safe_logits = logits.masked_fill(~legal_mask, -1e9)
return F.softmax(safe_logits, dim=1)
def _batch_tensor(self, values):
"""Stack framework tensors or raw numpy/list values into a float tensor."""
tensors = []
for value in values:
if isinstance(value, torch.Tensor):
tensor = value.to(self.device, dtype=torch.float32)
else:
tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device)
tensors.append(tensor)
return torch.stack(tensors)
def _entropy_beta(self):
"""Linearly decay entropy regularization for fast early exploration."""
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)