Skip to content

Commit 2ca57af

Browse files
committed
Add looping over dictionaries and sets
1 parent 8773842 commit 2ca57af

File tree

6 files changed

+294
-10
lines changed

6 files changed

+294
-10
lines changed

integration_tests/loop_12.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
def test_for_dict_int():
3+
dict_int: dict[i32, i32] = {1:2, 2:3, 3:4}
4+
key: i32
5+
s1: i32 = 0
6+
s2: i32 = 0
7+
8+
for key in dict_int:
9+
print(key)
10+
s1 += key
11+
s2 += dict_int[key]
12+
13+
assert s1 == 6
14+
assert s2 == 9
15+
16+
def test_for_dict_str():
17+
dict_str: dict[str, str] = {"a":"b", "c":"d"}
18+
key: str
19+
s1: str = ""
20+
s2: str = ""
21+
22+
for key in dict_str:
23+
print(key)
24+
s1 += key
25+
s2 += dict_str[key]
26+
27+
assert (s1 == "ac" or s1 == "ca")
28+
assert ((s1 == "ac" and s2 == "bd") or (s1 == "ca" and s2 == "db"))
29+
30+
def test_for_set_int():
31+
set_int: set[i32] = {1, 2, 3}
32+
el: i32
33+
s: i32 = 0
34+
35+
for el in set_int:
36+
print(el)
37+
s += el
38+
39+
assert s == 6
40+
41+
def test_for_set_str():
42+
set_str: set[str] = {'a', 'b'}
43+
el: str
44+
s: str = ""
45+
46+
for el in set_str:
47+
print(el)
48+
s += el
49+
50+
assert (s == "ab" or s == "ba")
51+
52+

src/libasr/ASR.asdl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ stmt
4040
| ErrorStop(expr? code)
4141
| Exit(identifier? stmt_name)
4242
| ForAllSingle(do_loop_head head, stmt assign_stmt)
43+
| ForEach(expr var, expr container, stmt* body)
4344
| GoTo(int target_id, identifier name)
4445
| GoToTarget(int id, identifier name)
4546
| If(expr test, stmt* body, stmt* orelse)

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5775,6 +5775,189 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
57755775
strings_to_be_deallocated.p = strings_to_be_deallocated_copy;
57765776
}
57775777

5778+
void visit_ForEach(const ASR::ForEach_t &x) {
5779+
llvm::Value **strings_to_be_deallocated_copy = strings_to_be_deallocated.p;
5780+
size_t n = strings_to_be_deallocated.n;
5781+
strings_to_be_deallocated.reserve(al, 1);
5782+
5783+
int64_t ptr_loads_copy = ptr_loads;
5784+
ptr_loads = 0;
5785+
this->visit_expr(*x.m_container);
5786+
llvm::Value *pcontainer = tmp;
5787+
ptr_loads = 0;
5788+
this->visit_expr(*x.m_var);
5789+
llvm::Value *pvar = tmp;
5790+
ptr_loads = ptr_loads_copy;
5791+
5792+
if (ASR::is_a<ASR::Dict_t>(*ASRUtils::expr_type(x.m_container))) {
5793+
ASR::Dict_t *dict_type = ASR::down_cast<ASR::Dict_t>(
5794+
ASRUtils::expr_type(x.m_container));
5795+
ASR::ttype_t *key_type = dict_type->m_key_type;
5796+
llvm::Value *capacity = LLVM::CreateLoad(*builder,
5797+
llvm_utils->dict_api->get_pointer_to_capacity(pcontainer));
5798+
llvm::Value *key_mask = LLVM::CreateLoad(*builder,
5799+
llvm_utils->dict_api->get_pointer_to_keymask(pcontainer));
5800+
llvm::Value *key_list = llvm_utils->dict_api->get_key_list(pcontainer);
5801+
llvm::AllocaInst *idx_ptr = builder->CreateAlloca(
5802+
llvm::Type::getInt32Ty(context), nullptr);
5803+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
5804+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr);
5805+
5806+
if (llvm_utils->dict_api == llvm_utils->dict_api_sc) {
5807+
llvm::Value *key_value_pairs = LLVM::CreateLoad(*builder,
5808+
llvm_utils->dict_api->get_pointer_to_key_value_pairs(pcontainer));
5809+
llvm::Type* kv_pair_type =
5810+
llvm_utils->dict_api->get_key_value_pair_type(key_type, dict_type->m_value_type);
5811+
llvm::AllocaInst *chain_itr = builder->CreateAlloca(
5812+
llvm::Type::getInt8PtrTy(context), nullptr);
5813+
5814+
create_loop(nullptr, [=](){
5815+
call_lcompilers_free_strings();
5816+
return builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
5817+
}, [&](){
5818+
llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr);
5819+
llvm::Value* key_mask_value = LLVM::CreateLoad(*builder,
5820+
llvm_utils->create_ptr_gep(key_mask, idx));
5821+
llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value,
5822+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
5823+
5824+
llvm_utils->create_if_else(is_key_set, [&]() {
5825+
llvm::Value* dict_i = llvm_utils->create_ptr_gep(key_value_pairs, idx);
5826+
llvm::Value* kv_ll_i8 = builder->CreateBitCast(dict_i, llvm::Type::getInt8PtrTy(context));
5827+
LLVM::CreateStore(*builder, kv_ll_i8, chain_itr);
5828+
5829+
llvm::BasicBlock *loop2head = llvm::BasicBlock::Create(context, "loop2.head");
5830+
llvm::BasicBlock *loop2body = llvm::BasicBlock::Create(context, "loop2.body");
5831+
llvm::BasicBlock *loop2end = llvm::BasicBlock::Create(context, "loop2.end");
5832+
5833+
// head
5834+
llvm_utils->start_new_block(loop2head);
5835+
{
5836+
llvm::Value *cond = builder->CreateICmpNE(
5837+
LLVM::CreateLoad(*builder, chain_itr),
5838+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
5839+
);
5840+
builder->CreateCondBr(cond, loop2body, loop2end);
5841+
}
5842+
5843+
// body
5844+
llvm_utils->start_new_block(loop2body);
5845+
{
5846+
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
5847+
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_pair_type->getPointerTo());
5848+
llvm::Value* kv_el = llvm_utils->create_gep(kv_struct, 0);
5849+
if( !LLVM::is_llvm_struct(key_type) ) {
5850+
kv_el = LLVM::CreateLoad(*builder, kv_el);
5851+
}
5852+
LLVM::CreateStore(*builder, kv_el, pvar);
5853+
for (size_t i=0; i<x.n_body; i++) {
5854+
this->visit_stmt(*x.m_body[i]);
5855+
}
5856+
call_lcompilers_free_strings();
5857+
llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2));
5858+
LLVM::CreateStore(*builder, next_kv_struct, chain_itr);
5859+
}
5860+
5861+
builder->CreateBr(loop2head);
5862+
5863+
// end
5864+
llvm_utils->start_new_block(loop2end);
5865+
}, [=]() {
5866+
});
5867+
llvm::Value* tmp = builder->CreateAdd(idx,
5868+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
5869+
LLVM::CreateStore(*builder, tmp, idx_ptr);
5870+
5871+
});
5872+
5873+
} else {
5874+
create_loop(nullptr, [=](){
5875+
call_lcompilers_free_strings();
5876+
return builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
5877+
}, [&](){
5878+
llvm::Value *idx = LLVM::CreateLoad(*builder, idx_ptr);
5879+
llvm::Value *key_mask_value = LLVM::CreateLoad(*builder,
5880+
llvm_utils->create_ptr_gep(key_mask, idx));
5881+
llvm::Value *is_key_skip = builder->CreateICmpEQ(key_mask_value,
5882+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context),
5883+
llvm::APInt(8, 3)));
5884+
llvm::Value *is_key_set = builder->CreateICmpNE(key_mask_value,
5885+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context),
5886+
llvm::APInt(8, 0)));
5887+
5888+
llvm::Value *el_exists = builder->CreateAnd(is_key_set,
5889+
builder->CreateNot(is_key_skip));
5890+
5891+
llvm_utils->create_if_else(el_exists, [&]() {
5892+
LLVM::CreateStore(*builder, llvm_utils->list_api->read_item(key_list, idx,
5893+
false, *module, LLVM::is_llvm_struct(key_type)), pvar);
5894+
5895+
for (size_t i=0; i<x.n_body; i++) {
5896+
this->visit_stmt(*x.m_body[i]);
5897+
}
5898+
call_lcompilers_free_strings();
5899+
}, [=](){});
5900+
5901+
idx = builder->CreateAdd(idx,
5902+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
5903+
LLVM::CreateStore(*builder, idx, idx_ptr);
5904+
});
5905+
}
5906+
} else if (ASR::is_a<ASR::Set_t>(*ASRUtils::expr_type(x.m_container))) {
5907+
ASR::Set_t *set_type = ASR::down_cast<ASR::Set_t>(
5908+
ASRUtils::expr_type(x.m_container));
5909+
ASR::ttype_t *el_type = set_type->m_type;
5910+
5911+
llvm::AllocaInst *idx_ptr = builder->CreateAlloca(
5912+
llvm::Type::getInt32Ty(context), nullptr);
5913+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(
5914+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr);
5915+
5916+
llvm::Value *capacity = LLVM::CreateLoad(*builder,
5917+
llvm_utils->set_api->get_pointer_to_capacity(pcontainer));
5918+
llvm::Value *el_list = llvm_utils->set_api->get_el_list(pcontainer);
5919+
llvm::Value *el_mask = LLVM::CreateLoad(*builder,
5920+
llvm_utils->set_api->get_pointer_to_mask(pcontainer));
5921+
5922+
create_loop(nullptr, [=](){
5923+
call_lcompilers_free_strings();
5924+
return builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr));
5925+
}, [&](){
5926+
llvm::Value *idx = LLVM::CreateLoad(*builder, idx_ptr);
5927+
llvm::Value *el_mask_value = LLVM::CreateLoad(*builder,
5928+
llvm_utils->create_ptr_gep(el_mask, idx));
5929+
llvm::Value *is_el_skip = builder->CreateICmpEQ(el_mask_value,
5930+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context),
5931+
llvm::APInt(8, 3)));
5932+
llvm::Value *is_el_set = builder->CreateICmpNE(el_mask_value,
5933+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context),
5934+
llvm::APInt(8, 0)));
5935+
5936+
llvm::Value *el_exists = builder->CreateAnd(is_el_set,
5937+
builder->CreateNot(is_el_skip));
5938+
5939+
llvm_utils->create_if_else(el_exists, [&]() {
5940+
LLVM::CreateStore(*builder, llvm_utils->list_api->read_item(el_list, idx,
5941+
false, *module, LLVM::is_llvm_struct(el_type)), pvar);
5942+
5943+
for (size_t i=0; i<x.n_body; i++) {
5944+
this->visit_stmt(*x.m_body[i]);
5945+
}
5946+
call_lcompilers_free_strings();
5947+
}, [=](){});
5948+
5949+
idx = builder->CreateAdd(idx,
5950+
llvm::ConstantInt::get(context, llvm::APInt(32, 1)));
5951+
LLVM::CreateStore(*builder, idx, idx_ptr);
5952+
});
5953+
} else {
5954+
throw CodeGenError("Only sets and dictionaries are supported with this loop for now.");
5955+
}
5956+
strings_to_be_deallocated.reserve(al, n);
5957+
strings_to_be_deallocated.n = n;
5958+
strings_to_be_deallocated.p = strings_to_be_deallocated_copy;
5959+
}
5960+
57785961
bool case_insensitive_string_compare(const std::string& str1, const std::string& str2) {
57795962
if (str1.size() != str2.size()) {
57805963
return false;

src/libasr/codegen/llvm_utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,11 @@ namespace LCompilers {
20752075
return get_key_value_pair_type(key_type_code, value_type_code);
20762076
}
20772077

2078+
llvm::Type* LLVMDict::get_key_value_pair_type(
2079+
ASR::ttype_t* /*key_asr_type*/, ASR::ttype_t* /*value_asr_type*/) {
2080+
return nullptr;
2081+
}
2082+
20782083
llvm::Type* LLVMDictSeparateChaining::get_dict_type(
20792084
std::string key_type_code, std::string value_type_code,
20802085
int32_t key_type_size, int32_t value_type_size,
@@ -2156,6 +2161,10 @@ namespace LCompilers {
21562161
return llvm_utils->create_gep(dict, 1);
21572162
}
21582163

2164+
llvm::Value* LLVMDict::get_pointer_to_key_value_pairs(llvm::Value* /*dict*/) {
2165+
return nullptr;
2166+
}
2167+
21592168
llvm::Value* LLVMDictSeparateChaining::get_pointer_to_key_value_pairs(llvm::Value* dict) {
21602169
return llvm_utils->create_gep(dict, 3);
21612170
}

src/libasr/codegen/llvm_utils.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ namespace LCompilers {
567567
virtual
568568
llvm::Value* get_pointer_to_occupancy(llvm::Value* dict) = 0;
569569

570+
virtual
571+
llvm::Value* get_pointer_to_keymask(llvm::Value* dict) = 0;
572+
570573
virtual
571574
llvm::Value* get_pointer_to_capacity(llvm::Value* dict) = 0;
572575

@@ -651,6 +654,13 @@ namespace LCompilers {
651654
std::map<std::string, std::map<std::string, int>>& name2memidx,
652655
bool key_or_value) = 0;
653656

657+
virtual
658+
llvm::Type* get_key_value_pair_type(ASR::ttype_t* key_asr_type, ASR::ttype_t* value_pair_type) = 0;
659+
660+
virtual
661+
llvm::Value* get_pointer_to_key_value_pairs(llvm::Value* dict) = 0;
662+
663+
654664
virtual ~LLVMDictInterface() = 0;
655665

656666
};
@@ -744,6 +754,10 @@ namespace LCompilers {
744754
std::map<std::string, std::map<std::string, int>>& name2memidx,
745755
bool key_or_value);
746756

757+
llvm::Type* get_key_value_pair_type(ASR::ttype_t* key_asr_type, ASR::ttype_t* value_pair_type);
758+
759+
llvm::Value* get_pointer_to_key_value_pairs(llvm::Value* dict);
760+
747761
virtual ~LLVMDict();
748762
};
749763

@@ -791,8 +805,6 @@ namespace LCompilers {
791805

792806
llvm::Value* get_pointer_to_number_of_filled_buckets(llvm::Value* dict);
793807

794-
llvm::Value* get_pointer_to_key_value_pairs(llvm::Value* dict);
795-
796808
llvm::Value* get_pointer_to_rehash_flag(llvm::Value* dict);
797809

798810
void deepcopy_key_value_pair_linked_list(llvm::Value* srci, llvm::Value* desti,
@@ -810,8 +822,6 @@ namespace LCompilers {
810822

811823
llvm::Type* get_key_value_pair_type(std::string key_type_code, std::string value_type_code);
812824

813-
llvm::Type* get_key_value_pair_type(ASR::ttype_t* key_asr_type, ASR::ttype_t* value_pair_type);
814-
815825
void dict_init_given_initial_capacity(std::string key_type_code, std::string value_type_code,
816826
llvm::Value* dict, llvm::Module* module, llvm::Value* initial_capacity);
817827

@@ -892,6 +902,10 @@ namespace LCompilers {
892902
std::map<std::string, std::map<std::string, int>>& name2memidx,
893903
bool key_or_value);
894904

905+
llvm::Type* get_key_value_pair_type(ASR::ttype_t* key_asr_type, ASR::ttype_t* value_pair_type);
906+
907+
llvm::Value* get_pointer_to_key_value_pairs(llvm::Value* dict);
908+
895909
virtual ~LLVMDictSeparateChaining();
896910

897911
};
@@ -939,6 +953,9 @@ namespace LCompilers {
939953
virtual
940954
llvm::Value* get_pointer_to_capacity(llvm::Value* set) = 0;
941955

956+
virtual
957+
llvm::Value* get_pointer_to_mask(llvm::Value* set) = 0;
958+
942959
llvm::Value* get_el_hash(llvm::Value* capacity, llvm::Value* el,
943960
ASR::ttype_t* el_asr_type, llvm::Module& module);
944961

0 commit comments

Comments
 (0)