Skip to content

Solari GI: Balance heuristic for spatial resampling #20259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions crates/bevy_solari/src/realtime/restir_di.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
textureStore(view_output, global_id.xy, vec4(pixel_color, 1.0));
}

fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, diffuse_brdf: vec3<f32>, workgroup_id: vec2<u32>, rng: ptr<function, u32>) -> Reservoir{
fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, diffuse_brdf: vec3<f32>, workgroup_id: vec2<u32>, rng: ptr<function, u32>) -> Reservoir {
var workgroup_rng = (workgroup_id.x * 5782582u) + workgroup_id.y;
let light_tile_start = rand_range_u(128u, &workgroup_rng) * 1024u;

Expand Down Expand Up @@ -266,7 +266,6 @@ fn merge_reservoirs(
diffuse_brdf: vec3<f32>,
rng: ptr<function, u32>,
) -> ReservoirMergeResult {
// TODO: Balance heuristic MIS weights
let mis_weight_denominator = 1.0 / (canonical_reservoir.confidence_weight + other_reservoir.confidence_weight);

let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator;
Expand Down
187 changes: 138 additions & 49 deletions crates/bevy_solari/src/realtime/restir_gi.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ fn initial_and_temporal(@builtin(global_invocation_id) global_id: vec3<u32>) {

let initial_reservoir = generate_initial_reservoir(world_position, world_normal, &rng);
let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal);
let merge_result = merge_reservoirs(initial_reservoir, temporal_reservoir, vec3(1.0), vec3(1.0), &rng);
let combined_reservoir = merge_reservoirs(initial_reservoir, temporal_reservoir, &rng);

gi_reservoirs_b[pixel_index] = merge_result.merged_reservoir;
gi_reservoirs_b[pixel_index] = combined_reservoir;
}

@compute @workgroup_size(8, 8, 1)
Expand All @@ -68,12 +68,9 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
let diffuse_brdf = base_color / PI;

let input_reservoir = gi_reservoirs_b[pixel_index];
let spatial_reservoir = load_spatial_reservoir(global_id.xy, depth, world_position, world_normal, &rng);

let input_factor = dot(normalize(input_reservoir.sample_point_world_position - world_position), world_normal) * diffuse_brdf;
let spatial_factor = dot(normalize(spatial_reservoir.sample_point_world_position - world_position), world_normal) * diffuse_brdf;

let merge_result = merge_reservoirs(input_reservoir, spatial_reservoir, input_factor, spatial_factor, &rng);
let spatial = load_spatial_reservoir(global_id.xy, depth, world_position, world_normal, &rng);
let merge_result = merge_reservoirs_spatial(input_reservoir, world_position, world_normal, diffuse_brdf,
spatial.reservoir, spatial.world_position, spatial.world_normal, spatial.diffuse_brdf, &rng);
let combined_reservoir = merge_result.merged_reservoir;

gi_reservoirs_a[pixel_index] = combined_reservoir;
Expand All @@ -83,7 +80,7 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
textureStore(view_output, global_id.xy, pixel_color);
}

fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir{
fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir {
var reservoir = empty_reservoir();

let ray_direction = sample_uniform_hemisphere(world_normal, rng);
Expand Down Expand Up @@ -141,34 +138,30 @@ fn load_temporal_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3
return temporal_reservoir;
}

fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir {
struct SpatialInfo {
reservoir: Reservoir,
world_position: vec3<f32>,
world_normal: vec3<f32>,
diffuse_brdf: vec3<f32>,
}

fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> SpatialInfo {
let spatial_pixel_id = get_neighbor_pixel_id(pixel_id, rng);

let spatial_depth = textureLoad(depth_buffer, spatial_pixel_id, 0);
let spatial_gpixel = textureLoad(gbuffer, spatial_pixel_id, 0);
let spatial_world_position = reconstruct_world_position(spatial_pixel_id, spatial_depth);
let spatial_world_normal = octahedral_decode(unpack_24bit_normal(spatial_gpixel.a));
let spatial_base_color = pow(unpack4x8unorm(spatial_gpixel.r).rgb, vec3(2.2));
let spatial_diffuse_brdf = spatial_base_color / PI;
if pixel_dissimilar(depth, world_position, spatial_world_position, world_normal, spatial_world_normal) {
return empty_reservoir();
return SpatialInfo(empty_reservoir(), spatial_world_position, spatial_world_normal, spatial_diffuse_brdf);
}

let spatial_pixel_index = spatial_pixel_id.x + spatial_pixel_id.y * u32(view.viewport.z);
var spatial_reservoir = gi_reservoirs_b[spatial_pixel_index];
let spatial_reservoir = gi_reservoirs_b[spatial_pixel_index];

var jacobian = jacobian(
world_position,
spatial_world_position,
spatial_reservoir.sample_point_world_position,
spatial_reservoir.sample_point_world_normal
);
if jacobian > 10.0 || jacobian < 0.1 {
return empty_reservoir();
}
spatial_reservoir.unbiased_contribution_weight *= jacobian;

spatial_reservoir.unbiased_contribution_weight *= trace_point_visibility(world_position, spatial_reservoir.sample_point_world_position);

return spatial_reservoir;
return SpatialInfo(spatial_reservoir, spatial_world_position, spatial_world_normal, spatial_diffuse_brdf);
}

fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) -> vec2<u32> {
Expand All @@ -178,13 +171,13 @@ fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) ->
}

fn jacobian(
world_position: vec3<f32>,
spatial_world_position: vec3<f32>,
new_world_position: vec3<f32>,
original_world_position: vec3<f32>,
sample_point_world_position: vec3<f32>,
sample_point_world_normal: vec3<f32>,
) -> f32 {
let r = world_position - sample_point_world_position;
let q = spatial_world_position - sample_point_world_position;
let r = new_world_position - sample_point_world_position;
let q = original_world_position - sample_point_world_position;
let rl = length(r);
let ql = length(q);
let phi_r = saturate(dot(r / rl, sample_point_world_normal));
Expand Down Expand Up @@ -256,34 +249,22 @@ fn empty_reservoir() -> Reservoir {
);
}

struct ReservoirMergeResult {
merged_reservoir: Reservoir,
selected_sample_radiance: vec3<f32>,
}

fn merge_reservoirs(
canonical_reservoir: Reservoir,
other_reservoir: Reservoir,
canonical_factor: vec3<f32>,
other_factor: vec3<f32>,
rng: ptr<function, u32>,
) -> ReservoirMergeResult {
) -> Reservoir {
var combined_reservoir = empty_reservoir();
combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight;

if combined_reservoir.confidence_weight == 0.0 { return ReservoirMergeResult(combined_reservoir, vec3(0.0)); }

// TODO: Balance heuristic MIS weights
let mis_weight_denominator = 1.0 / combined_reservoir.confidence_weight;
let mis_weight_denominator = select(0.0, 1.0 / combined_reservoir.confidence_weight, combined_reservoir.confidence_weight > 0.0);

let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator;
let canonical_radiance = canonical_reservoir.radiance * canonical_factor;
let canonical_target_function = luminance(canonical_radiance);
let canonical_target_function = luminance(canonical_reservoir.radiance);
let canonical_resampling_weight = canonical_mis_weight * (canonical_target_function * canonical_reservoir.unbiased_contribution_weight);

let other_mis_weight = other_reservoir.confidence_weight * mis_weight_denominator;
let other_radiance = other_reservoir.radiance * other_factor;
let other_target_function = luminance(other_radiance);
let other_target_function = luminance(other_reservoir.radiance);
let other_resampling_weight = other_mis_weight * (other_target_function * other_reservoir.unbiased_contribution_weight);

combined_reservoir.weight_sum = canonical_resampling_weight + other_resampling_weight;
Expand All @@ -295,16 +276,124 @@ fn merge_reservoirs(

let inverse_target_function = select(0.0, 1.0 / other_target_function, other_target_function > 0.0);
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;

return ReservoirMergeResult(combined_reservoir, other_radiance);
} else {
combined_reservoir.sample_point_world_position = canonical_reservoir.sample_point_world_position;
combined_reservoir.sample_point_world_normal = canonical_reservoir.sample_point_world_normal;
combined_reservoir.radiance = canonical_reservoir.radiance;

let inverse_target_function = select(0.0, 1.0 / canonical_target_function, canonical_target_function > 0.0);
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;
}

return combined_reservoir;
}

struct ReservoirMergeResult {
merged_reservoir: Reservoir,
selected_sample_radiance: vec3<f32>,
}

fn merge_reservoirs_spatial(
canonical_reservoir: Reservoir,
canonical_world_position: vec3<f32>,
canonical_world_normal: vec3<f32>,
canonical_diffuse_brdf: vec3<f32>,
other_reservoir: Reservoir,
other_world_position: vec3<f32>,
other_world_normal: vec3<f32>,
other_diffuse_brdf: vec3<f32>,
rng: ptr<function, u32>,
) -> ReservoirMergeResult {
// Radiances for resampling
let canonical_sample_radiance =
canonical_reservoir.radiance *
saturate(dot(normalize(canonical_reservoir.sample_point_world_position - canonical_world_position), canonical_world_normal)) *
canonical_diffuse_brdf;
let other_sample_radiance =
other_reservoir.radiance *
saturate(dot(normalize(other_reservoir.sample_point_world_position - canonical_world_position), canonical_world_normal)) *
canonical_diffuse_brdf *
trace_point_visibility(canonical_world_position, other_reservoir.sample_point_world_position);

// Target functions for resampling and MIS
let canonical_target_function_canonical_sample = luminance(canonical_sample_radiance);
let canonical_target_function_other_sample = luminance(other_sample_radiance);

// Extra target functions for MIS
let other_target_function_canonical_sample = luminance(
canonical_reservoir.radiance *
saturate(dot(normalize(canonical_reservoir.sample_point_world_position - other_world_position), other_world_normal)) *
other_diffuse_brdf
);
let other_target_function_other_sample = luminance(
other_reservoir.radiance *
saturate(dot(normalize(other_reservoir.sample_point_world_position - other_world_position), other_world_normal)) *
other_diffuse_brdf
);

// Jacobians for resampling and MIS
let canonical_target_function_other_sample_jacobian = jacobian(
canonical_world_position,
other_world_position,
other_reservoir.sample_point_world_position,
other_reservoir.sample_point_world_normal
);
let other_target_function_canonical_sample_jacobian = jacobian(
other_world_position,
canonical_world_position,
canonical_reservoir.sample_point_world_position,
canonical_reservoir.sample_point_world_normal
);

// Resampling weight for canonical sample
let canonical_sample_mis_weight = balance_heuristic(
canonical_reservoir.confidence_weight * canonical_target_function_canonical_sample,
other_reservoir.confidence_weight * other_target_function_canonical_sample * other_target_function_canonical_sample_jacobian,
);
let canonical_sample_resampling_weight = canonical_sample_mis_weight *
canonical_target_function_canonical_sample *
canonical_reservoir.unbiased_contribution_weight;

// Resampling weight for other sample
let other_sample_mis_weight = balance_heuristic(
other_reservoir.confidence_weight * other_target_function_other_sample,
canonical_reservoir.confidence_weight * canonical_target_function_other_sample * canonical_target_function_other_sample_jacobian,
);
let other_sample_resampling_weight = other_sample_mis_weight *
canonical_target_function_other_sample *
other_reservoir.unbiased_contribution_weight *
canonical_target_function_other_sample_jacobian;

// Perform resampling
var combined_reservoir = empty_reservoir();
combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight;
combined_reservoir.weight_sum = canonical_sample_resampling_weight + other_sample_resampling_weight;

if rand_f(rng) < other_sample_resampling_weight / combined_reservoir.weight_sum {
combined_reservoir.sample_point_world_position = other_reservoir.sample_point_world_position;
combined_reservoir.sample_point_world_normal = other_reservoir.sample_point_world_normal;
combined_reservoir.radiance = other_reservoir.radiance;

let inverse_target_function = select(0.0, 1.0 / canonical_target_function_other_sample, canonical_target_function_other_sample > 0.0);
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;

return ReservoirMergeResult(combined_reservoir, other_sample_radiance);
} else {
combined_reservoir.sample_point_world_position = canonical_reservoir.sample_point_world_position;
combined_reservoir.sample_point_world_normal = canonical_reservoir.sample_point_world_normal;
combined_reservoir.radiance = canonical_reservoir.radiance;

let inverse_target_function = select(0.0, 1.0 / canonical_target_function_canonical_sample, canonical_target_function_canonical_sample > 0.0);
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;

return ReservoirMergeResult(combined_reservoir, canonical_sample_radiance);
}
}

return ReservoirMergeResult(combined_reservoir, canonical_radiance);
fn balance_heuristic(x: f32, y: f32) -> f32 {
let sum = x + y;
if sum == 0.0 {
return 0.0;
}
return x / sum;
}
2 changes: 1 addition & 1 deletion crates/bevy_solari/src/scene/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ struct GpuLightSource {
impl GpuLightSource {
fn new_emissive_mesh_light(instance_id: u32, triangle_count: u32) -> GpuLightSource {
if triangle_count > u16::MAX as u32 {
panic!("Too triangles in an emissive mesh, maximum is 65535.");
panic!("Too many triangles ({triangle_count}) in an emissive mesh, maximum is 65535.");
}

Self {
Expand Down
2 changes: 1 addition & 1 deletion release-content/release-notes/bevy_solari.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: Initial raytraced lighting progress (bevy_solari)
authors: ["@JMS55"]
pull_requests: [19058, 19620, 19790, 20020, 20113, 20213]
pull_requests: [19058, 19620, 19790, 20020, 20113, 20213, 20259]
---

(TODO: Embed solari example screenshot here)
Expand Down
Loading