Skip to content

Commit

Permalink
Implement AbsGS (#166)
Browse files Browse the repository at this point in the history
* absgrad

* add version bump

---------

Co-authored-by: Ruilong Li <[email protected]>
Co-authored-by: Jianbo Ye <[email protected]>
Co-authored-by: Justin Kerr <[email protected]>
  • Loading branch information
4 people authored Apr 29, 2024
1 parent 8a19034 commit 2ddee2a
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 4 deletions.
15 changes: 15 additions & 0 deletions gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ __global__ void nd_rasterize_backward_kernel(
const float* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float2* __restrict__ v_xy_abs,
float3* __restrict__ v_conic,
float* __restrict__ v_rgb,
float* __restrict__ v_opacity
Expand Down Expand Up @@ -90,6 +91,7 @@ __global__ void nd_rasterize_backward_kernel(
float v_alpha = 0.f;
float3 v_conic_local = {0.f, 0.f, 0.f};
float2 v_xy_local = {0.f, 0.f};
float2 v_xy_abs_local = {0.f, 0.f};
float v_opacity_local = 0.f;
if(valid){
// compute the current T for this gaussian
Expand All @@ -114,19 +116,24 @@ __global__ void nd_rasterize_backward_kernel(
0.5f * v_sigma * delta.y * delta.y};
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)};
v_opacity_local = vis * v_alpha;
}
warpSum3(v_conic_local, warp);
warpSum2(v_xy_local, warp);
warpSum2(v_xy_abs_local, warp);
warpSum(v_opacity_local, warp);
if (warp.thread_rank() == 0) {
float* v_conic_ptr = (float*)(v_conic);
float* v_xy_ptr = (float*)(v_xy);
float* v_xy_abs_ptr = (float*)(v_xy_abs);
atomicAdd(v_conic_ptr + 3*g + 0, v_conic_local.x);
atomicAdd(v_conic_ptr + 3*g + 1, v_conic_local.y);
atomicAdd(v_conic_ptr + 3*g + 2, v_conic_local.z);
atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x);
atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y);
atomicAdd(v_xy_abs_ptr + 2*g + 0, v_xy_abs_local.x);
atomicAdd(v_xy_abs_ptr + 2*g + 1, v_xy_abs_local.y);
atomicAdd(v_opacity + g, v_opacity_local);
}
}
Expand All @@ -147,6 +154,7 @@ __global__ void rasterize_backward_kernel(
const float3* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float2* __restrict__ v_xy_abs,
float3* __restrict__ v_conic,
float3* __restrict__ v_rgb,
float* __restrict__ v_opacity
Expand Down Expand Up @@ -251,6 +259,7 @@ __global__ void rasterize_backward_kernel(
float3 v_rgb_local = {0.f, 0.f, 0.f};
float3 v_conic_local = {0.f, 0.f, 0.f};
float2 v_xy_local = {0.f, 0.f};
float2 v_xy_abs_local = {0.f, 0.f};
float v_opacity_local = 0.f;
//initialize everything to 0, only set if the lane is valid
if(valid){
Expand Down Expand Up @@ -284,11 +293,13 @@ __global__ void rasterize_backward_kernel(
0.5f * v_sigma * delta.y * delta.y};
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)};
v_opacity_local = vis * v_alpha;
}
warpSum3(v_rgb_local, warp);
warpSum3(v_conic_local, warp);
warpSum2(v_xy_local, warp);
warpSum2(v_xy_abs_local, warp);
warpSum(v_opacity_local, warp);
if (warp.thread_rank() == 0) {
int32_t g = id_batch[t];
Expand All @@ -305,6 +316,10 @@ __global__ void rasterize_backward_kernel(
float* v_xy_ptr = (float*)(v_xy);
atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x);
atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y);

float* v_xy_abs_ptr = (float*)(v_xy_abs);
atomicAdd(v_xy_abs_ptr + 2*g + 0, v_xy_abs_local.x);
atomicAdd(v_xy_abs_ptr + 2*g + 1, v_xy_abs_local.y);

atomicAdd(v_opacity + g, v_opacity_local);
}
Expand Down
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/backward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ __global__ void nd_rasterize_backward_kernel(
const float* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float2* __restrict__ v_xy_abs,
float3* __restrict__ v_conic,
float* __restrict__ v_rgb,
float* __restrict__ v_opacity
Expand All @@ -66,6 +67,7 @@ __global__ void rasterize_backward_kernel(
const float3* __restrict__ v_output,
const float* __restrict__ v_output_alpha,
float2* __restrict__ v_xy,
float2* __restrict__ v_xy_abs,
float3* __restrict__ v_conic,
float3* __restrict__ v_rgb,
float* __restrict__ v_opacity
Expand Down
10 changes: 8 additions & 2 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ nd_rasterize_forward_tensor(
std::
tuple<
torch::Tensor, // dL_dxy
torch::Tensor, // dL_dxy_abs
torch::Tensor, // dL_dconic
torch::Tensor, // dL_dcolors
torch::Tensor // dL_dopacity
Expand Down Expand Up @@ -568,6 +569,7 @@ std::
const int channels = colors.size(1);

torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options());
torch::Tensor v_xy_abs = torch::zeros({num_points, 2}, xys.options());
torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options());
torch::Tensor v_colors =
torch::zeros({num_points, channels}, xys.options());
Expand Down Expand Up @@ -595,17 +597,19 @@ std::
v_output.contiguous().data_ptr<float>(),
v_output_alpha.contiguous().data_ptr<float>(),
(float2 *)v_xy.contiguous().data_ptr<float>(),
(float2 *)v_xy_abs.contiguous().data_ptr<float>(),
(float3 *)v_conic.contiguous().data_ptr<float>(),
v_colors.contiguous().data_ptr<float>(),
v_opacity.contiguous().data_ptr<float>()
);

return std::make_tuple(v_xy, v_conic, v_colors, v_opacity);
return std::make_tuple(v_xy, v_xy_abs, v_conic, v_colors, v_opacity);
}

std::
tuple<
torch::Tensor, // dL_dxy
torch::Tensor, // dL_dxy_abs
torch::Tensor, // dL_dconic
torch::Tensor, // dL_dcolors
torch::Tensor // dL_dopacity
Expand Down Expand Up @@ -649,6 +653,7 @@ std::
const int channels = colors.size(1);

torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options());
torch::Tensor v_xy_abs = torch::zeros({num_points, 2}, xys.options());
torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options());
torch::Tensor v_colors =
torch::zeros({num_points, channels}, xys.options());
Expand All @@ -669,10 +674,11 @@ std::
(float3 *)v_output.contiguous().data_ptr<float>(),
v_output_alpha.contiguous().data_ptr<float>(),
(float2 *)v_xy.contiguous().data_ptr<float>(),
(float2 *)v_xy_abs.contiguous().data_ptr<float>(),
(float3 *)v_conic.contiguous().data_ptr<float>(),
(float3 *)v_colors.contiguous().data_ptr<float>(),
v_opacity.contiguous().data_ptr<float>()
);

return std::make_tuple(v_xy, v_conic, v_colors, v_opacity);
return std::make_tuple(v_xy, v_xy_abs, v_conic, v_colors, v_opacity);
}
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ std::tuple<
std::
tuple<
torch::Tensor, // dL_dxy
torch::Tensor, // dL_dxy_abs
torch::Tensor, // dL_dconic
torch::Tensor, // dL_dcolors
torch::Tensor // dL_dopacity
Expand All @@ -173,6 +174,7 @@ std::
std::
tuple<
torch::Tensor, // dL_dxy
torch::Tensor, // dL_dxy_abs
torch::Tensor, // dL_dconic
torch::Tensor, // dL_dcolors
torch::Tensor // dL_dopacity
Expand Down
9 changes: 8 additions & 1 deletion gsplat/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.autograd import Function

import gsplat.cuda as _C

from .utils import bin_and_sort_gaussians, compute_cumulative_intersects


Expand Down Expand Up @@ -205,6 +206,7 @@ def backward(ctx, v_out_img, v_out_alpha=None):

if num_intersects < 1:
v_xy = torch.zeros_like(xys)
v_xy_abs = torch.zeros_like(xys)
v_conic = torch.zeros_like(conics)
v_colors = torch.zeros_like(colors)
v_opacity = torch.zeros_like(opacity)
Expand All @@ -214,7 +216,7 @@ def backward(ctx, v_out_img, v_out_alpha=None):
rasterize_fn = _C.rasterize_backward
else:
rasterize_fn = _C.nd_rasterize_backward
v_xy, v_conic, v_colors, v_opacity = rasterize_fn(
v_xy, v_xy_abs, v_conic, v_colors, v_opacity = rasterize_fn(
img_height,
img_width,
ctx.block_width,
Expand All @@ -231,6 +233,11 @@ def backward(ctx, v_out_img, v_out_alpha=None):
v_out_alpha,
)

# Abs grad for gaussian splitting criterion. See
# - "AbsGS: Recovering Fine Details for 3D Gaussian Splatting"
# - "EfficientGS: Streamlining Gaussian Splatting for Large-Scale High-Resolution Scene Representation"
xys.absgrad = v_xy_abs

return (
v_xy, # xys
None, # depths
Expand Down
2 changes: 1 addition & 1 deletion gsplat/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.10"
__version__ = "0.1.11"

0 comments on commit 2ddee2a

Please sign in to comment.