-
Notifications
You must be signed in to change notification settings - Fork 12
/
normal_shading.py
83 lines (72 loc) · 3.02 KB
/
normal_shading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Edited from https://pytorch3d.readthedocs.io/en/v0.6.0/_modules/pytorch3d/renderer/mesh/shader.html#HardPhongShader
import warnings
from typing import Optional
import torch
import torch.nn as nn
from pytorch3d.structures.meshes import Meshes
from pytorch3d.renderer.blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.utils import TensorProperties
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.shading import flat_shading, gouraud_shading
from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes
class HardPhongNormalShader(nn.Module):
"""
Modifies HardPhongShader to return normals
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardPhongShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self
def phong_normal_shading(self, meshes, fragments) -> torch.Tensor:
faces = meshes.faces_packed() # (F, 3)
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_normals = vertex_normals[faces]
ones = torch.ones_like(fragments.bary_coords)
pixel_normals = interpolate_face_attributes(
fragments.pix_to_face, ones, faces_normals
)
return pixel_normals
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardPhongShader"
raise ValueError(msg)
normals = self.phong_normal_shading(
meshes=meshes,
fragments=fragments,
)
return normals