Skip to content

Commit

Permalink
stricter req & better init control
Browse files Browse the repository at this point in the history
  • Loading branch information
dendenxu committed Apr 18, 2024
1 parent b6d3e04 commit c88faf1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions fast_gauss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from typing import NamedTuple
from typing import NamedTuple, List, Dict

from .gsplat_utils import GSplatContextManager

Expand All @@ -23,13 +23,20 @@ class GaussianRasterizationSettings(NamedTuple):


class GaussianRasterizer:
def __init__(self, raster_settings: GaussianRasterizationSettings, dtype=torch.float, tex_dtype=torch.half):
def __init__(self,
raster_settings: GaussianRasterizationSettings,
init_buffer_size: int = 32768,
init_texture_size: List[int] = [512, 512],
dtype=torch.float,
tex_dtype=torch.half):
super().__init__()
self.raster_settings = raster_settings

global raster_context
dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
tex_dtype = getattr(torch, tex_dtype) if isinstance(tex_dtype, str) else tex_dtype
if raster_context is None or raster_context.dtype != dtype or raster_context.tex_dtype != tex_dtype:
raster_context = GSplatContextManager(dtype=dtype, tex_dtype=tex_dtype) # only created once
raster_context = GSplatContextManager(init_buffer_size=init_buffer_size, init_texture_size=init_texture_size, dtype=dtype, tex_dtype=tex_dtype) # only created once

def __call__(self,
means3D: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ torch
numpy
cuda-python
PyGLM
PyOpenGL
PyOpenGL>=3.1.7

0 comments on commit c88faf1

Please sign in to comment.