diff --git a/python/torch_bindings.py b/python/torch_bindings.py index d6be6d7..ead81f1 100644 --- a/python/torch_bindings.py +++ b/python/torch_bindings.py @@ -75,6 +75,9 @@ def forward(ctx, gmaker, center, coords, types, radii): ctx.save_for_backward(coords, types, radii) ctx.gmaker = gmaker ctx.center = center + if len(types.shape) != 2: + raise ValueError('Vector types required in Coords2Grid.') + shape = gmaker.grid_dimensions(types.shape[1]) #ntypes == nchannels output = torch.empty(*shape,dtype=coords.dtype,device=coords.device) gmaker.forward(center, coords, types, radii, output) @@ -97,13 +100,16 @@ class BatchedCoords2GridFunction(torch.autograd.Function): @staticmethod def forward(ctx, gmaker, center, coords, types, radii): - '''coords are Nx3, types are NxT, radii are N''' + '''coords are BxNx3, types are BxNxT, radii are BxN''' ctx.save_for_backward(coords, types, radii) ctx.gmaker = gmaker ctx.center = center batch_size = coords.shape[0] if batch_size != types.shape[0] or batch_size != radii.shape[0]: raise RuntimeError("Inconsistent batch sizes in Coords2Grid inputs") + + if len(types.shape) != 3: + raise ValueError('Vector types required in BatchedCoords2Grid.') shape = gmaker.grid_dimensions(types.shape[2]) #ntypes == nchannels output = torch.empty(batch_size,*shape,dtype=coords.dtype,device=coords.device) for i in range(batch_size):