-
Notifications
You must be signed in to change notification settings - Fork 7
fix #54: implement linalg matmul and batch_matmul #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
eb14a12 to
54b5eb7
Compare
1da164c to
71a326d
Compare
src/pydsl/linalg.py
Outdated
| if ( | ||
| len({ | ||
| x.element_type, | ||
| y.element_type, | ||
| *([init.element_type] if init else []), | ||
| }) | ||
| != 1 | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely x.element_type is not y.element_type or x.element_type is not init.element_type is simpler. Just move the code for setting init above this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
went with len({x.element_type, y.element_type, init.element_type}) != 1 because if it's shorter, if you are OK with that.
| if not ( | ||
| are_dims_compatible(x.shape[1], y.shape[0]) | ||
| and are_dims_compatible(x.shape[0], init.shape[0]) | ||
| and are_dims_compatible(y.shape[1], init.shape[1]) | ||
| ): | ||
| raise TypeError( | ||
| "operands need to be in matmul form. e.g (m,k)x(k,n) = (m,n)" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this might need to be a strict check: x.shape[1] == y.shape[0] not are_dims_compatible. IIRC the last time I tried, linalg.add on memref<10x?xf32> and memref<10x20xf32> didn't work, since linalg wants the shapes to be exactly the same, not just compatible. If compatible but not identical dimensions do work for matmul, add a testcase for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my matmul and batch_matmul tests:
TensorTD = TensorFactory((DYNAMIC, DYNAMIC), SInt64)
assert (
f[TensorTD, TensorTD, TensorTD](A.copy(), B.copy(), C.copy())
== cor_res
).all()As for add, the following test passes
def test_linalg_add_dynamic():
@compile(dump_mlir=True, dump_mlir_passes=True)
def f(
m1: MemRef[UInt32, DYNAMIC, DYNAMIC],
m2: MemRef[UInt32, DYNAMIC, DYNAMIC],
m3: MemRef[UInt32, DYNAMIC, DYNAMIC],
) -> MemRef[UInt32, DYNAMIC, DYNAMIC]:
return linalg.add(m1, m2, out=m3)
n1 = multi_arange((10, 10), np.uint32)
n2 = multi_arange((10, 10), np.uint32) + 1000
n3 = multi_arange((10, 10), np.uint32) + 2000
cor_res = n1 + n2
res = f(n1, n2, n3)
assert (res == cor_res).all()
assert (n3 == cor_res).all()The MLIR is
module {
func.func public @f(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>, %arg2: memref<?x?xi32>) -> memref<?x?xi32> {
linalg.elemwise_binary {cast = #linalg.type_fn<cast_unsigned>, fun = #linalg.binary_fn<add>} ins(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xi32>) outs(%arg2 : memref<?x?xi32>)
return %arg2 : memref<?x?xi32>
}
}after ConvertLinalgToLoopPass
module {
func.func public @f(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>, %arg2: memref<?x?xi32>) -> memref<?x?xi32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%dim = memref.dim %arg0, %c0 : memref<?x?xi32>
%dim_0 = memref.dim %arg0, %c1 : memref<?x?xi32>
scf.for %arg3 = %c0 to %dim step %c1 {
scf.for %arg4 = %c0 to %dim_0 step %c1 {
%0 = memref.load %arg0[%arg3, %arg4] : memref<?x?xi32>
%1 = memref.load %arg1[%arg3, %arg4] : memref<?x?xi32>
%2 = arith.addi %0, %1 : i32
memref.store %2, %arg2[%arg3, %arg4] : memref<?x?xi32>
}
}
return %arg2 : memref<?x?xi32>
}
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant something like the following:
module {
func.func public @f(%arg0: memref<?x6xi32>, %arg1: memref<?x?xi32>, %arg2: memref<7x?xi32>){
linalg.add ins(%arg0, %arg1 : memref<?x6xi32>, memref<?x?xi32>) outs(%arg2 : memref<7x?xi32>)
return
}
}Where a static dimension is compared to a dynamic dimension. But this also seems to work now... Let me open an issue to make the existing linalg ops more lenient.
So the only change you should do is add a testcase with both static and dynamic dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added test
71a326d to
53724d4
Compare
53724d4 to
e01ffb5
Compare
fix #54: implement linalg.batch_matmul