Skip to content

Commit 3fc1ce9

Browse files
advikkabraczgdp1807
authored andcommitted
Add set.pop method
1 parent f61b952 commit 3fc1ce9

File tree

5 files changed

+223
-0
lines changed

5 files changed

+223
-0
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ RUN(NAME test_set_add LABELS cpython llvm llvm_jit)
590590
RUN(NAME test_set_remove LABELS cpython llvm llvm_jit)
591591
RUN(NAME test_set_discard LABELS cpython llvm llvm_jit)
592592
RUN(NAME test_set_clear LABELS cpython llvm)
593+
RUN(NAME test_set_pop LABELS cpython llvm)
593594
RUN(NAME test_global_set LABELS cpython llvm llvm_jit)
594595
RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c)
595596
RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64)

integration_tests/test_set_pop.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
def set_pop_str():
2+
s: set[str] = {'a', 'b', 'c'}
3+
4+
assert s.pop() in {'a', 'b', 'c'}
5+
assert len(s) == 2
6+
assert s.pop() in {'a', 'b', 'c'}
7+
assert s.pop() in {'a', 'b', 'c'}
8+
assert len(s) == 0
9+
10+
s.add('d')
11+
assert s.pop() == 'd'
12+
13+
def set_pop_int():
14+
s: set[i32] = {1, 2, 3}
15+
16+
assert s.pop() in {1, 2, 3}
17+
assert len(s) == 2
18+
assert s.pop() in {1, 2, 3}
19+
assert s.pop() in {1, 2, 3}
20+
assert len(s) == 0
21+
22+
s.add(4)
23+
assert s.pop() == 4
24+
25+
set_pop_str()
26+
set_pop_int()

src/libasr/codegen/asr_to_llvm.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
15881588
LLVM::is_llvm_struct(dict_type->m_value_type));
15891589
}
15901590

1591+
void visit_SetPop(const ASR::SetPop_t& x) {
1592+
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
1593+
ASRUtils::expr_type(x.m_a));
1594+
int64_t ptr_loads_copy = ptr_loads;
1595+
ptr_loads = 0;
1596+
this->visit_expr(*x.m_a);
1597+
llvm::Value* pset = tmp;
1598+
1599+
ptr_loads = ptr_loads_copy;
1600+
1601+
llvm_utils->set_set_api(set_type);
1602+
tmp = llvm_utils->set_api->pop_item(pset, *module, set_type->m_type);
1603+
}
1604+
1605+
15911606
void visit_ListLen(const ASR::ListLen_t& x) {
15921607
if (x.m_value) {
15931608
this->visit_expr(*x.m_value);

src/libasr/codegen/llvm_utils.cpp

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6714,6 +6714,180 @@ namespace LCompilers {
67146714
LLVM::CreateStore(*builder, occupancy, occupancy_ptr);
67156715
}
67166716

6717+
llvm::Value* LLVMSetLinearProbing::pop_item(llvm::Value *set, llvm::Module &module,
6718+
ASR::ttype_t *el_asr_type) {
6719+
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
6720+
6721+
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set);
6722+
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
6723+
llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), [=]() {}, [&]() {
6724+
std::string message = "The set is empty";
6725+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6726+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6727+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6728+
int exit_code_int = 1;
6729+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6730+
llvm::APInt(32, exit_code_int));
6731+
exit(context, module, *builder, exit_code);
6732+
});
6733+
get_builder0();
6734+
llvm::AllocaInst *pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
6735+
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);
6736+
6737+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
6738+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
6739+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
6740+
6741+
llvm::Value *el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
6742+
llvm::Value *el_list = get_el_list(set);
6743+
6744+
// head
6745+
llvm_utils->start_new_block(loophead);
6746+
{
6747+
llvm::Value *cond = builder->CreateICmpSGT(
6748+
current_capacity,
6749+
LLVM::CreateLoad(*builder, pos_ptr)
6750+
);
6751+
builder->CreateCondBr(cond, loopbody, loopend);
6752+
}
6753+
6754+
// body
6755+
llvm_utils->start_new_block(loopbody);
6756+
{
6757+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
6758+
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
6759+
llvm_utils->create_ptr_gep(el_mask, pos));
6760+
llvm::Value* is_el_skip = builder->CreateICmpEQ(el_mask_value,
6761+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)));
6762+
llvm::Value* is_el_set = builder->CreateICmpNE(el_mask_value,
6763+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)));
6764+
llvm::Value* is_el = builder->CreateAnd(is_el_set,
6765+
builder->CreateNot(is_el_skip));
6766+
6767+
llvm_utils->create_if_else(is_el, [&]() {
6768+
llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos);
6769+
llvm::Value* tombstone_marker = llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3));
6770+
LLVM::CreateStore(*builder, tombstone_marker, el_mask_i);
6771+
occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get(
6772+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
6773+
LLVM::CreateStore(*builder, occupancy, occupancy_ptr);
6774+
}, [=]() {
6775+
LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get(
6776+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr);
6777+
});
6778+
builder->CreateCondBr(is_el, loopend, loophead);
6779+
}
6780+
6781+
// end
6782+
llvm_utils->start_new_block(loopend);
6783+
6784+
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
6785+
llvm::Value *el = llvm_utils->list_api->read_item(el_list, pos, false, module,
6786+
LLVM::is_llvm_struct(el_asr_type));
6787+
return el;
6788+
}
6789+
6790+
llvm::Value* LLVMSetSeparateChaining::pop_item(llvm::Value *set, llvm::Module &module,
6791+
ASR::ttype_t *el_asr_type) {
6792+
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
6793+
llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set);
6794+
llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr);
6795+
llvm_utils->create_if_else(builder->CreateICmpNE(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0)), []() {}, [&]() {
6796+
std::string message = "The set is empty";
6797+
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
6798+
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
6799+
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
6800+
int exit_code_int = 1;
6801+
llvm::Value *exit_code = llvm::ConstantInt::get(context,
6802+
llvm::APInt(32, exit_code_int));
6803+
exit(context, module, *builder, exit_code);
6804+
});
6805+
6806+
get_builder0();
6807+
llvm::AllocaInst* chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
6808+
llvm::AllocaInst* found_ptr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr);
6809+
llvm::AllocaInst* pos = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
6810+
LLVM::CreateStore(*builder,
6811+
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos);
6812+
6813+
llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head");
6814+
llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body");
6815+
llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end");
6816+
6817+
llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set));
6818+
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
6819+
6820+
// head
6821+
llvm_utils->start_new_block(loophead);
6822+
{
6823+
llvm::Value *cond = builder->CreateICmpSGT(
6824+
current_capacity,
6825+
LLVM::CreateLoad(*builder, pos_ptr)
6826+
);
6827+
builder->CreateCondBr(cond, loopbody, loopend);
6828+
}
6829+
6830+
// body
6831+
llvm_utils->start_new_block(loopbody);
6832+
{
6833+
llvm::Value *el_mask_value = LLVM::CreateLoad(*builder,
6834+
llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos)));
6835+
llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, LLVM::CreateLoad(*builder, pos));
6836+
6837+
llvm::Value *is_el = builder->CreateICmpEQ(el_mask_value,
6838+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
6839+
llvm_utils->create_if_else(is_el, [&]() {
6840+
llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context));
6841+
LLVM::CreateStore(*builder, el_ll_i8, chain_itr);
6842+
llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
6843+
llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)];
6844+
llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo());
6845+
llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1));
6846+
llvm::Value *cond = builder->CreateICmpNE(
6847+
next_el_struct,
6848+
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))
6849+
);
6850+
6851+
llvm_utils->create_if_else(cond, [&](){
6852+
llvm::Value *found = LLVM::CreateLoad(*builder, next_el_struct);
6853+
llvm::Value *prev = LLVM::CreateLoad(*builder, chain_itr);
6854+
found = builder->CreateBitCast(found, el_struct_type->getPointerTo());
6855+
llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 1));
6856+
prev = builder->CreateBitCast(prev, el_struct_type->getPointerTo());
6857+
LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 1));
6858+
LLVM::CreateStore(*builder, found, found_ptr);
6859+
}, [&](){
6860+
llvm::Value *found = LLVM::CreateLoad(*builder, chain_itr);
6861+
llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)];
6862+
found = builder->CreateBitCast(found, el_struct_type->getPointerTo());
6863+
LLVM::CreateStore(*builder, found, found_ptr);
6864+
LLVM::CreateStore(
6865+
*builder,
6866+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)),
6867+
llvm_utils->create_ptr_gep(el_mask, LLVM::CreateLoad(*builder, pos))
6868+
);
6869+
llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set);
6870+
llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr);
6871+
num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get(
6872+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1)));
6873+
LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr);
6874+
});
6875+
}, [&]() {
6876+
});
6877+
LLVM::CreateStore(*builder, builder->CreateAdd(pos, llvm::ConstantInt::get(
6878+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))), pos_ptr);
6879+
builder->CreateCondBr(is_el, loopend, loophead);
6880+
}
6881+
6882+
llvm::Value *el = llvm_utils->create_ptr_gep(LLVM::CreateLoad(*builder, pos_ptr), 0);
6883+
6884+
if (LLVM::is_llvm_struct(el_asr_type)) {
6885+
return el;
6886+
} else {
6887+
return LLVM::CreateLoad(*builder, el);
6888+
}
6889+
}
6890+
67176891
void LLVMSetLinearProbing::set_deepcopy(
67186892
llvm::Value* src, llvm::Value* dest,
67196893
ASR::Set_t* set_type, llvm::Module* module,

src/libasr/codegen/llvm_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,9 @@ namespace LCompilers {
10041004
llvm::Value* set, llvm::Value* el,
10051005
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;
10061006

1007+
virtual
1008+
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;
1009+
10071010
virtual
10081011
void set_deepcopy(
10091012
llvm::Value* src, llvm::Value* dest,
@@ -1077,6 +1080,8 @@ namespace LCompilers {
10771080
llvm::Value* set, llvm::Value* el,
10781081
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
10791082

1083+
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type);
1084+
10801085
void set_deepcopy(
10811086
llvm::Value* src, llvm::Value* dest,
10821087
ASR::Set_t* set_type, llvm::Module* module,
@@ -1160,6 +1165,8 @@ namespace LCompilers {
11601165
llvm::Value* set, llvm::Value* el,
11611166
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);
11621167

1168+
llvm::Value* pop_item(llvm::Value* set, llvm::Module& module, ASR::ttype_t* el_asr_type);
1169+
11631170
void set_deepcopy(
11641171
llvm::Value* src, llvm::Value* dest,
11651172
ASR::Set_t* set_type, llvm::Module* module,

0 commit comments

Comments
 (0)