diff --git a/yt_idv/shader_objects.py b/yt_idv/shader_objects.py index 654707b..384cc83 100644 --- a/yt_idv/shader_objects.py +++ b/yt_idv/shader_objects.py @@ -243,6 +243,7 @@ class Shader(traitlets.HasTraits): """ _shader = None + allow_null = traitlets.Bool(True) source = traitlets.Any() shader_name = traitlets.CUnicode() info = traitlets.CUnicode() @@ -252,6 +253,7 @@ class Shader(traitlets.HasTraits): ) blend_equation = GLValue("func add") depth_test = GLValue("always") + _filename = traitlets.Unicode() use_separate_blend = traitlets.Bool(False) blend_equation_separate = traitlets.Tuple( @@ -295,6 +297,7 @@ def _get_source(self, source): fn = os.path.join(sh_directory, fn) if not os.path.isfile(fn): raise YTInvalidShaderType(fn) + self._filename = fn full_source.append(open(fn).read()) return "\n\n".join(full_source) @@ -320,7 +323,11 @@ def compile(self, source=None): GL.glCompileShader(shader) result = GL.glGetShaderiv(shader, GL.GL_COMPILE_STATUS) if not (result): - raise RuntimeError(GL.glGetShaderInfoLog(shader)) + msg = f"shader complilation error for {self.shader_name} {self.shader_type}" + if self._filename is not None: + msg += f" in {self._filename}." + msg += f"\n GL.glGetShaderInfoLog: \n {GL.glGetShaderInfoLog(shader)}" + raise RuntimeError(msg) self._shader = shader def setup_blend(self): @@ -340,10 +347,13 @@ def shader(self): try: self.compile() except RuntimeError as exc: - print(exc) - for line_num, line in enumerate(self.shader_source.split("\n")): - print(f"{line_num + 1:05}: {line}") - self._enable_null_shader() + if self.allow_null: + print(exc) + for line_num, line in enumerate(self.shader_source.split("\n")): + print(f"{line_num + 1:05}: {line}") + self._enable_null_shader() + else: + raise exc return self._shader def delete_shader(self): @@ -356,6 +366,15 @@ def __del__(self): self.delete_shader() +def _validate_shader(shader_type, value, allow_null=True): + shader_info = known_shaders[shader_type][value] + shader_info.setdefault("shader_type", shader_type) + shader_info["use_separate_blend"] = bool("blend_func_separate" in shader_info) + shader_info.setdefault("shader_name", value) + shader = Shader(allow_null=allow_null, **shader_info) + return shader + + class ShaderTrait(traitlets.TraitType): default_value = None info_text = "A shader (vertex, fragment or geometry)" @@ -364,14 +383,7 @@ def validate(self, obj, value): if isinstance(value, str): try: shader_type = self.metadata.get("shader_type", "vertex") - shader_info = known_shaders[shader_type][value] - shader_info.setdefault("shader_type", shader_type) - shader_info["use_separate_blend"] = bool( - "blend_func_separate" in shader_info - ) - shader_info.setdefault("shader_name", value) - shader = Shader(**shader_info) - return shader + return _validate_shader(shader_type, value) except KeyError: self.error(obj, value) elif isinstance(value, Shader): diff --git a/yt_idv/shaders/bad_shader.vert.glsl b/yt_idv/shaders/bad_shader.vert.glsl new file mode 100644 index 0000000..fab14de --- /dev/null +++ b/yt_idv/shaders/bad_shader.vert.glsl @@ -0,0 +1,18 @@ +in vec4 model_vertex; // The location of the vertex in model space +in vec3 in_dx; +in vec3 in_left_edge; +in vec3 in_right_edge; +out vec4 v_model; +flat out vec3 v_camera_pos; +flat out mat4 inverse_proj; +flat out mat4 inverse_mvm; +flat out mat4 inverse_pmvm; +flat out vec3 dx; +flat out vec3 left_edge; +flat out vec3 right_edge; + +void main() +{ + this_is a very bad shader. do_not use it. it is for_testing that the + pytest tests_do indeed catch shader compilation errors. +} diff --git a/yt_idv/tests/test_yt_idv.py b/yt_idv/tests/test_yt_idv.py index 9ebce3f..8a47cf1 100644 --- a/yt_idv/tests/test_yt_idv.py +++ b/yt_idv/tests/test_yt_idv.py @@ -9,6 +9,7 @@ from pytest_html import extras import yt_idv +from yt_idv import shader_objects from yt_idv.scene_components.curves import CurveCollectionRendering, CurveRendering from yt_idv.scene_data.curve import CurveCollection, CurveData @@ -151,10 +152,67 @@ def test_curves(osmesa_fake_amr, image_store): curve_collection.add_data() # call add_data() after done adding curves cc_render = CurveCollectionRendering( - data=curve_collection, curve_rgb=(0.2, 0.2, 0.2, 1.0), line_width=4 + data=curve_collection, curve_rgba=(0.2, 0.2, 0.2, 1.0), line_width=4 ) cc_render.display_name = "multiple streamlines" osmesa_fake_amr.scene.data_objects.append(curve_collection) osmesa_fake_amr.scene.components.append(cc_render) image_store(osmesa_fake_amr) + + +@pytest.fixture() +def set_very_bad_shader(): + known_shaders = shader_objects.known_shaders + good_shader = known_shaders["vertex"]["default"]["source"] + known_shaders["vertex"]["default"]["source"] = "bad_shader.vert.glsl" + yield known_shaders + known_shaders["vertex"]["default"]["source"] = good_shader + + +def test_bad_shader(osmesa_empty, set_very_bad_shader): + shader_name = "box_outline" + program = shader_objects.component_shaders[shader_name]["default"] + + vertex_shader = shader_objects._validate_shader( + "vertex", program["first_vertex"], allow_null=False + ) + fragment_shader = shader_objects._validate_shader( + "fragment", program["first_fragment"], allow_null=False + ) + with pytest.raises(RuntimeError, match="shader complilation error"): + _ = shader_objects.ShaderProgram(vertex_shader, fragment_shader, None) + + +@pytest.mark.parametrize("shader_name", list(shader_objects.component_shaders.keys())) +def test_shader_programs(osmesa_empty, shader_name): + for program in shader_objects.component_shaders[shader_name].values(): + + vertex_shader = shader_objects._validate_shader( + "vertex", program["first_vertex"], allow_null=False + ) + assert isinstance(vertex_shader, shader_objects.Shader) + fragment_shader = shader_objects._validate_shader( + "fragment", program["first_fragment"], allow_null=False + ) + assert isinstance(fragment_shader, shader_objects.Shader) + geometry_shader = program.get("first_geometry", None) + if geometry_shader is not None: + geometry_shader = shader_objects._validate_shader( + "geometry", geometry_shader, allow_null=False + ) + assert isinstance(geometry_shader, shader_objects.Shader) + + _ = shader_objects.ShaderProgram( + vertex_shader, fragment_shader, geometry_shader + ) + + colormap_vertex = shader_objects._validate_shader( + "vertex", program["second_vertex"], allow_null=False + ) + assert isinstance(colormap_vertex, shader_objects.Shader) + colormap_fragment = shader_objects._validate_shader( + "fragment", program["second_fragment"], allow_null=False + ) + assert isinstance(colormap_fragment, shader_objects.Shader) + _ = shader_objects.ShaderProgram(colormap_vertex, colormap_fragment)