Skip to content

Dask reshape bug for arrays with fully chunked leading axes #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ravwojdyla opened this issue Feb 4, 2021 · 1 comment
Open

Dask reshape bug for arrays with fully chunked leading axes #456

ravwojdyla opened this issue Feb 4, 2021 · 1 comment
Labels
bug Something isn't working

Comments

@ravwojdyla
Copy link
Collaborator

Long story short, the problematic line in our code is: https://github.com/pystatgen/sgkit/blob/41827f3fd116d59ab4dc8b119a15ad5f3be730b9/sgkit/stats/regenie.py#L364
dask/dask#6748 is a special case optimisation:

When the slow-moving (early) axes in .reshape are all size 1

Our YP[i] happens to fall into that category. And afaiu dask/dask#6748 introduced a bug, it might be hard to see that in the reginie code, here's a distilled reproduction:

> # c4038add1a087ba3a82207a557cdcad9753b689d is the https://github.com/dask/dask/pull/6748
> dask git:(c4038add) g ck c4038add1a087ba3a82207a557cdcad9753b689d
HEAD is now at c4038add Avoid rechunking in reshape with chunksize=1 (#6748)
> dask git:(c4038add) python
>>> import numpy as np
>>> import dask.array as da
>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (4,)))
>>> a.reshape(6,4)
dask.array<reshape, shape=(6, 4), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # merging dimensions at the front works fine, now let's try the last two (which is our use case)
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # NOTICE: the shape is (2, 24) NOT (2,12)!, now let's try to compute this:
>>> a.reshape(2,12).compute()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/rav/projects/dask/dask/base.py", line 167, in compute
    (result,) = compute(self, traverse=False, **kwargs)
  File "/Users/rav/projects/dask/dask/base.py", line 454, in compute
    results = schedule(dsk, keys, **kwargs)
  File "/Users/rav/projects/dask/dask/threaded.py", line 76, in get
    results = get_async(
  File "/Users/rav/projects/dask/dask/local.py", line 503, in get_async
    return nested_get(result, state["cache"])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in nested_get
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in <listcomp>
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in nested_get
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in <listcomp>
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in nested_get
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 299, in <listcomp>
    return tuple([nested_get(i, coll) for i in ind])
  File "/Users/rav/projects/dask/dask/local.py", line 301, in nested_get
    return coll[ind]
KeyError: ('reshape-aa29de8b0f6be5be25495836ed047c4a', 1, 0)

Notice the invalid shape after a.reshape(2,12).

The same code works fine without dask/dask#6748:

> dask git:(c4038add) g ck head~1
Previous HEAD position was c4038add Avoid rechunking in reshape with chunksize=1 (#6748)
HEAD is now at 94bdd4e3 Try to make categoricals work on join (#6205)
> dask git:(94bdd4e3) python
>>> import numpy as np
>>> import dask.array as da
>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (4,)))
>>> a.reshape(6,4)
dask.array<reshape, shape=(6, 4), dtype=int64, chunksize=(3, 4), chunktype=numpy.ndarray>
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 12), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> a.reshape(2,12).compute()
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])

This issue seems to be affected by the chunking of the array (a), in the case above (and in our reginie case) the lower axis are completely chunked, see:

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (4,)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # BAD

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (2,2)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 2), chunktype=numpy.ndarray>
>>> # BAD

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 1, 1), (1, 1, 1, 1)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 24), dtype=int64, chunksize=(1, 1), chunktype=numpy.ndarray>
>>> # BAD

>>> a = da.from_array(np.arange(24).reshape(2, 3, 4), chunks=((1, 1), (1, 2), (2, 2)))
>>> a.reshape(2,12)
dask.array<reshape, shape=(2, 12), dtype=int64, chunksize=(1, 4), chunktype=numpy.ndarray>
>>> # OK

Btw, this is an good example of how valuable Eric's asserts in this case are, it's already pretty hard to debug this code, and image if it just failed at compute() with a cryptic KeyError: ('reshape-aa29de8b0f6be5be25495836ed047c4a', 1, 0), so big +1 to https://github.com/pystatgen/sgkit/issues/267

Originally posted by @ravwojdyla in https://github.com/pystatgen/sgkit/issues/430#issuecomment-772868301

@tomwhite
Copy link
Collaborator

tomwhite commented Apr 6, 2021

This has now been fixed in Dask, so should be possible to check if it resolves the issue here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants