Skip to content

Commit 7053e9e

Browse files
committed
Tests are passing for test_lazyexpr_fields.py suite
1 parent 442d330 commit 7053e9e

File tree

5 files changed

+105
-86
lines changed

5 files changed

+105
-86
lines changed

src/blosc2/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ class Tuner(Enum):
266266
get_expr_operands,
267267
validate_expr,
268268
evaluate,
269-
_ne_evaluate,
270269
)
271270
from .proxy import Proxy, ProxySource, ProxyNDSource, ProxyNDField, SimpleProxy, jit
272271

src/blosc2/lazyexpr.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
ne = None
4646

4747

48-
def _ne_evaluate(expression, local_dict=None, **kwargs):
48+
def ne_evaluate(expression, local_dict=None, **kwargs):
4949
"""Safely evaluate expressions using numexpr when possible, falling back to numpy."""
5050
if local_dict is None:
5151
local_dict = {}
@@ -62,7 +62,11 @@ def _ne_evaluate(expression, local_dict=None, **kwargs):
6262
)
6363
# Get local vars dict from the stack frame
6464
_frame_depth = kwargs.pop("_frame_depth", 1)
65-
local_dict |= dict(sys._getframe(_frame_depth).f_locals)
65+
local_dict |= {
66+
k: v
67+
for k, v in dict(sys._getframe(_frame_depth).f_locals).items()
68+
if hasattr(v, "dtype") or np.isscalar(v)
69+
}
6670
if "out" in kwargs:
6771
out = kwargs.pop("out")
6872
out[:] = eval(expression, safe_globals, local_dict)
@@ -994,19 +998,19 @@ def fast_eval( # noqa: C901
994998
# if callable(expression):
995999
# expression(tuple(chunk_operands.values()), out[slice_], offset=offset)
9961000
# else:
997-
# _ne_evaluate(expression, chunk_operands, out=out[slice_])
1001+
# ne_evaluate(expression, chunk_operands, out=out[slice_])
9981002
# continue
9991003
if callable(expression):
10001004
result = np.empty(chunks_, dtype=out.dtype)
10011005
expression(tuple(chunk_operands.values()), result, offset=offset)
10021006
else:
10031007
if where is None:
1004-
result = _ne_evaluate(expression, chunk_operands, **ne_args)
1008+
result = ne_evaluate(expression, chunk_operands, **ne_args)
10051009
else:
10061010
# Apply the where condition (in result)
10071011
if len(where) == 2:
10081012
new_expr = f"where({expression}, _where_x, _where_y)"
1009-
result = _ne_evaluate(new_expr, chunk_operands, **ne_args)
1013+
result = ne_evaluate(new_expr, chunk_operands, **ne_args)
10101014
else:
10111015
# We do not support one or zero operands in the fast path yet
10121016
raise ValueError("Fast path: the where condition must be a tuple with two elements")
@@ -1228,7 +1232,7 @@ def slices_eval( # noqa: C901
12281232
continue
12291233

12301234
if where is None:
1231-
result = _ne_evaluate(expression, chunk_operands, **ne_args)
1235+
result = ne_evaluate(expression, chunk_operands, **ne_args)
12321236
else:
12331237
# Apply the where condition (in result)
12341238
if len(where) == 2:
@@ -1237,9 +1241,9 @@ def slices_eval( # noqa: C901
12371241
# result = np.where(result, x, y)
12381242
# numexpr is a bit faster than np.where, and we can fuse operations in this case
12391243
new_expr = f"where({expression}, _where_x, _where_y)"
1240-
result = _ne_evaluate(new_expr, chunk_operands, **ne_args)
1244+
result = ne_evaluate(new_expr, chunk_operands, **ne_args)
12411245
elif len(where) == 1:
1242-
result = _ne_evaluate(expression, chunk_operands, **ne_args)
1246+
result = ne_evaluate(expression, chunk_operands, **ne_args)
12431247
if _indices or _order:
12441248
# Return indices only makes sense when the where condition is a tuple with one element
12451249
# and result is a boolean array
@@ -1517,14 +1521,14 @@ def reduce_slices( # noqa: C901
15171521
# We don't have an actual expression, so avoid a copy
15181522
result = chunk_operands["o0"]
15191523
else:
1520-
result = _ne_evaluate(expression, chunk_operands, **ne_args)
1524+
result = ne_evaluate(expression, chunk_operands, **ne_args)
15211525
else:
15221526
# Apply the where condition (in result)
15231527
if len(where) == 2:
15241528
new_expr = f"where({expression}, _where_x, _where_y)"
1525-
result = _ne_evaluate(new_expr, chunk_operands, **ne_args)
1529+
result = ne_evaluate(new_expr, chunk_operands, **ne_args)
15261530
elif len(where) == 1:
1527-
result = _ne_evaluate(expression, chunk_operands, **ne_args)
1531+
result = ne_evaluate(expression, chunk_operands, **ne_args)
15281532
x = chunk_operands["_where_x"]
15291533
result = x[result]
15301534
else:
@@ -1937,7 +1941,7 @@ def dtype(self):
19371941
for key, value in self.operands.items()
19381942
}
19391943
if "contains" in self.expression:
1940-
_out = _ne_evaluate(self.expression, local_dict=operands)
1944+
_out = ne_evaluate(self.expression, local_dict=operands)
19411945
else:
19421946
# Create a globals dict with the functions of numpy
19431947
globals_dict = {f: getattr(np, f) for f in functions if f not in ("contains", "pow")}
@@ -1947,7 +1951,7 @@ def dtype(self):
19471951
# Sometimes, numpy gets a RuntimeWarning when evaluating expressions
19481952
# with synthetic operands (1's). Let's try with numexpr, which is not so picky
19491953
# about this.
1950-
_out = _ne_evaluate(self.expression, local_dict=operands)
1954+
_out = ne_evaluate(self.expression, local_dict=operands)
19511955
self._dtype_ = _out.dtype
19521956
self._expression_ = self.expression
19531957
return self._dtype_
@@ -3143,7 +3147,7 @@ def evaluate(
31433147
nres = na1 + na2
31443148
print(f"Elapsed time (numpy, [:]): {time() - t0:.3f} s")
31453149
t0 = time()
3146-
nres = _ne_evaluate("na1 + na2")
3150+
nres = ne_evaluate("na1 + na2")
31473151
print(f"Elapsed time (numexpr, [:]): {time() - t0:.3f} s")
31483152
nres = nres[sl] if sl is not None else nres
31493153
t0 = time()
@@ -3166,7 +3170,7 @@ def evaluate(
31663170
# nres = np.sin(na1[:]) + 2 * na1[:] + 1 + 2
31673171
print(f"Elapsed time (numpy, [:]): {time() - t0:.3f} s")
31683172
t0 = time()
3169-
nres = _ne_evaluate("tan(na1) * (sin(na2) * sin(na2) + cos(na3)) + (sqrt(na4) * 2) + 2")
3173+
nres = ne_evaluate("tan(na1) * (sin(na2) * sin(na2) + cos(na3)) + (sqrt(na4) * 2) + 2")
31703174
print(f"Elapsed time (numexpr, [:]): {time() - t0:.3f} s")
31713175
nres = nres[sl] if sl is not None else nres
31723176
t0 = time()

0 commit comments

Comments
 (0)