From 7c794fb1ff79b2559b8be31e5b088333f1571331 Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Thu, 17 Aug 2023 01:43:54 +0530 Subject: [PATCH 1/5] ASR: Support lambda functions --- src/lpython/semantics/python_ast_to_asr.cpp | 90 +++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 4d10ff965c..72a88528e0 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -2873,6 +2873,91 @@ class CommonVisitor : public AST::BaseVisitor { return nullptr; } + void handle_lambda_function_declaration(std::string &var_name, ASR::FunctionType_t* fn_type, AST::expr_t* value, const Location &loc) { + if (value == nullptr) { + throw SemanticError("Callback functions must have a value", loc); + } + + if (!AST::is_a(*value)) { + throw SemanticError("Callback functions supports only lambda expressions as value", value->base.loc); + } + + const AST::Lambda_t &x = *AST::down_cast(value); + if (fn_type->n_arg_types != x.m_args.n_args) { + diag.add(diag::Diagnostic( + "The number of args to lambda function much match the number of args declared in function type", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("", + {fn_type->base.base.loc, x.m_args.loc}) + }) + ); + throw SemanticAbort(); + } + + // Add the lambda function to the current scope + SymbolTable *parent_scope = current_scope; + current_scope = al.make_new(parent_scope); + + Vec args; + args.reserve(al, fn_type->n_arg_types); + for (size_t i=0; in_arg_types; i++) { + std::string arg_name = x.m_args.m_args[i].m_arg; + ASR::symbol_t *v; + SetChar variable_dependencies_vec; + variable_dependencies_vec.reserve(al, 1); + ASRUtils::collect_variable_dependencies(al, variable_dependencies_vec, + fn_type->m_arg_types[i]); + v = ASR::down_cast( + ASR::make_Variable_t(al, x.m_args.m_args[i].loc, + current_scope, s2c(al, arg_name), variable_dependencies_vec.p, + variable_dependencies_vec.size(), ASRUtils::intent_unspecified, + nullptr, nullptr, ASR::storage_typeType::Default, fn_type->m_arg_types[i], + nullptr, ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, + false)); + current_scope->add_symbol(arg_name, v); + LCOMPILERS_ASSERT(v != nullptr) + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, x.m_args.m_args[i].loc, v))); + } + + this->visit_expr(*x.m_body); + ASR::asr_t* return_var_assign_stmt = make_dummy_assignment(ASRUtils::EXPR(tmp)); + ASR::expr_t *return_var = ASR::down_cast2(return_var_assign_stmt)->m_target; + + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(return_var), fn_type->m_return_var_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(return_var)); + std::string rtype = ASRUtils::type_to_str_python(fn_type->m_return_var_type); + diag.add(diag::Diagnostic( + "Type mismatch in lambda expression return value", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {ASRUtils::expr_type(return_var)->base.loc, fn_type->m_return_var_type->base.loc}) + }) + ); + throw SemanticAbort(); + } + + Vec body; + body.reserve(al, 0); + body.push_back(al, ASRUtils::STMT(return_var_assign_stmt)); + + ASR::asr_t* fn_sym_util = ASRUtils::make_Function_t_util( + al, x.base.base.loc, + /* a_symtab */ current_scope, + /* a_name */ s2c(al, var_name), + nullptr, 0, + /* a_args */ args.p, + /* n_args */ args.size(), + /* a_body */ body.p, + /* n_body */ body.size(), + /* a_return_var */ return_var, + ASR::abiType::BindC, ASR::accessType::Public, ASR::deftypeType::Implementation, + nullptr, false, false, false, false, false, nullptr, 0, false, false, false); + current_scope = parent_scope; + ASR::symbol_t* fn_sym = ASR::down_cast(fn_sym_util); + current_scope->add_symbol(var_name, fn_sym); + tmp = nullptr; + } + void visit_AnnAssignUtil(const AST::AnnAssign_t& x, std::string& var_name, ASR::expr_t* &init_expr, bool wrap_derived_type_in_pointer=false, @@ -2885,6 +2970,11 @@ class CommonVisitor : public AST::BaseVisitor { } else { type = ast_expr_to_asr_type(x.base.base.loc, *x.m_annotation, is_allocatable, true, abi); } + if (ASR::is_a(*type)) { + ASR::FunctionType_t* fn_type = ASR::down_cast(type); + handle_lambda_function_declaration(var_name, fn_type, x.m_value, x.base.base.loc); + return; + } ASR::ttype_t* ann_assign_target_type_copy = ann_assign_target_type; ann_assign_target_type = type; if( ASR::is_a(*type) && From 279fe0a2b35c38827973768fcbc7c3b6a453fdbc Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Thu, 17 Aug 2023 03:53:40 +0530 Subject: [PATCH 2/5] TEST: For lambda expression --- integration_tests/CMakeLists.txt | 1 + integration_tests/lambda_01.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 integration_tests/lambda_01.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 88aaf52961..580b9bbb9f 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -770,6 +770,7 @@ RUN(NAME callback_01 LABELS cpython llvm c) RUN(NAME callback_02 LABELS cpython llvm c) RUN(NAME callback_03 LABELS cpython llvm c) +RUN(NAME lambda_01 LABELS cpython llvm) # callback_04 is to test emulation. So just run with cpython RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython) diff --git a/integration_tests/lambda_01.py b/integration_tests/lambda_01.py new file mode 100644 index 0000000000..7e49e4884c --- /dev/null +++ b/integration_tests/lambda_01.py @@ -0,0 +1,16 @@ +from lpython import i32, Callable + +def main0(): + x: Callable[[i32, i32, i32], i32] = lambda p, q, r: p + q + r + + a123: i32 = x(1, 2, 3) + a456: i32 = x(4, 5, 6) + a_1_2_3: i32 = x(-1, -2, -3) + + print(a123, a456, a_1_2_3) + + assert a123 == 6 + assert a456 == 15 + assert a_1_2_3 == -6 + +main0() From 82cabb026399a88a2ab446543e0bc089446cbf57 Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Thu, 17 Aug 2023 04:03:12 +0530 Subject: [PATCH 3/5] ASR: Fix location to annotation type --- src/lpython/semantics/python_ast_to_asr.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 72a88528e0..0ce768f9d2 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -2966,9 +2966,9 @@ class CommonVisitor : public AST::BaseVisitor { bool is_allocatable = false; ASR::ttype_t *type = nullptr; if( inside_struct ) { - type = ast_expr_to_asr_type(x.base.base.loc, *x.m_annotation, is_allocatable, true); + type = ast_expr_to_asr_type(x.m_annotation->base.loc, *x.m_annotation, is_allocatable, true); } else { - type = ast_expr_to_asr_type(x.base.base.loc, *x.m_annotation, is_allocatable, true, abi); + type = ast_expr_to_asr_type(x.m_annotation->base.loc, *x.m_annotation, is_allocatable, true, abi); } if (ASR::is_a(*type)) { ASR::FunctionType_t* fn_type = ASR::down_cast(type); From 7c2bbac2e99c368383509de975a937ee813857ef Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Thu, 17 Aug 2023 04:05:05 +0530 Subject: [PATCH 4/5] TEST: Add error test for mismatch arg count for lambda function --- tests/errors/lambda_01.py | 9 +++++++++ tests/tests.toml | 4 ++++ 2 files changed, 13 insertions(+) create mode 100644 tests/errors/lambda_01.py diff --git a/tests/errors/lambda_01.py b/tests/errors/lambda_01.py new file mode 100644 index 0000000000..8a28334447 --- /dev/null +++ b/tests/errors/lambda_01.py @@ -0,0 +1,9 @@ + +from lpython import i32, Callable + +def main0(): + x: Callable[[i32, i32, i32], i32] = lambda p, q, r, s: p + q + r + s + + a123 = x(1, 2, 3) + +main0() diff --git a/tests/tests.toml b/tests/tests.toml index 19cec7d482..ba4de05588 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -638,6 +638,10 @@ ast_new = true filename = "parser/tuple1.py" ast_new = true +[[test]] +filename = "errors/lambda_01.py" +asr = true + [[test]] filename = "errors/test_bit_length.py" asr = true From c8d07f573691f2711f54ffc3d45ea49f696895af Mon Sep 17 00:00:00 2001 From: Shaikh Ubaid Date: Thu, 17 Aug 2023 04:05:59 +0530 Subject: [PATCH 5/5] TEST: Update reference tests --- tests/reference/asr-lambda_01-1ec3e01.json | 13 +++++++++++++ tests/reference/asr-lambda_01-1ec3e01.stderr | 5 +++++ tests/reference/asr-test_dict10-8c0beff.json | 2 +- tests/reference/asr-test_dict10-8c0beff.stderr | 4 ++-- tests/reference/asr-test_dict11-2ab4e6c.json | 2 +- tests/reference/asr-test_dict11-2ab4e6c.stderr | 4 ++-- tests/reference/asr-test_dict8-d960ce0.json | 2 +- tests/reference/asr-test_dict8-d960ce0.stderr | 4 ++-- tests/reference/asr-test_dict9-907bda7.json | 2 +- tests/reference/asr-test_dict9-907bda7.stderr | 4 ++-- 10 files changed, 30 insertions(+), 12 deletions(-) create mode 100644 tests/reference/asr-lambda_01-1ec3e01.json create mode 100644 tests/reference/asr-lambda_01-1ec3e01.stderr diff --git a/tests/reference/asr-lambda_01-1ec3e01.json b/tests/reference/asr-lambda_01-1ec3e01.json new file mode 100644 index 0000000000..31b51267b2 --- /dev/null +++ b/tests/reference/asr-lambda_01-1ec3e01.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-lambda_01-1ec3e01", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/errors/lambda_01.py", + "infile_hash": "0a22dc5de76f7c3f4f97dc4349f62e51261c0a9b3fc5e932926d438e", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "asr-lambda_01-1ec3e01.stderr", + "stderr_hash": "99ca916bd82540da6812ad3149c0026c812efdbc777dbb5fb465e868", + "returncode": 2 +} \ No newline at end of file diff --git a/tests/reference/asr-lambda_01-1ec3e01.stderr b/tests/reference/asr-lambda_01-1ec3e01.stderr new file mode 100644 index 0000000000..482e01b32e --- /dev/null +++ b/tests/reference/asr-lambda_01-1ec3e01.stderr @@ -0,0 +1,5 @@ +semantic error: The number of args to lambda function much match the number of args declared in function type + --> tests/errors/lambda_01.py:5:8 + | +5 | x: Callable[[i32, i32, i32], i32] = lambda p, q, r, s: p + q + r + s + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^ diff --git a/tests/reference/asr-test_dict10-8c0beff.json b/tests/reference/asr-test_dict10-8c0beff.json index 0f7ae1272d..2b2342369e 100644 --- a/tests/reference/asr-test_dict10-8c0beff.json +++ b/tests/reference/asr-test_dict10-8c0beff.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_dict10-8c0beff.stderr", - "stderr_hash": "95d5b555fbf664cf7bc7735845c89acc77393a00ad44b42fcf7c8fe8", + "stderr_hash": "06772bed43d8fff0fb889a763afb49307005f50ce26c7a601652e258", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_dict10-8c0beff.stderr b/tests/reference/asr-test_dict10-8c0beff.stderr index 7e0d792c97..58c4edd7d3 100644 --- a/tests/reference/asr-test_dict10-8c0beff.stderr +++ b/tests/reference/asr-test_dict10-8c0beff.stderr @@ -1,5 +1,5 @@ semantic error: 'dict' key type cannot be float/complex because resolving collisions by exact comparison of float/complex values will result in unexpected behaviours. In addition fuzzy equality checks with a certain tolerance does not follow transitivity with float/complex values. - --> tests/errors/test_dict10.py:4:5 + --> tests/errors/test_dict10.py:4:8 | 4 | d: dict[c32, f64] = {} - | ^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^ diff --git a/tests/reference/asr-test_dict11-2ab4e6c.json b/tests/reference/asr-test_dict11-2ab4e6c.json index 89ed565509..c91886a137 100644 --- a/tests/reference/asr-test_dict11-2ab4e6c.json +++ b/tests/reference/asr-test_dict11-2ab4e6c.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_dict11-2ab4e6c.stderr", - "stderr_hash": "4944c96752dfe5fcfc190831966428e9568e9d4b8b03a553524df84b", + "stderr_hash": "6ef78d7738e0780fc0f9b9567390798b3d74374b95d0dd156ccbdab4", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_dict11-2ab4e6c.stderr b/tests/reference/asr-test_dict11-2ab4e6c.stderr index f4bae6f532..fcc460b76f 100644 --- a/tests/reference/asr-test_dict11-2ab4e6c.stderr +++ b/tests/reference/asr-test_dict11-2ab4e6c.stderr @@ -1,5 +1,5 @@ semantic error: 'dict' key type cannot be float/complex because resolving collisions by exact comparison of float/complex values will result in unexpected behaviours. In addition fuzzy equality checks with a certain tolerance does not follow transitivity with float/complex values. - --> tests/errors/test_dict11.py:4:5 + --> tests/errors/test_dict11.py:4:8 | 4 | d: dict[c64, f32] = {} - | ^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^ diff --git a/tests/reference/asr-test_dict8-d960ce0.json b/tests/reference/asr-test_dict8-d960ce0.json index 10fa72e28d..303c677d1d 100644 --- a/tests/reference/asr-test_dict8-d960ce0.json +++ b/tests/reference/asr-test_dict8-d960ce0.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_dict8-d960ce0.stderr", - "stderr_hash": "86744c3a768772a885a4cafef8973f69689fb2522aae6dfe486f7dcd", + "stderr_hash": "c2dcf3e38154f9a69328274fafd4940b8b6296d31f442c01c88eaa0e", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_dict8-d960ce0.stderr b/tests/reference/asr-test_dict8-d960ce0.stderr index 51374cc073..050138ffac 100644 --- a/tests/reference/asr-test_dict8-d960ce0.stderr +++ b/tests/reference/asr-test_dict8-d960ce0.stderr @@ -1,5 +1,5 @@ semantic error: 'dict' key type cannot be float/complex because resolving collisions by exact comparison of float/complex values will result in unexpected behaviours. In addition fuzzy equality checks with a certain tolerance does not follow transitivity with float/complex values. - --> tests/errors/test_dict8.py:4:5 + --> tests/errors/test_dict8.py:4:8 | 4 | d: dict[f64, f64] = {} - | ^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^ diff --git a/tests/reference/asr-test_dict9-907bda7.json b/tests/reference/asr-test_dict9-907bda7.json index 3883167972..3603e2ca62 100644 --- a/tests/reference/asr-test_dict9-907bda7.json +++ b/tests/reference/asr-test_dict9-907bda7.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_dict9-907bda7.stderr", - "stderr_hash": "14a0981e18ecf1948417be8e93c7956f82c76fcc5e84b1d428d525c0", + "stderr_hash": "3278571c4f1c492f88f33ca78dcf8fb5051f9e3ca89df7557b7881f6", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_dict9-907bda7.stderr b/tests/reference/asr-test_dict9-907bda7.stderr index e7dee1b91d..a1394398fa 100644 --- a/tests/reference/asr-test_dict9-907bda7.stderr +++ b/tests/reference/asr-test_dict9-907bda7.stderr @@ -1,5 +1,5 @@ semantic error: 'dict' key type cannot be float/complex because resolving collisions by exact comparison of float/complex values will result in unexpected behaviours. In addition fuzzy equality checks with a certain tolerance does not follow transitivity with float/complex values. - --> tests/errors/test_dict9.py:4:5 + --> tests/errors/test_dict9.py:4:8 | 4 | d: dict[f32, f64] = {} - | ^^^^^^^^^^^^^^^^^^^^^^ + | ^^^^^^^^^^^^^^