diff --git a/integration_tests/symbolics_08.py b/integration_tests/symbolics_08.py index c360b60f37..15e010d160 100644 --- a/integration_tests/symbolics_08.py +++ b/integration_tests/symbolics_08.py @@ -5,6 +5,10 @@ def basic_new_stack(x: CPtr) -> None: pass +@ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib") +def basic_free_stack(x: CPtr) -> None: + pass + @ccall(header="symengine/cwrapper.h", c_shared_lib="symengine", c_shared_lib_path=f"{os.environ['CONDA_PREFIX']}/lib") def basic_const_pi(x: CPtr) -> None: pass @@ -22,5 +26,6 @@ def main0(): s: str = basic_str(x) print(s) assert s == "pi" + basic_free_stack(x) main0() \ No newline at end of file diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index b30fca76a3..f0cda46468 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -54,6 +54,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); SymbolTable* current_scope_copy = this->current_scope; this->current_scope = xx.m_symtab; + SymbolTable* module_scope = this->current_scope->parent; for (auto &item : x.m_symtab->get_scope()) { if (ASR::is_a(*item.second)) { ASR::Variable_t *s = ASR::down_cast(item.second); @@ -75,6 +76,28 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorcurrent_scope = current_scope_copy; + + // freeing out variables + std::string new_name = "basic_free_stack"; + ASR::symbol_t* basic_free_stack_sym = module_scope->get_symbol(new_name); + Vec func_body; + func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body); + + for (ASR::symbol_t* symbol : symbolic_vars) { + Vec call_args; + call_args.reserve(al, 1); + ASR::call_arg_t call_arg; + call_arg.loc = xx.base.base.loc; + call_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, symbol)); + call_args.push_back(al, call_arg); + ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_free_stack_sym, + basic_free_stack_sym, call_args.p, call_args.n, nullptr)); + func_body.push_back(al, stmt); + } + + xx.n_body = func_body.size(); + xx.m_body = func_body.p; + symbolic_vars.clear(); } void visit_Variable(const ASR::Variable_t& x) { @@ -132,6 +155,38 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitoradd_symbol(new_name, new_symbol); } + new_name = "basic_free_stack"; + symbolic_dependencies.push_back(new_name); + if (!module_scope->get_symbol(new_name)) { + std::string header = "symengine/cwrapper.h"; + SymbolTable *fn_symtab = al.make_new(module_scope); + + Vec args; + { + args.reserve(al, 1); + ASR::symbol_t *arg = ASR::down_cast(ASR::make_Variable_t( + al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In, + nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr, + ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true)); + fn_symtab->add_symbol(s2c(al, "x"), arg); + args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg))); + } + + Vec body; + body.reserve(al, 1); + + Vec dep; + dep.reserve(al, 1); + + ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc, + fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n, + nullptr, ASR::abiType::BindC, ASR::accessType::Public, + ASR::deftypeType::Interface, s2c(al, new_name), false, false, false, + false, false, nullptr, 0, false, false, false, s2c(al, header)); + ASR::symbol_t *new_symbol = ASR::down_cast(new_subrout); + module_scope->add_symbol(new_name, new_symbol); + } + ASR::symbol_t* var_sym = current_scope->get_symbol(var_name); ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder); ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym)); @@ -154,7 +209,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitorget_symbol(new_name); + ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack"); Vec call_args; call_args.reserve(al, 1); ASR::call_arg_t call_arg;