From d31e8a499bdab8b70295580c288e5aa8d7f080dc Mon Sep 17 00:00:00 2001
From: Ezra Shaw <ezrasure@outlook.com>
Date: Sat, 22 Apr 2023 16:55:59 +1200
Subject: [PATCH] allow array-style simd in inline asm

---
 .../src/check/intrinsicck.rs                  | 38 ++++++++++++-------
 tests/assembly/asm/inline-asm-avx.rs          | 25 ++++++++++++
 2 files changed, 50 insertions(+), 13 deletions(-)
 create mode 100644 tests/assembly/asm/inline-asm-avx.rs

diff --git a/compiler/rustc_hir_analysis/src/check/intrinsicck.rs b/compiler/rustc_hir_analysis/src/check/intrinsicck.rs
index 0d482b53afef8..a28814681dbf6 100644
--- a/compiler/rustc_hir_analysis/src/check/intrinsicck.rs
+++ b/compiler/rustc_hir_analysis/src/check/intrinsicck.rs
@@ -84,33 +84,45 @@ impl<'a, 'tcx> InlineAsmCtxt<'a, 'tcx> {
             ty::Adt(adt, substs) if adt.repr().simd() => {
                 let fields = &adt.non_enum_variant().fields;
                 let elem_ty = fields[FieldIdx::from_u32(0)].ty(self.tcx, substs);
-                match elem_ty.kind() {
-                    ty::Never | ty::Error(_) => return None,
-                    ty::Int(IntTy::I8) | ty::Uint(UintTy::U8) => {
-                        Some(InlineAsmType::VecI8(fields.len() as u64))
+
+                let (size, ty) = match elem_ty.kind() {
+                    ty::Array(ty, len) => {
+                        if let Some(len) =
+                            len.try_eval_target_usize(self.tcx, self.tcx.param_env(adt.did()))
+                        {
+                            (len, *ty)
+                        } else {
+                            return None;
+                        }
                     }
+                    _ => (fields.len() as u64, elem_ty),
+                };
+
+                match ty.kind() {
+                    ty::Never | ty::Error(_) => return None,
+                    ty::Int(IntTy::I8) | ty::Uint(UintTy::U8) => Some(InlineAsmType::VecI8(size)),
                     ty::Int(IntTy::I16) | ty::Uint(UintTy::U16) => {
-                        Some(InlineAsmType::VecI16(fields.len() as u64))
+                        Some(InlineAsmType::VecI16(size))
                     }
                     ty::Int(IntTy::I32) | ty::Uint(UintTy::U32) => {
-                        Some(InlineAsmType::VecI32(fields.len() as u64))
+                        Some(InlineAsmType::VecI32(size))
                     }
                     ty::Int(IntTy::I64) | ty::Uint(UintTy::U64) => {
-                        Some(InlineAsmType::VecI64(fields.len() as u64))
+                        Some(InlineAsmType::VecI64(size))
                     }
                     ty::Int(IntTy::I128) | ty::Uint(UintTy::U128) => {
-                        Some(InlineAsmType::VecI128(fields.len() as u64))
+                        Some(InlineAsmType::VecI128(size))
                     }
                     ty::Int(IntTy::Isize) | ty::Uint(UintTy::Usize) => {
                         Some(match self.tcx.sess.target.pointer_width {
-                            16 => InlineAsmType::VecI16(fields.len() as u64),
-                            32 => InlineAsmType::VecI32(fields.len() as u64),
-                            64 => InlineAsmType::VecI64(fields.len() as u64),
+                            16 => InlineAsmType::VecI16(size),
+                            32 => InlineAsmType::VecI32(size),
+                            64 => InlineAsmType::VecI64(size),
                             _ => unreachable!(),
                         })
                     }
-                    ty::Float(FloatTy::F32) => Some(InlineAsmType::VecF32(fields.len() as u64)),
-                    ty::Float(FloatTy::F64) => Some(InlineAsmType::VecF64(fields.len() as u64)),
+                    ty::Float(FloatTy::F32) => Some(InlineAsmType::VecF32(size)),
+                    ty::Float(FloatTy::F64) => Some(InlineAsmType::VecF64(size)),
                     _ => None,
                 }
             }
diff --git a/tests/assembly/asm/inline-asm-avx.rs b/tests/assembly/asm/inline-asm-avx.rs
new file mode 100644
index 0000000000000..c2875f3e0a444
--- /dev/null
+++ b/tests/assembly/asm/inline-asm-avx.rs
@@ -0,0 +1,25 @@
+// assembly-output: emit-asm
+// compile-flags: --crate-type=lib
+// only-x86_64
+// ignore-sgx
+
+#![feature(portable_simd)]
+
+use std::simd::Simd;
+use std::arch::asm;
+
+#[target_feature(enable = "avx")]
+#[no_mangle]
+// CHECK-LABEL: convert:
+pub unsafe fn convert(a: *const f32) -> Simd<f32, 8> {
+    // CHECK: vbroadcastss (%{{[er][a-ds0-9][xpi0-9]?}}), {{%ymm[0-7]}}
+    let b: Simd<f32, 8>;
+    unsafe {
+        asm!(
+            "vbroadcastss {b}, [{a}]",
+            a = in(reg) a,
+            b = out(ymm_reg) b,
+        );
+    }
+    b
+}