Skip to content

Commit

Permalink
Merge pull request #392 from MeasureTransport/torch-pickle
Browse files Browse the repository at this point in the history
Pickle Bugfix
  • Loading branch information
mparno authored Feb 27, 2024
2 parents fd3ebeb + c93aa58 commit 6e84606
Show file tree
Hide file tree
Showing 12 changed files with 157 additions and 127 deletions.
1 change: 1 addition & 0 deletions .docker/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ dependencies:
- gcc
- gxx
- make
- dill
2 changes: 1 addition & 1 deletion .github/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ dependencies:
- cereal >= 1.3
- nlopt >= 2.7
- pytorch

- dill
1 change: 1 addition & 0 deletions .github/workflows/build-bindings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name: binding-tests
on:
push:
branches:
- release
- main
pull_request: {}

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release

jobs:
build-docs:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-external-lib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release
pull_request: {}

env:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-push-docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release

jobs:
docker:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- release
pull_request: {}

jobs:
Expand Down
121 changes: 5 additions & 116 deletions bindings/python/package/torch.py
Original file line number Diff line number Diff line change
@@ -1,124 +1,13 @@
import torch

def ExtractTorchTensorData(tensor):
""" Extracts the pointer, shape, and stride from a pytorch tensor and returns a tuple
that can be passed to MParT functions that have been overloaded to accept
(double*, std::tuple<int,int>, std::tuple<int,int>) instead of a Kokkos::View.
Arguments:
------------
tensor: pytorch.Tensor
The pytorch tensor we want to eventually wrap with a Kokkos view.
from .torch_helpers import ExtractTorchTensorData, MpartTorchAutograd

Returns:
------------
Tuple[int, Tuple[int,int], Tuple[int,int]]
A python tuple that contains all information needed to construct a Kokkos::View.
After casting to c++ types using pybind, this output can be passed to the
mpart::ConstructViewFromPointer function.
"""

# Make sure the tensor has double data type
if tensor.dtype != torch.float64:
raise ValueError(f'Currently only tensors with float64 datatype can be converted. Current dtype is {tensor.dtype}')

if len(tensor.shape)==1:
return tensor.data_ptr(), tensor.shape[0], tensor.stride()[0]
elif len(tensor.shape)==2:
return tensor.data_ptr(), tuple(tensor.shape), tuple(tensor.stride())
else:
raise ValueError(f'Currently only 1d and 2d tensors can be converted.')


class MpartTorchAutograd(torch.autograd.Function):

@staticmethod
def forward(ctx, input, coeffs, f, return_logdet):
ctx.save_for_backward(input, coeffs)
ctx.f = f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()

output = torch.zeros(f.outputDim, input.shape[1], dtype=torch.double)
f.EvaluateImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(output))

if return_logdet:
logdet = torch.zeros(input.shape[1], dtype=torch.double)
f.LogDeterminantImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(logdet))
return output.type(input.dtype), logdet.type(input.dtype)
else:
return output.type(input.dtype)

@staticmethod
def backward(ctx, output_sens, logdet_sens=None):
input, coeffs = ctx.saved_tensors
f = ctx.f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()
output_sens_dbl = output_sens.double()

logdet_sens_dbl = None
if logdet_sens is not None:
logdet_sens_dbl = logdet_sens.double()

# Get the gradient wrt input
grad = None
if input.requires_grad:
grad = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.GradientImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(grad))

if logdet_sens is not None:
grad2 = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.LogDeterminantInputGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))
grad += grad2*logdet_sens_dbl[None,:]

coeff_grad = None
if coeffs is not None:
if coeffs.requires_grad:
coeff_grad = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)
f.CoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(coeff_grad))

coeff_grad = coeff_grad.sum(axis=1) # pytorch expects total gradient not per-sample gradient

if logdet_sens is not None:
grad2 = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)

f.LogDeterminantCoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))

coeff_grad += torch.sum(grad2*logdet_sens[None,:],axis=1)

if coeff_grad is not None:
coeff_grad = coeff_grad.type(input.dtype)

if grad is not None:
grad = grad.type(input.dtype)

return grad, coeff_grad, None, None



class TorchParameterizedFunctionBase(torch.nn.Module):
""" Defines a wrapper around the MParT ParameterizedFunctionBase class that
can be used with pytorch.
"""

def __init__(self, f, store_coeffs=True, dtype=torch.double):
def __init__(self, f=None, store_coeffs=True, dtype=torch.double):
super().__init__()

self.f = f
Expand All @@ -129,7 +18,7 @@ def __init__(self, f, store_coeffs=True, dtype=torch.double):
self.coeffs = torch.nn.Parameter(coeff_tensor)
else:
self.coeffs = None

def forward(self, x, coeffs=None):

if coeffs is None:
Expand All @@ -148,7 +37,7 @@ class TorchConditionalMapBase(torch.nn.Module):
This can be done either in the constructor or afterwards.
"""

def __init__(self, f, store_coeffs=True, return_logdet=False, dtype=torch.double):
def __init__(self, f=None, store_coeffs=True, return_logdet=False, dtype=torch.double):
super().__init__()

self.return_logdet = return_logdet
Expand All @@ -159,7 +48,7 @@ def __init__(self, f, store_coeffs=True, return_logdet=False, dtype=torch.double
self.coeffs = torch.nn.Parameter(coeff_tensor)
else:
self.coeffs = None

def forward(self, x, coeffs=None):

if coeffs is None:
Expand Down
115 changes: 115 additions & 0 deletions bindings/python/package/torch_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch

def ExtractTorchTensorData(tensor):
""" Extracts the pointer, shape, and stride from a pytorch tensor and returns a tuple
that can be passed to MParT functions that have been overloaded to accept
(double*, std::tuple<int,int>, std::tuple<int,int>) instead of a Kokkos::View.
Arguments:
------------
tensor: pytorch.Tensor
The pytorch tensor we want to eventually wrap with a Kokkos view.
Returns:
------------
Tuple[int, Tuple[int,int], Tuple[int,int]]
A python tuple that contains all information needed to construct a Kokkos::View.
After casting to c++ types using pybind, this output can be passed to the
mpart::ConstructViewFromPointer function.
"""

# Make sure the tensor has double data type
if tensor.dtype != torch.float64:
raise ValueError(f'Currently only tensors with float64 datatype can be converted. Current dtype is {tensor.dtype}')

if len(tensor.shape)==1:
return tensor.data_ptr(), tensor.shape[0], tensor.stride()[0]
elif len(tensor.shape)==2:
return tensor.data_ptr(), tuple(tensor.shape), tuple(tensor.stride())
else:
raise ValueError(f'Currently only 1d and 2d tensors can be converted.')


class MpartTorchAutograd(torch.autograd.Function):

def __reduce__(self):
return (self.__class__, (None,))

@staticmethod
def forward(ctx, input, coeffs, f, return_logdet):
ctx.save_for_backward(input, coeffs)
ctx.f = f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()

output = torch.zeros(f.outputDim, input.shape[1], dtype=torch.double)
f.EvaluateImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(output))

if return_logdet:
logdet = torch.zeros(input.shape[1], dtype=torch.double)
f.LogDeterminantImpl(ExtractTorchTensorData(input_dbl), ExtractTorchTensorData(logdet))
return output.type(input.dtype), logdet.type(input.dtype)
else:
return output.type(input.dtype)

@staticmethod
def backward(ctx, output_sens, logdet_sens=None):
input, coeffs = ctx.saved_tensors
f = ctx.f

coeffs_dbl = None
if coeffs is not None:
coeffs_dbl = coeffs.double()
f.WrapCoeffs(ExtractTorchTensorData(coeffs_dbl))
input_dbl = input.double()
output_sens_dbl = output_sens.double()

logdet_sens_dbl = None
if logdet_sens is not None:
logdet_sens_dbl = logdet_sens.double()

# Get the gradient wrt input
grad = None
if input.requires_grad:
grad = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.GradientImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(grad))

if logdet_sens is not None:
grad2 = torch.zeros(f.inputDim, input.shape[1], dtype=torch.double)

f.LogDeterminantInputGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))
grad += grad2*logdet_sens_dbl[None,:]

coeff_grad = None
if coeffs is not None:
if coeffs.requires_grad:
coeff_grad = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)
f.CoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(output_sens_dbl),
ExtractTorchTensorData(coeff_grad))

coeff_grad = coeff_grad.sum(axis=1) # pytorch expects total gradient not per-sample gradient

if logdet_sens is not None:
grad2 = torch.zeros(f.numCoeffs, input.shape[1], dtype=torch.double)

f.LogDeterminantCoeffGradImpl(ExtractTorchTensorData(input_dbl),
ExtractTorchTensorData(grad2))

coeff_grad += torch.sum(grad2*logdet_sens[None,:],axis=1)

if coeff_grad is not None:
coeff_grad = coeff_grad.type(input.dtype)

if grad is not None:
grad = grad.type(input.dtype)

return grad, coeff_grad, None, None
35 changes: 26 additions & 9 deletions bindings/python/tests/test_TorchWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mpart as mt
import numpy as np
import dill

if haveTorch:

Expand Down Expand Up @@ -81,6 +82,21 @@ def test_AutogradCoeffs():
loss.backward()
assert tmap2.coeffs.grad is not None

def test_AutogradCoeffAsInput():

opts = mt.MapOptions()
tmap = mt.CreateTriangular(dim,dim,3,opts) # Simple third order map

tmap2 = tmap.torch(store_coeffs=False)

x = torch.randn(numSamps, dim, dtype=torch.double)
coeffs = torch.randn(tmap.numCoeffs, dtype=torch.double)

y = tmap2(x,coeffs)
assert y.shape[0] == numSamps
assert y.shape[1] == dim
assert not y.isnan().any()

def test_TorchMethod():
opts = mt.MapOptions()
tmap = mt.CreateTriangular(dim,dim,3,opts) # Simple third order map
Expand All @@ -96,20 +112,20 @@ def test_TorchMethod():
assert np.all(y.detach().numpy() == tmap.Evaluate(x.T.detach().numpy()).T)
assert np.all(logdet.detach().numpy() == tmap.LogDeterminant(x.T.detach().numpy()))

def test_AutogradCoeffAsInput():

def test_TorchPickle():
opts = mt.MapOptions()
tmap = mt.CreateTriangular(dim,dim,3,opts) # Simple third order map

tmap2 = tmap.torch(store_coeffs=False)

x = torch.randn(numSamps, dim, dtype=torch.double)
coeffs = torch.randn(tmap.numCoeffs, dtype=torch.double)
tmap2 = tmap.torch(store_coeffs=True)
y = tmap2.forward(x)


y = tmap2(x,coeffs)
assert y.shape[0] == numSamps
assert y.shape[1] == dim
assert not y.isnan().any()
map_bytes = dill.dumps(tmap2, dill.HIGHEST_PROTOCOL)
tmap3 = dill.loads(map_bytes)

y2 = tmap3.forward(x)
assert (y2-y).abs().max() < 1e-8


if __name__=='__main__':
Expand All @@ -118,4 +134,5 @@ def test_AutogradCoeffAsInput():
test_Autograd()
test_AutogradCoeffs()
test_TorchMethod()
test_TorchPickle()
test_AutogradCoeffAsInput()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ license={file="LICENSE.txt"}
readme="README.md"
requires-python = ">=3.7"
description="A Monotone Parameterization Toolkit"
version="2.2.1"
version="2.2.2"
keywords=["Measure Transport", "Monotone", "Transport Map", "Isotonic Regression", "Triangular", "Knothe-Rosenblatt"]

[project.urls]
Expand Down
Loading

0 comments on commit 6e84606

Please sign in to comment.