diff --git a/fast_gauss/__init__.py b/fast_gauss/__init__.py index fdef091..0d1401f 100644 --- a/fast_gauss/__init__.py +++ b/fast_gauss/__init__.py @@ -1,6 +1,6 @@ import torch from torch import nn -from typing import NamedTuple +from typing import NamedTuple, List, Dict from .gsplat_utils import GSplatContextManager @@ -23,13 +23,20 @@ class GaussianRasterizationSettings(NamedTuple): class GaussianRasterizer: - def __init__(self, raster_settings: GaussianRasterizationSettings, dtype=torch.float, tex_dtype=torch.half): + def __init__(self, + raster_settings: GaussianRasterizationSettings, + init_buffer_size: int = 32768, + init_texture_size: List[int] = [512, 512], + dtype=torch.float, + tex_dtype=torch.half): super().__init__() self.raster_settings = raster_settings global raster_context + dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype + tex_dtype = getattr(torch, tex_dtype) if isinstance(tex_dtype, str) else tex_dtype if raster_context is None or raster_context.dtype != dtype or raster_context.tex_dtype != tex_dtype: - raster_context = GSplatContextManager(dtype=dtype, tex_dtype=tex_dtype) # only created once + raster_context = GSplatContextManager(init_buffer_size=init_buffer_size, init_texture_size=init_texture_size, dtype=dtype, tex_dtype=tex_dtype) # only created once def __call__(self, means3D: torch.Tensor, diff --git a/requirements.txt b/requirements.txt index 1b1206a..ba8dfaa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ torch numpy cuda-python PyGLM -PyOpenGL \ No newline at end of file +PyOpenGL>=3.1.7 \ No newline at end of file