-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdepth.py
61 lines (45 loc) · 1.58 KB
/
depth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from __future__ import annotations # tolerate "subscriptable 'type' for < 3.9
from functools import partial
from typing import NamedTuple
import jax
from jaxtyping import Array, Float
from jaxtyping import jaxtyped # pyright: ignore[reportUnknownVariableType]
from .._backport import Tuple
from .._meta_utils import add_tracing_name
from .._meta_utils import typed_jit as jit
from ..geometry import Camera, to_homogeneous
from ..shader import ID, PerVertex, Shader
from ..types import Vec4f
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
class DepthExtraInput(NamedTuple):
"""Extra input for Depth Shader.
Attributes:
- position: in world space, of each vertex.
"""
position: Float[Array, "vertices 3"] # in world space
class DepthExtraFragmentData(NamedTuple):
pass
class DepthExtraMixerOutput(NamedTuple):
pass
class DepthShader(
Shader[DepthExtraInput, DepthExtraFragmentData, DepthExtraMixerOutput]
):
"""Depth Shading."""
@staticmethod
@jaxtyped
@partial(jit, inline=True)
@add_tracing_name
def vertex(
gl_VertexID: ID,
gl_InstanceID: ID,
camera: Camera,
extra: DepthExtraInput,
) -> Tuple[PerVertex, DepthExtraFragmentData]:
# Use gl_VertexID to index in `extra` buffer.
position: Vec4f = to_homogeneous(extra.position[gl_VertexID])
gl_Position: Vec4f = camera.to_clip(position)
assert isinstance(gl_Position, Vec4f)
return (
PerVertex(gl_Position=gl_Position),
DepthExtraFragmentData(),
)