Skip to content

Commit bff6d57

Browse files
nyunyunyunyujonathanpwang
authored andcommitted
[perf] Add CircuitVarTo64BitsF for split_32 (#80)
Reth-benchmark: https://github.com/axiom-crypto/openvm-reth-benchmark/actions/runs/13730024145 Static verifier performance is back to before #5 Used cells: 57M(before #5) v.s. 62M(after this PR) It seems that the number of columns are same so the actual proof times are close. This actually fixes a soundness issue. `CircuitNum2BitsV` cannot decompose a `Var` into unique limbs.
1 parent 1b83ba7 commit bff6d57

File tree

6 files changed

+126
-98
lines changed

6 files changed

+126
-98
lines changed

extensions/native/compiler/src/constraints/halo2/compiler.rs

Lines changed: 87 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ use openvm_stark_sdk::{p3_baby_bear::BabyBear, p3_bn254_fr::Bn254Fr};
1414
use snark_verifier_sdk::snark_verifier::{
1515
halo2_base::{
1616
gates::{
17-
circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip, RangeInstructions,
17+
circuit::builder::BaseCircuitBuilder, GateChip, GateInstructions, RangeChip,
18+
RangeInstructions,
1819
},
1920
halo2_proofs::halo2curves::bn256::Fr,
20-
utils::{biguint_to_fe, ScalarField},
21-
Context,
21+
utils::{biguint_to_fe, decompose_fe_to_u64_limbs, ScalarField},
22+
AssignedValue, Context, QuantumCell,
2223
},
23-
util::arithmetic::PrimeField as _,
24+
util::arithmetic::{Field as _, PrimeField as _},
2425
};
2526

2627
use super::stats::Halo2Stats;
@@ -314,15 +315,6 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
314315
let reduced_felt = f_chip.reduce(ctx, felt);
315316
vars.insert(a.0, reduced_felt.value);
316317
}
317-
DslIr::CircuitNum2BitsV(value, bits, output) => {
318-
let shortened_bits = bits.min(Fr::NUM_BITS as usize);
319-
let mut x = gate.num_to_bits(ctx, vars[&value.0], shortened_bits);
320-
let zero = ctx.load_zero();
321-
x.resize(bits, zero);
322-
for (o, x) in output.into_iter().zip_eq(x) {
323-
vars.insert(o.0, x);
324-
}
325-
}
326318
DslIr::CircuitNum2BitsF(value, output) => {
327319
let val = f_chip.reduce(ctx, felts[&value.0]);
328320
let x = gate.num_to_bits(ctx, val.value, 32); // C::F::bits());
@@ -331,6 +323,13 @@ impl<C: Config + Debug> Halo2ConstraintCompiler<C> {
331323
vars.insert(o.0, x);
332324
}
333325
}
326+
DslIr::CircuitVarTo64BitsF(value, output) => {
327+
let x = vars[&value.0];
328+
let limbs = var_to_u64_limbs(ctx, &range, gate, x);
329+
for (o, l) in output.into_iter().zip(limbs) {
330+
felts.insert(o.0, l);
331+
}
332+
}
334333
DslIr::CircuitPoseidon2Permute(state_vars) => {
335334
let mut state =
336335
Poseidon2State::<Fr, POSEIDON2_T>::new(state_vars.map(|x| vars[&x.0]));
@@ -526,10 +525,79 @@ fn is_babybear_ir<C: Config>(ir: &DslIr<C>) -> bool {
526525
)
527526
}
528527

529-
#[allow(dead_code)]
530-
fn is_num2bits_ir<C: Config>(ir: &DslIr<C>) -> bool {
531-
matches!(
532-
ir,
533-
DslIr::CircuitNum2BitsV(_, _, _) | DslIr::CircuitNum2BitsF(_, _)
534-
)
528+
fn fr_to_u64_limbs(fr: &Fr) -> [u64; 4] {
529+
// We need 64-bit limbs but `decompose_fe_to_u64_limbs` only support `bit_len < 64`.
530+
let limbs = decompose_fe_to_u64_limbs(fr, 8, 32);
531+
std::array::from_fn(|i| limbs[2 * i] + limbs[2 * i + 1] * (1 << 32))
535532
}
533+
534+
fn var_to_u64_limbs(
535+
ctx: &mut Context<Fr>,
536+
range: &RangeChip<Fr>,
537+
gate: &GateChip<Fr>,
538+
x: AssignedValue<Fr>,
539+
) -> [AssignedBabyBear; 4] {
540+
let limbs = fr_to_u64_limbs(x.value()).map(|limb| ctx.load_witness(Fr::from(limb)));
541+
let factors = [
542+
Fr::from([1, 0, 0, 0]),
543+
Fr::from([0, 1, 0, 0]),
544+
Fr::from([0, 0, 1, 0]),
545+
Fr::from([0, 0, 0, 1]),
546+
];
547+
let sum = gate.inner_product(ctx, limbs, factors.map(QuantumCell::Constant));
548+
ctx.constrain_equal(&sum, &x);
549+
let fr_bound_limbs = fr_to_u64_limbs(&(Fr::ZERO - Fr::ONE));
550+
let ret = std::array::from_fn(|i| {
551+
let limb = limbs[i];
552+
let bits = if i < 3 {
553+
range.range_check(ctx, limb, 64);
554+
64
555+
} else {
556+
range.check_less_than_safe(ctx, limbs[3], fr_bound_limbs[3] + 1);
557+
(Fr::NUM_BITS - 3 * 64) as usize
558+
};
559+
AssignedBabyBear {
560+
value: limb,
561+
max_bits: bits,
562+
}
563+
});
564+
// Constraint decomposition doesn't overflow.
565+
// Whether limbs[i] == fr_bound_limbs[i] so far
566+
let mut on_bound = gate.is_equal(
567+
ctx,
568+
limbs[3],
569+
QuantumCell::Constant(Fr::from(fr_bound_limbs[3])),
570+
);
571+
for i in (0..3).rev() {
572+
// limbs[i] > fr_bound_limbs[i]
573+
let li_gt_bd = range.is_less_than(
574+
ctx,
575+
QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
576+
limbs[i],
577+
64,
578+
);
579+
let li_out_bd = gate.add(ctx, on_bound, li_gt_bd);
580+
// on_bound li_gt_bd result
581+
// 1 1 fail
582+
// 1 0 pass
583+
// 0 1 pass
584+
// 0 0 pass
585+
gate.assert_bit(ctx, li_out_bd);
586+
// Update on_bound except the last limb
587+
if i > 0 {
588+
debug_assert_ne!(fr_bound_limbs[i], 0, "This should never happen for Bn254Fr");
589+
// on_bound && limbs[i] - fr_bound_limbs[i] == 0
590+
let diff = gate.sub_mul(
591+
ctx,
592+
QuantumCell::Constant(Fr::from(fr_bound_limbs[i])),
593+
on_bound,
594+
limbs[i],
595+
);
596+
on_bound = gate.is_zero(ctx, diff);
597+
}
598+
}
599+
ret
600+
}
601+
602+
#[test]
603+
fn test_var_to_u64_limbs() {}

extensions/native/compiler/src/constraints/mod.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,6 @@ impl<C: Config + Debug> ConstraintCompiler<C> {
236236
opcode: ConstraintOpcode::NegE,
237237
args: vec![vec![a.id()], vec![b.id()]],
238238
}),
239-
DslIr::CircuitNum2BitsV(value, bits, output) => constraints.push(Constraint {
240-
opcode: ConstraintOpcode::Num2BitsV,
241-
args: vec![
242-
output.iter().map(|x| x.id()).collect(),
243-
vec![value.id()],
244-
vec![bits.to_string()],
245-
],
246-
}),
247239
DslIr::CircuitNum2BitsF(value, output) => constraints.push(Constraint {
248240
opcode: ConstraintOpcode::Num2BitsF,
249241
args: vec![output.iter().map(|x| x.id()).collect(), vec![value.id()]],

extensions/native/compiler/src/ir/bits.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,11 @@
1-
use std::any::TypeId;
1+
use std::{any::TypeId, array};
22

33
use openvm_stark_backend::p3_field::FieldAlgebra;
44
use openvm_stark_sdk::p3_baby_bear::BabyBear;
55

66
use super::{Array, Builder, Config, DslIr, Felt, MemIndex, Var};
77

88
impl<C: Config> Builder<C> {
9-
/// Converts a variable to bits inside a circuit.
10-
pub fn num2bits_v_circuit(&mut self, num: Var<C::N>, bits: usize) -> Vec<Var<C::N>> {
11-
let mut output = Vec::new();
12-
for _ in 0..bits {
13-
output.push(self.uninit());
14-
}
15-
16-
self.push(DslIr::CircuitNum2BitsV(num, bits, output.clone()));
17-
18-
output
19-
}
20-
219
/// Converts a felt to bits. Will result in a failed assertion if `num` has more than `num_bits` bits.
2210
/// Only works for C::F = BabyBear
2311
pub fn num2bits_f(&mut self, num: Felt<C::F>, num_bits: u32) -> Array<C, Var<C::N>> {
@@ -94,4 +82,11 @@ impl<C: Config> Builder<C> {
9482
}
9583
result
9684
}
85+
86+
/// Decompose a Var into 64-bit Felt limbs.
87+
pub fn var_to_64bits_f_circuit(&mut self, value: Var<C::N>) -> [Felt<C::F>; 4] {
88+
let ret = array::from_fn(|_| self.uninit());
89+
self.push(DslIr::CircuitVarTo64BitsF(value, ret));
90+
ret
91+
}
9792
}

extensions/native/compiler/src/ir/instructions.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ pub enum DslIr<C: Config> {
179179
StoreHeapPtr(Ptr<C::N>),
180180

181181
// Bits.
182-
/// Decompose a variable into size bits (bits = num2bits(var, size)). Should only be used when target is a circuit.
183-
CircuitNum2BitsV(Var<C::N>, usize, Vec<Var<C::N>>),
184182
/// Decompose a field element into bits (bits = num2bits(felt)). Should only be used when target is a circuit.
185183
CircuitNum2BitsF(Felt<C::F>, Vec<Var<C::N>>),
184+
/// Decompose a Var into 16-bit limbs.
185+
CircuitVarTo64BitsF(Var<C::N>, [Felt<C::F>; 4]),
186186

187187
// Hashing.
188188
/// Permutes an array of baby bear elements using Poseidon2 (output = p2_permute(array)).

extensions/native/recursion/src/halo2/tests/mod.rs

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -93,28 +93,6 @@ fn test_publish() {
9393
assert_eq!(pis, vec![vec![convert_fr(&value_fr)]]);
9494
}
9595

96-
#[test]
97-
fn test_num2bits_v_circuit() {
98-
let mut builder = Builder::<OuterConfig>::default();
99-
builder.flags.static_only = true;
100-
let mut value_u32 = 1345237507;
101-
let value = builder.eval(Bn254Fr::from_canonical_u32(value_u32));
102-
let result = builder.num2bits_v_circuit(value, 32);
103-
for r in result {
104-
builder.assert_var_eq(r, Bn254Fr::from_canonical_u32(value_u32 & 1));
105-
value_u32 >>= 1;
106-
}
107-
108-
Halo2Prover::mock::<OuterConfig>(
109-
10,
110-
DslOperations {
111-
operations: builder.operations,
112-
num_public_values: 0,
113-
},
114-
Witness::default(),
115-
);
116-
}
117-
11896
#[test]
11997
fn test_reduce_32() {
12098
let value_1 = BabyBear::from_canonical_u32(1345237507);
@@ -140,27 +118,32 @@ fn test_reduce_32() {
140118

141119
#[test]
142120
fn test_split_32() {
143-
let value = Bn254Fr::from_canonical_u32(1345237507);
144-
let gt: Vec<BabyBear> = split_32_gt(value, 3);
145-
dbg!(&gt);
146-
147-
let mut builder = Builder::<OuterConfig>::default();
148-
builder.flags.static_only = true;
149-
let value = builder.eval(value);
150-
let result = split_32(&mut builder, value, 3);
151-
152-
builder.assert_felt_eq(result[0], gt[0]);
153-
builder.assert_felt_eq(result[1], gt[1]);
154-
builder.assert_felt_eq(result[2], gt[2]);
155-
156-
Halo2Prover::mock::<OuterConfig>(
157-
10,
158-
DslOperations {
159-
operations: builder.operations,
160-
num_public_values: 0,
161-
},
162-
Witness::default(),
163-
);
121+
let f = |value| {
122+
let gt: Vec<BabyBear> = split_32_gt(value, 3);
123+
dbg!(&gt);
124+
125+
let mut builder = Builder::<OuterConfig>::default();
126+
builder.flags.static_only = true;
127+
let value = builder.eval(value);
128+
let result = split_32(&mut builder, value, 3);
129+
130+
builder.assert_felt_eq(result[0], gt[0]);
131+
builder.assert_felt_eq(result[1], gt[1]);
132+
builder.assert_felt_eq(result[2], gt[2]);
133+
134+
Halo2Prover::mock::<OuterConfig>(
135+
10,
136+
DslOperations {
137+
operations: builder.operations,
138+
num_public_values: 0,
139+
},
140+
Witness::default(),
141+
);
142+
};
143+
let modulus = Bn254Fr::ZERO - Bn254Fr::ONE;
144+
f(Bn254Fr::from_canonical_u32(1345237507));
145+
f(Bn254Fr::ZERO);
146+
f(modulus);
164147
}
165148

166149
#[test]

extensions/native/recursion/src/utils.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,9 @@ pub fn reduce_32<C: Config>(builder: &mut Builder<C>, vals: &[Felt<C::F>]) -> Va
5252

5353
/// Reference: <https://github.com/Plonky3/Plonky3/blob/622375885320ac6bf3c338001760ed8f2230e3cb/field/src/helpers.rs#L149>
5454
pub fn split_32<C: Config>(builder: &mut Builder<C>, val: Var<C::N>, n: usize) -> Vec<Felt<C::F>> {
55-
let bits = builder.num2bits_v_circuit(val, 256);
56-
let mut results = Vec::new();
57-
for i in 0..n {
58-
let result: Felt<C::F> = builder.eval(C::F::ZERO);
59-
for j in 0..64 {
60-
let bit = bits[i * 64 + j];
61-
let t = builder.eval(result + C::F::from_wrapped_u64(1 << j));
62-
let z = builder.select_f(bit, t, result);
63-
builder.assign(&result, z);
64-
}
65-
results.push(result);
66-
}
67-
results
55+
let felts = builder.var_to_64bits_f_circuit(val);
56+
assert!(n <= felts.len());
57+
felts[0..n].to_vec()
6858
}
6959

7060
/// Eval two expressions, return in the reversed order if cond == 1. Otherwise, return in the original order.

0 commit comments

Comments
 (0)