Skip to content

Commit

Permalink
Add bf16 support to cuda healpixpad implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
daviddpruitt committed Aug 7, 2024
1 parent 39031e2 commit 20d386d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
11 changes: 11 additions & 0 deletions earth2grid/csrc/healpixpad/healpixpad_cuda_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,17 @@ std::vector<torch::Tensor> healpixpad_cuda_backward(
goutput.data_ptr<at::Half>(),
my_stream);
break;
case torch::ScalarType::BFloat16:
HEALPixPadBck<at::BFloat16>(pad,
batch_size,
num_faces,
num_channels,
face_size,
face_size,
ginput.data_ptr<at::BFloat16>(),
goutput.data_ptr<at::BFloat16>(),
my_stream);
break;
}

return {goutput};
Expand Down
12 changes: 12 additions & 0 deletions earth2grid/csrc/healpixpad/healpixpad_cuda_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,18 @@ std::vector<torch::Tensor> healpixpad_cuda_forward(
output.data_ptr<at::Half>(),
my_stream);
break;
case torch::ScalarType::BFloat16:
HEALPixPadFwd<at::BFloat16>(
pad,
batch_size,
num_faces,
num_channels,
face_size,
face_size,
input.data_ptr<at::BFloat16>(),
output.data_ptr<at::BFloat16>(),
my_stream);
break;
}

return {output};
Expand Down

0 comments on commit 20d386d

Please sign in to comment.