Skip to content

Commit

Permalink
batched forward of individual tensors
Browse files Browse the repository at this point in the history
This should make things mroe convenient when using pytorch data loaders
instead of our superior example providers.
  • Loading branch information
dkoes committed Aug 18, 2021
1 parent 3ff5ab9 commit eeec466
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 7 deletions.
48 changes: 45 additions & 3 deletions include/libmolgrid/grid_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,13 @@ class GridMaker {

// Docstring_GridMaker_forward_8
/* \brief Generate grid tensor from CPU atomic data. Grid must be properly sized.
* If TypesFromRadii, radii should be size of types and type information will be used
* to select the radii.
* @param[in] center of grid
* @param[in] coordinates (Nx3)
* @param[in] vectors (NxT)
* @param[in] radii (N) or (T)
* @param[out] a 4D grid
*/
template <typename Dtype, bool TypesFromRadii = false>
template <typename Dtype>
void forward(float3 grid_center, const Grid<float, 2, false>& coords,
const Grid<float, 2, false>& type_vector, const Grid<float, 1, false>& radii,
Grid<Dtype, 4, false>& out) const;
Expand All @@ -247,6 +245,50 @@ class GridMaker {
const Grid<float, 2, true>& type_vector, const Grid<float, 1, true>& radii,
Grid<Dtype, 4, true>& out) const;

// Docstring_GridMaker_forward_10
/* \brief Generate grid tensors from batched atomic data. Grid must be properly sized.
* @param[in] centers of grid (Bx3)
* @param[in] coordinates (BxNx3)
* @param[in] type vectors (BxNxT) or type indices (BxN)
* @param[in] radii (BxN) or (BxT)
* @param[out] a 5D grid
*/
template <typename Dtype, bool isCUDA, int N>
void forward(const Grid<float, 2, isCUDA> &centers,
const Grid<float, 3, isCUDA> &coords,
const Grid<float, N, isCUDA> &types,
const Grid<float, 2, isCUDA> &radii,Grid<Dtype, 5, isCUDA> &out) const{
size_t B = centers.dimension(0);
if(coords.dimension(0) != B)
throw std::invalid_argument(
"Mismatched batch sizes: " + itoa(coords.dimension(0)) + " vs " + itoa(B));
if(types.dimension(0) != B)
throw std::invalid_argument(
"Mismatched batch sizes: " + itoa(types.dimension(0)) + " vs " + itoa(B));
if(radii.dimension(0) != B)
throw std::invalid_argument(
"Mismatched batch sizes: " + itoa(radii.dimension(0)) + " vs " + itoa(B));
if(out.dimension(0) != B)
throw std::invalid_argument(
"Mismatched batch sizes: " + itoa(out.dimension(0)) + " vs " + itoa(B));

float3 center = { 0, };

for(unsigned i = 0; i < B; i++) {
if(isCUDA)
cudaMemcpy(&center, centers[i].data(), sizeof(center),cudaMemcpyDeviceToHost);
else
memcpy(&center, centers[i].data(), sizeof(center));

//convert from subgrids to full grids
Grid<float, 2, isCUDA> C = coords[i];
Grid<float, N-1, isCUDA> T = types[i];
Grid<float, 1, isCUDA> R = radii[i];
Grid<Dtype, 4, isCUDA> O = out[i];
forward(center, C, T, R, O);
}
}


// Docstring_GridMaker_backward_1
/* \brief Generate atom and type gradients from grid gradients. (CPU)
Expand Down
28 changes: 25 additions & 3 deletions src/grid_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void GridMaker::forward(float3 grid_center, const Grid<float, 2, false>& coords,
}
}

template<typename Dtype, bool TypesFromRadii>
template<typename Dtype>
void GridMaker::forward(float3 grid_center, const Grid<float, 2, false>& coords,
const Grid<float, 2, false>& type_vector, const Grid<float, 1, false>& radii,
Grid<Dtype, 4, false>& out) const {
Expand Down Expand Up @@ -271,7 +271,6 @@ void GridMaker::forward(float3 grid_center, const Grid<float, 2, false>& coords,
}



template void GridMaker::forward(const std::vector<Example>& in, Grid<float, 5, false>& out,
float random_translation, bool random_rotation) const;
template void GridMaker::forward(const std::vector<Example>& in, Grid<float, 5, true>& out,
Expand All @@ -298,7 +297,30 @@ template void GridMaker::forward(float3 grid_center, const Grid<float, 2, false>
template void GridMaker::forward(float3 grid_center, const Grid<float, 2, false>& coords,
const Grid<float, 2, false>& type_vector, const Grid<float, 1, false>& radii,
Grid<double, 4, false>& out) const;


//batched cpu float

template void GridMaker::forward<float,false,2>(const Grid<float, 2, false> &centers,
const Grid<float, 3, false> &coords,
const Grid<float, 2, false> &types,
const Grid<float, 2, false> &radii, Grid<float, 5, false> &out) const;
template void GridMaker::forward<float,false,3>(const Grid<float, 2, false> &centers,
const Grid<float, 3, false> &coords,
const Grid<float, 3, false> &types,
const Grid<float, 2, false> &radii,Grid<float, 5, false> &out) const;

//batched cpu double
template void GridMaker::forward<double,false,2>(const Grid<float, 2, false> &centers,
const Grid<float, 3, false> &coords,
const Grid<float, 2, false> &types,
const Grid<float, 2, false> &radii, Grid<double, 5, false> &out) const;
template void GridMaker::forward<double,false,3>(const Grid<float, 2, false> &centers,
const Grid<float, 3, false> &coords,
const Grid<float, 3, false> &types,
const Grid<float, 2, false> &radii,Grid<double, 5, false> &out) const;



//set a single atom gradient - note can't pass a slice by reference
template <typename Dtype>
float3 GridMaker::calc_atom_gradient_cpu(const float3& grid_origin, const Grid1f& coordr, const Grid<Dtype, 3, false>& diff, float radius) const {
Expand Down
24 changes: 23 additions & 1 deletion src/grid_maker.cu
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,29 @@ namespace libmolgrid {
const Grid<float, 2, true>& type_vector, const Grid<float, 1, true>& radii, Grid<float, 4, true>& out) const;
template void GridMaker::forward(float3 grid_center, const Grid<float, 2, true>& coords,
const Grid<float, 2, true>& type_vector, const Grid<float, 1, true>& radii, Grid<double, 4, true>& out) const;



//batched gpu float
template void GridMaker::forward<float,true,2>(const Grid<float, 2, true> &centers,
const Grid<float, 3, true> &coords,
const Grid<float, 2, true> &types,
const Grid<float, 2, true> &radii, Grid<float, 5, true> &out) const;
template void GridMaker::forward<float,true,3>(const Grid<float, 2, true> &centers,
const Grid<float, 3, true> &coords,
const Grid<float, 3, true> &types,
const Grid<float, 2, true> &radii,Grid<float, 5, true> &out) const;

//batched gpu double
template void GridMaker::forward<double,true,2>(const Grid<float, 2, true> &centers,
const Grid<float, 3, true> &coords,
const Grid<float, 2, true> &types,
const Grid<float, 2, true> &radii, Grid<double, 5, true> &out) const;
template void GridMaker::forward<double,true,3>(const Grid<float, 2, true> &centers,
const Grid<float, 3, true> &coords,
const Grid<float, 3, true> &types,
const Grid<float, 2, true> &radii,Grid<double, 5, true> &out) const;


//kernel launch - parallelize across whole atoms
//TODO: accelerate this more
template<typename Dtype>
Expand Down

0 comments on commit eeec466

Please sign in to comment.