Skip to content

Slow performance with groupby using a custom DataArray grouper #8377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
alessioarena opened this issue Oct 26, 2023 · 6 comments · Fixed by #8758
Closed

Slow performance with groupby using a custom DataArray grouper #8377

alessioarena opened this issue Oct 26, 2023 · 6 comments · Fixed by #8758

Comments

@alessioarena
Copy link

What is your issue?

I have a code that calculates a per-pixel nearest neighbor match between two datasets, to then perform a groupby + aggregation.
The calculation I perform is generally lazy using dask.

I recently noticed a slow performance of groupby in this way, with lazy calculations taking in excess of 10 minutes for an index of approximately 4000 by 4000.

I did a bit of digging around and noticed that the slow line is this:

Timer unit: 1e-09 s

Total time: 0.263679 s
File: /env/lib/python3.10/site-packages/xarray/core/duck_array_ops.py
Function: array_equiv at line 260

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   260                                           def array_equiv(arr1, arr2):
   261                                               """Like np.array_equal, but also allows values to be NaN in both arrays"""
   262     22140   96490101.0   4358.2     36.6      arr1 = asarray(arr1)
   263     22140   34155953.0   1542.7     13.0      arr2 = asarray(arr2)
   264     22140  119855572.0   5413.5     45.5      lazy_equiv = lazy_array_equiv(arr1, arr2)
   265     22140    7390478.0    333.8      2.8      if lazy_equiv is None:
   266                                                   with warnings.catch_warnings():
   267                                                       warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
   268                                                       flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
   269                                                       return bool(flag_array.all())
   270                                               else:
   271     22140    5787053.0    261.4      2.2          return lazy_equiv

Total time: 242.247 s
File: /env/lib/python3.10/site-packages/xarray/core/indexing.py
Function: __getitem__ at line 1419

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1419                                               def __getitem__(self, key):
  1420     22140   26764337.0   1208.9      0.0          if not isinstance(key, VectorizedIndexer):
  1421                                                       # if possible, short-circuit when keys are effectively slice(None)
  1422                                                       # This preserves dask name and passes lazy array equivalence checks
  1423                                                       # (see duck_array_ops.lazy_array_equiv)
  1424     22140   10513930.0    474.9      0.0              rewritten_indexer = False
  1425     22140    4602305.0    207.9      0.0              new_indexer = []
  1426     66420   61804870.0    930.5      0.0              for idim, k in enumerate(key.tuple):
  1427     88560   78516641.0    886.6      0.0                  if isinstance(k, Iterable) and (
  1428     22140  151748667.0   6854.1      0.1                      not is_duck_dask_array(k)
  1429     22140        2e+11    1e+07     93.6                      and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
  1430                                                           ):
  1431                                                               new_indexer.append(slice(None))
  1432                                                               rewritten_indexer = True
  1433                                                           else:
  1434     44280   40322984.0    910.6      0.0                      new_indexer.append(k)
  1435     22140    4847251.0    218.9      0.0              if rewritten_indexer:
  1436                                                           key = type(key)(tuple(new_indexer))
  1437                                           
  1438     22140   24251221.0   1095.4      0.0          if isinstance(key, BasicIndexer):
  1439                                                       return self.array[key.tuple]
  1440     22140    9613954.0    434.2      0.0          elif isinstance(key, VectorizedIndexer):
  1441                                                       return self.array.vindex[key.tuple]
  1442                                                   else:
  1443     22140    8618414.0    389.3      0.0              assert isinstance(key, OuterIndexer)
  1444     22140   26601491.0   1201.5      0.0              key = key.tuple
  1445     22140    6010672.0    271.5      0.0              try:
  1446     22140        2e+10 678487.7      6.2                  return self.array[key]
  1447                                                       except NotImplementedError:
  1448                                                           # manual orthogonal indexing.
  1449                                                           # TODO: port this upstream into dask in a saner way.
  1450                                                           value = self.array
  1451                                                           for axis, subkey in reversed(list(enumerate(key))):
  1452                                                               value = value[(slice(None),) * axis + (subkey,)]
  1453                                                           return value

The test duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim])) is repeated multiple times, and despite that being decently fast it amounts to a lot of time that could be potentially minimized by introducing a prior test of equal length, like

                if isinstance(k, Iterable) and (
                    not is_duck_dask_array(k)
                    and len(k) == self.array.shape[idim]
                    and duck_array_ops.array_equiv(k, np.arange(self.array.shape[idim]))
                ):

This would work better because, despite that test being performed by array_equiv, currently the array to test against is always created using np.arange, that being ultimately the bottleneck

         74992059 function calls (73375414 primitive calls) in 298.934 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    22140  225.296    0.010  225.296    0.010 {built-in method numpy.arange}
   177123    3.192    0.000    3.670    0.000 inspect.py:2920(__init__)
110702/110701    2.180    0.000    2.180    0.000 {built-in method numpy.asarray}
11690863/11668723    2.036    0.000    5.043    0.000 {built-in method builtins.isinstance}
   287827    1.876    0.000    3.768    0.000 utils.py:25(meta_from_array)
   132843    1.872    0.000    7.649    0.000 inspect.py:2280(_signature_from_function)
   974166    1.485    0.000    2.558    0.000 inspect.py:2637(__init__)
@alessioarena alessioarena added the needs triage Issue that has not been reviewed by xarray team member label Oct 26, 2023
@TomNicholas TomNicholas added topic-groupby topic-performance and removed needs triage Issue that has not been reviewed by xarray team member labels Oct 26, 2023
@mathause
Copy link
Collaborator

Good find. I think we are open to this change. However, we test for Iterable but this does not necessarily contain __len__. So the isinstance will have to be changed and use isinstance(k, Sequence) (or Collection?) see https://docs.python.org/3/library/collections.abc.html#collections-abstract-base-classes (I am not super confident about this)

Are you interested to open a PR?

@dcherian
Copy link
Contributor

Hmmm... this is some "optimization" I tried to add, and given all this additional complexity perhaps we can just delete it.

for reference it avoids changing the dask token for array[0, 1, 2, 3] if it is identical to array[slice(None))]`. Arguably, something like this should go in dask

@max-sixty
Copy link
Collaborator

max-sixty commented Oct 26, 2023

Great find @alessioarena , impressive work!


Can I ask — are array[0, 1, 2, 3] indexers frequently being passed to the method? For example, some internal or dask function passes them? Or is that just a theoretical case that a user might pass it?

If not, I agree with removing it, I wouldn't think it's that common?

(If it is created by some internal function — possibly we could pass a range object instead, which can then be checked without materializing a whole array. Though also possibly we can pass slice(None) itself...)

@dcherian
Copy link
Contributor

Yes when you sort a sorted array. But we don't want to do this for numpy, because np.sort implies a copy.

I think this was a bad addition, let's remove it. There will be a failing test that can be deleted.

@alessioarena
Copy link
Author

Thanks all for jumping on this so quickly.

I'm happy to do a PR if that is the preference, or leaving it to @dcherian to revert the addition.

Thanks heaps for all the amazing work you are doing! I'm quite an heavy and happy user of xarray/dask

@max-sixty
Copy link
Collaborator

Hi @alessioarena — a PR would be great if you'd be up for it — thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants