Skip to content

Commit f9b09dd

Browse files
authored
Merge pull request #2331 from anutosh491/Fixing_symbolic_parameters
Added support for functions to accept symbolic variables
2 parents 2293972 + be85f02 commit f9b09dd

File tree

3 files changed

+165
-113
lines changed

3 files changed

+165
-113
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ RUN(NAME symbolics_05 LABELS cpython_sym c_sym llvm_sym NOFAST)
711711
RUN(NAME symbolics_06 LABELS cpython_sym c_sym llvm_sym NOFAST)
712712
RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST)
713713
RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym)
714+
RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST)
714715

715716
RUN(NAME sizeof_01 LABELS llvm c
716717
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_09.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from sympy import Symbol, pi, S
2+
from lpython import S, i32
3+
4+
def addInteger(x: S, y: S, z: S, i: i32):
5+
_i: S = S(i)
6+
print(x + y + z + _i)
7+
8+
def call_addInteger():
9+
a: S = Symbol("x")
10+
b: S = Symbol("y")
11+
c: S = pi
12+
addInteger(a, b, c, 2)
13+
14+
def main0():
15+
call_addInteger()
16+
17+
main0()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 147 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
4545
pass_result.reserve(al, 1);
4646
}
4747
std::vector<std::string> symbolic_dependencies;
48-
std::set<ASR::symbol_t*> symbolic_vars;
48+
std::set<ASR::symbol_t*> symbolic_vars_to_free;
49+
std::set<ASR::symbol_t*> symbolic_vars_to_omit;
4950
SymEngine_Stack symengine_stack;
5051

5152
void visit_Function(const ASR::Function_t &x) {
@@ -55,6 +56,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
5556
SymbolTable* current_scope_copy = this->current_scope;
5657
this->current_scope = xx.m_symtab;
5758
SymbolTable* module_scope = this->current_scope->parent;
59+
60+
ASR::ttype_t* f_signature= xx.m_function_signature;
61+
ASR::FunctionType_t *f_type = ASR::down_cast<ASR::FunctionType_t>(f_signature);
62+
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
63+
for (size_t i = 0; i < f_type->n_arg_types; ++i) {
64+
if (f_type->m_arg_types[i]->type == ASR::ttypeType::SymbolicExpression) {
65+
f_type->m_arg_types[i] = type1;
66+
}
67+
}
68+
5869
for (auto &item : x.m_symtab->get_scope()) {
5970
if (ASR::is_a<ASR::Variable_t>(*item.second)) {
6071
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second);
@@ -83,7 +94,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
8394
Vec<ASR::stmt_t*> func_body;
8495
func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body);
8596

86-
for (ASR::symbol_t* symbol : symbolic_vars) {
97+
for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
98+
if (symbolic_vars_to_omit.find(symbol) != symbolic_vars_to_omit.end()) continue;
8799
Vec<ASR::call_arg_t> call_args;
88100
call_args.reserve(al, 1);
89101
ASR::call_arg_t call_arg;
@@ -97,7 +109,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
97109

98110
xx.n_body = func_body.size();
99111
xx.m_body = func_body.p;
100-
symbolic_vars.clear();
112+
symbolic_vars_to_free.clear();
101113
}
102114

103115
void visit_Variable(const ASR::Variable_t& x) {
@@ -109,125 +121,130 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
109121

110122
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
111123
xx.m_type = type1;
112-
symbolic_vars.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
113-
114-
ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8));
115-
ASR::symbol_t* sym2 = ASR::down_cast<ASR::symbol_t>(
116-
ASR::make_Variable_t(al, xx.base.base.loc, current_scope,
117-
s2c(al, placeholder), nullptr, 0,
118-
xx.m_intent, nullptr,
119-
nullptr, xx.m_storage,
120-
type2, nullptr, xx.m_abi,
121-
xx.m_access, xx.m_presence,
122-
xx.m_value_attr));
123-
124-
current_scope->add_symbol(s2c(al, placeholder), sym2);
125-
126-
std::string new_name = "basic_new_stack";
127-
symbolic_dependencies.push_back(new_name);
128-
if (!module_scope->get_symbol(new_name)) {
129-
std::string header = "symengine/cwrapper.h";
130-
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);
131-
132-
Vec<ASR::expr_t*> args;
133-
{
134-
args.reserve(al, 1);
135-
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
136-
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
137-
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
138-
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
139-
fn_symtab->add_symbol(s2c(al, "x"), arg);
140-
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
141-
}
124+
symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
125+
if(xx.m_intent == ASR::intentType::In){
126+
symbolic_vars_to_omit.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
127+
}
142128

143-
Vec<ASR::stmt_t*> body;
144-
body.reserve(al, 1);
129+
if(xx.m_intent == ASR::intentType::Local){
130+
ASR::ttype_t *type2 = ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 8));
131+
ASR::symbol_t* sym2 = ASR::down_cast<ASR::symbol_t>(
132+
ASR::make_Variable_t(al, xx.base.base.loc, current_scope,
133+
s2c(al, placeholder), nullptr, 0,
134+
xx.m_intent, nullptr,
135+
nullptr, xx.m_storage,
136+
type2, nullptr, xx.m_abi,
137+
xx.m_access, xx.m_presence,
138+
xx.m_value_attr));
145139

146-
Vec<char *> dep;
147-
dep.reserve(al, 1);
140+
current_scope->add_symbol(s2c(al, placeholder), sym2);
148141

149-
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
150-
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
151-
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
152-
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
153-
false, false, nullptr, 0, false, false, false, s2c(al, header));
154-
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
155-
module_scope->add_symbol(new_name, new_symbol);
156-
}
142+
std::string new_name = "basic_new_stack";
143+
symbolic_dependencies.push_back(new_name);
144+
if (!module_scope->get_symbol(new_name)) {
145+
std::string header = "symengine/cwrapper.h";
146+
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);
147+
148+
Vec<ASR::expr_t*> args;
149+
{
150+
args.reserve(al, 1);
151+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
152+
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
153+
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
154+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
155+
fn_symtab->add_symbol(s2c(al, "x"), arg);
156+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
157+
}
157158

158-
new_name = "basic_free_stack";
159-
symbolic_dependencies.push_back(new_name);
160-
if (!module_scope->get_symbol(new_name)) {
161-
std::string header = "symengine/cwrapper.h";
162-
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);
159+
Vec<ASR::stmt_t*> body;
160+
body.reserve(al, 1);
163161

164-
Vec<ASR::expr_t*> args;
165-
{
166-
args.reserve(al, 1);
167-
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
168-
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
169-
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
170-
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
171-
fn_symtab->add_symbol(s2c(al, "x"), arg);
172-
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
162+
Vec<char *> dep;
163+
dep.reserve(al, 1);
164+
165+
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
166+
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
167+
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
168+
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
169+
false, false, nullptr, 0, false, false, false, s2c(al, header));
170+
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
171+
module_scope->add_symbol(new_name, new_symbol);
173172
}
174173

175-
Vec<ASR::stmt_t*> body;
176-
body.reserve(al, 1);
174+
new_name = "basic_free_stack";
175+
symbolic_dependencies.push_back(new_name);
176+
if (!module_scope->get_symbol(new_name)) {
177+
std::string header = "symengine/cwrapper.h";
178+
SymbolTable *fn_symtab = al.make_new<SymbolTable>(module_scope);
177179

178-
Vec<char *> dep;
179-
dep.reserve(al, 1);
180+
Vec<ASR::expr_t*> args;
181+
{
182+
args.reserve(al, 1);
183+
ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t>(ASR::make_Variable_t(
184+
al, xx.base.base.loc, fn_symtab, s2c(al, "x"), nullptr, 0, ASR::intentType::In,
185+
nullptr, nullptr, ASR::storage_typeType::Default, type1, nullptr,
186+
ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true));
187+
fn_symtab->add_symbol(s2c(al, "x"), arg);
188+
args.push_back(al, ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, arg)));
189+
}
180190

181-
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
182-
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
183-
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
184-
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
185-
false, false, nullptr, 0, false, false, false, s2c(al, header));
186-
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
187-
module_scope->add_symbol(new_name, new_symbol);
188-
}
191+
Vec<ASR::stmt_t*> body;
192+
body.reserve(al, 1);
189193

190-
ASR::symbol_t* var_sym = current_scope->get_symbol(var_name);
191-
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
192-
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
193-
ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym));
194-
195-
// statement 1
196-
ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc,
197-
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0,
198-
ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4)))),
199-
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2,
200-
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2))));
201-
202-
// statement 2
203-
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1));
204-
205-
// statement 3
206-
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc,
207-
target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr));
208-
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
209-
type1, nullptr));
210-
211-
// statement 4
212-
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack");
213-
Vec<ASR::call_arg_t> call_args;
214-
call_args.reserve(al, 1);
215-
ASR::call_arg_t call_arg;
216-
call_arg.loc = xx.base.base.loc;
217-
call_arg.m_value = target2;
218-
call_args.push_back(al, call_arg);
194+
Vec<char *> dep;
195+
dep.reserve(al, 1);
219196

220-
// defining the assignment statement
221-
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
222-
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
223-
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
224-
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym,
225-
basic_new_stack_sym, call_args.p, call_args.n, nullptr));
226-
227-
pass_result.push_back(al, stmt1);
228-
pass_result.push_back(al, stmt2);
229-
pass_result.push_back(al, stmt3);
230-
pass_result.push_back(al, stmt4);
197+
ASR::asr_t* new_subrout = ASRUtils::make_Function_t_util(al, xx.base.base.loc,
198+
fn_symtab, s2c(al, new_name), dep.p, dep.n, args.p, args.n, body.p, body.n,
199+
nullptr, ASR::abiType::BindC, ASR::accessType::Public,
200+
ASR::deftypeType::Interface, s2c(al, new_name), false, false, false,
201+
false, false, nullptr, 0, false, false, false, s2c(al, header));
202+
ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t>(new_subrout);
203+
module_scope->add_symbol(new_name, new_symbol);
204+
}
205+
206+
ASR::symbol_t* var_sym = current_scope->get_symbol(var_name);
207+
ASR::symbol_t* placeholder_sym = current_scope->get_symbol(placeholder);
208+
ASR::expr_t* target1 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, placeholder_sym));
209+
ASR::expr_t* target2 = ASRUtils::EXPR(ASR::make_Var_t(al, xx.base.base.loc, var_sym));
210+
211+
// statement 1
212+
ASR::expr_t* value1 = ASRUtils::EXPR(ASR::make_Cast_t(al, xx.base.base.loc,
213+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0,
214+
ASRUtils::TYPE(ASR::make_Integer_t(al, xx.base.base.loc, 4)))),
215+
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2,
216+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2))));
217+
218+
// statement 2
219+
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1));
220+
221+
// statement 3
222+
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc,
223+
target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr));
224+
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
225+
type1, nullptr));
226+
227+
// statement 4
228+
ASR::symbol_t* basic_new_stack_sym = module_scope->get_symbol("basic_new_stack");
229+
Vec<ASR::call_arg_t> call_args;
230+
call_args.reserve(al, 1);
231+
ASR::call_arg_t call_arg;
232+
call_arg.loc = xx.base.base.loc;
233+
call_arg.m_value = target2;
234+
call_args.push_back(al, call_arg);
235+
236+
// defining the assignment statement
237+
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
238+
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value2, nullptr));
239+
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target2, value3, nullptr));
240+
ASR::stmt_t* stmt4 = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, xx.base.base.loc, basic_new_stack_sym,
241+
basic_new_stack_sym, call_args.p, call_args.n, nullptr));
242+
243+
pass_result.push_back(al, stmt1);
244+
pass_result.push_back(al, stmt2);
245+
pass_result.push_back(al, stmt3);
246+
pass_result.push_back(al, stmt4);
247+
}
231248
}
232249
}
233250

@@ -621,7 +638,24 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
621638
if (cast_t->m_kind == ASR::cast_kindType::IntegerToSymbolicExpression) {
622639
ASR::expr_t* cast_arg = cast_t->m_arg;
623640
ASR::expr_t* cast_value = cast_t->m_value;
624-
if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
641+
if (ASR::is_a<ASR::Var_t>(*cast_arg)) {
642+
ASR::symbol_t* integer_set_sym = declare_integer_set_si_function(al, x.base.base.loc, module_scope);
643+
ASR::ttype_t* cast_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 8));
644+
ASR::expr_t* value = ASRUtils::EXPR(ASR::make_Cast_t(al, x.base.base.loc, cast_arg,
645+
(ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, nullptr));
646+
Vec<ASR::call_arg_t> call_args;
647+
call_args.reserve(al, 2);
648+
ASR::call_arg_t call_arg1, call_arg2;
649+
call_arg1.loc = x.base.base.loc;
650+
call_arg1.m_value = x.m_target;
651+
call_arg2.loc = x.base.base.loc;
652+
call_arg2.m_value = value;
653+
call_args.push_back(al, call_arg1);
654+
call_args.push_back(al, call_arg2);
655+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, integer_set_sym,
656+
integer_set_sym, call_args.p, call_args.n, nullptr));
657+
pass_result.push_back(al, stmt);
658+
} else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
625659
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(cast_value);
626660
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id;
627661
if (static_cast<LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id) ==
@@ -668,7 +702,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
668702
ASR::expr_t* val = x.m_values[i];
669703
if (ASR::is_a<ASR::Var_t>(*val) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type(val))) {
670704
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(val)->m_v;
671-
if (symbolic_vars.find(v) == symbolic_vars.end()) return;
705+
if (symbolic_vars_to_free.find(v) == symbolic_vars_to_free.end()) return;
672706
ASR::symbol_t* basic_str_sym = declare_basic_str_function(al, x.base.base.loc, module_scope);
673707

674708
// Extract the symbol from value (Var)

0 commit comments

Comments
 (0)