Skip to content

Add scan. #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 12, 2025
Merged

Add scan. #531

merged 4 commits into from
Jun 12, 2025

Conversation

dcherian
Copy link
Contributor

Closes #277

# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
# Instead we generalize recursively apply the scan to `reduced`.
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need input here on choosing a new intermediate chunksize to rechunk to based on memory info.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of things to consider here: the number of chunks to combine at each stage, and the memory limits.

The first is like split_every in reduction, where the default is 4, although 6 or 8 may be better for larger workloads.

For the second, we should make sure the new chunksize is no larger than (x.spec.allowed_mem - x.spec.reserved_mem) // 4, where the factor of 4 is comes about because of the {compressed,uncompressed} * {input,output} copies.

There is an error case where this memory constraint means the new chunksize is no larger than the existing one, so the computation can't proceed. The user can fix this either by reducing the chunksize or by increasing the memory. This is similar to this case:

cubed/cubed/core/ops.py

Lines 985 to 991 in 88c5dc4

# single axis: see how many result chunks fit in max_mem
# factor of 4 is memory for {compressed, uncompressed} x {input, output}
target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4)
if target_chunk_size <= 1:
raise ValueError(
f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size"
)

shape=scanned.shape,
dtype=scanned.dtype,
chunks=scanned.chunks,
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need input here too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the memory allocated to read from the side inputs (scanned and increment here). We double the chunk memory to account for reading the compressed Zarr chunk, so the result would be

extra_projected_mem=scanned.chunkmem * 2 + increment.chunkmem * 2

(There's an open issue #288 to make this a bit more transparent.)

Copy link
Member

@tomwhite tomwhite left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you going to add a user-facing cumulative_sum function from the Array API? This would be a good function for the unit tests to test.

shape=scanned.shape,
dtype=scanned.dtype,
chunks=scanned.chunks,
extra_projected_mem=scanned.chunkmem * 2, # arbitrary
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the memory allocated to read from the side inputs (scanned and increment here). We double the chunk memory to account for reading the compressed Zarr chunk, so the result would be

extra_projected_mem=scanned.chunkmem * 2 + increment.chunkmem * 2

(There's an open issue #288 to make this a bit more transparent.)

# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan.
# Instead we generalize recursively apply the scan to `reduced`.
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of things to consider here: the number of chunks to combine at each stage, and the memory limits.

The first is like split_every in reduction, where the default is 4, although 6 or 8 may be better for larger workloads.

For the second, we should make sure the new chunksize is no larger than (x.spec.allowed_mem - x.spec.reserved_mem) // 4, where the factor of 4 is comes about because of the {compressed,uncompressed} * {input,output} copies.

There is an error case where this memory constraint means the new chunksize is no larger than the existing one, so the computation can't proceed. The user can fix this either by reducing the chunksize or by increasing the memory. This is similar to this case:

cubed/cubed/core/ops.py

Lines 985 to 991 in 88c5dc4

# single axis: see how many result chunks fit in max_mem
# factor of 4 is memory for {compressed, uncompressed} x {input, output}
target_chunk_size = (max_mem - chunk_mem) // (chunk_mem * 4)
if target_chunk_size <= 1:
raise ValueError(
f"Not enough memory for reduction. Increase allowed_mem ({allowed_mem}) or decrease chunk size"
)

"""
# Blelloch (1990) out-of-core algorithm.
# 1. First, scan blockwise
scanned = blockwise(func, "ij", array, "ij", axis=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using map_blocks would be simpler and avoid the 2D assumption

@@ -1442,3 +1443,120 @@ def smallest_blockdim(blockdims):
m = ntd[0]
out = ntd
return out


def wrapper_binop(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call something like _scan_binop to link it to the scan implementation? I've been using a naming convention like that elsewhere in the file.

@tomwhite tomwhite mentioned this pull request Aug 1, 2024
@tomwhite
Copy link
Member

I'm going to merge this, and then do some follow-up PRs to add cumulative sum and product functions.

@dcherian
Copy link
Contributor Author

Sorry I dropped it. Does it still work?

@tomwhite
Copy link
Member

Sorry I dropped it. Does it still work?

No problem! I just wrote a cumulative sum test that passes, so yes! It uses map_direct, which has since been deprecated in favour of map_selection, so I'd like to change that too (which may be a bit fiddly).

As I said I'm very happy to push this forward building on the work you did, unless you'd like to take a look? My goal is to close it so we can mark #438 as complete.

@tomwhite
Copy link
Member

It uses map_direct, which has since been deprecated in favour of map_selection, so I'd like to change that too (which may be a bit fiddly).

Actually it's not possible to use map_selection, since it only works on single arrays, and we have two (scanned and increment). Instead we can use general_blockwise with a key function that returns the relevant chunk in increments that corresponds to the chunk in the main array (slightly tricky since the chunking differs). I'll try to do that as a follow on.

@dcherian
Copy link
Contributor Author

I won't be able to take a look till the weekend, so feel free to take over!

@tomwhite tomwhite marked this pull request as ready for review June 11, 2025 16:52
@tomwhite
Copy link
Member

I think this is basically working now - the test failures are unrelated. I'll leave it open for a bit in case you want to have a look @dcherian.

@dcherian
Copy link
Contributor Author

Ah I was so close! LGTM. Thanks for finishing it up

@tomwhite tomwhite merged commit 7de9081 into cubed-dev:main Jun 12, 2025
16 of 19 checks passed
@tomwhite
Copy link
Member

Thanks for your work on this @dcherian!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add scan / prefix sum primitive
2 participants