Skip to content

[NVPTX] Combine addressing-mode variants of ld, st, wmma #129102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
569 changes: 159 additions & 410 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Large diffs are not rendered by default.

12 changes: 1 addition & 11 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
}

// Match direct address complex pattern.
bool SelectDirectAddr(SDValue N, SDValue &Address);

void SelectADDRri_imp(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset, MVT VT);
bool SelectADDRri(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
bool SelectADDRri64(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
bool SelectADDRsi(SDNode *OpNode, SDValue Addr, SDValue &Base,
SDValue &Offset);
bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);

bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;

Expand Down
190 changes: 51 additions & 139 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1917,27 +1917,15 @@ defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
// Data Movement (Load / Store, Move)
//-----------------------------------

let WantsRoot = true in {
def ADDRri : ComplexPattern<i32, 2, "SelectADDRri", [frameindex]>;
def ADDRri64 : ComplexPattern<i64, 2, "SelectADDRri64", [frameindex]>;
}
def ADDRvar : ComplexPattern<iPTR, 1, "SelectDirectAddr", [], []>;
def addr : ComplexPattern<pAny, 2, "SelectADDR">;

def MEMri : Operand<i32> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops Int32Regs, i32imm);
}
def MEMri64 : Operand<i64> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops Int64Regs, i64imm);
}

def imem : Operand<iPTR> {
def ADDR_base : Operand<pAny> {
let PrintMethod = "printOperand";
}

def imemAny : Operand<pAny> {
let PrintMethod = "printOperand";
def ADDR : Operand<pAny> {
let PrintMethod = "printMemOperand";
let MIOperandInfo = (ops ADDR_base, i32imm);
}

def LdStCode : Operand<i32> {
Expand All @@ -1956,10 +1944,10 @@ def SDTWrapper : SDTypeProfile<1, 1, [SDTCisSameAs<0, 1>, SDTCisPtrTy<0>]>;
def Wrapper : SDNode<"NVPTXISD::Wrapper", SDTWrapper>;

// Load a memory address into a u32 or u64 register.
def MOV_ADDR : NVPTXInst<(outs Int32Regs:$dst), (ins imem:$a),
def MOV_ADDR : NVPTXInst<(outs Int32Regs:$dst), (ins ADDR_base:$a),
"mov.u32 \t$dst, $a;",
[(set i32:$dst, (Wrapper tglobaladdr:$a))]>;
def MOV_ADDR64 : NVPTXInst<(outs Int64Regs:$dst), (ins imem:$a),
def MOV_ADDR64 : NVPTXInst<(outs Int64Regs:$dst), (ins ADDR_base:$a),
"mov.u64 \t$dst, $a;",
[(set i64:$dst, (Wrapper tglobaladdr:$a))]>;

Expand Down Expand Up @@ -2021,12 +2009,17 @@ def : Pat<(i32 (Wrapper texternalsym:$dst)), (IMOV32ri texternalsym:$dst)>;
def : Pat<(i64 (Wrapper texternalsym:$dst)), (IMOV64ri texternalsym:$dst)>;

//---- Copy Frame Index ----
def LEA_ADDRi : NVPTXInst<(outs Int32Regs:$dst), (ins MEMri:$addr),
"add.u32 \t$dst, ${addr:add};",
[(set i32:$dst, ADDRri:$addr)]>;
def LEA_ADDRi64 : NVPTXInst<(outs Int64Regs:$dst), (ins MEMri64:$addr),
"add.u64 \t$dst, ${addr:add};",
[(set i64:$dst, ADDRri64:$addr)]>;
def LEA_ADDRi : NVPTXInst<(outs Int32Regs:$dst), (ins ADDR:$addr),
"add.u32 \t$dst, ${addr:add};", []>;
def LEA_ADDRi64 : NVPTXInst<(outs Int64Regs:$dst), (ins ADDR:$addr),
"add.u64 \t$dst, ${addr:add};", []>;

def to_tframeindex : SDNodeXForm<frameindex, [{
return CurDAG->getTargetFrameIndex(N->getIndex(), N->getValueType(0));
}]>;

def : Pat<(i32 frameindex:$fi), (LEA_ADDRi (to_tframeindex $fi), 0)>;
def : Pat<(i64 frameindex:$fi), (LEA_ADDRi64 (to_tframeindex $fi), 0)>;

//-----------------------------------
// Comparison and Selection
Expand Down Expand Up @@ -2660,7 +2653,7 @@ def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ",
def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a",
[(LastCallArg (i32 1), (i32 imm:$a))]>;

def CallVoidInst : NVPTXInst<(outs), (ins imem:$addr), "$addr, ",
def CallVoidInst : NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ",
[(CallVoid (Wrapper tglobaladdr:$addr))]>;
def CallVoidInstReg : NVPTXInst<(outs), (ins Int32Regs:$addr), "$addr, ",
[(CallVoid i32:$addr)]>;
Expand Down Expand Up @@ -2753,109 +2746,56 @@ foreach vt = [v2f16, v2bf16, v2i16, v4i8] in {
//
// Load / Store Handling
//
multiclass LD<NVPTXRegClass regclass> {
def _ari : NVPTXInst<
class LD<NVPTXRegClass regclass>
: NVPTXInst<
(outs regclass:$dst),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign,
i32imm:$fromWidth, Int32Regs:$addr, Offseti32imm:$offset),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t$dst, [$addr$offset];", []>;
def _ari_64 : NVPTXInst<
(outs regclass:$dst),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int64Regs:$addr, Offseti32imm:$offset),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t$dst, [$addr$offset];", []>;
def _asi : NVPTXInst<
(outs regclass:$dst),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, imem:$addr, Offseti32imm:$offset),
i32imm:$fromWidth, ADDR:$addr),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t$dst, [$addr$offset];", []>;
}
"\t$dst, [$addr];", []>;

let mayLoad=1, hasSideEffects=0 in {
defm LD_i8 : LD<Int16Regs>;
defm LD_i16 : LD<Int16Regs>;
defm LD_i32 : LD<Int32Regs>;
defm LD_i64 : LD<Int64Regs>;
defm LD_f32 : LD<Float32Regs>;
defm LD_f64 : LD<Float64Regs>;
def LD_i8 : LD<Int16Regs>;
def LD_i16 : LD<Int16Regs>;
def LD_i32 : LD<Int32Regs>;
def LD_i64 : LD<Int64Regs>;
def LD_f32 : LD<Float32Regs>;
def LD_f64 : LD<Float64Regs>;
}

multiclass ST<NVPTXRegClass regclass> {
def _ari : NVPTXInst<
class ST<NVPTXRegClass regclass>
: NVPTXInst<
(outs),
(ins regclass:$src, LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
LdStCode:$Vec, LdStCode:$Sign, i32imm:$toWidth, Int32Regs:$addr,
Offseti32imm:$offset),
LdStCode:$Vec, LdStCode:$Sign, i32imm:$toWidth, ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$toWidth"
" \t[$addr$offset], $src;", []>;
def _ari_64 : NVPTXInst<
(outs),
(ins regclass:$src, LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
LdStCode:$Vec, LdStCode:$Sign, i32imm:$toWidth, Int64Regs:$addr,
Offseti32imm:$offset),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$toWidth"
" \t[$addr$offset], $src;", []>;
def _asi : NVPTXInst<
(outs),
(ins regclass:$src, LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp,
LdStCode:$Vec, LdStCode:$Sign, i32imm:$toWidth, imem:$addr,
Offseti32imm:$offset),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$toWidth"
" \t[$addr$offset], $src;", []>;
}
" \t[$addr], $src;", []>;

let mayStore=1, hasSideEffects=0 in {
defm ST_i8 : ST<Int16Regs>;
defm ST_i16 : ST<Int16Regs>;
defm ST_i32 : ST<Int32Regs>;
defm ST_i64 : ST<Int64Regs>;
defm ST_f32 : ST<Float32Regs>;
defm ST_f64 : ST<Float64Regs>;
def ST_i8 : ST<Int16Regs>;
def ST_i16 : ST<Int16Regs>;
def ST_i32 : ST<Int32Regs>;
def ST_i64 : ST<Int64Regs>;
def ST_f32 : ST<Float32Regs>;
def ST_f64 : ST<Float64Regs>;
}

// The following is used only in and after vector elementizations. Vector
// elementization happens at the machine instruction level, so the following
// instructions never appear in the DAG.
multiclass LD_VEC<NVPTXRegClass regclass> {
def _v2_ari : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int32Regs:$addr, Offseti32imm:$offset),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2}}, [$addr$offset];", []>;
def _v2_ari_64 : NVPTXInst<
def _v2 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int64Regs:$addr, Offseti32imm:$offset),
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2}}, [$addr$offset];", []>;
def _v2_asi : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, imem:$addr, Offseti32imm:$offset),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2}}, [$addr$offset];", []>;
def _v4_ari : NVPTXInst<
"\t{{$dst1, $dst2}}, [$addr];", []>;
def _v4 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int32Regs:$addr, Offseti32imm:$offset),
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr$offset];", []>;
def _v4_ari_64 : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int64Regs:$addr, Offseti32imm:$offset),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr$offset];", []>;
def _v4_asi : NVPTXInst<
(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
(ins LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, imem:$addr, Offseti32imm:$offset),
"ld${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr$offset];", []>;
"\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];", []>;
}
let mayLoad=1, hasSideEffects=0 in {
defm LDV_i8 : LD_VEC<Int16Regs>;
Expand All @@ -2867,48 +2807,20 @@ let mayLoad=1, hasSideEffects=0 in {
}

multiclass ST_VEC<NVPTXRegClass regclass> {
def _v2_ari : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign, i32imm:$fromWidth,
Int32Regs:$addr, Offseti32imm:$offset),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr$offset], {{$src1, $src2}};", []>;
def _v2_ari_64 : NVPTXInst<
def _v2 : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign, i32imm:$fromWidth,
Int64Regs:$addr, Offseti32imm:$offset),
ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr$offset], {{$src1, $src2}};", []>;
def _v2_asi : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, LdStCode:$sem, LdStCode:$scope,
LdStCode:$addsp, LdStCode:$Vec, LdStCode:$Sign, i32imm:$fromWidth,
imem:$addr, Offseti32imm:$offset),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr$offset], {{$src1, $src2}};", []>;
def _v4_ari : NVPTXInst<
"\t[$addr], {{$src1, $src2}};", []>;
def _v4 : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int32Regs:$addr, Offseti32imm:$offset),
LdStCode:$Sign, i32imm:$fromWidth, ADDR:$addr),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr$offset], {{$src1, $src2, $src3, $src4}};", []>;
def _v4_ari_64 : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, Int64Regs:$addr, Offseti32imm:$offset),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}$fromWidth "
"\t[$addr$offset], {{$src1, $src2, $src3, $src4}};", []>;
def _v4_asi : NVPTXInst<
(outs),
(ins regclass:$src1, regclass:$src2, regclass:$src3, regclass:$src4,
LdStCode:$sem, LdStCode:$scope, LdStCode:$addsp, LdStCode:$Vec,
LdStCode:$Sign, i32imm:$fromWidth, imem:$addr, Offseti32imm:$offset),
"st${sem:sem}${scope:scope}${addsp:addsp}${Vec:vec}.${Sign:sign}"
"$fromWidth \t[$addr$offset], {{$src1, $src2, $src3, $src4}};", []>;
"\t[$addr], {{$src1, $src2, $src3, $src4}};", []>;
}

let mayStore=1, hasSideEffects=0 in {
Expand Down
Loading
Loading