Skip to content

Commit 9b5f52f

Browse files
authored
Adding Support for symbolics in the list data structure (#2368)
1 parent ac65227 commit 9b5f52f

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,7 @@ RUN(NAME symbolics_07 LABELS cpython_sym c_sym llvm_sym NOFAST)
714714
RUN(NAME symbolics_08 LABELS cpython_sym c_sym llvm_sym)
715715
RUN(NAME symbolics_09 LABELS cpython_sym c_sym llvm_sym NOFAST)
716716
RUN(NAME symbolics_10 LABELS cpython_sym c_sym llvm_sym NOFAST)
717+
RUN(NAME symbolics_11 LABELS cpython_sym c_sym NOFAST)
717718

718719
RUN(NAME sizeof_01 LABELS llvm c
719720
EXTRAFILES sizeof_01b.c)

integration_tests/symbolics_11.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from sympy import Symbol, sin, pi
2+
from lpython import S
3+
4+
def test_extraction_of_elements():
5+
x: S = Symbol("x")
6+
l1: list[S] = [x, pi, sin(x), Symbol("y")]
7+
ele1: S = l1[0]
8+
ele2: S = l1[1]
9+
ele3: S = l1[2]
10+
ele4: S = l1[3]
11+
12+
assert(ele1 == x)
13+
assert(ele2 == pi)
14+
assert(ele3 == sin(x))
15+
assert(ele4 == Symbol("y"))
16+
print(ele1, ele2, ele3, ele4)
17+
18+
test_extraction_of_elements()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,13 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
245245
pass_result.push_back(al, stmt3);
246246
pass_result.push_back(al, stmt4);
247247
}
248+
} else if (xx.m_type->type == ASR::ttypeType::List) {
249+
ASR::List_t* list = ASR::down_cast<ASR::List_t>(xx.m_type);
250+
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
251+
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, xx.base.base.loc));
252+
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, xx.base.base.loc, CPtr_type));
253+
xx.m_type = list_type;
254+
}
248255
}
249256
}
250257

@@ -920,6 +927,47 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
920927
}
921928
}
922929
}
930+
} else if (ASR::is_a<ASR::ListConstant_t>(*x.m_value)) {
931+
ASR::ListConstant_t* list_constant = ASR::down_cast<ASR::ListConstant_t>(x.m_value);
932+
if (list_constant->m_type->type == ASR::ttypeType::List) {
933+
ASR::List_t* list = ASR::down_cast<ASR::List_t>(list_constant->m_type);
934+
if (list->m_type->type == ASR::ttypeType::SymbolicExpression){
935+
Vec<ASR::expr_t*> temp_list;
936+
temp_list.reserve(al, list_constant->n_args + 1);
937+
938+
for (size_t i = 0; i < list_constant->n_args; ++i) {
939+
ASR::expr_t* value = handle_argument(al, x.base.base.loc, list_constant->m_args[i]);
940+
temp_list.push_back(al, value);
941+
}
942+
943+
ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
944+
ASR::ttype_t* list_type = ASRUtils::TYPE(ASR::make_List_t(al, x.base.base.loc, type));
945+
ASR::expr_t* temp_list_const = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, temp_list.p,
946+
temp_list.size(), list_type));
947+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, x.m_target, temp_list_const, nullptr));
948+
pass_result.push_back(al, stmt);
949+
}
950+
}
951+
} else if (ASR::is_a<ASR::ListItem_t>(*x.m_value)) {
952+
ASR::ListItem_t* list_item = ASR::down_cast<ASR::ListItem_t>(x.m_value);
953+
if (list_item->m_type->type == ASR::ttypeType::SymbolicExpression) {
954+
ASR::ttype_t *CPtr_type = ASRUtils::TYPE(ASR::make_CPtr_t(al, x.base.base.loc));
955+
ASR::symbol_t* basic_assign_sym = declare_basic_assign_function(al, x.base.base.loc, module_scope);
956+
957+
Vec<ASR::call_arg_t> call_args;
958+
call_args.reserve(al, 2);
959+
ASR::call_arg_t call_arg1, call_arg2;
960+
call_arg1.loc = x.base.base.loc;
961+
call_arg1.m_value = x.m_target;
962+
call_arg2.loc = x.base.base.loc;
963+
call_arg2.m_value = ASRUtils::EXPR(ASR::make_ListItem_t(al, x.base.base.loc, list_item->m_a,
964+
list_item->m_pos, CPtr_type, nullptr));
965+
call_args.push_back(al, call_arg1);
966+
call_args.push_back(al, call_arg2);
967+
ASR::stmt_t* stmt = ASRUtils::STMT(ASR::make_SubroutineCall_t(al, x.base.base.loc, basic_assign_sym,
968+
basic_assign_sym, call_args.p, call_args.n, nullptr));
969+
pass_result.push_back(al, stmt);
970+
}
923971
} else if (ASR::is_a<ASR::SymbolicCompare_t>(*x.m_value)) {
924972
ASR::SymbolicCompare_t *s = ASR::down_cast<ASR::SymbolicCompare_t>(x.m_value);
925973
if (s->m_op == ASR::cmpopType::Eq || s->m_op == ASR::cmpopType::NotEq) {

0 commit comments

Comments
 (0)