Skip to content

Commit

Permalink
Update for libtorch 2.0
Browse files Browse the repository at this point in the history
PyTorch 2.0 deprecated passing a int-valued mask to masked_fill.
This change updates one useage in dqn.cc to use bools.
  • Loading branch information
tacertain committed Apr 28, 2024
1 parent 1208f83 commit c07251f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions open_spiel/algorithms/dqn_torch/dqn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void DQN::Learn() {
legal_actions_mask.push_back(
torch::from_blob(t.legal_actions_mask.data(),
{1, t.legal_actions_mask.size()},
torch::TensorOptions().dtype(torch::kInt32))
torch::TensorOptions().dtype(torch::kBool))
.clone());
actions.push_back(t.action);
rewards.push_back(t.reward);
Expand All @@ -231,7 +231,7 @@ void DQN::Learn() {
next_info_states_tensor).detach();

torch::Tensor illegal_action_masks_tensor =
1.0 - torch::stack(legal_actions_mask, 0);
torch::stack(legal_actions_mask, 0).bitwise_not();
torch::Tensor legal_q_values =
torch::masked_fill(target_q_values, illegal_action_masks_tensor,
kIllegalActionLogitsPenalty);
Expand Down

0 comments on commit c07251f

Please sign in to comment.