Skip to content

Fix codegen bug in "ptx-kernel" abi related to arg passing #94703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2568,6 +2568,22 @@ where

pointee_info
}

fn is_adt(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Adt(..))
}

fn is_never(this: TyAndLayout<'tcx>) -> bool {
this.ty.kind() == &ty::Never
}

fn is_tuple(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Tuple(..))
}

fn is_unit(this: TyAndLayout<'tcx>) -> bool {
matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0)
}
}

impl<'tcx> ty::Instance<'tcx> {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ impl<T> List<T> {
static EMPTY_SLICE: InOrder<usize, MaxAlign> = InOrder(0, MaxAlign);
unsafe { &*(&EMPTY_SLICE as *const _ as *const List<T>) }
}

pub fn len(&self) -> usize {
self.len
}
}

impl<T: Copy> List<T> {
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_target/src/abi/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> {
"sparc" => sparc::compute_abi_info(cx, self),
"sparc64" => sparc64::compute_abi_info(cx, self),
"nvptx" => nvptx::compute_abi_info(self),
"nvptx64" => nvptx64::compute_abi_info(self),
"nvptx64" => {
if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel {
nvptx64::compute_ptx_kernel_abi_info(cx, self)
} else {
nvptx64::compute_abi_info(self)
}
}
"hexagon" => hexagon::compute_abi_info(self),
"riscv32" | "riscv64" => riscv::compute_abi_info(cx, self),
"wasm32" | "wasm64" => {
Expand Down
47 changes: 39 additions & 8 deletions compiler/rustc_target/src/abi/call/nvptx64.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
// Reference: PTX Writer's Guide to Interoperability
// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability

use crate::abi::call::{ArgAbi, FnAbi};
use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
use crate::abi::{HasDataLayout, TyAbiInterface};

fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 {
ret.make_indirect();
} else {
ret.extend_integer_width_to(64);
}
}

fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 {
arg.make_indirect();
} else {
arg.extend_integer_width_to(64);
}
}

fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
let align_bytes = arg.layout.align.abi.bytes();

let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) });
}
}

Expand All @@ -31,3 +45,20 @@ pub fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
classify_arg(arg);
}
}

pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
panic!("Kernels should not return anything other than () or !");
}

for arg in &mut fn_abi.args {
if arg.is_ignore() {
continue;
}
classify_arg_kernel(cx, arg);
}
}
32 changes: 32 additions & 0 deletions compiler/rustc_target/src/abi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ pub trait TyAbiInterface<'a, C>: Sized {
cx: &C,
offset: Size,
) -> Option<PointeeInfo>;
fn is_adt(this: TyAndLayout<'a, Self>) -> bool;
fn is_never(this: TyAndLayout<'a, Self>) -> bool;
fn is_tuple(this: TyAndLayout<'a, Self>) -> bool;
fn is_unit(this: TyAndLayout<'a, Self>) -> bool;
}

impl<'a, Ty> TyAndLayout<'a, Ty> {
Expand Down Expand Up @@ -1291,6 +1295,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> {
_ => false,
}
}

pub fn is_adt<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_adt(self)
}

pub fn is_never<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_never(self)
}

pub fn is_tuple<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_tuple(self)
}

pub fn is_unit<C>(self) -> bool
where
Ty: TyAbiInterface<'a, C>,
{
Ty::is_unit(self)
}
}

impl<'a, Ty> TyAndLayout<'a, Ty> {
Expand Down
Loading