Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Depth Supervision for Gaussian Splatting #3182

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,21 +222,20 @@ ns-export pointcloud --help

Using an existing dataset is great, but likely you want to use your own data! We support various methods for using your own data. Before it can be used in nerfstudio, the camera location and orientations must be determined and then converted into our format using `ns-process-data`. We rely on external tools for this, instructions and information can be found in the documentation.

| Data | Capture Device | Requirements | `ns-process-data` Speed |
| --------------------------------------------------------------------------------------------- | -------------- | ----------------------------------------------------------------- | ----------------------- |
| 📷 [Images](https://docs.nerf.studio/quickstart/custom_dataset.html#images-or-video) | Any | [COLMAP](https://colmap.github.io/install.html) | 🐢 |
| 📹 [Video](https://docs.nerf.studio/quickstart/custom_dataset.html#images-or-video) | Any | [COLMAP](https://colmap.github.io/install.html) | 🐢 |
| 🌎 [360 Data](https://docs.nerf.studio/quickstart/custom_dataset.html#data-equirectangular) | Any | [COLMAP](https://colmap.github.io/install.html) | 🐢 |
| 📱 [Polycam](https://docs.nerf.studio/quickstart/custom_dataset.html#polycam-capture) | IOS with LiDAR | [Polycam App](https://poly.cam/) | 🐇 |
| 📱 [KIRI Engine](https://docs.nerf.studio/quickstart/custom_dataset.html#kiri-engine-capture) | IOS or Android | [KIRI Engine App](https://www.kiriengine.com/) | 🐇 |
| 📱 [Record3D](https://docs.nerf.studio/quickstart/custom_dataset.html#record3d-capture) | IOS with LiDAR | [Record3D app](https://record3d.app/) | 🐇 |
| 📱 [Spectacular AI](https://docs.nerf.studio/quickstart/custom_dataset.html#spectacularai) | IOS, OAK, [others](https://www.spectacularai.com/mapping#supported-devices) | [App](https://apps.apple.com/us/app/spectacular-rec/id6473188128) / [`sai-cli`](https://www.spectacularai.com/mapping) | 🐇 |
| 🖥 [Metashape](https://docs.nerf.studio/quickstart/custom_dataset.html#metashape) | Any | [Metashape](https://www.agisoft.com/) | 🐇 |
| 🖥 [RealityCapture](https://docs.nerf.studio/quickstart/custom_dataset.html#realitycapture) | Any | [RealityCapture](https://www.capturingreality.com/realitycapture) | 🐇 |
| 🖥 [ODM](https://docs.nerf.studio/quickstart/custom_dataset.html#odm) | Any | [ODM](https://github.com/OpenDroneMap/ODM) | 🐇 |
| 👓 [Aria](https://docs.nerf.studio/quickstart/custom_dataset.html#aria) | Aria glasses | [Project Aria](https://projectaria.com/) | 🐇 |
| 🛠 [Custom](https://docs.nerf.studio/quickstart/data_conventions.html) | Any | Camera Poses | 🐇 |

| Data | Capture Device | Requirements | `ns-process-data` Speed |
| --------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------- | ----------------------- |
| 📷 [Images](https://docs.nerf.studio/quickstart/custom_dataset.html#images-or-video) | Any | [COLMAP](https://colmap.github.io/install.html) | 🐢 |
| 📹 [Video](https://docs.nerf.studio/quickstart/custom_dataset.html#images-or-video) | Any | [COLMAP](https://colmap.github.io/install.html) | 🐢 |
| 🌎 [360 Data](https://docs.nerf.studio/quickstart/custom_dataset.html#data-equirectangular) | Any | [COLMAP](https://colmap.github.io/install.html) | 🐢 |
| 📱 [Polycam](https://docs.nerf.studio/quickstart/custom_dataset.html#polycam-capture) | IOS with LiDAR | [Polycam App](https://poly.cam/) | 🐇 |
| 📱 [KIRI Engine](https://docs.nerf.studio/quickstart/custom_dataset.html#kiri-engine-capture) | IOS or Android | [KIRI Engine App](https://www.kiriengine.com/) | 🐇 |
| 📱 [Record3D](https://docs.nerf.studio/quickstart/custom_dataset.html#record3d-capture) | IOS with LiDAR | [Record3D app](https://record3d.app/) | 🐇 |
| 📱 [Spectacular AI](https://docs.nerf.studio/quickstart/custom_dataset.html#spectacularai) | IOS, OAK, [others](https://www.spectacularai.com/mapping#supported-devices) | [App](https://apps.apple.com/us/app/spectacular-rec/id6473188128) / [`sai-cli`](https://www.spectacularai.com/mapping) | 🐇 |
| 🖥 [Metashape](https://docs.nerf.studio/quickstart/custom_dataset.html#metashape) | Any | [Metashape](https://www.agisoft.com/) | 🐇 |
| 🖥 [RealityCapture](https://docs.nerf.studio/quickstart/custom_dataset.html#realitycapture) | Any | [RealityCapture](https://www.capturingreality.com/realitycapture) | 🐇 |
| 🖥 [ODM](https://docs.nerf.studio/quickstart/custom_dataset.html#odm) | Any | [ODM](https://github.com/OpenDroneMap/ODM) | 🐇 |
| 👓 [Aria](https://docs.nerf.studio/quickstart/custom_dataset.html#aria) | Aria glasses | [Project Aria](https://projectaria.com/) | 🐇 |
| 🛠 [Custom](https://docs.nerf.studio/quickstart/data_conventions.html) | Any | Camera Poses | 🐇 |

## 5. Advanced Options

Expand Down
30 changes: 20 additions & 10 deletions docs/nerfology/methods/splat.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Splatfacto

<h4>Nerfstudio's Gaussian Splatting Implementation</h4>
<iframe width="560" height="315" src="https://www.youtube.com/embed/0yueTFx-MdQ?si=GxiYnFAeYVVl-soJ" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>

Expand All @@ -17,53 +18,62 @@ To avoid confusion with the original paper, we refer to nerfstudio's implementat
```{button-link} https://docs.gsplat.studio/
:color: primary
:outline:
GSplat
GSplat
```

Nerfstudio uses [gsplat](https://github.com/nerfstudio-project/gsplat) as its gaussian rasterization backend, an in-house re-implementation which is designed to be more developer friendly. This can be installed with `pip install gsplat`. The associated CUDA code will be compiled the first time gsplat is executed. Some users with PyTorch 2.0 have experienced issues with this, which can be resolved by either installing gsplat from source, or upgrading torch to 2.1.

### Data
Gaussian splatting works much better if you initialize it from pre-existing geometry, such as SfM points from COLMAP. COLMAP datasets or datasets from `ns-process-data` will automatically save these points and initialize gaussians on them. Other datasets currently do not support initialization, and will initialize gaussians randomly. Initializing from other data inputs (i.e. depth from phone app scanners) may be supported in the future.

Because the method trains on *full images* instead of bundles of rays, there is a new datamanager in `full_images_datamanager.py` which undistorts input images, caches them, and provides single images at each train step.
Gaussian splatting works much better if you initialize it from pre-existing geometry, such as SfM points from COLMAP. COLMAP datasets or datasets from `ns-process-data` will automatically save these points and initialize gaussians on them. Other datasets currently do not support initialization, and will initialize gaussians randomly. Initializing from other data inputs (i.e. depth from phone app scanners) may be supported in the future.

Because the method trains on _full images_ instead of bundles of rays, there is a new datamanager in `full_images_datamanager.py` which undistorts input images, caches them, and provides single images at each train step.

### Running the Method

To run splatfacto, run `ns-train splatfacto --data <data>`. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported.

We provide a few additional variants:

| Method | Description | Memory | Speed |
| ---------------- | ------------------------------ | ------ | ------- |
| `splatfacto` | Default Model | ~6GB | Fast |
| `splatfacto-big` | More Gaussians, Higher Quality | ~12GB | Slower |
| Method | Description | Memory | Speed |
| ------------------ | -------------------------------- | ------ | ------ |
| `splatfacto` | Default Model | ~6GB | Fast |
| `depth-splatfacto` | Default Model, Depth Supervision | ~6GB | Fast |
| `splatfacto-big` | More Gaussians, Higher Quality | ~12GB | Slower |


A full evalaution of Nerfstudio's implementation of Gaussian Splatting against the original Inria method can be found [here](https://docs.gsplat.studio/main/tests/eval.html).


#### Quality and Regularization
The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold

The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold
(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train splatfacto --pipeline.model.cull_alpha_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data <data>`

A common artifact in splatting is long, spikey gaussians. [PhysGaussian](https://xpandora.github.io/PhysGaussian/) proposes a scale regularizer that encourages gaussians to be more evenly shaped. To enable this, set the `pipeline.model.use_scale_regularization` flag to `True`.

### Details

For more details on the method, see the [original paper](https://arxiv.org/abs/2308.04079). Additionally, for a detailed derivation of the gradients used in the gsplat library, see [here](https://arxiv.org/abs/2312.02121).

### Exporting splats

Gaussian splats can be exported as a `.ply` file which are ingestable by a variety of online web viewers. You can do this via the viewer, or `ns-export gaussian-splat --load-config <config> --output-dir exports/splat`. Currently splats can only be exported from trained splats, not from nerfacto.

Nerfstudio's splat export currently supports multiple third-party splat viewers:

- [Polycam Viewer](https://poly.cam/tools/gaussian-splatting)
- [Playcanvas SuperSplat](https://playcanvas.com/super-splat)
- [WebGL Viewer by antimatter15](https://antimatter15.com/splat/)
- [Spline](https://spline.design/)
- [WebGL Viewer by antimatter15](https://antimatter15.com/splat/)
- [Spline](https://spline.design/)
- [Three.js Viewer by mkkellogg](https://github.com/mkkellogg/GaussianSplats3D)

### FAQ

- Can I export a mesh or pointcloud?

Currently these export options are not supported, but may be in the future. Contributions are always welcome!

- Can I render fisheye, equirectangular, orthographic images?

Currently, no. Gaussian rasterization assumes a perspective camera for its rasterization pipeline. Implementing other camera models is of interest but not currently planned.
54 changes: 53 additions & 1 deletion nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nerfstudio.configs.base_config import ViewerConfig
from nerfstudio.configs.external_methods import ExternalMethodDummyTrainerConfig, get_external_methods
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager, VanillaDataManagerConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager, FullImageDatamanagerConfig
from nerfstudio.data.datamanagers.parallel_datamanager import ParallelDataManagerConfig
from nerfstudio.data.datamanagers.random_cameras_datamanager import RandomCamerasDataManagerConfig
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
Expand All @@ -51,6 +51,7 @@
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.fields.sdf_field import SDFFieldConfig
from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig
from nerfstudio.models.depth_splatfacto import DepthSplatfactoModelConfig
from nerfstudio.models.generfacto import GenerfactoModelConfig
from nerfstudio.models.instant_ngp import InstantNGPModelConfig
from nerfstudio.models.mipnerf import MipNerfModel
Expand Down Expand Up @@ -82,6 +83,7 @@
"neus": "Implementation of NeuS. (slow)",
"neus-facto": "Implementation of NeuS-Facto. (slow)",
"splatfacto": "Gaussian Splatting model",
"depth-splatfacto": "Depth supervised Gaussian Splatting model",
"splatfacto-big": "Larger version of Splatfacto with higher quality.",
}

Expand Down Expand Up @@ -642,6 +644,56 @@
vis="viewer",
)

method_configs["depth-splatfacto"] = TrainerConfig(
method_name="depth-splatfacto",
steps_per_eval_image=100,
steps_per_eval_batch=0,
steps_per_save=2000,
steps_per_eval_all_images=1000,
max_num_iterations=30000,
mixed_precision=False,
pipeline=VanillaPipelineConfig(
datamanager=FullImageDatamanagerConfig(
_target=FullImageDatamanager[DepthDataset],
dataparser=NerfstudioDataParserConfig(load_3D_points=True),
cache_images_type="uint8",
),
model=DepthSplatfactoModelConfig(),
),
optimizers={
"means": {
"optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1.6e-6,
max_steps=30000,
),
},
"features_dc": {
"optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15),
"scheduler": None,
},
"features_rest": {
"optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15),
"scheduler": None,
},
"opacities": {
"optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
"scheduler": None,
},
"scales": {
"optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
"scheduler": None,
},
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
)

method_configs["splatfacto-big"] = TrainerConfig(
method_name="splatfacto",
steps_per_eval_image=100,
Expand Down
44 changes: 44 additions & 0 deletions nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from jaxtyping import Bool, Float
from torch import Tensor, nn
from torchmetrics.functional.regression import pearson_corrcoef

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.field_components.field_heads import FieldHeadNames
Expand All @@ -44,6 +45,8 @@ class DepthLossType(Enum):
DS_NERF = 1
URF = 2
SPARSENERF_RANKING = 3
MSE = 4
PEARSON_LOSS = 5


FORCE_PSEUDODEPTH_LOSS = False
Expand Down Expand Up @@ -221,6 +224,44 @@ def pred_normal_loss(
"""Loss between normals calculated from density and normals from prediction network."""
return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1)

def mse_depth_loss(
termination_depth: Float[Tensor, "*batch 1"],
predicted_depth: Float[Tensor, "*batch 1"],
)-> Float[Tensor, "*batch 1"]:
"""MSE depth loss.

Args:
termination_depth: Ground truth depth of rays.
predicted_depth: Predicted depths.
Returns:
Depth loss scalar.
"""
depth_mask = termination_depth > 0

expected_depth_loss = (termination_depth - predicted_depth) ** 2

expected_depth_loss = expected_depth_loss * depth_mask
return torch.mean(expected_depth_loss)


def pearson_correlation_depth_loss(
termination_depth,
predicted_depth,
)-> Float[Tensor, "*batch 1"]:
"""Pearson correlation depth loss.

Args:
termination_depth: Ground truth depth of rays.
predicted_depth: Rendered depth from the radiance field
Returns:
Depth loss scalar.
"""
termination_depth = termination_depth.reshape(-1, 1)
predicted_depth = predicted_depth.reshape(-1, 1)

loss = (1 - pearson_corrcoef( predicted_depth, termination_depth))
return torch.mean(loss)


def ds_nerf_depth_loss(
weights: Float[Tensor, "*batch num_samples 1"],
Expand Down Expand Up @@ -318,6 +359,9 @@ def depth_loss(
if depth_loss_type == DepthLossType.DS_NERF:
lengths = ray_samples.frustums.ends - ray_samples.frustums.starts
return ds_nerf_depth_loss(weights, termination_depth, steps, lengths, sigma)

if depth_loss_type == DepthLossType.MSE:
return mse_depth_loss(termination_depth, predicted_depth)

if depth_loss_type == DepthLossType.URF:
return urban_radiance_field_depth_loss(weights, termination_depth, predicted_depth, steps, sigma)
Expand Down
Loading
Loading