Skip to content
This repository was archived by the owner on Apr 9, 2024. It is now read-only.

Commit 040369a

Browse files
feat(stdlib): Add fallback implementation of SHA256 black box function (#407)
Co-authored-by: kevaundray <[email protected]>
1 parent 967ec81 commit 040369a

File tree

12 files changed

+1360
-73
lines changed

12 files changed

+1360
-73
lines changed

acvm/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ async-trait = "0.1"
2828
default = ["bn254"]
2929
bn254 = ["acir/bn254", "stdlib/bn254", "brillig_vm/bn254", "blackbox_solver/bn254"]
3030
bls12_381 = ["acir/bls12_381", "stdlib/bls12_381", "brillig_vm/bls12_381", "blackbox_solver/bls12_381"]
31+
testing = ["stdlib/testing", "unstable-fallbacks"]
32+
unstable-fallbacks = []
3133

3234
[dev-dependencies]
3335
rand = "0.8.5"
36+
proptest = "1.2.0"

acvm/src/compiler/transformers/fallback.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ impl FallbackTransformer {
7575
lhs.num_bits, rhs.num_bits,
7676
"number of bits specified for each input must be the same"
7777
);
78-
stdlib::fallback::and(
78+
stdlib::blackbox_fallbacks::and(
7979
Expression::from(lhs.witness),
8080
Expression::from(rhs.witness),
8181
*output,
@@ -88,7 +88,7 @@ impl FallbackTransformer {
8888
lhs.num_bits, rhs.num_bits,
8989
"number of bits specified for each input must be the same"
9090
);
91-
stdlib::fallback::xor(
91+
stdlib::blackbox_fallbacks::xor(
9292
Expression::from(lhs.witness),
9393
Expression::from(rhs.witness),
9494
*output,
@@ -98,12 +98,26 @@ impl FallbackTransformer {
9898
}
9999
BlackBoxFuncCall::RANGE { input } => {
100100
// Note there are no outputs because range produces no outputs
101-
stdlib::fallback::range(
101+
stdlib::blackbox_fallbacks::range(
102102
Expression::from(input.witness),
103103
input.num_bits,
104104
current_witness_idx,
105105
)
106106
}
107+
#[cfg(feature = "unstable-fallbacks")]
108+
BlackBoxFuncCall::SHA256 { inputs, outputs } => {
109+
let mut sha256_inputs = Vec::new();
110+
for input in inputs.iter() {
111+
let witness_index = Expression::from(input.witness);
112+
let num_bits = input.num_bits;
113+
sha256_inputs.push((witness_index, num_bits));
114+
}
115+
stdlib::blackbox_fallbacks::sha256(
116+
sha256_inputs,
117+
outputs.to_vec(),
118+
current_witness_idx,
119+
)
120+
}
107121
_ => {
108122
return Err(CompileError::UnsupportedBlackBox(gc.get_black_box_func()));
109123
}

acvm/src/pwg/directives/sorting.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ pub(super) fn route(inputs: Vec<FieldElement>, outputs: Vec<FieldElement>) -> Ve
247247
mod tests {
248248
use super::route;
249249
use acir::FieldElement;
250+
use proptest as _;
250251
use rand::prelude::*;
251252

252253
fn execute_network(config: Vec<bool>, inputs: Vec<FieldElement>) -> Vec<FieldElement> {

acvm/tests/solver.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use acvm::{
1717
};
1818
use blackbox_solver::BlackBoxResolutionError;
1919

20-
struct StubbedBackend;
20+
pub(crate) struct StubbedBackend;
2121

2222
impl BlackBoxFunctionSolver for StubbedBackend {
2323
fn schnorr_verify(

acvm/tests/stdlib.rs

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
#![cfg(feature = "testing")]
2+
mod solver;
3+
use crate::solver::StubbedBackend;
4+
use acir::{
5+
circuit::{
6+
opcodes::{BlackBoxFuncCall, FunctionInput},
7+
Circuit, Opcode, PublicInputs,
8+
},
9+
native_types::Witness,
10+
FieldElement,
11+
};
12+
use acvm::{
13+
compiler::{compile, CircuitSimplifier},
14+
pwg::{ACVMStatus, ACVM},
15+
Language,
16+
};
17+
use proptest::prelude::*;
18+
use sha2::{Digest, Sha256};
19+
use std::collections::{BTreeMap, BTreeSet};
20+
use stdlib::blackbox_fallbacks::UInt32;
21+
22+
proptest! {
23+
#[test]
24+
fn test_uint32_ror(x in 0..u32::MAX, y in 0..32_u32) {
25+
let fe = FieldElement::from(x as u128);
26+
let w = Witness(1);
27+
let result = x.rotate_right(y);
28+
let sha256_u32 = UInt32::new(w);
29+
let (w, extra_gates, _) = sha256_u32.ror(y, 2);
30+
let witness_assignments = BTreeMap::from([(Witness(1), fe)]).into();
31+
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
32+
let solver_status = acvm.solve();
33+
34+
prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
35+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
36+
}
37+
38+
#[test]
39+
fn test_uint32_euclidean_division(x in 0..u32::MAX, y in 0..u32::MAX) {
40+
let lhs = FieldElement::from(x as u128);
41+
let rhs = FieldElement::from(y as u128);
42+
let w1 = Witness(1);
43+
let w2 = Witness(2);
44+
let q = x.div_euclid(y);
45+
let r = x.rem_euclid(y);
46+
let u32_1 = UInt32::new(w1);
47+
let u32_2 = UInt32::new(w2);
48+
let (q_w, r_w, extra_gates, _) = UInt32::euclidean_division(&u32_1, &u32_2, 3);
49+
let witness_assignments = BTreeMap::from([(Witness(1), lhs),(Witness(2), rhs)]).into();
50+
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
51+
let solver_status = acvm.solve();
52+
53+
prop_assert_eq!(acvm.witness_map().get(&q_w.get_inner()).unwrap(), &FieldElement::from(q as u128));
54+
prop_assert_eq!(acvm.witness_map().get(&r_w.get_inner()).unwrap(), &FieldElement::from(r as u128));
55+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
56+
}
57+
58+
#[test]
59+
fn test_uint32_add(x in 0..u32::MAX, y in 0..u32::MAX, z in 0..u32::MAX) {
60+
let lhs = FieldElement::from(x as u128);
61+
let rhs = FieldElement::from(y as u128);
62+
let rhs_z = FieldElement::from(z as u128);
63+
let result = FieldElement::from(((x as u128).wrapping_add(y as u128) % (1_u128 << 32)).wrapping_add(z as u128) % (1_u128 << 32));
64+
let w1 = Witness(1);
65+
let w2 = Witness(2);
66+
let w3 = Witness(3);
67+
let u32_1 = UInt32::new(w1);
68+
let u32_2 = UInt32::new(w2);
69+
let u32_3 = UInt32::new(w3);
70+
let mut gates = Vec::new();
71+
let (w, extra_gates, num_witness) = u32_1.add(&u32_2, 4);
72+
gates.extend(extra_gates);
73+
let (w2, extra_gates, _) = w.add(&u32_3, num_witness);
74+
gates.extend(extra_gates);
75+
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into();
76+
let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments);
77+
let solver_status = acvm.solve();
78+
79+
prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result);
80+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
81+
}
82+
83+
#[test]
84+
fn test_uint32_sub(x in 0..u32::MAX, y in 0..u32::MAX, z in 0..u32::MAX) {
85+
let lhs = FieldElement::from(x as u128);
86+
let rhs = FieldElement::from(y as u128);
87+
let rhs_z = FieldElement::from(z as u128);
88+
let result = FieldElement::from(((x as u128).wrapping_sub(y as u128) % (1_u128 << 32)).wrapping_sub(z as u128) % (1_u128 << 32));
89+
let w1 = Witness(1);
90+
let w2 = Witness(2);
91+
let w3 = Witness(3);
92+
let u32_1 = UInt32::new(w1);
93+
let u32_2 = UInt32::new(w2);
94+
let u32_3 = UInt32::new(w3);
95+
let mut gates = Vec::new();
96+
let (w, extra_gates, num_witness) = u32_1.sub(&u32_2, 4);
97+
gates.extend(extra_gates);
98+
let (w2, extra_gates, _) = w.sub(&u32_3, num_witness);
99+
gates.extend(extra_gates);
100+
let witness_assignments = BTreeMap::from([(Witness(1), lhs), (Witness(2), rhs), (Witness(3), rhs_z)]).into();
101+
let mut acvm = ACVM::new(StubbedBackend, gates, witness_assignments);
102+
let solver_status = acvm.solve();
103+
104+
prop_assert_eq!(acvm.witness_map().get(&w2.get_inner()).unwrap(), &result);
105+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
106+
}
107+
108+
#[test]
109+
fn test_uint32_left_shift(x in 0..u32::MAX, y in 0..32_u32) {
110+
let lhs = FieldElement::from(x as u128);
111+
let w1 = Witness(1);
112+
let result = x.overflowing_shl(y).0;
113+
let u32_1 = UInt32::new(w1);
114+
let (w, extra_gates, _) = u32_1.leftshift(y, 2);
115+
let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into();
116+
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
117+
let solver_status = acvm.solve();
118+
119+
prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
120+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
121+
}
122+
123+
#[test]
124+
fn test_uint32_right_shift(x in 0..u32::MAX, y in 0..32_u32) {
125+
let lhs = FieldElement::from(x as u128);
126+
let w1 = Witness(1);
127+
let result = x.overflowing_shr(y).0;
128+
let u32_1 = UInt32::new(w1);
129+
let (w, extra_gates, _) = u32_1.rightshift(y, 2);
130+
let witness_assignments = BTreeMap::from([(Witness(1), lhs)]).into();
131+
let mut acvm = ACVM::new(StubbedBackend, extra_gates, witness_assignments);
132+
let solver_status = acvm.solve();
133+
134+
prop_assert_eq!(acvm.witness_map().get(&w.get_inner()).unwrap(), &FieldElement::from(result as u128));
135+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
136+
}
137+
}
138+
139+
proptest! {
140+
#![proptest_config(ProptestConfig::with_cases(3))]
141+
#[test]
142+
fn test_sha256(input_values in proptest::collection::vec(0..u8::MAX, 1..50)) {
143+
let mut opcodes = Vec::new();
144+
let mut witness_assignments = BTreeMap::new();
145+
let mut sha256_input_witnesses: Vec<FunctionInput> = Vec::new();
146+
let mut correct_result_witnesses: Vec<Witness> = Vec::new();
147+
let mut output_witnesses: Vec<Witness> = Vec::new();
148+
149+
// prepare test data
150+
hash_witnesses!(input_values, witness_assignments, sha256_input_witnesses, correct_result_witnesses, output_witnesses, Sha256);
151+
let sha256_blackbox = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SHA256 { inputs: sha256_input_witnesses, outputs: output_witnesses });
152+
opcodes.push(sha256_blackbox);
153+
154+
// compile circuit
155+
let circuit_simplifier = CircuitSimplifier::new(witness_assignments.len() as u32 + 32);
156+
let circuit = Circuit {current_witness_index: witness_assignments.len() as u32 + 32,
157+
opcodes, public_parameters: PublicInputs(BTreeSet::new()), return_values: PublicInputs(BTreeSet::new()) };
158+
let circuit = compile(circuit, Language::PLONKCSat{ width: 3 }, does_not_support_sha256, &circuit_simplifier).unwrap().0;
159+
160+
// solve witnesses
161+
let mut acvm = ACVM::new(StubbedBackend, circuit.opcodes, witness_assignments.into());
162+
let solver_status = acvm.solve();
163+
164+
prop_assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved");
165+
}
166+
}
167+
168+
fn does_not_support_sha256(opcode: &Opcode) -> bool {
169+
!matches!(opcode, Opcode::BlackBoxFuncCall(BlackBoxFuncCall::SHA256 { .. }))
170+
}
171+
172+
#[macro_export]
173+
macro_rules! hash_witnesses {
174+
(
175+
$input_values:ident,
176+
$witness_assignments:ident,
177+
$input_witnesses: ident,
178+
$correct_result_witnesses:ident,
179+
$output_witnesses:ident,
180+
$hasher:ident
181+
) => {
182+
let mut counter = 0;
183+
let output = $hasher::digest($input_values.clone());
184+
for inp_v in $input_values {
185+
counter += 1;
186+
let function_input = FunctionInput { witness: Witness(counter), num_bits: 8 };
187+
$input_witnesses.push(function_input);
188+
$witness_assignments.insert(Witness(counter), FieldElement::from(inp_v as u128));
189+
}
190+
191+
for o_v in output {
192+
counter += 1;
193+
$correct_result_witnesses.push(Witness(counter));
194+
$witness_assignments.insert(Witness(counter), FieldElement::from(o_v as u128));
195+
}
196+
197+
for _ in 0..32 {
198+
counter += 1;
199+
$output_witnesses.push(Witness(counter));
200+
}
201+
};
202+
}

stdlib/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ acir.workspace = true
1717
default = ["bn254"]
1818
bn254 = ["acir/bn254"]
1919
bls12_381 = ["acir/bls12_381"]
20+
testing = ["bn254"]

stdlib/src/fallback.rs renamed to stdlib/src/blackbox_fallbacks/logic_fallbacks.rs

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,10 @@
1-
use crate::helpers::VariableStore;
1+
use super::utils::bit_decomposition;
22
use acir::{
33
acir_field::FieldElement,
4-
circuit::{directives::Directive, Opcode},
4+
circuit::Opcode,
55
native_types::{Expression, Witness},
66
};
77

8-
// Perform bit decomposition on the provided expression
9-
#[deprecated(note = "use bit_decomposition function instead")]
10-
pub fn split(
11-
gate: Expression,
12-
bit_size: u32,
13-
num_witness: u32,
14-
new_gates: &mut Vec<Opcode>,
15-
) -> Vec<Witness> {
16-
let (extra_gates, bits, _) = bit_decomposition(gate, bit_size, num_witness);
17-
new_gates.extend(extra_gates);
18-
bits
19-
}
20-
21-
// Generates opcodes and directives to bit decompose the input `gate`
22-
// Returns the bits and the updated witness counter
23-
// TODO:Ideally, we return the updated witness counter, or we require the input
24-
// TODO to be a VariableStore. We are not doing this because we want migration to
25-
// TODO be less painful
26-
pub(crate) fn bit_decomposition(
27-
gate: Expression,
28-
bit_size: u32,
29-
mut num_witness: u32,
30-
) -> (Vec<Opcode>, Vec<Witness>, u32) {
31-
let mut new_gates = Vec::new();
32-
let mut variables = VariableStore::new(&mut num_witness);
33-
34-
// First create a witness for each bit
35-
let mut bit_vector = Vec::with_capacity(bit_size as usize);
36-
for _ in 0..bit_size {
37-
bit_vector.push(variables.new_variable())
38-
}
39-
40-
// Next create a directive which computes those bits.
41-
new_gates.push(Opcode::Directive(Directive::ToLeRadix {
42-
a: gate.clone(),
43-
b: bit_vector.clone(),
44-
radix: 2,
45-
}));
46-
47-
// Now apply constraints to the bits such that they are the bit decomposition
48-
// of the input and each bit is actually a bit
49-
let mut binary_exprs = Vec::new();
50-
let mut bit_decomp_constraint = gate;
51-
let mut two_pow: FieldElement = FieldElement::one();
52-
let two = FieldElement::from(2_i128);
53-
for &bit in &bit_vector {
54-
// Bit constraint to ensure each bit is a zero or one; bit^2 - bit = 0
55-
let mut expr = Expression::default();
56-
expr.push_multiplication_term(FieldElement::one(), bit, bit);
57-
expr.push_addition_term(-FieldElement::one(), bit);
58-
binary_exprs.push(Opcode::Arithmetic(expr));
59-
60-
// Constraint to ensure that the bits are constrained to be a bit decomposition
61-
// of the input
62-
// ie \sum 2^i * x_i = input
63-
bit_decomp_constraint.push_addition_term(-two_pow, bit);
64-
two_pow = two * two_pow;
65-
}
66-
67-
new_gates.extend(binary_exprs);
68-
bit_decomp_constraint.sort(); // TODO: we have an issue open to check if this is needed. Ideally, we remove it.
69-
new_gates.push(Opcode::Arithmetic(bit_decomp_constraint));
70-
71-
(new_gates, bit_vector, variables.finalize())
72-
}
73-
748
// Range constraint
759
pub fn range(gate: Expression, bit_size: u32, num_witness: u32) -> (u32, Vec<Opcode>) {
7610
let (new_gates, _, updated_witness_counter) = bit_decomposition(gate, bit_size, num_witness);

stdlib/src/blackbox_fallbacks/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
mod logic_fallbacks;
2+
mod sha256;
3+
mod uint32;
4+
mod utils;
5+
pub use logic_fallbacks::{and, range, xor};
6+
pub use sha256::sha256;
7+
pub use uint32::UInt32;

0 commit comments

Comments
 (0)