Skip to content

Commit 6be4353

Browse files
authored
Merge pull request #2157 from Shaikh-Ubaid/support_arr_comp
Support array comparison
2 parents fe2fbf0 + ff45349 commit 6be4353

File tree

8 files changed

+107
-0
lines changed

8 files changed

+107
-0
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ RUN(NAME array_size_02 LABELS cpython llvm c)
368368
RUN(NAME array_01 LABELS cpython llvm wasm c)
369369
RUN(NAME array_02 LABELS cpython wasm c)
370370
RUN(NAME array_03 LABELS cpython llvm c)
371+
RUN(NAME array_04 LABELS cpython llvm c)
371372
RUN(NAME bindc_01 LABELS cpython llvm c)
372373
RUN(NAME bindc_02 LABELS cpython llvm c)
373374
RUN(NAME bindc_04 LABELS llvm c NOFAST)

integration_tests/array_04.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from lpython import i32, Const
2+
from numpy import empty, int32
3+
4+
def main0():
5+
n: Const[i32] = 1
6+
x: i32[n, n] = empty([n, n], dtype=int32)
7+
y: i32[n, n] = empty([n, n], dtype=int32)
8+
9+
x[0, 0] = -10
10+
y[0, 0] = -10
11+
12+
print(x[0, 0], y[0, 0])
13+
assert x == y
14+
15+
y[0, 0] = 10
16+
print(x[0, 0], y[0, 0])
17+
assert x != y
18+
19+
main0()

src/libasr/asr_utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,12 @@ static inline int64_t get_fixed_size_of_array(ASR::dimension_t* m_dims, size_t n
18681868
return array_size;
18691869
}
18701870

1871+
static inline int64_t get_fixed_size_of_array(ASR::ttype_t* type) {
1872+
ASR::dimension_t* m_dims = nullptr;
1873+
size_t n_dims = ASRUtils::extract_dimensions_from_ttype(type, m_dims);
1874+
return ASRUtils::get_fixed_size_of_array(m_dims, n_dims);
1875+
}
1876+
18711877
inline int extract_n_dims_from_ttype(ASR::ttype_t *x) {
18721878
ASR::dimension_t* m_dims_temp = nullptr;
18731879
return extract_dimensions_from_ttype(x, m_dims_temp);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6157,6 +6157,37 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
61576157
if( ASR::is_a<ASR::Enum_t>(*dest_type) || ASR::is_a<ASR::Const_t>(*dest_type) ) {
61586158
dest_type = ASRUtils::get_contained_type(dest_type);
61596159
}
6160+
6161+
if (ASRUtils::is_array(dest_type)) {
6162+
ASR::dimension_t* m_dims = nullptr;
6163+
int n_dims = ASRUtils::extract_dimensions_from_ttype(dest_type, m_dims);
6164+
int array_size = ASRUtils::get_fixed_size_of_array(m_dims, n_dims);
6165+
if (array_size == -1) {
6166+
throw SemanticError("The truth value of an array is ambiguous. Use a.any() or a.all()", x.base.base.loc);
6167+
} else if (array_size != 1) {
6168+
throw SemanticError("The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()", x.base.base.loc);
6169+
} else {
6170+
Vec<ASR::array_index_t> argsL, argsR;
6171+
argsL.reserve(al, 1);
6172+
argsR.reserve(al, 1);
6173+
for (int i = 0; i < n_dims; i++) {
6174+
ASR::array_index_t aiL, aiR;
6175+
ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4));
6176+
ASR::expr_t* const_zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int_type));
6177+
aiL.m_right = aiR.m_right = const_zero;
6178+
aiL.m_left = aiR.m_left = nullptr;
6179+
aiL.m_step = aiR.m_step = nullptr;
6180+
aiL.loc = left->base.loc;
6181+
aiR.loc = right->base.loc;
6182+
argsL.push_back(al, aiL);
6183+
argsR.push_back(al, aiR);
6184+
}
6185+
dest_type = ASRUtils::type_get_past_array(dest_type);
6186+
left = ASRUtils::EXPR(make_ArrayItem_t(al, left->base.loc, left, argsL.p, argsL.n, dest_type, ASR::arraystorageType::RowMajor, nullptr));
6187+
right = ASRUtils::EXPR(make_ArrayItem_t(al, right->base.loc, right, argsR.p, argsR.n, dest_type, ASR::arraystorageType::RowMajor, nullptr));
6188+
}
6189+
}
6190+
61606191
if (ASRUtils::is_integer(*dest_type)) {
61616192
if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) {
61626193
int64_t left_value = -1;

tests/errors/arrays_02.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from lpython import (i8, i32, dataclass)
2+
from numpy import (empty, int8)
3+
4+
@dataclass
5+
class Foo:
6+
a : i8[4] = empty(4, dtype=int8)
7+
dim : i32 = 4
8+
9+
def trinary_majority(x : Foo, y : Foo, z : Foo) -> Foo:
10+
foo : Foo = Foo()
11+
i : i32
12+
for i in range(foo.dim):
13+
foo.a[i] = (x.a[i] & y.a[i]) | (y.a[i] & z.a[i]) | (z.a[i] & x.a[i])
14+
return foo
15+
16+
17+
t1 : Foo = Foo()
18+
t1.a = empty(4, dtype=int8)
19+
20+
t2 : Foo = Foo()
21+
t2.a = empty(4, dtype=int8)
22+
23+
t3 : Foo = Foo()
24+
t3.a = empty(4, dtype=int8)
25+
26+
r1 : Foo = trinary_majority(t1, t2, t3)
27+
28+
assert r1.a == t1.a
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"basename": "asr-arrays_02-da94458",
3+
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
4+
"infile": "tests/errors/arrays_02.py",
5+
"infile_hash": "05e70a0056dc67dbf3a54ea66965db8746f9de012561ca95cb1fdb43",
6+
"outfile": null,
7+
"outfile_hash": null,
8+
"stdout": null,
9+
"stdout_hash": null,
10+
"stderr": "asr-arrays_02-da94458.stderr",
11+
"stderr_hash": "dc0e5be7cd6de7395421aedf1ce11977206f3e35bb7cba271aed8992",
12+
"returncode": 2
13+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
semantic error: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
2+
--> tests/errors/arrays_02.py:28:8
3+
|
4+
28 | assert r1.a == t1.a
5+
| ^^^^^^^^^^^^

tests/tests.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,10 @@ asr = true
650650
filename = "errors/arrays_01.py"
651651
asr = true
652652

653+
[[test]]
654+
filename = "errors/arrays_02.py"
655+
asr = true
656+
653657
[[test]]
654658
filename = "errors/structs_02.py"
655659
asr = true

0 commit comments

Comments
 (0)