Skip to content

Commit e0c3426

Browse files
committed
bigint/x86_64: Eliminate most unnecessary chunking.
We don't actually need to pass arguments as `&[[Limb; 8]]` instead of `&[Limb]` in most cases. As long as one argument has a type like that, then we've ensured that the length is a multiple of 512 bits. In most cases we were calling `as_chunks[_mut]` on something and then immediately calling `as_flattened[_mut]` on it.
1 parent 4dae629 commit e0c3426

File tree

4 files changed

+21
-32
lines changed

4 files changed

+21
-32
lines changed

src/arithmetic/bigint.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -637,9 +637,7 @@ fn elem_exp_consttime_inner<N, M, const STORAGE_LIMBS: usize>(
637637
.ok_or_else(|| LenMismatchError::new(m.limbs().len()))?;
638638
let cpe = m_original.len(); // 512-bit chunks per entry
639639

640-
let oneRRR = &oneRRR.as_ref().limbs;
641-
let oneRRR =
642-
as_chunks_exact(oneRRR.as_ref()).ok_or_else(|| LenMismatchError::new(oneRRR.len()))?;
640+
let oneRRR = oneRRR.as_ref().limbs.as_ref();
643641

644642
// The x86_64 assembly was written under the assumption that the input data
645643
// is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL.
@@ -669,15 +667,16 @@ fn elem_exp_consttime_inner<N, M, const STORAGE_LIMBS: usize>(
669667
// These are named `(tmp, am, np)` in BoringSSL.
670668
let (acc, rest) = state.split_at_mut(cpe);
671669
let (base_cached, m_cached) = rest.split_at_mut(cpe);
670+
let base_cached = base_cached.as_flattened_mut();
671+
let acc = acc.as_flattened_mut();
672672

673673
// "To improve cache locality" according to upstream.
674674
m_cached
675675
.as_flattened_mut()
676676
.copy_from_slice(m_original.as_flattened());
677677

678678
let out: Elem<M, RInverse> = elem_reduced(out, base_mod_n, m, other_prime_len_bits);
679-
let base_rinverse = as_chunks_exact(out.limbs.as_ref())
680-
.ok_or_else(|| LenMismatchError::new(out.limbs.len()))?;
679+
let base_rinverse = out.limbs.as_ref();
681680

682681
// base_cached = base*R == (base/R * RRR)/R
683682
let _: &[Limb] = mul_mont5(base_cached, base_rinverse, oneRRR, m_cached, n0, cpu2)?;
@@ -687,7 +686,7 @@ fn elem_exp_consttime_inner<N, M, const STORAGE_LIMBS: usize>(
687686
// gathering, storing the last calculated power into `acc`.
688687
fn scatter_powers_of_2(
689688
table: &mut [[Limb; 8]],
690-
acc: &mut [[Limb; 8]],
689+
mut acc: &mut [Limb],
691690
m_cached: &[[Limb; 8]],
692691
n0: &N0,
693692
mut i: LeakyWindow5,
@@ -699,20 +698,19 @@ fn elem_exp_consttime_inner<N, M, const STORAGE_LIMBS: usize>(
699698
Some(i) => i,
700699
None => break,
701700
};
702-
let _: &[Limb] = sqr_mont5(acc.as_flattened_mut(), m_cached, n0, cpu)?;
701+
acc = sqr_mont5(acc, m_cached, n0, cpu)?;
703702
}
704703
Ok(())
705704
}
706705

707706
// All entries in `table` will be Montgomery encoded.
708707

709708
// acc = table[0] = base**0 (i.e. 1).
710-
let _: &[Limb] = m.oneR(polyfill::slice::Uninit::from_mut(acc.as_flattened_mut()))?;
709+
let _: &[Limb] = m.oneR(polyfill::slice::Uninit::from_mut(acc))?;
711710
scatter5(acc, table, LeakyWindow5::_0)?;
712711

713712
// acc = base**1 (i.e. base).
714-
acc.as_flattened_mut()
715-
.copy_from_slice(base_cached.as_flattened());
713+
acc.copy_from_slice(base_cached);
716714

717715
// Fill in entries 1, 2, 4, 8, 16.
718716
scatter_powers_of_2(table, acc, m_cached, n0, LeakyWindow5::_1, cpu2)?;
@@ -740,7 +738,7 @@ fn elem_exp_consttime_inner<N, M, const STORAGE_LIMBS: usize>(
740738
},
741739
);
742740

743-
out.as_mut().copy_from_slice(acc.as_flattened());
741+
out.as_mut().copy_from_slice(acc);
744742
Ok(from_montgomery_amm(out, m))
745743
}
746744

src/arithmetic/limbs/x86_64/mont.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,15 @@ const _512_IS_LIMB_BITS_TIMES_8: () = assert!(8 * Limb::BITS == 512);
3838

3939
#[inline]
4040
pub(in super::super::super) fn mul_mont5<'o>(
41-
r: &'o mut [[Limb; 8]],
42-
a: &[[Limb; 8]],
43-
b: &[[Limb; 8]],
41+
r: &'o mut [Limb],
42+
a: &[Limb],
43+
b: &[Limb],
4444
m: &[[Limb; 8]],
4545
n0: &N0,
4646
maybe_adx_bmi2: Option<(Adx, Bmi2)>,
4747
) -> Result<&'o mut [Limb], LimbSliceError> {
4848
mul_mont5_4x(
49-
(
50-
Uninit::from_mut(r.as_flattened_mut()),
51-
a.as_flattened(),
52-
b.as_flattened(),
53-
),
49+
(Uninit::from_mut(r), a, b),
5450
SmallerChunks::as_smaller_chunks(m),
5551
n0,
5652
maybe_adx_bmi2,
@@ -125,7 +121,7 @@ pub(in super::super::super) fn sqr_mont5<'o>(
125121

126122
#[inline(always)]
127123
pub(in super::super::super) fn gather5(
128-
r: &mut [[Limb; 8]],
124+
r: &mut [Limb],
129125
table: &[[Limb; 8]],
130126
power: Window5,
131127
) -> Result<(), LimbSliceError> {
@@ -139,16 +135,15 @@ pub(in super::super::super) fn gather5(
139135
power: Window5);
140136
}
141137
let num_limbs = check_common(r, table)?;
142-
let r = r.as_flattened_mut();
143138
let table = table.as_flattened();
144139
unsafe { bn_gather5(r.as_mut_ptr(), num_limbs, table.as_ptr(), power) };
145140
Ok(())
146141
}
147142

148143
#[inline(always)]
149144
pub(in super::super::super) fn mul_mont_gather5_amm(
150-
r: &mut [[Limb; 8]],
151-
a: &[[Limb; 8]],
145+
r: &mut [Limb],
146+
a: &[Limb],
152147
table: &[[Limb; 8]],
153148
n: &[[Limb; 8]],
154149
n0: &N0,
@@ -180,11 +175,9 @@ pub(in super::super::super) fn mul_mont_gather5_amm(
180175
);
181176
}
182177
let num_limbs = check_common_with_n(r, table, n)?;
183-
let a = a.as_flattened();
184178
if a.len() != num_limbs.get() {
185179
Err(LenMismatchError::new(a.len()))?;
186180
}
187-
let r = r.as_flattened_mut();
188181
let r = r.as_mut_ptr();
189182
let a = a.as_ptr();
190183
let table = table.as_flattened();
@@ -202,7 +195,7 @@ pub(in super::super::super) fn mul_mont_gather5_amm(
202195
// SAFETY: `power` must be less than 32.
203196
#[inline(always)]
204197
pub(in super::super::super) fn power5_amm(
205-
in_out: &mut [[Limb; 8]],
198+
in_out: &mut [Limb],
206199
table: &[[Limb; 8]],
207200
n: &[[Limb; 8]],
208201
n0: &N0,
@@ -234,7 +227,6 @@ pub(in super::super::super) fn power5_amm(
234227
);
235228
}
236229
let num_limbs = check_common_with_n(in_out, table, n)?;
237-
let in_out = in_out.as_flattened_mut();
238230
let r = in_out.as_mut_ptr();
239231
let a = in_out.as_ptr();
240232
let table = table.as_flattened();

src/arithmetic/limbs512/scatter.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use {
1010
// `table` has space for 32 entries the same size as `a`. Instead of storing
1111
// entries consecutively row-wise, instead store them column-wise.
1212
pub(in super::super::super) fn scatter5(
13-
a: &[[Limb; LIMBS_PER_CHUNK]],
13+
a: &[Limb],
1414
table: &mut [[Limb; LIMBS_PER_CHUNK]],
1515
i: LeakyWindow5,
1616
) -> Result<(), LimbSliceError> {
@@ -22,7 +22,7 @@ pub(in super::super::super) fn scatter5(
2222
.iter_mut()
2323
.skip(i)
2424
.step_by(32)
25-
.zip(a.as_flattened())
25+
.zip(a)
2626
.for_each(|(t, &a)| {
2727
*t = a;
2828
});

src/arithmetic/limbs512/storage.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,10 @@ impl<const N: usize> AlignedStorage<N> {
7171
// callers.
7272
#[inline(always)]
7373
pub(crate) fn check_common(
74-
a: &[[Limb; LIMBS_PER_CHUNK]],
74+
a: &[Limb],
7575
table: &[[Limb; LIMBS_PER_CHUNK]],
7676
) -> Result<NonZeroUsize, LimbSliceError> {
7777
assert_eq!((table.as_ptr() as usize) % 16, 0); // According to BoringSSL.
78-
let a = a.as_flattened();
7978
let table = table.as_flattened();
8079
let num_limbs = NonZeroUsize::new(a.len()).ok_or_else(|| LimbSliceError::too_short(a.len()))?;
8180
if num_limbs.get() > MAX_LIMBS {
@@ -90,7 +89,7 @@ pub(crate) fn check_common(
9089
#[cfg(target_arch = "x86_64")]
9190
#[inline(always)]
9291
pub(crate) fn check_common_with_n(
93-
a: &[[Limb; LIMBS_PER_CHUNK]],
92+
a: &[Limb],
9493
table: &[[Limb; LIMBS_PER_CHUNK]],
9594
n: &[[Limb; LIMBS_PER_CHUNK]],
9695
) -> Result<NonZeroUsize, LimbSliceError> {

0 commit comments

Comments
 (0)