@@ -672,6 +672,45 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
672672 return module_scope->get_symbol (name);
673673 }
674674
675+ ASR::symbol_t * declare_basic_get_type_function (Allocator& al, const Location& loc, SymbolTable* module_scope) {
676+ std::string name = " basic_get_type" ;
677+ symbolic_dependencies.push_back (name);
678+ if (!module_scope->get_symbol (name)) {
679+ std::string header = " symengine/cwrapper.h" ;
680+ SymbolTable* fn_symtab = al.make_new <SymbolTable>(module_scope);
681+
682+ Vec<ASR::expr_t *> args;
683+ args.reserve (al, 1 );
684+ ASR::symbol_t * arg1 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
685+ al, loc, fn_symtab, s2c (al, " _lpython_return_variable" ), nullptr , 0 , ASR::intentType::ReturnVar,
686+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )),
687+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, false ));
688+ fn_symtab->add_symbol (s2c (al, " _lpython_return_variable" ), arg1);
689+ ASR::symbol_t * arg2 = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
690+ al, loc, fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
691+ nullptr , nullptr , ASR::storage_typeType::Default, ASRUtils::TYPE (ASR::make_CPtr_t (al, loc)),
692+ nullptr , ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
693+ fn_symtab->add_symbol (s2c (al, " x" ), arg2);
694+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, loc, arg2)));
695+
696+ Vec<ASR::stmt_t *> body;
697+ body.reserve (al, 1 );
698+
699+ Vec<char *> dep;
700+ dep.reserve (al, 1 );
701+
702+ ASR::expr_t * return_var = ASRUtils::EXPR (ASR::make_Var_t (al, loc, fn_symtab->get_symbol (" _lpython_return_variable" )));
703+ ASR::asr_t * subrout = ASRUtils::make_Function_t_util (al, loc,
704+ fn_symtab, s2c (al, name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
705+ return_var, ASR::abiType::BindC, ASR::accessType::Public,
706+ ASR::deftypeType::Interface, s2c (al, name), false , false , false ,
707+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
708+ ASR::symbol_t * symbol = ASR::down_cast<ASR::symbol_t >(subrout);
709+ module_scope->add_symbol (s2c (al, name), symbol);
710+ }
711+ return module_scope->get_symbol (name);
712+ }
713+
675714 ASR::symbol_t * declare_basic_eq_function (Allocator& al, const Location& loc, SymbolTable* module_scope) {
676715 std::string name = " basic_eq" ;
677716 symbolic_dependencies.push_back (name);
@@ -828,6 +867,60 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
828867 ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr , nullptr ));
829868 break ;
830869 }
870+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicAddQ: {
871+ ASR::symbol_t * basic_get_type_sym = declare_basic_get_type_function (al, loc, module_scope);
872+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
873+ Vec<ASR::call_arg_t > call_args;
874+ call_args.reserve (al, 1 );
875+ ASR::call_arg_t call_arg;
876+ call_arg.loc = loc;
877+ call_arg.m_value = value1;
878+ call_args.push_back (al, call_arg);
879+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
880+ basic_get_type_sym, basic_get_type_sym, call_args.p , call_args.n ,
881+ ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )), nullptr , nullptr ));
882+ // Using 16 as the right value of the IntegerCompare node as it represents SYMENGINE_ADD through SYMENGINE_ENUM
883+ return ASRUtils::EXPR (ASR::make_IntegerCompare_t (al, loc, function_call, ASR::cmpopType::Eq,
884+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, loc, 16 , ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )))),
885+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr ));
886+ break ;
887+ }
888+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicMulQ: {
889+ ASR::symbol_t * basic_get_type_sym = declare_basic_get_type_function (al, loc, module_scope);
890+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
891+ Vec<ASR::call_arg_t > call_args;
892+ call_args.reserve (al, 1 );
893+ ASR::call_arg_t call_arg;
894+ call_arg.loc = loc;
895+ call_arg.m_value = value1;
896+ call_args.push_back (al, call_arg);
897+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
898+ basic_get_type_sym, basic_get_type_sym, call_args.p , call_args.n ,
899+ ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )), nullptr , nullptr ));
900+ // Using 15 as the right value of the IntegerCompare node as it represents SYMENGINE_MUL through SYMENGINE_ENUM
901+ return ASRUtils::EXPR (ASR::make_IntegerCompare_t (al, loc, function_call, ASR::cmpopType::Eq,
902+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, loc, 15 , ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )))),
903+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr ));
904+ break ;
905+ }
906+ case LCompilers::ASRUtils::IntrinsicScalarFunctions::SymbolicPowQ: {
907+ ASR::symbol_t * basic_get_type_sym = declare_basic_get_type_function (al, loc, module_scope);
908+ ASR::expr_t * value1 = handle_argument (al, loc, intrinsic_func->m_args [0 ]);
909+ Vec<ASR::call_arg_t > call_args;
910+ call_args.reserve (al, 1 );
911+ ASR::call_arg_t call_arg;
912+ call_arg.loc = loc;
913+ call_arg.m_value = value1;
914+ call_args.push_back (al, call_arg);
915+ ASR::expr_t * function_call = ASRUtils::EXPR (ASRUtils::make_FunctionCall_t_util (al, loc,
916+ basic_get_type_sym, basic_get_type_sym, call_args.p , call_args.n ,
917+ ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )), nullptr , nullptr ));
918+ // Using 17 as the right value of the IntegerCompare node as it represents SYMENGINE_POW through SYMENGINE_ENUM
919+ return ASRUtils::EXPR (ASR::make_IntegerCompare_t (al, loc, function_call, ASR::cmpopType::Eq,
920+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, loc, 17 , ASRUtils::TYPE (ASR::make_Integer_t (al, loc, 4 )))),
921+ ASRUtils::TYPE (ASR::make_Logical_t (al, loc, 4 )), nullptr ));
922+ break ;
923+ }
831924 default : {
832925 throw LCompilersException (" IntrinsicFunction: `"
833926 + ASRUtils::get_intrinsic_name (intrinsic_id)
@@ -998,6 +1091,20 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
9981091 }
9991092 }
10001093
1094+ void visit_If (const ASR::If_t& x) {
1095+ ASR::If_t& xx = const_cast <ASR::If_t&>(x);
1096+ transform_stmts (xx.m_body , xx.n_body );
1097+ transform_stmts (xx.m_orelse , xx.n_orelse );
1098+ SymbolTable* module_scope = current_scope->parent ;
1099+ if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*xx.m_test )) {
1100+ ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(xx.m_test );
1101+ if (intrinsic_func->m_type ->type == ASR::ttypeType::Logical) {
1102+ ASR::expr_t * function_call = process_attributes (al, xx.base .base .loc , xx.m_test , module_scope);
1103+ xx.m_test = function_call;
1104+ }
1105+ }
1106+ }
1107+
10011108 void visit_SubroutineCall (const ASR::SubroutineCall_t &x) {
10021109 SymbolTable* module_scope = current_scope->parent ;
10031110 Vec<ASR::call_arg_t > call_args;
@@ -1298,7 +1405,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
12981405
12991406 ASR::stmt_t *assert_stmt = ASRUtils::STMT (ASR::make_Assert_t (al, x.base .base .loc , test, x.m_msg ));
13001407 pass_result.push_back (al, assert_stmt);
1301- } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) {
1408+ } else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_test )) {
13021409 ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_test );
13031410 SymbolTable* module_scope = current_scope->parent ;
13041411 ASR::expr_t * left_tmp = nullptr ;
0 commit comments