Skip to content

[NVPTX] Combine addressing-mode variants of ld, st, wmma #129102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

AlexMaclean
Copy link
Member

This change fold together the _ari, _ari64, and _asi variants of these instructions into a single instruction capable of holding any address. This allows for the removal of a lot of unnecessary code and moves us towards a standard way of representing an address in NVPTX.

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

This change fold together the _ari, _ari64, and _asi variants of these instructions into a single instruction capable of holding any address. This allows for the removal of a lot of unnecessary code and moves us towards a standard way of representing an address in NVPTX.


Patch is 58.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129102.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+157-410)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1-11)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+51-139)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+78-128)
  • (modified) llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp (+1-1)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 971a128aadfdb..08022104bfedf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -930,8 +930,6 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   if (canLowerToLDG(LD, *Subtarget, CodeAddrSpace, MF)) {
     return tryLDGLDU(N);
   }
-  unsigned int PointerSize =
-      CurDAG->getDataLayout().getPointerSizeInBits(LD->getAddressSpace());
 
   SDLoc DL(N);
   SDValue Chain = N->getOperand(0);
@@ -964,37 +962,24 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
     FromType = getLdStRegType(ScalarVT);
 
   // Create the machine instruction DAG
-  SDValue N1 = N->getOperand(1);
   SDValue Offset, Base;
-  std::optional<unsigned> Opcode;
-  MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
-
-  SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-                                getI32Imm(CodeAddrSpace, DL),
-                                getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                                getI32Imm(FromTypeWidth, DL)});
-
-  if (SelectADDRsi(N1.getNode(), N1, Base, Offset)) {
-    Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_asi, NVPTX::LD_i16_asi,
-                             NVPTX::LD_i32_asi, NVPTX::LD_i64_asi,
-                             NVPTX::LD_f32_asi, NVPTX::LD_f64_asi);
-  } else {
-    if (PointerSize == 64) {
-      SelectADDRri64(N1.getNode(), N1, Base, Offset);
-      Opcode =
-          pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari_64, NVPTX::LD_i16_ari_64,
-                          NVPTX::LD_i32_ari_64, NVPTX::LD_i64_ari_64,
-                          NVPTX::LD_f32_ari_64, NVPTX::LD_f64_ari_64);
-    } else {
-      SelectADDRri(N1.getNode(), N1, Base, Offset);
-      Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari, NVPTX::LD_i16_ari,
-                               NVPTX::LD_i32_ari, NVPTX::LD_i64_ari,
-                               NVPTX::LD_f32_ari, NVPTX::LD_f64_ari);
-    }
-  }
+  SelectADDR(N->getOperand(1), Base, Offset);
+  SDValue Ops[] = {getI32Imm(Ordering, DL),
+                   getI32Imm(Scope, DL),
+                   getI32Imm(CodeAddrSpace, DL),
+                   getI32Imm(VecType, DL),
+                   getI32Imm(FromType, DL),
+                   getI32Imm(FromTypeWidth, DL),
+                   Base,
+                   Offset,
+                   Chain};
+
+  const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
+  const std::optional<unsigned> Opcode =
+      pickOpcodeForVT(TargetVT, NVPTX::LD_i8, NVPTX::LD_i16, NVPTX::LD_i32,
+                      NVPTX::LD_i64, NVPTX::LD_f32, NVPTX::LD_f64);
   if (!Opcode)
     return false;
-  Ops.append({Base, Offset, Chain});
 
   SDNode *NVPTXLD =
       CurDAG->getMachineNode(*Opcode, DL, TargetVT, MVT::Other, Ops);
@@ -1030,8 +1015,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
     return tryLDGLDU(N);
   }
-  unsigned int PointerSize =
-      CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace());
 
   SDLoc DL(N);
   SDValue Chain = N->getOperand(0);
@@ -1079,77 +1062,38 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     FromTypeWidth = 32;
   }
 
-  SDValue Op1 = N->getOperand(1);
   SDValue Offset, Base;
-  std::optional<unsigned> Opcode;
-  SDNode *LD;
+  SelectADDR(N->getOperand(1), Base, Offset);
+  SDValue Ops[] = {getI32Imm(Ordering, DL),
+                   getI32Imm(Scope, DL),
+                   getI32Imm(CodeAddrSpace, DL),
+                   getI32Imm(VecType, DL),
+                   getI32Imm(FromType, DL),
+                   getI32Imm(FromTypeWidth, DL),
+                   Base,
+                   Offset,
+                   Chain};
 
-  SmallVector<SDValue, 12> Ops({getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-                                getI32Imm(CodeAddrSpace, DL),
-                                getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                                getI32Imm(FromTypeWidth, DL)});
-
-  if (SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
-    switch (N->getOpcode()) {
-    default:
-      return false;
-    case NVPTXISD::LoadV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::LDV_i8_v2_asi, NVPTX::LDV_i16_v2_asi,
-                               NVPTX::LDV_i32_v2_asi, NVPTX::LDV_i64_v2_asi,
-                               NVPTX::LDV_f32_v2_asi, NVPTX::LDV_f64_v2_asi);
-      break;
-    case NVPTXISD::LoadV4:
-      Opcode =
-          pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_asi,
-                          NVPTX::LDV_i16_v4_asi, NVPTX::LDV_i32_v4_asi,
-                          std::nullopt, NVPTX::LDV_f32_v4_asi, std::nullopt);
-      break;
-    }
-  } else {
-    if (PointerSize == 64) {
-      SelectADDRri64(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case NVPTXISD::LoadV2:
-        Opcode =
-            pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                            NVPTX::LDV_i8_v2_ari_64, NVPTX::LDV_i16_v2_ari_64,
-                            NVPTX::LDV_i32_v2_ari_64, NVPTX::LDV_i64_v2_ari_64,
-                            NVPTX::LDV_f32_v2_ari_64, NVPTX::LDV_f64_v2_ari_64);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari_64,
-            NVPTX::LDV_i16_v4_ari_64, NVPTX::LDV_i32_v4_ari_64, std::nullopt,
-            NVPTX::LDV_f32_v4_ari_64, std::nullopt);
-        break;
-      }
-    } else {
-      SelectADDRri(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::LDV_i8_v2_ari, NVPTX::LDV_i16_v2_ari,
-                                 NVPTX::LDV_i32_v2_ari, NVPTX::LDV_i64_v2_ari,
-                                 NVPTX::LDV_f32_v2_ari, NVPTX::LDV_f64_v2_ari);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode =
-            pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari,
-                            NVPTX::LDV_i16_v4_ari, NVPTX::LDV_i32_v4_ari,
-                            std::nullopt, NVPTX::LDV_f32_v4_ari, std::nullopt);
-        break;
-      }
-    }
+  std::optional<unsigned> Opcode;
+  switch (N->getOpcode()) {
+  default:
+    return false;
+  case NVPTXISD::LoadV2:
+    Opcode =
+        pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2,
+                        NVPTX::LDV_i16_v2, NVPTX::LDV_i32_v2, NVPTX::LDV_i64_v2,
+                        NVPTX::LDV_f32_v2, NVPTX::LDV_f64_v2);
+    break;
+  case NVPTXISD::LoadV4:
+    Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4,
+                             NVPTX::LDV_i16_v4, NVPTX::LDV_i32_v4, std::nullopt,
+                             NVPTX::LDV_f32_v4, std::nullopt);
+    break;
   }
   if (!Opcode)
     return false;
-  Ops.append({Base, Offset, Chain});
-  LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
+
+  SDNode *LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
 
   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
   CurDAG->setNodeMemRefs(cast<MachineSDNode>(LD), {MemRef});
@@ -1197,176 +1141,58 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
   SDValue Chain = N->getOperand(0);
 
   std::optional<unsigned> Opcode;
-  SDLoc DL(N);
-  SDNode *LD;
-  SDValue Base, Offset;
-
-  if (SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
-    switch (N->getOpcode()) {
-    default:
-      return false;
-    case ISD::LOAD:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_i16asi, NVPTX::INT_PTX_LDG_GLOBAL_i32asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_i64asi, NVPTX::INT_PTX_LDG_GLOBAL_f32asi,
-          NVPTX::INT_PTX_LDG_GLOBAL_f64asi);
-      break;
-    case ISD::INTRINSIC_W_CHAIN:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_i16asi, NVPTX::INT_PTX_LDU_GLOBAL_i32asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_i64asi, NVPTX::INT_PTX_LDU_GLOBAL_f32asi,
-          NVPTX::INT_PTX_LDU_GLOBAL_f64asi);
-      break;
-    case NVPTXISD::LoadV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::INT_PTX_LDG_G_v2i8_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i16_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i32_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2i64_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2f32_ELE_asi,
-                               NVPTX::INT_PTX_LDG_G_v2f64_ELE_asi);
-      break;
-    case NVPTXISD::LDUV2:
-      Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                               NVPTX::INT_PTX_LDU_G_v2i8_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i16_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i32_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2i64_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2f32_ELE_asi,
-                               NVPTX::INT_PTX_LDU_G_v2f64_ELE_asi);
-      break;
-    case NVPTXISD::LoadV4:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_asi,
-          NVPTX::INT_PTX_LDG_G_v4i16_ELE_asi,
-          NVPTX::INT_PTX_LDG_G_v4i32_ELE_asi, std::nullopt,
-          NVPTX::INT_PTX_LDG_G_v4f32_ELE_asi, std::nullopt);
-      break;
-    case NVPTXISD::LDUV4:
-      Opcode = pickOpcodeForVT(
-          EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_asi,
-          NVPTX::INT_PTX_LDU_G_v4i16_ELE_asi,
-          NVPTX::INT_PTX_LDU_G_v4i32_ELE_asi, std::nullopt,
-          NVPTX::INT_PTX_LDU_G_v4f32_ELE_asi, std::nullopt);
-      break;
-    }
-  } else {
-    if (TM.is64Bit()) {
-      SelectADDRri64(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case ISD::LOAD:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i8ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i16ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i32ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_i64ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_f32ari64,
-                                 NVPTX::INT_PTX_LDG_GLOBAL_f64ari64);
-        break;
-      case ISD::INTRINSIC_W_CHAIN:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i8ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i16ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i32ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_i64ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_f32ari64,
-                                 NVPTX::INT_PTX_LDU_GLOBAL_f64ari64);
-        break;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                     NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64);
-        break;
-      case NVPTXISD::LDUV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                     NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64,
-                                     NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64,
-            NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64,
-            NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64, std::nullopt,
-            NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64, std::nullopt);
-        break;
-      case NVPTXISD::LDUV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64,
-            NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64,
-            NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64, std::nullopt,
-            NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64, std::nullopt);
-        break;
-      }
-    } else {
-      SelectADDRri(Op1.getNode(), Op1, Base, Offset);
-      switch (N->getOpcode()) {
-      default:
-        return false;
-      case ISD::LOAD:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_i16ari, NVPTX::INT_PTX_LDG_GLOBAL_i32ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_i64ari, NVPTX::INT_PTX_LDG_GLOBAL_f32ari,
-            NVPTX::INT_PTX_LDG_GLOBAL_f64ari);
-        break;
-      case ISD::INTRINSIC_W_CHAIN:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_i16ari, NVPTX::INT_PTX_LDU_GLOBAL_i32ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_i64ari, NVPTX::INT_PTX_LDU_GLOBAL_f32ari,
-            NVPTX::INT_PTX_LDU_GLOBAL_f64ari);
-        break;
-      case NVPTXISD::LoadV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32);
-        break;
-      case NVPTXISD::LDUV2:
-        Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
-                                 NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32,
-                                 NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32);
-        break;
-      case NVPTXISD::LoadV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32,
-            NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32,
-            NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32, std::nullopt,
-            NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32, std::nullopt);
-        break;
-      case NVPTXISD::LDUV4:
-        Opcode = pickOpcodeForVT(
-            EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32,
-            NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32,
-            NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32, std::nullopt,
-            NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32, std::nullopt);
-        break;
-      }
-    }
+  switch (N->getOpcode()) {
+  default:
+    return false;
+  case ISD::LOAD:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
+        NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
+        NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
+        NVPTX::INT_PTX_LDG_GLOBAL_f64);
+    break;
+  case ISD::INTRINSIC_W_CHAIN:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
+        NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
+        NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
+        NVPTX::INT_PTX_LDU_GLOBAL_f64);
+    break;
+  case NVPTXISD::LoadV2:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
+        NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
+        NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
+        NVPTX::INT_PTX_LDG_G_v2f64_ELE);
+    break;
+  case NVPTXISD::LDUV2:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
+        NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
+        NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
+        NVPTX::INT_PTX_LDU_G_v2f64_ELE);
+    break;
+  case NVPTXISD::LoadV4:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
+        NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
+        std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
+    break;
+  case NVPTXISD::LDUV4:
+    Opcode = pickOpcodeForVT(
+        EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
+        NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
+        std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
+    break;
   }
   if (!Opcode)
     return false;
+
+  SDLoc DL(N);
+  SDValue Base, Offset;
+  SelectADDR(Op1, Base, Offset);
   SDValue Ops[] = {Base, Offset, Chain};
-  LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
+  SDNode *LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
 
   // For automatic generation of LDG (through SelectLoad[Vector], not the
   // intrinsics), we may have an extending load like:
@@ -1424,8 +1250,6 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
 
   // Address Space Setting
   unsigned int CodeAddrSpace = getCodeAddrSpace(ST);
-  unsigned int PointerSize =
-      CurDAG->getDataLayout().getPointerSizeInBits(ST->getAddressSpace());
 
   SDLoc DL(N);
   SDValue Chain = ST->getChain();
@@ -1450,38 +1274,28 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
 
   // Create the machine instruction DAG
   SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
-  SDValue BasePtr = ST->getBasePtr();
+
   SDValue Offset, Base;
-  std::optional<unsigned> Opcode;
-  MVT::SimpleValueType SourceVT =
+  SelectADDR(ST->getBasePtr(), Base, Offset);
+
+  SDValue Ops[] = {Value,
+                   getI32Imm(Ordering, DL),
+                   getI32Imm(Scope, DL),
+                   getI32Imm(CodeAddrSpace, DL),
+                   getI32Imm(VecType, DL),
+                   getI32Imm(ToType, DL),
+                   getI32Imm(ToTypeWidth, DL),
+                   Base,
+                   Offset,
+                   Chain};
+
+  const MVT::SimpleValueType SourceVT =
       Value.getNode()->getSimpleValueType(0).SimpleTy;
-
-  SmallVector<SDValue, 12> Ops(
-      {Value, getI32Imm(Ordering, DL), getI32Imm(Scope, DL),
-       getI32Imm(CodeAddrSpace, DL), getI32Imm(VecType, DL),
-       getI32Imm(ToType, DL), getI32Imm(ToTypeWidth, DL)});
-
-  if (SelectADDRsi(BasePtr.getNode(), BasePtr, Base, Offset)) {
-    Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_asi, NVPTX::ST_i16_asi,
-                             NVPTX::ST_i32_asi, NVPTX::ST_i64_asi,
-                             NVPTX::ST_f32_asi, NVPTX::ST_f64_asi);
-  } else {
-    if (PointerSize == 64) {
-      SelectADDRri64(BasePtr.getNode(), BasePtr, Base, Offset);
-      Opcode =
-          pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari_64, NVPTX::ST_i16_ari_64,
-                          NVPTX::ST_i32_ari_64, NVPTX::ST_i64_ari_64,
-                          NVPTX::ST_f32_ari_64, NVPTX::ST_f64_ari_64);
-    } else {
-      SelectADDRri(BasePtr.getNode(), BasePtr, Base, Offset);
-      Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari, NVPTX::ST_i16_ari,
-                      ...
[truncated]

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/combine-addr-variants branch from 5e423a9 to ba1129b Compare February 27, 2025 20:04
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. LGTM.

Comment on lines +2084 to +2087
// - [var] - Offset is simply set to 0
// - [reg] - Offset is simply set to 0
// - [reg+immOff]
// - [var+immOff]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words it's [ var|reg + offset], where offset may be 0.

It may be worth mentioning the limits on offset value must fit in a signed i32.

Copy link
Contributor

@kalxr kalxr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr,
SDValue &Base, SDValue &Offset) {
SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64);
// Select a pair of operands which represnent a valid PTX address, this could be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

represnent


SDLoc DL(N);
SDValue Base, Offset;
SelectADDR(Op1, Base, Offset);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny nit - seems like in all the other cases we call SelectADDR before the opcode switch, would prefer consistency unless there's a reason for it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously SelectADDR* functions were called in individual conditional branches, so each had to be done separately.
Now we have only one function to call and we can move it close to where it's actually used.
I believe this location is the right place for the call now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoot, I already moved it to address @kalxr's request. I think either location is basically fine and I'll plan to leave it where it is now unless there are any strong reasons for the other placement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

@AlexMaclean AlexMaclean merged commit 85f8bd1 into llvm:main Feb 28, 2025
8 of 11 checks passed
cheezeburglar pushed a commit to cheezeburglar/llvm-project that referenced this pull request Feb 28, 2025
This change fold together the _ari, _ari64, and _asi variants of these
instructions into a single instruction capable of holding any address.
This allows for the removal of a lot of unnecessary code and moves us
towards a standard way of representing an address in NVPTX.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants