Skip to content

Commit 2a0ef4c

Browse files
committed
experiment checkpoint
1 parent 739a2d9 commit 2a0ef4c

File tree

8 files changed

+136
-32
lines changed

8 files changed

+136
-32
lines changed

frontend/api.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ type API interface {
3030
// Add returns res = i1+i2+...in
3131
Add(i1, i2 Variable, in ...Variable) Variable
3232

33+
// MAC sets and return a = a + (b*c)
34+
// ! may mutate a without allocating a new result
35+
// ! but behavior is not stable, use with caution
36+
// ! always use MAC(...) result for correctness
37+
MAC(a Variable, b, c Variable) Variable
38+
3339
// Neg returns -i
3440
Neg(i1 Variable) Variable
3541

frontend/cs/r1cs/api.go

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,85 @@ import (
4040
func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
4141
// extract frontend.Variables from input
4242
vars, s := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...)
43-
return builder.add(vars, false, s)
43+
return builder.add(vars, false, s, nil)
4444

4545
}
4646

47+
func (builder *builder) MAC(a frontend.Variable, b, c frontend.Variable) frontend.Variable {
48+
builder.macBuffer1 = builder.macBuffer1[:0]
49+
// do the multiplication
50+
mul := func() {
51+
n1, v1Constant := builder.constantValue(b)
52+
n2, v2Constant := builder.constantValue(c)
53+
54+
// v1 and v2 are both unknown, this is the only case we add a constraint
55+
if !v1Constant && !v2Constant {
56+
res := builder.newInternalVariable()
57+
builder.cs.AddConstraint(builder.newR1C(b, c, res))
58+
builder.macBuffer1 = append(builder.macBuffer1, res...)
59+
return
60+
}
61+
62+
// v1 and v2 are constants, we multiply big.Int values and return resulting constant
63+
if v1Constant && v2Constant {
64+
builder.cs.Mul(&n1, &n2)
65+
builder.macBuffer1 = append(builder.macBuffer1, expr.NewTerm(0, n1))
66+
return
67+
}
68+
69+
if v1Constant {
70+
builder.macBuffer1 = append(builder.macBuffer1, builder.toVariable(c)...)
71+
builder.mulConstant(builder.macBuffer1, n1, true)
72+
return
73+
}
74+
builder.macBuffer1 = append(builder.macBuffer1, builder.toVariable(b)...)
75+
builder.mulConstant(builder.macBuffer1, n2, true)
76+
}
77+
mul()
78+
79+
// we can't mutate a, but we return an address to mutate it in subsequent calls.
80+
_a := builder.toVariable(a)
81+
builder.macBuffer2 = builder.macBuffer2[:0]
82+
builder.add([]expr.LinearExpression{_a, builder.macBuffer1}, false, 0, &builder.macBuffer2)
83+
if cap(_a) >= len(builder.macBuffer2) {
84+
_a = _a[:0]
85+
_a = append(_a, builder.macBuffer2...)
86+
} else {
87+
_a = make(expr.LinearExpression, len(builder.macBuffer2))
88+
copy(_a, builder.macBuffer2)
89+
}
90+
return _a
91+
92+
// if _a, ok := a.(*expr.LinearExpression); ok {
93+
// // we can mutate a
94+
// builder.macBuffer2 = builder.macBuffer2[:0]
95+
// builder.macBuffer2 = append(builder.macBuffer2, *_a...)
96+
97+
// builder.add([]expr.LinearExpression{builder.macBuffer2, builder.macBuffer1}, false, 0, _a)
98+
// return _a
99+
// } else {
100+
// // we can't mutate a, but we return an address to mutate it in subsequent calls.
101+
// _a := builder.toVariable(a)
102+
// r := builder.add([]expr.LinearExpression{_a, builder.macBuffer1}, false, len(_a)+len(builder.macBuffer1), nil).(expr.LinearExpression)
103+
// return &r
104+
// }
105+
// // fmt.Println(reflect.TypeOf(a))
106+
// // if reflect.TypeOf(a) == tPointerVariable {
107+
// // fmt.Println("HERE")
108+
// // a = *(reflect.ValueOf(a).Interface().(*frontend.Variable))
109+
// // }
110+
// return builder.Add(a, builder.Mul(b, c))
111+
}
112+
47113
// Sub returns res = i1 - i2
48114
func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
49115
// extract frontend.Variables from input
50116
vars, s := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...)
51-
return builder.add(vars, true, s)
117+
return builder.add(vars, true, s, nil)
52118
}
53119

54120
// returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true.
55-
func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int) frontend.Variable {
121+
func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int, res *expr.LinearExpression) frontend.Variable {
56122
// we want to merge all terms from input linear expressions
57123
// if they are duplicate, we reduce; that is, if multiple terms in different vars have the
58124
// same variable id.
@@ -68,7 +134,11 @@ func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int
68134
}
69135
builder.heap.heapify()
70136

71-
res := make(expr.LinearExpression, 0, capacity)
137+
if res == nil {
138+
t := make(expr.LinearExpression, 0, capacity)
139+
res = &t
140+
}
141+
// res := make(expr.LinearExpression, 0, capacity)
72142
curr := -1
73143

74144
// process all the terms from all the inputs, in sorted order
@@ -87,37 +157,37 @@ func (builder *builder) add(vars []expr.LinearExpression, sub bool, capacity int
87157
if t.Coeff.IsZero() {
88158
continue // is this really needed?
89159
}
90-
if curr != -1 && t.VID == res[curr].VID {
160+
if curr != -1 && t.VID == (*res)[curr].VID {
91161
// accumulate, it's the same variable ID
92162
if sub && lID != 0 {
93-
builder.cs.Sub(&res[curr].Coeff, &t.Coeff)
163+
builder.cs.Sub(&(*res)[curr].Coeff, &t.Coeff)
94164
} else {
95-
builder.cs.Add(&res[curr].Coeff, &t.Coeff)
165+
builder.cs.Add(&(*res)[curr].Coeff, &t.Coeff)
96166
}
97-
if res[curr].Coeff.IsZero() {
167+
if (*res)[curr].Coeff.IsZero() {
98168
// remove self.
99-
res = res[:curr]
169+
(*res) = (*res)[:curr]
100170
curr--
101171
}
102172
} else {
103173
// append, it's a new variable ID
104-
res = append(res, *t)
174+
(*res) = append((*res), *t)
105175
curr++
106176
if sub && lID != 0 {
107-
builder.cs.Neg(&res[curr].Coeff)
177+
builder.cs.Neg(&(*res)[curr].Coeff)
108178
}
109179
}
110180
}
111181

112-
if len(res) == 0 {
182+
if len((*res)) == 0 {
113183
// keep the linear expression valid (assertIsSet)
114-
res = expr.NewLinearExpression(0, constraint.Coeff{})
184+
(*res) = append((*res), expr.NewTerm(0, constraint.Coeff{}))
115185
}
116186
// if the linear expression LE is too long then record an equality
117187
// constraint LE * 1 = t and return short linear expression instead.
118-
res = builder.compress(res)
188+
(*res) = builder.compress((*res))
119189

120-
return res
190+
return *res
121191
}
122192

123193
// Neg returns -i
@@ -151,7 +221,6 @@ func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) f
151221
// v1 and v2 are constants, we multiply big.Int values and return resulting constant
152222
if v1Constant && v2Constant {
153223
builder.cs.Mul(&n1, &n2)
154-
// n1.Mul(n1, n2).Mod(n1, builder.q)
155224
return expr.NewLinearExpression(0, n1)
156225
}
157226

@@ -547,7 +616,7 @@ func (builder *builder) Println(a ...frontend.Variable) {
547616
if i > 0 {
548617
sbb.WriteByte(' ')
549618
}
550-
if v, ok := arg.(expr.LinearExpression); ok {
619+
if v, ok := builder.isLinearExpression(arg); ok {
551620
assertIsSet(v)
552621

553622
sbb.WriteString("%s")

frontend/cs/r1cs/api_assertions.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,12 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) {
8383
func (builder *builder) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) {
8484
v := builder.toVariable(_v)
8585

86-
switch b := bound.(type) {
87-
case expr.LinearExpression:
86+
if b, ok := builder.isLinearExpression(bound); ok {
8887
assertIsSet(b)
8988
builder.mustBeLessOrEqVar(v, b)
90-
default:
91-
builder.mustBeLessOrEqCst(v, utils.FromInterface(b))
89+
} else {
90+
builder.mustBeLessOrEqCst(v, utils.FromInterface(bound))
9291
}
93-
9492
}
9593

9694
func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) {

frontend/cs/r1cs/builder.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ type builder struct {
5959
q *big.Int
6060
tOne constraint.Coeff
6161
heap minHeap // helps merge k sorted linear expressions
62+
63+
macBuffer1 expr.LinearExpression
64+
macBuffer2 expr.LinearExpression
6265
}
6366

6467
// initialCapacity has quite some impact on frontend performance, especially on large circuits size
@@ -68,6 +71,8 @@ func newBuilder(field *big.Int, config frontend.CompileConfig) *builder {
6871
mtBooleans: make(map[uint64][]expr.LinearExpression, config.Capacity/10),
6972
config: config,
7073
heap: make(minHeap, 0, 100),
74+
macBuffer1: make(expr.LinearExpression, 0, 100),
75+
macBuffer2: make(expr.LinearExpression, 0, 100),
7176
}
7277

7378
// by default the circuit is given a public wire equal to 1
@@ -209,7 +214,7 @@ func (builder *builder) MarkBoolean(v frontend.Variable) {
209214
return
210215
}
211216
// v is a linear expression
212-
l := v.(expr.LinearExpression)
217+
l, _ := builder.isLinearExpression(v)
213218
sort.Sort(l)
214219

215220
key := l.HashCode()
@@ -226,7 +231,7 @@ func (builder *builder) IsBoolean(v frontend.Variable) bool {
226231
return (builder.isCstZero(&b) || builder.isCstOne(&b))
227232
}
228233
// v is a linear expression
229-
l := v.(expr.LinearExpression)
234+
l, _ := builder.isLinearExpression(v)
230235
sort.Sort(l)
231236

232237
key := l.HashCode()
@@ -279,7 +284,7 @@ func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) {
279284
}
280285

281286
func (builder *builder) constantValue(v frontend.Variable) (constraint.Coeff, bool) {
282-
if _v, ok := v.(expr.LinearExpression); ok {
287+
if _v, ok := builder.isLinearExpression(v); ok {
283288
assertIsSet(_v)
284289

285290
if len(_v) != 1 {
@@ -306,6 +311,9 @@ func (builder *builder) toVariable(input interface{}) expr.LinearExpression {
306311
// this is already a "kwown" variable
307312
assertIsSet(t)
308313
return t
314+
case *expr.LinearExpression:
315+
assertIsSet(*t)
316+
return *t
309317
case constraint.Coeff:
310318
return expr.NewLinearExpression(0, t)
311319
case *constraint.Coeff:
@@ -352,14 +360,11 @@ func (builder *builder) NewHint(f hint.Function, nbOutputs int, inputs ...fronte
352360
// TODO @gbotrel hint input pass
353361
// ensure inputs are set and pack them in a []uint64
354362
for i, in := range inputs {
355-
switch t := in.(type) {
356-
case expr.LinearExpression:
363+
if t, ok := builder.isLinearExpression(in); ok {
357364
assertIsSet(t)
358365
hintInputs[i] = builder.getLinearExpression(t)
359-
default:
360-
// make a term
361-
// c := utils.FromInterface(t)
362-
c := builder.cs.FromInterface(t)
366+
} else {
367+
c := builder.cs.FromInterface(in)
363368
term := builder.cs.MakeTerm(&c, 0)
364369
term.MarkConstant()
365370
hintInputs[i] = constraint.LinearExpression{term}
@@ -443,3 +448,13 @@ func (builder *builder) compress(le expr.LinearExpression) expr.LinearExpression
443448
builder.cs.AddConstraint(builder.newR1C(le, one, t))
444449
return t
445450
}
451+
452+
func (builder *builder) isLinearExpression(v frontend.Variable) (expr.LinearExpression, bool) {
453+
switch t := v.(type) {
454+
case expr.LinearExpression:
455+
return t, true
456+
case *expr.LinearExpression:
457+
return *t, true
458+
}
459+
return nil, false
460+
}

frontend/cs/scs/api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ func (builder *scs) Add(i1, i2 frontend.Variable, in ...frontend.Variable) front
5252

5353
}
5454

55+
func (builder *scs) MAC(a frontend.Variable, b, c frontend.Variable) frontend.Variable {
56+
// TODO can we do better here to limit allocations?
57+
return builder.Add(a, builder.Mul(b, c))
58+
}
59+
5560
// neg returns -in
5661
func (builder *scs) neg(in []frontend.Variable) []frontend.Variable {
5762

std/math/emulated/field_assert.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ func rsh(api frontend.API, v frontend.Variable, startDigit, endDigit int) fronte
6767
c.Lsh(c, uint(startDigit))
6868

6969
for i := 0; i < len(bits); i++ {
70-
Σbi = api.Add(Σbi, api.Mul(bits[i], c))
70+
// Σbi = api.Add(Σbi, api.Mul(bits[i], c))
71+
Σbi = api.MAC(Σbi, bits[i], c)
7172
ΣbiRShift = api.Add(ΣbiRShift, api.Mul(bits[i], cRShift))
7273
c.Lsh(c, 1)
7374
cRShift.Lsh(cRShift, 1)

std/math/emulated/wrapped_api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ func (w *FieldAPI[T]) Add(i1 frontend.Variable, i2 frontend.Variable, in ...fron
7676
return res
7777
}
7878

79+
func (w *FieldAPI[T]) MAC(a frontend.Variable, b, c frontend.Variable) frontend.Variable {
80+
// TODO can we do better here to limit allocations?
81+
return w.Add(a, w.Mul(b, c))
82+
}
83+
7984
func (w *FieldAPI[T]) Neg(i1 frontend.Variable) frontend.Variable {
8085
el := w.varToElement(i1)
8186
return w.f.Neg(el)

test/engine.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend
156156
return res
157157
}
158158

159+
func (e *engine) MAC(a frontend.Variable, b, c frontend.Variable) frontend.Variable {
160+
// TODO can we do better here to limit allocations?
161+
return e.Add(a, e.Mul(b, c))
162+
}
163+
159164
func (e *engine) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable {
160165
cptSub++
161166
res := new(big.Int)

0 commit comments

Comments
 (0)