Skip to content

Commit a6b9256

Browse files
authored
Merge pull request #2477 from anutosh491/fix_symbolic_list
Fixing Symbolic List assignment
2 parents 467081e + aef00c0 commit a6b9256

File tree

4 files changed

+123
-24
lines changed

4 files changed

+123
-24
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,9 +718,10 @@ RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST)
718718
RUN(NAME symbolics_11 LABELS cpython_sym c_sym llvm_sym NOFAST)
719719
RUN(NAME symbolics_12 LABELS cpython_sym c_sym llvm_sym NOFAST)
720720
RUN(NAME symbolics_13 LABELS cpython_sym c_sym llvm_sym NOFAST)
721-
RUN(NAME symbolics_14 LABELS cpython_sym llvm_sym NOFAST)
721+
RUN(NAME symbolics_14 LABELS cpython_sym c_sym llvm_sym NOFAST)
722722
RUN(NAME test_gruntz LABELS cpython_sym c_sym llvm_sym NOFAST)
723723
RUN(NAME symbolics_15 LABELS c_sym llvm_sym NOFAST)
724+
RUN(NAME symbolics_16 LABELS cpython_sym c_sym llvm_sym NOFAST)
724725

725726
RUN(NAME sizeof_01 LABELS llvm c
726727
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_15.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@ def mmrv(r: Out[list[CPtr]]) -> None:
3333
basic_new_stack(x)
3434
basic_const_pi(x)
3535

36-
# l1: list[S]
36+
# l1: list[S] = [x]
37+
_l1: list[CPtr] = [x]
3738
l1: list[CPtr] = []
3839

39-
# l1 = [x]
4040
i: i32 = 0
41-
Len: i32 = 1
42-
for i in range(Len):
41+
for i in range(len(_l1)):
4342
tmp: CPtr = basic_new_heap()
4443
l1.append(tmp)
4544
basic_assign(l1[0], x)
@@ -57,8 +56,8 @@ def mmrv(r: Out[list[CPtr]]) -> None:
5756
def test_mrv():
5857
# ans : list[S]
5958
# temp : list[S]
60-
ans: list[CPtr] = []
61-
temp: list[CPtr] = []
59+
ans: list[CPtr]
60+
temp: list[CPtr]
6261

6362
# mmrv(ans)
6463
# temp = ans

integration_tests/symbolics_16.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from lpython import S
2+
from sympy import Symbol, pi, sin
3+
4+
def mmrv() -> list[S]:
5+
x: S = Symbol('x')
6+
l1: list[S] = [pi, sin(x)]
7+
return l1
8+
9+
def test_mrv1():
10+
ans: list[S] = mmrv()
11+
element_1: S = ans[0]
12+
element_2: S = ans[1]
13+
assert element_1 == pi
14+
assert element_2 == sin(Symbol('x'))
15+
print(element_1, element_2)
16+
17+
18+
test_mrv1()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 98 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
147147
return SubroutineCall(loc, basic_free_stack_sym, {x});
148148
}
149149

150+
ASR::expr_t *basic_new_heap(const Location& loc) {
151+
ASR::symbol_t* basic_new_heap_sym = create_bindc_function(loc,
152+
"basic_new_heap", {}, ASRUtils::TYPE((ASR::make_CPtr_t(al, loc))));
153+
Vec<ASR::call_arg_t> call_args; call_args.reserve(al, 1);
154+
return FunctionCall(loc, basic_new_heap_sym, {},
155+
ASRUtils::TYPE(ASR::make_CPtr_t(al, loc)));
156+
}
157+
150158
ASR::stmt_t* basic_get_args(const Location& loc, ASR::expr_t *x, ASR::expr_t *y) {
151159
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_CPtr_t(al, loc));
152160
ASR::symbol_t* basic_get_args_sym = create_bindc_function(loc,
@@ -323,8 +331,8 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
323331
std::string var_name = xx.m_name;
324332
std::string placeholder = "_" + std::string(var_name);
325333

326-
ASR::ttype_t *type1 = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
327-
xx.m_type = type1;
334+
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
335+
xx.m_type = CPtr_type;
328336
if (var_name != "_lpython_return_variable" && xx.m_intent != ASR::intentType::Out) {
329337
symbolic_vars_to_free.insert(ASR::down_cast<ASR::symbol_t>((ASR::asr_t*)&xx));
330338
}
@@ -357,13 +365,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
357365
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, xx.base.base.loc, 0, type2))));
358366

359367
// statement 2
360-
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, type1));
368+
ASR::expr_t* value2 = ASRUtils::EXPR(ASR::make_PointerNullConstant_t(al, xx.base.base.loc, CPtr_type));
361369

362370
// statement 3
363371
ASR::expr_t* get_pointer_node = ASRUtils::EXPR(ASR::make_GetPointer_t(al, xx.base.base.loc,
364372
target1, ASRUtils::TYPE(ASR::make_Pointer_t(al, xx.base.base.loc, type2)), nullptr));
365373
ASR::expr_t* value3 = ASRUtils::EXPR(ASR::make_PointerToCPtr_t(al, xx.base.base.loc, get_pointer_node,
366-
type1, nullptr));
374+
CPtr_type, nullptr));
367375

368376
// defining the assignment statement
369377
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, xx.base.base.loc, target1, value1, nullptr));
@@ -548,21 +556,94 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
548556
ASR::ListConstant_t* list_constant = ASR::down_cast<ASR::ListConstant_t>(x.m_value);
549557
if (list_constant->m_type->type == ASR::ttypeType::List) {
550558
ASR::List_t* list = ASR::down_cast<ASR::List_t>(list_constant->m_type);
551-
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
552-
Vec<ASR::expr_t*> temp_list;
553-
temp_list.reserve(al, list_constant->n_args + 1);
554559

555-
for (size_t i = 0; i < list_constant->n_args; ++i) {
556-
ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]);
557-
temp_list.push_back(al, value);
560+
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
561+
if(ASR::is_a<ASR::Var_t>(*x.m_target)) {
562+
ASR::symbol_t *v = ASR::down_cast<ASR::Var_t>(x.m_target)->m_v;
563+
if (ASR::is_a<ASR::Variable_t>(*v)) {
564+
// Step1: Add the placeholder for the list variable to the scope
565+
ASRUtils::ASRBuilder b(al, x.base.base.loc);
566+
ASR::ttype_t* CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
567+
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, CPtr_type));
568+
ASR::Variable_t *list_variable = ASR::down_cast<ASR::Variable_t>(v);
569+
std::string list_name = list_variable->m_name;
570+
std::string placeholder = "_" + std::string(list_name);
571+
572+
ASR::symbol_t* placeholder_sym = ASR::down_cast<ASR::symbol_t>(
573+
ASR::make_Variable_t(al, list_variable->base.base.loc, current_scope,
574+
s2c(al, placeholder), nullptr, 0,
575+
list_variable->m_intent, nullptr,
576+
nullptr, list_variable->m_storage,
577+
list_type, nullptr, list_variable->m_abi,
578+
list_variable->m_access, list_variable->m_presence,
579+
list_variable->m_value_attr));
580+
581+
current_scope->add_symbol(s2c(al, placeholder), placeholder_sym);
582+
ASR::expr_t* placeholder_target = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, placeholder_sym));
583+
584+
Vec<ASR::expr_t*> temp_list1, temp_list2;
585+
temp_list1.reserve(al, list_constant->n_args + 1);
586+
temp_list2.reserve(al, list_constant->n_args + 1);
587+
588+
for (size_t i = 0; i < list_constant->n_args; ++i) {
589+
ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]);
590+
temp_list1.push_back(al, value);
591+
}
592+
593+
ASR::expr_t* temp_list_const1 = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list1.p,
594+
temp_list1.size(), list_type));
595+
ASR::stmt_t* stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, placeholder_target, temp_list_const1, nullptr));
596+
pass_result.push_back(al, stmt1);
597+
598+
// Step2: Add the empty list variable
599+
ASR::expr_t* temp_list_const2 = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list2.p,
600+
temp_list2.size(), list_type));
601+
ASR::stmt_t* stmt2 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const2, nullptr));
602+
pass_result.push_back(al, stmt2);
603+
604+
// Step3: Add the list index to the function scope
605+
std::string symbolic_list_index = current_scope->get_unique_name("symbolic_list_index");
606+
ASR::ttype_t* int32_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4));
607+
ASR::symbol_t* index_sym = ASR::down_cast<ASR::symbol_t>(
608+
ASR::make_Variable_t(al, x.base.base.loc, current_scope, s2c(al, symbolic_list_index),
609+
nullptr, 0, ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default,
610+
int32_type, nullptr, ASR::abiType::Source, ASR::Public, ASR::presenceType::Required, false));
611+
current_scope->add_symbol(symbolic_list_index, index_sym);
612+
ASR::expr_t* index = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, index_sym));
613+
ASR::stmt_t* stmt3 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, index,
614+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int32_type)), nullptr));
615+
pass_result.push_back(al, stmt3);
616+
617+
// Step4: Add the DoLoop for appending elements into the list
618+
std::string block_name = current_scope->get_unique_name("block");
619+
SymbolTable* block_symtab = al.make_new<SymbolTable>(current_scope);
620+
char *tmp_var_name = s2c(al, "tmp");
621+
ASR::expr_t* tmp_var = b.Variable(block_symtab, tmp_var_name, CPtr_type,
622+
ASR::intentType::Local, ASR::abiType::Source, false);
623+
Vec<ASR::stmt_t*> block_body; block_body.reserve(al, 1);
624+
ASR::stmt_t* block_stmt1 = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, tmp_var,
625+
basic_new_heap(x.base.base.loc), nullptr));
626+
block_body.push_back(al, block_stmt1);
627+
ASR::stmt_t* block_stmt2 = ASRUtils::STMT(ASR::make_ListAppend_t(al, x.base.base.loc, x.m_target, tmp_var));
628+
block_body.push_back(al, block_stmt2);
629+
block_body.push_back(al, basic_assign(x.base.base.loc, ASRUtils::EXPR(ASR::make_ListItem_t(al,
630+
x.base.base.loc, x.m_target, index, CPtr_type, nullptr)), ASRUtils::EXPR(ASR::make_ListItem_t(al,
631+
x.base.base.loc, placeholder_target, index, CPtr_type, nullptr))));
632+
ASR::symbol_t* block = ASR::down_cast<ASR::symbol_t>(ASR::make_Block_t(al, x.base.base.loc,
633+
block_symtab, s2c(al, block_name), block_body.p, block_body.n));
634+
current_scope->add_symbol(block_name, block);
635+
ASR::stmt_t* block_call = ASRUtils::STMT(ASR::make_BlockCall_t(
636+
al, x.base.base.loc, -1, block));
637+
std::vector<ASR::stmt_t*> do_loop_body;
638+
do_loop_body.push_back(block_call);
639+
ASR::stmt_t* stmt4 = b.DoLoop(index, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int32_type)),
640+
ASRUtils::EXPR(ASR::make_IntegerBinOp_t(al, x.base.base.loc,
641+
ASRUtils::EXPR(ASR::make_ListLen_t(al, x.base.base.loc, placeholder_target, int32_type, nullptr)), ASR::binopType::Sub,
642+
ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)), int32_type, nullptr)),
643+
do_loop_body, ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 1, int32_type)));
644+
pass_result.push_back(al, stmt4);
645+
}
558646
}
559-
560-
ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
561-
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, type));
562-
ASR::expr_t* temp_list_const = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list.p,
563-
temp_list.size(), list_type));
564-
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const, nullptr));
565-
pass_result.push_back(al, stmt);
566647
}
567648
}
568649
} else if (ASR::is_a<ASR::ListItem_t>(*x.m_value)) {

0 commit comments

Comments
 (0)