Skip to content

Fix issues with nested dicts #2253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ RUN(NAME test_dict_13 LABELS cpython llvm c)
RUN(NAME test_dict_bool LABELS cpython llvm)
RUN(NAME test_dict_increment LABELS cpython llvm)
RUN(NAME test_dict_keys_values LABELS cpython llvm)
RUN(NAME test_dict_nested1 LABELS cpython llvm)
RUN(NAME test_set_len LABELS cpython llvm)
RUN(NAME test_set_add LABELS cpython llvm)
RUN(NAME test_set_remove LABELS cpython llvm)
Expand Down
9 changes: 9 additions & 0 deletions integration_tests/test_dict_nested1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from lpython import i32

def test_nested_dict():
d: dict[i32, dict[i32, i32]] = {1001: {2002: 3003}, 1002: {101: 2}}
d[1001] = d[1002]
d[1001][100] = 4005
assert d[1001][100] == 4005

test_nested_dict()
99 changes: 80 additions & 19 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,9 @@ namespace LCompilers {
return builder->CreateICmpEQ(left, right);
}
case ASR::ttypeType::Logical: {
return builder->CreateICmpEQ(left, right);
llvm::Value* left_i32 = builder->CreateZExt(left, llvm::Type::getInt32Ty(context));
llvm::Value* right_i32 = builder->CreateZExt(right, llvm::Type::getInt32Ty(context));
return builder->CreateICmpEQ(left_i32, right_i32);
}
case ASR::ttypeType::Real: {
return builder->CreateFCmpOEQ(left, right);
Expand Down Expand Up @@ -1515,6 +1517,10 @@ namespace LCompilers {
switch( asr_type->type ) {
case ASR::ttypeType::Integer:
case ASR::ttypeType::Logical: {
if( asr_type->type == ASR::ttypeType::Logical ) {
left = builder->CreateZExt(left, llvm::Type::getInt32Ty(context));
right = builder->CreateZExt(right, llvm::Type::getInt32Ty(context));
}
switch( overload_id ) {
case 0: {
pred = llvm::CmpInst::Predicate::ICMP_SLT;
Expand Down Expand Up @@ -1640,7 +1646,7 @@ namespace LCompilers {
overload_id, int32_type);
}
default: {
throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " +
throw LCompilersException("LLVMUtils::is_ineq_by_value isn't implemented for " +
ASRUtils::type_to_str_python(asr_type));
}
}
Expand Down Expand Up @@ -1705,7 +1711,7 @@ namespace LCompilers {
}
case ASR::ttypeType::Dict: {
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
// set dict api here?
set_dict_api(dict_type);
dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx);
break ;
}
Expand Down Expand Up @@ -2514,6 +2520,46 @@ namespace LCompilers {
llvm::Value* key, llvm::Value* key_list,
llvm::Value* key_mask, llvm::Module& module,
ASR::ttype_t* key_asr_type, bool for_read) {

/**
* C++ equivalent:
*
* pos = key_hash;
*
* while( true ) {
* is_key_skip = key_mask_value == 3; // tombstone
* is_key_set = key_mask_value != 0;
* is_key_matching = 0;
*
* compare_keys = is_key_set && !is_key_skip;
* if( compare_keys ) {
* original_key = key_list[pos];
* is_key_matching = key == original_key;
* }
*
* cond;
* if( for_read ) {
* // for reading, continue to next pos
* // even if current pos is tombstone
* cond = (is_key_set && !is_key_matching) || is_key_skip;
* }
* else {
* // for writing, do not continue
* // if current pos is tombstone
* cond = is_key_set && !is_key_matching && !is_key_skip;
* }
*
* if( cond ) {
* pos += 1;
* pos %= capacity;
* }
* else {
* break;
* }
* }
*
*/

get_builder0()
if( !for_read ) {
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
Expand Down Expand Up @@ -2889,8 +2935,8 @@ namespace LCompilers {
* C++ equivalent:
*
* key_mask_value = key_mask[key_hash];
* is_prob_needed = key_mask_value == 1;
* if( is_prob_needed ) {
* is_prob_not_needed = key_mask_value == 1;
* if( is_prob_not_needed ) {
* is_key_matching = key == key_list[key_hash];
* if( is_key_matching ) {
* pos = key_hash;
Expand Down Expand Up @@ -3290,7 +3336,15 @@ namespace LCompilers {
return tuple_hash;
}
case ASR::ttypeType::Logical: {
return builder->CreateZExt(key, llvm::Type::getInt32Ty(context));
// (int32_t)key % capacity
// modulo is required for the case when dict has a single key, `True`
llvm::Value* key_i32 = builder->CreateZExt(key, llvm::Type::getInt32Ty(context));
llvm::Value* logical_hash = builder->CreateZExtOrTrunc(
builder->CreateURem(key_i32,
builder->CreateZExtOrTrunc(capacity, key_i32->getType())),
capacity->getType()
);
return logical_hash;
}
default: {
throw LCompilersException("Hashing " + ASRUtils::type_to_str_python(key_asr_type) +
Expand Down Expand Up @@ -3536,23 +3590,29 @@ namespace LCompilers {
void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module,
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
std::map<std::string, std::map<std::string, int>>& name2memidx) {
/**
* C++ equivalent:
*
* // this condition will be true with 0 capacity too
* rehash_condition = 5 * occupancy >= 3 * capacity;
* if( rehash_condition ) {
* rehash();
* }
*
*/

llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict));
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity,
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)));
occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
llvm::APInt(32, 1)));
occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context));
capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context));
llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity);
// Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor
llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context),
llvm::APFloat((float) 0.6));
rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold));
llvm_utils->create_if_else(rehash_condition, [&]() {
// occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity
llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 5)));
llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get(
llvm::Type::getInt32Ty(context), llvm::APInt(32, 3)));
llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5,
capacity_times_3), [&]() {
rehash(dict, module, key_asr_type, value_asr_type, name2memidx);
}, [=]() {
});
}, []() {});
}

void LLVMDictSeparateChaining::rehash_all_at_once_if_needed(
Expand Down Expand Up @@ -3586,6 +3646,7 @@ namespace LCompilers {
llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module);
this->resolve_collision_for_write(dict, key_hash, key, value, module,
key_asr_type, value_asr_type, name2memidx);
rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx);
}

void LLVMDictSeparateChaining::write_item(llvm::Value* dict, llvm::Value* key,
Expand Down
Loading