66import numpy as np
77import pandas as pd
88
9- from .core import find_group_cohorts
9+ from .core import _unique , find_group_cohorts
1010
1111
1212def draw_mesh (
@@ -131,14 +131,14 @@ def factorize_cohorts(by, cohorts):
131131 return factorized
132132
133133
134- def visualize_cohorts_2d (by , array , method = "cohorts" ):
134+ def visualize_cohorts_2d (by , array ):
135135 assert by .ndim == 2
136136 print ("finding cohorts..." )
137137 before_merged = find_group_cohorts (
138- by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = False , method = method
138+ by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = False
139139 ).values ()
140140 merged = find_group_cohorts (
141- by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = True , method = method
141+ by , [array .chunks [ax ] for ax in range (- by .ndim , 0 )], merge = True
142142 ).values ()
143143 print ("finished cohorts..." )
144144
@@ -149,16 +149,12 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
149149 ax = ax .ravel ()
150150 ax [1 ].set_visible (False )
151151 ax = ax [[0 , 2 , 3 ]]
152- flat = by .ravel ()
153- ngroups = len (np .unique (flat [~ np .isnan (flat )]))
154152
153+ ngroups = len (_unique (by ))
155154 h0 = ax [0 ].imshow (by , cmap = get_colormap (ngroups ))
156- h1 = ax [1 ].imshow (
157- factorize_cohorts (by , before_merged ),
158- vmin = 0 ,
159- cmap = get_colormap (len (before_merged )),
160- )
161- h2 = ax [2 ].imshow (factorize_cohorts (by , merged ), vmin = 0 , cmap = get_colormap (len (merged )))
155+ h1 = _visualize_cohorts (by , before_merged , ax = ax [1 ])
156+ h2 = _visualize_cohorts (by , merged , ax = ax [2 ])
157+
162158 for axx in ax :
163159 axx .grid (True , which = "both" )
164160 axx .set_xticks (xticks )
@@ -170,3 +166,10 @@ def visualize_cohorts_2d(by, array, method="cohorts"):
170166 ax [1 ].set_title (f"{ len (before_merged )} cohorts" )
171167 ax [2 ].set_title (f"{ len (merged )} merged cohorts" )
172168 f .set_size_inches ((6 , 6 ))
169+
170+
171+ def _visualize_cohorts (by , cohorts , ax = None ):
172+ if ax is None :
173+ _ , ax = plt .subplots (1 , 1 )
174+
175+ ax .imshow (factorize_cohorts (by , cohorts ), vmin = 0 , cmap = get_colormap (len (cohorts )))
0 commit comments