Skip to content

Commit

Permalink
remove projmat from API (#149)
Browse files Browse the repository at this point in the history
* remove 'projmat' arg
  • Loading branch information
kerrj authored Mar 27, 2024
1 parent 1516b9d commit 2d69aaf
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 17 deletions.
1 change: 0 additions & 1 deletion examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def train(
1,
self.quats,
self.viewmat,
self.viewmat,
self.focal,
self.focal,
self.W / 2,
Expand Down
18 changes: 10 additions & 8 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
w, x, y, z = torch.unbind(quat, dim=-1)
mat = torch.stack(
[
1 - 2 * (y**2 + z**2),
1 - 2 * (y ** 2 + z ** 2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x**2 + z**2),
1 - 2 * (x ** 2 + z ** 2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x**2 + y**2),
1 - 2 * (x ** 2 + y ** 2),
],
dim=-1,
)
Expand Down Expand Up @@ -165,7 +165,7 @@ def project_cov3d_ewa(
t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3)

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz**2 # (...,)
rz2 = rz ** 2 # (...,)

lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device)
lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device)
Expand Down Expand Up @@ -220,24 +220,26 @@ def compute_cov2d_bounds(cov2d_mat: Tensor):
dim=-1,
) # (..., 3)
b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,)
v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,)
radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device)
conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device)
radius_all[valid] = radius
conic_all[valid] = conic
return conic_all, radius_all, valid


def project_pix(fxfy, p_view, center, eps=1e-6):
fx, fy = fxfy
cx, cy = center

rw = 1.0 / (p_view[..., 2] + 1e-6)
p_proj = ( p_view[..., 0] * rw, p_view[..., 1] * rw )
u, v = ( p_proj[0] * fx + cx, p_proj[1] * fy + cy )
p_proj = (p_view[..., 0] * rw, p_view[..., 1] * rw)
u, v = (p_proj[0] * fx + cx, p_proj[1] * fy + cy)
return torch.stack([u, v], dim=-1)


def clip_near_plane(p, viewmat, clip_thresh=0.01):
R = viewmat[:3, :3]
T = viewmat[:3, 3]
Expand Down
10 changes: 3 additions & 7 deletions gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def project_gaussians(
glob_scale: float,
quats: Float[Tensor, "*batch 4"],
viewmat: Float[Tensor, "4 4"],
projmat: Optional[Float[Tensor, "4 4"]],
fx: float,
fy: float,
cx: float,
Expand All @@ -37,7 +36,6 @@ def project_gaussians(
glob_scale (float): A global scaling factor applied to the scene.
quats (Tensor): rotations in quaternion [w,x,y,z] format.
viewmat (Tensor): view matrix for rendering.
projmat (Tensor): DEPRECATED and ignored. Set to None
fx (float): focal length x.
fy (float): focal length y.
cx (float): principal point x.
Expand Down Expand Up @@ -65,7 +63,6 @@ def project_gaussians(
glob_scale,
quats.contiguous(),
viewmat.contiguous(),
None,
fx,
fy,
cx,
Expand All @@ -88,7 +85,6 @@ def forward(
glob_scale: float,
quats: Float[Tensor, "*batch 4"],
viewmat: Float[Tensor, "4 4"],
projmat: Optional[Float[Tensor, "4 4"]],
fx: float,
fy: float,
cx: float,
Expand Down Expand Up @@ -227,7 +223,9 @@ def backward(
# gradent w.r.t. view matrix rotation
for j in range(3):
for l in range(3):
v_viewmat[..., j, l] = torch.dot(v_mean3d_cam[..., j], means3d[..., l])
v_viewmat[..., j, l] = torch.dot(
v_mean3d_cam[..., j], means3d[..., l]
)
else:
v_viewmat = None

Expand All @@ -243,8 +241,6 @@ def backward(
v_quat,
# viewmat: Float[Tensor, "4 4"],
v_viewmat,
# projmat: Float[Tensor, "4 4"],
None,
# fx: float,
None,
# fy: float,
Expand Down
1 change: 0 additions & 1 deletion tests/test_project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_project_gaussians_forward():
glob_scale,
quats,
viewmat,
None, # deprecated projmat/fullmat
fx,
fy,
cx,
Expand Down

0 comments on commit 2d69aaf

Please sign in to comment.