diff --git a/open_spiel/algorithms/dqn_torch/dqn.cc b/open_spiel/algorithms/dqn_torch/dqn.cc index e744067d4c..948e535fc7 100644 --- a/open_spiel/algorithms/dqn_torch/dqn.cc +++ b/open_spiel/algorithms/dqn_torch/dqn.cc @@ -216,7 +216,7 @@ void DQN::Learn() { legal_actions_mask.push_back( torch::from_blob(t.legal_actions_mask.data(), {1, t.legal_actions_mask.size()}, - torch::TensorOptions().dtype(torch::kInt32)) + torch::TensorOptions().dtype(torch::kBool)) .clone()); actions.push_back(t.action); rewards.push_back(t.reward); @@ -231,7 +231,7 @@ void DQN::Learn() { next_info_states_tensor).detach(); torch::Tensor illegal_action_masks_tensor = - 1.0 - torch::stack(legal_actions_mask, 0); + torch::stack(legal_actions_mask, 0).bitwise_not(); torch::Tensor legal_q_values = torch::masked_fill(target_q_values, illegal_action_masks_tensor, kIllegalActionLogitsPenalty);