@@ -45,7 +45,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
45
45
pass_result.reserve (al, 1 );
46
46
}
47
47
std::vector<std::string> symbolic_dependencies;
48
- std::set<ASR::symbol_t *> symbolic_vars;
48
+ std::set<ASR::symbol_t *> symbolic_vars_to_free;
49
+ std::set<ASR::symbol_t *> symbolic_vars_to_omit;
49
50
SymEngine_Stack symengine_stack;
50
51
51
52
void visit_Function (const ASR::Function_t &x) {
@@ -55,6 +56,16 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
55
56
SymbolTable* current_scope_copy = this ->current_scope ;
56
57
this ->current_scope = xx.m_symtab ;
57
58
SymbolTable* module_scope = this ->current_scope ->parent ;
59
+
60
+ ASR::ttype_t * f_signature= xx.m_function_signature ;
61
+ ASR::FunctionType_t *f_type = ASR::down_cast<ASR::FunctionType_t>(f_signature);
62
+ ASR::ttype_t *type1 = ASRUtils::TYPE (ASR::make_CPtr_t (al, xx.base .base .loc ));
63
+ for (size_t i = 0 ; i < f_type->n_arg_types ; ++i) {
64
+ if (f_type->m_arg_types [i]->type == ASR::ttypeType::SymbolicExpression) {
65
+ f_type->m_arg_types [i] = type1;
66
+ }
67
+ }
68
+
58
69
for (auto &item : x.m_symtab ->get_scope ()) {
59
70
if (ASR::is_a<ASR::Variable_t>(*item.second )) {
60
71
ASR::Variable_t *s = ASR::down_cast<ASR::Variable_t>(item.second );
@@ -83,7 +94,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
83
94
Vec<ASR::stmt_t *> func_body;
84
95
func_body.from_pointer_n_copy (al, xx.m_body , xx.n_body );
85
96
86
- for (ASR::symbol_t * symbol : symbolic_vars) {
97
+ for (ASR::symbol_t * symbol : symbolic_vars_to_free) {
98
+ if (symbolic_vars_to_omit.find (symbol) != symbolic_vars_to_omit.end ()) continue ;
87
99
Vec<ASR::call_arg_t > call_args;
88
100
call_args.reserve (al, 1 );
89
101
ASR::call_arg_t call_arg;
@@ -97,7 +109,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
97
109
98
110
xx.n_body = func_body.size ();
99
111
xx.m_body = func_body.p ;
100
- symbolic_vars .clear ();
112
+ symbolic_vars_to_free .clear ();
101
113
}
102
114
103
115
void visit_Variable (const ASR::Variable_t& x) {
@@ -109,125 +121,130 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
109
121
110
122
ASR::ttype_t *type1 = ASRUtils::TYPE (ASR::make_CPtr_t (al, xx.base .base .loc ));
111
123
xx.m_type = type1;
112
- symbolic_vars.insert (ASR::down_cast<ASR::symbol_t >((ASR::asr_t *)&xx));
113
-
114
- ASR::ttype_t *type2 = ASRUtils::TYPE (ASR::make_Integer_t (al, xx.base .base .loc , 8 ));
115
- ASR::symbol_t * sym2 = ASR::down_cast<ASR::symbol_t >(
116
- ASR::make_Variable_t (al, xx.base .base .loc , current_scope,
117
- s2c (al, placeholder), nullptr , 0 ,
118
- xx.m_intent , nullptr ,
119
- nullptr , xx.m_storage ,
120
- type2, nullptr , xx.m_abi ,
121
- xx.m_access , xx.m_presence ,
122
- xx.m_value_attr ));
123
-
124
- current_scope->add_symbol (s2c (al, placeholder), sym2);
125
-
126
- std::string new_name = " basic_new_stack" ;
127
- symbolic_dependencies.push_back (new_name);
128
- if (!module_scope->get_symbol (new_name)) {
129
- std::string header = " symengine/cwrapper.h" ;
130
- SymbolTable *fn_symtab = al.make_new <SymbolTable>(module_scope);
131
-
132
- Vec<ASR::expr_t *> args;
133
- {
134
- args.reserve (al, 1 );
135
- ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
136
- al, xx.base .base .loc , fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
137
- nullptr , nullptr , ASR::storage_typeType::Default, type1, nullptr ,
138
- ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
139
- fn_symtab->add_symbol (s2c (al, " x" ), arg);
140
- args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , arg)));
141
- }
124
+ symbolic_vars_to_free.insert (ASR::down_cast<ASR::symbol_t >((ASR::asr_t *)&xx));
125
+ if (xx.m_intent == ASR::intentType::In){
126
+ symbolic_vars_to_omit.insert (ASR::down_cast<ASR::symbol_t >((ASR::asr_t *)&xx));
127
+ }
142
128
143
- Vec<ASR::stmt_t *> body;
144
- body.reserve (al, 1 );
129
+ if (xx.m_intent == ASR::intentType::Local){
130
+ ASR::ttype_t *type2 = ASRUtils::TYPE (ASR::make_Integer_t (al, xx.base .base .loc , 8 ));
131
+ ASR::symbol_t * sym2 = ASR::down_cast<ASR::symbol_t >(
132
+ ASR::make_Variable_t (al, xx.base .base .loc , current_scope,
133
+ s2c (al, placeholder), nullptr , 0 ,
134
+ xx.m_intent , nullptr ,
135
+ nullptr , xx.m_storage ,
136
+ type2, nullptr , xx.m_abi ,
137
+ xx.m_access , xx.m_presence ,
138
+ xx.m_value_attr ));
145
139
146
- Vec<char *> dep;
147
- dep.reserve (al, 1 );
140
+ current_scope->add_symbol (s2c (al, placeholder), sym2);
148
141
149
- ASR::asr_t * new_subrout = ASRUtils::make_Function_t_util (al, xx.base .base .loc ,
150
- fn_symtab, s2c (al, new_name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
151
- nullptr , ASR::abiType::BindC, ASR::accessType::Public,
152
- ASR::deftypeType::Interface, s2c (al, new_name), false , false , false ,
153
- false , false , nullptr , 0 , false , false , false , s2c (al, header));
154
- ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t >(new_subrout);
155
- module_scope->add_symbol (new_name, new_symbol);
156
- }
142
+ std::string new_name = " basic_new_stack" ;
143
+ symbolic_dependencies.push_back (new_name);
144
+ if (!module_scope->get_symbol (new_name)) {
145
+ std::string header = " symengine/cwrapper.h" ;
146
+ SymbolTable *fn_symtab = al.make_new <SymbolTable>(module_scope);
147
+
148
+ Vec<ASR::expr_t *> args;
149
+ {
150
+ args.reserve (al, 1 );
151
+ ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
152
+ al, xx.base .base .loc , fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
153
+ nullptr , nullptr , ASR::storage_typeType::Default, type1, nullptr ,
154
+ ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
155
+ fn_symtab->add_symbol (s2c (al, " x" ), arg);
156
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , arg)));
157
+ }
157
158
158
- new_name = " basic_free_stack" ;
159
- symbolic_dependencies.push_back (new_name);
160
- if (!module_scope->get_symbol (new_name)) {
161
- std::string header = " symengine/cwrapper.h" ;
162
- SymbolTable *fn_symtab = al.make_new <SymbolTable>(module_scope);
159
+ Vec<ASR::stmt_t *> body;
160
+ body.reserve (al, 1 );
163
161
164
- Vec<ASR::expr_t *> args;
165
- {
166
- args.reserve (al, 1 );
167
- ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
168
- al, xx.base .base .loc , fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
169
- nullptr , nullptr , ASR::storage_typeType::Default, type1, nullptr ,
170
- ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
171
- fn_symtab->add_symbol (s2c (al, " x" ), arg);
172
- args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , arg)));
162
+ Vec<char *> dep;
163
+ dep.reserve (al, 1 );
164
+
165
+ ASR::asr_t * new_subrout = ASRUtils::make_Function_t_util (al, xx.base .base .loc ,
166
+ fn_symtab, s2c (al, new_name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
167
+ nullptr , ASR::abiType::BindC, ASR::accessType::Public,
168
+ ASR::deftypeType::Interface, s2c (al, new_name), false , false , false ,
169
+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
170
+ ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t >(new_subrout);
171
+ module_scope->add_symbol (new_name, new_symbol);
173
172
}
174
173
175
- Vec<ASR::stmt_t *> body;
176
- body.reserve (al, 1 );
174
+ new_name = " basic_free_stack" ;
175
+ symbolic_dependencies.push_back (new_name);
176
+ if (!module_scope->get_symbol (new_name)) {
177
+ std::string header = " symengine/cwrapper.h" ;
178
+ SymbolTable *fn_symtab = al.make_new <SymbolTable>(module_scope);
177
179
178
- Vec<char *> dep;
179
- dep.reserve (al, 1 );
180
+ Vec<ASR::expr_t *> args;
181
+ {
182
+ args.reserve (al, 1 );
183
+ ASR::symbol_t *arg = ASR::down_cast<ASR::symbol_t >(ASR::make_Variable_t (
184
+ al, xx.base .base .loc , fn_symtab, s2c (al, " x" ), nullptr , 0 , ASR::intentType::In,
185
+ nullptr , nullptr , ASR::storage_typeType::Default, type1, nullptr ,
186
+ ASR::abiType::BindC, ASR::Public, ASR::presenceType::Required, true ));
187
+ fn_symtab->add_symbol (s2c (al, " x" ), arg);
188
+ args.push_back (al, ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , arg)));
189
+ }
180
190
181
- ASR::asr_t * new_subrout = ASRUtils::make_Function_t_util (al, xx.base .base .loc ,
182
- fn_symtab, s2c (al, new_name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
183
- nullptr , ASR::abiType::BindC, ASR::accessType::Public,
184
- ASR::deftypeType::Interface, s2c (al, new_name), false , false , false ,
185
- false , false , nullptr , 0 , false , false , false , s2c (al, header));
186
- ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t >(new_subrout);
187
- module_scope->add_symbol (new_name, new_symbol);
188
- }
191
+ Vec<ASR::stmt_t *> body;
192
+ body.reserve (al, 1 );
189
193
190
- ASR::symbol_t * var_sym = current_scope->get_symbol (var_name);
191
- ASR::symbol_t * placeholder_sym = current_scope->get_symbol (placeholder);
192
- ASR::expr_t * target1 = ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , placeholder_sym));
193
- ASR::expr_t * target2 = ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , var_sym));
194
-
195
- // statement 1
196
- ASR::expr_t * value1 = ASRUtils::EXPR (ASR::make_Cast_t (al, xx.base .base .loc ,
197
- ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, xx.base .base .loc , 0 ,
198
- ASRUtils::TYPE (ASR::make_Integer_t (al, xx.base .base .loc , 4 )))),
199
- (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2,
200
- ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, xx.base .base .loc , 0 , type2))));
201
-
202
- // statement 2
203
- ASR::expr_t * value2 = ASRUtils::EXPR (ASR::make_PointerNullConstant_t (al, xx.base .base .loc , type1));
204
-
205
- // statement 3
206
- ASR::expr_t * get_pointer_node = ASRUtils::EXPR (ASR::make_GetPointer_t (al, xx.base .base .loc ,
207
- target1, ASRUtils::TYPE (ASR::make_Pointer_t (al, xx.base .base .loc , type2)), nullptr ));
208
- ASR::expr_t * value3 = ASRUtils::EXPR (ASR::make_PointerToCPtr_t (al, xx.base .base .loc , get_pointer_node,
209
- type1, nullptr ));
210
-
211
- // statement 4
212
- ASR::symbol_t * basic_new_stack_sym = module_scope->get_symbol (" basic_new_stack" );
213
- Vec<ASR::call_arg_t > call_args;
214
- call_args.reserve (al, 1 );
215
- ASR::call_arg_t call_arg;
216
- call_arg.loc = xx.base .base .loc ;
217
- call_arg.m_value = target2;
218
- call_args.push_back (al, call_arg);
194
+ Vec<char *> dep;
195
+ dep.reserve (al, 1 );
219
196
220
- // defining the assignment statement
221
- ASR::stmt_t * stmt1 = ASRUtils::STMT (ASR::make_Assignment_t (al, xx.base .base .loc , target1, value1, nullptr ));
222
- ASR::stmt_t * stmt2 = ASRUtils::STMT (ASR::make_Assignment_t (al, xx.base .base .loc , target2, value2, nullptr ));
223
- ASR::stmt_t * stmt3 = ASRUtils::STMT (ASR::make_Assignment_t (al, xx.base .base .loc , target2, value3, nullptr ));
224
- ASR::stmt_t * stmt4 = ASRUtils::STMT (ASR::make_SubroutineCall_t (al, xx.base .base .loc , basic_new_stack_sym,
225
- basic_new_stack_sym, call_args.p , call_args.n , nullptr ));
226
-
227
- pass_result.push_back (al, stmt1);
228
- pass_result.push_back (al, stmt2);
229
- pass_result.push_back (al, stmt3);
230
- pass_result.push_back (al, stmt4);
197
+ ASR::asr_t * new_subrout = ASRUtils::make_Function_t_util (al, xx.base .base .loc ,
198
+ fn_symtab, s2c (al, new_name), dep.p , dep.n , args.p , args.n , body.p , body.n ,
199
+ nullptr , ASR::abiType::BindC, ASR::accessType::Public,
200
+ ASR::deftypeType::Interface, s2c (al, new_name), false , false , false ,
201
+ false , false , nullptr , 0 , false , false , false , s2c (al, header));
202
+ ASR::symbol_t *new_symbol = ASR::down_cast<ASR::symbol_t >(new_subrout);
203
+ module_scope->add_symbol (new_name, new_symbol);
204
+ }
205
+
206
+ ASR::symbol_t * var_sym = current_scope->get_symbol (var_name);
207
+ ASR::symbol_t * placeholder_sym = current_scope->get_symbol (placeholder);
208
+ ASR::expr_t * target1 = ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , placeholder_sym));
209
+ ASR::expr_t * target2 = ASRUtils::EXPR (ASR::make_Var_t (al, xx.base .base .loc , var_sym));
210
+
211
+ // statement 1
212
+ ASR::expr_t * value1 = ASRUtils::EXPR (ASR::make_Cast_t (al, xx.base .base .loc ,
213
+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, xx.base .base .loc , 0 ,
214
+ ASRUtils::TYPE (ASR::make_Integer_t (al, xx.base .base .loc , 4 )))),
215
+ (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, type2,
216
+ ASRUtils::EXPR (ASR::make_IntegerConstant_t (al, xx.base .base .loc , 0 , type2))));
217
+
218
+ // statement 2
219
+ ASR::expr_t * value2 = ASRUtils::EXPR (ASR::make_PointerNullConstant_t (al, xx.base .base .loc , type1));
220
+
221
+ // statement 3
222
+ ASR::expr_t * get_pointer_node = ASRUtils::EXPR (ASR::make_GetPointer_t (al, xx.base .base .loc ,
223
+ target1, ASRUtils::TYPE (ASR::make_Pointer_t (al, xx.base .base .loc , type2)), nullptr ));
224
+ ASR::expr_t * value3 = ASRUtils::EXPR (ASR::make_PointerToCPtr_t (al, xx.base .base .loc , get_pointer_node,
225
+ type1, nullptr ));
226
+
227
+ // statement 4
228
+ ASR::symbol_t * basic_new_stack_sym = module_scope->get_symbol (" basic_new_stack" );
229
+ Vec<ASR::call_arg_t > call_args;
230
+ call_args.reserve (al, 1 );
231
+ ASR::call_arg_t call_arg;
232
+ call_arg.loc = xx.base .base .loc ;
233
+ call_arg.m_value = target2;
234
+ call_args.push_back (al, call_arg);
235
+
236
+ // defining the assignment statement
237
+ ASR::stmt_t * stmt1 = ASRUtils::STMT (ASR::make_Assignment_t (al, xx.base .base .loc , target1, value1, nullptr ));
238
+ ASR::stmt_t * stmt2 = ASRUtils::STMT (ASR::make_Assignment_t (al, xx.base .base .loc , target2, value2, nullptr ));
239
+ ASR::stmt_t * stmt3 = ASRUtils::STMT (ASR::make_Assignment_t (al, xx.base .base .loc , target2, value3, nullptr ));
240
+ ASR::stmt_t * stmt4 = ASRUtils::STMT (ASR::make_SubroutineCall_t (al, xx.base .base .loc , basic_new_stack_sym,
241
+ basic_new_stack_sym, call_args.p , call_args.n , nullptr ));
242
+
243
+ pass_result.push_back (al, stmt1);
244
+ pass_result.push_back (al, stmt2);
245
+ pass_result.push_back (al, stmt3);
246
+ pass_result.push_back (al, stmt4);
247
+ }
231
248
}
232
249
}
233
250
@@ -621,7 +638,24 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
621
638
if (cast_t ->m_kind == ASR::cast_kindType::IntegerToSymbolicExpression) {
622
639
ASR::expr_t * cast_arg = cast_t ->m_arg ;
623
640
ASR::expr_t * cast_value = cast_t ->m_value ;
624
- if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
641
+ if (ASR::is_a<ASR::Var_t>(*cast_arg)) {
642
+ ASR::symbol_t * integer_set_sym = declare_integer_set_si_function (al, x.base .base .loc , module_scope);
643
+ ASR::ttype_t * cast_type = ASRUtils::TYPE (ASR::make_Integer_t (al, x.base .base .loc , 8 ));
644
+ ASR::expr_t * value = ASRUtils::EXPR (ASR::make_Cast_t (al, x.base .base .loc , cast_arg,
645
+ (ASR::cast_kindType)ASR::cast_kindType::IntegerToInteger, cast_type, nullptr ));
646
+ Vec<ASR::call_arg_t > call_args;
647
+ call_args.reserve (al, 2 );
648
+ ASR::call_arg_t call_arg1, call_arg2;
649
+ call_arg1.loc = x.base .base .loc ;
650
+ call_arg1.m_value = x.m_target ;
651
+ call_arg2.loc = x.base .base .loc ;
652
+ call_arg2.m_value = value;
653
+ call_args.push_back (al, call_arg1);
654
+ call_args.push_back (al, call_arg2);
655
+ ASR::stmt_t * stmt = ASRUtils::STMT (ASR::make_SubroutineCall_t (al, x.base .base .loc , integer_set_sym,
656
+ integer_set_sym, call_args.p , call_args.n , nullptr ));
657
+ pass_result.push_back (al, stmt);
658
+ } else if (ASR::is_a<ASR::IntrinsicScalarFunction_t>(*cast_value)) {
625
659
ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast<ASR::IntrinsicScalarFunction_t>(cast_value);
626
660
int64_t intrinsic_id = intrinsic_func->m_intrinsic_id ;
627
661
if (static_cast <LCompilers::ASRUtils::IntrinsicScalarFunctions>(intrinsic_id) ==
@@ -668,7 +702,7 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
668
702
ASR::expr_t * val = x.m_values [i];
669
703
if (ASR::is_a<ASR::Var_t>(*val) && ASR::is_a<ASR::CPtr_t>(*ASRUtils::expr_type (val))) {
670
704
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(val)->m_v ;
671
- if (symbolic_vars .find (v) == symbolic_vars .end ()) return ;
705
+ if (symbolic_vars_to_free .find (v) == symbolic_vars_to_free .end ()) return ;
672
706
ASR::symbol_t * basic_str_sym = declare_basic_str_function (al, x.base .base .loc , module_scope);
673
707
674
708
// Extract the symbol from value (Var)
0 commit comments