@@ -1212,8 +1212,13 @@ def plot(self):
1212
1212
1213
1213
row_min , row_max , col_min , col_max = self ._selection_bounds ()
1214
1214
dim_names = self .model ().xlabels [0 ]
1215
+ # label for each selected column
1215
1216
xlabels = self .model ().xlabels [1 ][col_min :col_max ]
1216
- ylabels = self .model ().ylabels [1 :][row_min :row_max ]
1217
+ # list of selected labels for each index column
1218
+ labels_per_index_column = [col_labels [row_min :row_max ] for col_labels in self .model ().ylabels [1 :]]
1219
+ # list of (str) label for each selected row
1220
+ ylabels = [[str (label ) for label in row_labels ]
1221
+ for row_labels in zip (* labels_per_index_column )]
1217
1222
1218
1223
assert data .ndim == 2
1219
1224
@@ -1225,19 +1230,17 @@ def plot(self):
1225
1230
if data .shape [1 ] == 1 :
1226
1231
# plot one column
1227
1232
xlabel = ',' .join (dim_names [:- 1 ])
1228
- xticklabels = ['\n ' .join ([str (ylabels [c ][r ]) for c in range (len (ylabels ))])
1229
- for r in range (row_max - row_min )]
1230
- xdata = np .arange (row_max - row_min , dtype = int )
1233
+ xticklabels = ['\n ' .join (ylabels [row ]) for row in range (row_max - row_min )]
1234
+ xdata = np .arange (row_max - row_min )
1231
1235
ax .plot (xdata , data [:, 0 ])
1232
1236
ax .set_ylabel (xlabels [0 ])
1233
1237
else :
1234
1238
# plot each row as a line
1235
1239
xlabel = dim_names [- 1 ]
1236
1240
xticklabels = [str (label ) for label in xlabels ]
1237
- xdata = np .arange (col_max - col_min , dtype = int )
1241
+ xdata = np .arange (col_max - col_min )
1238
1242
for row in range (len (data )):
1239
- label = ',' .join ([str (label ) for label in ylabels [row ]])
1240
- ax .plot (xdata , data [row ], label = label )
1243
+ ax .plot (xdata , data [row ], label = ' ' .join (ylabels [row ]))
1241
1244
1242
1245
# set x axis
1243
1246
ax .set_xlabel (xlabel )
0 commit comments