From 936ab1ee3eb6169426e5dd5b36ca833dd05a3525 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Wed, 5 Jul 2023 09:26:06 +0530 Subject: [PATCH 1/6] init set --- src/libasr/codegen/llvm_utils.cpp | 333 ++++++++++++++++++++++++++++++ src/libasr/codegen/llvm_utils.h | 83 ++++++++ 2 files changed, 416 insertions(+) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index f965ac4538..571a6a8613 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -3271,4 +3271,337 @@ namespace LCompilers { &module, name2memidx); } + LLVMSetInterface::LLVMSetInterface(llvm::LLVMContext& context_, + LLVMUtils* llvm_utils_, + llvm::IRBuilder<>* builder_): + context(context_), + llvm_utils(std::move(llvm_utils_)), + builder(std::move(builder_)), + pos_ptr(nullptr), is_set_present_(false) { + } + + LLVMSetLinearProbing::LLVMSetLinearProbing(llvm::LLVMContext& context_, + LLVMUtils* llvm_utils_, + llvm::IRBuilder<>* builder_): + LLVMSetInterface(context_, llvm_utils_, builder_) { + } + + LLVMSetInterface::~LLVMSetInterface() { + typecode2settype.clear(); + } + + LLVMSetLinearProbing::~LLVMSetLinearProbing() { + } + + llvm::Value* LLVMSetLinearProbing::get_pointer_to_occupancy(llvm::Value* set) { + return llvm_utils->create_gep(set, 0); + } + + llvm::Value* LLVMSetLinearProbing::get_pointer_to_capacity(llvm::Value* set) { + return llvm_utils->list_api->get_pointer_to_current_capacity( + get_el_list(set)); + } + + llvm::Value* LLVMSetLinearProbing::get_el_list(llvm::Value* set) { + return llvm_utils->create_gep(set, 1); + } + + llvm::Value* LLVMSetLinearProbing::get_pointer_to_mask(llvm::Value* set) { + return llvm_utils->create_gep(set, 2); + } + + llvm::Type* LLVMSetLinearProbing::get_set_type(std::string type_code, int32_t type_size, + llvm::Type* el_type) { + is_set_present_ = true; + if( typecode2settype.find(type_code) != typecode2settype.end() ) { + return std::get<0>(typecode2settype[type_code]); + } + + llvm::Type* el_list_type = llvm_utils->list_api->get_list_type(el_type, + type_code, type_size); + std::vector set_type_vec = {llvm::Type::getInt32Ty(context), + el_list_type, + llvm::Type::getInt8PtrTy(context)}; + llvm::Type* set_desc = llvm::StructType::create(context, set_type_vec, "set"); + typecode2settype[type_code] = std::make_tuple(set_desc, type_size, el_type); + return set_desc; + } + + void LLVMSetLinearProbing::set_init(std::string type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity) { + llvm::Value* n_ptr = get_pointer_to_occupancy(set); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), n_ptr); + llvm::Value* el_list = get_el_list(set); + llvm_utils->list_api->list_init(type_code, el_list, *module, + initial_capacity, initial_capacity); + llvm::DataLayout data_layout(module); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, initial_capacity)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* el_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_capacity, + llvm_mask_size); + LLVM::CreateStore(*builder, el_mask, get_pointer_to_mask(set)); + } + + void LLVMSetInterface::set_iterators() { + if( are_iterators_set || !is_set_present_ ) { + return ; + } + llvm_utils->set_iterators(); + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr, "pos_ptr"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), pos_ptr); + is_el_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr, + "is_el_matching_var"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 0)), is_el_matching_var); + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr, "idx_ptr"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), idx_ptr); + hash_value = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_value"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), + llvm::APInt(64, 0)), hash_value); + hash_iter = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_iter"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), + llvm::APInt(64, 0)), hash_iter); + polynomial_powers = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), + llvm::APInt(64, 1)), polynomial_powers); + are_iterators_set = true; + } + + void LLVMSetInterface::reset_iterators() { + llvm_utils->reset_iterators(); + pos_ptr = nullptr; + is_el_matching_var = nullptr; + idx_ptr = nullptr; + hash_iter = nullptr; + hash_value = nullptr; + polynomial_powers = nullptr; + are_iterators_set = false; + } + + llvm::Value* LLVMSetInterface::get_el_hash(llvm::Value* capacity, llvm::Value* el, + ASR::ttype_t* el_asr_type, llvm::Module& module) { + // Write specialised hash functions for intrinsic types + // This is to avoid unnecessary calls to C-runtime and do + // as much as possible in LLVM directly. + switch( el_asr_type->type ) { + case ASR::ttypeType::Integer: { + // Simple modulo with the capacity of the set. + // We can update it later to do a better hash function + // which produces lesser collisions. + + llvm::Value* int_hash = builder->CreateZExtOrTrunc( + builder->CreateURem(el, + builder->CreateZExtOrTrunc(capacity, el->getType())), + capacity->getType() + ); + return int_hash; + } + case ASR::ttypeType::Character: { + // Polynomial rolling hash function for strings + llvm::Value* null_char = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), + llvm::APInt(8, '\0')); + llvm::Value* p = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 31)); + llvm::Value* m = llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 100000009)); + if( !are_iterators_set ) { + hash_value = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_value"); + hash_iter = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "hash_iter"); + polynomial_powers = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); + } + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_value); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1)), + polynomial_powers); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 0)), + hash_iter); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el, i)); + llvm::Value *cond = builder->CreateICmpNE(c, null_char); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + // for c in el: + // hash_value = (hash_value + (ord(c) + 1) * p_pow) % m + // p_pow = (p_pow * p) % m + llvm::Value* i = LLVM::CreateLoad(*builder, hash_iter); + llvm::Value* c = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el, i)); + llvm::Value* p_pow = LLVM::CreateLoad(*builder, polynomial_powers); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + c = builder->CreateZExt(c, llvm::Type::getInt64Ty(context)); + c = builder->CreateAdd(c, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + c = builder->CreateMul(c, p_pow); + c = builder->CreateSRem(c, m); + hash = builder->CreateAdd(hash, c); + hash = builder->CreateSRem(hash, m); + LLVM::CreateStore(*builder, hash, hash_value); + p_pow = builder->CreateMul(p_pow, p); + p_pow = builder->CreateSRem(p_pow, m); + LLVM::CreateStore(*builder, p_pow, polynomial_powers); + i = builder->CreateAdd(i, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1))); + LLVM::CreateStore(*builder, i, hash_iter); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + llvm::Value* hash = LLVM::CreateLoad(*builder, hash_value); + hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context)); + return builder->CreateSRem(hash, capacity); + } + case ASR::ttypeType::Tuple: { + llvm::Value* tuple_hash = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + ASR::Tuple_t* asr_tuple = ASR::down_cast(el_asr_type); + for( size_t i = 0; i < asr_tuple->n_type; i++ ) { + llvm::Value* llvm_tuple_i = llvm_utils->tuple_api->read_item(el, i, + LLVM::is_llvm_struct(asr_tuple->m_type[i])); + tuple_hash = builder->CreateAdd(tuple_hash, get_el_hash(capacity, llvm_tuple_i, + asr_tuple->m_type[i], module)); + tuple_hash = builder->CreateSRem(tuple_hash, capacity); + } + return tuple_hash; + } + case ASR::ttypeType::Logical: { + return builder->CreateZExt(el, llvm::Type::getInt32Ty(context)); + } + default: { + throw LCompilersException("Hashing " + ASRUtils::type_to_str_python(el_asr_type) + + " isn't implemented yet."); + } + } + } + + void LLVMSetLinearProbing::resolve_collision(llvm::Value* capacity, llvm::Value* el_hash, + llvm::Value* el, llvm::Value* el_list, + llvm::Value* el_mask, llvm::Module& module, + ASR::ttype_t* el_asr_type, bool for_read) { + + /** + * C++ equivalent: + * + * pos = el_hash; + * + * while( true ) { + * is_el_skip = el_mask_value == 3; // tombstone + * is_el_set = el_mask_value != 0; + * is_el_matching = 0; + * + * compare_elems = is_el_set && !is_el_skip; + * if( compare_elems ) { + * original_el = el_list[pos]; + * is_el_matching = el == original_el; + * } + * + * cond; + * if( for_read ) { + * // for reading, continue to next pos + * // even if current pos is tombstone + * cond = (is_el_set && !is_el_matching) || is_el_skip; + * } + * else { + * // for writing, do not continue + * // if current pos is tombstone + * cond = is_el_set && !is_el_matching && !is_el_skip; + * } + * + * if( cond ) { + * pos += 1; + * pos %= capacity; + * } + * else { + * break; + * } + * } + * + */ + + if( !are_iterators_set ) { + if( !for_read ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + is_el_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + } + + LLVM::CreateStore(*builder, el_hash, pos_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, pos)); + llvm::Value* is_el_skip = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3))); + llvm::Value* is_el_set = builder->CreateICmpNE(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* is_el_matching = llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 0)); + LLVM::CreateStore(*builder, is_el_matching, is_el_matching_var); + llvm::Value* compare_elems = builder->CreateAnd(is_el_set, + builder->CreateNot(is_el_skip)); + llvm_utils->create_if_else(compare_elems, [&]() { + llvm::Value* original_el = llvm_utils->list_api->read_item(el_list, pos, + false, module, LLVM::is_llvm_struct(el_asr_type)); + is_el_matching = llvm_utils->is_equal_by_value(el, original_el, module, + el_asr_type); + LLVM::CreateStore(*builder, is_el_matching, is_el_matching_var); + }, [=]() { + }); + // TODO: Allow safe exit if pos becomes el_hash again. + // Ideally should not happen as set will be resized once + // load factor touches a threshold (which will always be less than 1) + // so there will be some el which will not be set. However for safety + // we can add an exit from the loop with a error message. + llvm::Value *cond = nullptr; + if( for_read ) { + cond = builder->CreateAnd(is_el_set, builder->CreateNot( + LLVM::CreateLoad(*builder, is_el_matching_var))); + cond = builder->CreateOr(is_el_skip, cond); + } else { + cond = builder->CreateAnd(is_el_set, builder->CreateNot(is_el_skip)); + cond = builder->CreateAnd(cond, builder->CreateNot( + LLVM::CreateLoad(*builder, is_el_matching_var))); + } + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + pos = builder->CreateAdd(pos, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + pos = builder->CreateSRem(pos, capacity); + LLVM::CreateStore(*builder, pos, pos_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + } // namespace LCompilers diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index ffcfef16c6..4b68c23ac8 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -706,6 +706,89 @@ namespace LCompilers { }; + class LLVMSetInterface { + + protected: + + llvm::LLVMContext& context; + LLVMUtils* llvm_utils; + llvm::IRBuilder<>* builder; + llvm::AllocaInst *pos_ptr, *is_el_matching_var; + llvm::AllocaInst *idx_ptr, *hash_iter, *hash_value; + llvm::AllocaInst *polynomial_powers; + bool are_iterators_set; + + std::map> typecode2settype; + + public: + + bool is_set_present_; + + LLVMSetInterface( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); + + virtual + llvm::Type* get_set_type(std::string type_code, + int32_t type_size, llvm::Type* el_type) = 0; + + virtual + void set_init(std::string type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity) = 0; + + virtual + llvm::Value* get_el_list(llvm::Value* set) = 0; + + virtual + llvm::Value* get_pointer_to_occupancy(llvm::Value* set) = 0; + + virtual + llvm::Value* get_pointer_to_capacity(llvm::Value* set) = 0; + + virtual + void set_iterators(); + + virtual + void reset_iterators(); + + llvm::Value* get_el_hash(llvm::Value* capacity, llvm::Value* el, + ASR::ttype_t* el_asr_type, llvm::Module& module); + + virtual ~LLVMSetInterface() = 0; + + }; + + class LLVMSetLinearProbing: public LLVMSetInterface { + + public: + + LLVMSetLinearProbing(llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); + + llvm::Type* get_set_type(std::string type_code, + int32_t type_size, llvm::Type* el_type); + + void set_init(std::string type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity); + + llvm::Value* get_el_list(llvm::Value* set); + + llvm::Value* get_pointer_to_occupancy(llvm::Value* set); + + llvm::Value* get_pointer_to_capacity(llvm::Value* set); + + llvm::Value* get_pointer_to_mask(llvm::Value* set); + + void resolve_collision(llvm::Value* capacity, llvm::Value* el_hash, + llvm::Value* el, llvm::Value* el_list, + llvm::Value* el_mask, llvm::Module& module, + ASR::ttype_t* el_asr_type, bool for_read=false); + + ~LLVMSetLinearProbing(); + }; + } // namespace LCompilers #endif // LFORTRAN_LLVM_UTILS_H From f0a7e61bff53bf01a34922dff72bb33651d1bb6c Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Wed, 5 Jul 2023 22:34:44 +0530 Subject: [PATCH 2/6] set constant and len --- src/libasr/codegen/asr_to_llvm.cpp | 71 ++++++++ src/libasr/codegen/llvm_utils.cpp | 270 ++++++++++++++++++++++++++++- src/libasr/codegen/llvm_utils.h | 84 ++++++++- 3 files changed, 412 insertions(+), 13 deletions(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index a28d5237c7..f83b7139fb 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -203,6 +203,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::unique_ptr tuple_api; std::unique_ptr dict_api_lp; std::unique_ptr dict_api_sc; + std::unique_ptr set_api; // linear probing std::unique_ptr arr_descr; int64_t ptr_loads; @@ -241,6 +242,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tuple_api(std::make_unique(context, llvm_utils.get(), builder.get())), dict_api_lp(std::make_unique(context, llvm_utils.get(), builder.get())), dict_api_sc(std::make_unique(context, llvm_utils.get(), builder.get())), + set_api(std::make_unique(context, llvm_utils.get(), builder.get())), arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context, builder.get(), llvm_utils.get(), @@ -255,6 +257,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->tuple_api = tuple_api.get(); llvm_utils->list_api = list_api.get(); llvm_utils->dict_api = nullptr; + llvm_utils->set_api = set_api.get(); llvm_utils->arr_api = arr_descr.get(); } @@ -1606,6 +1609,25 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = const_dict; } + void visit_SetConstant(const ASR::SetConstant_t& x) { + llvm::Type* const_set_type = get_set_type(x.m_type); + llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set"); + ASR::Set_t* x_set = ASR::down_cast(x.m_type); + std::string el_type_code = ASRUtils::get_type_code(x_set->m_type); + llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements); + int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type); + int64_t ptr_loads_copy = ptr_loads; + for( size_t i = 0; i < x.n_elements; i++ ) { + ptr_loads = ptr_loads_el; + visit_expr_wrapper(x.m_elements[i], true); + llvm::Value* element = tmp; + llvm_utils->set_api->write_item(const_set, element, module.get(), + x_set->m_type, name2memidx); + } + ptr_loads = ptr_loads_copy; + tmp = const_set; + } + void visit_TupleConstant(const ASR::TupleConstant_t& x) { ASR::Tuple_t* tuple_type = ASR::down_cast(x.m_type); std::string type_code = ASRUtils::get_type_code(tuple_type->m_type, @@ -1911,6 +1933,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = llvm_utils->dict_api->len(pdict); } + void visit_SetLen(const ASR::SetLen_t& x) { + if (x.m_value) { + this->visit_expr(*x.m_value); + return ; + } + + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_arg); + ptr_loads = ptr_loads_copy; + llvm::Value* pset = tmp; + tmp = llvm_utils->set_api->len(pset); + } + void visit_ListInsert(const ASR::ListInsert_t& x) { ASR::List_t* asr_list = ASR::down_cast( ASRUtils::expr_type(x.m_a)); @@ -3104,6 +3140,23 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor value_type_size, key_llvm_type, value_llvm_type); } + llvm::Type* get_set_type(ASR::ttype_t* asr_type) { + ASR::Set_t* asr_set = ASR::down_cast(asr_type); + bool is_local_array_type = false, is_local_malloc_array_type = false; + bool is_local_list = false; + ASR::dimension_t* local_m_dims = nullptr; + int local_n_dims = 0; + int local_a_kind = -1; + ASR::storage_typeType local_m_storage = ASR::storage_typeType::Default; + llvm::Type* el_llvm_type = get_type_from_ttype_t(asr_set->m_type, nullptr, local_m_storage, + is_local_array_type, is_local_malloc_array_type, + is_local_list, local_m_dims, local_n_dims, + local_a_kind); + int32_t el_type_size = get_type_size(asr_set->m_type, el_llvm_type, local_a_kind); + std::string el_type_code = ASRUtils::get_type_code(asr_set->m_type); + return llvm_utils->set_api->get_set_type(el_type_code, el_type_size, el_llvm_type); + } + llvm::Type* get_type_from_ttype_t(ASR::ttype_t* asr_type, ASR::symbol_t *type_declaration, ASR::storage_typeType m_storage, @@ -3227,6 +3280,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_type = get_dict_type(asr_type); break; } + case (ASR::ttypeType::Set): { + llvm_type = get_set_type(asr_type); + break; + } case (ASR::ttypeType::Tuple) : { ASR::Tuple_t* asr_tuple = ASR::down_cast(asr_type); std::string type_code = ASRUtils::get_type_code(asr_tuple->m_type, @@ -4813,6 +4870,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_value_tuple = ASR::is_a(*asr_value_type); bool is_target_dict = ASR::is_a(*asr_target_type); bool is_value_dict = ASR::is_a(*asr_value_type); + bool is_target_set = ASR::is_a(*asr_target_type); + bool is_value_set = ASR::is_a(*asr_value_type); bool is_target_struct = ASR::is_a(*asr_target_type); bool is_value_struct = ASR::is_a(*asr_value_type); if (ASR::is_a(*x.m_target)) { @@ -4902,6 +4961,18 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->dict_api->dict_deepcopy(value_dict, target_dict, value_dict_type, module.get(), name2memidx); return ; + } else if( is_target_set && is_value_set ) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_value); + llvm::Value* value_set = tmp; + this->visit_expr(*x.m_target); + llvm::Value* target_set = tmp; + ptr_loads = ptr_loads_copy; + ASR::Set_t* value_set_type = ASR::down_cast(asr_value_type); + llvm_utils->set_api->set_deepcopy(value_set, target_set, + value_set_type, module.get(), name2memidx); + return ; } else if( is_target_struct && is_value_struct ) { int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 571a6a8613..f90649826f 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -3384,7 +3384,8 @@ namespace LCompilers { are_iterators_set = false; } - llvm::Value* LLVMSetInterface::get_el_hash(llvm::Value* capacity, llvm::Value* el, + llvm::Value* LLVMSetInterface::get_el_hash( + llvm::Value* capacity, llvm::Value* el, ASR::ttype_t* el_asr_type, llvm::Module& module) { // Write specialised hash functions for intrinsic types // This is to avoid unnecessary calls to C-runtime and do @@ -3489,10 +3490,11 @@ namespace LCompilers { } } - void LLVMSetLinearProbing::resolve_collision(llvm::Value* capacity, llvm::Value* el_hash, - llvm::Value* el, llvm::Value* el_list, - llvm::Value* el_mask, llvm::Module& module, - ASR::ttype_t* el_asr_type, bool for_read) { + void LLVMSetLinearProbing::resolve_collision( + llvm::Value* capacity, llvm::Value* el_hash, + llvm::Value* el, llvm::Value* el_list, + llvm::Value* el_mask, llvm::Module& module, + ASR::ttype_t* el_asr_type, bool for_read) { /** * C++ equivalent: @@ -3603,5 +3605,263 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMSetLinearProbing::resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * resolve_collision(); // modifies pos + * el_list[pos] = el; + * el_mask_value = el_mask[pos]; + * is_slot_empty = el_mask_value == 0; // el_list[pos] wasn't set before + * occupancy += is_slot_empty; + * linear_prob_happened = (el_hash != pos) || (el_mask[el_hash] == 2); + * set_max_2 = linear_prob_happened ? 2 : 1; + * el_mask[el_hash] = set_max_2; + * el_mask[pos] = set_max_2; + * + */ + + llvm::Value* el_list = get_el_list(set); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + this->resolve_collision(capacity, el_hash, el, el_list, el_mask, *module, el_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm_utils->list_api->write_item(el_list, pos, el, + el_asr_type, false, module, name2memidx); + + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, pos)); + llvm::Value* is_slot_empty = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + LLVM::CreateStore(*builder, builder->CreateAdd(occupancy, is_slot_empty), + occupancy_ptr); + + llvm::Value* linear_prob_happened = builder->CreateICmpNE(el_hash, pos); + linear_prob_happened = builder->CreateOr(linear_prob_happened, + builder->CreateICmpEQ( + LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el_mask, el_hash)), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2) + )) + ); + llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2)), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(el_mask, el_hash)); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(el_mask, pos)); + } + + void LLVMSetLinearProbing::rehash( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * old_capacity = capacity; + * capacity = 2 * capacity + 1; + * + * idx = 0; + * while( old_capacity > idx ) { + * is_el_set = el_mask[idx] != 0; + * if( is_el_set ) { + * el = el_list[idx]; + * el_hash = get_el_hash(); // with new capacity + * resolve_collision(); // with new_el_list; modifies pos + * new_el_list[pos] = el; + * linear_prob_happened = el_hash != pos; + * set_max_2 = linear_prob_happened ? 2 : 1; + * new_el_mask[el_hash] = set_max_2; + * new_el_mask[pos] = set_max_2; + * } + * idx += 1; + * } + * + * free(el_list); + * free(el_mask); + * el_list = new_el_list; + * el_mask = new_el_mask; + * + */ + + llvm::Value* capacity_ptr = get_pointer_to_capacity(set); + llvm::Value* old_capacity = LLVM::CreateLoad(*builder, capacity_ptr); + llvm::Value* capacity = builder->CreateMul(old_capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 2))); + capacity = builder->CreateAdd(capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, capacity, capacity_ptr); + + std::string el_type_code = ASRUtils::get_type_code(el_asr_type); + llvm::Type* el_llvm_type = std::get<2>(typecode2settype[el_type_code]); + int32_t el_type_size = std::get<1>(typecode2settype[el_type_code]); + + llvm::Value* el_list = get_el_list(set); + llvm::Value* new_el_list = builder->CreateAlloca(llvm_utils->list_api->get_list_type(el_llvm_type, + el_type_code, el_type_size), nullptr); + llvm_utils->list_api->list_init(el_type_code, new_el_list, *module, capacity, capacity); + + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::DataLayout data_layout(module); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* new_el_mask = LLVM::lfortran_calloc(context, *module, *builder, capacity, + llvm_mask_size); + + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + if( !are_iterators_set ) { + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), idx_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT(old_capacity, LLVM::CreateLoad(*builder, idx_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* is_el_set = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el_mask, idx)); + is_el_set = builder->CreateICmpNE(is_el_set, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + builder->CreateCondBr(is_el_set, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + llvm::Value* el = llvm_utils->list_api->read_item(el_list, idx, + false, *module, LLVM::is_llvm_struct(el_asr_type)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, *module); + this->resolve_collision(current_capacity, el_hash, el, new_el_list, + new_el_mask, *module, el_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* el_dest = llvm_utils->list_api->read_item( + new_el_list, pos, false, *module, true); + llvm_utils->deepcopy(el, el_dest, el_asr_type, module, name2memidx); + + llvm::Value* linear_prob_happened = builder->CreateICmpNE(el_hash, pos); + llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2)), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(new_el_mask, el_hash)); + LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(new_el_mask, pos)); + } + builder->CreateBr(mergeBB); + + llvm_utils->start_new_block(elseBB); + llvm_utils->start_new_block(mergeBB); + idx = builder->CreateAdd(idx, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, idx, idx_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + + llvm_utils->list_api->free_data(el_list, *module); + LLVM::lfortran_free(context, *module, *builder, el_mask); + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_el_list), el_list); + LLVM::CreateStore(*builder, new_el_mask, get_pointer_to_mask(set)); + } + + void LLVMSetLinearProbing::rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * occupancy += 1; + * load_factor = occupancy / capacity; + * load_factor_threshold = 0.6; + * rehash_condition = (capacity == 0) || (load_factor >= load_factor_threshold); + * if( rehash_condition ) { + * rehash(); + * } + * + */ + + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))); + occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context)); + capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context)); + llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity); + // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor + llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), + llvm::APFloat((float) 0.6)); + rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold)); + llvm_utils->create_if_else(rehash_condition, [&]() { + rehash(set, module, el_asr_type, name2memidx); + }, [=]() { + }); + } + + void LLVMSetLinearProbing::write_item( + llvm::Value* set, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + rehash_all_at_once_if_needed(set, module, el_asr_type, name2memidx); + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, *module); + this->resolve_collision_for_write(set, el_hash, el, module, + el_asr_type, name2memidx); + } + + void LLVMSetLinearProbing::set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx) { + LCOMPILERS_ASSERT(src->getType() == dest->getType()); + llvm::Value* src_occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(src)); + llvm::Value* dest_occupancy_ptr = get_pointer_to_occupancy(dest); + LLVM::CreateStore(*builder, src_occupancy, dest_occupancy_ptr); + + llvm::Value* src_el_list = get_el_list(src); + llvm::Value* dest_el_list = get_el_list(dest); + llvm_utils->list_api->list_deepcopy(src_el_list, dest_el_list, + set_type->m_type, module, + name2memidx); + + llvm::Value* src_el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(src)); + llvm::Value* dest_el_mask_ptr = get_pointer_to_mask(dest); + llvm::DataLayout data_layout(module); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(src)); + llvm::Value* dest_el_mask = LLVM::lfortran_calloc(context, *module, *builder, src_capacity, + llvm_mask_size); + builder->CreateMemCpy(dest_el_mask, llvm::MaybeAlign(), src_el_mask, + llvm::MaybeAlign(), builder->CreateMul(src_capacity, llvm_mask_size)); + LLVM::CreateStore(*builder, dest_el_mask, dest_el_mask_ptr); + } + + llvm::Value* LLVMSetLinearProbing::len(llvm::Value* set) { + return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); + } } // namespace LCompilers diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 4b68c23ac8..b48c8e08bc 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -103,6 +103,7 @@ namespace LCompilers { class LLVMList; class LLVMTuple; class LLVMDictInterface; + class LLVMSetInterface; class LLVMUtils { @@ -119,6 +120,7 @@ namespace LCompilers { LLVMTuple* tuple_api; LLVMList* list_api; LLVMDictInterface* dict_api; + LLVMSetInterface* set_api; LLVMArrUtils::Descriptor* arr_api; llvm::Module* module; @@ -755,6 +757,44 @@ namespace LCompilers { llvm::Value* get_el_hash(llvm::Value* capacity, llvm::Value* el, ASR::ttype_t* el_asr_type, llvm::Module& module); + virtual + void resolve_collision( + llvm::Value* capacity, llvm::Value* el_hash, + llvm::Value* el, llvm::Value* el_list, + llvm::Value* el_mask, llvm::Module& module, + ASR::ttype_t* el_asr_type, bool for_read=false) = 0; + + virtual + void resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) = 0; + + virtual + void rehash( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) = 0; + + virtual + void rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) = 0; + + virtual + void write_item( + llvm::Value* set, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) = 0; + + virtual + void set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx) = 0; + + virtual + llvm::Value* len(llvm::Value* set) = 0; + virtual ~LLVMSetInterface() = 0; }; @@ -763,11 +803,13 @@ namespace LCompilers { public: - LLVMSetLinearProbing(llvm::LLVMContext& context_, - LLVMUtils* llvm_utils, - llvm::IRBuilder<>* builder); + LLVMSetLinearProbing( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); - llvm::Type* get_set_type(std::string type_code, + llvm::Type* get_set_type( + std::string type_code, int32_t type_size, llvm::Type* el_type); void set_init(std::string type_code, llvm::Value* set, @@ -781,10 +823,36 @@ namespace LCompilers { llvm::Value* get_pointer_to_mask(llvm::Value* set); - void resolve_collision(llvm::Value* capacity, llvm::Value* el_hash, - llvm::Value* el, llvm::Value* el_list, - llvm::Value* el_mask, llvm::Module& module, - ASR::ttype_t* el_asr_type, bool for_read=false); + void resolve_collision( + llvm::Value* capacity, llvm::Value* el_hash, + llvm::Value* el, llvm::Value* el_list, + llvm::Value* el_mask, llvm::Module& module, + ASR::ttype_t* el_asr_type, bool for_read=false); + + void resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + + void rehash( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + + void rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + + void write_item( + llvm::Value* set, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + + void set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx); + + llvm::Value* len(llvm::Value* set); ~LLVMSetLinearProbing(); }; From 31080dcdf263f332b0325e7befe1c7d6cd99df62 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Sat, 8 Jul 2023 11:52:26 +0530 Subject: [PATCH 3/6] set len test --- integration_tests/CMakeLists.txt | 1 + integration_tests/test_set_len.py | 8 ++++++++ 2 files changed, 9 insertions(+) create mode 100644 integration_tests/test_set_len.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 23175f5d47..d2e7160f23 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -477,6 +477,7 @@ RUN(NAME test_dict_12 LABELS cpython llvm c) RUN(NAME test_dict_13 LABELS cpython llvm c) RUN(NAME test_dict_bool LABELS cpython llvm) RUN(NAME test_dict_increment LABELS cpython llvm) +RUN(NAME test_set_len LABELS cpython llvm) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_set_len.py b/integration_tests/test_set_len.py new file mode 100644 index 0000000000..33d252a0fe --- /dev/null +++ b/integration_tests/test_set_len.py @@ -0,0 +1,8 @@ +from lpython import i32 + +def test_set(): + s: set[i32] + s = {1, 2, 22, 2, -1, 1} + assert len(s) == 4 + +test_set() \ No newline at end of file From 97e4f44af0753e961fd705409ea2ff585cad9b43 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Wed, 12 Jul 2023 16:36:24 +0530 Subject: [PATCH 4/6] set add, remove --- integration_tests/CMakeLists.txt | 2 + integration_tests/test_set_add.py | 34 ++++ integration_tests/test_set_remove.py | 47 ++++++ src/libasr/ASR.asdl | 2 - src/libasr/codegen/asr_to_llvm.cpp | 36 +++++ src/libasr/codegen/llvm_utils.cpp | 150 +++++++++++++++++- src/libasr/codegen/llvm_utils.h | 18 +++ src/libasr/pass/intrinsic_function_registry.h | 112 +++++++++++++ src/lpython/semantics/python_ast_to_asr.cpp | 2 +- src/lpython/semantics/python_attribute_eval.h | 62 +++----- tests/errors/test_set4.py | 6 + tests/reference/asr-set1-b7b913a.json | 2 +- tests/reference/asr-set1-b7b913a.stdout | 24 ++- tests/reference/asr-test_set1-11379c7.json | 2 +- tests/reference/asr-test_set1-11379c7.stderr | 6 +- tests/reference/asr-test_set2-d91a6f0.json | 2 +- tests/reference/asr-test_set2-d91a6f0.stderr | 2 +- tests/reference/asr-test_set4-53fea39.json | 13 ++ tests/reference/asr-test_set4-53fea39.stderr | 5 + tests/tests.toml | 4 + 20 files changed, 472 insertions(+), 59 deletions(-) create mode 100644 integration_tests/test_set_add.py create mode 100644 integration_tests/test_set_remove.py create mode 100644 tests/errors/test_set4.py create mode 100644 tests/reference/asr-test_set4-53fea39.json create mode 100644 tests/reference/asr-test_set4-53fea39.stderr diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index b2f4a18231..5cc23e81c9 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -483,6 +483,8 @@ RUN(NAME test_dict_13 LABELS cpython llvm c) RUN(NAME test_dict_bool LABELS cpython llvm) RUN(NAME test_dict_increment LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm) +RUN(NAME test_set_add LABELS cpython llvm) +RUN(NAME test_set_remove LABELS cpython llvm) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_set_add.py b/integration_tests/test_set_add.py new file mode 100644 index 0000000000..de8bbdd3d9 --- /dev/null +++ b/integration_tests/test_set_add.py @@ -0,0 +1,34 @@ +from lpython import i32 + +def test_set_add(): + s1: set[i32] + s2: set[tuple[i32, tuple[i32, i32], str]] + s3: set[str] + st1: str + i: i32 + j: i32 + + s1 = {0} + s2 = {(0, (1, 2), 'a')} + for i in range(20): + j = i % 10 + s1.add(j) + s2.add((j, (j + 1, j + 2), 'a')) + assert len(s1) == len(s2) + if i < 10: + assert len(s1) == i + 1 + else: + assert len(s1) == 10 + + st1 = 'a' + s3 = {st1} + for i in range(20): + s3.add(st1) + if i < 10: + if i > 0: + assert len(s3) == i + st1 += 'a' + else: + assert len(s3) == 10 + +test_set_add() \ No newline at end of file diff --git a/integration_tests/test_set_remove.py b/integration_tests/test_set_remove.py new file mode 100644 index 0000000000..ca93ec0a80 --- /dev/null +++ b/integration_tests/test_set_remove.py @@ -0,0 +1,47 @@ +from lpython import i32 + +def test_set_add(): + s1: set[i32] + s2: set[tuple[i32, tuple[i32, i32], str]] + s3: set[str] + st1: str + i: i32 + j: i32 + k: i32 + + for k in range(2): + s1 = {0} + s2 = {(0, (1, 2), 'a')} + for i in range(20): + j = i % 10 + s1.add(j) + s2.add((j, (j + 1, j + 2), 'a')) + + for i in range(10): + s1.remove(i) + s2.remove((i, (i + 1, i + 2), 'a')) + # assert len(s1) == 10 - 1 - i + # assert len(s1) == len(s2) + + st1 = 'a' + s3 = {st1} + for i in range(20): + s3.add(st1) + if i < 10: + if i > 0: + st1 += 'a' + + st1 = 'a' + for i in range(10): + s3.remove(st1) + assert len(s3) == 10 - 1 - i + if i < 10: + st1 += 'a' + + for i in range(20): + s1.add(i) + if i % 2 == 0: + s1.remove(i) + assert len(s1) == (i + 1) // 2 + +test_set_add() \ No newline at end of file diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 6703b0229d..7e6f53174a 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -219,8 +219,6 @@ stmt | SelectType(expr selector, type_stmt* body, stmt* default) | CPtrToPointer(expr cptr, expr ptr, expr? shape, expr? lower_bounds) | BlockCall(int label, symbol m) - | SetInsert(expr a, expr ele) - | SetRemove(expr a, expr ele) | ListInsert(expr a, expr pos, expr ele) | ListRemove(expr a, expr ele) | ListClear(expr a) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 2fc1934745..550d7524cd 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1668,6 +1668,34 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx); } + void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*m_arg); + llvm::Value* pset = tmp; + + ptr_loads = 2; + this->visit_expr_wrapper(m_ele, true); + ptr_loads = ptr_loads_copy; + llvm::Value *el = tmp; + set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx); + } + + void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*m_arg); + llvm::Value* pset = tmp; + + ptr_loads = 2; + this->visit_expr_wrapper(m_ele, true); + ptr_loads = ptr_loads_copy; + llvm::Value *el = tmp; + set_api->remove_item(pset, el, *module, asr_el_type); + } + void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) { switch (static_cast(x.m_intrinsic_id)) { case ASRUtils::IntrinsicFunctions::ListIndex: { @@ -1711,6 +1739,14 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } break; } + case ASRUtils::IntrinsicFunctions::SetAdd: { + generate_SetAdd(x.m_args[0], x.m_args[1]); + break; + } + case ASRUtils::IntrinsicFunctions::SetRemove: { + generate_SetRemove(x.m_args[0], x.m_args[1]); + break; + } case ASRUtils::IntrinsicFunctions::Exp: { switch (x.m_overload_id) { case 0: { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 9e8a5b481e..22279f5d92 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -2757,6 +2757,8 @@ namespace LCompilers { llvm_utils->create_ptr_gep(key_mask, pos)); llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + is_slot_empty = builder->CreateOr(is_slot_empty, builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)))); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); @@ -2959,6 +2961,33 @@ namespace LCompilers { llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + + /** + * C++ equivalent: + * + * key_mask_value = key_mask[key_hash]; + * is_prob_needed = key_mask_value == 1; + * if( is_prob_needed ) { + * is_key_matching = key == key_list[key_hash]; + * if( is_key_matching ) { + * pos = key_hash; + * } + * else { + * exit(1); // key not present + * } + * } + * else { + * resolve_collision(key, for_read=true); // modifies pos + * } + * + * is_key_matching = key == key_list[pos]; + * if( !is_key_matching ) { + * exit(1); // key not present + * } + * + * return value_list[pos]; + */ + llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -5133,7 +5162,7 @@ namespace LCompilers { * resolve_collision(); // modifies pos * el_list[pos] = el; * el_mask_value = el_mask[pos]; - * is_slot_empty = el_mask_value == 0; // el_list[pos] wasn't set before + * is_slot_empty = el_mask_value == 0 || el_mask_value == 3; * occupancy += is_slot_empty; * linear_prob_happened = (el_hash != pos) || (el_mask[el_hash] == 2); * set_max_2 = linear_prob_happened ? 2 : 1; @@ -5154,6 +5183,8 @@ namespace LCompilers { llvm_utils->create_ptr_gep(el_mask, pos)); llvm::Value* is_slot_empty = builder->CreateICmpEQ(el_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + is_slot_empty = builder->CreateOr(is_slot_empty, builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)))); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); @@ -5348,6 +5379,123 @@ namespace LCompilers { el_asr_type, name2memidx); } + void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + + /** + * C++ equivalent: + * + * el_mask_value = el_mask[el_hash]; + * is_prob_needed = el_mask_value == 1; + * if( is_prob_needed ) { + * is_el_matching = el == el_list[el_hash]; + * if( is_el_matching ) { + * pos = el_hash; + * } + * else { + * exit(1); // el not present + * } + * } + * else { + * resolve_collision(el, for_read=true); // modifies pos + * } + * + * is_el_matching = el == el_list[pos]; + * if( !is_el_matching ) { + * exit(1); // el not present + * } + * + */ + + llvm::Value* el_list = get_el_list(set); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + if( !are_iterators_set ) { + pos_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + // reasoning for this check explained in + // LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check + llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, + llvm_utils->list_api->read_item(el_list, el_hash, false, module, + LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); + + llvm_utils->create_if_else(is_el_matching, [=]() { + LLVM::CreateStore(*builder, el_hash, pos_ptr); + }, [&]() { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + this->resolve_collision(capacity, el_hash, el, el_list, el_mask, + module, el_asr_type, true); + } + llvm_utils->start_new_block(mergeBB); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + // Check if the actual element is present or not + llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, + llvm_utils->list_api->read_item(el_list, pos, false, module, + LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); + + llvm_utils->create_if_else(is_el_matching, [&]() { + }, [&]() { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + } + + void LLVMSetLinearProbing::remove_item( + llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * resolve_collision_for_read(el); // modifies pos + * el_mask[pos] = 3; // tombstone marker + * occupancy -= 1; + */ + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module); + this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos); + llvm::Value* tombstone_marker = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)); + LLVM::CreateStore(*builder, tombstone_marker, el_mask_i); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); + } + void LLVMSetLinearProbing::set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module, diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 72a42e1ba7..7ff0085fe3 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -928,6 +928,16 @@ namespace LCompilers { llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx) = 0; + virtual + void resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) = 0; + + virtual + void remove_item( + llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) = 0; + virtual void set_deepcopy( llvm::Value* src, llvm::Value* dest, @@ -989,6 +999,14 @@ namespace LCompilers { llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); + void resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void remove_item( + llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type); + void set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module, diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index cff08d2e69..a7b9a8dcdd 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -64,6 +64,8 @@ enum class IntrinsicFunctions : int64_t { Partition, ListReverse, ListPop, + SetAdd, + SetRemove, SymbolicSymbol, SymbolicAdd, SymbolicSub, @@ -1146,6 +1148,104 @@ static inline ASR::asr_t* create_ListPop(Allocator& al, const Location& loc, } // namespace ListPop +namespace SetAdd { + +static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, "Call to set.add must have exactly one argument", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), + "First argument to set.add must be of set type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASRUtils::check_equal_type(ASRUtils::expr_type(x.m_args[1]), + ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]))), + "Second argument to set.add must be of same type as set's element type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(x.m_type == nullptr, + "Return type of set.add must be empty", + x.base.base.loc, diagnostics); +} + +static inline ASR::expr_t *eval_set_add(Allocator &/*al*/, + const Location &/*loc*/, Vec& /*args*/) { + // TODO: To be implemented for SetConstant expression + return nullptr; +} + +static inline ASR::asr_t* create_SetAdd(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 2) { + err("Call to set.add must have exactly one argument", loc); + } + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(args[1]), + ASRUtils::get_contained_type(ASRUtils::expr_type(args[0])))) { + err("Argument to set.add must be of same type as set's " + "element type", loc); + } + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + ASR::expr_t* compile_time_value = eval_set_add(al, loc, arg_values); + return ASR::make_Expr_t(al, loc, + ASRUtils::EXPR(ASR::make_IntrinsicFunction_t(al, loc, + static_cast(ASRUtils::IntrinsicFunctions::SetAdd), + args.p, args.size(), 0, nullptr, compile_time_value))); +} + +} // namespace SetAdd + +namespace SetRemove { + +static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, "Call to set.remove must have exactly one argument", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), + "First argument to set.remove must be of set type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASRUtils::check_equal_type(ASRUtils::expr_type(x.m_args[1]), + ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]))), + "Second argument to set.remove must be of same type as set's element type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(x.m_type == nullptr, + "Return type of set.remove must be empty", + x.base.base.loc, diagnostics); +} + +static inline ASR::expr_t *eval_set_remove(Allocator &/*al*/, + const Location &/*loc*/, Vec& /*args*/) { + // TODO: To be implemented for SetConstant expression + return nullptr; +} + +static inline ASR::asr_t* create_SetRemove(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 2) { + err("Call to set.remove must have exactly one argument", loc); + } + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(args[1]), + ASRUtils::get_contained_type(ASRUtils::expr_type(args[0])))) { + err("Argument to set.remove must be of same type as set's " + "element type", loc); + } + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + ASR::expr_t* compile_time_value = eval_set_remove(al, loc, arg_values); + return ASR::make_Expr_t(al, loc, + ASRUtils::EXPR(ASR::make_IntrinsicFunction_t(al, loc, + static_cast(ASRUtils::IntrinsicFunctions::SetRemove), + args.p, args.size(), 0, nullptr, compile_time_value))); +} + +} // namespace SetRemove + namespace Any { static inline void verify_array(ASR::expr_t* array, ASR::ttype_t* return_type, @@ -2267,6 +2367,10 @@ namespace IntrinsicFunctionRegistry { {nullptr, &ListPop::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::ListReverse), {nullptr, &ListReverse::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SetAdd), + {nullptr, &SetAdd::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::SetRemove), + {nullptr, &SetRemove::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSymbol), {nullptr, &SymbolicSymbol::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd), @@ -2335,6 +2439,10 @@ namespace IntrinsicFunctionRegistry { "list.reverse"}, {static_cast(ASRUtils::IntrinsicFunctions::ListPop), "list.pop"}, + {static_cast(ASRUtils::IntrinsicFunctions::SetAdd), + "set.add"}, + {static_cast(ASRUtils::IntrinsicFunctions::SetRemove), + "set.remove"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSymbol), "Symbol"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd), @@ -2394,6 +2502,8 @@ namespace IntrinsicFunctionRegistry { {"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}}, {"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}}, {"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}}, + {"set.add", {&SetAdd::create_SetAdd, &SetAdd::eval_set_add}}, + {"set.remove", {&SetRemove::create_SetRemove, &SetRemove::eval_set_remove}}, {"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}}, {"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}}, {"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}}, @@ -2515,6 +2625,8 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(Partition) INTRINSIC_NAME_CASE(ListReverse) INTRINSIC_NAME_CASE(ListPop) + INTRINSIC_NAME_CASE(SetAdd) + INTRINSIC_NAME_CASE(SetRemove) INTRINSIC_NAME_CASE(SymbolicSymbol) INTRINSIC_NAME_CASE(SymbolicAdd) INTRINSIC_NAME_CASE(SymbolicSub) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 947d870e98..fb16de689f 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1729,7 +1729,7 @@ class CommonVisitor : public AST::BaseVisitor { false, false, false, nullptr, 0, nullptr, 0, false)); return type; } else if (var_annotation == "set") { - if (AST::is_a(*s->m_slice)) { + if (AST::is_a(*s->m_slice) || AST::is_a(*s->m_slice)) { ASR::ttype_t *type = ast_expr_to_asr_type(loc, *s->m_slice, is_allocatable, raise_error, abi, is_argument); return ASRUtils::TYPE(ASR::make_Set_t(al, loc, type)); diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 1b45558979..5f150f5ef4 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -294,53 +294,31 @@ struct AttributeHandler { } static ASR::asr_t* eval_set_add(ASR::expr_t *s, Allocator &al, const Location &loc, - Vec &args, diag::Diagnostics &diag) { - if (args.size() != 1) { - throw SemanticError("add() takes exactly one argument", loc); - } - - ASR::ttype_t *type = ASRUtils::expr_type(s); - ASR::ttype_t *set_type = ASR::down_cast(type)->m_type; - ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]); - if (!ASRUtils::check_equal_type(ele_type, set_type)) { - std::string fnd = ASRUtils::type_to_str_python(ele_type); - std::string org = ASRUtils::type_to_str_python(set_type); - diag.add(diag::Diagnostic( - "Type mismatch in 'add', the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')", - {args[0]->base.loc}) - }) - ); - throw SemanticAbort(); + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_set; + args_with_set.reserve(al, args.size() + 1); + args_with_set.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_set.push_back(al, args[i]); } - - return make_SetInsert_t(al, loc, s, args[0]); + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("set.add"); + return create_function(al, loc, args_with_set, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); } static ASR::asr_t* eval_set_remove(ASR::expr_t *s, Allocator &al, const Location &loc, - Vec &args, diag::Diagnostics &diag) { - if (args.size() != 1) { - throw SemanticError("remove() takes exactly one argument", loc); - } - - ASR::ttype_t *type = ASRUtils::expr_type(s); - ASR::ttype_t *set_type = ASR::down_cast(type)->m_type; - ASR::ttype_t *ele_type = ASRUtils::expr_type(args[0]); - if (!ASRUtils::check_equal_type(ele_type, set_type)) { - std::string fnd = ASRUtils::type_to_str_python(ele_type); - std::string org = ASRUtils::type_to_str_python(set_type); - diag.add(diag::Diagnostic( - "Type mismatch in 'remove', the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch (found: '" + fnd + "', expected: '" + org + "')", - {args[0]->base.loc}) - }) - ); - throw SemanticAbort(); + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_set; + args_with_set.reserve(al, args.size() + 1); + args_with_set.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_set.push_back(al, args[i]); } - - return make_SetRemove_t(al, loc, s, args[0]); + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("set.remove"); + return create_function(al, loc, args_with_set, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); } static ASR::asr_t* eval_dict_get(ASR::expr_t *s, Allocator &al, const Location &loc, diff --git a/tests/errors/test_set4.py b/tests/errors/test_set4.py new file mode 100644 index 0000000000..7f64375502 --- /dev/null +++ b/tests/errors/test_set4.py @@ -0,0 +1,6 @@ +from lpython import i32 + +def test4(): + a: set[i32] + a = {1, 2, 3} + a.add(3, 4) diff --git a/tests/reference/asr-set1-b7b913a.json b/tests/reference/asr-set1-b7b913a.json index 22e446ded3..cebcf642e9 100644 --- a/tests/reference/asr-set1-b7b913a.json +++ b/tests/reference/asr-set1-b7b913a.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "asr-set1-b7b913a.stdout", - "stdout_hash": "b1a0479fa02536b0cf53854d4978319fcdba1b01d0ef99eec861a5a9", + "stdout_hash": "ab49b2e02638f55c4a3ca26750f154602746cbc5096c4b75e62e0b2a", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/asr-set1-b7b913a.stdout b/tests/reference/asr-set1-b7b913a.stdout index 5bf13d0bde..5f57547d04 100644 --- a/tests/reference/asr-set1-b7b913a.stdout +++ b/tests/reference/asr-set1-b7b913a.stdout @@ -109,13 +109,25 @@ ) () ) - (SetInsert - (Var 2 a) - (IntegerConstant 9 (Integer 4)) + (Expr + (IntrinsicFunction + SetAdd + [(Var 2 a) + (IntegerConstant 9 (Integer 4))] + 0 + () + () + ) ) - (SetRemove - (Var 2 a) - (IntegerConstant 4 (Integer 4)) + (Expr + (IntrinsicFunction + SetRemove + [(Var 2 a) + (IntegerConstant 4 (Integer 4))] + 0 + () + () + ) ) (= (Var 2 b) diff --git a/tests/reference/asr-test_set1-11379c7.json b/tests/reference/asr-test_set1-11379c7.json index 09cc2515af..417bba61b8 100644 --- a/tests/reference/asr-test_set1-11379c7.json +++ b/tests/reference/asr-test_set1-11379c7.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_set1-11379c7.stderr", - "stderr_hash": "9dcd4fd9b8878cabe6559827531844364da8311d7c8f5f846b38620d", + "stderr_hash": "64dea3d94817d0666cf71481546f7ec61639f47a3b696fe96ae287c6", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_set1-11379c7.stderr b/tests/reference/asr-test_set1-11379c7.stderr index d153edc1b9..f5851bc83d 100644 --- a/tests/reference/asr-test_set1-11379c7.stderr +++ b/tests/reference/asr-test_set1-11379c7.stderr @@ -1,5 +1,5 @@ -semantic error: Type mismatch in 'add', the types must be compatible - --> tests/errors/test_set1.py:6:11 +semantic error: Argument to set.add must be of same type as set's element type + --> tests/errors/test_set1.py:6:5 | 6 | a.add('err') - | ^^^^^ type mismatch (found: 'str', expected: 'i32') + | ^^^^^^^^^^^^ diff --git a/tests/reference/asr-test_set2-d91a6f0.json b/tests/reference/asr-test_set2-d91a6f0.json index 8d33226ef5..4c1d7ad258 100644 --- a/tests/reference/asr-test_set2-d91a6f0.json +++ b/tests/reference/asr-test_set2-d91a6f0.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_set2-d91a6f0.stderr", - "stderr_hash": "5459ddb5148c630f9374c827aad9c37d25967248002dc0dff5314530", + "stderr_hash": "36a3e507b04f030fc4e281ffe82947765ef640b6c558030957bd3e90", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_set2-d91a6f0.stderr b/tests/reference/asr-test_set2-d91a6f0.stderr index 7c5dfa54d2..29a2683c11 100644 --- a/tests/reference/asr-test_set2-d91a6f0.stderr +++ b/tests/reference/asr-test_set2-d91a6f0.stderr @@ -1,4 +1,4 @@ -semantic error: remove() takes exactly one argument +semantic error: Call to set.remove must have exactly one argument --> tests/errors/test_set2.py:6:5 | 6 | a.remove('error', 'error2') diff --git a/tests/reference/asr-test_set4-53fea39.json b/tests/reference/asr-test_set4-53fea39.json new file mode 100644 index 0000000000..aad37eb089 --- /dev/null +++ b/tests/reference/asr-test_set4-53fea39.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-test_set4-53fea39", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/errors/test_set4.py", + "infile_hash": "3d78c7ad82aa32c3a4cc5f1a7d44e53b81639194f55672ddc99b4d2d", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "asr-test_set4-53fea39.stderr", + "stderr_hash": "d9646bd3609c55ff39f57ca435fedc7dabed530caf28caddc9e58a06", + "returncode": 2 +} \ No newline at end of file diff --git a/tests/reference/asr-test_set4-53fea39.stderr b/tests/reference/asr-test_set4-53fea39.stderr new file mode 100644 index 0000000000..9ce2e3863c --- /dev/null +++ b/tests/reference/asr-test_set4-53fea39.stderr @@ -0,0 +1,5 @@ +semantic error: Call to set.add must have exactly one argument + --> tests/errors/test_set4.py:6:5 + | +6 | a.add(3, 4) + | ^^^^^^^^^^^ diff --git a/tests/tests.toml b/tests/tests.toml index 13ac8257cd..987ba267fb 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -833,6 +833,10 @@ asr = true filename = "errors/test_set3.py" asr = true +[[test]] +filename = "errors/test_set4.py" +asr = true + [[test]] filename = "errors/test_pow.py" asr = true From 09ad3dbc579130a9500fa60dd554a263143e1124 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Wed, 12 Jul 2023 17:54:35 +0530 Subject: [PATCH 5/6] add nofast to set add and remove --- integration_tests/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index f891d02903..116623c425 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -484,8 +484,8 @@ RUN(NAME test_dict_13 LABELS cpython llvm c) RUN(NAME test_dict_bool LABELS cpython llvm) RUN(NAME test_dict_increment LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm) -RUN(NAME test_set_add LABELS cpython llvm) -RUN(NAME test_set_remove LABELS cpython llvm) +RUN(NAME test_set_add LABELS cpython llvm NOFAST) +RUN(NAME test_set_remove LABELS cpython llvm NOFAST) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) From ab441175a959a1830731cbfc25ce5221e2e6662e Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Sat, 15 Jul 2023 16:40:24 +0530 Subject: [PATCH 6/6] Initialise are_iterators_set to False --- integration_tests/CMakeLists.txt | 4 +-- integration_tests/test_set_add.py | 2 +- src/libasr/codegen/llvm_utils.cpp | 47 ++++++++++++++++--------------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 116623c425..f891d02903 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -484,8 +484,8 @@ RUN(NAME test_dict_13 LABELS cpython llvm c) RUN(NAME test_dict_bool LABELS cpython llvm) RUN(NAME test_dict_increment LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm) -RUN(NAME test_set_add LABELS cpython llvm NOFAST) -RUN(NAME test_set_remove LABELS cpython llvm NOFAST) +RUN(NAME test_set_add LABELS cpython llvm) +RUN(NAME test_set_remove LABELS cpython llvm) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_set_add.py b/integration_tests/test_set_add.py index de8bbdd3d9..699b1cfa58 100644 --- a/integration_tests/test_set_add.py +++ b/integration_tests/test_set_add.py @@ -31,4 +31,4 @@ def test_set_add(): else: assert len(s3) == 10 -test_set_add() \ No newline at end of file +test_set_add() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index cdf8b3459b..470cdacbbd 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -2961,7 +2961,7 @@ namespace LCompilers { /** * C++ equivalent: - * + * * key_mask_value = key_mask[key_hash]; * is_prob_needed = key_mask_value == 1; * if( is_prob_needed ) { @@ -2976,12 +2976,12 @@ namespace LCompilers { * else { * resolve_collision(key, for_read=true); // modifies pos * } - * + * * is_key_matching = key == key_list[pos]; * if( !is_key_matching ) { * exit(1); // key not present * } - * + * * return value_list[pos]; */ @@ -4820,7 +4820,8 @@ namespace LCompilers { context(context_), llvm_utils(std::move(llvm_utils_)), builder(std::move(builder_)), - pos_ptr(nullptr), is_set_present_(false) { + pos_ptr(nullptr), are_iterators_set(false), + is_set_present_(false) { } LLVMSetLinearProbing::LLVMSetLinearProbing(llvm::LLVMContext& context_, @@ -5041,20 +5042,20 @@ namespace LCompilers { /** * C++ equivalent: - * + * * pos = el_hash; - * + * * while( true ) { * is_el_skip = el_mask_value == 3; // tombstone * is_el_set = el_mask_value != 0; * is_el_matching = 0; - * + * * compare_elems = is_el_set && !is_el_skip; * if( compare_elems ) { * original_el = el_list[pos]; * is_el_matching = el == original_el; * } - * + * * cond; * if( for_read ) { * // for reading, continue to next pos @@ -5066,7 +5067,7 @@ namespace LCompilers { * // if current pos is tombstone * cond = is_el_set && !is_el_matching && !is_el_skip; * } - * + * * if( cond ) { * pos += 1; * pos %= capacity; @@ -5075,7 +5076,7 @@ namespace LCompilers { * break; * } * } - * + * */ if( !are_iterators_set ) { @@ -5155,7 +5156,7 @@ namespace LCompilers { /** * C++ equivalent: - * + * * resolve_collision(); // modifies pos * el_list[pos] = el; * el_mask_value = el_mask[pos]; @@ -5165,7 +5166,7 @@ namespace LCompilers { * set_max_2 = linear_prob_happened ? 2 : 1; * el_mask[el_hash] = set_max_2; * el_mask[pos] = set_max_2; - * + * */ llvm::Value* el_list = get_el_list(set); @@ -5208,10 +5209,10 @@ namespace LCompilers { /** * C++ equivalent: - * + * * old_capacity = capacity; * capacity = 2 * capacity + 1; - * + * * idx = 0; * while( old_capacity > idx ) { * is_el_set = el_mask[idx] != 0; @@ -5227,12 +5228,12 @@ namespace LCompilers { * } * idx += 1; * } - * + * * free(el_list); * free(el_mask); * el_list = new_el_list; * el_mask = new_el_mask; - * + * */ llvm::Value* capacity_ptr = get_pointer_to_capacity(set); @@ -5332,10 +5333,10 @@ namespace LCompilers { void LLVMSetLinearProbing::rehash_all_at_once_if_needed( llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx) { - + /** * C++ equivalent: - * + * * occupancy += 1; * load_factor = occupancy / capacity; * load_factor_threshold = 0.6; @@ -5343,7 +5344,7 @@ namespace LCompilers { * if( rehash_condition ) { * rehash(); * } - * + * */ llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); @@ -5382,7 +5383,7 @@ namespace LCompilers { /** * C++ equivalent: - * + * * el_mask_value = el_mask[el_hash]; * is_prob_needed = el_mask_value == 1; * if( is_prob_needed ) { @@ -5397,12 +5398,12 @@ namespace LCompilers { * else { * resolve_collision(el, for_read=true); // modifies pos * } - * + * * is_el_matching = el == el_list[pos]; * if( !is_el_matching ) { * exit(1); // el not present * } - * + * */ llvm::Value* el_list = get_el_list(set); @@ -5472,7 +5473,7 @@ namespace LCompilers { llvm::Module& module, ASR::ttype_t* el_asr_type) { /** * C++ equivalent: - * + * * resolve_collision_for_read(el); // modifies pos * el_mask[pos] = 3; // tombstone marker * occupancy -= 1;