Skip to content

Commit 7662148

Browse files
authored
Merge pull request #4912 from folkertdev/aarch64-pairwise-add
aarch64: add shims for pairwise widening/wrapping addition
2 parents a75b877 + 0fed718 commit 7662148

2 files changed

Lines changed: 179 additions & 0 deletions

File tree

src/shims/aarch64.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,93 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
5959
this.write_immediate(*res_lane, &dest)?;
6060
}
6161
}
62+
63+
// Wrapping pairwise addition.
64+
//
65+
// Concatenates the two input vectors and adds adjacent elements. For input vectors `v`
66+
// and `w` this computes `[v0 + v1, v2 + v3, ..., w0 + w1, w2 + w3, ...]`, using
67+
// wrapping addition for `+`.
68+
//
69+
// Used by `vpadd_{s8, u8, s16, u16, s32, u32}`.
70+
name if name.starts_with("neon.addp.") => {
71+
let [left, right] =
72+
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
73+
74+
let (left, left_len) = this.project_to_simd(left)?;
75+
let (right, right_len) = this.project_to_simd(right)?;
76+
let (dest, dest_len) = this.project_to_simd(dest)?;
77+
78+
assert_eq!(left_len, right_len);
79+
assert_eq!(left_len, dest_len);
80+
81+
assert_eq!(left.layout, right.layout);
82+
assert_eq!(left.layout, dest.layout);
83+
84+
assert!(dest_len.is_multiple_of(2));
85+
let half_len = dest_len.strict_div(2);
86+
87+
for lane_idx in 0..dest_len {
88+
// The left and right vectors are concatenated.
89+
let (src, src_pair_idx) = if lane_idx < half_len {
90+
(&left, lane_idx)
91+
} else {
92+
(&right, lane_idx.strict_sub(half_len))
93+
};
94+
// Convert "pair index" into "index of first element of the pair".
95+
let i = src_pair_idx.strict_mul(2);
96+
97+
let lhs = this.read_immediate(&this.project_index(src, i)?)?;
98+
let rhs = this.read_immediate(&this.project_index(src, i.strict_add(1))?)?;
99+
100+
// Wrapping addition on the element type.
101+
let sum = this.binary_op(BinOp::Add, &lhs, &rhs)?;
102+
103+
let dst_lane = this.project_index(&dest, lane_idx)?;
104+
this.write_immediate(*sum, &dst_lane)?;
105+
}
106+
}
107+
108+
// Widening pairwise addition.
109+
//
110+
// Takes a single input vector, and an output vector with half as many lanes and double
111+
// the element width. Takes adjacent pairs of elements, widens both, and then adds them
112+
// together.
113+
//
114+
// Used by `vpaddl_{u8, u16, u32}` and `vpaddlq_{u8, u16, u32}`.
115+
name if name.starts_with("neon.uaddlp.") => {
116+
let [src] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
117+
118+
let (src, src_len) = this.project_to_simd(src)?;
119+
let (dest, dest_len) = this.project_to_simd(dest)?;
120+
121+
// Operates pairwise, so src has twice as many lanes.
122+
assert_eq!(src_len, dest_len.strict_mul(2));
123+
124+
let src_elem_size = src.layout.field(this, 0).size;
125+
let dest_elem_size = dest.layout.field(this, 0).size;
126+
127+
// Widens, so dest elements must be exactly twice as wide.
128+
assert_eq!(dest_elem_size.bytes(), src_elem_size.bytes().strict_mul(2));
129+
130+
for dest_idx in 0..dest_len {
131+
let src_idx = dest_idx.strict_mul(2);
132+
133+
let a_scalar = this.read_scalar(&this.project_index(&src, src_idx)?)?;
134+
let b_scalar =
135+
this.read_scalar(&this.project_index(&src, src_idx.strict_add(1))?)?;
136+
137+
let a_val = a_scalar.to_uint(src_elem_size)?;
138+
let b_val = b_scalar.to_uint(src_elem_size)?;
139+
140+
// Use addition on u128 to simulate widening addition for the destination type.
141+
// This cannot wrap since the element type is at most u64.
142+
let sum = a_val.strict_add(b_val);
143+
144+
let dst_lane = this.project_index(&dest, dest_idx)?;
145+
this.write_scalar(Scalar::from_uint(sum, dest_elem_size), &dst_lane)?;
146+
}
147+
}
148+
62149
// Vector table lookup: each index selects a byte from the 16-byte table, out-of-range -> 0.
63150
// Used to implement vtbl1_u8 function.
64151
// LLVM does not have a portable shuffle that takes non-const indices

tests/pass/shims/aarch64/intrinsics-aarch64-neon.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ fn main() {
1212
unsafe {
1313
test_vpmaxq_u8();
1414
test_tbl1_v16i8_basic();
15+
test_vpadd();
16+
test_vpaddl();
1517
}
1618
}
1719

@@ -65,3 +67,93 @@ fn test_tbl1_v16i8_basic() {
6567
assert_eq!(&got2_arr[3..16], &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12][..]);
6668
}
6769
}
70+
#[target_feature(enable = "neon")]
71+
unsafe fn test_vpadd() {
72+
let a = vld1_s8([1, 2, 3, 4, 5, 6, 7, 8].as_ptr());
73+
let b = vld1_s8([9, 10, -1, 2, i8::MIN, i8::MIN, i8::MAX, i8::MAX].as_ptr());
74+
let e =
75+
[3i8, 7, 11, 15, 19, -1 + 2, i8::MIN.wrapping_add(i8::MIN), i8::MAX.wrapping_add(i8::MAX)];
76+
let mut r = [0i8; 8];
77+
vst1_s8(r.as_mut_ptr(), vpadd_s8(a, b));
78+
assert_eq!(r, e);
79+
80+
let a = vld1_s16([1, 2, 3, 4].as_ptr());
81+
let b = vld1_s16([-1, 2, i16::MAX, i16::MAX].as_ptr());
82+
let e = [3i16, 7, -1 + 2, i16::MAX.wrapping_add(i16::MAX)];
83+
let mut r = [0i16; 4];
84+
vst1_s16(r.as_mut_ptr(), vpadd_s16(a, b));
85+
assert_eq!(r, e);
86+
87+
let a = vld1_s32([1, 2].as_ptr());
88+
let b = vld1_s32([i32::MAX, i32::MAX].as_ptr());
89+
let e = [3i32, i32::MAX.wrapping_add(i32::MAX)];
90+
let mut r = [0i32; 2];
91+
vst1_s32(r.as_mut_ptr(), vpadd_s32(a, b));
92+
assert_eq!(r, e);
93+
94+
let a = vld1_u8([1, 2, 3, 4, 5, 6, 7, 8].as_ptr());
95+
let b = vld1_u8([9, 10, 11, 12, 13, 14, u8::MAX, u8::MAX].as_ptr());
96+
let e = [3u8, 7, 11, 15, 19, 23, 27, 254];
97+
let mut r = [0u8; 8];
98+
vst1_u8(r.as_mut_ptr(), vpadd_u8(a, b));
99+
assert_eq!(r, e);
100+
101+
let a = vld1_u16([1, 2, 3, 4].as_ptr());
102+
let b = vld1_u16([5, 6, u16::MAX, u16::MAX].as_ptr());
103+
let e = [3u16, 7, 11, 65534];
104+
let mut r = [0u16; 4];
105+
vst1_u16(r.as_mut_ptr(), vpadd_u16(a, b));
106+
assert_eq!(r, e);
107+
108+
let a = vld1_u32([1, 2].as_ptr());
109+
let b = vld1_u32([u32::MAX, u32::MAX].as_ptr());
110+
let e = [3u32, u32::MAX.wrapping_add(u32::MAX)];
111+
let mut r = [0u32; 2];
112+
vst1_u32(r.as_mut_ptr(), vpadd_u32(a, b));
113+
assert_eq!(r, e);
114+
}
115+
116+
#[target_feature(enable = "neon")]
117+
unsafe fn test_vpaddl() {
118+
let a = vld1_u8([1, 2, 3, 4, 5, 6, u8::MAX, u8::MAX].as_ptr());
119+
let e = [3u16, 7, 11, 510];
120+
let mut r = [0u16; 4];
121+
vst1_u16(r.as_mut_ptr(), vpaddl_u8(a));
122+
assert_eq!(r, e);
123+
124+
let a = vld1q_u8([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, u8::MAX, u8::MAX].as_ptr());
125+
let e = [3u16, 7, 11, 15, 19, 23, 27, 510];
126+
let mut r = [0u16; 8];
127+
vst1q_u16(r.as_mut_ptr(), vpaddlq_u8(a));
128+
assert_eq!(r, e);
129+
130+
let a = vld1_u16([1, 2, u16::MAX, u16::MAX].as_ptr());
131+
let e = [3u32, 131070];
132+
let mut r = [0u32; 2];
133+
vst1_u32(r.as_mut_ptr(), vpaddl_u16(a));
134+
assert_eq!(r, e);
135+
136+
let a = vld1q_u16([1, 2, 3, 4, 5, 6, u16::MAX, u16::MAX].as_ptr());
137+
let e = [3u32, 7, 11, 131070];
138+
let mut r = [0u32; 4];
139+
vst1q_u32(r.as_mut_ptr(), vpaddlq_u16(a));
140+
assert_eq!(r, e);
141+
142+
let a = vld1_u32([1, 2].as_ptr());
143+
let e = [3u64];
144+
let mut r = [0u64; 1];
145+
vst1_u64(r.as_mut_ptr(), vpaddl_u32(a));
146+
assert_eq!(r, e);
147+
148+
let a = vld1_u32([u32::MAX, u32::MAX].as_ptr());
149+
let e = [8589934590];
150+
let mut r = [0u64; 1];
151+
vst1_u64(r.as_mut_ptr(), vpaddl_u32(a));
152+
assert_eq!(r, e);
153+
154+
let a = vld1q_u32([1, 2, u32::MAX, u32::MAX].as_ptr());
155+
let e = [3u64, 8589934590];
156+
let mut r = [0u64; 2];
157+
vst1q_u64(r.as_mut_ptr(), vpaddlq_u32(a));
158+
assert_eq!(r, e);
159+
}

0 commit comments

Comments
 (0)