Skip to content

Commit a3c6053

Browse files
shoyerXarray-Beam authors
authored andcommitted
Skip rechunking if source and target chunks are the same
This avoids an unnecessary shuffle. PiperOrigin-RevId: 697681675
1 parent 762228b commit a3c6053

File tree

4 files changed

+31
-16
lines changed

4 files changed

+31
-16
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
setuptools.setup(
4343
name='xarray-beam',
44-
version='0.6.3',
44+
version='0.6.4', # keep in sync with __init__.py
4545
license='Apache 2.0',
4646
author='Google LLC',
4747
author_email='[email protected]',

xarray_beam/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@
5151
DatasetToZarr,
5252
)
5353

54-
__version__ = '0.6.3'
54+
__version__ = '0.6.4' # keep in sync with setup.py

xarray_beam/_src/rechunk.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,12 @@ def __init__(
547547
self.dim_sizes = dim_sizes
548548
self.source_chunks = normalize_chunks(source_chunks, dim_sizes)
549549
self.target_chunks = normalize_chunks(target_chunks, dim_sizes)
550+
551+
if self.source_chunks == self.target_chunks:
552+
self.stage_in = self.stage_out = []
553+
logging.info(f'Rechunk with chunks {self.source_chunks} is a no-op')
554+
return
555+
550556
plan = rechunking_plan(
551557
dim_sizes,
552558
self.source_chunks,

xarray_beam/_src/rechunk_test.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_normalize_chunks_errors(self):
6363

6464
def test_rechunking_plan(self):
6565
# this trivial case fits entirely into memory
66-
plan, = rechunk.rechunking_plan(
66+
(plan,) = rechunk.rechunking_plan(
6767
dim_sizes={'x': 10, 'y': 20},
6868
source_chunks={'x': 1, 'y': 20},
6969
target_chunks={'x': 10, 'y': 1},
@@ -75,7 +75,7 @@ def test_rechunking_plan(self):
7575
self.assertEqual(plan, expected)
7676

7777
# this harder case doesn't
78-
(read_chunks, _, write_chunks), = rechunk.rechunking_plan(
78+
((read_chunks, _, write_chunks),) = rechunk.rechunking_plan(
7979
dim_sizes={'t': 1000, 'x': 200, 'y': 300},
8080
source_chunks={'t': 1, 'x': 200, 'y': 300},
8181
target_chunks={'t': 1000, 'x': 20, 'y': 20},
@@ -361,15 +361,11 @@ def test_consolidate_with_unchunked_vars(self):
361361
]
362362
with self.assertRaisesRegex(
363363
ValueError,
364-
re.escape(
365-
textwrap.dedent(
366-
"""
364+
re.escape(textwrap.dedent("""
367365
combining nested dataset chunks for vars=None with offsets={'x': [0, 10]} failed.
368366
Leading datasets along dimension 'x':
369367
<xarray.Dataset>
370-
"""
371-
).strip()
372-
),
368+
""").strip()),
373369
):
374370
inconsistent_inputs | xbeam.ConsolidateChunks({'x': -1})
375371

@@ -449,14 +445,10 @@ def test_consolidate_variables_merge_fails(self):
449445
]
450446
with self.assertRaisesRegex(
451447
ValueError,
452-
re.escape(
453-
textwrap.dedent(
454-
"""
448+
re.escape(textwrap.dedent("""
455449
merging dataset chunks with variables [{'foo'}, {'bar'}] failed.
456450
<xarray.Dataset>
457-
"""
458-
).strip()
459-
),
451+
""").strip()),
460452
):
461453
inputs | xbeam.ConsolidateVariables()
462454

@@ -816,6 +808,23 @@ def test_rechunk_inconsistent_dimensions(self):
816808
)
817809
self.assertIdenticalChunks(actual, expected)
818810

811+
def test_rechunk_same_source_and_target_chunks(self):
812+
rs = np.random.RandomState(0)
813+
ds = xarray.Dataset({'foo': (('x', 'y'), rs.rand(2, 3))})
814+
p = test_util.EagerPipeline()
815+
inputs = p | xbeam.DatasetToChunks(ds, {'x': 1}, split_vars=True)
816+
rechunk_transform = xbeam.Rechunk(
817+
dim_sizes=ds.sizes,
818+
source_chunks={'x': 1},
819+
target_chunks={'x': 1},
820+
itemsize=8,
821+
)
822+
# no rechunk stages
823+
self.assertEqual(rechunk_transform.stage_in, [])
824+
self.assertEqual(rechunk_transform.stage_out, [])
825+
outputs = inputs | rechunk_transform
826+
self.assertIdenticalChunks(outputs, inputs)
827+
819828

820829
if __name__ == '__main__':
821830
absltest.main()

0 commit comments

Comments
 (0)