Skip to content

Conversation

@Ritsuka314
Copy link
Collaborator

fix #54: implement linalg.batch_matmul

@Ritsuka314 Ritsuka314 force-pushed the fix-issue-54 branch 3 times, most recently from eb14a12 to 54b5eb7 Compare August 22, 2025 20:03
@Ritsuka314 Ritsuka314 marked this pull request as ready for review August 22, 2025 20:05
@Ritsuka314 Ritsuka314 changed the title [WIP] fix #54: implement linalg.batch_matmul fix #54: implement linalg.batch_matmul Aug 22, 2025
@Balint-R Balint-R added the feature New feature or request label Aug 25, 2025
@Ritsuka314 Ritsuka314 force-pushed the fix-issue-54 branch 2 times, most recently from 1da164c to 71a326d Compare August 27, 2025 17:30
Comment on lines 250 to 257
if (
len({
x.element_type,
y.element_type,
*([init.element_type] if init else []),
})
!= 1
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely x.element_type is not y.element_type or x.element_type is not init.element_type is simpler. Just move the code for setting init above this check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went with len({x.element_type, y.element_type, init.element_type}) != 1 because if it's shorter, if you are OK with that.

Comment on lines +276 to +278
if not (
are_dims_compatible(x.shape[1], y.shape[0])
and are_dims_compatible(x.shape[0], init.shape[0])
and are_dims_compatible(y.shape[1], init.shape[1])
):
raise TypeError(
"operands need to be in matmul form. e.g (m,k)x(k,n) = (m,n)"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this might need to be a strict check: x.shape[1] == y.shape[0] not are_dims_compatible. IIRC the last time I tried, linalg.add on memref<10x?xf32> and memref<10x20xf32> didn't work, since linalg wants the shapes to be exactly the same, not just compatible. If compatible but not identical dimensions do work for matmul, add a testcase for it.

Copy link
Collaborator Author

@Ritsuka314 Ritsuka314 Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my matmul and batch_matmul tests:

        TensorTD = TensorFactory((DYNAMIC, DYNAMIC), SInt64)

            assert (
                f[TensorTD, TensorTD, TensorTD](A.copy(), B.copy(), C.copy())
                == cor_res
            ).all()

As for add, the following test passes

def test_linalg_add_dynamic():
    @compile(dump_mlir=True, dump_mlir_passes=True)
    def f(
        m1: MemRef[UInt32, DYNAMIC, DYNAMIC],
        m2: MemRef[UInt32, DYNAMIC, DYNAMIC],
        m3: MemRef[UInt32, DYNAMIC, DYNAMIC],
    ) -> MemRef[UInt32, DYNAMIC, DYNAMIC]:
        return linalg.add(m1, m2, out=m3)

    n1 = multi_arange((10, 10), np.uint32)
    n2 = multi_arange((10, 10), np.uint32) + 1000
    n3 = multi_arange((10, 10), np.uint32) + 2000
    cor_res = n1 + n2
    res = f(n1, n2, n3)
    assert (res == cor_res).all()
    assert (n3 == cor_res).all()

The MLIR is

module {
  func.func public @f(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>, %arg2: memref<?x?xi32>) -> memref<?x?xi32> {
    linalg.elemwise_binary {cast = #linalg.type_fn<cast_unsigned>, fun = #linalg.binary_fn<add>} ins(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xi32>) outs(%arg2 : memref<?x?xi32>)
    return %arg2 : memref<?x?xi32>
  }
}

after ConvertLinalgToLoopPass

module {
  func.func public @f(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>, %arg2: memref<?x?xi32>) -> memref<?x?xi32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = memref.dim %arg0, %c0 : memref<?x?xi32>
    %dim_0 = memref.dim %arg0, %c1 : memref<?x?xi32>
    scf.for %arg3 = %c0 to %dim step %c1 {
      scf.for %arg4 = %c0 to %dim_0 step %c1 {
        %0 = memref.load %arg0[%arg3, %arg4] : memref<?x?xi32>
        %1 = memref.load %arg1[%arg3, %arg4] : memref<?x?xi32>
        %2 = arith.addi %0, %1 : i32
        memref.store %2, %arg2[%arg3, %arg4] : memref<?x?xi32>
      }
    }
    return %arg2 : memref<?x?xi32>
  }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant something like the following:

module {
  func.func public @f(%arg0: memref<?x6xi32>, %arg1: memref<?x?xi32>, %arg2: memref<7x?xi32>){
    linalg.add ins(%arg0, %arg1 : memref<?x6xi32>, memref<?x?xi32>) outs(%arg2 : memref<7x?xi32>)
    return
  }
}

Where a static dimension is compared to a dynamic dimension. But this also seems to work now... Let me open an issue to make the existing linalg ops more lenient.

So the only change you should do is add a testcase with both static and dynamic dimensions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test

@Ritsuka314 Ritsuka314 changed the title fix #54: implement linalg.batch_matmul fix #54: implement linalg matmul and batch_matmul Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Finish linalg.batch_matmul

2 participants