diff --git a/integration_tests/symbolics_02.py b/integration_tests/symbolics_02.py index f22a432606..74f4a4af35 100644 --- a/integration_tests/symbolics_02.py +++ b/integration_tests/symbolics_02.py @@ -1,9 +1,11 @@ -from sympy import Symbol +from sympy import Symbol, pi from lpython import S def test_symbolic_operations(): x: S = Symbol('x') y: S = Symbol('y') + p1: S = pi + p2: S = pi # Addition z: S = x + y @@ -37,4 +39,19 @@ def test_symbolic_operations(): assert(c == S(0)) print(c) + # Comparison + b1: bool = p1 == p2 + print(b1) + assert(b1 == True) + b2: bool = p1 != pi + print(b2) + assert(b2 == False) + b3: bool = p1 != x + print(b3) + assert(b3 == True) + b4: bool = pi == Symbol("x") + print(b4) + assert(b4 == False) + + test_symbolic_operations() diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 044d944ecd..59bdf45f81 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -626,6 +626,96 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::symbol_t* declare_basic_eq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_eq"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + + ASR::symbol_t* declare_basic_neq_function(Allocator& al, const Location& loc, SymbolTable* module_scope) { + std::string name = "basic_neq"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + return module_scope->get_symbol(name); + } + ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, SymbolTable* module_scope) { if (ASR::is_a(*expr)) { @@ -772,6 +862,33 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_value)) { + ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_value); + if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { + ASR::symbol_t* sym = nullptr; + if (s->m_op == ASR::cmpopType::Eq) { + sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + } else { + sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + } + ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); + ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); + + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = value1; + call_args.push_back(al, call_arg1); + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg2); + + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); + pass_result.push_back(al, stmt); + } } } @@ -905,6 +1022,32 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val)) { + ASR::SymbolicCompare_t *s = ASR::down_cast(val); + if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) { + ASR::symbol_t* sym = nullptr; + if (s->m_op == ASR::cmpopType::Eq) { + sym = declare_basic_eq_function(al, x.base.base.loc, module_scope); + } else { + sym = declare_basic_neq_function(al, x.base.base.loc, module_scope); + } + ASR::expr_t* value1 = handle_argument(al, x.base.base.loc, s->m_left); + ASR::expr_t* value2 = handle_argument(al, x.base.base.loc, s->m_right); + + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = x.base.base.loc; + call_arg1.m_value = value1; + call_args.push_back(al, call_arg1); + call_arg2.loc = x.base.base.loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg2); + + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + sym, sym, call_args.p, call_args.n, ASRUtils::TYPE(ASR::make_Logical_t(al, x.base.base.loc, 4)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } } else { print_tmp.push_back(x.m_values[i]); }