From 06f0565a83d7343ff9043a09976cd0f3e34de548 Mon Sep 17 00:00:00 2001 From: David Koes Date: Thu, 2 Sep 2021 09:37:21 -0400 Subject: [PATCH] Better error message when wrong types used. --- python/torch_bindings.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/torch_bindings.py b/python/torch_bindings.py index d6be6d7..ead81f1 100644 --- a/python/torch_bindings.py +++ b/python/torch_bindings.py @@ -75,6 +75,9 @@ def forward(ctx, gmaker, center, coords, types, radii): ctx.save_for_backward(coords, types, radii) ctx.gmaker = gmaker ctx.center = center + if len(types.shape) != 2: + raise ValueError('Vector types required in Coords2Grid.') + shape = gmaker.grid_dimensions(types.shape[1]) #ntypes == nchannels output = torch.empty(*shape,dtype=coords.dtype,device=coords.device) gmaker.forward(center, coords, types, radii, output) @@ -97,13 +100,16 @@ class BatchedCoords2GridFunction(torch.autograd.Function): @staticmethod def forward(ctx, gmaker, center, coords, types, radii): - '''coords are Nx3, types are NxT, radii are N''' + '''coords are BxNx3, types are BxNxT, radii are BxN''' ctx.save_for_backward(coords, types, radii) ctx.gmaker = gmaker ctx.center = center batch_size = coords.shape[0] if batch_size != types.shape[0] or batch_size != radii.shape[0]: raise RuntimeError("Inconsistent batch sizes in Coords2Grid inputs") + + if len(types.shape) != 3: + raise ValueError('Vector types required in BatchedCoords2Grid.') shape = gmaker.grid_dimensions(types.shape[2]) #ntypes == nchannels output = torch.empty(batch_size,*shape,dtype=coords.dtype,device=coords.device) for i in range(batch_size):