Skip to content

Commit 56c373c

Browse files
committed
fix #120: make linalg unary ops reject integer types
1 parent 3e56942 commit 56c373c

File tree

2 files changed

+51
-118
lines changed

2 files changed

+51
-118
lines changed

src/pydsl/linalg.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ def _gen_elementwise_unary_macro(op: DefinedOpCallable) -> CallMacro:
5353
def op_macro(visitor: ToMLIRBase, x: Compiled) -> Tensor | MemRef:
5454
verify_memref_tensor_types(x)
5555

56+
t = x.element_type
57+
if not issubclass(t, Float):
58+
raise TypeError(
59+
f"this linalg elementwise unary operation only supports "
60+
f"arguments with element type Float, got "
61+
f"{t.__qualname__}"
62+
)
63+
5664
if isinstance(x, Tensor):
5765
# Return a new tensor, since tensors are SSA in MLIR
5866
rep = op(lower_single(x), outs=[lower_single(x)])

tests/e2e/test_linalg.py

Lines changed: 43 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import typing
33

44
from collections.abc import Iterable
5+
6+
import pytest
57
import pydsl.arith as arith
68
from pydsl.frontend import compile
79
from pydsl.func import InlineFunction
810
from pydsl.memref import alloca, DYNAMIC, MemRef, MemRefFactory
911
from pydsl.tensor import Tensor, TensorFactory
10-
from pydsl.type import F32, F64, SInt32, SInt64, UInt8, UInt32, UInt64
12+
from pydsl.type import F32, F64, SInt8, SInt32, SInt64, UInt8, UInt32, UInt64
1113
import pydsl.linalg as linalg
1214
from helper import compilation_failed_from, multi_arange, run
1315

@@ -17,114 +19,48 @@
1719
MemRefUI64 = MemRefFactory((DYNAMIC,), UInt64)
1820

1921

20-
def test_linalg_exp():
21-
@compile()
22-
def f(t1: TensorF64) -> TensorF64:
23-
return linalg.exp(t1)
24-
25-
n1 = multi_arange((100,), np.float64) / 10 - 5
26-
assert np.allclose(f(n1.copy()), np.exp(n1))
27-
28-
29-
def test_linalg_log():
30-
@compile()
31-
def f(t1: TensorF64) -> TensorF64:
32-
return linalg.log(t1)
33-
34-
n1 = multi_arange((100,), np.float64) / 10 + 0.1
35-
assert np.allclose(f(n1.copy()), np.log(n1))
36-
37-
38-
def test_linalg_abs():
39-
@compile()
40-
def f(t1: TensorF64) -> TensorF64:
41-
return linalg.abs(t1)
42-
43-
n1 = multi_arange((100,), np.float64) / 10 - 5
44-
assert np.allclose(f(n1.copy()), np.abs(n1))
45-
46-
47-
def test_linalg_ceil():
48-
@compile()
49-
def f(t1: TensorF64) -> TensorF64:
50-
return linalg.ceil(t1)
51-
52-
n1 = multi_arange((100,), np.float64) / 10 - 5
53-
assert np.allclose(f(n1.copy()), np.ceil(n1))
54-
55-
56-
def test_linalg_floor():
57-
@compile()
58-
def f(t1: TensorF64) -> TensorF64:
59-
return linalg.floor(t1)
60-
61-
n1 = multi_arange((100,), np.float64) / 10 - 5
62-
assert np.allclose(f(n1.copy()), np.floor(n1))
63-
64-
65-
def test_linalg_negf():
66-
@compile()
67-
def f(t1: TensorF64) -> TensorF64:
68-
return linalg.negf(t1)
69-
70-
n1 = multi_arange((100,), np.float64) / 10 - 5
71-
assert np.allclose(f(n1.copy()), np.negative(n1))
72-
73-
74-
def test_linalg_round():
75-
@compile()
76-
def f(t1: TensorF64) -> TensorF64:
77-
return linalg.round(t1)
78-
22+
@pytest.mark.parametrize("dtype, Type", [
23+
(np.uint8, UInt8),
24+
(np.uint32, UInt32),
25+
(np.uint64, UInt64),
26+
(np.int8, SInt8),
27+
(np.int32, SInt32),
28+
(np.int64, SInt64),
29+
(np.float32, F32),
30+
(np.float64, F64),
31+
])
32+
@pytest.mark.parametrize("Container", [Tensor, MemRef])
33+
@pytest.mark.parametrize("linalg_op,np_op,input_gen",[
34+
(linalg.exp, np.exp, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
35+
(linalg.log, np.log, lambda dtype: multi_arange((100,), dtype) / 10 + 0.1),
36+
(linalg.abs, np.abs, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
37+
(linalg.ceil, np.ceil, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
38+
(linalg.floor, np.floor, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
39+
(linalg.negf, np.negative, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
40+
(linalg.round, np.round, lambda dtype: multi_arange((100,), dtype) / 9 - 5),
7941
# MLIR and numpy don't round 0.5 the same way
80-
n1 = multi_arange((100,), np.float64) / 9 - 5
81-
assert np.allclose(f(n1.copy()), np.round(n1))
82-
83-
84-
def test_linalg_sqrt():
85-
@compile()
86-
def f(t1: TensorF64) -> TensorF64:
87-
return linalg.sqrt(t1)
88-
89-
n1 = multi_arange((100,), np.float64) / 10
90-
assert np.allclose(f(n1.copy()), np.sqrt(n1))
91-
92-
93-
def test_linalg_rsqrt():
94-
@compile()
95-
def f(t1: TensorF64) -> TensorF64:
96-
return linalg.rsqrt(t1)
97-
98-
n1 = multi_arange((100,), np.float64) / 10 + 0.1
99-
assert np.allclose(f(n1.copy()), np.reciprocal(np.sqrt(n1)))
100-
101-
102-
def test_linalg_square():
103-
@compile()
104-
def f(t1: TensorF64) -> TensorF64:
105-
return linalg.square(t1)
106-
107-
n1 = multi_arange((100,), np.float64) / 10 - 5
108-
assert np.allclose(f(n1.copy()), np.square(n1))
109-
110-
111-
def test_linalg_tanh():
112-
@compile()
113-
def f(t1: TensorF64) -> TensorF64:
114-
return linalg.tanh(t1)
115-
116-
n1 = multi_arange((100,), np.float64) / 10 - 5
117-
assert np.allclose(f(n1.copy()), np.tanh(n1))
118-
119-
120-
# Numpy doesn't have erf. Scipy is needed.
121-
# def test_linalg_erf():
122-
# @compile()
123-
# def f(t1: TensorF64) -> TensorF64:
124-
# return linalg.erf(t1)
42+
(linalg.sqrt, np.sqrt, lambda dtype: multi_arange((100,), dtype) / 10),
43+
(linalg.rsqrt, lambda x: 1 / np.sqrt(x), lambda dtype: multi_arange((100,), dtype) / 10 + 0.1),
44+
(linalg.square, np.square, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
45+
(linalg.tanh, np.tanh, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
46+
# (linalg.erf, np.erf, lambda dtype: multi_arange((100,), dtype) / 10 - 5),
47+
# Numpy doesn't have erf. Scipy is needed.
48+
])
49+
def test_linalg_unary(linalg_op, np_op, input_gen, Container, dtype, Type):
50+
def make_func():
51+
@compile()
52+
def f(t1: Container[Type, DYNAMIC]) -> Container[Type, DYNAMIC]:
53+
return linalg_op(t1)
54+
return f
55+
56+
if np.issubdtype(dtype, np.integer):
57+
with compilation_failed_from(TypeError):
58+
make_func()
59+
return
12560

126-
# n1 = multi_arange((100,), np.float64) / 10 - 5
127-
# assert np.allclose(f(n1.copy()), np.erf(n1))
61+
f = make_func()
62+
n1 = input_gen(dtype)
63+
assert np.allclose(f(n1.copy()), np_op(n1))
12864

12965

13066
def test_multiple_unary():
@@ -461,17 +397,6 @@ def f(arr: MemRef[SInt32, 10], out: MemRef[SInt32, 4, 9]):
461397

462398

463399
if __name__ == "__main__":
464-
run(test_linalg_exp)
465-
run(test_linalg_log)
466-
run(test_linalg_abs)
467-
run(test_linalg_ceil)
468-
run(test_linalg_floor)
469-
run(test_linalg_negf)
470-
run(test_linalg_round)
471-
run(test_linalg_sqrt)
472-
run(test_linalg_rsqrt)
473-
run(test_linalg_square)
474-
run(test_linalg_tanh)
475400
run(test_multiple_unary)
476401
run(test_linalg_add)
477402
run(test_linalg_sub)

0 commit comments

Comments
 (0)