diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 6bde8f5d6c..cb32b652e0 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -590,6 +590,7 @@ RUN(NAME test_set_add LABELS cpython llvm llvm_jit) RUN(NAME test_set_remove LABELS cpython llvm llvm_jit) RUN(NAME test_set_discard LABELS cpython llvm llvm_jit) RUN(NAME test_set_clear LABELS cpython llvm) +RUN(NAME test_set_pop LABELS cpython llvm) RUN(NAME test_global_set LABELS cpython llvm llvm_jit) RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c) RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_set_pop.py b/integration_tests/test_set_pop.py new file mode 100644 index 0000000000..af4500e236 --- /dev/null +++ b/integration_tests/test_set_pop.py @@ -0,0 +1,26 @@ +def set_pop_str(): + s: set[str] = {'a', 'b', 'c'} + + assert s.pop() in {'a', 'b', 'c'} + assert len(s) == 2 + assert s.pop() in {'a', 'b', 'c'} + assert s.pop() in {'a', 'b', 'c'} + assert len(s) == 0 + + s.add('d') + assert s.pop() == 'd' + +def set_pop_int(): + s: set[i32] = {1, 2, 3} + + assert s.pop() in {1, 2, 3} + assert len(s) == 2 + assert s.pop() in {1, 2, 3} + assert s.pop() in {1, 2, 3} + assert len(s) == 0 + + s.add(4) + assert s.pop() == 4 + +set_pop_str() +set_pop_int() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 1a1d3a597c..fc00b0b142 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1588,6 +1588,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor LLVM::is_llvm_struct(dict_type->m_value_type)); } + void visit_SetPop(const ASR::SetPop_t& x) { + ASR::Set_t* set_type = ASR::down_cast( + ASRUtils::expr_type(x.m_a)); + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pset = tmp; + + ptr_loads = ptr_loads_copy; + + llvm_utils->set_set_api(set_type); + tmp = llvm_utils->set_api->pop_item(pset, *module, set_type->m_type); + } + + void visit_ListLen(const ASR::ListLen_t& x) { if (x.m_value) { this->visit_expr(*x.m_value); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 7ae03d5fa4..e5cd47ace4 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -6714,6 +6714,180 @@ namespace LCompilers { LLVM::CreateStore(*builder, occupancy, occupancy_ptr); } + llvm::Value* LLVMSetLinearProbing::pop_item(llvm::Value *set, llvm::Module &module, + ASR::ttype_t *el_asr_type) { + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), [=]() {}, [&]() { + std::string message = "The set is empty"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + get_builder0(); + llvm::AllocaInst *pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + llvm::Value *el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value *el_list = get_el_list(set); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + current_capacity, + LLVM::CreateLoad(*builder, pos_ptr) + ); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, pos)); + llvm::Value* is_el_skip = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3))); + llvm::Value* is_el_set = builder->CreateICmpNE(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* is_el = builder->CreateAnd(is_el_set, + builder->CreateNot(is_el_skip)); + + llvm_utils->create_if_else(is_el, [&]() { + llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos); + llvm::Value* tombstone_marker = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)); + LLVM::CreateStore(*builder, tombstone_marker, el_mask_i); + occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); + }, [=]() { + LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr); + }); + builder->CreateCondBr(is_el, loopend, loophead); + } + + // end + llvm_utils->start_new_block(loopend); + + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value *el = llvm_utils->list_api->read_item(el_list, pos, false, module, + LLVM::is_llvm_struct(el_asr_type)); + return el; + } + + llvm::Value* LLVMSetSeparateChaining::pop_item(llvm::Value *set, llvm::Module &module, + ASR::ttype_t *el_asr_type) { + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), []() {}, [&]() { + std::string message = "The set is empty"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + + get_builder0(); + llvm::AllocaInst* chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + llvm::AllocaInst* found_ptr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + llvm::AllocaInst* pos = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + current_capacity, + LLVM::CreateLoad(*builder, pos_ptr) + ); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value *el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos))); + llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, LLVM::CreateLoad(*builder, pos)); + + llvm::Value *is_el = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + llvm_utils->create_if_else(is_el, [&]() { + llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, el_ll_i8, chain_itr); + llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]; + llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo()); + llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1)); + llvm::Value *cond = builder->CreateICmpNE( + next_el_struct, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + + llvm_utils->create_if_else(cond, [&](){ + llvm::Value *found = LLVM::CreateLoad(*builder, next_el_struct); + llvm::Value *prev = LLVM::CreateLoad(*builder, chain_itr); + found = builder->CreateBitCast(found, el_struct_type->getPointerTo()); + llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 1)); + prev = builder->CreateBitCast(prev, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 1)); + LLVM::CreateStore(*builder, found, found_ptr); + }, [&](){ + llvm::Value *found = LLVM::CreateLoad(*builder, chain_itr); + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]; + found = builder->CreateBitCast(found, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, found, found_ptr); + LLVM::CreateStore( + *builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)), + llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos)) + ); + llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr); + num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr); + }); + }, [&]() { + }); + LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr); + builder->CreateCondBr(is_el, loopend, loophead); + } + + llvm::Value *el = llvm_utils->create_ptr_gep(LLVM::CreateLoad(*builder, pos_ptr), 0); + + if (LLVM::is_llvm_struct(el_asr_type)) { + return el; + } else { + return LLVM::CreateLoad(*builder, el); + } + } + void LLVMSetLinearProbing::set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module, diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 2346d65088..16ba263769 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -1004,6 +1004,9 @@ namespace LCompilers { llvm::Value* set, llvm::Value* el, llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0; + virtual + llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type) = 0; + virtual void set_deepcopy( llvm::Value* src, llvm::Value* dest, @@ -1077,6 +1080,8 @@ namespace LCompilers { llvm::Value* set, llvm::Value* el, llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); + llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type); + void set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module, @@ -1160,6 +1165,8 @@ namespace LCompilers { llvm::Value* set, llvm::Value* el, llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); + llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type); + void set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module,