diff --git a/flox/visualize.py b/flox/visualize.py index fd712fd4b..0a8b84ea3 100644 --- a/flox/visualize.py +++ b/flox/visualize.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from .core import find_group_cohorts +from .core import _unique, find_group_cohorts def draw_mesh( @@ -131,14 +131,14 @@ def factorize_cohorts(by, cohorts): return factorized -def visualize_cohorts_2d(by, array, method="cohorts"): +def visualize_cohorts_2d(by, array): assert by.ndim == 2 print("finding cohorts...") before_merged = find_group_cohorts( - by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False, method=method + by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=False ).values() merged = find_group_cohorts( - by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True, method=method + by, [array.chunks[ax] for ax in range(-by.ndim, 0)], merge=True ).values() print("finished cohorts...") @@ -149,16 +149,12 @@ def visualize_cohorts_2d(by, array, method="cohorts"): ax = ax.ravel() ax[1].set_visible(False) ax = ax[[0, 2, 3]] - flat = by.ravel() - ngroups = len(np.unique(flat[~np.isnan(flat)])) + ngroups = len(_unique(by)) h0 = ax[0].imshow(by, cmap=get_colormap(ngroups)) - h1 = ax[1].imshow( - factorize_cohorts(by, before_merged), - vmin=0, - cmap=get_colormap(len(before_merged)), - ) - h2 = ax[2].imshow(factorize_cohorts(by, merged), vmin=0, cmap=get_colormap(len(merged))) + h1 = _visualize_cohorts(by, before_merged, ax=ax[1]) + h2 = _visualize_cohorts(by, merged, ax=ax[2]) + for axx in ax: axx.grid(True, which="both") axx.set_xticks(xticks) @@ -170,3 +166,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"): ax[1].set_title(f"{len(before_merged)} cohorts") ax[2].set_title(f"{len(merged)} merged cohorts") f.set_size_inches((6, 6)) + + +def _visualize_cohorts(by, cohorts, ax=None): + if ax is None: + _, ax = plt.subplots(1, 1) + + ax.imshow(factorize_cohorts(by, cohorts), vmin=0, cmap=get_colormap(len(cohorts)))