Skip to content

Commit f8663d3

Browse files
authored
elliptic-curve: impl BatchInvert for NonZeroScalar (#1890)
This implements `BatchInvert` for `NonZeroScalar`. To accomplish this, I did the following notable things: - Remove all trait bounds on `BatchInvert` itself. - Remove `CtOption` from `BatchInvert::batch_invert`s return type. - Expose `invert_batch_internal()` internally, which now takes an `invert` function instead of requiring `trait Invert`. - Implement `MulAssign` for `NonZeroScalar`. Things that could still be added: - Mirror all `BatchInvert` implementations on `Scalar`, I left the ones taking a reference to a slice out. - Implement `MulAssign` for `NonZeroScalar` on more combinations. It might be a tiny bit simpler if I implemented `Default` on `NonZeroScalar`, but that seemed like a footgun to me.
1 parent 284a928 commit f8663d3

File tree

2 files changed

+104
-43
lines changed

2 files changed

+104
-43
lines changed

elliptic-curve/src/ops.rs

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Traits for arithmetic operations on elliptic curve field elements.
22
33
use core::iter;
4-
pub use core::ops::{Add, AddAssign, Mul, Neg, Shr, ShrAssign, Sub, SubAssign};
4+
pub use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Shr, ShrAssign, Sub, SubAssign};
55
pub use crypto_bigint::Invert;
66

77
use crypto_bigint::Integer;
@@ -13,26 +13,24 @@ use alloc::{borrow::ToOwned, vec::Vec};
1313

1414
/// Perform a batched inversion on a sequence of field elements (i.e. base field elements or scalars)
1515
/// at an amortized cost that should be practically as efficient as a single inversion.
16-
pub trait BatchInvert<FieldElements: ?Sized>: Field + Sized {
16+
pub trait BatchInvert<FieldElements: ?Sized> {
1717
/// The output of batch inversion. A container of field elements.
18-
type Output: AsRef<[Self]>;
18+
type Output;
1919

2020
/// Invert a batch of field elements.
21-
fn batch_invert(
22-
field_elements: FieldElements,
23-
) -> CtOption<<Self as BatchInvert<FieldElements>>::Output>;
21+
fn batch_invert(field_elements: FieldElements) -> <Self as BatchInvert<FieldElements>>::Output;
2422
}
2523

2624
impl<const N: usize, T> BatchInvert<[T; N]> for T
2725
where
2826
T: Field,
2927
{
30-
type Output = [Self; N];
28+
type Output = CtOption<[Self; N]>;
3129

3230
fn batch_invert(mut field_elements: [Self; N]) -> CtOption<[Self; N]> {
3331
let mut field_elements_pad = [Self::default(); N];
3432
let inversion_succeeded =
35-
invert_batch_internal(&mut field_elements, &mut field_elements_pad);
33+
invert_batch_internal(&mut field_elements, &mut field_elements_pad, invert);
3634

3735
CtOption::new(field_elements, inversion_succeeded)
3836
}
@@ -43,11 +41,12 @@ impl<'this, T> BatchInvert<&'this mut [Self]> for T
4341
where
4442
T: Field,
4543
{
46-
type Output = &'this mut [Self];
44+
type Output = CtOption<&'this mut [Self]>;
4745

4846
fn batch_invert(field_elements: &'this mut [Self]) -> CtOption<&'this mut [Self]> {
4947
let mut field_elements_pad: Vec<Self> = vec![Self::default(); field_elements.len()];
50-
let inversion_succeeded = invert_batch_internal(field_elements, &mut field_elements_pad);
48+
let inversion_succeeded =
49+
invert_batch_internal(field_elements, &mut field_elements_pad, invert);
5150

5251
CtOption::new(field_elements, inversion_succeeded)
5352
}
@@ -58,13 +57,13 @@ impl<T> BatchInvert<&[Self]> for T
5857
where
5958
T: Field,
6059
{
61-
type Output = Vec<Self>;
60+
type Output = CtOption<Vec<Self>>;
6261

6362
fn batch_invert(field_elements: &[Self]) -> CtOption<Vec<Self>> {
6463
let mut field_elements: Vec<Self> = field_elements.to_owned();
6564
let mut field_elements_pad: Vec<Self> = vec![Self::default(); field_elements.len()];
6665
let inversion_succeeded =
67-
invert_batch_internal(&mut field_elements, &mut field_elements_pad);
66+
invert_batch_internal(&mut field_elements, &mut field_elements_pad, invert);
6867

6968
CtOption::new(field_elements, inversion_succeeded)
7069
}
@@ -75,26 +74,35 @@ impl<T> BatchInvert<Vec<Self>> for T
7574
where
7675
T: Field,
7776
{
78-
type Output = Vec<Self>;
77+
type Output = CtOption<Vec<Self>>;
7978

8079
fn batch_invert(mut field_elements: Vec<Self>) -> CtOption<Vec<Self>> {
8180
let mut field_elements_pad: Vec<Self> = vec![Self::default(); field_elements.len()];
8281
let inversion_succeeded =
83-
invert_batch_internal(&mut field_elements, &mut field_elements_pad);
82+
invert_batch_internal(&mut field_elements, &mut field_elements_pad, invert);
8483

8584
CtOption::new(field_elements, inversion_succeeded)
8685
}
8786
}
8887

88+
fn invert<T: Field>(scalar: T) -> (T, Choice) {
89+
let scalar = scalar.invert();
90+
let choice = scalar.is_some();
91+
let scalar = scalar.unwrap_or(T::default());
92+
93+
(scalar, choice)
94+
}
95+
8996
/// Implements "Montgomery's trick", a trick for computing many modular inverses at once.
9097
///
9198
/// "Montgomery's trick" works by reducing the problem of computing `n` inverses
9299
/// to computing a single inversion, plus some storage and `O(n)` extra multiplications.
93100
///
94101
/// See: https://iacr.org/archive/pkc2004/29470042/29470042.pdf section 2.2.
95-
fn invert_batch_internal<T: Field>(
102+
pub(crate) fn invert_batch_internal<T: Copy + Mul<Output = T> + MulAssign>(
96103
field_elements: &mut [T],
97104
field_elements_pad: &mut [T],
105+
invert: fn(T) -> (T, Choice),
98106
) -> Choice {
99107
let batch_size = field_elements.len();
100108
if batch_size != field_elements_pad.len() {
@@ -117,32 +125,32 @@ fn invert_batch_internal<T: Field>(
117125
*field_element_pad = acc;
118126
}
119127

120-
acc.invert()
121-
.map(|mut acc| {
122-
// Shift the iterator by one element back. The one we are skipping is served in `acc`.
123-
let field_elements_pad = field_elements_pad
124-
.iter()
125-
.rev()
126-
.skip(1)
127-
.map(Some)
128-
.chain(iter::once(None));
129-
130-
for (field_element, field_element_pad) in
131-
field_elements.iter_mut().rev().zip(field_elements_pad)
132-
{
133-
if let Some(field_element_pad) = field_element_pad {
134-
// Store in a temporary so we can overwrite `field_element`.
135-
// $ a_{n-1} = {a_n}^{-1}*x_n $
136-
let tmp = acc * *field_element;
137-
// $ {x_n}^{-1} = a_{n}^{-1}*a_{n-1} $
138-
*field_element = acc * *field_element_pad;
139-
acc = tmp;
140-
} else {
141-
*field_element = acc;
142-
}
143-
}
144-
})
145-
.is_some()
128+
let (mut acc, choice) = invert(acc);
129+
130+
// Shift the iterator by one element back. The one we are skipping is served in `acc`.
131+
let field_elements_pad = field_elements_pad
132+
.iter()
133+
.rev()
134+
.skip(1)
135+
.map(Some)
136+
.chain(iter::once(None));
137+
138+
for (field_element, field_element_pad) in
139+
field_elements.iter_mut().rev().zip(field_elements_pad)
140+
{
141+
if let Some(field_element_pad) = field_element_pad {
142+
// Store in a temporary so we can overwrite `field_element`.
143+
// $ a_{n-1} = {a_n}^{-1}*x_n $
144+
let tmp = acc * *field_element;
145+
// $ {x_n}^{-1} = a_{n}^{-1}*a_{n-1} $
146+
*field_element = acc * *field_element_pad;
147+
acc = tmp;
148+
} else {
149+
*field_element = acc;
150+
}
151+
}
152+
153+
choice
146154
}
147155

148156
/// Linear combination.

elliptic-curve/src/scalar/nonzero.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
33
use crate::{
44
CurveArithmetic, Error, FieldBytes, PrimeCurve, Scalar, ScalarPrimitive, SecretKey,
5-
ops::{Invert, Reduce, ReduceNonZero},
5+
ops::{self, BatchInvert, Invert, Reduce, ReduceNonZero},
66
point::NonIdentity,
77
scalar::IsHigh,
88
};
99
use base16ct::HexDisplay;
1010
use core::{
1111
fmt,
12-
ops::{Deref, Mul, Neg},
12+
ops::{Deref, Mul, MulAssign, Neg},
1313
str,
1414
};
1515
use crypto_bigint::{ArrayEncoding, Integer};
@@ -18,6 +18,9 @@ use rand_core::{CryptoRng, TryCryptoRng};
1818
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
1919
use zeroize::Zeroize;
2020

21+
#[cfg(feature = "alloc")]
22+
use alloc::vec::Vec;
23+
2124
#[cfg(feature = "serde")]
2225
use serdect::serde::{Deserialize, Serialize, de, ser};
2326

@@ -96,6 +99,47 @@ where
9699
}
97100
}
98101

102+
impl<const N: usize, C> BatchInvert<[Self; N]> for NonZeroScalar<C>
103+
where
104+
C: CurveArithmetic + PrimeCurve,
105+
{
106+
type Output = [Self; N];
107+
108+
fn batch_invert(mut field_elements: [Self; N]) -> [Self; N] {
109+
let mut field_elements_pad = [Self {
110+
scalar: Scalar::<C>::ONE,
111+
}; N];
112+
ops::invert_batch_internal(&mut field_elements, &mut field_elements_pad, |scalar| {
113+
(scalar.invert(), Choice::from(1))
114+
});
115+
116+
field_elements
117+
}
118+
}
119+
120+
#[cfg(feature = "alloc")]
121+
impl<C> BatchInvert<Vec<Self>> for NonZeroScalar<C>
122+
where
123+
C: CurveArithmetic + PrimeCurve,
124+
{
125+
type Output = Vec<Self>;
126+
127+
fn batch_invert(mut field_elements: Vec<Self>) -> Vec<Self> {
128+
let mut field_elements_pad: Vec<Self> = vec![
129+
Self {
130+
scalar: Scalar::<C>::ONE,
131+
};
132+
field_elements.len()
133+
];
134+
135+
ops::invert_batch_internal(&mut field_elements, &mut field_elements_pad, |scalar| {
136+
(scalar.invert(), Choice::from(1))
137+
});
138+
139+
field_elements
140+
}
141+
}
142+
99143
impl<C> ConditionallySelectable for NonZeroScalar<C>
100144
where
101145
C: CurveArithmetic,
@@ -307,6 +351,15 @@ where
307351
}
308352
}
309353

354+
impl<C> MulAssign for NonZeroScalar<C>
355+
where
356+
C: PrimeCurve + CurveArithmetic,
357+
{
358+
fn mul_assign(&mut self, rhs: Self) {
359+
*self = *self * rhs;
360+
}
361+
}
362+
310363
impl<C> PartialEq for NonZeroScalar<C>
311364
where
312365
C: CurveArithmetic,

0 commit comments

Comments
 (0)