Skip to content

Commit

Permalink
Merge branch 'grid_interp'
Browse files Browse the repository at this point in the history
  • Loading branch information
dkoes committed Sep 2, 2021
2 parents 64e4134 + 06f0565 commit ce386b1
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/torch_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def forward(ctx, gmaker, center, coords, types, radii):
ctx.save_for_backward(coords, types, radii)
ctx.gmaker = gmaker
ctx.center = center
if len(types.shape) != 2:
raise ValueError('Vector types required in Coords2Grid.')

shape = gmaker.grid_dimensions(types.shape[1]) #ntypes == nchannels
output = torch.empty(*shape,dtype=coords.dtype,device=coords.device)
gmaker.forward(center, coords, types, radii, output)
Expand All @@ -97,13 +100,16 @@ class BatchedCoords2GridFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, gmaker, center, coords, types, radii):
'''coords are Nx3, types are NxT, radii are N'''
'''coords are BxNx3, types are BxNxT, radii are BxN'''
ctx.save_for_backward(coords, types, radii)
ctx.gmaker = gmaker
ctx.center = center
batch_size = coords.shape[0]
if batch_size != types.shape[0] or batch_size != radii.shape[0]:
raise RuntimeError("Inconsistent batch sizes in Coords2Grid inputs")

if len(types.shape) != 3:
raise ValueError('Vector types required in BatchedCoords2Grid.')
shape = gmaker.grid_dimensions(types.shape[2]) #ntypes == nchannels
output = torch.empty(batch_size,*shape,dtype=coords.dtype,device=coords.device)
for i in range(batch_size):
Expand Down

0 comments on commit ce386b1

Please sign in to comment.