Skip to content

Commit

Permalink
Merge pull request yt-project#133 from chrishavlin/fail_pytest_on_sha…
Browse files Browse the repository at this point in the history
…der_errors

Add tests for shader program compilation
  • Loading branch information
chrishavlin authored Jul 1, 2024
2 parents 53c5ec4 + 60ef4e9 commit ec345f6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 14 deletions.
43 changes: 30 additions & 13 deletions yt_idv/shader_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,16 @@ class Shader(traitlets.HasTraits):
This can either be a string containing a full source of a shader,
an absolute path to a source file or a filename of a shader
residing in the ./shaders/ directory.
allow_null : bool
If True (default) then shader compilation errors will be caught and
printed without raising exception (convenient for general use and for
developing new shaders). If False, any compilation errors will raise
a RunTimeError (useful for CI testing).
"""

_shader = None
allow_null = traitlets.Bool(True)
source = traitlets.Any()
shader_name = traitlets.CUnicode()
info = traitlets.CUnicode()
Expand All @@ -252,6 +258,7 @@ class Shader(traitlets.HasTraits):
)
blend_equation = GLValue("func add")
depth_test = GLValue("always")
_filename = traitlets.Unicode()

use_separate_blend = traitlets.Bool(False)
blend_equation_separate = traitlets.Tuple(
Expand Down Expand Up @@ -295,6 +302,7 @@ def _get_source(self, source):
fn = os.path.join(sh_directory, fn)
if not os.path.isfile(fn):
raise YTInvalidShaderType(fn)
self._filename = fn
full_source.append(open(fn).read())
return "\n\n".join(full_source)

Expand All @@ -320,7 +328,11 @@ def compile(self, source=None):
GL.glCompileShader(shader)
result = GL.glGetShaderiv(shader, GL.GL_COMPILE_STATUS)
if not (result):
raise RuntimeError(GL.glGetShaderInfoLog(shader))
msg = f"shader complilation error for {self.shader_name} {self.shader_type}"
if self._filename is not None:
msg += f" in {self._filename}."
msg += f"\n GL.glGetShaderInfoLog: \n {GL.glGetShaderInfoLog(shader)}"
raise RuntimeError(msg)
self._shader = shader

def setup_blend(self):
Expand All @@ -340,10 +352,13 @@ def shader(self):
try:
self.compile()
except RuntimeError as exc:
print(exc)
for line_num, line in enumerate(self.shader_source.split("\n")):
print(f"{line_num + 1:05}: {line}")
self._enable_null_shader()
if self.allow_null:
print(exc)
for line_num, line in enumerate(self.shader_source.split("\n")):
print(f"{line_num + 1:05}: {line}")
self._enable_null_shader()
else:
raise exc
return self._shader

def delete_shader(self):
Expand All @@ -356,6 +371,15 @@ def __del__(self):
self.delete_shader()


def _validate_shader(shader_type, value, allow_null=True):
shader_info = known_shaders[shader_type][value]
shader_info.setdefault("shader_type", shader_type)
shader_info["use_separate_blend"] = bool("blend_func_separate" in shader_info)
shader_info.setdefault("shader_name", value)
shader = Shader(allow_null=allow_null, **shader_info)
return shader


class ShaderTrait(traitlets.TraitType):
default_value = None
info_text = "A shader (vertex, fragment or geometry)"
Expand All @@ -364,14 +388,7 @@ def validate(self, obj, value):
if isinstance(value, str):
try:
shader_type = self.metadata.get("shader_type", "vertex")
shader_info = known_shaders[shader_type][value]
shader_info.setdefault("shader_type", shader_type)
shader_info["use_separate_blend"] = bool(
"blend_func_separate" in shader_info
)
shader_info.setdefault("shader_name", value)
shader = Shader(**shader_info)
return shader
return _validate_shader(shader_type, value)
except KeyError:
self.error(obj, value)
elif isinstance(value, Shader):
Expand Down
18 changes: 18 additions & 0 deletions yt_idv/shaders/bad_shader.vert.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
in vec4 model_vertex; // The location of the vertex in model space
in vec3 in_dx;
in vec3 in_left_edge;
in vec3 in_right_edge;
out vec4 v_model;
flat out vec3 v_camera_pos;
flat out mat4 inverse_proj;
flat out mat4 inverse_mvm;
flat out mat4 inverse_pmvm;
flat out vec3 dx;
flat out vec3 left_edge;
flat out vec3 right_edge;

void main()
{
this_is a very bad shader. do_not use it. it is for_testing that the
pytest tests_do indeed catch shader compilation errors.
}
64 changes: 63 additions & 1 deletion yt_idv/tests/test_yt_idv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytest_html import extras

import yt_idv
from yt_idv import shader_objects
from yt_idv.scene_components.curves import CurveCollectionRendering, CurveRendering
from yt_idv.scene_data.curve import CurveCollection, CurveData

Expand Down Expand Up @@ -151,10 +152,71 @@ def test_curves(osmesa_fake_amr, image_store):
curve_collection.add_data() # call add_data() after done adding curves

cc_render = CurveCollectionRendering(
data=curve_collection, curve_rgb=(0.2, 0.2, 0.2, 1.0), line_width=4
data=curve_collection, curve_rgba=(0.2, 0.2, 0.2, 1.0), line_width=4
)
cc_render.display_name = "multiple streamlines"
osmesa_fake_amr.scene.data_objects.append(curve_collection)
osmesa_fake_amr.scene.components.append(cc_render)

image_store(osmesa_fake_amr)


@pytest.fixture()
def set_very_bad_shader():
# this temporarily points the default vertex shader source file to a
# bad shader that will raise compilation errors.
known_shaders = shader_objects.known_shaders
good_shader = known_shaders["vertex"]["default"]["source"]
known_shaders["vertex"]["default"]["source"] = "bad_shader.vert.glsl"
yield known_shaders
known_shaders["vertex"]["default"]["source"] = good_shader


def test_bad_shader(osmesa_empty, set_very_bad_shader):
# this test is meant to check that a bad shader would indeed be caught
# by the subsequent test_shader_programs test.
shader_name = "box_outline"
program = shader_objects.component_shaders[shader_name]["default"]

vertex_shader = shader_objects._validate_shader(
"vertex", program["first_vertex"], allow_null=False
)
fragment_shader = shader_objects._validate_shader(
"fragment", program["first_fragment"], allow_null=False
)
with pytest.raises(RuntimeError, match="shader complilation error"):
_ = shader_objects.ShaderProgram(vertex_shader, fragment_shader, None)


@pytest.mark.parametrize("shader_name", list(shader_objects.component_shaders.keys()))
def test_shader_programs(osmesa_empty, shader_name):
for program in shader_objects.component_shaders[shader_name].values():

vertex_shader = shader_objects._validate_shader(
"vertex", program["first_vertex"], allow_null=False
)
assert isinstance(vertex_shader, shader_objects.Shader)
fragment_shader = shader_objects._validate_shader(
"fragment", program["first_fragment"], allow_null=False
)
assert isinstance(fragment_shader, shader_objects.Shader)
geometry_shader = program.get("first_geometry", None)
if geometry_shader is not None:
geometry_shader = shader_objects._validate_shader(
"geometry", geometry_shader, allow_null=False
)
assert isinstance(geometry_shader, shader_objects.Shader)

_ = shader_objects.ShaderProgram(
vertex_shader, fragment_shader, geometry_shader
)

colormap_vertex = shader_objects._validate_shader(
"vertex", program["second_vertex"], allow_null=False
)
assert isinstance(colormap_vertex, shader_objects.Shader)
colormap_fragment = shader_objects._validate_shader(
"fragment", program["second_fragment"], allow_null=False
)
assert isinstance(colormap_fragment, shader_objects.Shader)
_ = shader_objects.ShaderProgram(colormap_vertex, colormap_fragment)

0 comments on commit ec345f6

Please sign in to comment.