Skip to content

Commit

Permalink
Better multi-device support
Browse files Browse the repository at this point in the history
Provide helper functions for setting current device.
Throw useful error message if tensor is not on current device.
  • Loading branch information
dkoes committed Sep 14, 2021
1 parent ce386b1 commit 4467041
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/bindings.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ BOOST_PYTHON_MODULE(molgrid)
"Set if generated grids should be on GPU by default.");
def("tofloatptr", +[](long val) { return Pointer<float>((float*)val);}, "Return integer as float *");
def("todoubleptr", +[](long val) { return Pointer<double>((double*)val);}, "Return integer as double *");
def("set_gpu_device", +[](int device)->void {LMG_CUDA_CHECK(cudaSetDevice(device));}, "Set current GPU device.");
def("get_gpu_device", +[]()->int {int device = 0; LMG_CUDA_CHECK(cudaGetDevice(&device)); return device;}, "Get current GPU device.");

//type converters
py_pair<int, float>();
Expand Down
8 changes: 8 additions & 0 deletions python/bindings_grids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ struct Grid_from_python {
} else {
return false; //don't recognize
}
if(info.isGPU && hasattr(t,"device") && hasattr(t.attr("device"), "index")) {
int d = extract<int>(t.attr("device").attr("index"));
int currd = 0;
LMG_CUDA_CHECK(cudaGetDevice(&currd));
if(currd != d) {
throw std::invalid_argument("Attempt to use GPU tensor on different device ("+itoa(d)+") than current device ("+itoa(currd)+"). Change location of tensor or change current device.");
}
}
if(hasattr(t,"is_contiguous")) {
if(!t.attr("is_contiguous")()) {
throw std::invalid_argument("Attempt to use non-contiguous tensor in molgrid. Call clone first.");
Expand Down
26 changes: 26 additions & 0 deletions test/test_gridmaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,32 @@ def test_a_grid():
assert 2.094017 == approx(mgridout.tonumpy().max())
assert 2.094017 == approx(mgridgpu.tonumpy().max())

def test_devices():
fname = datadir+"/small.types"
e = molgrid.ExampleProvider(data_root=datadir+"/structs")
e.populate(fname)
ex = e.next()

gmaker = molgrid.GridMaker()
dims = gmaker.grid_dimensions(ex.num_types()) # this should be grid_dims or get_grid_dims

try:
torchout = torch.zeros(dims, dtype=torch.float32, device='cuda:1')
except RuntimeError:
return # can't test multiple devices because we don't have them

try:
gmaker.forward(ex, torchout)
assert False # should not get here
except ValueError:
pass # this should return an error

molgrid.set_gpu_device(1)
assert molgrid.get_gpu_device() == 1
ex = e.next()
gmaker.forward(ex, torchout) #should work now


def test_radius_multiples():
g1 = molgrid.GridMaker(resolution=.1,dimension=6.0)
c = np.array([[0,0,0]],np.float32)
Expand Down

0 comments on commit 4467041

Please sign in to comment.