Skip to content

[mlir][vector] LoadOp/StoreOp: Allow 0-D vectors #76134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 22, 2023

Conversation

matthias-springer
Copy link
Member

Similar to vector.transfer_read/vector.transfer_write, allow 0-D vectors.

This commit fixes mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir when verifying the IR after each pattern (#74270). That test produces a temporary 0-D load/store op.

Similar to `vector.transfer_read`/`vector.transfer_write`, allow 0-D vectors.

This commit fixes `mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir` when verifying the IR after each pattern (llvm#74270). That test produces a temporary 0-D load/store op.
@llvmbot
Copy link
Member

llvmbot commented Dec 21, 2023

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Similar to vector.transfer_read/vector.transfer_write, allow 0-D vectors.

This commit fixes mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir when verifying the IR after each pattern (#74270). That test produces a temporary 0-D load/store op.


Full diff: https://github.com/llvm/llvm-project/pull/76134.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+27-15)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+30)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 423118f79e733d..40d874dc99dd90 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1582,22 +1582,27 @@ def Vector_LoadOp : Vector_Op<"load"> {
     vector. If the memref element type is vector, it should match the result
     vector type.
 
-    Example 1: 1-D vector load on a scalar memref.
+    Example: 0-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<f32>
+    ```
+
+    Example: 1-D vector load on a scalar memref.
     ```mlir
     %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
     ```
 
-    Example 2: 1-D vector load on a vector memref.
+    Example: 1-D vector load on a vector memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
     ```
 
-    Example 3:  2-D vector load on a scalar memref.
+    Example:  2-D vector load on a scalar memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
     ```
 
-    Example 4:  2-D vector load on a vector memref.
+    Example:  2-D vector load on a vector memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
     ```
@@ -1608,12 +1613,12 @@ def Vector_LoadOp : Vector_Op<"load"> {
     loaded out of bounds. Not all targets may support out-of-bounds vector
     loads.
 
-    Example 5:  Potential out-of-bound vector load.
+    Example:  Potential out-of-bound vector load.
     ```mlir
     %result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
     ```
 
-    Example 6:  Explicit out-of-bound vector load.
+    Example:  Explicit out-of-bound vector load.
     ```mlir
     %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
@@ -1622,7 +1627,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices);
-  let results = (outs AnyVector:$result);
+  let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -1660,22 +1665,27 @@ def Vector_StoreOp : Vector_Op<"store"> {
     to store. If the memref element type is vector, it should match the type
     of the value to store.
 
-    Example 1: 1-D vector store on a scalar memref.
+    Example: 0-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+    ```
+
+    Example: 1-D vector store on a scalar memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
     ```
 
-    Example 2: 1-D vector store on a vector memref.
+    Example: 1-D vector store on a vector memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
     ```
 
-    Example 3:  2-D vector store on a scalar memref.
+    Example:  2-D vector store on a scalar memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
     ```
 
-    Example 4:  2-D vector store on a vector memref.
+    Example:  2-D vector store on a vector memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
     ```
@@ -1685,21 +1695,23 @@ def Vector_StoreOp : Vector_Op<"store"> {
     target-specific. No assumptions should be made on the memory written out of
     bounds. Not all targets may support out-of-bounds vector stores.
 
-    Example 5:  Potential out-of-bounds vector store.
+    Example:  Potential out-of-bounds vector store.
     ```mlir
     vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
     ```
 
-    Example 6:  Explicit out-of-bounds vector store.
+    Example:  Explicit out-of-bounds vector store.
     ```mlir
     vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
   }];
 
-  let arguments = (ins AnyVector:$valueToStore,
+  let arguments = (ins
+      AnyVectorOfAnyRank:$valueToStore,
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
-      Variadic<Index>:$indices);
+      Variadic<Index>:$indices
+  );
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d80392ebd87b03..7ea0197bdecb36 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2059,6 +2059,36 @@ func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j
 
 // -----
 
+func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<f32> {
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: func @vector_load_op_0d
+// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
+// CHECK: return %[[cast]] : vector<f32>
+
+// -----
+
+func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = arith.constant dense<11.0> : vector<f32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op_0d
+// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
+// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+
+// -----
+
 func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c1ef8f2c30c05c..49f0af5d81e45e 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -714,6 +714,16 @@ func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   return %0 : vector<16xi32>
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
+func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
+                                                  %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {

@llvmbot
Copy link
Member

llvmbot commented Dec 21, 2023

@llvm/pr-subscribers-mlir-vector

Author: Matthias Springer (matthias-springer)

Changes

Similar to vector.transfer_read/vector.transfer_write, allow 0-D vectors.

This commit fixes mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir when verifying the IR after each pattern (#74270). That test produces a temporary 0-D load/store op.


Full diff: https://github.com/llvm/llvm-project/pull/76134.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+27-15)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+30)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 423118f79e733d..40d874dc99dd90 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1582,22 +1582,27 @@ def Vector_LoadOp : Vector_Op<"load"> {
     vector. If the memref element type is vector, it should match the result
     vector type.
 
-    Example 1: 1-D vector load on a scalar memref.
+    Example: 0-D vector load on a scalar memref.
+    ```mlir
+    %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<f32>
+    ```
+
+    Example: 1-D vector load on a scalar memref.
     ```mlir
     %result = vector.load %base[%i, %j] : memref<100x100xf32>, vector<8xf32>
     ```
 
-    Example 2: 1-D vector load on a vector memref.
+    Example: 1-D vector load on a vector memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
     ```
 
-    Example 3:  2-D vector load on a scalar memref.
+    Example:  2-D vector load on a scalar memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
     ```
 
-    Example 4:  2-D vector load on a vector memref.
+    Example:  2-D vector load on a vector memref.
     ```mlir
     %result = vector.load %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
     ```
@@ -1608,12 +1613,12 @@ def Vector_LoadOp : Vector_Op<"load"> {
     loaded out of bounds. Not all targets may support out-of-bounds vector
     loads.
 
-    Example 5:  Potential out-of-bound vector load.
+    Example:  Potential out-of-bound vector load.
     ```mlir
     %result = vector.load %memref[%index] : memref<?xf32>, vector<8xf32>
     ```
 
-    Example 6:  Explicit out-of-bound vector load.
+    Example:  Explicit out-of-bound vector load.
     ```mlir
     %result = vector.load %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
@@ -1622,7 +1627,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$base,
       Variadic<Index>:$indices);
-  let results = (outs AnyVector:$result);
+  let results = (outs AnyVectorOfAnyRank:$result);
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -1660,22 +1665,27 @@ def Vector_StoreOp : Vector_Op<"store"> {
     to store. If the memref element type is vector, it should match the type
     of the value to store.
 
-    Example 1: 1-D vector store on a scalar memref.
+    Example: 0-D vector store on a scalar memref.
+    ```mlir
+    vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+    ```
+
+    Example: 1-D vector store on a scalar memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<8xf32>
     ```
 
-    Example 2: 1-D vector store on a vector memref.
+    Example: 1-D vector store on a vector memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<8xf32>>, vector<8xf32>
     ```
 
-    Example 3:  2-D vector store on a scalar memref.
+    Example:  2-D vector store on a scalar memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xf32>, vector<4x8xf32>
     ```
 
-    Example 4:  2-D vector store on a vector memref.
+    Example:  2-D vector store on a vector memref.
     ```mlir
     vector.store %valueToStore, %memref[%i, %j] : memref<200x100xvector<4x8xf32>>, vector<4x8xf32>
     ```
@@ -1685,21 +1695,23 @@ def Vector_StoreOp : Vector_Op<"store"> {
     target-specific. No assumptions should be made on the memory written out of
     bounds. Not all targets may support out-of-bounds vector stores.
 
-    Example 5:  Potential out-of-bounds vector store.
+    Example:  Potential out-of-bounds vector store.
     ```mlir
     vector.store %valueToStore, %memref[%index] : memref<?xf32>, vector<8xf32>
     ```
 
-    Example 6:  Explicit out-of-bounds vector store.
+    Example:  Explicit out-of-bounds vector store.
     ```mlir
     vector.store %valueToStore, %memref[%c0] : memref<7xf32>, vector<8xf32>
     ```
   }];
 
-  let arguments = (ins AnyVector:$valueToStore,
+  let arguments = (ins
+      AnyVectorOfAnyRank:$valueToStore,
       Arg<AnyMemRef, "the reference to store to",
       [MemWrite]>:$base,
-      Variadic<Index>:$indices);
+      Variadic<Index>:$indices
+  );
 
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index d80392ebd87b03..7ea0197bdecb36 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2059,6 +2059,36 @@ func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j
 
 // -----
 
+func.func @vector_load_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<f32> {
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: func @vector_load_op_0d
+// CHECK: %[[load:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
+// CHECK: %[[vec:.*]] = llvm.mlir.undef : vector<1xf32>
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: %[[inserted:.*]] = llvm.insertelement %[[load]], %[[vec]][%[[c0]] : i32] : vector<1xf32>
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[inserted]] : vector<1xf32> to vector<f32>
+// CHECK: return %[[cast]] : vector<f32>
+
+// -----
+
+func.func @vector_store_op_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
+  %val = arith.constant dense<11.0> : vector<f32>
+  vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: func @vector_store_op_0d
+// CHECK: %[[val:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
+// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[val]] : vector<f32> to vector<1xf32>
+// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[extracted:.*]] = llvm.extractelement %[[cast]][%[[c0]] : i64] : vector<1xf32>
+// CHECK: memref.store %[[extracted]], %{{.*}}[%{{.*}}, %{{.*}}]
+
+// -----
+
 func.func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
   %c0 = arith.constant 0: index
   %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index c1ef8f2c30c05c..49f0af5d81e45e 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -714,6 +714,16 @@ func.func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
   return %0 : vector<16xi32>
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_memref
+func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
+                                                  %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I'm not a code owner so probably best to wait for others to also have a look

// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
%0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<f32>
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32>, vector<f32>
vector.store %0, %memref[%i, %j] : memref<200x100xf32>, vector<f32>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated to this change, but this reminded me the types are the wrong way round for vector.store

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was told that that's intentional (so that vector.load and vector.store look similar). Also not a fan, but that's tangential to this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was told that that's intentional (so that vector.load and vector.store look similar).

I see, it's a bit inconsistent w.r.t to transfer_read / transfer_write where the types order matches the inputs.

Also not a fan, but that's tangential to this PR.

Of course

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@matthias-springer matthias-springer merged commit c99670b into llvm:main Dec 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants