diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 0a33b45284..89f79379d6 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -713,6 +713,7 @@ RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym) RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST) +RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) diff --git a/integration_tests/symbolics_10.py b/integration_tests/symbolics_10.py new file mode 100644 index 0000000000..c833c1e59f --- /dev/null +++ b/integration_tests/symbolics_10.py @@ -0,0 +1,26 @@ +from sympy import Symbol, sin, pi +from lpython import S + +def test_attributes(): + w: S = pi + x: S = Symbol('x') + y: S = Symbol('y') + z: S = sin(x) + + # test has + assert(w.has(x) == False) + assert(y.has(x) == False) + assert(x.has(x) == True) + assert(x.has(x) == z.has(x)) + + # test has 2 + assert(sin(x).has(x) == True) + assert(sin(x).has(y) == False) + assert(sin(Symbol("x")).has(x) == True) + assert(sin(Symbol("x")).has(y) == False) + assert(sin(Symbol("x")).has(Symbol("x")) == True) + assert(sin(Symbol("x")).has(Symbol("y")) == False) + assert(sin(Symbol("x")).has(Symbol("x")) != sin(Symbol("x")).has(Symbol("y"))) + assert(sin(Symbol("x")).has(Symbol("x")) == sin(Symbol("y")).has(Symbol("y"))) + +test_attributes() \ No newline at end of file diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 4f9fcf1fe3..1be030ef7b 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -76,6 +76,7 @@ enum class IntrinsicScalarFunctions : int64_t { SymbolicLog, SymbolicExp, SymbolicAbs, + SymbolicHasSymbolQ, // ... }; @@ -135,6 +136,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(SymbolicLog) INTRINSIC_NAME_CASE(SymbolicExp) INTRINSIC_NAME_CASE(SymbolicAbs) + INTRINSIC_NAME_CASE(SymbolicHasSymbolQ) default : { throw LCompilersException("pickle: intrinsic_id not implemented"); } @@ -2908,6 +2910,56 @@ namespace SymbolicInteger { } // namespace SymbolicInteger +namespace SymbolicHasSymbolQ { + static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, + diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, "Intrinsic function SymbolicHasSymbolQ" + "accepts exactly 2 arguments", x.base.base.loc, diagnostics); + + ASR::ttype_t* left_type = ASRUtils::expr_type(x.m_args[0]); + ASR::ttype_t* right_type = ASRUtils::expr_type(x.m_args[1]); + + ASRUtils::require_impl(ASR::is_a(*left_type) && + ASR::is_a(*right_type), + "Both arguments of SymbolicHasSymbolQ must be of type SymbolicExpression", + x.base.base.loc, diagnostics); + } + + static inline ASR::expr_t* eval_SymbolicHasSymbolQ(Allocator &/*al*/, + const Location &/*loc*/, ASR::ttype_t *, Vec &/*args*/) { + /*TODO*/ + return nullptr; + } + + static inline ASR::asr_t* create_SymbolicHasSymbolQ(Allocator& al, + const Location& loc, Vec& args, + const std::function err) { + + if (args.size() != 2) { + err("Intrinsic function SymbolicHasSymbolQ accepts exactly 2 arguments", loc); + } + + for (size_t i = 0; i < args.size(); i++) { + ASR::ttype_t* argtype = ASRUtils::expr_type(args[i]); + if(!ASR::is_a(*argtype)) { + err("Arguments of SymbolicHasSymbolQ function must be of type SymbolicExpression", + args[i]->base.loc); + } + } + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + + ASR::expr_t* compile_time_value = eval_SymbolicHasSymbolQ(al, loc, logical, arg_values); + return ASR::make_IntrinsicScalarFunction_t(al, loc, + static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), + args.p, args.size(), 0, logical, compile_time_value); + } +} // namespace SymbolicHasSymbolQ + #define create_symbolic_unary_macro(X) \ namespace X { \ static inline void verify_args(const ASR::IntrinsicScalarFunction_t& x, \ @@ -3057,6 +3109,8 @@ namespace IntrinsicScalarFunctionRegistry { {nullptr, &SymbolicExp::verify_args}}, {static_cast(IntrinsicScalarFunctions::SymbolicAbs), {nullptr, &SymbolicAbs::verify_args}}, + {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), + {nullptr, &SymbolicHasSymbolQ::verify_args}}, }; static const std::map& intrinsic_function_id_to_name = { @@ -3157,6 +3211,8 @@ namespace IntrinsicScalarFunctionRegistry { "SymbolicExp"}, {static_cast(IntrinsicScalarFunctions::SymbolicAbs), "SymbolicAbs"}, + {static_cast(IntrinsicScalarFunctions::SymbolicHasSymbolQ), + "SymbolicHasSymbolQ"}, }; @@ -3210,6 +3266,7 @@ namespace IntrinsicScalarFunctionRegistry { {"SymbolicLog", {&SymbolicLog::create_SymbolicLog, &SymbolicLog::eval_SymbolicLog}}, {"SymbolicExp", {&SymbolicExp::create_SymbolicExp, &SymbolicExp::eval_SymbolicExp}}, {"SymbolicAbs", {&SymbolicAbs::create_SymbolicAbs, &SymbolicAbs::eval_SymbolicAbs}}, + {"has", {&SymbolicHasSymbolQ::create_SymbolicHasSymbolQ, &SymbolicHasSymbolQ::eval_SymbolicHasSymbolQ}}, }; static inline bool is_intrinsic_function(const std::string& name) { diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index 7531c1ffeb..044d944ecd 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -626,12 +626,92 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(name); } + ASR::expr_t* process_attributes(Allocator &al, const Location &loc, ASR::expr_t* expr, + SymbolTable* module_scope) { + if (ASR::is_a(*expr)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(expr); + int64_t intrinsic_id = intrinsic_func->m_intrinsic_id; + switch (static_cast(intrinsic_id)) { + case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicHasSymbolQ: { + std::string name = "basic_has_symbol"; + symbolic_dependencies.push_back(name); + if (!module_scope->get_symbol(name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable* fn_symtab = al.make_new(module_scope); + + Vec args; + args.reserve(al, 1); + ASR::symbol_t* arg1 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "_lpython_return_variable"), nullptr, 0, ASR::intentType::ReturnVar, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + fn_symtab->add_symbol(s2c(al, "_lpython_return_variable"), arg1); + ASR::symbol_t* arg2 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg2); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg2))); + ASR::symbol_t* arg3 = ASR::down_cast(ASR::make_Variable_t( + al, loc, fn_symtab, s2c(al, "y"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)), + nullptr, ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "y"), arg3); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, loc, arg3))); + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::expr_t* return_var = ASRUtils::EXPR(ASR::make_Var_t(al, loc, fn_symtab->get_symbol("_lpython_return_variable"))); + ASR::asr_t* subrout = ASRUtils::make_Function_t_util(al, loc, + fn_symtab, s2c(al, name), dep.p, dep.n, args.p, args.n, body.p, body.n, + return_var, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t* symbol = ASR::down_cast(subrout); + module_scope->add_symbol(s2c(al, name), symbol); + } + + ASR::symbol_t* basic_has_symbol = module_scope->get_symbol(name); + ASR::expr_t* value1 = handle_argument(al, loc, intrinsic_func->m_args[0]); + ASR::expr_t* value2 = handle_argument(al, loc, intrinsic_func->m_args[1]); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg1, call_arg2; + call_arg1.loc = loc; + call_arg1.m_value = value1; + call_args.push_back(al, call_arg1); + call_arg2.loc = loc; + call_arg2.m_value = value2; + call_args.push_back(al, call_arg2); + return ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, loc, + basic_has_symbol, basic_has_symbol, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4)), nullptr, nullptr)); + break; + } + default: { + throw LCompilersException("IntrinsicFunction: `" + + ASRUtils::get_intrinsic_name(intrinsic_id) + + "` is not implemented"); + } + } + } + return expr; + } + void visit_Assignment(const ASR::Assignment_t &x) { SymbolTable* module_scope = current_scope->parent; if (ASR::is_a(*x.m_value)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(x.m_value); if (intrinsic_func->m_type->type == ASR::ttypeType::SymbolicExpression) { process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, x.m_target); + } else if (intrinsic_func->m_type->type == ASR::ttypeType::Logical) { + ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, x.m_value, module_scope); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, function_call, nullptr)); + pass_result.push_back(al, stmt); } } else if (ASR::is_a(*x.m_value)) { ASR::Cast_t* cast_t = ASR::down_cast(x.m_value); @@ -770,37 +850,42 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*val) && ASR::is_a(*ASRUtils::expr_type(val))) { + } else if (ASR::is_a(*val)) { ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(val); - ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); - std::string symengine_var = symengine_stack.push(); - ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( - al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, - nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, - ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); - current_scope->add_symbol(s2c(al, symengine_var), arg); - for (auto &item : current_scope->get_scope()) { - if (ASR::is_a(*item.second)) { - ASR::Variable_t *s = ASR::down_cast(item.second); - this->visit_Variable(*s); + if (ASR::is_a(*ASRUtils::expr_type(val))) { + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_SymbolicExpression_t(al, x.base.base.loc)); + std::string symengine_var = symengine_stack.push(); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, x.base.base.loc, current_scope, s2c(al, symengine_var), nullptr, 0, ASR::intentType::Local, + nullptr, nullptr, ASR::storage_typeType::Default, type, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false)); + current_scope->add_symbol(s2c(al, symengine_var), arg); + for (auto &item : current_scope->get_scope()) { + if (ASR::is_a(*item.second)) { + ASR::Variable_t *s = ASR::down_cast(item.second); + this->visit_Variable(*s); + } } - } - ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); - process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target); + ASR::expr_t* target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, arg)); + process_intrinsic_function(al, x.base.base.loc, intrinsic_func, module_scope, target); - // Now create the FunctionCall node for basic_str - ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); - Vec call_args; - call_args.reserve(al, 1); - ASR::call_arg_t call_arg; - call_arg.loc = x.base.base.loc; - call_arg.m_value = target; - call_args.push_back(al, call_arg); - ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, - basic_str_sym, basic_str_sym, call_args.p, call_args.n, - ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); - print_tmp.push_back(function_call); + // Now create the FunctionCall node for basic_str + ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = x.base.base.loc; + call_arg.m_value = target; + call_args.push_back(al, call_arg); + ASR::expr_t* function_call = ASRUtils::EXPR(ASRUtils::make_FunctionCall_t_util(al, x.base.base.loc, + basic_str_sym, basic_str_sym, call_args.p, call_args.n, + ASRUtils::TYPE(ASR::make_Character_t(al, x.base.base.loc, 1, -2, nullptr)), nullptr, nullptr)); + print_tmp.push_back(function_call); + } else if (ASR::is_a(*ASRUtils::expr_type(val))) { + ASR::expr_t* function_call = process_attributes(al, x.base.base.loc, val, module_scope); + print_tmp.push_back(function_call); + } } else if (ASR::is_a(*val)) { ASR::Cast_t* cast_t = ASR::down_cast(val); if(cast_t->m_kind != ASR::cast_kindType::IntegerToSymbolicExpression) return; @@ -951,20 +1036,34 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*x.m_test)) return; - ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); SymbolTable* module_scope = current_scope->parent; ASR::expr_t* left_tmp = nullptr; ASR::expr_t* right_tmp = nullptr; + if (ASR::is_a(*x.m_test)) { + ASR::LogicalCompare_t *l = ASR::down_cast(x.m_test); + + left_tmp = process_attributes(al, x.base.base.loc, l->m_left, module_scope); + right_tmp = process_attributes(al, x.base.base.loc, l->m_right, module_scope); + ASR::expr_t* test = ASRUtils::EXPR(ASR::make_LogicalCompare_t(al, x.base.base.loc, left_tmp, + l->m_op, right_tmp, l->m_type, l->m_value)); + + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } else if(ASR::is_a(*x.m_test)) { + ASR::SymbolicCompare_t *s = ASR::down_cast(x.m_test); + SymbolTable* module_scope = current_scope->parent; + ASR::expr_t* left_tmp = nullptr; + ASR::expr_t* right_tmp = nullptr; - ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); - left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym); - right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym); - ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, - s->m_op, right_tmp, s->m_type, s->m_value)); + ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope); + left_tmp = process_with_basic_str(al, x.base.base.loc, s->m_left, basic_str_sym); + right_tmp = process_with_basic_str(al, x.base.base.loc, s->m_right, basic_str_sym); + ASR::expr_t* test = ASRUtils::EXPR(ASR::make_StringCompare_t(al, x.base.base.loc, left_tmp, + s->m_op, right_tmp, s->m_type, s->m_value)); - ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); - pass_result.push_back(al, assert_stmt); + ASR::stmt_t *assert_stmt = ASRUtils::STMT(ASR::make_Assert_t(al, x.base.base.loc, test, x.m_msg)); + pass_result.push_back(al, assert_stmt); + } } }; diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 6df1a1bb80..31e59c1089 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -7189,7 +7189,7 @@ class BodyVisitor : public CommonVisitor { } else { st = current_scope->resolve_symbol(mod_name); std::set symbolic_attributes = { - "diff", "expand" + "diff", "expand", "has" }; std::set symbolic_constants = { "pi" @@ -7266,7 +7266,7 @@ class BodyVisitor : public CommonVisitor { } else if (AST::is_a(*at->m_value)) { AST::BinOp_t* bop = AST::down_cast(at->m_value); std::set symbolic_attributes = { - "diff", "expand" + "diff", "expand", "has" }; if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){ switch (bop->m_op) { @@ -7313,7 +7313,7 @@ class BodyVisitor : public CommonVisitor { } else if (AST::is_a(*at->m_value)) { AST::Call_t* call = AST::down_cast(at->m_value); std::set symbolic_attributes = { - "diff", "expand" + "diff", "expand", "has" }; if (symbolic_attributes.find(at->m_attr) != symbolic_attributes.end()){ std::set symbolic_functions = { diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index d6491dff80..1c8256cfc2 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -45,7 +45,8 @@ struct AttributeHandler { symbolic_attribute_map = { {"diff", &eval_symbolic_diff}, - {"expand", &eval_symbolic_expand} + {"expand", &eval_symbolic_expand}, + {"has", &eval_symbolic_has_symbol} }; } @@ -442,6 +443,20 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_symbolic_has_symbol(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_list; + args_with_list.reserve(al, args.size() + 1); + args_with_list.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_list.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicScalarFunctionRegistry::get_create_function("has"); + return create_function(al, loc, args_with_list, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython