Skip to content

Commit b1001b6

Browse files
authored
Merge pull request #2225 from Shaikh-Ubaid/c_support_list_repeat
C: Support ListRepeat
2 parents 7481e00 + 1cb3eb0 commit b1001b6

File tree

5 files changed

+86
-3
lines changed

5 files changed

+86
-3
lines changed

integration_tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ RUN(NAME test_list_section2 LABELS cpython llvm c NOFAST)
504504
RUN(NAME test_list_count LABELS cpython llvm)
505505
RUN(NAME test_list_index LABELS cpython llvm)
506506
RUN(NAME test_list_index2 LABELS cpython llvm)
507-
RUN(NAME test_list_repeat LABELS cpython llvm NOFAST)
507+
RUN(NAME test_list_repeat LABELS cpython llvm c NOFAST)
508+
RUN(NAME test_list_repeat2 LABELS cpython llvm c NOFAST)
508509
RUN(NAME test_list_reverse LABELS cpython llvm)
509510
RUN(NAME test_list_pop LABELS cpython llvm NOFAST) # TODO: Remove NOFAST from here.
510511
RUN(NAME test_list_pop2 LABELS cpython llvm NOFAST) # TODO: Remove NOFAST from here.

integration_tests/test_list_repeat.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@ def test_list_repeat():
2828
l_str_3 = l_str_1 * i
2929
assert l_str_3 == l_str_2
3030
l_str_2 += l_str_1
31-
31+
3232
for i in range(5):
3333
assert l_int_1 * i + l_int_1 * (i + 1) == l_int_1 * (2 * i + 1)
3434
assert l_tuple_1 * i + l_tuple_1 * (i + 1) == l_tuple_1 * (2 * i + 1)
3535
assert l_str_1 * i + l_str_1 * (i + 1) == l_str_1 * (2 * i + 1)
3636

37-
test_list_repeat()
37+
print(l_int_1)
38+
print(l_tuple_1)
39+
print(l_tuple_1)
40+
41+
test_list_repeat()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from lpython import i32, f32
2+
3+
def add_list(x: list[f32]) -> f32:
4+
sum: f32 = f32(0.0)
5+
i: i32
6+
7+
for i in range(len(x)):
8+
sum = sum + f32(x[i])
9+
return sum
10+
11+
def create_list(n: i32) -> list[f32]:
12+
x: list[f32]
13+
i: i32
14+
15+
x = [f32(0.0)] * n
16+
for i in range(n):
17+
x[i] = f32(i)
18+
return x
19+
20+
def main0():
21+
x: list[f32] = create_list(i32(10))
22+
print(add_list(x))
23+
24+
main0()

src/libasr/codegen/asr_to_c_cpp.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,20 @@ PyMODINIT_FUNC PyInit_lpython_module_)" + fn_name + R"((void) {
16761676
src += indent + list_remove_func + "(&" + list_var + ", " + element + ");\n";
16771677
}
16781678

1679+
void visit_ListRepeat(const ASR::ListRepeat_t& x) {
1680+
CHECK_FAST_C_CPP(compiler_options, x)
1681+
ASR::List_t* t = ASR::down_cast<ASR::List_t>(x.m_type);
1682+
std::string list_repeat_func = c_ds_api->get_list_repeat_func(t);
1683+
bracket_open++;
1684+
self().visit_expr(*x.m_left);
1685+
std::string list_var = std::move(src);
1686+
self().visit_expr(*x.m_right);
1687+
std::string freq = std::move(src);
1688+
bracket_open--;
1689+
tmp_buffer_src.push_back(check_tmp_buffer());
1690+
src = "(*" + list_repeat_func + "(&" + list_var + ", " + freq + "))";
1691+
}
1692+
16791693
void visit_ListLen(const ASR::ListLen_t& x) {
16801694
CHECK_FAST_C_CPP(compiler_options, x)
16811695
self().visit_expr(*x.m_arg);

src/libasr/codegen/c_utils.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ class CCPPDSUtils {
594594
list_remove(list_struct_type, list_type_code, list_element_type, list_type->m_type);
595595
list_clear(list_struct_type, list_type_code, list_element_type);
596596
list_concat(list_struct_type, list_type_code, list_element_type, list_type->m_type);
597+
list_repeat(list_struct_type, list_type_code, list_element_type, list_type->m_type);
597598
list_section(list_struct_type, list_type_code);
598599
return list_struct_type;
599600
}
@@ -652,6 +653,11 @@ class CCPPDSUtils {
652653
return typecodeToDSfuncs[list_type_code]["list_concat"];
653654
}
654655

656+
std::string get_list_repeat_func(ASR::List_t* list_type) {
657+
std::string list_type_code = ASRUtils::get_type_code(list_type->m_type, true);
658+
return typecodeToDSfuncs[list_type_code]["list_repeat"];
659+
}
660+
655661
std::string get_list_find_item_position_function(std::string list_type_code) {
656662
return typecodeToDSfuncs[list_type_code]["list_find_item"];
657663
}
@@ -934,6 +940,40 @@ class CCPPDSUtils {
934940
generated_code += indent + "}\n\n";
935941
}
936942

943+
void list_repeat(std::string list_struct_type,
944+
std::string list_type_code,
945+
std::string list_element_type, ASR::ttype_t *m_type) {
946+
std::string indent(indentation_level * indentation_spaces, ' ');
947+
std::string tab(indentation_spaces, ' ');
948+
std::string list_con_func = global_scope->get_unique_name("list_repeat_" + list_type_code);
949+
typecodeToDSfuncs[list_type_code]["list_repeat"] = list_con_func;
950+
std::string init_func = typecodeToDSfuncs[list_type_code]["list_init"];
951+
std::string signature = list_struct_type + "* " + list_con_func + "("
952+
+ list_struct_type + "* x, "
953+
+ "int32_t freq)";
954+
func_decls += "inline " + signature + ";\n";
955+
generated_code += indent + signature + " {\n";
956+
generated_code += indent + tab + list_struct_type + " *result = (" + list_struct_type + "*)malloc(sizeof(" +
957+
list_struct_type + "));\n";
958+
generated_code += indent + tab + init_func + "(result, x->current_end_point * freq);\n";
959+
generated_code += indent + tab + "for (int i=0; i<freq; i++) {\n";
960+
961+
if (ASR::is_a<ASR::List_t>(*m_type)) {
962+
ASR::ttype_t *tt = ASR::down_cast<ASR::List_t>(m_type)->m_type;
963+
std::string deep_copy_func = typecodeToDSfuncs[ASRUtils::get_type_code(tt, true)]["list_deepcopy"];
964+
LCOMPILERS_ASSERT(deep_copy_func.size() > 0);
965+
generated_code += indent + tab + tab + "for(int j=0; j<x->current_end_point; j++)\n";
966+
generated_code += indent + tab + tab + tab + deep_copy_func + "(&x->data[j], &result->data[i*x->current_end_point+j]);\n";
967+
} else {
968+
generated_code += indent + tab + tab + "memcpy(&result->data[i*x->current_end_point], x->data, x->current_end_point * sizeof(" + list_element_type + "));\n";
969+
}
970+
971+
generated_code += indent + tab + "}\n";
972+
generated_code += indent + tab + "result->current_end_point = x->current_end_point * freq;\n";
973+
generated_code += indent + tab + "return result;\n";
974+
generated_code += indent + "}\n\n";
975+
}
976+
937977
void resize_if_needed(std::string list_struct_type,
938978
std::string list_type_code,
939979
std::string list_element_type) {

0 commit comments

Comments
 (0)