Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/blosc2/lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,6 @@ def update_expr(self, new_op): # noqa: C901
if hasattr(value2, "_where_args"):
value2 = value2.compute()

self._dtype = infer_dtype(op, value1, value2)
if not isinstance(value1, LazyExpr) and not isinstance(value2, LazyExpr):
# We converted some of the operands to NDArray (where() handling above)
new_operands = {"o0": value1, "o1": value2}
Expand Down Expand Up @@ -2677,8 +2676,8 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
_shape = new_expr.shape
if isinstance(new_expr, blosc2.LazyExpr):
# Restore the original expression and operands
new_expr.expression = _expression
new_expr.expression_tosave = expression
new_expr.expression = f"({_expression})" # forcibly add parenthesis
new_expr.expression_tosave = new_expr.expression
new_expr.operands = _operands
new_expr.operands_tosave = operands
else:
Expand Down
22 changes: 22 additions & 0 deletions tests/ndarray/test_lazyexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,10 @@ def test_dtype_infer(dtype1, dtype2, scalar):
np.testing.assert_allclose(res, nres)
assert res.dtype == nres.dtype

# Check dtype not changed by expression creation (bug fix)
assert a.dtype == dtype1
assert b.dtype == dtype2


@pytest.mark.parametrize(
"cfunc", ["np.int8", "np.int16", "np.int32", "np.int64", "np.float32", "np.float64"]
Expand Down Expand Up @@ -1330,3 +1334,21 @@ def test_missing_operator():
# Clean up
blosc2.remove_urlpath("a.b2nd")
blosc2.remove_urlpath("expr.b2nd")


# Test the chaining of multiple lazy expressions
def test_chain_expressions():
N = 1_000
dtype = "float64"
a = blosc2.linspace(0, 1, N * N, dtype=dtype, shape=(N, N))
b = blosc2.linspace(1, 2, N * N, dtype=dtype, shape=(N, N))
c = blosc2.linspace(0, 1, N, dtype=dtype, shape=(N,))

le1 = a**3 + blosc2.sin(a**2)
le2 = le1 < c
le3 = le2 & (b < 0)

le1_ = blosc2.lazyexpr("a ** 3 + sin(a ** 2)", {"a": a})
le2_ = blosc2.lazyexpr("(le1 < c)", {"le1": le1_, "c": c})
le3_ = blosc2.lazyexpr("(le2 & (b < 0))", {"le2": le2_, "b": b})
assert (le3_[:] == le3[:]).all()