Skip to content

Commit 03ddce7

Browse files
authored
Merge pull request #31 from Balint-R/macro-default-arg
Add CallMacro default argument testcase
2 parents 0e8d1d3 + 8c930e0 commit 03ddce7

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

src/pydsl/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,10 @@ def visit_Expr(self, node: ast.Expr) -> SubtreeOut:
358358
case _:
359359
self.visit(expr_val)
360360

361+
def visit_Expression(self, node: ast.Expression) -> SubtreeOut:
362+
# ast.Expression is output by ast.parse(mode="eval")
363+
return self.visit(node.body)
364+
361365
def visit_Constant(self, node: ast.Constant) -> SubtreeOut:
362366
# TODO: Constant may not always be Number. It may also be e.g. string.
363367
# When other forms of constants are supported this needs to be updated.

tests/e2e/test_macro.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,23 @@ def g(x: UInt32) -> UInt32:
182182
assert g(0) == 0 + 0
183183

184184

185+
def test_default_args():
186+
ast_123 = ast.parse("UInt32(123)", mode="eval")
187+
188+
@CallMacro.generate()
189+
def add_macro(
190+
visitor: ToMLIRBase, x: Uncompiled = ast_123, y: Evaluated = 456
191+
) -> UInt32:
192+
x = visitor.visit(x)
193+
return x.op_add(y)
194+
195+
@compile()
196+
def f() -> UInt32:
197+
return add_macro()
198+
199+
assert f() == 123 + 456
200+
201+
185202
if __name__ == "__main__":
186203
run(test_Compiled_ArgCompiler)
187204
run(test_Evaluated_ArgCompiler)
@@ -193,3 +210,4 @@ def g(x: UInt32) -> UInt32:
193210
run(test_class_method)
194211
run(test_class_only_method)
195212
run(test_static_method)
213+
run(test_default_args)

0 commit comments

Comments
 (0)