Skip to content

Commit f7667f3

Browse files
authored
Merge pull request #510 from Blosc/fix_lazyexpr
Fixes to infer_shape and SimpleProxy for lazy expression handling of non-blosc2 inputs
2 parents 9ca0288 + 08773e0 commit f7667f3

File tree

11 files changed

+537
-141
lines changed

11 files changed

+537
-141
lines changed

ADD_LAZYFUNCS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Once you have written a (public API) function in Blosc2, it is important to:
55
* Add it to the list of functions in ``__all__`` in the ``__init__.py`` file
66
* If it is present in numpy, add it to the relevant dictionary (``local_ufunc_map``, ``ufunc_map`` ``ufunc_map_1param``) in ``ndarray.py``
77

8+
If your function is implemented at the Blosc2 level (and not via either the `LazyUDF` or `LazyExpr` classes), you will need to add some conversion of the inputs to SimpleProxy instances (see e.g. ``matmul`` for an example).
9+
810
Finally, you also need to deal with it correctly within ``shape_utils.py``.
911

1012
If the function does not change the shape of the output, simply add it to ``elementwise_funcs`` and you're done.

src/blosc2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def _raise(exc):
445445
result_type,
446446
can_cast,
447447
)
448-
from .proxy import Proxy, ProxySource, ProxyNDSource, ProxyNDField, SimpleProxy, jit
448+
from .proxy import Proxy, ProxySource, ProxyNDSource, ProxyNDField, SimpleProxy, jit, as_simpleproxy
449449

450450
from .schunk import SChunk, open
451451
from . import linalg
@@ -648,6 +648,7 @@ def _raise(exc):
648648
"asarray",
649649
"asin",
650650
"asinh",
651+
"as_simpleproxy",
651652
"astype",
652653
"atan",
653654
"atan2",

src/blosc2/lazyexpr.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
process_key,
4848
)
4949

50+
from .proxy import _convert_dtype
5051
from .shape_utils import constructors, elementwise_funcs, infer_shape, linalg_attrs, linalg_funcs, reducers
5152

5253
if not blosc2.IS_WASM:
@@ -433,13 +434,13 @@ def convert_inputs(inputs):
433434
return []
434435
inputs_ = []
435436
for obj in inputs:
436-
if not isinstance(obj, blosc2.Array) and not np.isscalar(obj):
437+
if not isinstance(obj, (np.ndarray, blosc2.Operand)) and not np.isscalar(obj):
437438
try:
438-
obj = np.asarray(obj)
439+
obj = blosc2.SimpleProxy(obj)
439440
except Exception:
440441
print(
441442
"Inputs not being np.ndarray, Array or Python scalar objects"
442-
" should be convertible to np.ndarray."
443+
" should be convertible to SimpleProxy."
443444
)
444445
raise
445446
inputs_.append(obj)
@@ -885,7 +886,13 @@ def validate_inputs(inputs: dict, out=None, reduce=False) -> tuple: # noqa: C90
885886
return shape, None, None, False
886887

887888
# More checks specific of NDArray inputs
888-
NDinputs = [input for input in inputs if hasattr(input, "chunks")]
889+
# NDInputs are either non-SimpleProxy with chunks or are SimpleProxy with src having chunks
890+
NDinputs = [
891+
input
892+
for input in inputs
893+
if (hasattr(input, "chunks") and not isinstance(input, blosc2.SimpleProxy))
894+
or (isinstance(input, blosc2.SimpleProxy) and hasattr(input.src, "chunks"))
895+
]
889896
if not NDinputs:
890897
# All inputs are NumPy arrays, so we cannot take the fast path
891898
if inputs and hasattr(inputs[0], "shape"):
@@ -1076,9 +1083,9 @@ def fill_chunk_operands( # noqa: C901
10761083
if nchunk == 0:
10771084
# Initialize the iterator for reading the chunks
10781085
# Take any operand (all should have the same shape and chunks)
1079-
arr = next(iter(operands.values()))
1086+
key, arr = next(iter(operands.items()))
10801087
chunks_idx, _ = get_chunks_idx(arr.shape, arr.chunks)
1081-
info = (reduc, aligned, low_mem, chunks_idx)
1088+
info = (reduc, aligned[key], low_mem, chunks_idx)
10821089
iter_chunks = read_nchunk(list(operands.values()), info)
10831090
# Run the asynchronous file reading function from a synchronous context
10841091
chunks = next(iter_chunks)
@@ -1094,7 +1101,7 @@ def fill_chunk_operands( # noqa: C901
10941101
# The chunk is a special zero chunk, so we can treat it as a scalar
10951102
chunk_operands[key] = np.zeros((), dtype=value.dtype)
10961103
continue
1097-
if aligned:
1104+
if aligned[key]:
10981105
buff = blosc2.decompress2(chunks[i])
10991106
bsize = value.dtype.itemsize * math.prod(chunks_)
11001107
chunk_operands[key] = np.frombuffer(buff[:bsize], dtype=value.dtype).reshape(chunks_)
@@ -1114,10 +1121,6 @@ def fill_chunk_operands( # noqa: C901
11141121
chunk_operands[key] = value[()]
11151122
continue
11161123

1117-
if isinstance(value, np.ndarray | blosc2.C2Array):
1118-
chunk_operands[key] = value[slice_]
1119-
continue
1120-
11211124
if not full_chunk or not isinstance(value, blosc2.NDArray):
11221125
# The chunk is not a full one, or has padding, or is not a blosc2.NDArray,
11231126
# so we need to go the slow path
@@ -1142,7 +1145,7 @@ def fill_chunk_operands( # noqa: C901
11421145
value.get_slice_numpy(chunk_operands[key], (starts, stops))
11431146
continue
11441147

1145-
if aligned:
1148+
if aligned[key]:
11461149
# Decompress the whole chunk and store it
11471150
buff = value.schunk.decompress_chunk(nchunk)
11481151
bsize = value.dtype.itemsize * math.prod(chunks_)
@@ -1202,7 +1205,10 @@ def fast_eval( # noqa: C901
12021205
if blocks is None:
12031206
blocks = basearr.blocks
12041207
# Check whether the partitions are aligned and behaved
1205-
aligned = blosc2.are_partitions_aligned(shape, chunks, blocks)
1208+
aligned = {
1209+
k: False if not hasattr(k, "chunks") else blosc2.are_partitions_aligned(k.shape, k.chunks, k.blocks)
1210+
for k in operands
1211+
}
12061212
behaved = blosc2.are_partitions_behaved(shape, chunks, blocks)
12071213

12081214
# Check that all operands are NDArray for fast path
@@ -1226,7 +1232,7 @@ def fast_eval( # noqa: C901
12261232
offset = tuple(s.start for s in cslice) # offset for the udf
12271233
chunks_ = tuple(s.stop - s.start for s in cslice)
12281234

1229-
full_chunk = chunks_ == chunks
1235+
full_chunk = chunks_ == chunks # slice is same as chunk
12301236
fill_chunk_operands(
12311237
operands, cslice, chunks_, full_chunk, aligned, nchunk, iter_disk, chunk_operands
12321238
)
@@ -1810,7 +1816,7 @@ def reduce_slices( # noqa: C901
18101816
same_chunks = all(operand.chunks == o.chunks for o in operands.values() if hasattr(o, "chunks"))
18111817
same_blocks = all(operand.blocks == o.blocks for o in operands.values() if hasattr(o, "blocks"))
18121818
fast_path = same_shape and same_chunks and same_blocks and (0 not in operand.chunks)
1813-
aligned, iter_disk = False, False
1819+
aligned, iter_disk = dict.fromkeys(operands.keys(), False), False
18141820
if fast_path:
18151821
chunks = operand.chunks
18161822
# Check that all operands are NDArray for fast path
@@ -2213,7 +2219,9 @@ def result_type(
22132219
# Follow NumPy rules for scalar-array operations
22142220
# Create small arrays with the same dtypes and let NumPy's type promotion determine the result type
22152221
arrs = [
2216-
value if (np.isscalar(value) or not hasattr(value, "dtype")) else np.array([0], dtype=value.dtype)
2222+
value
2223+
if (np.isscalar(value) or not hasattr(value, "dtype"))
2224+
else np.array([0], dtype=_convert_dtype(value.dtype))
22172225
for value in arrays_and_dtypes
22182226
]
22192227
return np.result_type(*arrs)
@@ -2255,15 +2263,29 @@ def __init__(self, new_op): # noqa: C901
22552263
return
22562264
value1, op, value2 = new_op
22572265
dtype_ = check_dtype(op, value1, value2) # perform some checks
2266+
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
2267+
value1 = (
2268+
blosc2.SimpleProxy(value1)
2269+
if not (isinstance(value1, (blosc2.Operand, np.ndarray)) or np.isscalar(value1))
2270+
else value1
2271+
)
22582272
if value2 is None:
22592273
if isinstance(value1, LazyExpr):
22602274
self.expression = value1.expression if op is None else f"{op}({value1.expression})"
22612275
self.operands = value1.operands
22622276
else:
2277+
if np.isscalar(value1):
2278+
value1 = ne_evaluate(f"{op}({value1})")
2279+
op = None
22632280
self.operands = {"o0": value1}
22642281
self.expression = "o0" if op is None else f"{op}(o0)"
22652282
return
2266-
elif isinstance(value1, LazyExpr) or isinstance(value2, LazyExpr):
2283+
value2 = (
2284+
blosc2.SimpleProxy(value2)
2285+
if not (isinstance(value2, (blosc2.Operand, np.ndarray)) or np.isscalar(value2))
2286+
else value2
2287+
)
2288+
if isinstance(value1, LazyExpr) or isinstance(value2, LazyExpr):
22672289
if isinstance(value1, LazyExpr):
22682290
newexpr = value1.update_expr(new_op)
22692291
else:
@@ -2274,7 +2296,8 @@ def __init__(self, new_op): # noqa: C901
22742296
return
22752297
elif op in funcs_2args:
22762298
if np.isscalar(value1) and np.isscalar(value2):
2277-
self.expression = f"{op}({value1}, {value2})"
2299+
self.expression = "o0"
2300+
self.operands = {"o0": ne_evaluate(f"{op}({value1}, {value2})")} # eager evaluation
22782301
elif np.isscalar(value2):
22792302
self.operands = {"o0": value1}
22802303
self.expression = f"{op}(o0, {value2})"
@@ -2288,7 +2311,8 @@ def __init__(self, new_op): # noqa: C901
22882311

22892312
self._dtype = dtype_
22902313
if np.isscalar(value1) and np.isscalar(value2):
2291-
self.expression = f"({value1} {op} {value2})"
2314+
self.expression = "o0"
2315+
self.operands = {"o0": ne_evaluate(f"({value1} {op} {value2})")} # eager evaluation
22922316
elif np.isscalar(value2):
22932317
self.operands = {"o0": value1}
22942318
self.expression = f"(o0 {op} {value2})"
@@ -2530,7 +2554,7 @@ def where(self, value1=None, value2=None):
25302554
# This just acts as a 'decorator' for the existing expression
25312555
if value1 is not None and value2 is not None:
25322556
# Guess the outcome dtype for value1 and value2
2533-
dtype = np.result_type(value1, value2)
2557+
dtype = blosc2.result_type(value1, value2)
25342558
args = {"_where_x": value1, "_where_y": value2}
25352559
elif value1 is not None:
25362560
if hasattr(value1, "dtype"):
@@ -2736,7 +2760,7 @@ def find_args(expr):
27362760

27372761
def _compute_expr(self, item, kwargs): # noqa : C901
27382762
# ne_evaluate will need safe_blosc2_globals for some functions (e.g. clip, logaddexp)
2739-
# that are implemenetd in python-blosc2 not in numexpr
2763+
# that are implemented in python-blosc2 not in numexpr
27402764
global safe_blosc2_globals
27412765
if len(safe_blosc2_globals) == 0:
27422766
# First eval call, fill blosc2_safe_globals for ne_evaluate
@@ -2881,7 +2905,7 @@ def __getitem__(self, item):
28812905
# Squeeze single-element dimensions when indexing with integers
28822906
# See e.g. examples/ndarray/animated_plot.py
28832907
if isinstance(item, int) or (hasattr(item, "__iter__") and any(isinstance(i, int) for i in item)):
2884-
result = result.squeeze()
2908+
result = result.squeeze(axis=tuple(i for i in range(result.ndim) if result.shape[i] == 1))
28852909
return result
28862910

28872911
def slice(self, item):
@@ -3008,7 +3032,7 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
30083032
_operands = operands | local_vars
30093033
# Check that operands are proper Operands, LazyArray or scalars; if not, convert to NDArray objects
30103034
for op, val in _operands.items():
3011-
if not (isinstance(val, (blosc2.Operand, blosc2.LazyArray, np.ndarray)) or np.isscalar(val)):
3035+
if not (isinstance(val, (blosc2.Operand, np.ndarray)) or np.isscalar(val)):
30123036
_operands[op] = blosc2.SimpleProxy(val)
30133037
# for scalars just return value (internally converts to () if necessary)
30143038
opshapes = {k: v if not hasattr(v, "shape") else v.shape for k, v in _operands.items()}

0 commit comments

Comments
 (0)