1515from xarray .core .indexes import Index
1616from xarray .core .merge import merge
1717from xarray .core .pycompat import is_dask_collection
18+ from xarray .core .variable import Variable
1819
1920if TYPE_CHECKING :
2021 from xarray .core .types import T_Xarray
@@ -156,6 +157,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping
156157 return slice (None )
157158
158159
160+ def subset_dataset_to_block (
161+ graph : dict , gname : str , dataset : Dataset , input_chunk_bounds , chunk_index
162+ ):
163+ """
164+ Creates a task that subsets an xarray dataset to a block determined by chunk_index.
165+ Block extents are determined by input_chunk_bounds.
166+ Also subtasks that subset the constituent variables of a dataset.
167+ """
168+ import dask
169+
170+ # this will become [[name1, variable1],
171+ # [name2, variable2],
172+ # ...]
173+ # which is passed to dict and then to Dataset
174+ data_vars = []
175+ coords = []
176+
177+ chunk_tuple = tuple (chunk_index .values ())
178+ chunk_dims_set = set (chunk_index )
179+ variable : Variable
180+ for name , variable in dataset .variables .items ():
181+ # make a task that creates tuple of (dims, chunk)
182+ if dask .is_dask_collection (variable .data ):
183+ # get task name for chunk
184+ chunk = (
185+ variable .data .name ,
186+ * tuple (chunk_index [dim ] for dim in variable .dims ),
187+ )
188+
189+ chunk_variable_task = (f"{ name } -{ gname } -{ chunk [0 ]!r} " ,) + chunk_tuple
190+ graph [chunk_variable_task ] = (
191+ tuple ,
192+ [variable .dims , chunk , variable .attrs ],
193+ )
194+ else :
195+ assert name in dataset .dims or variable .ndim == 0
196+
197+ # non-dask array possibly with dimensions chunked on other variables
198+ # index into variable appropriately
199+ subsetter = {
200+ dim : _get_chunk_slicer (dim , chunk_index , input_chunk_bounds )
201+ for dim in variable .dims
202+ }
203+ if set (variable .dims ) < chunk_dims_set :
204+ this_var_chunk_tuple = tuple (chunk_index [dim ] for dim in variable .dims )
205+ else :
206+ this_var_chunk_tuple = chunk_tuple
207+
208+ chunk_variable_task = (
209+ f"{ name } -{ gname } -{ dask .base .tokenize (subsetter )} " ,
210+ ) + this_var_chunk_tuple
211+ # We are including a dimension coordinate,
212+ # minimize duplication by not copying it in the graph for every chunk.
213+ if variable .ndim == 0 or chunk_variable_task not in graph :
214+ subset = variable .isel (subsetter )
215+ graph [chunk_variable_task ] = (
216+ tuple ,
217+ [subset .dims , subset ._data , subset .attrs ],
218+ )
219+
220+ # this task creates dict mapping variable name to above tuple
221+ if name in dataset ._coord_names :
222+ coords .append ([name , chunk_variable_task ])
223+ else :
224+ data_vars .append ([name , chunk_variable_task ])
225+
226+ return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
227+
228+
159229def map_blocks (
160230 func : Callable [..., T_Xarray ],
161231 obj : DataArray | Dataset ,
@@ -280,6 +350,10 @@ def _wrapper(
280350
281351 result = func (* converted_args , ** kwargs )
282352
353+ merged_coordinates = merge (
354+ [arg .coords for arg in args if isinstance (arg , (Dataset , DataArray ))]
355+ ).coords
356+
283357 # check all dims are present
284358 missing_dimensions = set (expected ["shapes" ]) - set (result .sizes )
285359 if missing_dimensions :
@@ -295,12 +369,16 @@ def _wrapper(
295369 f"Received dimension { name !r} of length { result .sizes [name ]} . "
296370 f"Expected length { expected ['shapes' ][name ]} ."
297371 )
298- if name in expected ["indexes" ]:
299- expected_index = expected ["indexes" ][name ]
300- if not index .equals (expected_index ):
301- raise ValueError (
302- f"Expected index { name !r} to be { expected_index !r} . Received { index !r} instead."
303- )
372+
373+ # ChainMap wants MutableMapping, but xindexes is Mapping
374+ merged_indexes = collections .ChainMap (
375+ expected ["indexes" ], merged_coordinates .xindexes # type: ignore[arg-type]
376+ )
377+ expected_index = merged_indexes .get (name , None )
378+ if expected_index is not None and not index .equals (expected_index ):
379+ raise ValueError (
380+ f"Expected index { name !r} to be { expected_index !r} . Received { index !r} instead."
381+ )
304382
305383 # check that all expected variables were returned
306384 check_result_variables (result , expected , "coords" )
@@ -356,6 +434,8 @@ def _wrapper(
356434 dataarray_to_dataset (arg ) if isinstance (arg , DataArray ) else arg
357435 for arg in aligned
358436 )
437+ # rechunk any numpy variables appropriately
438+ xarray_objs = tuple (arg .chunk (arg .chunksizes ) for arg in xarray_objs )
359439
360440 merged_coordinates = merge ([arg .coords for arg in aligned ]).coords
361441
@@ -378,7 +458,7 @@ def _wrapper(
378458 new_coord_vars = template_coords - set (merged_coordinates )
379459
380460 preserved_coords = merged_coordinates .to_dataset ()[preserved_coord_vars ]
381- # preserved_coords contains all coordinates bariables that share a dimension
461+ # preserved_coords contains all coordinates variables that share a dimension
382462 # with any index variable in preserved_indexes
383463 # Drop any unneeded vars in a second pass, this is required for e.g.
384464 # if the mapped function were to drop a non-dimension coordinate variable.
@@ -403,6 +483,13 @@ def _wrapper(
403483 " Please construct a template with appropriately chunked dask arrays."
404484 )
405485
486+ new_indexes = set (template .xindexes ) - set (merged_coordinates )
487+ modified_indexes = set (
488+ name
489+ for name , xindex in coordinates .xindexes .items ()
490+ if not xindex .equals (merged_coordinates .xindexes .get (name , None ))
491+ )
492+
406493 for dim in output_chunks :
407494 if dim in input_chunks and len (input_chunks [dim ]) != len (output_chunks [dim ]):
408495 raise ValueError (
@@ -443,63 +530,7 @@ def _wrapper(
443530 dim : np .cumsum ((0 ,) + chunks_v ) for dim , chunks_v in output_chunks .items ()
444531 }
445532
446- def subset_dataset_to_block (
447- graph : dict , gname : str , dataset : Dataset , input_chunk_bounds , chunk_index
448- ):
449- """
450- Creates a task that subsets an xarray dataset to a block determined by chunk_index.
451- Block extents are determined by input_chunk_bounds.
452- Also subtasks that subset the constituent variables of a dataset.
453- """
454-
455- # this will become [[name1, variable1],
456- # [name2, variable2],
457- # ...]
458- # which is passed to dict and then to Dataset
459- data_vars = []
460- coords = []
461-
462- chunk_tuple = tuple (chunk_index .values ())
463- for name , variable in dataset .variables .items ():
464- # make a task that creates tuple of (dims, chunk)
465- if dask .is_dask_collection (variable .data ):
466- # recursively index into dask_keys nested list to get chunk
467- chunk = variable .__dask_keys__ ()
468- for dim in variable .dims :
469- chunk = chunk [chunk_index [dim ]]
470-
471- chunk_variable_task = (f"{ name } -{ gname } -{ chunk [0 ]!r} " ,) + chunk_tuple
472- graph [chunk_variable_task ] = (
473- tuple ,
474- [variable .dims , chunk , variable .attrs ],
475- )
476- else :
477- # non-dask array possibly with dimensions chunked on other variables
478- # index into variable appropriately
479- subsetter = {
480- dim : _get_chunk_slicer (dim , chunk_index , input_chunk_bounds )
481- for dim in variable .dims
482- }
483- subset = variable .isel (subsetter )
484- chunk_variable_task = (
485- f"{ name } -{ gname } -{ dask .base .tokenize (subset )} " ,
486- ) + chunk_tuple
487- graph [chunk_variable_task ] = (
488- tuple ,
489- [subset .dims , subset , subset .attrs ],
490- )
491-
492- # this task creates dict mapping variable name to above tuple
493- if name in dataset ._coord_names :
494- coords .append ([name , chunk_variable_task ])
495- else :
496- data_vars .append ([name , chunk_variable_task ])
497-
498- return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
499-
500- # variable names that depend on the computation. Currently, indexes
501- # cannot be modified in the mapped function, so we exclude thos
502- computed_variables = set (template .variables ) - set (coordinates .xindexes )
533+ computed_variables = set (template .variables ) - set (coordinates .indexes )
503534 # iterate over all possible chunk combinations
504535 for chunk_tuple in itertools .product (* ichunk .values ()):
505536 # mapping from dimension name to chunk index
@@ -523,11 +554,12 @@ def subset_dataset_to_block(
523554 },
524555 "data_vars" : set (template .data_vars .keys ()),
525556 "coords" : set (template .coords .keys ()),
557+ # only include new or modified indexes to minimize duplication of data, and graph size.
526558 "indexes" : {
527559 dim : coordinates .xindexes [dim ][
528560 _get_chunk_slicer (dim , chunk_index , output_chunk_bounds )
529561 ]
530- for dim in coordinates . xindexes
562+ for dim in ( new_indexes | modified_indexes )
531563 },
532564 }
533565
@@ -541,14 +573,11 @@ def subset_dataset_to_block(
541573 gname_l = f"{ name } -{ gname } "
542574 var_key_map [name ] = gname_l
543575
544- key : tuple [Any , ...] = (gname_l ,)
545- for dim in variable .dims :
546- if dim in chunk_index :
547- key += (chunk_index [dim ],)
548- else :
549- # unchunked dimensions in the input have one chunk in the result
550- # output can have new dimensions with exactly one chunk
551- key += (0 ,)
576+ # unchunked dimensions in the input have one chunk in the result
577+ # output can have new dimensions with exactly one chunk
578+ key : tuple [Any , ...] = (gname_l ,) + tuple (
579+ chunk_index [dim ] if dim in chunk_index else 0 for dim in variable .dims
580+ )
552581
553582 # We're adding multiple new layers to the graph:
554583 # The first new layer is the result of the computation on
0 commit comments