From 2ddee2a7d62d079dc2e169b5aca86c795db9753e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ruilong=20Li=28=E6=9D=8E=E7=91=9E=E9=BE=99=29?= Date: Mon, 29 Apr 2024 11:18:06 -0700 Subject: [PATCH] Implement AbsGS (#166) * absgrad * add version bump --------- Co-authored-by: Ruilong Li <397653553@qq.com> Co-authored-by: Jianbo Ye Co-authored-by: Justin Kerr --- gsplat/cuda/csrc/backward.cu | 15 +++++++++++++++ gsplat/cuda/csrc/backward.cuh | 2 ++ gsplat/cuda/csrc/bindings.cu | 10 ++++++++-- gsplat/cuda/csrc/bindings.h | 2 ++ gsplat/rasterize.py | 9 ++++++++- gsplat/version.py | 2 +- 6 files changed, 36 insertions(+), 4 deletions(-) diff --git a/gsplat/cuda/csrc/backward.cu b/gsplat/cuda/csrc/backward.cu index 8b06d0be1..908724b35 100644 --- a/gsplat/cuda/csrc/backward.cu +++ b/gsplat/cuda/csrc/backward.cu @@ -35,6 +35,7 @@ __global__ void nd_rasterize_backward_kernel( const float* __restrict__ v_output, const float* __restrict__ v_output_alpha, float2* __restrict__ v_xy, + float2* __restrict__ v_xy_abs, float3* __restrict__ v_conic, float* __restrict__ v_rgb, float* __restrict__ v_opacity @@ -90,6 +91,7 @@ __global__ void nd_rasterize_backward_kernel( float v_alpha = 0.f; float3 v_conic_local = {0.f, 0.f, 0.f}; float2 v_xy_local = {0.f, 0.f}; + float2 v_xy_abs_local = {0.f, 0.f}; float v_opacity_local = 0.f; if(valid){ // compute the current T for this gaussian @@ -114,19 +116,24 @@ __global__ void nd_rasterize_backward_kernel( 0.5f * v_sigma * delta.y * delta.y}; v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y), v_sigma * (conic.y * delta.x + conic.z * delta.y)}; + v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)}; v_opacity_local = vis * v_alpha; } warpSum3(v_conic_local, warp); warpSum2(v_xy_local, warp); + warpSum2(v_xy_abs_local, warp); warpSum(v_opacity_local, warp); if (warp.thread_rank() == 0) { float* v_conic_ptr = (float*)(v_conic); float* v_xy_ptr = (float*)(v_xy); + float* v_xy_abs_ptr = (float*)(v_xy_abs); atomicAdd(v_conic_ptr + 3*g + 0, v_conic_local.x); atomicAdd(v_conic_ptr + 3*g + 1, v_conic_local.y); atomicAdd(v_conic_ptr + 3*g + 2, v_conic_local.z); atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x); atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y); + atomicAdd(v_xy_abs_ptr + 2*g + 0, v_xy_abs_local.x); + atomicAdd(v_xy_abs_ptr + 2*g + 1, v_xy_abs_local.y); atomicAdd(v_opacity + g, v_opacity_local); } } @@ -147,6 +154,7 @@ __global__ void rasterize_backward_kernel( const float3* __restrict__ v_output, const float* __restrict__ v_output_alpha, float2* __restrict__ v_xy, + float2* __restrict__ v_xy_abs, float3* __restrict__ v_conic, float3* __restrict__ v_rgb, float* __restrict__ v_opacity @@ -251,6 +259,7 @@ __global__ void rasterize_backward_kernel( float3 v_rgb_local = {0.f, 0.f, 0.f}; float3 v_conic_local = {0.f, 0.f, 0.f}; float2 v_xy_local = {0.f, 0.f}; + float2 v_xy_abs_local = {0.f, 0.f}; float v_opacity_local = 0.f; //initialize everything to 0, only set if the lane is valid if(valid){ @@ -284,11 +293,13 @@ __global__ void rasterize_backward_kernel( 0.5f * v_sigma * delta.y * delta.y}; v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y), v_sigma * (conic.y * delta.x + conic.z * delta.y)}; + v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)}; v_opacity_local = vis * v_alpha; } warpSum3(v_rgb_local, warp); warpSum3(v_conic_local, warp); warpSum2(v_xy_local, warp); + warpSum2(v_xy_abs_local, warp); warpSum(v_opacity_local, warp); if (warp.thread_rank() == 0) { int32_t g = id_batch[t]; @@ -305,6 +316,10 @@ __global__ void rasterize_backward_kernel( float* v_xy_ptr = (float*)(v_xy); atomicAdd(v_xy_ptr + 2*g + 0, v_xy_local.x); atomicAdd(v_xy_ptr + 2*g + 1, v_xy_local.y); + + float* v_xy_abs_ptr = (float*)(v_xy_abs); + atomicAdd(v_xy_abs_ptr + 2*g + 0, v_xy_abs_local.x); + atomicAdd(v_xy_abs_ptr + 2*g + 1, v_xy_abs_local.y); atomicAdd(v_opacity + g, v_opacity_local); } diff --git a/gsplat/cuda/csrc/backward.cuh b/gsplat/cuda/csrc/backward.cuh index 91dd33d9c..9e67a8d2a 100644 --- a/gsplat/cuda/csrc/backward.cuh +++ b/gsplat/cuda/csrc/backward.cuh @@ -46,6 +46,7 @@ __global__ void nd_rasterize_backward_kernel( const float* __restrict__ v_output, const float* __restrict__ v_output_alpha, float2* __restrict__ v_xy, + float2* __restrict__ v_xy_abs, float3* __restrict__ v_conic, float* __restrict__ v_rgb, float* __restrict__ v_opacity @@ -66,6 +67,7 @@ __global__ void rasterize_backward_kernel( const float3* __restrict__ v_output, const float* __restrict__ v_output_alpha, float2* __restrict__ v_xy, + float2* __restrict__ v_xy_abs, float3* __restrict__ v_conic, float3* __restrict__ v_rgb, float* __restrict__ v_opacity diff --git a/gsplat/cuda/csrc/bindings.cu b/gsplat/cuda/csrc/bindings.cu index ff7812ef3..e0924ed74 100644 --- a/gsplat/cuda/csrc/bindings.cu +++ b/gsplat/cuda/csrc/bindings.cu @@ -525,6 +525,7 @@ nd_rasterize_forward_tensor( std:: tuple< torch::Tensor, // dL_dxy + torch::Tensor, // dL_dxy_abs torch::Tensor, // dL_dconic torch::Tensor, // dL_dcolors torch::Tensor // dL_dopacity @@ -568,6 +569,7 @@ std:: const int channels = colors.size(1); torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options()); + torch::Tensor v_xy_abs = torch::zeros({num_points, 2}, xys.options()); torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options()); torch::Tensor v_colors = torch::zeros({num_points, channels}, xys.options()); @@ -595,17 +597,19 @@ std:: v_output.contiguous().data_ptr(), v_output_alpha.contiguous().data_ptr(), (float2 *)v_xy.contiguous().data_ptr(), + (float2 *)v_xy_abs.contiguous().data_ptr(), (float3 *)v_conic.contiguous().data_ptr(), v_colors.contiguous().data_ptr(), v_opacity.contiguous().data_ptr() ); - return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); + return std::make_tuple(v_xy, v_xy_abs, v_conic, v_colors, v_opacity); } std:: tuple< torch::Tensor, // dL_dxy + torch::Tensor, // dL_dxy_abs torch::Tensor, // dL_dconic torch::Tensor, // dL_dcolors torch::Tensor // dL_dopacity @@ -649,6 +653,7 @@ std:: const int channels = colors.size(1); torch::Tensor v_xy = torch::zeros({num_points, 2}, xys.options()); + torch::Tensor v_xy_abs = torch::zeros({num_points, 2}, xys.options()); torch::Tensor v_conic = torch::zeros({num_points, 3}, xys.options()); torch::Tensor v_colors = torch::zeros({num_points, channels}, xys.options()); @@ -669,10 +674,11 @@ std:: (float3 *)v_output.contiguous().data_ptr(), v_output_alpha.contiguous().data_ptr(), (float2 *)v_xy.contiguous().data_ptr(), + (float2 *)v_xy_abs.contiguous().data_ptr(), (float3 *)v_conic.contiguous().data_ptr(), (float3 *)v_colors.contiguous().data_ptr(), v_opacity.contiguous().data_ptr() ); - return std::make_tuple(v_xy, v_conic, v_colors, v_opacity); + return std::make_tuple(v_xy, v_xy_abs, v_conic, v_colors, v_opacity); } diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 8cdc39052..c34a88c8e 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -149,6 +149,7 @@ std::tuple< std:: tuple< torch::Tensor, // dL_dxy + torch::Tensor, // dL_dxy_abs torch::Tensor, // dL_dconic torch::Tensor, // dL_dcolors torch::Tensor // dL_dopacity @@ -173,6 +174,7 @@ std:: std:: tuple< torch::Tensor, // dL_dxy + torch::Tensor, // dL_dxy_abs torch::Tensor, // dL_dconic torch::Tensor, // dL_dcolors torch::Tensor // dL_dopacity diff --git a/gsplat/rasterize.py b/gsplat/rasterize.py index e32cf5d31..fd8382ac2 100644 --- a/gsplat/rasterize.py +++ b/gsplat/rasterize.py @@ -8,6 +8,7 @@ from torch.autograd import Function import gsplat.cuda as _C + from .utils import bin_and_sort_gaussians, compute_cumulative_intersects @@ -205,6 +206,7 @@ def backward(ctx, v_out_img, v_out_alpha=None): if num_intersects < 1: v_xy = torch.zeros_like(xys) + v_xy_abs = torch.zeros_like(xys) v_conic = torch.zeros_like(conics) v_colors = torch.zeros_like(colors) v_opacity = torch.zeros_like(opacity) @@ -214,7 +216,7 @@ def backward(ctx, v_out_img, v_out_alpha=None): rasterize_fn = _C.rasterize_backward else: rasterize_fn = _C.nd_rasterize_backward - v_xy, v_conic, v_colors, v_opacity = rasterize_fn( + v_xy, v_xy_abs, v_conic, v_colors, v_opacity = rasterize_fn( img_height, img_width, ctx.block_width, @@ -231,6 +233,11 @@ def backward(ctx, v_out_img, v_out_alpha=None): v_out_alpha, ) + # Abs grad for gaussian splitting criterion. See + # - "AbsGS: Recovering Fine Details for 3D Gaussian Splatting" + # - "EfficientGS: Streamlining Gaussian Splatting for Large-Scale High-Resolution Scene Representation" + xys.absgrad = v_xy_abs + return ( v_xy, # xys None, # depths diff --git a/gsplat/version.py b/gsplat/version.py index 569b1212f..0c5c30071 100644 --- a/gsplat/version.py +++ b/gsplat/version.py @@ -1 +1 @@ -__version__ = "0.1.10" +__version__ = "0.1.11"