Skip to content

Commit

Permalink
Fixed a-trous kernel and variance guidance
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Oct 28, 2023
1 parent 61db6b9 commit d14917d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 24 deletions.
57 changes: 40 additions & 17 deletions blade-render/code/blur.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
// Spatio-temporal variance-guided filtering
// https://research.nvidia.com/sites/default/files/pubs/2017-07_Spatiotemporal-Variance-Guided-Filtering%3A//svgf_preprint.pdf

// Note: using "ilm" in place of "illumination and the 2nd moment of its luminanc"

struct Params {
extent: vec2<i32>,
temporal_weight: f32,
iteration: u32,
}

var<uniform> camera: CameraParams;
Expand All @@ -21,6 +24,8 @@ var input: texture_2d<f32>;
var prev_input: texture_2d<f32>;
var output: texture_storage_2d<rgba16float, write>;

const LUMA: vec3<f32> = vec3<f32>(0.2126, 0.7152, 0.0722);

fn get_projected_pixel_quad(cp: CameraParams, point: vec3<f32>) -> array<vec2<i32>, 4> {
let pixel = get_projected_pixel_float(cp, point);
return array<vec2<i32>, 4>(
Expand Down Expand Up @@ -51,13 +56,14 @@ fn temporal_accum(@builtin(global_invocation_id) global_id: vec3<u32>) {
return;
}
//TODO: use motion vectors
let cur_radiance = textureLoad(input, pixel, 0).xyz;
let cur_illumination = textureLoad(input, pixel, 0).xyz;
let surface = read_surface(pixel);
let pos_world = camera.position + surface.depth * get_ray_direction(camera, pixel);
// considering all samples in 2x2 quad, to help with edges
var prev_pixels = get_projected_pixel_quad(prev_camera, pos_world);
var best_index = 0;
var best_weight = 0.0;
//TODO: optimize depth load with a gather operation
for (var i = 0; i < 4; i += 1) {
let prev_pixel = prev_pixels[i];
if (all(prev_pixel >= vec2<i32>(0)) && all(prev_pixel < params.extent)) {
Expand All @@ -72,45 +78,62 @@ fn temporal_accum(@builtin(global_invocation_id) global_id: vec3<u32>) {
}
}

var prev_radiance = cur_radiance;
let luminocity = dot(cur_illumination, LUMA);
var mixed_ilm = vec4<f32>(cur_illumination, luminocity * luminocity);
if (best_weight > 0.01) {
prev_radiance = textureLoad(prev_input, prev_pixels[best_index], 0).xyz;
let prev_ilm = textureLoad(prev_input, prev_pixels[best_index], 0);
mixed_ilm = mix(mixed_ilm, prev_ilm, best_weight * (1.0 - params.temporal_weight));
}
let radiance = mix(cur_radiance, prev_radiance, best_weight * (1.0 - params.temporal_weight));
textureStore(output, global_id.xy, vec4<f32>(radiance, 0.0));
textureStore(output, global_id.xy, mixed_ilm);
}

const GAUSSIAN_WEIGHTS = vec2<f32>(0.44198, 0.27901);
const SIGMA_L: f32 = 4.0;
const EPSILON: f32 = 0.001;

fn compare_luminance(a_lum: f32, b_lum: f32, variance: f32) -> f32 {
return exp(-abs(a_lum - b_lum) / (SIGMA_L * variance + EPSILON));
}

const gaussian_weights = vec2<f32>(0.44198, 0.27901);
fn w4(w: f32) -> vec4<f32> {
return vec4<f32>(vec3<f32>(w), w * w);
}

@compute @workgroup_size(8, 8)
fn atrous(@builtin(global_invocation_id) global_id: vec3<u32>) {
fn atrous3x3(@builtin(global_invocation_id) global_id: vec3<u32>) {
let center = vec2<i32>(global_id.xy);
if (any(center >= params.extent)) {
return;
}

let center_radiance = textureLoad(input, center, 0).xyz;
let center_ilm = textureLoad(input, center, 0);
let center_luma = dot(center_ilm.xyz, LUMA);
let variance = sqrt(center_ilm.w);
let center_suf = read_surface(center);
var sum_weight = gaussian_weights[0] * gaussian_weights[0];
var sum_radiance = center_radiance * sum_weight;
var sum_weight = GAUSSIAN_WEIGHTS[0] * GAUSSIAN_WEIGHTS[0];
var sum_ilm = w4(sum_weight) * center_ilm;

for (var yy=-1; yy<=1; yy+=1) {
for (var xx=-1; xx<=1; xx+=1) {
let p = center + vec2<i32>(xx, yy);
let p = center + vec2<i32>(xx, yy) * (1 << params.iteration);
if (all(p == center) || any(p < vec2<i32>(0)) || any(p >= params.extent)) {
continue;
}

//TODO: store in group-shared memory
let surface = read_surface(p);
var weight = gaussian_weights[abs(xx)] * gaussian_weights[abs(yy)];
//weight *= compare_surfaces(center_suf, surface);
let radiance = textureLoad(input, p, 0).xyz;
sum_radiance += weight * radiance;
var weight = GAUSSIAN_WEIGHTS[abs(xx)] * GAUSSIAN_WEIGHTS[abs(yy)];
//TODO: make it stricter on higher iterations
weight *= compare_flat_normals(surface.flat_normal, center_suf.flat_normal);
//Note: should we use a projected depth instead of the surface one?
weight *= compare_depths(surface.depth, center_suf.depth);
let other_ilm = textureLoad(input, p, 0);
weight *= compare_luminance(center_luma, dot(other_ilm.xyz, LUMA), variance);
sum_ilm += w4(weight) * other_ilm;
sum_weight += weight;
}
}

let radiance = sum_radiance / sum_weight;
textureStore(output, global_id.xy, vec4<f32>(radiance, 0.0));
let filtered_ilm = sum_ilm / w4(sum_weight);
textureStore(output, global_id.xy, filtered_ilm);
}
4 changes: 3 additions & 1 deletion blade-render/code/surface.inc.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ struct Surface {
depth: f32,
}

const SIGMA_N: f32 = 4.0;

fn compare_flat_normals(a: vec3<f32>, b: vec3<f32>) -> f32 {
return smoothstep(0.4, 0.9, dot(a, b));
return pow(max(0.0, dot(a, b)), SIGMA_N);
}

fn compare_depths(a: f32, b: f32) -> f32 {
Expand Down
9 changes: 5 additions & 4 deletions blade-render/src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ struct MainData {
struct BlurParams {
extent: [u32; 2],
temporal_weight: f32,
pad: f32,
iteration: i32,
}

#[derive(blade_macros::ShaderData)]
Expand Down Expand Up @@ -624,7 +624,7 @@ impl ShaderPipelines {
gpu.create_compute_pipeline(blade_graphics::ComputePipelineDesc {
name: "atrous",
data_layouts: &[&layout],
compute: shader.at("atrous"),
compute: shader.at("atrous3x3"),
})
}

Expand Down Expand Up @@ -1342,10 +1342,10 @@ impl Renderer {
command_encoder: &mut blade_graphics::CommandEncoder,
denoiser_config: DenoiserConfig,
) {
let params = BlurParams {
let mut params = BlurParams {
extent: [self.screen_size.width, self.screen_size.height],
temporal_weight: denoiser_config.temporal_weight,
pad: 0.0,
iteration: 0,
};
if denoiser_config.temporal_weight < 1.0 {
let cur = self.frame_data.first().unwrap();
Expand Down Expand Up @@ -1400,6 +1400,7 @@ impl Renderer {
pc.dispatch(groups);
self.post_proc_input = targets[0];
targets.swap(0, 1); // rotate the views
params.iteration += 1;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions examples/scene/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ impl Example {
},
denoiser_enabled: true,
denoiser_config: blade_render::DenoiserConfig {
num_passes: 5,
num_passes: 2,
temporal_weight: 0.1,
},
debug_blit: None,
Expand Down Expand Up @@ -613,7 +613,7 @@ impl Example {
egui::Slider::new(&mut dc.temporal_weight, 0.0..=1.0f32)
.text("Temporal weight"),
);
ui.add(egui::Slider::new(&mut dc.num_passes, 0..=15u32).text("A-trous passes"));
ui.add(egui::Slider::new(&mut dc.num_passes, 0..=5u32).text("A-trous passes"));
});

egui::CollapsingHeader::new("Tone Map").show(ui, |ui| {
Expand Down

0 comments on commit d14917d

Please sign in to comment.