From d14917d42567ae868d43821154b7df5b869fa0c3 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 27 Oct 2023 22:58:01 -0700 Subject: [PATCH] Fixed a-trous kernel and variance guidance --- blade-render/code/blur.wgsl | 57 +++++++++++++++++++++--------- blade-render/code/surface.inc.wgsl | 4 ++- blade-render/src/render/mod.rs | 9 ++--- examples/scene/main.rs | 4 +-- 4 files changed, 50 insertions(+), 24 deletions(-) diff --git a/blade-render/code/blur.wgsl b/blade-render/code/blur.wgsl index 053fba62..dd981550 100644 --- a/blade-render/code/blur.wgsl +++ b/blade-render/code/blur.wgsl @@ -5,9 +5,12 @@ // Spatio-temporal variance-guided filtering // https://research.nvidia.com/sites/default/files/pubs/2017-07_Spatiotemporal-Variance-Guided-Filtering%3A//svgf_preprint.pdf +// Note: using "ilm" in place of "illumination and the 2nd moment of its luminanc" + struct Params { extent: vec2, temporal_weight: f32, + iteration: u32, } var camera: CameraParams; @@ -21,6 +24,8 @@ var input: texture_2d; var prev_input: texture_2d; var output: texture_storage_2d; +const LUMA: vec3 = vec3(0.2126, 0.7152, 0.0722); + fn get_projected_pixel_quad(cp: CameraParams, point: vec3) -> array, 4> { let pixel = get_projected_pixel_float(cp, point); return array, 4>( @@ -51,13 +56,14 @@ fn temporal_accum(@builtin(global_invocation_id) global_id: vec3) { return; } //TODO: use motion vectors - let cur_radiance = textureLoad(input, pixel, 0).xyz; + let cur_illumination = textureLoad(input, pixel, 0).xyz; let surface = read_surface(pixel); let pos_world = camera.position + surface.depth * get_ray_direction(camera, pixel); // considering all samples in 2x2 quad, to help with edges var prev_pixels = get_projected_pixel_quad(prev_camera, pos_world); var best_index = 0; var best_weight = 0.0; + //TODO: optimize depth load with a gather operation for (var i = 0; i < 4; i += 1) { let prev_pixel = prev_pixels[i]; if (all(prev_pixel >= vec2(0)) && all(prev_pixel < params.extent)) { @@ -72,45 +78,62 @@ fn temporal_accum(@builtin(global_invocation_id) global_id: vec3) { } } - var prev_radiance = cur_radiance; + let luminocity = dot(cur_illumination, LUMA); + var mixed_ilm = vec4(cur_illumination, luminocity * luminocity); if (best_weight > 0.01) { - prev_radiance = textureLoad(prev_input, prev_pixels[best_index], 0).xyz; + let prev_ilm = textureLoad(prev_input, prev_pixels[best_index], 0); + mixed_ilm = mix(mixed_ilm, prev_ilm, best_weight * (1.0 - params.temporal_weight)); } - let radiance = mix(cur_radiance, prev_radiance, best_weight * (1.0 - params.temporal_weight)); - textureStore(output, global_id.xy, vec4(radiance, 0.0)); + textureStore(output, global_id.xy, mixed_ilm); +} + +const GAUSSIAN_WEIGHTS = vec2(0.44198, 0.27901); +const SIGMA_L: f32 = 4.0; +const EPSILON: f32 = 0.001; + +fn compare_luminance(a_lum: f32, b_lum: f32, variance: f32) -> f32 { + return exp(-abs(a_lum - b_lum) / (SIGMA_L * variance + EPSILON)); } -const gaussian_weights = vec2(0.44198, 0.27901); +fn w4(w: f32) -> vec4 { + return vec4(vec3(w), w * w); +} @compute @workgroup_size(8, 8) -fn atrous(@builtin(global_invocation_id) global_id: vec3) { +fn atrous3x3(@builtin(global_invocation_id) global_id: vec3) { let center = vec2(global_id.xy); if (any(center >= params.extent)) { return; } - let center_radiance = textureLoad(input, center, 0).xyz; + let center_ilm = textureLoad(input, center, 0); + let center_luma = dot(center_ilm.xyz, LUMA); + let variance = sqrt(center_ilm.w); let center_suf = read_surface(center); - var sum_weight = gaussian_weights[0] * gaussian_weights[0]; - var sum_radiance = center_radiance * sum_weight; + var sum_weight = GAUSSIAN_WEIGHTS[0] * GAUSSIAN_WEIGHTS[0]; + var sum_ilm = w4(sum_weight) * center_ilm; for (var yy=-1; yy<=1; yy+=1) { for (var xx=-1; xx<=1; xx+=1) { - let p = center + vec2(xx, yy); + let p = center + vec2(xx, yy) * (1 << params.iteration); if (all(p == center) || any(p < vec2(0)) || any(p >= params.extent)) { continue; } //TODO: store in group-shared memory let surface = read_surface(p); - var weight = gaussian_weights[abs(xx)] * gaussian_weights[abs(yy)]; - //weight *= compare_surfaces(center_suf, surface); - let radiance = textureLoad(input, p, 0).xyz; - sum_radiance += weight * radiance; + var weight = GAUSSIAN_WEIGHTS[abs(xx)] * GAUSSIAN_WEIGHTS[abs(yy)]; + //TODO: make it stricter on higher iterations + weight *= compare_flat_normals(surface.flat_normal, center_suf.flat_normal); + //Note: should we use a projected depth instead of the surface one? + weight *= compare_depths(surface.depth, center_suf.depth); + let other_ilm = textureLoad(input, p, 0); + weight *= compare_luminance(center_luma, dot(other_ilm.xyz, LUMA), variance); + sum_ilm += w4(weight) * other_ilm; sum_weight += weight; } } - let radiance = sum_radiance / sum_weight; - textureStore(output, global_id.xy, vec4(radiance, 0.0)); + let filtered_ilm = sum_ilm / w4(sum_weight); + textureStore(output, global_id.xy, filtered_ilm); } diff --git a/blade-render/code/surface.inc.wgsl b/blade-render/code/surface.inc.wgsl index d6c70fcc..c8327777 100644 --- a/blade-render/code/surface.inc.wgsl +++ b/blade-render/code/surface.inc.wgsl @@ -4,8 +4,10 @@ struct Surface { depth: f32, } +const SIGMA_N: f32 = 4.0; + fn compare_flat_normals(a: vec3, b: vec3) -> f32 { - return smoothstep(0.4, 0.9, dot(a, b)); + return pow(max(0.0, dot(a, b)), SIGMA_N); } fn compare_depths(a: f32, b: f32) -> f32 { diff --git a/blade-render/src/render/mod.rs b/blade-render/src/render/mod.rs index 1a99f59d..11059cc4 100644 --- a/blade-render/src/render/mod.rs +++ b/blade-render/src/render/mod.rs @@ -451,7 +451,7 @@ struct MainData { struct BlurParams { extent: [u32; 2], temporal_weight: f32, - pad: f32, + iteration: i32, } #[derive(blade_macros::ShaderData)] @@ -624,7 +624,7 @@ impl ShaderPipelines { gpu.create_compute_pipeline(blade_graphics::ComputePipelineDesc { name: "atrous", data_layouts: &[&layout], - compute: shader.at("atrous"), + compute: shader.at("atrous3x3"), }) } @@ -1342,10 +1342,10 @@ impl Renderer { command_encoder: &mut blade_graphics::CommandEncoder, denoiser_config: DenoiserConfig, ) { - let params = BlurParams { + let mut params = BlurParams { extent: [self.screen_size.width, self.screen_size.height], temporal_weight: denoiser_config.temporal_weight, - pad: 0.0, + iteration: 0, }; if denoiser_config.temporal_weight < 1.0 { let cur = self.frame_data.first().unwrap(); @@ -1400,6 +1400,7 @@ impl Renderer { pc.dispatch(groups); self.post_proc_input = targets[0]; targets.swap(0, 1); // rotate the views + params.iteration += 1; } } } diff --git a/examples/scene/main.rs b/examples/scene/main.rs index ccd0d4d6..55e11717 100644 --- a/examples/scene/main.rs +++ b/examples/scene/main.rs @@ -234,7 +234,7 @@ impl Example { }, denoiser_enabled: true, denoiser_config: blade_render::DenoiserConfig { - num_passes: 5, + num_passes: 2, temporal_weight: 0.1, }, debug_blit: None, @@ -613,7 +613,7 @@ impl Example { egui::Slider::new(&mut dc.temporal_weight, 0.0..=1.0f32) .text("Temporal weight"), ); - ui.add(egui::Slider::new(&mut dc.num_passes, 0..=15u32).text("A-trous passes")); + ui.add(egui::Slider::new(&mut dc.num_passes, 0..=5u32).text("A-trous passes")); }); egui::CollapsingHeader::new("Tone Map").show(ui, |ui| {