Skip to content

Commit 444a394

Browse files
wip
1 parent d60b4f4 commit 444a394

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -899,20 +899,36 @@ def build_dataframe(args, attrables, array_attrables, constructor):
899899
if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
900900
args["data_frame"] = pd.DataFrame(args["data_frame"])
901901

902-
if not args.get("x", None) and not args.get("y", None) and df_provided:
902+
wide_traces = [go.Scatter, go.Bar, go.Violin, go.Box, go.Histogram]
903+
has_x = args.get("x", None) is not None
904+
has_y = args.get("y", None) is not None
905+
if not has_x and not has_y and df_provided and constructor in wide_traces:
906+
index_name = args["data_frame"].index.name or "index"
907+
id_vars = [index_name]
908+
# TODO multi-level index
909+
# TODO multi-level columns
910+
# TODO orientation
911+
912+
# TODO do we need to add everything to this candidate list basically? array_attrables?
913+
# TODO will we need to be able to glue in non-string values here, like arrays and stuff?
914+
# ...like maybe this needs to run after we've glued together the data frame?
915+
for candidate in ["color", "symbol", "line_dash", "facet_row", "facet_col"] + [
916+
"line_group",
917+
"animation_group",
918+
]:
919+
if args.get(candidate, None) not in [None, index_name, "value", "variable"]:
920+
id_vars.append(args[candidate])
921+
args["data_frame"] = args["data_frame"].reset_index().melt(id_vars=id_vars)
903922
if constructor in [go.Scatter, go.Bar]:
904-
args["data_frame"] = args["data_frame"].reset_index().melt(id_vars="index")
905-
args["x"] = "index"
923+
args["x"] = index_name
906924
args["y"] = "value"
907-
args["color"] = "variable"
925+
args["color"] = args["color"] or "variable"
908926
if constructor in [go.Violin, go.Box]:
909-
args["data_frame"] = args["data_frame"].reset_index().melt(id_vars="index")
910927
args["x"] = "variable"
911928
args["y"] = "value"
912929
if constructor in [go.Histogram]:
913-
args["data_frame"] = args["data_frame"].reset_index().melt(id_vars="index")
914930
args["x"] = "value"
915-
args["color"] = "variable"
931+
args["color"] = args["color"] or "variable"
916932

917933
df_input = args["data_frame"]
918934

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_build_df_from_lists():
225225
output = {key: key for key in args}
226226
df = pd.DataFrame(args)
227227
args["data_frame"] = None
228-
out = build_dataframe(args, all_attrables, array_attrables)
228+
out = build_dataframe(args, all_attrables, array_attrables, None)
229229
assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1))
230230
out.pop("data_frame")
231231
assert out == output
@@ -235,7 +235,7 @@ def test_build_df_from_lists():
235235
output = {key: key for key in args}
236236
df = pd.DataFrame(args)
237237
args["data_frame"] = None
238-
out = build_dataframe(args, all_attrables, array_attrables)
238+
out = build_dataframe(args, all_attrables, array_attrables, None)
239239
assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1))
240240
out.pop("data_frame")
241241
assert out == output
@@ -244,7 +244,7 @@ def test_build_df_from_lists():
244244
def test_build_df_with_index():
245245
tips = px.data.tips()
246246
args = dict(data_frame=tips, x=tips.index, y="total_bill")
247-
out = build_dataframe(args, all_attrables, array_attrables)
247+
out = build_dataframe(args, all_attrables, array_attrables, None)
248248
assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"])
249249

250250

@@ -254,15 +254,15 @@ def test_non_matching_index():
254254
expected = pd.DataFrame(dict(x=["a", "b", "c"], y=[1, 2, 3]))
255255

256256
args = dict(data_frame=df, x=df.index, y="y")
257-
out = build_dataframe(args, all_attrables, array_attrables)
257+
out = build_dataframe(args, all_attrables, array_attrables, None)
258258
assert_frame_equal(expected, out["data_frame"])
259259

260260
args = dict(data_frame=None, x=df.index, y=df.y)
261-
out = build_dataframe(args, all_attrables, array_attrables)
261+
out = build_dataframe(args, all_attrables, array_attrables, None)
262262
assert_frame_equal(expected, out["data_frame"])
263263

264264
args = dict(data_frame=None, x=["a", "b", "c"], y=df.y)
265-
out = build_dataframe(args, all_attrables, array_attrables)
265+
out = build_dataframe(args, all_attrables, array_attrables, None)
266266
assert_frame_equal(expected, out["data_frame"])
267267

268268

0 commit comments

Comments
 (0)