From 85f12adb4601887506fc10ded28a6b7bc886ebc7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 3 Sep 2024 00:14:27 -0700 Subject: [PATCH] Merge temporal accumulation into the Main pass --- blade-helpers/src/hud.rs | 1 + blade-render/code/a-trous.wgsl | 83 ++++++++++++ blade-render/code/accum.inc.wgsl | 19 +++ blade-render/code/blur.wgsl | 162 ----------------------- blade-render/code/ray-trace.wgsl | 79 +++++++----- blade-render/src/render/mod.rs | 215 +++++++++++-------------------- examples/scene/main.rs | 17 +-- src/lib.rs | 17 +-- 8 files changed, 247 insertions(+), 346 deletions(-) create mode 100644 blade-render/code/a-trous.wgsl create mode 100644 blade-render/code/accum.inc.wgsl delete mode 100644 blade-render/code/blur.wgsl diff --git a/blade-helpers/src/hud.rs b/blade-helpers/src/hud.rs index ff4cf5a5..d8b9b279 100644 --- a/blade-helpers/src/hud.rs +++ b/blade-helpers/src/hud.rs @@ -37,6 +37,7 @@ impl ExposeHud for blade_render::RayConfig { impl ExposeHud for blade_render::DenoiserConfig { fn populate_hud(&mut self, ui: &mut egui::Ui) { + ui.checkbox(&mut self.enabled, "Enable denoiser"); ui.add(egui::Slider::new(&mut self.temporal_weight, 0.0..=1.0f32).text("Temporal weight")); ui.add(egui::Slider::new(&mut self.num_passes, 0..=5u32).text("A-trous passes")); } diff --git a/blade-render/code/a-trous.wgsl b/blade-render/code/a-trous.wgsl new file mode 100644 index 00000000..19277d20 --- /dev/null +++ b/blade-render/code/a-trous.wgsl @@ -0,0 +1,83 @@ +#include "camera.inc.wgsl" +#include "quaternion.inc.wgsl" +#include "surface.inc.wgsl" + +// Spatio-temporal variance-guided filtering +// https://research.nvidia.com/sites/default/files/pubs/2017-07_Spatiotemporal-Variance-Guided-Filtering%3A//svgf_preprint.pdf + +// Note: using "ilm" in place of "illumination and the 2nd moment of its luminance" + +struct Params { + extent: vec2, + iteration: u32, +} + +var camera: CameraParams; +var prev_camera: CameraParams; +var params: Params; +var t_depth: texture_2d; +var t_flat_normal: texture_2d; +var t_motion: texture_2d; +var input: texture_2d; +var output: texture_storage_2d; + +const LUMA: vec3 = vec3(0.2126, 0.7152, 0.0722); +const MIN_WEIGHT: f32 = 0.01; + +fn read_surface(pixel: vec2) -> Surface { + var surface = Surface(); + surface.flat_normal = normalize(textureLoad(t_flat_normal, pixel, 0).xyz); + surface.depth = textureLoad(t_depth, pixel, 0).x; + return surface; +} + +const GAUSSIAN_WEIGHTS = vec2(0.44198, 0.27901); +const SIGMA_L: f32 = 4.0; +const EPSILON: f32 = 0.001; + +fn compare_luminance(a_lum: f32, b_lum: f32, variance: f32) -> f32 { + return exp(-abs(a_lum - b_lum) / (SIGMA_L * variance + EPSILON)); +} + +fn w4(w: f32) -> vec4 { + return vec4(vec3(w), w * w); +} + +@compute @workgroup_size(8, 8) +fn atrous3x3(@builtin(global_invocation_id) global_id: vec3) { + let center = vec2(global_id.xy); + if (any(center >= params.extent)) { + return; + } + + let center_ilm = textureLoad(input, center, 0); + let center_luma = dot(center_ilm.xyz, LUMA); + let variance = sqrt(center_ilm.w); + let center_suf = read_surface(center); + var sum_weight = GAUSSIAN_WEIGHTS[0] * GAUSSIAN_WEIGHTS[0]; + var sum_ilm = w4(sum_weight) * center_ilm; + + for (var yy=-1; yy<=1; yy+=1) { + for (var xx=-1; xx<=1; xx+=1) { + let p = center + vec2(xx, yy) * (1i << params.iteration); + if (all(p == center) || any(p < vec2(0)) || any(p >= params.extent)) { + continue; + } + + //TODO: store in group-shared memory + let surface = read_surface(p); + var weight = GAUSSIAN_WEIGHTS[abs(xx)] * GAUSSIAN_WEIGHTS[abs(yy)]; + //TODO: make it stricter on higher iterations + weight *= compare_flat_normals(surface.flat_normal, center_suf.flat_normal); + //Note: should we use a projected depth instead of the surface one? + weight *= compare_depths(surface.depth, center_suf.depth); + let other_ilm = textureLoad(input, p, 0); + weight *= compare_luminance(center_luma, dot(other_ilm.xyz, LUMA), variance); + sum_ilm += w4(weight) * other_ilm; + sum_weight += weight; + } + } + + let filtered_ilm = select(center_ilm, sum_ilm / w4(sum_weight), sum_weight > MIN_WEIGHT); + textureStore(output, global_id.xy, filtered_ilm); +} diff --git a/blade-render/code/accum.inc.wgsl b/blade-render/code/accum.inc.wgsl new file mode 100644 index 00000000..cc599ee3 --- /dev/null +++ b/blade-render/code/accum.inc.wgsl @@ -0,0 +1,19 @@ +const LUMA: vec3 = vec3(0.2126, 0.7152, 0.0722); + +var inout_diffuse: texture_storage_2d; + +fn accumulate_temporal( + surface: Surface, position: vec3, pixel_coord: vec2, + cur_illumination: vec3, temporal_weight: f32, + prev_surface: Surface, prev_pixel: vec2, prev_valid: bool, +) { + let cur_luminocity = dot(cur_illumination, LUMA); + var ilm = vec4(cur_illumination, cur_luminocity * cur_luminocity); + if (prev_valid && temporal_weight < 1.0) { + let illumination = textureLoad(inout_diffuse, prev_pixel).xyz; + let luminocity = dot(illumination, LUMA); + let prev_ilm = vec4(illumination, luminocity * luminocity); + ilm = mix(prev_ilm, ilm, temporal_weight); + } + textureStore(inout_diffuse, pixel_coord, ilm); +} diff --git a/blade-render/code/blur.wgsl b/blade-render/code/blur.wgsl deleted file mode 100644 index 3aec665b..00000000 --- a/blade-render/code/blur.wgsl +++ /dev/null @@ -1,162 +0,0 @@ -#include "camera.inc.wgsl" -#include "motion.inc.wgsl" -#include "quaternion.inc.wgsl" -#include "surface.inc.wgsl" - -// Spatio-temporal variance-guided filtering -// https://research.nvidia.com/sites/default/files/pubs/2017-07_Spatiotemporal-Variance-Guided-Filtering%3A//svgf_preprint.pdf - -// Note: using "ilm" in place of "illumination and the 2nd moment of its luminance" - -struct Params { - extent: vec2, - temporal_weight: f32, - iteration: u32, - use_motion_vectors: u32, -} - -var camera: CameraParams; -var prev_camera: CameraParams; -var params: Params; -var t_depth: texture_2d; -var t_prev_depth: texture_2d; -var t_flat_normal: texture_2d; -var t_prev_flat_normal: texture_2d; -var t_motion: texture_2d; -var input: texture_2d; -var prev_input: texture_2d; -var output: texture_storage_2d; - -const LUMA: vec3 = vec3(0.2126, 0.7152, 0.0722); -const MIN_WEIGHT: f32 = 0.01; - -fn read_surface(pixel: vec2) -> Surface { - var surface = Surface(); - surface.flat_normal = normalize(textureLoad(t_flat_normal, pixel, 0).xyz); - surface.depth = textureLoad(t_depth, pixel, 0).x; - return surface; -} -fn read_prev_surface(pixel: vec2) -> Surface { - var surface = Surface(); - surface.flat_normal = normalize(textureLoad(t_prev_flat_normal, pixel, 0).xyz); - surface.depth = textureLoad(t_prev_depth, pixel, 0).x; - return surface; -} - -fn get_prev_pixel(pixel: vec2, pos_world: vec3) -> vec2 { - if (USE_MOTION_VECTORS && params.use_motion_vectors != 0u) { - let motion = textureLoad(t_motion, pixel, 0).xy / MOTION_SCALE; - return vec2(pixel) + 0.5 + motion; - } else { - return get_projected_pixel_float(prev_camera, pos_world); - } -} - -@compute @workgroup_size(8, 8) -fn temporal_accum(@builtin(global_invocation_id) global_id: vec3) { - let pixel = vec2(global_id.xy); - if (any(pixel >= params.extent)) { - return; - } - - let surface = read_surface(pixel); - let pos_world = camera.position + surface.depth * get_ray_direction(camera, pixel); - // considering all samples in 2x2 quad, to help with edges - var center_pixel = get_prev_pixel(pixel, pos_world); - var prev_pixels = array, 4>( - vec2(vec2(center_pixel.x - 0.5, center_pixel.y - 0.5)), - vec2(vec2(center_pixel.x + 0.5, center_pixel.y - 0.5)), - vec2(vec2(center_pixel.x + 0.5, center_pixel.y + 0.5)), - vec2(vec2(center_pixel.x - 0.5, center_pixel.y + 0.5)), - ); - //Note: careful about the pixel center when there is a perfect match - let w_bot_right = fract(center_pixel + vec2(0.5)); - var prev_weights = vec4( - (1.0 - w_bot_right.x) * (1.0 - w_bot_right.y), - w_bot_right.x * (1.0 - w_bot_right.y), - w_bot_right.x * w_bot_right.y, - (1.0 - w_bot_right.x) * w_bot_right.y, - ); - - var sum_weight = 0.0; - var sum_ilm = vec4(0.0); - //TODO: optimize depth load with a gather operation - for (var i = 0; i < 4; i += 1) { - let prev_pixel = prev_pixels[i]; - if (all(prev_pixel >= vec2(0)) && all(prev_pixel < params.extent)) { - let prev_surface = read_prev_surface(prev_pixel); - if (compare_flat_normals(surface.flat_normal, prev_surface.flat_normal) < 0.5) { - continue; - } - let projected_distance = length(pos_world - prev_camera.position); - if (compare_depths(prev_surface.depth, projected_distance) < 0.5) { - continue; - } - let w = prev_weights[i]; - sum_weight += w; - let illumination = w * textureLoad(prev_input, prev_pixel, 0).xyz; - let luminocity = dot(illumination, LUMA); - sum_ilm += vec4(illumination, luminocity * luminocity); - } - } - - let cur_illumination = textureLoad(input, pixel, 0).xyz; - let cur_luminocity = dot(cur_illumination, LUMA); - var mixed_ilm = vec4(cur_illumination, cur_luminocity * cur_luminocity); - if (sum_weight > MIN_WEIGHT) { - let prev_ilm = sum_ilm / vec4(vec3(sum_weight), max(0.001, sum_weight*sum_weight)); - mixed_ilm = mix(mixed_ilm, prev_ilm, sum_weight * (1.0 - params.temporal_weight)); - } - textureStore(output, global_id.xy, mixed_ilm); -} - -const GAUSSIAN_WEIGHTS = vec2(0.44198, 0.27901); -const SIGMA_L: f32 = 4.0; -const EPSILON: f32 = 0.001; - -fn compare_luminance(a_lum: f32, b_lum: f32, variance: f32) -> f32 { - return exp(-abs(a_lum - b_lum) / (SIGMA_L * variance + EPSILON)); -} - -fn w4(w: f32) -> vec4 { - return vec4(vec3(w), w * w); -} - -@compute @workgroup_size(8, 8) -fn atrous3x3(@builtin(global_invocation_id) global_id: vec3) { - let center = vec2(global_id.xy); - if (any(center >= params.extent)) { - return; - } - - let center_ilm = textureLoad(input, center, 0); - let center_luma = dot(center_ilm.xyz, LUMA); - let variance = sqrt(center_ilm.w); - let center_suf = read_surface(center); - var sum_weight = GAUSSIAN_WEIGHTS[0] * GAUSSIAN_WEIGHTS[0]; - var sum_ilm = w4(sum_weight) * center_ilm; - - for (var yy=-1; yy<=1; yy+=1) { - for (var xx=-1; xx<=1; xx+=1) { - let p = center + vec2(xx, yy) * (1i << params.iteration); - if (all(p == center) || any(p < vec2(0)) || any(p >= params.extent)) { - continue; - } - - //TODO: store in group-shared memory - let surface = read_surface(p); - var weight = GAUSSIAN_WEIGHTS[abs(xx)] * GAUSSIAN_WEIGHTS[abs(yy)]; - //TODO: make it stricter on higher iterations - weight *= compare_flat_normals(surface.flat_normal, center_suf.flat_normal); - //Note: should we use a projected depth instead of the surface one? - weight *= compare_depths(surface.depth, center_suf.depth); - let other_ilm = textureLoad(input, p, 0); - weight *= compare_luminance(center_luma, dot(other_ilm.xyz, LUMA), variance); - sum_ilm += w4(weight) * other_ilm; - sum_weight += weight; - } - } - - let filtered_ilm = select(center_ilm, sum_ilm / w4(sum_weight), sum_weight > MIN_WEIGHT); - textureStore(output, global_id.xy, filtered_ilm); -} diff --git a/blade-render/code/ray-trace.wgsl b/blade-render/code/ray-trace.wgsl index 2abf2147..fe1a0f19 100644 --- a/blade-render/code/ray-trace.wgsl +++ b/blade-render/code/ray-trace.wgsl @@ -8,6 +8,7 @@ #include "surface.inc.wgsl" #include "geometry.inc.wgsl" #include "motion.inc.wgsl" +#include "accum.inc.wgsl" const PI: f32 = 3.1415926; const MAX_RESAMPLE: u32 = 4u; @@ -34,6 +35,7 @@ struct MainParams { t_start: f32, use_motion_vectors: u32, grid_scale: vec2, + temporal_accumulation_weight: f32, } var camera: CameraParams; @@ -62,6 +64,12 @@ struct PixelCache { world_pos: vec3, } var pixel_cache: array; +struct ReprojectionCache { + surface: Surface, + pixel_coord: vec2, + is_valid: bool, +} +var reprojection_cache: array; struct LightSample { radiance: vec3, @@ -142,10 +150,11 @@ fn pack_reservoir(r: LiveReservoir) -> StoredReservoir { return pack_reservoir_detail(r, r.history); } -var t_prev_depth: texture_2d; -var t_prev_basis: texture_2d; -var t_prev_flat_normal: texture_2d; -var out_diffuse: texture_storage_2d; +var inout_depth: texture_storage_2d; +var inout_basis: texture_storage_2d; +var inout_flat_normal: texture_storage_2d; +var out_albedo: texture_storage_2d; +var out_motion: texture_storage_2d; var out_debug: texture_storage_2d; fn sample_circle(random: f32) -> vec2 { @@ -201,9 +210,9 @@ fn sample_light_from_environment(rng: ptr) -> LightSample fn read_prev_surface(pixel: vec2) -> Surface { var surface: Surface; - surface.basis = normalize(textureLoad(t_prev_basis, pixel, 0)); - surface.flat_normal = normalize(textureLoad(t_prev_flat_normal, pixel, 0).xyz); - surface.depth = textureLoad(t_prev_depth, pixel, 0).x; + surface.basis = normalize(textureLoad(inout_basis, pixel)); + surface.flat_normal = normalize(textureLoad(inout_flat_normal, pixel).xyz); + surface.depth = textureLoad(inout_depth, pixel).x; return surface; } @@ -458,8 +467,9 @@ fn finalize_resampling( fn resample_temporal( surface: Surface, motion: vec2, cur_pixel: vec2, position: vec3, - rng: ptr, debug_len: f32, + local_index: u32, rng: ptr, debug_len: f32, ) -> ResampleOutput { + reprojection_cache[local_index].is_valid = false; if (debug.view_mode == DebugMode_TemporalMatch || debug.view_mode == DebugMode_TemporalMisCanonical || debug.view_mode == DebugMode_TemporalMisError) { textureStore(out_debug, cur_pixel, vec4(0.0)); } @@ -469,27 +479,32 @@ fn resample_temporal( let canonical = produce_canonical(surface, position, rng, debug_len); //TODO: find best match in a 2x2 grid - let prev_pixel = vec2(get_prev_pixel(cur_pixel, position, motion)); + var prev = ReprojectionCache(); + prev.pixel_coord = vec2(get_prev_pixel(cur_pixel, position, motion)); - let prev_reservoir_index = get_reservoir_index(prev_pixel, prev_camera); + let prev_reservoir_index = get_reservoir_index(prev.pixel_coord, prev_camera); if (parameters.temporal_tap == 0u || prev_reservoir_index < 0) { return finalize_canonical(canonical); } let prev_reservoir = reservoirs[prev_reservoir_index]; - let prev_surface = read_prev_surface(prev_pixel); + prev.surface = read_prev_surface(prev.pixel_coord); + prev.is_valid = compare_surfaces(surface, prev.surface) > 0.1; // if the surfaces are too different, there is no trust in this sample - if (prev_reservoir.confidence == 0.0 || compare_surfaces(surface, prev_surface) < 0.1) { + if (prev_reservoir.confidence == 0.0 || !prev.is_valid) { return finalize_canonical(canonical); } + // Write down the reprojection cache, no need to carry this around + reprojection_cache[local_index] = prev; + var reservoir = LiveReservoir(); var color_and_weight = vec4(0.0); let base = ResampleBase(surface, canonical, position, 1.0); - let prev_dir = get_ray_direction(prev_camera, prev_pixel); - let prev_world_pos = prev_camera.position + prev_surface.depth * prev_dir; - let other = PixelCache(prev_surface, prev_reservoir, prev_world_pos); + let prev_dir = get_ray_direction(prev_camera, prev.pixel_coord); + let prev_world_pos = prev_camera.position + prev.surface.depth * prev_dir; + let other = PixelCache(prev.surface, prev_reservoir, prev_world_pos); let rr = resample(&reservoir, &color_and_weight, base, other, prev_acc_struct, parameters.temporal_history, rng, debug_len); let mis_canonical = 1.0 + rr.mis_canonical; @@ -503,6 +518,7 @@ fn resample_temporal( let total = mis_canonical + rr.mis_sample; textureStore(out_debug, cur_pixel, vec4(abs(total - 1.0 - base.accepted_count))); } + return finalize_resampling(&reservoir, &color_and_weight, base, mis_canonical, rng); } @@ -575,7 +591,7 @@ fn compute_restir( ) -> vec3 { let debug_len = select(0.0, rs.inner.depth * 0.2, enable_debug); - let temporal = resample_temporal(rs.inner, rs.motion, pixel, rs.position, rng, debug_len); + let temporal = resample_temporal(rs.inner, rs.motion, pixel, rs.position, local_index, rng, debug_len); pixel_cache[local_index] = PixelCache(rs.inner, temporal.reservoir, rs.position); // sync with the workgroup to ensure all reservoirs are available. @@ -586,15 +602,18 @@ fn compute_restir( let pixel_index = get_reservoir_index(pixel, camera); reservoirs[pixel_index] = spatial.reservoir; + + //Note: restoring it from the LDS allows to lower the register pressure during spatial re-use + let rc = reprojection_cache[local_index]; + accumulate_temporal( + rs.inner, rs.position, pixel, + spatial.color, parameters.temporal_accumulation_weight, + rc.surface, rc.pixel_coord, rc.is_valid, + ); + return spatial.color; } -var out_depth: texture_storage_2d; -var out_basis: texture_storage_2d; -var out_flat_normal: texture_storage_2d; -var out_albedo: texture_storage_2d; -var out_motion: texture_storage_2d; - @compute @workgroup_size(GROUP_SIZE.x, GROUP_SIZE.y) fn main( @builtin(workgroup_id) group_id: vec3, @@ -617,19 +636,19 @@ fn main( let enable_debug = all(pixel_coord == vec2(debug.mouse_pos)); let rs = fetch_geometry(pixel_coord, true, enable_debug); - // TODO: option to avoid writing data for the sky - textureStore(out_depth, pixel_coord, vec4(rs.inner.depth, 0.0, 0.0, 0.0)); - textureStore(out_basis, pixel_coord, rs.inner.basis); - textureStore(out_flat_normal, pixel_coord, vec4(rs.inner.flat_normal, 0.0)); - textureStore(out_albedo, pixel_coord, vec4(rs.albedo, 0.0)); - textureStore(out_motion, pixel_coord, vec4(rs.motion * MOTION_SCALE, 0.0, 0.0)); - let global_index = u32(pixel_coord.y) * camera.target_size.x + u32(pixel_coord.x); var rng = random_init(global_index, parameters.frame_index); let enable_restir_debug = (debug.draw_flags & DebugDrawFlags_RESTIR) != 0u && enable_debug; let color = compute_restir(rs, pixel_coord, local_index, group_id, &rng, enable_restir_debug); - textureStore(out_diffuse, pixel_coord, vec4(color, 1.0)); + + //Note: important to do this after the temporal pass specifically + // TODO: option to avoid writing data for the sky + textureStore(inout_depth, pixel_coord, vec4(rs.inner.depth, 0.0, 0.0, 0.0)); + textureStore(inout_basis, pixel_coord, rs.inner.basis); + textureStore(inout_flat_normal, pixel_coord, vec4(rs.inner.flat_normal, 0.0)); + textureStore(out_albedo, pixel_coord, vec4(rs.albedo, 0.0)); + textureStore(out_motion, pixel_coord, vec4(rs.motion * MOTION_SCALE, 0.0, 0.0)); if (enable_debug) { debug_buf.variance.color_sum += color; diff --git a/blade-render/src/render/mod.rs b/blade-render/src/render/mod.rs index e0e724c5..51d3f635 100644 --- a/blade-render/src/render/mod.rs +++ b/blade-render/src/render/mod.rs @@ -110,6 +110,7 @@ pub struct RayConfig { #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] pub struct DenoiserConfig { + pub enabled: bool, pub num_passes: u32, pub temporal_weight: f32, } @@ -212,11 +213,13 @@ impl RenderTarget { struct RestirTargets { reservoir_buf: blade_graphics::Buffer, debug: RenderTarget<1>, - depth: RenderTarget<2>, - basis: RenderTarget<2>, - flat_normal: RenderTarget<2>, + depth: RenderTarget<1>, + basis: RenderTarget<1>, + flat_normal: RenderTarget<1>, albedo: RenderTarget<1>, motion: RenderTarget<1>, + // One stores the ReSTIR output color, + // another 2 are used for a-trous ping-pong. light_diffuse: RenderTarget<3>, camera_params: [CameraParams; 2], } @@ -238,7 +241,7 @@ impl RestirTargets { Self { reservoir_buf, debug: RenderTarget::new( - "deubg", + "debug", blade_graphics::TextureFormat::Rgba8Unorm, size, encoder, @@ -297,7 +300,6 @@ impl RestirTargets { } struct Blur { - temporal_accum_pipeline: blade_graphics::ComputePipeline, atrous_pipeline: blade_graphics::ComputePipeline, } @@ -313,7 +315,6 @@ struct Blur { pub struct Renderer { shaders: Shaders, targets: RestirTargets, - post_proc_input_index: usize, main_pipeline: blade_graphics::ComputePipeline, post_proc_pipeline: blade_graphics::RenderPipeline, blur: Blur, @@ -372,6 +373,8 @@ struct MainParams { t_start: f32, use_motion_vectors: u32, grid_scale: [u32; 2], + temporal_accumulation_weight: f32, + pad: f32, } #[derive(blade_macros::ShaderData)] @@ -390,17 +393,14 @@ struct MainData<'a> { sampler_nearest: blade_graphics::Sampler, env_map: blade_graphics::TextureView, env_weights: blade_graphics::TextureView, - t_prev_depth: blade_graphics::TextureView, - t_prev_basis: blade_graphics::TextureView, - t_prev_flat_normal: blade_graphics::TextureView, debug_buf: blade_graphics::BufferPiece, reservoirs: blade_graphics::BufferPiece, - out_depth: blade_graphics::TextureView, - out_basis: blade_graphics::TextureView, - out_flat_normal: blade_graphics::TextureView, + inout_depth: blade_graphics::TextureView, + inout_basis: blade_graphics::TextureView, + inout_flat_normal: blade_graphics::TextureView, out_albedo: blade_graphics::TextureView, out_motion: blade_graphics::TextureView, - out_diffuse: blade_graphics::TextureView, + inout_diffuse: blade_graphics::TextureView, out_debug: blade_graphics::TextureView, } @@ -408,27 +408,10 @@ struct MainData<'a> { #[derive(Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)] struct BlurParams { extent: [u32; 2], - temporal_weight: f32, iteration: i32, - use_motion_vectors: u32, pad: u32, } -#[derive(blade_macros::ShaderData)] -struct TemporalAccumData { - camera: CameraParams, - prev_camera: CameraParams, - params: BlurParams, - input: blade_graphics::TextureView, - prev_input: blade_graphics::TextureView, - t_depth: blade_graphics::TextureView, - t_prev_depth: blade_graphics::TextureView, - t_flat_normal: blade_graphics::TextureView, - t_prev_flat_normal: blade_graphics::TextureView, - t_motion: blade_graphics::TextureView, - output: blade_graphics::TextureView, -} - #[derive(blade_macros::ShaderData)] struct AtrousData { params: BlurParams, @@ -478,7 +461,7 @@ struct HitEntry { pub struct Shaders { env_prepare: blade_asset::Handle, ray_trace: blade_asset::Handle, - blur: blade_asset::Handle, + a_trous: blade_asset::Handle, post_proc: blade_asset::Handle, debug_draw: blade_asset::Handle, debug_blit: blade_asset::Handle, @@ -490,7 +473,7 @@ impl Shaders { let shaders = Self { env_prepare: ctx.load_shader("env-prepare.wgsl"), ray_trace: ctx.load_shader("ray-trace.wgsl"), - blur: ctx.load_shader("blur.wgsl"), + a_trous: ctx.load_shader("a-trous.wgsl"), post_proc: ctx.load_shader("post-proc.wgsl"), debug_draw: ctx.load_shader("debug-draw.wgsl"), debug_blit: ctx.load_shader("debug-blit.wgsl"), @@ -501,8 +484,7 @@ impl Shaders { struct ShaderPipelines { main: blade_graphics::ComputePipeline, - temporal_accum: blade_graphics::ComputePipeline, - atrous: blade_graphics::ComputePipeline, + a_trous: blade_graphics::ComputePipeline, post_proc: blade_graphics::RenderPipeline, env_prepare: blade_graphics::ComputePipeline, reservoir_size: u32, @@ -532,25 +514,13 @@ impl ShaderPipelines { pipeline } - fn create_temporal_accum( - shader: &blade_graphics::Shader, - gpu: &blade_graphics::Context, - ) -> blade_graphics::ComputePipeline { - let layout = ::layout(); - gpu.create_compute_pipeline(blade_graphics::ComputePipelineDesc { - name: "temporal-accum", - data_layouts: &[&layout], - compute: shader.at("temporal_accum"), - }) - } - - fn create_atrous( + fn create_a_trous( shader: &blade_graphics::Shader, gpu: &blade_graphics::Context, ) -> blade_graphics::ComputePipeline { let layout = ::layout(); gpu.create_compute_pipeline(blade_graphics::ComputePipelineDesc { - name: "atrous", + name: "a-trous", data_layouts: &[&layout], compute: shader.at("atrous3x3"), }) @@ -584,11 +554,10 @@ impl ShaderPipelines { shader_man: &blade_asset::AssetManager, ) -> Result { let sh_main = shader_man[shaders.ray_trace].raw.as_ref().unwrap(); - let sh_blur = shader_man[shaders.blur].raw.as_ref().unwrap(); + let sh_atrous = shader_man[shaders.a_trous].raw.as_ref().unwrap(); Ok(Self { main: Self::create_ray_trace(sh_main, gpu), - temporal_accum: Self::create_temporal_accum(sh_blur, gpu), - atrous: Self::create_atrous(sh_blur, gpu), + a_trous: Self::create_a_trous(sh_atrous, gpu), post_proc: Self::create_post_proc( shader_man[shaders.post_proc].raw.as_ref().unwrap(), config.surface_info, @@ -618,6 +587,11 @@ pub struct FrameResources { pub acceleration_structures: Vec, } +#[derive(Debug, Default)] +pub struct FrameKey { + post_proc_input_index: usize, +} + impl Renderer { /// Create a new renderer with a given configuration. /// @@ -675,12 +649,10 @@ impl Renderer { Self { shaders, targets, - post_proc_input_index: 0, main_pipeline: sp.main, post_proc_pipeline: sp.post_proc, blur: Blur { - temporal_accum_pipeline: sp.temporal_accum, - atrous_pipeline: sp.atrous, + atrous_pipeline: sp.a_trous, }, acceleration_structure: blade_graphics::AccelerationStructure::default(), prev_acceleration_structure: blade_graphics::AccelerationStructure::default(), @@ -720,7 +692,6 @@ impl Renderer { gpu.destroy_sampler(self.samplers.nearest); gpu.destroy_sampler(self.samplers.linear); // pipelines - gpu.destroy_compute_pipeline(&mut self.blur.temporal_accum_pipeline); gpu.destroy_compute_pipeline(&mut self.blur.atrous_pipeline); gpu.destroy_compute_pipeline(&mut self.main_pipeline); gpu.destroy_render_pipeline(&mut self.post_proc_pipeline); @@ -737,7 +708,7 @@ impl Renderer { let old = self.shaders.clone(); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.ray_trace)); - tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.blur)); + tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.a_trous)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.post_proc)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.debug_draw)); tasks.extend(asset_hub.shaders.hot_reload(&mut self.shaders.debug_blit)); @@ -761,11 +732,9 @@ impl Renderer { self.main_pipeline = ShaderPipelines::create_ray_trace(shader, gpu); } } - if self.shaders.blur != old.blur { - if let Ok(ref shader) = asset_hub.shaders[self.shaders.blur].raw { - self.blur.temporal_accum_pipeline = - ShaderPipelines::create_temporal_accum(shader, gpu); - self.blur.atrous_pipeline = ShaderPipelines::create_atrous(shader, gpu); + if self.shaders.a_trous != old.a_trous { + if let Ok(ref shader) = asset_hub.shaders[self.shaders.a_trous].raw { + self.blur.atrous_pipeline = ShaderPipelines::create_a_trous(shader, gpu); } } if self.shaders.post_proc != old.post_proc { @@ -1071,7 +1040,6 @@ impl Renderer { self.frame_index += 1; } self.targets.camera_params[self.frame_index % 2] = self.make_camera_params(camera); - self.post_proc_input_index = self.frame_index % 2; } /// Ray trace the scene. @@ -1083,9 +1051,11 @@ impl Renderer { command_encoder: &mut blade_graphics::CommandEncoder, debug_config: DebugConfig, ray_config: RayConfig, - ) { + denoiser_config: DenoiserConfig, + ) -> FrameKey { let debug = self.make_debug_params(&debug_config); let (cur, prev) = self.work_indices(); + let mut post_proc_input_index = 0; if let mut pass = command_encoder.compute() { let grid_scale = { @@ -1127,6 +1097,12 @@ impl Renderer { t_start: ray_config.t_start, use_motion_vectors: (self.frame_scene_built == self.frame_index) as u32, grid_scale, + temporal_accumulation_weight: if denoiser_config.enabled { + denoiser_config.temporal_weight + } else { + 1.0 + }, + pad: 0.0, }, acc_struct: self.acceleration_structure, prev_acc_struct: if self.frame_scene_built < self.frame_index @@ -1145,91 +1121,53 @@ impl Renderer { sampler_nearest: self.samplers.nearest, env_map: self.env_map.main_view, env_weights: self.env_map.weight_view, - t_prev_depth: self.targets.depth.views[prev], - t_prev_basis: self.targets.basis.views[prev], - t_prev_flat_normal: self.targets.flat_normal.views[prev], debug_buf: self.debug.buffer_resource(), reservoirs: self.targets.reservoir_buf.into(), - out_depth: self.targets.depth.views[cur], - out_basis: self.targets.basis.views[cur], - out_flat_normal: self.targets.flat_normal.views[cur], + inout_depth: self.targets.depth.views[0], + inout_basis: self.targets.basis.views[0], + inout_flat_normal: self.targets.flat_normal.views[0], out_albedo: self.targets.albedo.views[0], out_motion: self.targets.motion.views[0], - out_diffuse: self.targets.light_diffuse.views[cur], + inout_diffuse: self.targets.light_diffuse.views[post_proc_input_index], out_debug: self.targets.debug.views[0], }, ); pc.dispatch(groups); } - } - /// Perform noise reduction using SVGF. - #[profiling::function] - pub fn denoise( - &mut self, //TODO: borrow immutably - command_encoder: &mut blade_graphics::CommandEncoder, - denoiser_config: DenoiserConfig, - ) { - let mut params = BlurParams { - extent: [self.surface_size.width, self.surface_size.height], - temporal_weight: denoiser_config.temporal_weight, - iteration: 0, - use_motion_vectors: (self.frame_scene_built == self.frame_index) as u32, - pad: 0, - }; - let (cur, prev) = self.work_indices(); - let temp = 2; - - if denoiser_config.temporal_weight < 1.0 { - let mut pass = command_encoder.compute(); - let mut pc = pass.with(&self.blur.temporal_accum_pipeline); - let groups = self - .blur - .atrous_pipeline - .get_dispatch_for(self.surface_size); - pc.bind( - 0, - &TemporalAccumData { - camera: self.targets.camera_params[cur], - prev_camera: self.targets.camera_params[prev], - params, - input: self.targets.light_diffuse.views[cur], - prev_input: self.targets.light_diffuse.views[prev], - t_depth: self.targets.depth.views[cur], - t_prev_depth: self.targets.depth.views[prev], - t_flat_normal: self.targets.flat_normal.views[cur], - t_prev_flat_normal: self.targets.flat_normal.views[prev], - t_motion: self.targets.motion.views[0], - output: self.targets.light_diffuse.views[temp], - }, - ); - pc.dispatch(groups); - //Note: making `cur` contain the latest reprojection output - self.targets.light_diffuse.views.swap(cur, temp); + if denoiser_config.enabled { + let mut params = BlurParams { + extent: [self.surface_size.width, self.surface_size.height], + iteration: 0, + pad: 0, + }; + let mut ping_pong = [1, 2]; + for _ in 0..denoiser_config.num_passes { + let mut pass = command_encoder.compute(); + let mut pc = pass.with(&self.blur.atrous_pipeline); + let groups = self + .blur + .atrous_pipeline + .get_dispatch_for(self.surface_size); + pc.bind( + 0, + &AtrousData { + params, + input: self.targets.light_diffuse.views[post_proc_input_index], + t_depth: self.targets.depth.views[0], + t_flat_normal: self.targets.flat_normal.views[0], + output: self.targets.light_diffuse.views[ping_pong[0]], + }, + ); + pc.dispatch(groups); + post_proc_input_index = ping_pong[0]; + ping_pong.swap(0, 1); + params.iteration += 1; + } } - let mut ping_pong = [temp, prev]; - for _ in 0..denoiser_config.num_passes { - let mut pass = command_encoder.compute(); - let mut pc = pass.with(&self.blur.atrous_pipeline); - let groups = self - .blur - .atrous_pipeline - .get_dispatch_for(self.surface_size); - pc.bind( - 0, - &AtrousData { - params, - input: self.targets.light_diffuse.views[self.post_proc_input_index], - t_depth: self.targets.depth.views[cur], - t_flat_normal: self.targets.flat_normal.views[cur], - output: self.targets.light_diffuse.views[ping_pong[0]], - }, - ); - pc.dispatch(groups); - self.post_proc_input_index = ping_pong[0]; - ping_pong.swap(0, 1); - params.iteration += 1; + FrameKey { + post_proc_input_index, } } @@ -1238,6 +1176,7 @@ impl Renderer { pub fn post_proc( &self, pass: &mut blade_graphics::RenderCommandEncoder, + key: FrameKey, debug_config: DebugConfig, pp_config: PostProcConfig, debug_lines: &[DebugLine], @@ -1250,7 +1189,7 @@ impl Renderer { 0, &PostProcData { t_albedo: self.targets.albedo.views[0], - light_diffuse: self.targets.light_diffuse.views[self.post_proc_input_index], + light_diffuse: self.targets.light_diffuse.views[key.post_proc_input_index], t_debug: self.targets.debug.views[0], tone_map_params: ToneMapParams { enabled: 1, @@ -1267,7 +1206,7 @@ impl Renderer { self.debug.render_lines( debug_lines, self.targets.camera_params[cur], - self.targets.depth.views[cur], + self.targets.depth.views[0], pass, ); self.debug diff --git a/examples/scene/main.rs b/examples/scene/main.rs index 903a30ca..1e9c6e57 100644 --- a/examples/scene/main.rs +++ b/examples/scene/main.rs @@ -159,7 +159,6 @@ struct Example { last_render_time: time::Instant, render_times: VecDeque, ray_config: blade_render::RayConfig, - denoiser_enabled: bool, denoiser_config: blade_render::DenoiserConfig, post_proc_config: blade_render::PostProcConfig, debug_blit: Option, @@ -268,8 +267,8 @@ impl Example { group_mixer: 10, t_start: 0.1, }, - denoiser_enabled: true, denoiser_config: blade_render::DenoiserConfig { + enabled: true, num_passes: 3, temporal_weight: 0.1, }, @@ -459,6 +458,7 @@ impl Example { // even while it's still being loaded. let do_render = self.scene_load_task.is_none() || (RENDER_WHILE_LOADING && self.scene_revision != 0); + let mut frame_key = blade_render::FrameKey::default(); if do_render { self.renderer.prepare( command_encoder, @@ -475,11 +475,12 @@ impl Example { //TODO: figure out why the main RT pipeline // causes a GPU crash when there are no objects if !self.objects.is_empty() { - self.renderer - .ray_trace(command_encoder, self.debug, self.ray_config); - if self.denoiser_enabled { - self.renderer.denoise(command_encoder, self.denoiser_config); - } + frame_key = self.renderer.ray_trace( + command_encoder, + self.debug, + self.ray_config, + self.denoiser_config, + ); } } @@ -509,6 +510,7 @@ impl Example { }; self.renderer.post_proc( &mut pass, + frame_key, self.debug, self.post_proc_config, &[], @@ -672,7 +674,6 @@ impl Example { egui::CollapsingHeader::new("Denoise") .default_open(false) .show(ui, |ui| { - ui.checkbox(&mut self.denoiser_enabled, "Enable"); self.denoiser_config.populate_hud(ui); }); diff --git a/src/lib.rs b/src/lib.rs index b14b8ebe..32f6e738 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -378,7 +378,6 @@ pub struct Engine { debug: blade_render::DebugConfig, pub frame_config: blade_render::FrameConfig, pub ray_config: blade_render::RayConfig, - pub denoiser_enabled: bool, pub denoiser_config: blade_render::DenoiserConfig, pub post_proc_config: blade_render::PostProcConfig, track_hot_reloads: bool, @@ -493,8 +492,8 @@ impl Engine { group_mixer: 10, t_start: 0.01, }, - denoiser_enabled: true, denoiser_config: blade_render::DenoiserConfig { + enabled: true, num_passes: 4, temporal_weight: 0.1, }, @@ -573,6 +572,7 @@ impl Engine { // We should be able to update TLAS and render content // even while it's still being loaded. + let mut frame_key = blade_render::FrameKey::default(); if self.load_tasks.is_empty() { self.render_objects.clear(); for (_, object) in self.objects.iter_mut() { @@ -628,11 +628,12 @@ impl Engine { self.frame_config.reset_reservoirs = false; if !self.render_objects.is_empty() { - self.renderer - .ray_trace(command_encoder, self.debug, self.ray_config); - if self.denoiser_enabled { - self.renderer.denoise(command_encoder, self.denoiser_config); - } + frame_key = self.renderer.ray_trace( + command_encoder, + self.debug, + self.ray_config, + self.denoiser_config, + ); } } @@ -702,6 +703,7 @@ impl Engine { if self.load_tasks.is_empty() { self.renderer.post_proc( &mut pass, + frame_key, self.debug, self.post_proc_config, &debug_lines, @@ -736,7 +738,6 @@ impl Engine { .show(ui, |ui| { self.ray_config.populate_hud(ui); self.frame_config.reset_reservoirs |= ui.button("Reset Accumulation").clicked(); - ui.checkbox(&mut self.denoiser_enabled, "Enable Denoiser"); self.denoiser_config.populate_hud(ui); self.post_proc_config.populate_hud(ui); });