优化 PPO 清扫策略
This commit is contained in:
@@ -43,14 +43,14 @@ class Algorithm:
|
||||
|
||||
训练入口:接收一批 SampleData,执行一步梯度更新。
|
||||
"""
|
||||
obs = torch.stack([s.obs for s in list_sample_data]).to(self.device)
|
||||
legal_action = torch.stack([s.legal_action for s in list_sample_data]).to(self.device)
|
||||
act = torch.stack([s.act for s in list_sample_data]).to(self.device).view(-1, 1)
|
||||
old_prob = torch.stack([s.prob for s in list_sample_data]).to(self.device)
|
||||
old_value = torch.stack([s.value for s in list_sample_data]).to(self.device)
|
||||
reward_sum = torch.stack([s.reward_sum for s in list_sample_data]).to(self.device)
|
||||
advantage = torch.stack([s.advantage for s in list_sample_data]).to(self.device)
|
||||
reward = torch.stack([s.reward for s in list_sample_data]).to(self.device)
|
||||
obs = self._batch_tensor([s.obs for s in list_sample_data])
|
||||
legal_action = self._batch_tensor([s.legal_action for s in list_sample_data])
|
||||
act = self._batch_tensor([s.act for s in list_sample_data]).view(-1, 1)
|
||||
old_prob = self._batch_tensor([s.prob for s in list_sample_data])
|
||||
old_value = self._batch_tensor([s.value for s in list_sample_data])
|
||||
reward_sum = self._batch_tensor([s.reward_sum for s in list_sample_data])
|
||||
advantage = self._batch_tensor([s.advantage for s in list_sample_data])
|
||||
reward = self._batch_tensor([s.reward for s in list_sample_data])
|
||||
|
||||
if Config.NORMALIZE_ADVANTAGE and advantage.numel() > 1:
|
||||
advantage = (advantage - advantage.mean()) / (advantage.std(unbiased=False) + 1e-8)
|
||||
@@ -194,9 +194,22 @@ class Algorithm:
|
||||
对 logits 应用合法动作掩码后计算 softmax。
|
||||
"""
|
||||
legal_mask = legal_action > 0.5
|
||||
all_illegal = ~legal_mask.any(dim=1, keepdim=True)
|
||||
legal_mask = torch.where(all_illegal, torch.ones_like(legal_mask), legal_mask)
|
||||
safe_logits = logits.masked_fill(~legal_mask, -1e9)
|
||||
return F.softmax(safe_logits, dim=1)
|
||||
|
||||
def _batch_tensor(self, values):
|
||||
"""Stack framework tensors or raw numpy/list values into a float tensor."""
|
||||
tensors = []
|
||||
for value in values:
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value.to(self.device, dtype=torch.float32)
|
||||
else:
|
||||
tensor = torch.as_tensor(value, dtype=torch.float32, device=self.device)
|
||||
tensors.append(tensor)
|
||||
return torch.stack(tensors)
|
||||
|
||||
def _entropy_beta(self):
|
||||
"""Linearly decay entropy regularization for fast early exploration."""
|
||||
progress = min(float(self.train_step) / max(Config.BETA_DECAY_STEPS, 1), 1.0)
|
||||
|
||||
Reference in New Issue
Block a user