-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensorflow v2, parameters contains 'Tensor' so cannot be numpy()'d #131
Comments
I switched to a pytorch (the version against CUDA 11.3) and with a pytorch loss function:
I get this error while training, similar to above:
|
And if I convert y_pred to the host machine
Then I see this:
This is roughly the same problem as in tfv2, where you cannot evaluate a Parameter as the network output without taking it out of the graph and losing the gradient |
When I try to find a autodiff gradient with pytorch on a simple case (based on your examples), it seems possible to get a gradient of a cvxpy solution w.r.t a cvxpy Parameter. Note in the print output the device=cpu, suggesting this operation is happening on the cpu.
|
Hi all,
We are trying to compute the wasserstein distance (minimize cx s.t. Ax = b) where b is the neural network output. We aim to get d(wasserstein dist) / d(nn theta) so we can train
As tensorflow is the backend, and we are using deepxde, we are working with Tensorflowv2's eagertensor vs tensor. EagerTensors have .numpy(), Tensors do not.
I've used your example code to achieve what I want in a small example like below. Here y_pred (which is a placeholder for the network output), is an EagerTensor, so concat'd with another EagerTensor = EagerTensor, numpy()-able
Here is the intent implemented as a loss function:
But this throws the following exception:
I know there is a way to evaluate a tfv2 Tensor into something that can be numpy()'d, by defining a py_func, and that is what we did before but because we could not get a gradient of the wass dist wrt network param so the loss never converged, that is why we came to your library.
This is related to #121, in that f(network params) => Parameters. But I think the core of our problem is cvxpylayers as I understand can only take Parameters that are 'eager' / numpy()-able, or is there some way around this?
Maybe we should use a different backend?
Please share any insights / advice, thank you in advance
The text was updated successfully, but these errors were encountered: