Skip to content

Add set.pop method #2749

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ RUN(NAME test_set_add LABELS cpython llvm llvm_jit)
RUN(NAME test_set_remove LABELS cpython llvm llvm_jit)
RUN(NAME test_set_discard LABELS cpython llvm llvm_jit)
RUN(NAME test_set_clear LABELS cpython llvm)
RUN(NAME test_set_pop LABELS cpython llvm)
RUN(NAME test_global_set LABELS cpython llvm llvm_jit)
RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c)
RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64)
Expand Down
26 changes: 26 additions & 0 deletions integration_tests/test_set_pop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
def set_pop_str():
s: set[str] = {'a', 'b', 'c'}

assert s.pop() in {'a', 'b', 'c'}
assert len(s) == 2
assert s.pop() in {'a', 'b', 'c'}
assert s.pop() in {'a', 'b', 'c'}
assert len(s) == 0

s.add('d')
assert s.pop() == 'd'

def set_pop_int():
s: set[i32] = {1, 2, 3}

assert s.pop() in {1, 2, 3}
assert len(s) == 2
assert s.pop() in {1, 2, 3}
assert s.pop() in {1, 2, 3}
assert len(s) == 0

s.add(4)
assert s.pop() == 4

set_pop_str()
set_pop_int()
15 changes: 15 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
LLVM::is_llvm_struct(dict_type->m_value_type));
}

void visit_SetPop(const ASR::SetPop_t& x) {
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
ASRUtils::expr_type(x.m_a));
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_a);
llvm::Value* pset = tmp;

ptr_loads = ptr_loads_copy;

llvm_utils->set_set_api(set_type);
tmp = llvm_utils->set_api->pop_item(pset, *module, set_type->m_type);
}


void visit_ListLen(const ASR::ListLen_t& x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
Expand Down
174 changes: 174 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6714,6 +6714,180 @@ namespace LCompilers {
LLVM::CreateStore(*builder, occupancy, occupancy_ptr);
}

llvm::Value* LLVMSetLinearProbing::pop_item(llvm::Value *set, llvm::Module &module,
ASR::ttype_t *el_asr_type) {
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));

llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set);
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), [=]() {}, [&]() {
std::string message = "The set is empty";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
});
get_builder0();
llvm::AllocaInst *pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

llvm::Value *el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
llvm::Value *el_list = get_el_list(set);

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(
current_capacity,
LLVM::CreateLoad(*builder, pos_ptr)
);
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(el_mask, pos));
llvm::Value* is_el_skip = builder->CreateICmpEQ(el_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
llvm::Value* is_el_set = builder->CreateICmpNE(el_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));
llvm::Value* is_el = builder->CreateAnd(is_el_set,
builder->CreateNot(is_el_skip));

llvm_utils->create_if_else(is_el, [&]() {
llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos);
llvm::Value* tombstone_marker = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3));
LLVM::CreateStore(*builder, tombstone_marker, el_mask_i);
occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, occupancy, occupancy_ptr);
}, [=]() {
LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr);
});
builder->CreateCondBr(is_el, loopend, loophead);
}

// end
llvm_utils->start_new_block(loopend);

llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
llvm::Value *el = llvm_utils->list_api->read_item(el_list, pos, false, module,
LLVM::is_llvm_struct(el_asr_type));
return el;
}

llvm::Value* LLVMSetSeparateChaining::pop_item(llvm::Value *set, llvm::Module &module,
ASR::ttype_t *el_asr_type) {
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set);
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), []() {}, [&]() {
std::string message = "The set is empty";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
});

get_builder0();
llvm::AllocaInst* chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
llvm::AllocaInst* found_ptr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
llvm::AllocaInst* pos = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
LLVM::CreateStore(*builder,
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos);

llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");

llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set));
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));

// head
llvm_utils->start_new_block(loophead);
{
llvm::Value *cond = builder->CreateICmpSGT(
current_capacity,
LLVM::CreateLoad(*builder, pos_ptr)
);
builder->CreateCondBr(cond, loopbody, loopend);
}

// body
llvm_utils->start_new_block(loopbody);
{
llvm::Value *el_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos)));
llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, LLVM::CreateLoad(*builder, pos));

llvm::Value *is_el = builder->CreateICmpEQ(el_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
llvm_utils->create_if_else(is_el, [&]() {
llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context));
LLVM::CreateStore(*builder, el_ll_i8, chain_itr);
llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)];
llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo());
llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1));
llvm::Value *cond = builder->CreateICmpNE(
next_el_struct,
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
);

llvm_utils->create_if_else(cond, [&](){
llvm::Value *found = LLVM::CreateLoad(*builder, next_el_struct);
llvm::Value *prev = LLVM::CreateLoad(*builder, chain_itr);
found = builder->CreateBitCast(found, el_struct_type->getPointerTo());
llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 1));
prev = builder->CreateBitCast(prev, el_struct_type->getPointerTo());
LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 1));
LLVM::CreateStore(*builder, found, found_ptr);
}, [&](){
llvm::Value *found = LLVM::CreateLoad(*builder, chain_itr);
llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)];
found = builder->CreateBitCast(found, el_struct_type->getPointerTo());
LLVM::CreateStore(*builder, found, found_ptr);
LLVM::CreateStore(
*builder,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)),
llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos))
);
llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set);
llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr);
num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr);
});
}, [&]() {
});
LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr);
builder->CreateCondBr(is_el, loopend, loophead);
}

llvm::Value *el = llvm_utils->create_ptr_gep(LLVM::CreateLoad(*builder, pos_ptr), 0);

if (LLVM::is_llvm_struct(el_asr_type)) {
return el;
} else {
return LLVM::CreateLoad(*builder, el);
}
}

void LLVMSetLinearProbing::set_deepcopy(
llvm::Value* src, llvm::Value* dest,
ASR::Set_t* set_type, llvm::Module* module,
Expand Down
7 changes: 7 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,9 @@ namespace LCompilers {
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;

virtual
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;

virtual
void set_deepcopy(
llvm::Value* src, llvm::Value* dest,
Expand Down Expand Up @@ -1077,6 +1080,8 @@ namespace LCompilers {
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);

llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type);

void set_deepcopy(
llvm::Value* src, llvm::Value* dest,
ASR::Set_t* set_type, llvm::Module* module,
Expand Down Expand Up @@ -1160,6 +1165,8 @@ namespace LCompilers {
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);

llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type);

void set_deepcopy(
llvm::Value* src, llvm::Value* dest,
ASR::Set_t* set_type, llvm::Module* module,
Expand Down
Loading