@@ -1211,8 +1211,8 @@ def plot(self):
1211
1211
1212
1212
row_min , row_max , col_min , col_max = self ._selection_bounds ()
1213
1213
dim_names = self .model ().xlabels [0 ]
1214
- xlabels = self .model ().xlabels
1215
- ylabels = self .model ().ylabels
1214
+ xlabels = self .model ().xlabels [ 1 ][ col_min : col_max ]
1215
+ ylabels = self .model ().ylabels [ 1 :][ row_min : row_max ]
1216
1216
1217
1217
assert data .ndim == 2
1218
1218
@@ -1224,27 +1224,29 @@ def plot(self):
1224
1224
if data .shape [1 ] == 1 :
1225
1225
# plot one column
1226
1226
xlabel = ',' .join (dim_names [:- 1 ])
1227
- xticklabels = ['\n ' .join ([str (ylabels [j ][r ]) for j in range (1 , len (ylabels ))])
1228
- for r in range (row_min , row_max )]
1229
- ax .plot (data [:, 0 ])
1230
- ax .set_ylabel (xlabels [1 ][col_min ])
1227
+ xticklabels = ['\n ' .join ([str (ylabels [c ][r ]) for c in range (len (ylabels ))])
1228
+ for r in range (row_max - row_min )]
1229
+ xdata = np .arange (row_max - row_min , dtype = int )
1230
+ ax .plot (xdata , data [:, 0 ])
1231
+ ax .set_ylabel (xlabels [0 ])
1231
1232
else :
1232
1233
# plot each row as a line
1233
1234
xlabel = dim_names [- 1 ]
1234
- xticklabels = [str (xlabels [1 ][c ]) for c in range (col_min , col_max )]
1235
+ xticklabels = [str (label ) for label in xlabels ]
1236
+ xdata = np .arange (col_max - col_min , dtype = int )
1235
1237
for row in range (len (data )):
1236
- label = ',' .join ([str (ylabels [j ][row_min + row ])
1237
- for j in range (1 , len (ylabels ))])
1238
- ax .plot (data [row ], label = label )
1238
+ label = ',' .join ([str (label ) for label in ylabels [row ]])
1239
+ ax .plot (xdata , data [row ], label = label )
1239
1240
1240
1241
# set x axis
1241
1242
ax .set_xlabel (xlabel )
1242
- ax .set_xlim (0 , len ( xticklabels ) - 1 )
1243
+ ax .set_xlim (( xdata [ 0 ], xdata [ - 1 ]) )
1243
1244
# we need to do that because matplotlib is smart enough to
1244
1245
# not show all ticks but a selection. However, that selection
1245
1246
# may include ticks outside the range of x axis
1246
1247
xticks = [t for t in ax .get_xticks ().astype (int ) if t <= len (xticklabels ) - 1 ]
1247
- xticklabels = [xticklabels [j ] for j in xticks ]
1248
+ xticklabels = [xticklabels [t ] for t in xticks ]
1249
+ ax .set_xticks (xticks )
1248
1250
ax .set_xticklabels (xticklabels )
1249
1251
1250
1252
if data .shape [1 ] != 1 :
0 commit comments