Skip to content

Commit

Permalink
add tests for shader program compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishavlin committed Jun 28, 2024
1 parent 53c5ec4 commit 470f096
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 14 deletions.
38 changes: 25 additions & 13 deletions yt_idv/shader_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class Shader(traitlets.HasTraits):
"""

_shader = None
allow_null = traitlets.Bool(True)
source = traitlets.Any()
shader_name = traitlets.CUnicode()
info = traitlets.CUnicode()
Expand All @@ -252,6 +253,7 @@ class Shader(traitlets.HasTraits):
)
blend_equation = GLValue("func add")
depth_test = GLValue("always")
_filename = traitlets.Unicode()

use_separate_blend = traitlets.Bool(False)
blend_equation_separate = traitlets.Tuple(
Expand Down Expand Up @@ -295,6 +297,7 @@ def _get_source(self, source):
fn = os.path.join(sh_directory, fn)
if not os.path.isfile(fn):
raise YTInvalidShaderType(fn)
self._filename = fn
full_source.append(open(fn).read())
return "\n\n".join(full_source)

Expand All @@ -320,7 +323,11 @@ def compile(self, source=None):
GL.glCompileShader(shader)
result = GL.glGetShaderiv(shader, GL.GL_COMPILE_STATUS)
if not (result):
raise RuntimeError(GL.glGetShaderInfoLog(shader))
msg = f"shader complilation error for {self.shader_name} {self.shader_type}"
if self._filename is not None:
msg += f" in {self._filename}."
msg += f"\n GL.glGetShaderInfoLog: \n {GL.glGetShaderInfoLog(shader)}"
raise RuntimeError(msg)
self._shader = shader

def setup_blend(self):
Expand All @@ -340,10 +347,13 @@ def shader(self):
try:
self.compile()
except RuntimeError as exc:
print(exc)
for line_num, line in enumerate(self.shader_source.split("\n")):
print(f"{line_num + 1:05}: {line}")
self._enable_null_shader()
if self.allow_null:
print(exc)
for line_num, line in enumerate(self.shader_source.split("\n")):
print(f"{line_num + 1:05}: {line}")
self._enable_null_shader()
else:
raise exc
return self._shader

def delete_shader(self):
Expand All @@ -356,6 +366,15 @@ def __del__(self):
self.delete_shader()


def _validate_shader(shader_type, value, allow_null=True):
shader_info = known_shaders[shader_type][value]
shader_info.setdefault("shader_type", shader_type)
shader_info["use_separate_blend"] = bool("blend_func_separate" in shader_info)
shader_info.setdefault("shader_name", value)
shader = Shader(allow_null=allow_null, **shader_info)
return shader


class ShaderTrait(traitlets.TraitType):
default_value = None
info_text = "A shader (vertex, fragment or geometry)"
Expand All @@ -364,14 +383,7 @@ def validate(self, obj, value):
if isinstance(value, str):
try:
shader_type = self.metadata.get("shader_type", "vertex")
shader_info = known_shaders[shader_type][value]
shader_info.setdefault("shader_type", shader_type)
shader_info["use_separate_blend"] = bool(
"blend_func_separate" in shader_info
)
shader_info.setdefault("shader_name", value)
shader = Shader(**shader_info)
return shader
return _validate_shader(shader_type, value)
except KeyError:
self.error(obj, value)
elif isinstance(value, Shader):
Expand Down
18 changes: 18 additions & 0 deletions yt_idv/shaders/bad_shader.vert.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
in vec4 model_vertex; // The location of the vertex in model space
in vec3 in_dx;
in vec3 in_left_edge;
in vec3 in_right_edge;
out vec4 v_model;
flat out vec3 v_camera_pos;
flat out mat4 inverse_proj;
flat out mat4 inverse_mvm;
flat out mat4 inverse_pmvm;
flat out vec3 dx;
flat out vec3 left_edge;
flat out vec3 right_edge;

void main()
{
this_is a very bad shader. do_not use it. it is for_testing that the
pytest tests_do indeed catch shader compilation errors.
}
60 changes: 59 additions & 1 deletion yt_idv/tests/test_yt_idv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytest_html import extras

import yt_idv
from yt_idv import shader_objects
from yt_idv.scene_components.curves import CurveCollectionRendering, CurveRendering
from yt_idv.scene_data.curve import CurveCollection, CurveData

Expand Down Expand Up @@ -151,10 +152,67 @@ def test_curves(osmesa_fake_amr, image_store):
curve_collection.add_data() # call add_data() after done adding curves

cc_render = CurveCollectionRendering(
data=curve_collection, curve_rgb=(0.2, 0.2, 0.2, 1.0), line_width=4
data=curve_collection, curve_rgba=(0.2, 0.2, 0.2, 1.0), line_width=4
)
cc_render.display_name = "multiple streamlines"
osmesa_fake_amr.scene.data_objects.append(curve_collection)
osmesa_fake_amr.scene.components.append(cc_render)

image_store(osmesa_fake_amr)


@pytest.fixture()
def set_very_bad_shader():
known_shaders = shader_objects.known_shaders
good_shader = known_shaders["vertex"]["default"]["source"]
known_shaders["vertex"]["default"]["source"] = "bad_shader.vert.glsl"
yield known_shaders
known_shaders["vertex"]["default"]["source"] = good_shader


def test_bad_shader(osmesa_empty, set_very_bad_shader):
shader_name = "box_outline"
program = shader_objects.component_shaders[shader_name]["default"]

vertex_shader = shader_objects._validate_shader(
"vertex", program["first_vertex"], allow_null=False
)
fragment_shader = shader_objects._validate_shader(
"fragment", program["first_fragment"], allow_null=False
)
with pytest.raises(RuntimeError, match="shader complilation error"):
_ = shader_objects.ShaderProgram(vertex_shader, fragment_shader, None)


@pytest.mark.parametrize("shader_name", list(shader_objects.component_shaders.keys()))
def test_shader_programs(osmesa_empty, shader_name):
for program in shader_objects.component_shaders[shader_name].values():

vertex_shader = shader_objects._validate_shader(
"vertex", program["first_vertex"], allow_null=False
)
assert isinstance(vertex_shader, shader_objects.Shader)
fragment_shader = shader_objects._validate_shader(
"fragment", program["first_fragment"], allow_null=False
)
assert isinstance(fragment_shader, shader_objects.Shader)
geometry_shader = program.get("first_geometry", None)
if geometry_shader is not None:
geometry_shader = shader_objects._validate_shader(
"geometry", geometry_shader, allow_null=False
)
assert isinstance(geometry_shader, shader_objects.Shader)

_ = shader_objects.ShaderProgram(
vertex_shader, fragment_shader, geometry_shader
)

colormap_vertex = shader_objects._validate_shader(
"vertex", program["second_vertex"], allow_null=False
)
assert isinstance(colormap_vertex, shader_objects.Shader)
colormap_fragment = shader_objects._validate_shader(
"fragment", program["second_fragment"], allow_null=False
)
assert isinstance(colormap_fragment, shader_objects.Shader)
_ = shader_objects.ShaderProgram(colormap_vertex, colormap_fragment)

0 comments on commit 470f096

Please sign in to comment.