diff --git a/integration_tests/test_set_len.py b/integration_tests/test_set_len.py index 33d252a0fe..8e66064dd3 100644 --- a/integration_tests/test_set_len.py +++ b/integration_tests/test_set_len.py @@ -3,6 +3,8 @@ def test_set(): s: set[i32] s = {1, 2, 22, 2, -1, 1} - assert len(s) == 4 + s2: set[str] + s2 = {'a', 'b', 'cd', 'b', 'abc', 'a'} + assert len(s2) == 4 -test_set() \ No newline at end of file +test_set() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 17a6705dde..cbcaaf2e19 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -175,7 +175,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::unique_ptr tuple_api; std::unique_ptr dict_api_lp; std::unique_ptr dict_api_sc; - std::unique_ptr set_api; // linear probing + std::unique_ptr set_api_lp; + std::unique_ptr set_api_sc; std::unique_ptr arr_descr; ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile, @@ -200,7 +201,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tuple_api(std::make_unique(context, llvm_utils.get(), builder.get())), dict_api_lp(std::make_unique(context, llvm_utils.get(), builder.get())), dict_api_sc(std::make_unique(context, llvm_utils.get(), builder.get())), - set_api(std::make_unique(context, llvm_utils.get(), builder.get())), + set_api_lp(std::make_unique(context, llvm_utils.get(), builder.get())), + set_api_sc(std::make_unique(context, llvm_utils.get(), builder.get())), arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context, builder.get(), llvm_utils.get(), LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor)) @@ -208,10 +210,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->tuple_api = tuple_api.get(); llvm_utils->list_api = list_api.get(); llvm_utils->dict_api = nullptr; - llvm_utils->set_api = set_api.get(); + llvm_utils->set_api = nullptr; llvm_utils->arr_api = arr_descr.get(); llvm_utils->dict_api_lp = dict_api_lp.get(); llvm_utils->dict_api_sc = dict_api_sc.get(); + llvm_utils->set_api_lp = set_api_lp.get(); + llvm_utils->set_api_sc = set_api_sc.get(); } llvm::Value* CreateLoad(llvm::Value *x) { @@ -1152,12 +1156,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get()); llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set"); ASR::Set_t* x_set = ASR::down_cast(x.m_type); + llvm_utils->set_set_api(x_set); std::string el_type_code = ASRUtils::get_type_code(x_set->m_type); llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements); int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type); int64_t ptr_loads_copy = ptr_loads; + ptr_loads = ptr_loads_el; for( size_t i = 0; i < x.n_elements; i++ ) { - ptr_loads = ptr_loads_el; visit_expr_wrapper(x.m_elements[i], true); llvm::Value* element = tmp; llvm_utils->set_api->write_item(const_set, element, module.get(), @@ -1516,6 +1521,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_arg); ptr_loads = ptr_loads_copy; llvm::Value* pset = tmp; + ASR::Set_t* x_set = ASR::down_cast(ASRUtils::expr_type(x.m_arg)); + llvm_utils->set_set_api(x_set); tmp = llvm_utils->set_api->len(pset); } @@ -1724,6 +1731,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + ASR::Set_t* set_type = ASR::down_cast( + ASRUtils::expr_type(m_arg)); ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; @@ -1734,10 +1743,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(m_ele, true); ptr_loads = ptr_loads_copy; llvm::Value *el = tmp; - set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx); + llvm_utils->set_set_api(set_type); + llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx); } void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + ASR::Set_t* set_type = ASR::down_cast( + ASRUtils::expr_type(m_arg)); ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; @@ -1748,7 +1760,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(m_ele, true); ptr_loads = ptr_loads_copy; llvm::Value *el = tmp; - set_api->remove_item(pset, el, *module, asr_el_type); + llvm_utils->set_set_api(set_type); + llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type); } void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) { @@ -2773,6 +2786,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_dict_present_copy_sc = dict_api_sc->is_dict_present(); dict_api_lp->set_is_dict_present(false); dict_api_sc->set_is_dict_present(false); + bool is_set_present_copy_lp = set_api_lp->is_set_present(); + bool is_set_present_copy_sc = set_api_sc->is_set_present(); + set_api_lp->set_is_set_present(false); + set_api_sc->set_is_set_present(false); llvm_goto_targets.clear(); // Generate code for nested subroutines and functions first: for (auto &item : x.m_symtab->get_scope()) { @@ -2832,6 +2849,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor builder->CreateRet(ret_val2); dict_api_lp->set_is_dict_present(is_dict_present_copy_lp); dict_api_sc->set_is_dict_present(is_dict_present_copy_sc); + set_api_lp->set_is_set_present(is_set_present_copy_lp); + set_api_sc->set_is_set_present(is_set_present_copy_sc); // Finalize the debug info. if (compiler_options.emit_debug_info) DBuilder->finalize(); @@ -3323,6 +3342,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_dict_present_copy_sc = dict_api_sc->is_dict_present(); dict_api_lp->set_is_dict_present(false); dict_api_sc->set_is_dict_present(false); + bool is_set_present_copy_lp = set_api_lp->is_set_present(); + bool is_set_present_copy_sc = set_api_sc->is_set_present(); + set_api_lp->set_is_set_present(false); + set_api_sc->set_is_set_present(false); llvm_goto_targets.clear(); instantiate_function(x); if (ASRUtils::get_FunctionType(x)->m_deftype == ASR::deftypeType::Interface) { @@ -3335,6 +3358,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor parent_function = nullptr; dict_api_lp->set_is_dict_present(is_dict_present_copy_lp); dict_api_sc->set_is_dict_present(is_dict_present_copy_sc); + set_api_lp->set_is_set_present(is_set_present_copy_lp); + set_api_sc->set_is_set_present(is_set_present_copy_sc); // Finalize the debug info. if (compiler_options.emit_debug_info) DBuilder->finalize(); @@ -4187,6 +4212,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* target_set = tmp; ptr_loads = ptr_loads_copy; ASR::Set_t* value_set_type = ASR::down_cast(asr_value_type); + llvm_utils->set_set_api(value_set_type); llvm_utils->set_api->set_deepcopy(value_set, target_set, value_set_type, module.get(), name2memidx); return ; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 07c4649b83..7fe6d28bf6 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -120,7 +120,8 @@ namespace LCompilers { name2dertype(name2dertype_), name2dercontext(name2dercontext_), struct_type_stack(struct_type_stack_), dertype2parent(dertype2parent_), name2memidx(name2memidx_), arr_arg_type_cache(arr_arg_type_cache_), fname2arg_type(fname2arg_type_), - dict_api_lp(nullptr), dict_api_sc(nullptr), compiler_options(compiler_options_) { + dict_api_lp(nullptr), dict_api_sc(nullptr), + set_api_lp(nullptr), set_api_sc(nullptr), compiler_options(compiler_options_) { std::vector els_4 = { llvm::Type::getFloatTy(context), llvm::Type::getFloatTy(context)}; @@ -599,6 +600,7 @@ namespace LCompilers { local_a_kind, module); int32_t el_type_size = get_type_size(asr_set->m_type, el_llvm_type, local_a_kind, module); std::string el_type_code = ASRUtils::get_type_code(asr_set->m_type); + set_set_api(asr_set); return set_api->get_set_type(el_type_code, el_type_size, el_llvm_type); } @@ -850,6 +852,7 @@ namespace LCompilers { is_list, m_dims, n_dims, a_kind, module, m_abi); int32_t el_type_size = get_type_size(asr_set->m_type, el_llvm_type, a_kind, module); + set_set_api(asr_set); type = set_api->get_set_type(el_type_code, el_type_size, el_llvm_type)->getPointerTo(); break; } @@ -873,6 +876,13 @@ namespace LCompilers { } } + void LLVMUtils::set_set_api(ASR::Set_t* /*set_type*/) { + // As per benchmarks, separate chaining + // does not provide significant gains over + // linear probing. + set_api = set_api_lp; + } + std::vector LLVMUtils::convert_args(const ASR::Function_t& x, llvm::Module* module) { std::vector args; for (size_t i=0; im_type, el_llvm_type, local_a_kind, module); + set_set_api(asr_set); + return_type = set_api->get_set_type(el_type_code, el_type_size, el_llvm_type); break; } @@ -1755,8 +1767,7 @@ namespace LCompilers { hash_value(nullptr), polynomial_powers(nullptr), chain_itr(nullptr), chain_itr_prev(nullptr), old_capacity(nullptr), old_key_value_pairs(nullptr), - old_key_mask(nullptr), are_iterators_set(false), - is_dict_present_(false) { + old_key_mask(nullptr), is_dict_present_(false) { } LLVMDict::LLVMDict(llvm::LLVMContext& context_, @@ -2577,6 +2588,20 @@ namespace LCompilers { llvm::Value* key, llvm::Value* key_value_pair_linked_list, llvm::Type* kv_pair_type, llvm::Value* key_mask, llvm::Module& module, ASR::ttype_t* key_asr_type) { + /** + * C++ equivalent: + * + * is_key_matching = 1; + * + * while( chain_itr != nullptr && is_key_matching ) { + * break_signal = key != kv_key; + * is_key_matching = break_signal; // 1 means not matching + * if( break_signal ) { + * chain_itr = next_kv_struct; + * } + * } + * + */ get_builder0() chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); chain_itr_prev = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); @@ -3787,9 +3812,8 @@ namespace LCompilers { llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict); ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type; - if( !are_iterators_set ) { - idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); - } + get_builder0(); + idx_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr); @@ -3841,10 +3865,9 @@ namespace LCompilers { ASR::ttype_t* value_asr_type, llvm::Module& module, std::map>& name2memidx, bool key_or_value) { - if( !are_iterators_set ) { - idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); - chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); - } + get_builder0() + idx_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr); @@ -4886,8 +4909,20 @@ namespace LCompilers { context(context_), llvm_utils(std::move(llvm_utils_)), builder(std::move(builder_)), - pos_ptr(nullptr), are_iterators_set(false), - is_set_present_(false) { + pos_ptr(nullptr), is_el_matching_var(nullptr), + idx_ptr(nullptr), hash_iter(nullptr), + hash_value(nullptr), polynomial_powers(nullptr), + chain_itr(nullptr), chain_itr_prev(nullptr), + old_capacity(nullptr), old_elems(nullptr), + old_el_mask(nullptr), is_set_present_(false) { + } + + bool LLVMSetInterface::is_set_present() { + return is_set_present_; + } + + void LLVMSetInterface::set_is_set_present(bool value) { + is_set_present_ = value; } LLVMSetLinearProbing::LLVMSetLinearProbing(llvm::LLVMContext& context_, @@ -4896,6 +4931,13 @@ namespace LCompilers { LLVMSetInterface(context_, llvm_utils_, builder_) { } + LLVMSetSeparateChaining::LLVMSetSeparateChaining( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils_, + llvm::IRBuilder<>* builder_): + LLVMSetInterface(context_, llvm_utils_, builder_) { + } + LLVMSetInterface::~LLVMSetInterface() { typecode2settype.clear(); } @@ -4903,6 +4945,9 @@ namespace LCompilers { LLVMSetLinearProbing::~LLVMSetLinearProbing() { } + LLVMSetSeparateChaining::~LLVMSetSeparateChaining() { + } + llvm::Value* LLVMSetLinearProbing::get_pointer_to_occupancy(llvm::Value* set) { return llvm_utils->create_gep(set, 0); } @@ -4920,6 +4965,34 @@ namespace LCompilers { return llvm_utils->create_gep(set, 2); } + llvm::Value* LLVMSetSeparateChaining::get_el_list(llvm::Value* /*set*/) { + return nullptr; + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_occupancy(llvm::Value* set) { + return llvm_utils->create_gep(set, 0); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_number_of_filled_buckets(llvm::Value* set) { + return llvm_utils->create_gep(set, 1); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_capacity(llvm::Value* set) { + return llvm_utils->create_gep(set, 2); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_elems(llvm::Value* set) { + return llvm_utils->create_gep(set, 3); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_mask(llvm::Value* set) { + return llvm_utils->create_gep(set, 4); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_rehash_flag(llvm::Value* set) { + return llvm_utils->create_gep(set, 5); + } + llvm::Type* LLVMSetLinearProbing::get_set_type(std::string type_code, int32_t type_size, llvm::Type* el_type) { is_set_present_ = true; @@ -4937,6 +5010,27 @@ namespace LCompilers { return set_desc; } + llvm::Type* LLVMSetSeparateChaining::get_set_type( + std::string el_type_code, int32_t el_type_size, llvm::Type* el_type) { + is_set_present_ = true; + if( typecode2settype.find(el_type_code) != typecode2settype.end() ) { + return std::get<0>(typecode2settype[el_type_code]); + } + + std::vector el_vec = {el_type, llvm::Type::getInt8PtrTy(context)}; + llvm::Type* elstruct = llvm::StructType::create(context, el_vec, "el"); + std::vector set_type_vec = {llvm::Type::getInt32Ty(context), + llvm::Type::getInt32Ty(context), + llvm::Type::getInt32Ty(context), + elstruct->getPointerTo(), + llvm::Type::getInt8PtrTy(context), + llvm::Type::getInt1Ty(context)}; + llvm::Type* set_desc = llvm::StructType::create(context, set_type_vec, "set"); + typecode2settype[el_type_code] = std::make_tuple(set_desc, el_type_size, el_type); + typecode2elstruct[el_type_code] = elstruct; + return set_desc; + } + void LLVMSetLinearProbing::set_init(std::string type_code, llvm::Value* set, llvm::Module* module, size_t initial_capacity) { llvm::Value* n_ptr = get_pointer_to_occupancy(set); @@ -4956,6 +5050,57 @@ namespace LCompilers { LLVM::CreateStore(*builder, el_mask, get_pointer_to_mask(set)); } + void LLVMSetSeparateChaining::set_init( + std::string el_type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity) { + llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, initial_capacity)); + llvm::Value* rehash_flag_ptr = get_pointer_to_rehash_flag(set); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 1)), rehash_flag_ptr); + set_init_given_initial_capacity(el_type_code, set, module, llvm_capacity); + } + + void LLVMSetSeparateChaining::set_init_given_initial_capacity( + std::string el_type_code, llvm::Value* set, + llvm::Module* module, llvm::Value* llvm_capacity) { + llvm::Value* rehash_flag_ptr = get_pointer_to_rehash_flag(set); + llvm::Value* rehash_flag = LLVM::CreateLoad(*builder, rehash_flag_ptr); + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + LLVM::CreateStore(*builder, llvm_zero, occupancy_ptr); + llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + LLVM::CreateStore(*builder, llvm_zero, num_buckets_filled_ptr); + + llvm::DataLayout data_layout(module); + llvm::Type* el_type = typecode2elstruct[el_type_code]; + size_t el_type_size = data_layout.getTypeAllocSize(el_type); + llvm::Value* llvm_el_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, el_type_size)); + llvm::Value* malloc_size = builder->CreateMul(llvm_capacity, llvm_el_size); + llvm::Value* el_ptr = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + rehash_flag = builder->CreateAnd(rehash_flag, + builder->CreateICmpNE(el_ptr, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + el_ptr = builder->CreateBitCast(el_ptr, el_type->getPointerTo()); + LLVM::CreateStore(*builder, el_ptr, get_pointer_to_elems(set)); + + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* el_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_capacity, + llvm_mask_size); + rehash_flag = builder->CreateAnd(rehash_flag, + builder->CreateICmpNE(el_mask, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + LLVM::CreateStore(*builder, el_mask, get_pointer_to_mask(set)); + + llvm::Value* capacity_ptr = get_pointer_to_capacity(set); + LLVM::CreateStore(*builder, llvm_capacity, capacity_ptr); + LLVM::CreateStore(*builder, rehash_flag, rehash_flag_ptr); + } + llvm::Value* LLVMSetInterface::get_el_hash( llvm::Value* capacity, llvm::Value* el, ASR::ttype_t* el_asr_type, llvm::Module& module) { @@ -5175,6 +5320,96 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMSetSeparateChaining::resolve_collision( + llvm::Value* el_hash, llvm::Value* el, llvm::Value* el_linked_list, + llvm::Type* el_struct_type, llvm::Value* el_mask, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * ll_exists = el_mask_value == 1; + * if( ll_exists ) { + * chain_itr = ll_head; + * } + * else { + * chain_itr = nullptr; + * } + * is_el_matching = 0; + * + * while( chain_itr != nullptr && !is_el_matching ) { + * chain_itr_prev = chain_itr; + * is_el_matching = (el == el_struct_el); + * if( !is_el_matching ) { + * chain_itr = next_el_struct; // (*chain_itr)[1] + * } + * } + * + * // now, chain_itr either points to element or is nullptr + * + */ + + get_builder0() + chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + chain_itr_prev = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + is_el_matching_var = builder0.CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr_prev); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm_utils->create_if_else(builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))), [&]() { + llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, el_ll_i8, chain_itr); + }, [&]() { + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr); + }); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(1, 0)), + is_el_matching_var + ); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + cond = builder->CreateAnd(cond, builder->CreateNot(LLVM::CreateLoad( + *builder, is_el_matching_var))); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + LLVM::CreateStore(*builder, el_struct_i8, chain_itr_prev); + llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo()); + llvm::Value* el_struct_el = llvm_utils->create_gep(el_struct, 0); + if( !LLVM::is_llvm_struct(el_asr_type) ) { + el_struct_el = LLVM::CreateLoad(*builder, el_struct_el); + } + LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(el, el_struct_el, + module, el_asr_type), is_el_matching_var); + llvm_utils->create_if_else(builder->CreateNot(LLVM::CreateLoad(*builder, is_el_matching_var)), [&]() { + llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1)); + LLVM::CreateStore(*builder, next_el_struct, chain_itr); + }, []() {}); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + + } + void LLVMSetLinearProbing::resolve_collision_for_write( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, llvm::Module* module, ASR::ttype_t* el_asr_type, @@ -5229,6 +5464,109 @@ namespace LCompilers { LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(el_mask, pos)); } + void LLVMSetSeparateChaining::resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * el_linked_list = elems[el_hash]; + * resolve_collision(el); // modifies chain_itr + * do_insert = chain_itr == nullptr; + * + * if( do_insert ) { + * if( chain_itr_prev != nullptr ) { + * new_el_struct = malloc(el_struct_size); + * new_el_struct[0] = el; + * new_el_struct[1] = nullptr; + * chain_itr_prev[1] = new_el_struct; + * } + * else { + * el_linked_list[0] = el; + * el_linked_list[1] = nullptr; + * } + * occupancy += 1; + * } + * else { + * el_struct[0] = el; + * } + * + * buckets_filled_delta = el_mask[el_hash] == 0; + * buckets_filled += buckets_filled_delta; + * el_mask[el_hash] = 1; + * + */ + + llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]; + this->resolve_collision(el_hash, el, el_linked_list, el_struct_type, + el_mask, *module, el_asr_type); + llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* do_insert = builder->CreateICmpEQ(el_struct_i8, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))); + builder->CreateCondBr(do_insert, thenBB, elseBB); + + builder->SetInsertPoint(thenBB); + { + llvm_utils->create_if_else(builder->CreateICmpNE( + LLVM::CreateLoad(*builder, chain_itr_prev), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), [&]() { + llvm::DataLayout data_layout(module); + size_t el_struct_size = data_layout.getTypeAllocSize(el_struct_type); + llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), el_struct_size); + llvm::Value* new_el_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + llvm::Value* new_el_struct = builder->CreateBitCast(new_el_struct_i8, el_struct_type->getPointerTo()); + llvm_utils->deepcopy(el, llvm_utils->create_gep(new_el_struct, 0), el_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(new_el_struct, 1)); + llvm::Value* el_struct_prev_i8 = LLVM::CreateLoad(*builder, chain_itr_prev); + llvm::Value* el_struct_prev = builder->CreateBitCast(el_struct_prev_i8, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, new_el_struct_i8, llvm_utils->create_gep(el_struct_prev, 1)); + }, [&]() { + llvm_utils->deepcopy(el, llvm_utils->create_gep(el_linked_list, 0), el_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(el_linked_list, 1)); + }); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateAdd(occupancy, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo()); + llvm_utils->deepcopy(el, llvm_utils->create_gep(el_struct, 0), el_asr_type, module, name2memidx); + } + llvm_utils->start_new_block(mergeBB); + llvm::Value* buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* el_mask_value_ptr = llvm_utils->create_ptr_gep(el_mask, el_hash); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, el_mask_value_ptr); + llvm::Value* buckets_filled_delta = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, buckets_filled_ptr); + buckets_filled = builder->CreateAdd( + buckets_filled, + builder->CreateZExt(buckets_filled_delta, llvm::Type::getInt32Ty(context)) + ); + LLVM::CreateStore(*builder, buckets_filled, buckets_filled_ptr); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)), + el_mask_value_ptr); + } + void LLVMSetLinearProbing::rehash( llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx) { @@ -5355,84 +5693,298 @@ namespace LCompilers { LLVM::CreateStore(*builder, new_el_mask, get_pointer_to_mask(set)); } - void LLVMSetLinearProbing::rehash_all_at_once_if_needed( + void LLVMSetSeparateChaining::rehash( llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx) { - /** * C++ equivalent: * - * occupancy += 1; - * load_factor = occupancy / capacity; - * load_factor_threshold = 0.6; - * rehash_condition = (capacity == 0) || (load_factor >= load_factor_threshold); - * if( rehash_condition ) { - * rehash(); - * } - * - */ - - llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); - llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity, - llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))); - occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, 1))); - occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context)); - capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context)); - llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity); - // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor - llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), - llvm::APFloat((float) 0.6)); - rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold)); - llvm_utils->create_if_else(rehash_condition, [&]() { - rehash(set, module, el_asr_type, name2memidx); - }, [=]() { - }); - } - - void LLVMSetLinearProbing::write_item( - llvm::Value* set, llvm::Value* el, - llvm::Module* module, ASR::ttype_t* el_asr_type, - std::map>& name2memidx) { - rehash_all_at_once_if_needed(set, module, el_asr_type, name2memidx); - llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, *module); - this->resolve_collision_for_write(set, el_hash, el, module, - el_asr_type, name2memidx); - } - - void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( - llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) { - - /** - * C++ equivalent: + * capacity = 3 * capacity + 1; * - * el_mask_value = el_mask[el_hash]; - * is_prob_needed = el_mask_value == 1; - * if( is_prob_needed ) { - * is_el_matching = el == el_list[el_hash]; - * if( is_el_matching ) { - * pos = el_hash; - * } - * else { - * exit(1); // el not present + * if( rehash_flag ) { + * while( old_capacity > idx ) { + * if( el_mask[el_hash] == 1 ) { + * write_el_linked_list(old_elems_value[idx]); + * } + * idx++; * } * } * else { - * resolve_collision(el, for_read=true); // modifies pos - * } - * - * is_el_matching = el == el_list[pos]; - * if( !is_el_matching ) { - * exit(1); // el not present + * // set to old values * } * */ get_builder0() - llvm::Value* el_list = get_el_list(set); + old_capacity = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + old_occupancy = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + old_number_of_buckets_filled = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + idx_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + old_elems = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + old_el_mask = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + + llvm::Value* capacity_ptr = get_pointer_to_capacity(set); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* number_of_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* old_capacity_value = LLVM::CreateLoad(*builder, capacity_ptr); + LLVM::CreateStore(*builder, old_capacity_value, old_capacity); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, occupancy_ptr), + old_occupancy + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, number_of_buckets_filled_ptr), + old_number_of_buckets_filled + ); + llvm::Value* old_el_mask_value = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value* old_elems_value = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + old_elems_value = builder->CreateBitCast(old_elems_value, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, old_el_mask_value, old_el_mask); + LLVM::CreateStore(*builder, old_elems_value, old_elems); + + llvm::Value* capacity = builder->CreateMul(old_capacity_value, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 3))); + capacity = builder->CreateAdd(capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + set_init_given_initial_capacity(ASRUtils::get_type_code(el_asr_type), + set, module, capacity); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB_rehash = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB_rehash = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB_rehash = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* rehash_flag = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(set)); + builder->CreateCondBr(rehash_flag, thenBB_rehash, elseBB_rehash); + + builder->SetInsertPoint(thenBB_rehash); + old_elems_value = LLVM::CreateLoad(*builder, old_elems); + old_elems_value = builder->CreateBitCast(old_elems_value, + typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]->getPointerTo()); + old_el_mask_value = LLVM::CreateLoad(*builder, old_el_mask); + old_capacity_value = LLVM::CreateLoad(*builder, old_capacity); + capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + old_capacity_value, + LLVM::CreateLoad(*builder, idx_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* itr = LLVM::CreateLoad(*builder, idx_ptr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(old_el_mask_value, itr)); + llvm::Value* is_el_set = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + + llvm_utils->create_if_else(is_el_set, [&]() { + llvm::Value* srci = llvm_utils->create_ptr_gep(old_elems_value, itr); + write_el_linked_list(srci, set, capacity, el_asr_type, module, name2memidx); + }, [=]() { + }); + llvm::Value* tmp = builder->CreateAdd( + itr, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, idx_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + builder->CreateBr(mergeBB_rehash); + llvm_utils->start_new_block(elseBB_rehash); + { + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_capacity), + get_pointer_to_capacity(set) + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_occupancy), + get_pointer_to_occupancy(set) + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_number_of_buckets_filled), + get_pointer_to_number_of_filled_buckets(set) + ); + LLVM::CreateStore(*builder, + builder->CreateBitCast( + LLVM::CreateLoad(*builder, old_elems), + typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]->getPointerTo() + ), + get_pointer_to_elems(set) + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_el_mask), + get_pointer_to_mask(set) + ); + } + llvm_utils->start_new_block(mergeBB_rehash); + } + + void LLVMSetSeparateChaining::write_el_linked_list( + llvm::Value* el_ll, llvm::Value* set, llvm::Value* capacity, + ASR::ttype_t* m_el_type, llvm::Module* module, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * while( src_itr != nullptr ) { + * resolve_collision_for_write(el_struct[0]); + * src_itr = el_struct[1]; + * } + * + */ + + get_builder0() + src_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(m_el_type)]->getPointerTo(); + LLVM::CreateStore(*builder, + builder->CreateBitCast(el_ll, llvm::Type::getInt8PtrTy(context)), + src_itr); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, src_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* curr_src = builder->CreateBitCast(LLVM::CreateLoad(*builder, src_itr), + el_struct_type); + llvm::Value* src_el_ptr = llvm_utils->create_gep(curr_src, 0); + llvm::Value* src_el = src_el_ptr; + if( !LLVM::is_llvm_struct(m_el_type) ) { + src_el = LLVM::CreateLoad(*builder, src_el_ptr); + } + llvm::Value* el_hash = get_el_hash(capacity, src_el, m_el_type, *module); + resolve_collision_for_write( + set, el_hash, src_el, module, + m_el_type, name2memidx); + + llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 1)); + LLVM::CreateStore(*builder, src_next_ptr, src_itr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + void LLVMSetLinearProbing::rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * // this condition will be true with 0 capacity too + * rehash_condition = 5 * occupancy >= 3 * capacity; + * if( rehash_condition ) { + * rehash(); + * } + * + */ + + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor + // occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity + llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 5))); + llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 3))); + llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5, + capacity_times_3), [&]() { + rehash(set, module, el_asr_type, name2memidx); + }, []() {}); + } + + void LLVMSetSeparateChaining::rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * rehash_condition = rehash_flag && occupancy >= 2 * buckets_filled; + * if( rehash_condition ) { + * rehash(); + * } + * + */ + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); + llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(set)); + llvm::Value* rehash_condition = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(set)); + llvm::Value* buckets_filled_times_2 = builder->CreateMul(buckets_filled, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 2))); + rehash_condition = builder->CreateAnd(rehash_condition, + builder->CreateICmpSGE(occupancy, buckets_filled_times_2)); + llvm_utils->create_if_else(rehash_condition, [&]() { + rehash(set, module, el_asr_type, name2memidx); + }, []() {}); + } + + void LLVMSetInterface::write_item( + llvm::Value* set, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + rehash_all_at_once_if_needed(set, module, el_asr_type, name2memidx); + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, *module); + this->resolve_collision_for_write(set, el_hash, el, module, + el_asr_type, name2memidx); + } + + void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + + /** + * C++ equivalent: + * + * el_mask_value = el_mask[el_hash]; + * is_prob_needed = el_mask_value == 1; + * if( is_prob_needed ) { + * is_el_matching = el == el_list[el_hash]; + * if( is_el_matching ) { + * pos = el_hash; + * } + * else { + * exit(1); // el not present + * } + * } + * else { + * resolve_collision(el, for_read=true); // modifies pos + * } + * + * is_el_matching = el == el_list[pos]; + * if( !is_el_matching ) { + * exit(1); // el not present + * } + * + */ + + get_builder0() + llvm::Value* el_list = get_el_list(set); llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); @@ -5479,8 +6031,48 @@ namespace LCompilers { llvm_utils->list_api->read_item(el_list, pos, false, module, LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); - llvm_utils->create_if_else(is_el_matching, [&]() { - }, [&]() { + llvm_utils->create_if_else(is_el_matching, []() {}, [&]() { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + } + + void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * resolve_collision(el); // modified chain_itr + * does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr; + * if( !does_el_exist ) { + * exit(1); // KeyError + * } + * + */ + llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + std::string el_type_code = ASRUtils::get_type_code(el_asr_type); + llvm::Type* el_struct_type = typecode2elstruct[el_type_code]; + this->resolve_collision(el_hash, el, el_linked_list, + el_struct_type, el_mask, module, el_asr_type); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm::Value* does_el_exist = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + does_el_exist = builder->CreateAnd(does_el_exist, + builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + + llvm_utils->create_if_else(does_el_exist, []() {}, [&]() { std::string message = "The set does not contain the specified element"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -5518,6 +6110,75 @@ namespace LCompilers { LLVM::CreateStore(*builder, occupancy, occupancy_ptr); } + void LLVMSetSeparateChaining::remove_item( + llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * // modifies chain_itr and chain_itr_prev + * resolve_collision_for_read_with_bound_check(el); + * + * if(chain_itr_prev != nullptr) { + * chain_itr_prev[1] = chain_itr[1]; // next + * } + * else { + * // this linked list is now empty + * el_mask[el_hash] = 0; + * num_buckets_filled--; + * } + * + * occupancy--; + * + */ + + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module); + this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type); + llvm::Value* prev = LLVM::CreateLoad(*builder, chain_itr_prev); + llvm::Value* found = LLVM::CreateLoad(*builder, chain_itr); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + + builder->CreateCondBr( + builder->CreateICmpNE(prev, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), + thenBB, elseBB + ); + builder->SetInsertPoint(thenBB); + { + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]; + found = builder->CreateBitCast(found, el_struct_type->getPointerTo()); + llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 1)); + prev = builder->CreateBitCast(prev, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 1)); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + LLVM::CreateStore( + *builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)), + llvm_utils->create_ptr_gep(el_mask, el_hash) + ); + llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr); + num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr); + } + llvm_utils->start_new_block(mergeBB); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); + } + void LLVMSetLinearProbing::set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module, @@ -5547,7 +6208,179 @@ namespace LCompilers { LLVM::CreateStore(*builder, dest_el_mask, dest_el_mask_ptr); } - llvm::Value* LLVMSetLinearProbing::len(llvm::Value* set) { + void LLVMSetSeparateChaining::set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx) { + llvm::Value* src_occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(src)); + llvm::Value* src_filled_buckets = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(src)); + llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(src)); + llvm::Value* src_el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(src)); + llvm::Value* src_rehash_flag = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(src)); + LLVM::CreateStore(*builder, src_occupancy, get_pointer_to_occupancy(dest)); + LLVM::CreateStore(*builder, src_filled_buckets, get_pointer_to_number_of_filled_buckets(dest)); + LLVM::CreateStore(*builder, src_capacity, get_pointer_to_capacity(dest)); + LLVM::CreateStore(*builder, src_rehash_flag, get_pointer_to_rehash_flag(dest)); + llvm::DataLayout data_layout(module); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* malloc_size = builder->CreateMul(src_capacity, llvm_mask_size); + llvm::Value* dest_el_mask = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + LLVM::CreateStore(*builder, dest_el_mask, get_pointer_to_mask(dest)); + + // number of elements to be copied = capacity + (occupancy - filled_buckets) + malloc_size = builder->CreateSub(src_occupancy, src_filled_buckets); + malloc_size = builder->CreateAdd(src_capacity, malloc_size); + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(set_type->m_type)]; + size_t el_struct_size = data_layout.getTypeAllocSize(el_struct_type); + llvm::Value* llvm_el_struct_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, el_struct_size)); + malloc_size = builder->CreateMul(malloc_size, llvm_el_struct_size); + llvm::Value* dest_elems = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + dest_elems = builder->CreateBitCast(dest_elems, el_struct_type->getPointerTo()); + get_builder0() + copy_itr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + next_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + LLVM::CreateStore(*builder, llvm_zero, copy_itr); + LLVM::CreateStore(*builder, src_capacity, next_ptr); + + llvm::Value* src_elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(src)); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + src_capacity, + LLVM::CreateLoad(*builder, copy_itr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* itr = LLVM::CreateLoad(*builder, copy_itr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(src_el_mask, itr)); + LLVM::CreateStore(*builder, el_mask_value, + llvm_utils->create_ptr_gep(dest_el_mask, itr)); + llvm::Value* is_el_set = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + + llvm_utils->create_if_else(is_el_set, [&]() { + llvm::Value* srci = llvm_utils->create_ptr_gep(src_elems, itr); + llvm::Value* desti = llvm_utils->create_ptr_gep(dest_elems, itr); + deepcopy_el_linked_list(srci, desti, dest_elems, + set_type, module, name2memidx); + }, []() {}); + llvm::Value* tmp = builder->CreateAdd( + itr, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, copy_itr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + LLVM::CreateStore(*builder, dest_elems, get_pointer_to_elems(dest)); + } + + void LLVMSetSeparateChaining::deepcopy_el_linked_list( + llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_elems, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * // memory allocation done before calling this function + * + * while( src_itr != nullptr ) { + * deepcopy(src_el, curr_dest_ptr); + * src_itr = src_itr_next; + * if( src_next_exists ) { + * *next_ptr = *next_ptr + 1; + * curr_dest[1] = &dest_elems[*next_ptr]; + * curr_dest = *curr_dest[1]; + * } + * else { + * curr_dest[1] = nullptr; + * } + * } + * + */ + get_builder0() + src_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + dest_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(set_type->m_type)]->getPointerTo(); + LLVM::CreateStore(*builder, + builder->CreateBitCast(srci, llvm::Type::getInt8PtrTy(context)), + src_itr); + LLVM::CreateStore(*builder, + builder->CreateBitCast(desti, llvm::Type::getInt8PtrTy(context)), + dest_itr); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, src_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* curr_src = builder->CreateBitCast(LLVM::CreateLoad(*builder, src_itr), + el_struct_type); + llvm::Value* curr_dest = builder->CreateBitCast(LLVM::CreateLoad(*builder, dest_itr), + el_struct_type); + llvm::Value* src_el_ptr = llvm_utils->create_gep(curr_src, 0); + llvm::Value *src_el = src_el_ptr; + if( !LLVM::is_llvm_struct(set_type->m_type) ) { + src_el = LLVM::CreateLoad(*builder, src_el_ptr); + } + llvm::Value* dest_el_ptr = llvm_utils->create_gep(curr_dest, 0); + llvm_utils->deepcopy(src_el, dest_el_ptr, set_type->m_type, module, name2memidx); + + llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 1)); + llvm::Value* curr_dest_next_ptr = llvm_utils->create_gep(curr_dest, 1); + LLVM::CreateStore(*builder, src_next_ptr, src_itr); + + llvm::Value* src_next_exists = builder->CreateICmpNE(src_next_ptr, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))); + llvm_utils->create_if_else(src_next_exists, [&]() { + llvm::Value* next_idx = LLVM::CreateLoad(*builder, next_ptr); + llvm::Value* dest_next_ptr = llvm_utils->create_ptr_gep(dest_elems, next_idx); + dest_next_ptr = builder->CreateBitCast(dest_next_ptr, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, dest_next_ptr, curr_dest_next_ptr); + LLVM::CreateStore(*builder, dest_next_ptr, dest_itr); + next_idx = builder->CreateAdd(next_idx, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, next_idx, next_ptr); + }, [&]() { + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + curr_dest_next_ptr + ); + }); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + llvm::Value* LLVMSetInterface::len(llvm::Value* set) { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 866cd05b68..d5d1264c8d 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -202,6 +202,8 @@ namespace LCompilers { LLVMDictInterface* dict_api_lp; LLVMDictInterface* dict_api_sc; + LLVMSetInterface* set_api_lp; + LLVMSetInterface* set_api_sc; CompilerOptions &compiler_options; @@ -296,6 +298,8 @@ namespace LCompilers { void set_dict_api(ASR::Dict_t* dict_type); + void set_set_api(ASR::Set_t* set_type); + void deepcopy(llvm::Value* src, llvm::Value* dest, ASR::ttype_t* asr_type, llvm::Module* module, std::map>& name2memidx); @@ -504,7 +508,6 @@ namespace LCompilers { llvm::AllocaInst *old_occupancy, *old_number_of_buckets_filled; llvm::AllocaInst *src_itr, *dest_itr, *next_ptr, *copy_itr; llvm::Value *tmp_value_ptr; - bool are_iterators_set; std::map, std::tuple, @@ -886,7 +889,10 @@ namespace LCompilers { llvm::AllocaInst *pos_ptr, *is_el_matching_var; llvm::AllocaInst *idx_ptr, *hash_iter, *hash_value; llvm::AllocaInst *polynomial_powers; - bool are_iterators_set; + llvm::AllocaInst *chain_itr, *chain_itr_prev; + llvm::AllocaInst *old_capacity, *old_elems, *old_el_mask; + llvm::AllocaInst *old_occupancy, *old_number_of_buckets_filled; + llvm::AllocaInst *src_itr, *dest_itr, *next_ptr, *copy_itr; std::map> typecode2settype; @@ -919,13 +925,6 @@ namespace LCompilers { llvm::Value* get_el_hash(llvm::Value* capacity, llvm::Value* el, ASR::ttype_t* el_asr_type, llvm::Module& module); - virtual - void resolve_collision( - llvm::Value* capacity, llvm::Value* el_hash, - llvm::Value* el, llvm::Value* el_list, - llvm::Value* el_mask, llvm::Module& module, - ASR::ttype_t* el_asr_type, bool for_read=false) = 0; - virtual void resolve_collision_for_write( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, @@ -946,7 +945,7 @@ namespace LCompilers { void write_item( llvm::Value* set, llvm::Value* el, llvm::Module* module, ASR::ttype_t* el_asr_type, - std::map>& name2memidx) = 0; + std::map>& name2memidx); virtual void resolve_collision_for_read_with_bound_check( @@ -965,7 +964,13 @@ namespace LCompilers { std::map>& name2memidx) = 0; virtual - llvm::Value* len(llvm::Value* set) = 0; + llvm::Value* len(llvm::Value* set); + + virtual + bool is_set_present(); + + virtual + void set_is_set_present(bool value); virtual ~LLVMSetInterface() = 0; @@ -1014,11 +1019,87 @@ namespace LCompilers { llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); - void write_item( + void resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void remove_item( llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx); + + ~LLVMSetLinearProbing(); + }; + + class LLVMSetSeparateChaining: public LLVMSetInterface { + + protected: + + std::map typecode2elstruct; + + llvm::Value* get_pointer_to_number_of_filled_buckets(llvm::Value* set); + + llvm::Value* get_pointer_to_elems(llvm::Value* set); + + llvm::Value* get_pointer_to_rehash_flag(llvm::Value* set); + + void set_init_given_initial_capacity(std::string el_type_code, + llvm::Value* set, llvm::Module* module, llvm::Value* initial_capacity); + + void resolve_collision( + llvm::Value* el_hash, llvm::Value* el, llvm::Value* el_linked_list, + llvm::Type* el_struct_type, llvm::Value* el_mask, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void write_el_linked_list( + llvm::Value* el_ll, llvm::Value* set, llvm::Value* capacity, + ASR::ttype_t* m_el_type, llvm::Module* module, + std::map>& name2memidx); + + void deepcopy_el_linked_list( + llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_elems, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx); + + public: + + LLVMSetSeparateChaining( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); + + llvm::Type* get_set_type( + std::string type_code, + int32_t type_size, llvm::Type* el_type); + + void set_init(std::string type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity); + + llvm::Value* get_el_list(llvm::Value* set); + + llvm::Value* get_pointer_to_occupancy(llvm::Value* set); + + llvm::Value* get_pointer_to_capacity(llvm::Value* set); + + llvm::Value* get_pointer_to_mask(llvm::Value* set); + + void resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); + void rehash( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + + void rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + void resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, llvm::Module& module, ASR::ttype_t* el_asr_type); @@ -1032,9 +1113,7 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); - llvm::Value* len(llvm::Value* set); - - ~LLVMSetLinearProbing(); + ~LLVMSetSeparateChaining(); }; } // namespace LCompilers