From f39c96a4c12c4363df2e5c739bbc484c3b224f82 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Tue, 15 Aug 2023 12:37:08 +0530 Subject: [PATCH 1/3] PASS: Fix fma declaration --- src/libasr/pass/fma.cpp | 5 +++-- src/libasr/pass/pass_utils.h | 4 +--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/libasr/pass/fma.cpp b/src/libasr/pass/fma.cpp index ded6561ba5..ae1f49b8ec 100644 --- a/src/libasr/pass/fma.cpp +++ b/src/libasr/pass/fma.cpp @@ -118,8 +118,7 @@ class FMAVisitor : public PassUtils::SkipOptimizationFunctionVisitor } fma_var = PassUtils::get_fma(other_expr, first_arg, second_arg, - al, unit, pass_options, current_scope, x.base.base.loc, - [&](const std::string &msg, const Location &) { throw LCompilersException(msg); }); + al, unit, x.base.base.loc); from_fma = false; } @@ -170,6 +169,8 @@ void pass_replace_fma(Allocator &al, ASR::TranslationUnit_t &unit, const LCompilers::PassOptions& pass_options) { FMAVisitor v(al, unit, pass_options); v.visit_TranslationUnit(unit); + PassUtils::UpdateDependenciesVisitor u(al); + u.visit_TranslationUnit(unit); } diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index e0f0cf0083..c8bf786b99 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -90,9 +90,7 @@ namespace LCompilers { ASR::intentType var_intent=ASR::intentType::Local); ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2, - Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options, - SymbolTable*& current_scope,Location& loc, - const std::function err); + Allocator& al, ASR::TranslationUnit_t& unit, Location& loc); ASR::expr_t* get_sign_from_value(ASR::expr_t* arg0, ASR::expr_t* arg1, Allocator& al, ASR::TranslationUnit_t& unit, From 86293d0bb2f330ec813ae5be2f4988861678d03d Mon Sep 17 00:00:00 2001 From: Smit-create Date: Tue, 15 Aug 2023 12:37:22 +0530 Subject: [PATCH 2/3] Generate fma using ASR --- src/libasr/pass/pass_utils.cpp | 95 +++++++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/src/libasr/pass/pass_utils.cpp b/src/libasr/pass/pass_utils.cpp index 526746a540..f1971f1355 100644 --- a/src/libasr/pass/pass_utils.cpp +++ b/src/libasr/pass/pass_utils.cpp @@ -665,12 +665,91 @@ namespace LCompilers { return var; } + #define create_args(x, type, symtab) { \ + ASR::symbol_t* arg = ASR::down_cast( \ + ASR::make_Variable_t(al, loc, symtab, \ + s2c(al, x), nullptr, 0, ASR::intentType::In, nullptr, nullptr, \ + ASR::storage_typeType::Default, type, nullptr, \ + ASR::abiType::Source, ASR::accessType::Public, \ + ASR::presenceType::Required, false)); \ + ASR::expr_t* arg_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg)); \ + arg_exprs.push_back(al, arg_expr); \ + symtab->add_symbol(x, arg); \ + } + + ASR::symbol_t* create_fma_func(Allocator& al, Location& loc, + SymbolTable*& global_scope, ASR::ttype_t* type) { + /* + elemental real(real32) function _lcompilers_optimization__function__fma__real32(a, b, c) result(d) + real(real32), intent(in) :: a, b, c + d = a + b * c + end function + */ + + std::string type_name = ASRUtils::get_type_code(type, true); + std::string func_name = "_lcompilers_optimization__function__fma__" + type_name; + if (global_scope->get_symbol(func_name) != nullptr) { + return global_scope->get_symbol(func_name); + } + SymbolTable* fma_func = al.make_new(global_scope); + Vec arg_exprs; + arg_exprs.reserve(al, 3); + + Vec body; + body.reserve(al, 1); + + // Declare `a_list`, `start`, `end` and `step` + create_args("a", type, fma_func) + create_args("b", type, fma_func) + create_args("c", type, fma_func) + + ASR::symbol_t* result_var = ASR::down_cast( + ASR::make_Variable_t(al, loc, fma_func, + s2c(al, "result_var"), nullptr, 0, ASR::intentType::Local, nullptr, nullptr, + ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::Source, ASR::accessType::Public, + ASR::presenceType::Required, false)); + ASR::expr_t* result = ASRUtils::EXPR(ASR::make_Var_t(al, loc, result_var)); + fma_func->add_symbol("result_var", result_var); + + ASR::expr_t* b_c = ASRUtils::EXPR(ASR::make_RealBinOp_t(al, loc, arg_exprs[1], + ASR::binopType::Mul, arg_exprs[2], type, nullptr)); + ASR::expr_t* a_b_c = ASRUtils::EXPR(ASR::make_RealBinOp_t(al, loc, arg_exprs[0], + ASR::binopType::Mul, b_c, type, nullptr)); + + ASR::stmt_t* res_stmt = ASRUtils::STMT(ASR::make_Assignment_t( + al, loc, result, a_b_c, nullptr)); + body.push_back(al, res_stmt); + + // Return + res_stmt = ASRUtils::STMT(ASR::make_Return_t(al, loc)); + body.push_back(al, res_stmt); + + ASR::asr_t *fn = ASRUtils::make_Function_t_util( + al, loc, + /* a_symtab */ fma_func, + /* a_name */ s2c(al, func_name), + nullptr, 0, + /* a_args */ arg_exprs.p, + /* n_args */ arg_exprs.n, + /* a_body */ body.p, + /* n_body */ body.n, + /* a_return_var */ result, + ASR::abiType::Source, + ASR::accessType::Public, ASR::deftypeType::Implementation, + nullptr, + false, true, false, false, false, + nullptr, 0, + false, false, false); + ASR::symbol_t *fn_sym = ASR::down_cast(fn); + global_scope->add_symbol(func_name, fn_sym); + return fn_sym; + } + ASR::expr_t* get_fma(ASR::expr_t* arg0, ASR::expr_t* arg1, ASR::expr_t* arg2, - Allocator& al, ASR::TranslationUnit_t& unit, LCompilers::PassOptions& pass_options, - SymbolTable*& current_scope, Location& loc, - const std::function err) { - ASR::symbol_t *v = import_generic_procedure("fma", "lfortran_intrinsic_optimization", - al, unit, pass_options, current_scope, arg0->base.loc); + Allocator& al, ASR::TranslationUnit_t& unit, Location& loc) { + ASR::ttype_t *t = ASRUtils::expr_type(arg0); + ASR::symbol_t *v = create_fma_func(al, loc, unit.m_global_scope, t); Vec args; args.reserve(al, 3); ASR::call_arg_t arg0_, arg1_, arg2_; @@ -681,8 +760,10 @@ namespace LCompilers { arg2_.loc = arg2->base.loc, arg2_.m_value = arg2; args.push_back(al, arg2_); return ASRUtils::EXPR( - ASRUtils::symbol_resolve_external_generic_procedure_without_eval( - loc, v, args, current_scope, al, err)); + ASRUtils::make_FunctionCall_t_util(al, loc, v, + v, args.p, args.size(), + t, + nullptr, nullptr)); } ASR::symbol_t* insert_fallback_vector_copy(Allocator& al, ASR::TranslationUnit_t& unit, From f9476bfaa3d1fbea871be47f78cf82110ebaf3f8 Mon Sep 17 00:00:00 2001 From: Smit-create Date: Tue, 15 Aug 2023 12:39:48 +0530 Subject: [PATCH 3/3] Add tests --- integration_tests/CMakeLists.txt | 1 + integration_tests/expr_22.py | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 integration_tests/expr_22.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index a673e0f515..cdb0274595 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -475,6 +475,7 @@ RUN(NAME expr_18 FAIL LABELS cpython llvm c) RUN(NAME expr_19 LABELS cpython llvm c) RUN(NAME expr_20 LABELS cpython llvm c) RUN(NAME expr_21 LABELS cpython llvm c) +RUN(NAME expr_22 LABELS cpython llvm c) RUN(NAME expr_01u LABELS cpython llvm c NOFAST) RUN(NAME expr_02u LABELS cpython llvm c NOFAST) diff --git a/integration_tests/expr_22.py b/integration_tests/expr_22.py new file mode 100644 index 0000000000..263c014867 --- /dev/null +++ b/integration_tests/expr_22.py @@ -0,0 +1,9 @@ +from lpython import f64 + +# test issue 1671 +def test_fast_fma() -> f64: + a : f64 = 5.00 + a = a + a * 10.00 + return a + +print(test_fast_fma())