Skip to content

Commit 3927b92

Browse files
authored
Separate chaining collision resolution for set (#2198)
1 parent 0eb1482 commit 3927b92

File tree

4 files changed

+1042
-102
lines changed

4 files changed

+1042
-102
lines changed

integration_tests/test_set_len.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
def test_set():
44
s: set[i32]
55
s = {1, 2, 22, 2, -1, 1}
6-
assert len(s) == 4
6+
s2: set[str]
7+
s2 = {'a', 'b', 'cd', 'b', 'abc', 'a'}
8+
assert len(s2) == 4
79

8-
test_set()
10+
test_set()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
175175
std::unique_ptr<LLVMTuple> tuple_api;
176176
std::unique_ptr<LLVMDictInterface> dict_api_lp;
177177
std::unique_ptr<LLVMDictInterface> dict_api_sc;
178-
std::unique_ptr<LLVMSetInterface> set_api; // linear probing
178+
std::unique_ptr<LLVMSetInterface> set_api_lp;
179+
std::unique_ptr<LLVMSetInterface> set_api_sc;
179180
std::unique_ptr<LLVMArrUtils::Descriptor> arr_descr;
180181

181182
ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile,
@@ -200,18 +201,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
200201
tuple_api(std::make_unique<LLVMTuple>(context, llvm_utils.get(), builder.get())),
201202
dict_api_lp(std::make_unique<LLVMDictOptimizedLinearProbing>(context, llvm_utils.get(), builder.get())),
202203
dict_api_sc(std::make_unique<LLVMDictSeparateChaining>(context, llvm_utils.get(), builder.get())),
203-
set_api(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
204+
set_api_lp(std::make_unique<LLVMSetLinearProbing>(context, llvm_utils.get(), builder.get())),
205+
set_api_sc(std::make_unique<LLVMSetSeparateChaining>(context, llvm_utils.get(), builder.get())),
204206
arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context,
205207
builder.get(), llvm_utils.get(),
206208
LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor))
207209
{
208210
llvm_utils->tuple_api = tuple_api.get();
209211
llvm_utils->list_api = list_api.get();
210212
llvm_utils->dict_api = nullptr;
211-
llvm_utils->set_api = set_api.get();
213+
llvm_utils->set_api = nullptr;
212214
llvm_utils->arr_api = arr_descr.get();
213215
llvm_utils->dict_api_lp = dict_api_lp.get();
214216
llvm_utils->dict_api_sc = dict_api_sc.get();
217+
llvm_utils->set_api_lp = set_api_lp.get();
218+
llvm_utils->set_api_sc = set_api_sc.get();
215219
}
216220

217221
llvm::Value* CreateLoad(llvm::Value *x) {
@@ -1152,12 +1156,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
11521156
llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get());
11531157
llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set");
11541158
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(x.m_type);
1159+
llvm_utils->set_set_api(x_set);
11551160
std::string el_type_code = ASRUtils::get_type_code(x_set->m_type);
11561161
llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements);
11571162
int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type);
11581163
int64_t ptr_loads_copy = ptr_loads;
1164+
ptr_loads = ptr_loads_el;
11591165
for( size_t i = 0; i < x.n_elements; i++ ) {
1160-
ptr_loads = ptr_loads_el;
11611166
visit_expr_wrapper(x.m_elements[i], true);
11621167
llvm::Value* element = tmp;
11631168
llvm_utils->set_api->write_item(const_set, element, module.get(),
@@ -1516,6 +1521,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15161521
this->visit_expr(*x.m_arg);
15171522
ptr_loads = ptr_loads_copy;
15181523
llvm::Value* pset = tmp;
1524+
ASR::Set_t* x_set = ASR::down_cast<ASR::Set_t>(ASRUtils::expr_type(x.m_arg));
1525+
llvm_utils->set_set_api(x_set);
15191526
tmp = llvm_utils->set_api->len(pset);
15201527
}
15211528

@@ -1724,6 +1731,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17241731
}
17251732

17261733
void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1734+
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
1735+
ASRUtils::expr_type(m_arg));
17271736
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
17281737
int64_t ptr_loads_copy = ptr_loads;
17291738
ptr_loads = 0;
@@ -1734,10 +1743,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17341743
this->visit_expr_wrapper(m_ele, true);
17351744
ptr_loads = ptr_loads_copy;
17361745
llvm::Value *el = tmp;
1737-
set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
1746+
llvm_utils->set_set_api(set_type);
1747+
llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
17381748
}
17391749

17401750
void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
1751+
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
1752+
ASRUtils::expr_type(m_arg));
17411753
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
17421754
int64_t ptr_loads_copy = ptr_loads;
17431755
ptr_loads = 0;
@@ -1748,7 +1760,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
17481760
this->visit_expr_wrapper(m_ele, true);
17491761
ptr_loads = ptr_loads_copy;
17501762
llvm::Value *el = tmp;
1751-
set_api->remove_item(pset, el, *module, asr_el_type);
1763+
llvm_utils->set_set_api(set_type);
1764+
llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type);
17521765
}
17531766

17541767
void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) {
@@ -2773,6 +2786,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
27732786
bool is_dict_present_copy_sc = dict_api_sc->is_dict_present();
27742787
dict_api_lp->set_is_dict_present(false);
27752788
dict_api_sc->set_is_dict_present(false);
2789+
bool is_set_present_copy_lp = set_api_lp->is_set_present();
2790+
bool is_set_present_copy_sc = set_api_sc->is_set_present();
2791+
set_api_lp->set_is_set_present(false);
2792+
set_api_sc->set_is_set_present(false);
27762793
llvm_goto_targets.clear();
27772794
// Generate code for nested subroutines and functions first:
27782795
for (auto &item : x.m_symtab->get_scope()) {
@@ -2832,6 +2849,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
28322849
builder->CreateRet(ret_val2);
28332850
dict_api_lp->set_is_dict_present(is_dict_present_copy_lp);
28342851
dict_api_sc->set_is_dict_present(is_dict_present_copy_sc);
2852+
set_api_lp->set_is_set_present(is_set_present_copy_lp);
2853+
set_api_sc->set_is_set_present(is_set_present_copy_sc);
28352854

28362855
// Finalize the debug info.
28372856
if (compiler_options.emit_debug_info) DBuilder->finalize();
@@ -3323,6 +3342,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
33233342
bool is_dict_present_copy_sc = dict_api_sc->is_dict_present();
33243343
dict_api_lp->set_is_dict_present(false);
33253344
dict_api_sc->set_is_dict_present(false);
3345+
bool is_set_present_copy_lp = set_api_lp->is_set_present();
3346+
bool is_set_present_copy_sc = set_api_sc->is_set_present();
3347+
set_api_lp->set_is_set_present(false);
3348+
set_api_sc->set_is_set_present(false);
33263349
llvm_goto_targets.clear();
33273350
instantiate_function(x);
33283351
if (ASRUtils::get_FunctionType(x)->m_deftype == ASR::deftypeType::Interface) {
@@ -3335,6 +3358,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
33353358
parent_function = nullptr;
33363359
dict_api_lp->set_is_dict_present(is_dict_present_copy_lp);
33373360
dict_api_sc->set_is_dict_present(is_dict_present_copy_sc);
3361+
set_api_lp->set_is_set_present(is_set_present_copy_lp);
3362+
set_api_sc->set_is_set_present(is_set_present_copy_sc);
33383363

33393364
// Finalize the debug info.
33403365
if (compiler_options.emit_debug_info) DBuilder->finalize();
@@ -4187,6 +4212,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
41874212
llvm::Value* target_set = tmp;
41884213
ptr_loads = ptr_loads_copy;
41894214
ASR::Set_t* value_set_type = ASR::down_cast<ASR::Set_t>(asr_value_type);
4215+
llvm_utils->set_set_api(value_set_type);
41904216
llvm_utils->set_api->set_deepcopy(value_set, target_set,
41914217
value_set_type, module.get(), name2memidx);
41924218
return ;

0 commit comments

Comments
 (0)