Skip to content

Use Dask's 'broadcast trick' to save memory for single-valued arrays #359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
from cubed.utils import memory_repr
from cubed.utils import broadcast_trick, memory_repr


class VirtualEmptyArray:
Expand All @@ -33,7 +33,8 @@ def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
return nxp.empty(indexer.shape, dtype=self.dtype)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)

@property
def oindex(self):
Expand Down Expand Up @@ -68,7 +69,10 @@ def __getitem__(self, key):
if not isinstance(key, tuple):
key = (key,)
indexer = BasicIndexer(key, self.template)
return nxp.full(indexer.shape, fill_value=self.fill_value, dtype=self.dtype)
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.full)(
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
)

@property
def oindex(self):
Expand Down
16 changes: 16 additions & 0 deletions cubed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from cubed.backend_array_api import namespace as nxp
from cubed.utils import (
block_id_to_offset,
broadcast_trick,
chunk_memory,
extract_stack_summaries,
join_path,
Expand Down Expand Up @@ -153,3 +156,16 @@ def test_map_nested_iterators():
assert count == 2
assert list(out1) == [4, 5]
assert count == 4


def test_broadcast_trick():
a = nxp.ones((10, 10), dtype=nxp.int8)
b = broadcast_trick(nxp.ones)((10, 10), dtype=nxp.int8)

assert_array_equal(a, b)
assert a.nbytes == 100
assert b.base.nbytes == 1

a = nxp.ones((), dtype=nxp.int8)
b = broadcast_trick(nxp.ones)((), dtype=nxp.int8)
assert_array_equal(a, b)
21 changes: 21 additions & 0 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from itertools import islice
from math import prod
from operator import add
Expand All @@ -17,6 +18,7 @@
import numpy as np
import tlz as toolz

from cubed.backend_array_api import namespace as nxp
from cubed.types import T_DType, T_RectangularChunks, T_RegularChunks
from cubed.vendor.dask.array.core import _check_regular_chunks

Expand Down Expand Up @@ -289,3 +291,22 @@ def map_nested(func, seq):
return map(lambda item: map_nested(func, item), seq)
else:
return func(seq)


def _broadcast_trick_inner(func, shape, *args, **kwargs):
# cupy-specific hack. numpy is happy with hardcoded shape=().
null_shape = () if shape == () else 1

return nxp.broadcast_to(func(*args, shape=null_shape, **kwargs), shape)


def broadcast_trick(func):
"""Apply Dask's broadcast trick to array API functions that produce arrays
containing a single value to save space in memory.

Note that this should only be used for arrays that never mutated.
"""
inner = partial(_broadcast_trick_inner, func)
inner.__doc__ = func.__doc__
inner.__name__ = func.__name__
return inner