Skip to content

Commit b5a0ce9

Browse files
authored
Merge pull request #2253 from kabra1110/nested_dict
Fix issues with nested `dict`
2 parents 93d36b0 + a7cdc18 commit b5a0ce9

File tree

8 files changed

+209
-87
lines changed

8 files changed

+209
-87
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ RUN(NAME test_dict_13 LABELS cpython llvm c)
543543
RUN(NAME test_dict_bool LABELS cpython llvm)
544544
RUN(NAME test_dict_increment LABELS cpython llvm)
545545
RUN(NAME test_dict_keys_values LABELS cpython llvm)
546+
RUN(NAME test_dict_nested1 LABELS cpython llvm)
546547
RUN(NAME test_set_len LABELS cpython llvm)
547548
RUN(NAME test_set_add LABELS cpython llvm)
548549
RUN(NAME test_set_remove LABELS cpython llvm)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from lpython import i32
2+
3+
def test_nested_dict():
4+
d: dict[i32, dict[i32, i32]] = {1001: {2002: 3003}, 1002: {101: 2}}
5+
d[1001] = d[1002]
6+
d[1001][100] = 4005
7+
assert d[1001][100] == 4005
8+
9+
test_nested_dict()

src/libasr/codegen/llvm_utils.cpp

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,9 @@ namespace LCompilers {
14331433
return builder->CreateICmpEQ(left, right);
14341434
}
14351435
case ASR::ttypeType::Logical: {
1436-
return builder->CreateICmpEQ(left, right);
1436+
llvm::Value* left_i32 = builder->CreateZExt(left, llvm::Type::getInt32Ty(context));
1437+
llvm::Value* right_i32 = builder->CreateZExt(right, llvm::Type::getInt32Ty(context));
1438+
return builder->CreateICmpEQ(left_i32, right_i32);
14371439
}
14381440
case ASR::ttypeType::Real: {
14391441
return builder->CreateFCmpOEQ(left, right);
@@ -1515,6 +1517,10 @@ namespace LCompilers {
15151517
switch( asr_type->type ) {
15161518
case ASR::ttypeType::Integer:
15171519
case ASR::ttypeType::Logical: {
1520+
if( asr_type->type == ASR::ttypeType::Logical ) {
1521+
left = builder->CreateZExt(left, llvm::Type::getInt32Ty(context));
1522+
right = builder->CreateZExt(right, llvm::Type::getInt32Ty(context));
1523+
}
15181524
switch( overload_id ) {
15191525
case 0: {
15201526
pred = llvm::CmpInst::Predicate::ICMP_SLT;
@@ -1640,7 +1646,7 @@ namespace LCompilers {
16401646
overload_id, int32_type);
16411647
}
16421648
default: {
1643-
throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " +
1649+
throw LCompilersException("LLVMUtils::is_ineq_by_value isn't implemented for " +
16441650
ASRUtils::type_to_str_python(asr_type));
16451651
}
16461652
}
@@ -1705,7 +1711,7 @@ namespace LCompilers {
17051711
}
17061712
case ASR::ttypeType::Dict: {
17071713
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(asr_type);
1708-
// set dict api here?
1714+
set_dict_api(dict_type);
17091715
dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx);
17101716
break ;
17111717
}
@@ -2514,6 +2520,46 @@ namespace LCompilers {
25142520
llvm::Value* key, llvm::Value* key_list,
25152521
llvm::Value* key_mask, llvm::Module& module,
25162522
ASR::ttype_t* key_asr_type, bool for_read) {
2523+
2524+
/**
2525+
* C++ equivalent:
2526+
*
2527+
* pos = key_hash;
2528+
*
2529+
* while( true ) {
2530+
* is_key_skip = key_mask_value == 3; // tombstone
2531+
* is_key_set = key_mask_value != 0;
2532+
* is_key_matching = 0;
2533+
*
2534+
* compare_keys = is_key_set && !is_key_skip;
2535+
* if( compare_keys ) {
2536+
* original_key = key_list[pos];
2537+
* is_key_matching = key == original_key;
2538+
* }
2539+
*
2540+
* cond;
2541+
* if( for_read ) {
2542+
* // for reading, continue to next pos
2543+
* // even if current pos is tombstone
2544+
* cond = (is_key_set && !is_key_matching) || is_key_skip;
2545+
* }
2546+
* else {
2547+
* // for writing, do not continue
2548+
* // if current pos is tombstone
2549+
* cond = is_key_set && !is_key_matching && !is_key_skip;
2550+
* }
2551+
*
2552+
* if( cond ) {
2553+
* pos += 1;
2554+
* pos %= capacity;
2555+
* }
2556+
* else {
2557+
* break;
2558+
* }
2559+
* }
2560+
*
2561+
*/
2562+
25172563
get_builder0()
25182564
if( !for_read ) {
25192565
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
@@ -2889,8 +2935,8 @@ namespace LCompilers {
28892935
* C++ equivalent:
28902936
*
28912937
* key_mask_value = key_mask[key_hash];
2892-
* is_prob_needed = key_mask_value == 1;
2893-
* if( is_prob_needed ) {
2938+
* is_prob_not_needed = key_mask_value == 1;
2939+
* if( is_prob_not_needed ) {
28942940
* is_key_matching = key == key_list[key_hash];
28952941
* if( is_key_matching ) {
28962942
* pos = key_hash;
@@ -3290,7 +3336,15 @@ namespace LCompilers {
32903336
return tuple_hash;
32913337
}
32923338
case ASR::ttypeType::Logical: {
3293-
return builder->CreateZExt(key, llvm::Type::getInt32Ty(context));
3339+
// (int32_t)key % capacity
3340+
// modulo is required for the case when dict has a single key, `True`
3341+
llvm::Value* key_i32 = builder->CreateZExt(key, llvm::Type::getInt32Ty(context));
3342+
llvm::Value* logical_hash = builder->CreateZExtOrTrunc(
3343+
builder->CreateURem(key_i32,
3344+
builder->CreateZExtOrTrunc(capacity, key_i32->getType())),
3345+
capacity->getType()
3346+
);
3347+
return logical_hash;
32943348
}
32953349
default: {
32963350
throw LCompilersException("Hashing " + ASRUtils::type_to_str_python(key_asr_type) +
@@ -3536,23 +3590,29 @@ namespace LCompilers {
35363590
void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module,
35373591
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type,
35383592
std::map<std::string, std::map<std::string, int>>& name2memidx) {
3593+
/**
3594+
* C++ equivalent:
3595+
*
3596+
* // this condition will be true with 0 capacity too
3597+
* rehash_condition = 5 * occupancy >= 3 * capacity;
3598+
* if( rehash_condition ) {
3599+
* rehash();
3600+
* }
3601+
*
3602+
*/
3603+
35393604
llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict));
35403605
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict));
3541-
llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity,
3542-
llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)));
3543-
occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context),
3544-
llvm::APInt(32, 1)));
3545-
occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context));
3546-
capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context));
3547-
llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity);
35483606
// Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor
3549-
llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context),
3550-
llvm::APFloat((float) 0.6));
3551-
rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold));
3552-
llvm_utils->create_if_else(rehash_condition, [&]() {
3607+
// occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity
3608+
llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get(
3609+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 5)));
3610+
llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get(
3611+
llvm::Type::getInt32Ty(context), llvm::APInt(32, 3)));
3612+
llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5,
3613+
capacity_times_3), [&]() {
35533614
rehash(dict, module, key_asr_type, value_asr_type, name2memidx);
3554-
}, [=]() {
3555-
});
3615+
}, []() {});
35563616
}
35573617

35583618
void LLVMDictSeparateChaining::rehash_all_at_once_if_needed(
@@ -3586,6 +3646,7 @@ namespace LCompilers {
35863646
llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module);
35873647
this->resolve_collision_for_write(dict, key_hash, key, value, module,
35883648
key_asr_type, value_asr_type, name2memidx);
3649+
rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx);
35893650
}
35903651

35913652
void LLVMDictSeparateChaining::write_item(llvm::Value* dict, llvm::Value* key,

0 commit comments

Comments
 (0)