Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device conflict when training BC #843

Open
IsaacSheidlower opened this issue Mar 28, 2024 · 1 comment
Open

Device conflict when training BC #843

IsaacSheidlower opened this issue Mar 28, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@IsaacSheidlower
Copy link

Bug description

When training a BC agent on cuda, I got a torch error with the latest stable release on python 3.11. Basically there was a device mismatch when calculating the loss. I did a rough fix to fix it on the torch side.

Steps to reproduce

Here is the code I set up the trainer with:
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng,
device="cuda:0"
)

bc_trainer.train(n_epochs=100)

I fixed it by changing the code in the Torch categorical distributions class in the log_prob function:

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value = value.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, self.logits)
value = value[..., :1]
# send value to the same device as self.logits, in case value is on a different device
value = value.to(log_pmf.device)
return log_pmf.gather(-1, value).squeeze(-1)

here is the error message:

RuntimeError Traceback (most recent call last)
Cell In[8], line 2
1 print("Training a policy using Behavior Cloning")
----> 2 bc_trainer.train(n_epochs=100)

File c:\Users\qiglo\AppData\Local\Programs\Python\Python311\Lib\site-packages\imitation\algorithms\bc.py:495, in BC.train(self, n_epochs, n_batches, on_epoch_end, on_batch_end, log_interval, log_rollouts_venv, log_rollouts_n_episodes, progress_bar, reset_tensorboard)
490 obs_tensor = types.map_maybe_dict(
491 lambda x: util.safe_to_tensor(x, device=self.policy.device),
492 types.maybe_unwrap_dictobs(batch["obs"]),
493 )
494 acts = util.safe_to_tensor(batch["acts"], device=self.policy.device)
--> 495 training_metrics = self.loss_calculator(self.policy, obs_tensor, acts)
497 # Renormalise the loss to be averaged over the whole
498 # batch size instead of the minibatch size.
499 # If there is an incomplete batch, its gradients will be
500 # smaller, which may be helpful for stability.
501 loss = training_metrics.loss * minibatch_size / self.batch_size

File c:\Users\qiglo\AppData\Local\Programs\Python\Python311\Lib\site-packages\imitation\algorithms\bc.py:130, in BehaviorCloningLossCalculator.call(self, policy, obs, acts)
126 acts = util.safe_to_tensor(acts)
128 # policy.evaluate_actions's type signatures are incorrect.
129 # See DLR-RM/stable-baselines3#1679
--> 130 (_, log_prob, entropy) = policy.evaluate_actions(
131 tensor_obs, # type: ignore[arg-type]
132 acts,
133 )
134 prob_true_act = th.exp(log_prob).mean()
135 log_prob = log_prob.mean()

File c:\Users\qiglo\AppData\Local\Programs\Python\Python311\Lib\site-packages\stable_baselines3\common\policies.py:736, in ActorCriticPolicy.evaluate_actions(self, obs, actions)
734 latent_vf = self.mlp_extractor.forward_critic(vf_features)
735 distribution = self._get_action_dist_from_latent(latent_pi)
--> 736 log_prob = distribution.log_prob(actions)
737 values = self.value_net(latent_vf)
738 entropy = distribution.entropy()

File c:\Users\qiglo\AppData\Local\Programs\Python\Python311\Lib\site-packages\stable_baselines3\common\distributions.py:292, in CategoricalDistribution.log_prob(self, actions)
291 def log_prob(self, actions: th.Tensor) -> th.Tensor:
--> 292 return self.distribution.log_prob(actions)

File c:\Users\qiglo\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\distributions\categorical.py:143, in Categorical.log_prob(self, value)
140 value = value[..., :1]
141 # # send value to the same device as self.logits, in case value is on a different device
142 # value = value.to(log_pmf.device)
--> 143 return log_pmf.gather(-1, value).squeeze(-1)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA_gather)

@IsaacSheidlower IsaacSheidlower added the bug Something isn't working label Mar 28, 2024
@Bpoole908
Copy link

This issue looks like it is being handled in this pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants