diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index df1fefe7a3..4212f5806d 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -543,6 +543,7 @@ RUN(NAME test_dict_13 LABELS cpython llvm c) RUN(NAME test_dict_bool LABELS cpython llvm) RUN(NAME test_dict_increment LABELS cpython llvm) RUN(NAME test_dict_keys_values LABELS cpython llvm) +RUN(NAME test_dict_nested1 LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm) RUN(NAME test_set_add LABELS cpython llvm) RUN(NAME test_set_remove LABELS cpython llvm) diff --git a/integration_tests/test_dict_nested1.py b/integration_tests/test_dict_nested1.py new file mode 100644 index 0000000000..14de899610 --- /dev/null +++ b/integration_tests/test_dict_nested1.py @@ -0,0 +1,9 @@ +from lpython import i32 + +def test_nested_dict(): + d: dict[i32, dict[i32, i32]] = {1001: {2002: 3003}, 1002: {101: 2}} + d[1001] = d[1002] + d[1001][100] = 4005 + assert d[1001][100] == 4005 + +test_nested_dict() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index cefce251ab..76b9a40291 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1433,7 +1433,9 @@ namespace LCompilers { return builder->CreateICmpEQ(left, right); } case ASR::ttypeType::Logical: { - return builder->CreateICmpEQ(left, right); + llvm::Value* left_i32 = builder->CreateZExt(left, llvm::Type::getInt32Ty(context)); + llvm::Value* right_i32 = builder->CreateZExt(right, llvm::Type::getInt32Ty(context)); + return builder->CreateICmpEQ(left_i32, right_i32); } case ASR::ttypeType::Real: { return builder->CreateFCmpOEQ(left, right); @@ -1515,6 +1517,10 @@ namespace LCompilers { switch( asr_type->type ) { case ASR::ttypeType::Integer: case ASR::ttypeType::Logical: { + if( asr_type->type == ASR::ttypeType::Logical ) { + left = builder->CreateZExt(left, llvm::Type::getInt32Ty(context)); + right = builder->CreateZExt(right, llvm::Type::getInt32Ty(context)); + } switch( overload_id ) { case 0: { pred = llvm::CmpInst::Predicate::ICMP_SLT; @@ -1640,7 +1646,7 @@ namespace LCompilers { overload_id, int32_type); } default: { - throw LCompilersException("LLVMUtils::is_equal_by_value isn't implemented for " + + throw LCompilersException("LLVMUtils::is_ineq_by_value isn't implemented for " + ASRUtils::type_to_str_python(asr_type)); } } @@ -1705,7 +1711,7 @@ namespace LCompilers { } case ASR::ttypeType::Dict: { ASR::Dict_t* dict_type = ASR::down_cast(asr_type); - // set dict api here? + set_dict_api(dict_type); dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx); break ; } @@ -2514,6 +2520,46 @@ namespace LCompilers { llvm::Value* key, llvm::Value* key_list, llvm::Value* key_mask, llvm::Module& module, ASR::ttype_t* key_asr_type, bool for_read) { + + /** + * C++ equivalent: + * + * pos = key_hash; + * + * while( true ) { + * is_key_skip = key_mask_value == 3; // tombstone + * is_key_set = key_mask_value != 0; + * is_key_matching = 0; + * + * compare_keys = is_key_set && !is_key_skip; + * if( compare_keys ) { + * original_key = key_list[pos]; + * is_key_matching = key == original_key; + * } + * + * cond; + * if( for_read ) { + * // for reading, continue to next pos + * // even if current pos is tombstone + * cond = (is_key_set && !is_key_matching) || is_key_skip; + * } + * else { + * // for writing, do not continue + * // if current pos is tombstone + * cond = is_key_set && !is_key_matching && !is_key_skip; + * } + * + * if( cond ) { + * pos += 1; + * pos %= capacity; + * } + * else { + * break; + * } + * } + * + */ + get_builder0() if( !for_read ) { pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); @@ -2889,8 +2935,8 @@ namespace LCompilers { * C++ equivalent: * * key_mask_value = key_mask[key_hash]; - * is_prob_needed = key_mask_value == 1; - * if( is_prob_needed ) { + * is_prob_not_needed = key_mask_value == 1; + * if( is_prob_not_needed ) { * is_key_matching = key == key_list[key_hash]; * if( is_key_matching ) { * pos = key_hash; @@ -3290,7 +3336,15 @@ namespace LCompilers { return tuple_hash; } case ASR::ttypeType::Logical: { - return builder->CreateZExt(key, llvm::Type::getInt32Ty(context)); + // (int32_t)key % capacity + // modulo is required for the case when dict has a single key, `True` + llvm::Value* key_i32 = builder->CreateZExt(key, llvm::Type::getInt32Ty(context)); + llvm::Value* logical_hash = builder->CreateZExtOrTrunc( + builder->CreateURem(key_i32, + builder->CreateZExtOrTrunc(capacity, key_i32->getType())), + capacity->getType() + ); + return logical_hash; } default: { throw LCompilersException("Hashing " + ASRUtils::type_to_str_python(key_asr_type) + @@ -3536,23 +3590,29 @@ namespace LCompilers { void LLVMDict::rehash_all_at_once_if_needed(llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { + /** + * C++ equivalent: + * + * // this condition will be true with 0 capacity too + * rehash_condition = 5 * occupancy >= 3 * capacity; + * if( rehash_condition ) { + * rehash(); + * } + * + */ + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity, - llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))); - occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, 1))); - occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context)); - capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context)); - llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity); // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor - llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), - llvm::APFloat((float) 0.6)); - rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold)); - llvm_utils->create_if_else(rehash_condition, [&]() { + // occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity + llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 5))); + llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 3))); + llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5, + capacity_times_3), [&]() { rehash(dict, module, key_asr_type, value_asr_type, name2memidx); - }, [=]() { - }); + }, []() {}); } void LLVMDictSeparateChaining::rehash_all_at_once_if_needed( @@ -3586,6 +3646,7 @@ namespace LCompilers { llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); this->resolve_collision_for_write(dict, key_hash, key, value, module, key_asr_type, value_asr_type, name2memidx); + rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx); } void LLVMDictSeparateChaining::write_item(llvm::Value* dict, llvm::Value* key, diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 5d9816c986..8de852e00d 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -5027,6 +5027,97 @@ class BodyVisitor : public CommonVisitor { } } + bool visit_SubscriptUtil(const AST::Subscript_t &x, const AST::Assign_t &assign_node, + ASR::expr_t *tmp_value, int32_t recursion_level) { + if (AST::is_a(*x.m_value)) { + std::string name = AST::down_cast(x.m_value)->m_id; + ASR::symbol_t *s = current_scope->resolve_symbol(name); + if (!s) { + throw SemanticError("Variable: '" + name + "' is not declared", + x.base.base.loc); + } + ASR::Variable_t *v = ASR::down_cast(s); + ASR::ttype_t *type = v->m_type; + if (ASR::is_a(*type)) { + this->visit_expr(*x.m_slice); + ASR::expr_t *key = ASRUtils::EXPR(tmp); + ASR::expr_t* se = ASR::down_cast( + ASR::make_Var_t(al, x.base.base.loc, s)); + if( recursion_level == 0 ) { + // dict insert case; + ASR::ttype_t *key_type = ASR::down_cast(type)->m_key_type; + ASR::ttype_t *value_type = ASR::down_cast(type)->m_value_type; + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) { + std::string ktype = ASRUtils::type_to_str_python(ASRUtils::expr_type(key)); + std::string totype = ASRUtils::type_to_str_python(key_type); + diag.add(diag::Diagnostic( + "Type mismatch in dictionary key, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch (found: '" + ktype + "', expected: '" + totype + "')", + {key->base.loc}) + }) + ); + throw SemanticAbort(); + } + if (tmp_value == nullptr) { + if (AST::is_a(*assign_node.m_value)) { + LCOMPILERS_ASSERT(AST::down_cast(assign_node.m_value)->n_elts == 0); + Vec list_ele; + list_ele.reserve(al, 1); + tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, assign_node.base.base.loc, + list_ele.p, list_ele.size(), value_type)); + } else if (AST::is_a(*assign_node.m_value)) { + LCOMPILERS_ASSERT(AST::down_cast(assign_node.m_value)->n_keys == 0); + Vec dict_ele; + dict_ele.reserve(al, 1); + tmp_value = ASRUtils::EXPR(ASR::make_DictConstant_t(al, assign_node.base.base.loc, + dict_ele.p, dict_ele.size(), dict_ele.p, dict_ele.size(), value_type)); + } + } + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) { + std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value)); + std::string totype = ASRUtils::type_to_str_python(value_type); + diag.add(diag::Diagnostic( + "Type mismatch in dictionary value, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')", + {tmp_value->base.loc}) + }) + ); + throw SemanticAbort(); + } + tmp = nullptr; + tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value)); + } + else { + tmp = make_DictItem_t(al, x.base.base.loc, se, key, nullptr, + ASR::down_cast(type)->m_value_type, nullptr); + } + return true; + } else if (ASRUtils::is_immutable(type)) { + throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support" + " item assignment", x.base.base.loc); + } + } else if( AST::is_a(*x.m_value) ) { + AST::Subscript_t *sb = AST::down_cast(x.m_value); + bool return_val = visit_SubscriptUtil(*sb, assign_node, tmp_value, recursion_level + 1); + if( return_val && tmp ) { + ASR::expr_t *dict = ASRUtils::EXPR(tmp); + this->visit_expr(*x.m_slice); + ASR::expr_t *key = ASRUtils::EXPR(tmp); + if( recursion_level == 0 ) { + tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, dict, key, tmp_value)); + } + else { + tmp = make_DictItem_t(al, x.base.base.loc, dict, key, nullptr, + ASR::down_cast(ASRUtils::expr_type(dict))->m_value_type, nullptr); + } + } + return return_val; + } + return false; + } + void visit_Assign(const AST::Assign_t &x) { ASR::expr_t *target, *assign_value = nullptr, *tmp_value; bool is_c_p_pointer_call_copy = is_c_p_pointer_call; @@ -5061,61 +5152,8 @@ class BodyVisitor : public CommonVisitor { check_is_assign_to_input_param(x.m_targets[i]); if (AST::is_a(*x.m_targets[i])) { AST::Subscript_t *sb = AST::down_cast(x.m_targets[i]); - if (AST::is_a(*sb->m_value)) { - std::string name = AST::down_cast(sb->m_value)->m_id; - ASR::symbol_t *s = current_scope->resolve_symbol(name); - if (!s) { - throw SemanticError("Variable: '" + name + "' is not declared", - x.base.base.loc); - } - ASR::Variable_t *v = ASR::down_cast(s); - ASR::ttype_t *type = v->m_type; - if (ASR::is_a(*type)) { - // dict insert case; - this->visit_expr(*sb->m_slice); - ASR::expr_t *key = ASRUtils::EXPR(tmp); - ASR::ttype_t *key_type = ASR::down_cast(type)->m_key_type; - ASR::ttype_t *value_type = ASR::down_cast(type)->m_value_type; - if (!ASRUtils::check_equal_type(ASRUtils::expr_type(key), key_type)) { - std::string ktype = ASRUtils::type_to_str_python(ASRUtils::expr_type(key)); - std::string totype = ASRUtils::type_to_str_python(key_type); - diag.add(diag::Diagnostic( - "Type mismatch in dictionary key, the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch (found: '" + ktype + "', expected: '" + totype + "')", - {key->base.loc}) - }) - ); - throw SemanticAbort(); - } - if (tmp_value == nullptr && AST::is_a(*x.m_value)) { - LCOMPILERS_ASSERT(AST::down_cast(x.m_value)->n_elts == 0); - Vec list_ele; - list_ele.reserve(al, 1); - tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, list_ele.p, - list_ele.size(), value_type)); - } - if (!ASRUtils::check_equal_type(ASRUtils::expr_type(tmp_value), value_type)) { - std::string vtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(tmp_value)); - std::string totype = ASRUtils::type_to_str_python(value_type); - diag.add(diag::Diagnostic( - "Type mismatch in dictionary value, the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch (found: '" + vtype + "', expected: '" + totype + "')", - {tmp_value->base.loc}) - }) - ); - throw SemanticAbort(); - } - ASR::expr_t* se = ASR::down_cast( - ASR::make_Var_t(al, x.base.base.loc, s)); - tmp = nullptr; - tmp_vec.push_back(make_DictInsert_t(al, x.base.base.loc, se, key, tmp_value)); - continue; - } else if (ASRUtils::is_immutable(type)) { - throw SemanticError("'" + ASRUtils::type_to_str_python(type) + "' object does not support" - " item assignment", x.base.base.loc); - } + if( visit_SubscriptUtil(*sb, x, tmp_value, 0) ) { + continue; } } else if (AST::is_a(*x.m_targets[i])) { AST::Attribute_t *attr = AST::down_cast(x.m_targets[i]); @@ -5136,12 +5174,20 @@ class BodyVisitor : public CommonVisitor { this->visit_expr(*x.m_targets[i]); target = ASRUtils::EXPR(tmp); ASR::ttype_t *target_type = ASRUtils::expr_type(target); - if (tmp_value == nullptr && AST::is_a(*x.m_value)) { - LCOMPILERS_ASSERT(AST::down_cast(x.m_value)->n_elts == 0); - Vec list_ele; - list_ele.reserve(al, 1); - tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, list_ele.p, - list_ele.size(), target_type)); + if (tmp_value == nullptr) { + if (AST::is_a(*x.m_value)) { + LCOMPILERS_ASSERT(AST::down_cast(x.m_value)->n_elts == 0); + Vec list_ele; + list_ele.reserve(al, 1); + tmp_value = ASRUtils::EXPR(ASR::make_ListConstant_t(al, x.base.base.loc, list_ele.p, + list_ele.size(), target_type)); + } else if (AST::is_a(*x.m_value)) { + LCOMPILERS_ASSERT(AST::down_cast(x.m_value)->n_keys == 0); + Vec dict_ele; + dict_ele.reserve(al, 1); + tmp_value = ASRUtils::EXPR(ASR::make_DictConstant_t(al, x.base.base.loc, dict_ele.p, + dict_ele.size(), dict_ele.p, dict_ele.size(), target_type)); + } } if (tmp_value == nullptr && ASR::is_a(*target)) { ASR::Var_t *var_tar = ASR::down_cast(target); @@ -6023,9 +6069,14 @@ class BodyVisitor : public CommonVisitor { void visit_Dict(const AST::Dict_t &x) { LCOMPILERS_ASSERT(x.n_keys == x.n_values); - if( x.n_keys == 0 && ann_assign_target_type != nullptr ) { - tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0, - nullptr, 0, ann_assign_target_type); + if( x.n_keys == 0 ) { + if( ann_assign_target_type != nullptr ) { + tmp = ASR::make_DictConstant_t(al, x.base.base.loc, nullptr, 0, + nullptr, 0, ann_assign_target_type); + } + else { + tmp = nullptr; + } return ; } Vec keys; diff --git a/tests/reference/asr-test_assign6-05cd64f.json b/tests/reference/asr-test_assign6-05cd64f.json index 765658fda0..4bab9d7802 100644 --- a/tests/reference/asr-test_assign6-05cd64f.json +++ b/tests/reference/asr-test_assign6-05cd64f.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_assign6-05cd64f.stderr", - "stderr_hash": "294865737572b9ab043b8ebab73fe949fa2bb73e9790c6a04d87dc50", + "stderr_hash": "5bc5e0f7454a31bb924cf1318c59e73da2446502181b92faffd9f5d4", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_assign6-05cd64f.stderr b/tests/reference/asr-test_assign6-05cd64f.stderr index b9594977bd..3eb1a1d84e 100644 --- a/tests/reference/asr-test_assign6-05cd64f.stderr +++ b/tests/reference/asr-test_assign6-05cd64f.stderr @@ -2,4 +2,4 @@ semantic error: 'str' object does not support item assignment --> tests/errors/test_assign6.py:4:5 | 4 | s[0] = 'f' - | ^^^^^^^^^^ + | ^^^^ diff --git a/tests/reference/asr-test_assign7-beebac3.json b/tests/reference/asr-test_assign7-beebac3.json index 7ddee4fb5a..e5197e2be8 100644 --- a/tests/reference/asr-test_assign7-beebac3.json +++ b/tests/reference/asr-test_assign7-beebac3.json @@ -8,6 +8,6 @@ "stdout": null, "stdout_hash": null, "stderr": "asr-test_assign7-beebac3.stderr", - "stderr_hash": "d12f04efad566740bd562fbe9c00a058210a9adf0f5297475fc41fe6", + "stderr_hash": "109f7da7ac86c0c2add0ff034655336396cb58ebe81570c1d0ce6e81", "returncode": 2 } \ No newline at end of file diff --git a/tests/reference/asr-test_assign7-beebac3.stderr b/tests/reference/asr-test_assign7-beebac3.stderr index 88ea06ffd7..87c04ca904 100644 --- a/tests/reference/asr-test_assign7-beebac3.stderr +++ b/tests/reference/asr-test_assign7-beebac3.stderr @@ -2,4 +2,4 @@ semantic error: 'tuple[i32, i32]' object does not support item assignment --> tests/errors/test_assign7.py:4:5 | 4 | t[0] = 3 - | ^^^^^^^^ + | ^^^^