Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify temporal and spatial taps #194

Merged
merged 1 commit into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions blade-helpers/src/hud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@ impl ExposeHud for blade_render::RayConfig {
&mut self.environment_importance_sampling,
"Env importance sampling",
);
ui.checkbox(&mut self.temporal_tap, "Temporal tap");
ui.add(egui::widgets::Slider::new(&mut self.tap_count, 0..=10).text("Tap count"));
ui.add(egui::widgets::Slider::new(&mut self.tap_radius, 1..=50).text("Tap radius (px)"));
ui.add(
egui::widgets::Slider::new(&mut self.temporal_history, 0..=50).text("Temporal history"),
egui::widgets::Slider::new(&mut self.tap_confidence_near, 1..=50)
.text("Max confidence"),
);
ui.add(egui::widgets::Slider::new(&mut self.spatial_taps, 0..=10).text("Spatial taps"));
ui.add(
egui::widgets::Slider::new(&mut self.spatial_tap_history, 0..=50)
.text("Spatial tap history"),
);
ui.add(
egui::widgets::Slider::new(&mut self.spatial_radius, 1..=50)
.text("Spatial radius (px)"),
egui::widgets::Slider::new(&mut self.tap_confidence_far, 1..=50).text("Min confidence"),
);
ui.add(
egui::widgets::Slider::new(&mut self.t_start, 0.001..=0.5)
Expand Down
9 changes: 4 additions & 5 deletions blade-helpers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ pub fn default_ray_config() -> blade_render::RayConfig {
blade_render::RayConfig {
num_environment_samples: 1,
environment_importance_sampling: false,
temporal_tap: true,
temporal_history: 10,
spatial_taps: 1,
spatial_tap_history: 10,
spatial_radius: 20,
tap_count: 2,
tap_radius: 20,
tap_confidence_near: 15,
tap_confidence_far: 10,
t_start: 0.01,
pairwise_mis: true,
defensive_mis: 0.1,
Expand Down
89 changes: 33 additions & 56 deletions blade-render/code/ray-trace.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,22 @@
const RAY_FLAG_CULL_NO_OPAQUE: u32 = 0x80u;

const PI: f32 = 3.1415926;
const MAX_RESERVOIRS: u32 = 2u;
const MAX_RESERVOIRS: u32 = 4u;
// See "DECOUPLING SHADING AND REUSE" in
// "Rearchitecting Spatiotemporal Resampling for Production"
const DECOUPLED_SHADING: bool = false;

// We are considering 2x2 grid, so must be <= 4
const FACTOR_TEMPORAL_CANDIDATES: u32 = 1u;
// How many more candidates to consder than the taps we need
const FACTOR_SPATIAL_CANDIDATES: u32 = 3u;
// Has to be at least discarding the 2x2 block
const MIN_SPATIAL_REUSE_DISTANCE: i32 = 7;
const FACTOR_CANDIDATES: u32 = 3u;

struct MainParams {
frame_index: u32,
num_environment_samples: u32,
environment_importance_sampling: u32,
temporal_tap: u32,
temporal_history: u32,
spatial_taps: u32,
spatial_tap_history: u32,
spatial_radius: i32,
tap_count: u32,
tap_radius: f32,
tap_confidence_near: f32,
tap_confidence_far: f32,
t_start: f32,
use_pairwise_mis: u32,
defensive_mis: f32,
Expand Down Expand Up @@ -124,13 +119,13 @@ fn normalize_reservoir(r: ptr<function, LiveReservoir>, history: f32) {
(*r).history = history;
}
}
fn unpack_reservoir(f: StoredReservoir, max_history: u32, radiance: vec3<f32>) -> LiveReservoir {
fn unpack_reservoir(f: StoredReservoir, max_confidence: f32, radiance: vec3<f32>) -> LiveReservoir {
var r: LiveReservoir;
r.selected_light_index = f.light_index;
r.selected_uv = f.light_uv;
r.selected_target_score = f.target_score;
r.selected_radiance = radiance;
let history = min(f.confidence, f32(max_history));
let history = min(f.confidence, max_confidence);
r.weight_sum = f.contribution_weight * f.target_score * history;
r.history = history;
return r;
Expand Down Expand Up @@ -234,7 +229,9 @@ fn evaluate_brdf(surface: Surface, dir: vec3<f32>) -> f32 {
return lambert_brdf * max(0.0, lambert_term);
}

fn check_ray_occluded(acs: acceleration_structure, position: vec3<f32>, direction: vec3<f32>, debug_len: f32) -> bool {
var<private> debug_len: f32;

fn check_ray_occluded(acs: acceleration_structure, position: vec3<f32>, direction: vec3<f32>, debug_len: f32, debug_color: u32) -> bool {
var rq: ray_query;
let flags = RAY_FLAG_TERMINATE_ON_FIRST_HIT | RAY_FLAG_CULL_NO_OPAQUE;
rayQueryInitialize(&rq, acs,
Expand All @@ -244,8 +241,8 @@ fn check_ray_occluded(acs: acceleration_structure, position: vec3<f32>, directio
let intersection = rayQueryGetCommittedIntersection(&rq);

let occluded = intersection.kind != RAY_QUERY_INTERSECTION_NONE;
if (debug_len != 0.0) {
let color = select(0xFFFFFFu, 0x0000FFu, occluded);
if (DEBUG_MODE && debug_len > 0.0) {
let color = select(0xFFFFFFu, 0x808080u, occluded) & debug_color;
debug_line(position, position + debug_len * direction, color);
}
return occluded;
Expand Down Expand Up @@ -284,7 +281,8 @@ fn make_target_score(color: vec3<f32>) -> TargetScore {
}

fn estimate_target_score_with_occlusion(
surface: Surface, position: vec3<f32>, light_index: u32, light_uv: vec2<f32>, acs: acceleration_structure, debug_len: f32
surface: Surface, position: vec3<f32>, light_index: u32, light_uv: vec2<f32>, acs: acceleration_structure,
debug_len: f32, debug_color: u32,
) -> TargetScore {
if (light_index != 0u) {
return TargetScore();
Expand All @@ -298,7 +296,7 @@ fn estimate_target_score_with_occlusion(
return TargetScore();
}

if (check_ray_occluded(acs, position, direction, debug_len)) {
if (check_ray_occluded(acs, position, direction, debug_len, debug_color)) {
return TargetScore();
} else {
//Note: same as `evaluate_reflected_light`
Expand All @@ -307,7 +305,7 @@ fn estimate_target_score_with_occlusion(
}
}

fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debug_len: f32) -> f32 {
fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debug_len: f32, debug_color: u32) -> f32 {
let dir = map_equirect_uv_to_dir(ls.uv);
if (dot(dir, surface.flat_normal) <= 0.0) {
return 0.0;
Expand All @@ -323,7 +321,7 @@ fn evaluate_sample(ls: LightSample, surface: Surface, start_pos: vec3<f32>, debu
return 0.0;
}

if (check_ray_occluded(acc_struct, start_pos, dir, debug_len)) {
if (check_ray_occluded(acc_struct, start_pos, dir, debug_len, debug_color)) {
return 0.0;
}

Expand All @@ -350,9 +348,9 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
if (WRITE_DEBUG_IMAGE && debug.view_mode == DebugMode_Depth) {
textureStore(out_debug, pixel, vec4<f32>(1.0 / surface.depth));
}
let debug_len = select(0.0, surface.depth * 0.2, enable_debug);
let position = camera.position + surface.depth * ray_dir;
let normal = qrot(surface.basis, vec3<f32>(0.0, 0.0, 1.0));
let debug_len = select(0.0, surface.depth * 0.2, enable_debug);

var canonical = LiveReservoir();
for (var i = 0u; i < parameters.num_environment_samples; i += 1u) {
Expand All @@ -363,7 +361,7 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
ls = sample_light_from_sphere(rng);
}

let brdf = evaluate_sample(ls, surface, position, debug_len);
let brdf = evaluate_sample(ls, surface, position, debug_len, 0x00FF00u);
if (brdf > 0.0) {
let other = make_reservoir(ls, 0u, vec3<f32>(brdf));
merge_reservoir(&canonical, other, random_gen(rng));
Expand All @@ -373,36 +371,17 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
}

let center_coord = get_prev_pixel(pixel, position);
let center_pixel = vec2<i32>(center_coord);
// Trick to start with closer pixels: we derive the "further"
// pixel in 2x2 grid by considering the sum.
let further_pixel = vec2<i32>(center_coord - 0.5) + vec2<i32>(center_coord + 0.5) - center_pixel;

// First, gather the list of reservoirs to merge with
var accepted_reservoir_indices = array<i32, MAX_RESERVOIRS>();
var accepted_count = 0u;
var temporal_index = ~0u;
let num_temporal_candidates = parameters.temporal_tap * FACTOR_TEMPORAL_CANDIDATES;
let num_candidates = num_temporal_candidates + parameters.spatial_taps * FACTOR_SPATIAL_CANDIDATES;
let max_samples = min(MAX_RESERVOIRS, 1u + parameters.spatial_taps);
let max_samples = min(MAX_RESERVOIRS, parameters.tap_count);
let num_candidates = max_samples * FACTOR_CANDIDATES;

for (var tap = 0u; tap < num_candidates && accepted_count < max_samples; tap += 1u) {
var other_pixel = center_pixel;
if (tap < num_temporal_candidates) {
if (temporal_index < tap) {
continue;
}
let mask = vec2<u32>(tap) & vec2<u32>(1u, 2u);
other_pixel = select(center_pixel, further_pixel, mask != vec2<u32>(0u));
} else {
let r0 = max(center_pixel - vec2<i32>(parameters.spatial_radius), vec2<i32>(0));
let r1 = min(center_pixel + vec2<i32>(parameters.spatial_radius + 1), vec2<i32>(prev_camera.target_size));
other_pixel = vec2<i32>(mix(vec2<f32>(r0), vec2<f32>(r1), vec2<f32>(random_gen(rng), random_gen(rng))));
let diff = other_pixel - center_pixel;
if (dot(diff, diff) < MIN_SPATIAL_REUSE_DISTANCE) {
continue;
}
}
let radius = parameters.tap_radius * random_gen(rng);
let offset = radius * sample_circle(random_gen(rng));
let other_pixel = vec2<i32>(center_coord + offset);

let other_index = get_reservoir_index(other_pixel, prev_camera);
if (other_index < 0) {
Expand All @@ -419,9 +398,6 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
continue;
}

if (tap < num_temporal_candidates) {
temporal_index = accepted_count;
}
accepted_reservoir_indices[accepted_count] = other_index;
accepted_count += 1u;
}
Expand All @@ -444,25 +420,26 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
for (var rid = 0u; rid < accepted_count; rid += 1u) {
let neighbor_index = accepted_reservoir_indices[rid];
let neighbor = prev_reservoirs[neighbor_index];
let neighbor_pixel = get_pixel_from_reservoir_index(neighbor_index, prev_camera);

let max_history = select(parameters.spatial_tap_history, parameters.temporal_history, rid == temporal_index);
let offset = vec2<f32>(neighbor_pixel) - center_coord;
let max_confidence = mix(parameters.tap_confidence_near, parameters.tap_confidence_far, length(offset) / parameters.tap_radius);
var other: LiveReservoir;
if (parameters.use_pairwise_mis != 0u) {
let neighbor_pixel = get_pixel_from_reservoir_index(neighbor_index, prev_camera);
let neighbor_history = min(neighbor.confidence, f32(max_history));
let neighbor_history = min(neighbor.confidence, max_confidence);
{ // scoping this to hint the register allocation
let neighbor_surface = read_prev_surface(neighbor_pixel);
let neighbor_dir = get_ray_direction(prev_camera, neighbor_pixel);
let neighbor_position = prev_camera.position + neighbor_surface.depth * neighbor_dir;

let t_canonical_at_neighbor = estimate_target_score_with_occlusion(
neighbor_surface, neighbor_position, canonical.selected_light_index, canonical.selected_uv, prev_acc_struct, debug_len);
neighbor_surface, neighbor_position, canonical.selected_light_index, canonical.selected_uv, prev_acc_struct, debug_len, 0xFF0000u);
let r_canonical = ratio(canonical.history * canonical.selected_target_score * inv_count, neighbor_history * t_canonical_at_neighbor.score);
mis_canonical += mis_scale * r_canonical;
}

let t_neighbor_at_canonical = estimate_target_score_with_occlusion(
surface, position, neighbor.light_index, neighbor.light_uv, acc_struct, debug_len);
surface, position, neighbor.light_index, neighbor.light_uv, acc_struct, debug_len, 0x0000FFu);
let r_neighbor = ratio(neighbor_history * neighbor.target_score, canonical.history * t_neighbor_at_canonical.score * inv_count);
let mis_neighbor = mis_scale * r_neighbor;

Expand All @@ -473,8 +450,8 @@ fn compute_restir(surface: Surface, pixel: vec2<i32>, rng: ptr<function, RandomS
other.selected_radiance = t_neighbor_at_canonical.color;
other.weight_sum = t_neighbor_at_canonical.score * neighbor.contribution_weight * mis_neighbor;
} else {
let radiance = evaluate_reflected_light(surface, other.selected_light_index, other.selected_uv);
other = unpack_reservoir(neighbor, max_history, radiance);
let radiance = evaluate_reflected_light(surface, neighbor.light_index, neighbor.light_uv);
other = unpack_reservoir(neighbor, max_confidence, radiance);
}

if (DECOUPLED_SHADING) {
Expand Down
27 changes: 12 additions & 15 deletions blade-render/src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,10 @@ pub struct DebugConfig {
pub struct RayConfig {
pub num_environment_samples: u32,
pub environment_importance_sampling: bool,
pub temporal_tap: bool,
pub temporal_history: u32,
pub spatial_taps: u32,
pub spatial_tap_history: u32,
pub spatial_radius: u32,
pub tap_count: u32,
pub tap_radius: u32,
pub tap_confidence_near: u32,
pub tap_confidence_far: u32,
pub t_start: f32,
/// Evaluate MIS factor for ReSTIR in a pair-wise fashion.
/// Adds 2 extra visibility rays per reused sample.
Expand Down Expand Up @@ -372,11 +371,10 @@ struct MainParams {
frame_index: u32,
num_environment_samples: u32,
environment_importance_sampling: u32,
temporal_tap: u32,
temporal_history: u32,
spatial_taps: u32,
spatial_tap_history: u32,
spatial_radius: u32,
tap_count: u32,
tap_radius: f32,
tap_confidence_near: f32,
tap_confidence_far: f32,
t_start: f32,
use_pairwise_mis: u32,
defensive_mis: f32,
Expand Down Expand Up @@ -1172,11 +1170,10 @@ impl Renderer {
num_environment_samples: ray_config.num_environment_samples,
environment_importance_sampling: ray_config.environment_importance_sampling
as u32,
temporal_tap: ray_config.temporal_tap as u32,
temporal_history: ray_config.temporal_history,
spatial_taps: ray_config.spatial_taps,
spatial_tap_history: ray_config.spatial_tap_history,
spatial_radius: ray_config.spatial_radius,
tap_count: ray_config.tap_count,
tap_radius: ray_config.tap_radius as f32,
tap_confidence_near: ray_config.tap_confidence_near as f32,
tap_confidence_far: ray_config.tap_confidence_far as f32,
t_start: ray_config.t_start,
use_pairwise_mis: ray_config.pairwise_mis as u32,
defensive_mis: ray_config.defensive_mis,
Expand Down
Loading