Skip to content

Commit faf2568

Browse files
committed
chore: move legend logic to standard graph utils
1 parent 90e2baa commit faf2568

File tree

3 files changed

+64
-23
lines changed

3 files changed

+64
-23
lines changed

pyretailscience/plots/line.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@
4545
def plot(
4646
df: pd.DataFrame,
4747
value_col: str | list[str],
48-
x_label: str | None = None,
49-
y_label: str | None = None,
50-
title: str | None = None,
48+
x_label: str,
49+
y_label: str,
50+
title: str,
5151
x_col: str | None = None,
5252
group_col: str | None = None,
53-
legend_title: str | None = None,
5453
ax: Axes | None = None,
5554
source_text: str | None = None,
55+
legend_title: str | None = None,
5656
move_legend_outside: bool = False,
5757
**kwargs: dict[str, any],
5858
) -> SubplotBase:
@@ -61,9 +61,9 @@ def plot(
6161
Args:
6262
df (pd.DataFrame): The dataframe to plot.
6363
value_col (str or list of str): The column(s) to plot.
64-
x_label (str, optional): The x-axis label.
65-
y_label (str, optional): The y-axis label.
66-
title (str, optional): The title of the plot.
64+
x_label (str): The x-axis label.
65+
y_label (str): The y-axis label.
66+
title (str): The title of the plot.
6767
x_col (str, optional): The column to be used as the x-axis. If None, the index is used.
6868
group_col (str, optional): The column used to define different lines.
6969
legend_title (str, optional): The title of the legend.
@@ -96,15 +96,14 @@ def plot(
9696
**kwargs,
9797
)
9898

99-
ax = gu.standard_graph_styles(ax=ax, title=title, x_label=x_label, y_label=y_label)
100-
101-
if move_legend_outside:
102-
ax.legend(bbox_to_anchor=(1.05, 1))
103-
104-
if legend_title is not None:
105-
legend = ax.get_legend()
106-
if legend is not None:
107-
legend.set_title(legend_title)
99+
ax = gu.standard_graph_styles(
100+
ax=ax,
101+
title=title,
102+
x_label=x_label,
103+
y_label=y_label,
104+
legend_title=legend_title,
105+
move_legend_outside=move_legend_outside,
106+
)
108107

109108
if source_text is not None:
110109
gu.add_source_text(ax=ax, source_text=source_text)

pyretailscience/style/graph_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def standard_graph_styles(
6767
title_pad: int = GraphStyles.DEFAULT_TITLE_PAD,
6868
x_label_pad: int = GraphStyles.DEFAULT_AXIS_LABEL_PAD,
6969
y_label_pad: int = GraphStyles.DEFAULT_AXIS_LABEL_PAD,
70+
legend_title: str | None = None,
71+
move_legend_outside: bool = False,
7072
) -> Axes:
7173
"""Apply standard styles to a Matplotlib graph.
7274
@@ -79,6 +81,8 @@ def standard_graph_styles(
7981
x_label_pad (int, optional): The padding below the x-axis label. Defaults to GraphStyles.DEFAULT_AXIS_LABEL_PAD.
8082
y_label_pad (int, optional): The padding to the left of the y-axis label. Defaults to
8183
GraphStyles.DEFAULT_AXIS_LABEL_PAD.
84+
legend_title (str, optional): The title of the legend. If None, no legend title is applied. Defaults to None.
85+
move_legend_outside (bool, optional): Whether to move the legend outside the plot. Defaults to False.
8286
8387
Returns:
8488
Axes: The graph with the styles applied.
@@ -112,6 +116,13 @@ def standard_graph_styles(
112116
labelpad=y_label_pad,
113117
)
114118

119+
legend = ax.legend() if move_legend_outside or legend_title is not None else ax.get_legend()
120+
if legend:
121+
if move_legend_outside:
122+
legend.set_bbox_to_anchor((1.05, 1))
123+
if legend_title:
124+
legend.set_title(legend_title)
125+
115126
return ax
116127

117128

tests/plots/test_line.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_plot_moves_legend_outside(sample_dataframe):
107107
"""Test the plot function moves the legend outside the plot."""
108108
_, ax = plt.subplots()
109109

110-
# Test with move_legend_outside=True
110+
# Create the plot with move_legend_outside=True
111111
result_ax = line.plot(
112112
df=sample_dataframe,
113113
value_col="y",
@@ -120,13 +120,44 @@ def test_plot_moves_legend_outside(sample_dataframe):
120120
move_legend_outside=True,
121121
)
122122

123-
expected_coords = (1.05, 1.0)
124-
legend = result_ax.get_legend()
125-
# Check if bbox_to_anchor is set to (1.05, 1) when legend is outside
126-
bbox_anchor = legend.get_bbox_to_anchor()._bbox
123+
# Assert that standard_graph_styles was called with move_legend_outside=True
124+
gu.standard_graph_styles.assert_called_once_with(
125+
ax=result_ax,
126+
title="Test Plot Legend Outside",
127+
x_label="X Axis",
128+
y_label="Y Axis",
129+
legend_title=None,
130+
move_legend_outside=True,
131+
)
132+
133+
134+
@pytest.mark.usefixtures("_mock_get_base_cmap", "_mock_gu_functions")
135+
def test_plot_moves_legend_inside(sample_dataframe):
136+
"""Test the plot function moves the legend inside the plot."""
137+
_, ax = plt.subplots()
138+
139+
# Create the plot with move_legend_outside=False
140+
result_ax = line.plot(
141+
df=sample_dataframe,
142+
value_col="y",
143+
x_label="X Axis",
144+
y_label="Y Axis",
145+
title="Test Plot Legend Inside",
146+
x_col="x",
147+
group_col="group",
148+
ax=ax,
149+
move_legend_outside=False,
150+
)
127151

128-
assert legend is not None
129-
assert (bbox_anchor.x0, bbox_anchor.y0) == expected_coords
152+
# Assert that standard_graph_styles was called with move_legend_outside=False
153+
gu.standard_graph_styles.assert_called_once_with(
154+
ax=result_ax,
155+
title="Test Plot Legend Inside",
156+
x_label="X Axis",
157+
y_label="Y Axis",
158+
legend_title=None,
159+
move_legend_outside=False,
160+
)
130161

131162

132163
@pytest.mark.usefixtures("_mock_get_base_cmap", "_mock_gu_functions")

0 commit comments

Comments
 (0)