Skip to content

Commit b50cd08

Browse files
Merge branch 'px_special_inputs' into wide_form2
2 parents 236cd2c + 918b87b commit b50cd08

File tree

4 files changed

+118
-5
lines changed

4 files changed

+118
-5
lines changed

Diff for: packages/python/plotly/plotly/express/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
get_trendline_results,
5656
)
5757

58+
from ._special_inputs import ( # noqa: F401
59+
IdentityMap,
60+
Constant,
61+
)
62+
5863
from . import data, colors # noqa: F401
5964

6065
__all__ = [
@@ -95,4 +100,6 @@
95100
"colors",
96101
"set_mapbox_access_token",
97102
"get_trendline_results",
103+
"IdentityMap",
104+
"Constant",
98105
]

Diff for: packages/python/plotly/plotly/express/_core.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import plotly.graph_objs as go
22
import plotly.io as pio
33
from collections import namedtuple, OrderedDict
4+
from ._special_inputs import IdentityMap, Constant
45

56
from _plotly_utils.basevalidators import ColorscaleValidator
67
from .colors import qualitative, sequential
@@ -41,6 +42,7 @@ def __init__(self):
4142
defaults = PxDefaults()
4243
del PxDefaults
4344

45+
4446
MAPBOX_TOKEN = None
4547

4648

@@ -141,11 +143,15 @@ def make_mapping(args, variable):
141143
if variable == "dash":
142144
arg_name = "line_dash"
143145
vprefix = "line_dash"
146+
if args[vprefix + "_map"] == "identity":
147+
val_map = IdentityMap()
148+
else:
149+
val_map = args[vprefix + "_map"].copy()
144150
return Mapping(
145151
show_in_trace_name=True,
146152
variable=variable,
147153
grouper=args[arg_name],
148-
val_map=args[vprefix + "_map"].copy(),
154+
val_map=val_map,
149155
sequence=args[vprefix + "_sequence"],
150156
updater=lambda trace, v: trace.update({parent: {variable: v}}),
151157
facet=None,
@@ -937,6 +943,8 @@ def build_dataframe(args, attrables, array_attrables, constructor):
937943
else:
938944
df_output[df_input.columns] = df_input[df_input.columns]
939945

946+
constants = dict()
947+
940948
# Loop over possible arguments
941949
for field_name in attrables:
942950
# Massaging variables
@@ -968,8 +976,15 @@ def build_dataframe(args, attrables, array_attrables, constructor):
968976
"pandas MultiIndex is not supported by plotly express "
969977
"at the moment." % field
970978
)
979+
# ----------------- argument is a constant ----------------------
980+
if isinstance(argument, Constant):
981+
col_name = _check_name_not_reserved(
982+
str(argument.label) if argument.label is not None else field,
983+
reserved_names,
984+
)
985+
constants[col_name] = argument.value
971986
# ----------------- argument is a col name ----------------------
972-
if isinstance(argument, str) or isinstance(
987+
elif isinstance(argument, str) or isinstance(
973988
argument, int
974989
): # just a column name given as str or int
975990
if not df_provided:
@@ -1073,6 +1088,9 @@ def build_dataframe(args, attrables, array_attrables, constructor):
10731088
args["x" if orient_v else "y"] = "value"
10741089
args["color"] = args["color"] or "variable"
10751090

1091+
for col_name in constants:
1092+
df_output[col_name] = constants[col_name]
1093+
10761094
args["data_frame"] = df_output
10771095
return args
10781096

@@ -1491,9 +1509,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
14911509
for col, val, m in zip(grouper, group_name, grouped_mappings):
14921510
if col != one_group:
14931511
key = get_label(args, col)
1494-
mapping_labels[key] = str(val)
1495-
if m.show_in_trace_name:
1496-
trace_name_labels[key] = str(val)
1512+
if not isinstance(m.val_map, IdentityMap):
1513+
mapping_labels[key] = str(val)
1514+
if m.show_in_trace_name:
1515+
trace_name_labels[key] = str(val)
14971516
if m.variable == "animation_frame":
14981517
frame_name = val
14991518
trace_name = ", ".join(trace_name_labels.values())
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
class IdentityMap(object):
2+
"""
3+
`dict`-like object which acts as if the value for any key is the key itself. Objects
4+
of this class can be passed in to arguments like `color_discrete_map` to
5+
use the provided data values as colors, rather than mapping them to colors cycled
6+
from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
7+
functions, such as `line_dash_map` and `symbol_map`.
8+
"""
9+
10+
def __getitem__(self, key):
11+
return key
12+
13+
def __contains__(self, key):
14+
return True
15+
16+
def copy(self):
17+
return self
18+
19+
20+
class Constant(object):
21+
"""
22+
Objects of this class can be passed to Plotly Express functions that expect column
23+
identifiers or list-like objects to indicate that this attribute should take on a
24+
constant value. An optional label can be provided.
25+
"""
26+
27+
def __init__(self, value, label=None):
28+
self.value = value
29+
self.label = label

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py

+58
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,61 @@ def test_size_column():
323323
df = px.data.tips()
324324
fig = px.scatter(df, x=df["size"], y=df.tip)
325325
assert fig.data[0].hovertemplate == "size=%{x}<br>tip=%{y}<extra></extra>"
326+
327+
328+
def test_identity_map():
329+
fig = px.scatter(
330+
x=[1, 2],
331+
y=[1, 2],
332+
symbol=["a", "b"],
333+
color=["red", "blue"],
334+
color_discrete_map=px.IdentityMap(),
335+
)
336+
assert fig.data[0].marker.color == "red"
337+
assert fig.data[1].marker.color == "blue"
338+
assert "color=" not in fig.data[0].hovertemplate
339+
assert "symbol=" in fig.data[0].hovertemplate
340+
assert fig.layout.legend.title.text == "symbol"
341+
342+
fig = px.scatter(
343+
x=[1, 2],
344+
y=[1, 2],
345+
symbol=["a", "b"],
346+
color=["red", "blue"],
347+
color_discrete_map="identity",
348+
)
349+
assert fig.data[0].marker.color == "red"
350+
assert fig.data[1].marker.color == "blue"
351+
assert "color=" not in fig.data[0].hovertemplate
352+
assert "symbol=" in fig.data[0].hovertemplate
353+
assert fig.layout.legend.title.text == "symbol"
354+
355+
356+
def test_constants():
357+
fig = px.scatter(x=px.Constant(1), y=[1, 2])
358+
assert fig.data[0].x[0] == 1
359+
assert fig.data[0].x[1] == 1
360+
assert "x=" in fig.data[0].hovertemplate
361+
362+
fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
363+
assert fig.data[0].x[0] == 1
364+
assert fig.data[0].x[1] == 1
365+
assert "x=" not in fig.data[0].hovertemplate
366+
assert "time=" in fig.data[0].hovertemplate
367+
368+
fig = px.scatter(
369+
x=[1, 2],
370+
y=[1, 2],
371+
symbol=["a", "b"],
372+
color=px.Constant("red", label="the_identity_label"),
373+
hover_data=[px.Constant("data", label="the_data")],
374+
color_discrete_map=px.IdentityMap(),
375+
)
376+
assert fig.data[0].marker.color == "red"
377+
assert fig.data[0].customdata[0][0] == "data"
378+
assert fig.data[1].marker.color == "red"
379+
assert "color=" not in fig.data[0].hovertemplate
380+
assert "the_identity_label=" not in fig.data[0].hovertemplate
381+
assert "symbol=" in fig.data[0].hovertemplate
382+
assert "the_data=" in fig.data[0].hovertemplate
383+
assert fig.layout.legend.title.text == "symbol"

0 commit comments

Comments
 (0)