Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 102 additions & 63 deletions gpu/circuit_prover/src/prover/trace/holder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@
gather_leaf_rows, gather_merkle_paths_device, gather_merkle_paths_from_rows,
};
use crate::ops::ntt::{
bitreversed_monomials_to_natural_evals_multi_coset, hypercube_x1_msb_evals_to_x1_msb_monomials,
log_size_supports_transposed_monomials, transform_whir_leaves_from_ntt_in_place_multi_coset,
bitreversed_monomials_to_natural_evals_multi_coset,
bitreversed_monomials_to_natural_evals_multi_coset_with_coset_range,
hypercube_x1_msb_evals_to_x1_msb_monomials, log_size_supports_transposed_monomials,
transform_whir_leaves_from_ntt_in_place_multi_coset,
};
use crate::primitives::context::DeviceAllocation;
#[cfg(test)]
use crate::primitives::context::HostAllocation;
#[cfg(test)]
use crate::primitives::device_structures::DeviceMatrix;
use crate::primitives::device_structures::{DeviceMatrixChunk, DeviceMatrixImpl, DeviceMatrixMut};
use crate::primitives::device_structures::{
DeviceMatrixChunk, DeviceMatrixImpl, DeviceMatrixMut, DeviceMatrixMutImpl,
};
use crate::primitives::field::BF;
use crate::prover::ProverContext;

Expand Down Expand Up @@ -215,23 +219,23 @@
}
}

/// Mutable shared-borrow of the full `CosetsHolder::Full` backing, intended
/// for callers that fill all cosets in one shot (multi-coset NTT writing
/// directly into the cosets backing). Asserts `!self.cosets_materialized`;
/// the caller is responsible for calling `mark_cosets_materialized` once the
/// fill completes.
pub(crate) fn get_uninit_consolidated_cosets_mut(&mut self) -> &mut DeviceSlice<T> {
assert!(
!self.cosets_materialized,
"get_uninit_consolidated_cosets_mut: cosets already materialized"
);
match &mut self.cosets {
CosetsHolder::Full(backing) => backing,
CosetsHolder::None(_) => {
panic!("cosets not allocated — call ensure_cosets_materialized first")
}
}
}
// /// Mutable shared-borrow of the full `CosetsHolder::Full` backing, intended
// /// for callers that fill all cosets in one shot (multi-coset NTT writing
// /// directly into the cosets backing). Asserts `!self.cosets_materialized`;
// /// the caller is responsible for calling `mark_cosets_materialized` once the
// /// fill completes.
// pub(crate) fn get_uninit_consolidated_cosets_mut(&mut self) -> &mut DeviceSlice<T> {
// assert!(
// !self.cosets_materialized,
// "get_uninit_consolidated_cosets_mut: cosets already materialized"
// );
// match &mut self.cosets {
// CosetsHolder::Full(backing) => backing,
// CosetsHolder::None(_) => {
// panic!("cosets not allocated — call ensure_cosets_materialized first")
// }
// }
// }

pub(crate) fn get_evaluations(&self) -> &DeviceSlice<T> {
self.get_coset_evaluations(0)
Expand Down Expand Up @@ -621,8 +625,9 @@
/// (the WHIR oracle's TraceHolder shape). The natural lde factor and
/// per-leaf size are passed as arguments — they live outside the TraceHolder
/// abstraction.
pub(crate) fn commit_all_into_from_ntt(
pub(crate) fn whir_lde_and_commit_all_into(
&mut self,
inputs_matrix: &DeviceMatrixChunk<BF>,
dst_u32: &mut DeviceSlice<u32>,
log_trace_len: u32,
natural_log_lde_factor: u32,
Expand All @@ -633,22 +638,22 @@
) -> CudaResult<()> {
assert_eq!(
self.log_lde_factor, 0,
"commit_all_into_from_ntt: TraceHolder must be the WHIR-oracle shape (log_lde_factor = 0)"
"whir_lde_and_commit_all_into: TraceHolder must be the WHIR-oracle shape (log_lde_factor = 0)"
);
assert_eq!(
self.log_rows_per_leaf, 0,
"commit_all_into_from_ntt: TraceHolder must be the WHIR-oracle shape (log_rows_per_leaf = 0)"
"whir_lde_and_commit_all_into: TraceHolder must be the WHIR-oracle shape (log_rows_per_leaf = 0)"
);
assert!(
self.cosets_materialized,
"commit_all_into_from_ntt: cosets backing must be materialized first"
!self.cosets_materialized,
"whir_lde_and_commit_all_into: cosets already materialized"
);
let stream = context.get_exec_stream();
let cap_size = 1usize << self.log_tree_cap_size;
assert_eq!(
dst_u32.len(),
cap_size * BLAKE2S_DIGEST_SIZE_U32_WORDS,
"commit_all_into_from_ntt dst_u32 length must match cap_size * DIGEST_U32_WORDS",
"whir_lde_and_commit_all_into dst_u32 length must match cap_size * DIGEST_U32_WORDS",
);

// Snapshot the cosets backing as a raw const slice so we can borrow
Expand Down Expand Up @@ -678,6 +683,7 @@
match &mut self.trees {
TreesHolder::Full(backing) => {
commit_trace_from_ntt_single_tree(
inputs_matrix,
ntt_output,
backing,
log_trace_len,
Expand All @@ -704,6 +710,7 @@
// = PARTIAL_TREE_REDUCTION_LAYERS.
let top_log_cap = total_leaf_count_log2 + 1 - PARTIAL_TREE_REDUCTION_LAYERS;
commit_trace_from_ntt_single_tree(
inputs_matrix,
ntt_output,
&mut tree_top[..],
log_trace_len,
Expand Down Expand Up @@ -733,7 +740,7 @@
transient_tree_tops = None; // partial path doesn't reuse via gather
}
TreesHolder::None => {
panic!("commit_all_into_from_ntt: TreesCacheMode::CacheNone is not supported; use CachePartial or CacheFull");
panic!("whir_lde_and_commit_all_into: TreesCacheMode::CacheNone is not supported; use CachePartial or CacheFull");
}
}

Expand Down Expand Up @@ -789,12 +796,13 @@
Ok(())
}

/// Wrapper around `commit_all_into_from_ntt` that allocates a private
/// Wrapper around `whir_lde_and_commit_all_into` that allocates a private
/// `unified_device_cap` (mirrors `commit_all`'s relationship to
/// `commit_all_into`). Used by `#[cfg(test)]` callers that don't have a slab
/// destination handy.
pub(crate) fn commit_all_from_ntt(
pub(crate) fn whir_lde_and_commit_all(
&mut self,
inputs_matrix: &DeviceMatrixChunk<BF>,
log_trace_len: u32,
natural_log_lde_factor: u32,
log_values_per_leaf: u32,
Expand All @@ -815,7 +823,8 @@
let dst_ptr = unified_cap_mut.as_mut_ptr() as *mut u32;
let dst_len = cap_size * BLAKE2S_DIGEST_SIZE_U32_WORDS;
let dst_u32 = unsafe { DeviceSlice::from_raw_parts_mut(dst_ptr, dst_len) };
self.commit_all_into_from_ntt(
self.whir_lde_and_commit_all_into(
inputs_matrix,
dst_u32,
log_trace_len,
natural_log_lde_factor,
Expand Down Expand Up @@ -1270,6 +1279,7 @@
/// (typically `whir_steps_lde_factors[i].trailing_zeros()`), NOT the
/// `TraceHolder`'s `log_lde_factor`.
pub(crate) fn commit_trace_from_ntt_single_tree(
inputs_matrix: &DeviceMatrixChunk<BF>,
ntt_output: &mut DeviceSlice<BF>,
trees_backing: &mut DeviceSlice<Digest>,
log_trace_len: u32,
Expand All @@ -1282,6 +1292,7 @@
) -> CudaResult<()> {
assert!(natural_log_lde_factor >= 1);
assert!(log_trace_len >= log_values_per_leaf);
let trace_len = 1 << log_trace_len;
let packed_leaf_count = 1usize << (log_trace_len - log_values_per_leaf);
let total_leaf_count = packed_leaf_count
.checked_mul(1 << natural_log_lde_factor)
Expand All @@ -1292,22 +1303,64 @@
let layers_count = total_leaf_count_log2 + 1 - log_tree_cap_size;
let (leaves, nodes) = trees_backing.split_at_mut(total_leaf_count);
let stream = context.get_exec_stream();
// TODO: Accept cosets_in_tile and coset_index_base as arguments for coset-based
// L2 chunking of the full NTT->transform->commit rows->commit nodes sequence.
if transform_leaves_to_multilinear_coeffs {
let l2_bytes = context.get_device_properties().l2_cache_size_bytes;
let single_bf_col_bytes = std::mem::size_of::<BF>() << log_trace_len;
let single_coset_bytes = src_cols_per_coset * single_bf_col_bytes;
let cosets_in_tile_chunk = if l2_bytes >= single_coset_bytes {
l2_bytes / single_coset_bytes

let device_properties = context.get_device_properties();
let ntt_ctx = context.ntt_device_context();
// Recursive WHIR folds to a small trace (trace_len_log2 <= 13), the DIT
// forward-NTT range, which needs a pooled d-table scratch (len >= N).
// Allocate from the stream-ordered pool so this stays enqueue-only per
// the GPU scheduling contract; the handle outlives the launches below.
// Outside the DIT range the compact path ignores the scratch, so the
// allocation is skipped entirely.
let mut d_scratch = if log_trace_len <= 13 {
Some(context.alloc::<BF>(trace_len, AllocationPlacement::BestFit)?)
} else {
None
};

let total_cosets = 1 << natural_log_lde_factor;
let l2_bytes = device_properties.l2_cache_size_bytes;
let single_bf_col_bytes = std::mem::size_of::<BF>() << log_trace_len;
let single_coset_bytes = src_cols_per_coset * single_bf_col_bytes;
let cosets_in_tile_chunk = if l2_bytes >= single_coset_bytes {
let nearest = l2_bytes / single_coset_bytes;
if nearest.is_power_of_two() {
nearest
} else {
1
};
let mut ntt_output_matrix = DeviceMatrixMut::new(ntt_output, 1 << log_trace_len);
let total_cosets = 1 << natural_log_lde_factor;
for coset_index_base in (0..total_cosets).step_by(cosets_in_tile_chunk) {
let cosets_in_tile =
std::cmp::min(cosets_in_tile_chunk, total_cosets - coset_index_base);
nearest.next_power_of_two() >> 1
}
} else {
// don't bother with chunking
total_cosets
};
if total_cosets > cosets_in_tile_chunk {
assert_eq!(total_cosets % cosets_in_tile_chunk, 0);
}

let mut ntt_output_matrix = DeviceMatrixMut::new(ntt_output, trace_len);

Check warning on line 1340 in gpu/circuit_prover/src/prover/trace/holder/mod.rs

View workflow job for this annotation

GitHub Actions / cargo fmt

Diff in /home/runner/work/zksync-airbender/zksync-airbender/gpu/circuit_prover/src/prover/trace/holder/mod.rs

for coset_index_base in (0..total_cosets).step_by(cosets_in_tile_chunk) {
let cosets_in_tile =
std::cmp::min(cosets_in_tile_chunk, total_cosets - coset_index_base);
// The NTT and hashing APIs don't internally apply an offset for coset_index_base,
// so for them we have to manually select the start in ntt_output.
let offset = src_cols_per_coset * trace_len * coset_index_base;
let scratch_opt = d_scratch.as_mut().map(|s| &mut s[..]);
bitreversed_monomials_to_natural_evals_multi_coset_with_coset_range(
inputs_matrix,
&mut (ntt_output_matrix.slice_mut())[offset..],
log_trace_len as usize,
natural_log_lde_factor as usize,
cosets_in_tile,
coset_index_base,
src_cols_per_coset,
false,
ntt_ctx,
scratch_opt,
stream,
device_properties,
)?;
if transform_leaves_to_multilinear_coeffs {
transform_whir_leaves_from_ntt_in_place_multi_coset(
&mut ntt_output_matrix,
log_trace_len,
Expand All @@ -1318,35 +1371,21 @@
src_cols_per_coset as u32,
stream,
)?;
crate::ops::blake2s::launch_leaves_kernel_from_ntt_multi_coset(
ntt_output_matrix.slice(),
leaves,
log_values_per_leaf,
src_cols_per_coset as u32,
natural_log_lde_factor,
coset_index_base as u32,
cosets_in_tile,
packed_leaf_count,
1u32 << log_trace_len,
stream,
)?;
}
} else {
let coset_index_base = 0;
let cosets_in_tile = 1usize << natural_log_lde_factor;
crate::ops::blake2s::launch_leaves_kernel_from_ntt_multi_coset(
ntt_output,
&(ntt_output_matrix.slice())[offset..],
leaves,
log_values_per_leaf,
src_cols_per_coset as u32,
natural_log_lde_factor,
coset_index_base,
coset_index_base as u32,
cosets_in_tile,
packed_leaf_count,
1u32 << log_trace_len,
trace_len as u32,
stream,
)?;
}

// Single-tree node layers: build_merkle_tree_nodes operates on a flat
// `[leaves | nodes]` slab. `layers_count - 1` because the leaf layer is
// already written; the function builds the remaining `layers_count - 1`
Expand Down
Loading
Loading