Skip to content

Commit 9d16b12

Browse files
committed
fix #55: support pos only and kw only args
1 parent cf265dd commit 9d16b12

File tree

4 files changed

+99
-12
lines changed

4 files changed

+99
-12
lines changed

src/pydsl/analysis/names.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ def analyze(func: ast.FunctionDef | ast.Lambda) -> set[str]:
110110
ba = BoundAnalysis()
111111

112112
# names of arguments are also bound in the function
113-
ba.bound.update([a.arg for a in func.args.args])
113+
ba.bound.update([
114+
a.arg
115+
for a in func.args.args
116+
+ func.args.posonlyargs
117+
+ func.args.kwonlyargs
118+
])
114119

115120
body = func.body if isinstance(func.body, Iterable) else [func.body]
116121

src/pydsl/frontend.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import ctypes
33
import dataclasses
44
import inspect
5+
from inspect import BoundArguments
56
import logging
67
import re
78
import subprocess
@@ -762,8 +763,30 @@ def temp_file_to_path(tempf):
762763
self.ll_to_so,
763764
])(file)
764765

766+
def parse_args(self, signature, *args, **kwargs) -> BoundArguments:
767+
"""
768+
Try to bind arguments, apply defaults, and type cast. Arguments should
769+
already be compiled to PyDSL types by this point.
770+
"""
771+
params = signature.parameters
772+
773+
try:
774+
# Associate each argument value passed into the call with
775+
# the parameter of the function
776+
bound_args = signature.bind(*args, **kwargs)
777+
binding = bound_args.arguments
778+
except TypeError as e:
779+
raise TypeError(
780+
f"couldn't bind arguments when calling an inline function: {e}"
781+
) from e
782+
783+
# Apply defaults for unfilled arguments
784+
bound_args.apply_defaults()
785+
786+
return bound_args
787+
765788
# TODO: this function is too large. Should break it down a bit
766-
def call_function(self, fname: str, *args) -> Any:
789+
def call_function(self, fname: str, *args, **kwargs) -> Any:
767790
if not hasattr(self, "_so"):
768791
raise RuntimeError(
769792
f"function {fname} is called before it is compiled"
@@ -772,22 +795,26 @@ def call_function(self, fname: str, *args) -> Any:
772795
f = self.get_func(fname)
773796
sig = f.signature
774797
so_f = self.load_function(f)
775-
if not len(sig.parameters) == len(args):
798+
if not len(sig.parameters) == len(args) + len(kwargs):
776799
raise TypeError(
777-
f"{f.name} takes {len(sig.parameters)} positional "
800+
f"{f.name} takes {len(sig.parameters)} "
778801
f"argument{"s" if len(sig.parameters) > 1 else ""} "
779-
f"but {len(args)} were given"
802+
f"but {len(args) + len(kwargs)} were given"
780803
)
781804

782805
arg_cont = ArgContainer()
783806

807+
bound_args = self.parse_args(sig, *args, **kwargs)
808+
784809
mapped_args_ct = [
785810
(
786811
ct,
787812
self.val_to_CType(arg_cont, sig.parameters[key].annotation, a),
788813
)
789-
for ct, key, a in zip(
790-
self.get_args_ctypes(f), sig.parameters, args, strict=False
814+
for ct, (key, a) in zip(
815+
self.get_args_ctypes(f),
816+
bound_args.arguments.items(),
817+
strict=False,
791818
)
792819
]
793820

@@ -1324,8 +1351,8 @@ def emit_mlir(self) -> str:
13241351

13251352

13261353
class CompiledFunction(CompiledObject[Callable[..., Any]]):
1327-
def __call__(self, *args) -> Any:
1328-
return self._target.call_function(self._o.__name__, *args)
1354+
def __call__(self, *args, **kwargs) -> Any:
1355+
return self._target.call_function(self._o.__name__, *args, **kwargs)
13291356

13301357

13311358
class CompiledClass(CompiledObject[type]):
@@ -1416,8 +1443,8 @@ def selfattr(name: str):
14161443

14171444
if name in _target.funcs():
14181445

1419-
def make_call(*args) -> Any:
1420-
return _target.call_function(name, *args)
1446+
def make_call(*args, **kwargs) -> Any:
1447+
return _target.call_function(name, *args, **kwargs)
14211448

14221449
return make_call
14231450

src/pydsl/func.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,10 @@ def on_Call(
616616
) -> SubtreeOut:
617617
self = attr_chain[-1]
618618
prefix_args = [visitor.visit(a) for a in prefix_args]
619-
args = [visitor.visit(a) for a in node.args]
619+
args = [
620+
visitor.visit(a)
621+
for a in node.args + [kw.value for kw in node.keywords]
622+
]
620623
args = [
621624
a if isinstance(a, t) else t(a) for t, a in zip(self.argst, args)
622625
]

tests/e2e/test_syntax.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,55 @@ def g(m1: MemRef[UInt8, 2, 3], m2: MemRef[Index, 1]) -> None:
130130
assert (m1 == expected).all()
131131

132132

133+
def test_pos_only_args():
134+
@compile()
135+
def positional_only(
136+
a: UInt8, /, b: UInt8, c: UInt8
137+
) -> Tuple[UInt8, UInt8, UInt8]:
138+
return a, b, c
139+
140+
assert positional_only(2, c=4, b=3) == (2, 3, 4)
141+
142+
143+
def test_kw_only_args():
144+
@compile()
145+
def keyword_only(
146+
a: UInt8, *, b: UInt8, c: UInt8
147+
) -> Tuple[UInt8, UInt8, UInt8]:
148+
return a, b, c
149+
150+
assert keyword_only(2, c=4, b=3) == (2, 3, 4)
151+
152+
153+
def test_pos_or_kw_args():
154+
@compile()
155+
def pos_or_kw(a: UInt8, b: UInt8, c: UInt8) -> Tuple[UInt8, UInt8, UInt8]:
156+
return a, b, c
157+
158+
assert pos_or_kw(2, c=4, b=3) == (2, 3, 4)
159+
160+
161+
def test_fun_args():
162+
@compile()
163+
class Mod:
164+
def pos_or_kw(
165+
a: UInt8, b: UInt8, c: UInt8
166+
) -> Tuple[UInt8, UInt8, UInt8]:
167+
return a, b, c
168+
169+
def keyword_only(
170+
a: UInt8, *, b: UInt8, c: UInt8
171+
) -> Tuple[UInt8, UInt8, UInt8]:
172+
return pos_or_kw(a, c=c, b=b)
173+
174+
def positional_only(
175+
a: UInt8, /, b: UInt8, c: UInt8
176+
) -> Tuple[UInt8, UInt8, UInt8]:
177+
return keyword_only(2, c=4, b=3)
178+
179+
assert Mod.positional_only(2, c=4, b=3) == (2, 3, 4)
180+
181+
133182
if __name__ == "__main__":
134183
run(test_annassign)
135184
run(test_illegal_annassign)
@@ -142,3 +191,6 @@ def g(m1: MemRef[UInt8, 2, 3], m2: MemRef[Index, 1]) -> None:
142191
run(test_minus_eq)
143192
run(test_plus_eq_memref)
144193
run(test_plus_eq_side_effect)
194+
run(test_pos_only_args)
195+
run(test_kw_only_args)
196+
run(test_pos_or_kw_args)

0 commit comments

Comments
 (0)