-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transfer learning tutorial doesn't work with pytorch backend #20287
Comments
Is the augmented image (or the first image) a tensor in GPU? @off6atomic If it is, can you do |
Hi @off6atomic - Can you please let me know which keras version you are getting error ? I ran the same code by setting backend set to "torch" and it is running fine on keras 3.5.0. Attached gist for the reference. |
Hi @off6atomic - As per mention in your code snipper, for pytorch backend getting error As per the error pytorch tensor which are in cuda:0 and Numpy array reside in CPU.
Attached gist having running entire code for reference. |
@mehtamansi29 It works! |
Hi @off6atomic-
The code and description at here https://keras.io/guides/transfer_learning/ is for keras API which contains all backends in it. |
OK. I mean that if I change the backend to torch, I expect that I don't need to change anything else in the code for it to work. In this case, we have to convert the tensor to cpu before it works. It's not a big hurdle though. Just wanted to know whether users are expected to know that they have to convert tensor to CPU in pytorch case and doesn't have to do it in tensorflow case. |
Hi @off6atomic -
Yes. We have to convert tensor to CPU in pytorch case and doesn't have to do it in tensorflow case. |
Is this difference in behavior documented somewhere? |
Hi @off6atomic - That is not documented somewhere but from here you can find if cuda is available torch tensor is in GPU. And from the error |
@mehtamansi29 Then there's probably no issue with the tutorial I guess. We can close it. Thank you! |
If you run this code with backend set to "torch"
https://keras.io/guides/transfer_learning/
you will get error in the following cell:
How do we fix this? This error happens on Colab and also on local machine.
The text was updated successfully, but these errors were encountered: