@@ -1212,8 +1212,13 @@ def plot(self):
12121212
12131213 row_min , row_max , col_min , col_max = self ._selection_bounds ()
12141214 dim_names = self .model ().xlabels [0 ]
1215+ # label for each selected column
12151216 xlabels = self .model ().xlabels [1 ][col_min :col_max ]
1216- ylabels = self .model ().ylabels [1 :][row_min :row_max ]
1217+ # list of selected labels for each index column
1218+ labels_per_index_column = [col_labels [row_min :row_max ] for col_labels in self .model ().ylabels [1 :]]
1219+ # list of (str) label for each selected row
1220+ ylabels = [[str (label ) for label in row_labels ]
1221+ for row_labels in zip (* labels_per_index_column )]
12171222
12181223 assert data .ndim == 2
12191224
@@ -1225,19 +1230,17 @@ def plot(self):
12251230 if data .shape [1 ] == 1 :
12261231 # plot one column
12271232 xlabel = ',' .join (dim_names [:- 1 ])
1228- xticklabels = ['\n ' .join ([str (ylabels [c ][r ]) for c in range (len (ylabels ))])
1229- for r in range (row_max - row_min )]
1230- xdata = np .arange (row_max - row_min , dtype = int )
1233+ xticklabels = ['\n ' .join (ylabels [row ]) for row in range (row_max - row_min )]
1234+ xdata = np .arange (row_max - row_min )
12311235 ax .plot (xdata , data [:, 0 ])
12321236 ax .set_ylabel (xlabels [0 ])
12331237 else :
12341238 # plot each row as a line
12351239 xlabel = dim_names [- 1 ]
12361240 xticklabels = [str (label ) for label in xlabels ]
1237- xdata = np .arange (col_max - col_min , dtype = int )
1241+ xdata = np .arange (col_max - col_min )
12381242 for row in range (len (data )):
1239- label = ',' .join ([str (label ) for label in ylabels [row ]])
1240- ax .plot (xdata , data [row ], label = label )
1243+ ax .plot (xdata , data [row ], label = ' ' .join (ylabels [row ]))
12411244
12421245 # set x axis
12431246 ax .set_xlabel (xlabel )
0 commit comments