Skip to content

Commit 4909b82

Browse files
authored
Merge pull request #109 from MarinManuel/master
- Fixed bug preventing non-string columns to be used - Modified the way data is accessed from pivoted_plot_data to avoid problems with NaN in some lines (from PR #110
2 parents 8d76edd + 21be690 commit 4909b82

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

dabest/_classes.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,21 @@ def __init__(self, data, idx, x, y, paired, id_col, ci,
127127

128128
else:
129129
x1_level = data_in[x[0]].unique()
130+
elif experiment is not None:
131+
experiment_label = data_in[experiment].unique()
132+
x1_level = data_in[x[0]].unique()
130133
self.__experiment_label = experiment_label
131134
self.__x1_level = x1_level
132135

133136

134-
# Check if idx is specified
135-
if delta2 is False and not idx:
136-
err = '`idx` is not a column in `data`. Please check.'
137-
raise IndexError(err)
137+
# # Check if idx is specified
138+
# if delta2 is False and not idx:
139+
# err = '`idx` is not a column in `data`. Please check.'
140+
# raise IndexError(err)
138141

139142

140143
# create new x & idx and record the second variable if this is a valid 2x2 ANOVA case
141-
if delta2 is True:
144+
if idx is None and x is not None and y is not None:
142145
# add a new column which is a combination of experiment and the first variable
143146
new_col_name = experiment+x[0]
144147
while new_col_name in data_in.columns:
@@ -165,7 +168,7 @@ def __init__(self, data, idx, x, y, paired, id_col, ci,
165168

166169

167170
# Determine the kind of estimation plot we need to produce.
168-
if all([isinstance(i, str) for i in idx]):
171+
if all([isinstance(i, (str, int, float)) for i in idx]):
169172
# flatten out idx.
170173
all_plot_groups = pd.unique([t for t in idx]).tolist()
171174
if len(idx) > len(all_plot_groups):

dabest/plotter.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -400,35 +400,24 @@ def EffectSizeDataFramePlotter(EffectSizeDataFrame, **plot_kwargs):
400400
# Plot the raw data as a slopegraph.
401401
# Pivot the long (melted) data.
402402
if color_col is None:
403-
pivot_values = yvar
403+
pivot_values = [yvar]
404404
else:
405405
pivot_values = [yvar, color_col]
406406
pivoted_plot_data = pd.pivot(data=plot_data, index=dabest_obj.id_col,
407-
columns=xvar, values=pivot_values)
407+
columns=xvar, values=pivot_values)
408408
x_start = 0
409409
for ii, current_tuple in enumerate(temp_idx):
410-
if len(temp_idx) > 1:
411-
# Select only the data for the current tuple.
412-
if color_col is None:
413-
current_pair = pivoted_plot_data.reindex(columns=current_tuple)
414-
else:
415-
current_pair = pivoted_plot_data[yvar].reindex(columns=current_tuple)
416-
else:
417-
if color_col is None:
418-
current_pair = pivoted_plot_data
419-
else:
420-
current_pair = pivoted_plot_data[yvar]
410+
current_pair = pivoted_plot_data.loc[:, pd.MultiIndex.from_product([pivot_values, current_tuple])].dropna()
421411
grp_count = len(current_tuple)
422412
# Iterate through the data for the current tuple.
423413
for ID, observation in current_pair.iterrows():
424414
x_points = [t for t in range(x_start, x_start + grp_count)]
425-
y_points = observation.tolist()
415+
y_points = observation[yvar].tolist()
426416

427417
if color_col is None:
428418
slopegraph_kwargs['color'] = ytick_color
429419
else:
430-
color_key = pivoted_plot_data[color_col,
431-
current_tuple[0]].loc[ID]
420+
color_key = observation[color_col][0]
432421
if isinstance(color_key, str) == True:
433422
slopegraph_kwargs['color'] = plot_palette_raw[color_key]
434423
slopegraph_kwargs['label'] = color_key

0 commit comments

Comments
 (0)