Skip to content

Commit cccf996

Browse files
Rollup merge of #155264 - sayantn:amx-autocast, r=dianqk
Add autocast support for `x86amx` Builds on #140763 by further adding autocasts for `x86amx` from/to vectors of size 8192 bits. This also disables SIMD vector abi checks for the `"unadjusted"` abi because - This is primarily used to link with LLVM intrinsics, which don't actually lower to function calls with vector arguments. Even with other cg backends, this is true. - This ABI is internal and perma-unstable (and also super specific), so it is very unlikely that this will cause breakages. - (The primary reason) Without doing this we can't actually use 8192 bit long vectors to represent `x86amx` > Why do we need a bypass for `x86amx`? Can't we use a `#[lang_item]` or something? If `x86amx` was a normal LLVM type, this approach would've worked and I would also prefer it. But LLVM specifies that > No instruction is allowed for this type. There are no arguments, arrays, pointers, vectors or constants of this type. So we can't treat it like a normal type at all -- even if we add it like a lang-item, we would still have to special-case everywhere to check if we are passing to the correct LLVM intrinsic, and only then use the `x86amx` type. IMO this is needlessly complex, and way worse than this solution, which just adds it to the autocast list in cg_llvm r? codegen
2 parents ce91732 + 7e24cd8 commit cccf996

3 files changed

Lines changed: 62 additions & 4 deletions

File tree

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,24 @@ fn can_autocast<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll
10331033
}
10341034
}
10351035
TypeKind::BFloat => rust_ty == cx.type_i16(),
1036+
TypeKind::X86_AMX if cx.type_kind(rust_ty) == TypeKind::Vector => {
1037+
let element_ty = cx.element_type(rust_ty);
1038+
let element_count = cx.vector_length(rust_ty) as u64;
1039+
1040+
let element_size_bits = match cx.type_kind(element_ty) {
1041+
TypeKind::Half => 16,
1042+
TypeKind::Float => 32,
1043+
TypeKind::Double => 64,
1044+
TypeKind::FP128 => 128,
1045+
TypeKind::Integer => cx.int_width(element_ty),
1046+
TypeKind::Pointer => cx.int_width(cx.isize_ty),
1047+
_ => bug!(
1048+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
1049+
),
1050+
};
1051+
1052+
element_size_bits * element_count == 8192
1053+
}
10361054
_ => false,
10371055
}
10381056
}
@@ -1102,6 +1120,12 @@ fn autocast<'ll>(
11021120
)
11031121
}
11041122
}
1123+
(TypeKind::Vector, TypeKind::X86_AMX) => {
1124+
bx.call_intrinsic("llvm.x86.cast.vector.to.tile", &[src_ty], &[val])
1125+
}
1126+
(TypeKind::X86_AMX, TypeKind::Vector) => {
1127+
bx.call_intrinsic("llvm.x86.cast.tile.to.vector", &[dest_ty], &[val])
1128+
}
11051129
_ => bx.bitcast(val, dest_ty), // for `bf16(xN)` <-> `u16(xN)`
11061130
}
11071131
}

compiler/rustc_monomorphize/src/mono_checks/abi_check.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! This module ensures that if a function's ABI requires a particular target feature,
22
//! that target feature is enabled both on the callee and all callers.
3-
use rustc_abi::{BackendRepr, CanonAbi, RegKind, X86Call};
3+
use rustc_abi::{BackendRepr, CanonAbi, ExternAbi, RegKind, X86Call};
44
use rustc_hir::{CRATE_HIR_ID, HirId};
55
use rustc_middle::mir::{self, Location, traversal};
66
use rustc_middle::ty::{self, Instance, InstanceKind, Ty, TyCtxt};
@@ -160,6 +160,12 @@ fn do_check_unsized_params<'tcx>(
160160
/// - the signature requires target features that are not enabled
161161
fn check_instance_abi<'tcx>(tcx: TyCtxt<'tcx>, instance: Instance<'tcx>) {
162162
let typing_env = ty::TypingEnv::fully_monomorphized();
163+
let ty = instance.ty(tcx, typing_env);
164+
if ty.is_fn() && ty.fn_sig(tcx).abi() == ExternAbi::Unadjusted {
165+
// We disable all checks for the unadjusted ABI to allow linking to arbitrary LLVM
166+
// intrinsics
167+
return;
168+
}
163169
let Ok(abi) = tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
164170
else {
165171
// An error will be reported during codegen if we cannot determine the ABI of this
@@ -194,9 +200,12 @@ fn check_call_site_abi<'tcx>(
194200
caller: InstanceKind<'tcx>,
195201
loc: impl Fn() -> (Span, HirId) + Copy,
196202
) {
197-
if callee.fn_sig(tcx).abi().is_rustic_abi() {
203+
let extern_abi = callee.fn_sig(tcx).abi();
204+
if extern_abi.is_rustic_abi() || extern_abi == ExternAbi::Unadjusted {
198205
// We directly handle the soundness of Rust ABIs -- so let's skip the majority of
199206
// call sites to avoid a perf regression.
207+
// We disable all checks for the unadjusted ABI to allow linking to arbitrary LLVM
208+
// intrinsics
200209
return;
201210
}
202211
let typing_env = ty::TypingEnv::fully_monomorphized();

tests/codegen-llvm/inject-autocast.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avxneconvert
1+
//@ compile-flags: -C opt-level=0 -C target-feature=+kl,+avx512vp2intersect,+avx512vl,+avx512dq,+avxneconvert,+amx-int8
22
//@ only-x86_64
33

4-
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd)]
4+
#![feature(link_llvm_intrinsics, abi_unadjusted, simd_ffi, portable_simd, repr_simd)]
55
#![crate_type = "lib"]
66

77
use std::simd::{f32x4, i16x8, i64x2};
@@ -10,6 +10,9 @@ use std::simd::{f32x4, i16x8, i64x2};
1010
pub struct Bar(u32, i64x2, i64x2, i64x2, i64x2, i64x2, i64x2);
1111
// CHECK: %Bar = type <{ i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> }>
1212

13+
#[repr(simd)]
14+
pub struct Tile([i8; 1024]);
15+
1316
// CHECK-LABEL: @struct_autocast
1417
#[no_mangle]
1518
pub unsafe fn struct_autocast(key_metadata: u32, key: i64x2) -> Bar {
@@ -84,10 +87,32 @@ pub unsafe fn bf16_vector_autocast(a: f32x4) -> i16x8 {
8487
foo(a)
8588
}
8689

90+
// CHECK-LABEL: @amx_autocast
91+
#[no_mangle]
92+
pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile {
93+
extern "unadjusted" {
94+
#[link_name = "llvm.x86.tdpbuud.internal"]
95+
fn foo(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) -> Tile;
96+
}
97+
98+
// CHECK: [[A:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
99+
// CHECK: [[B:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
100+
// CHECK: [[C:%[0-9]+]] = call x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8> {{.*}})
101+
// CHECK: [[D:%[0-9]+]] = call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx [[A]], x86_amx [[B]], x86_amx [[C]])
102+
// CHECK: call <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx [[D]])
103+
foo(m, n, k, a, b, c)
104+
}
105+
87106
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
88107

89108
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
90109

91110
// CHECK: declare <8 x i1> @llvm.x86.avx512.kadd.b(<8 x i1>, <8 x i1>)
92111

93112
// CHECK: declare <8 x bfloat> @llvm.x86.vcvtneps2bf16128(<4 x float>)
113+
114+
// CHECK: declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
115+
116+
// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>)
117+
118+
// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)

0 commit comments

Comments
 (0)