Skip to content

Commit 5bd512d

Browse files
committed
chore: small code simplifications
1 parent f4ba900 commit 5bd512d

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

pyretailscience/plots/line.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""This module provides flexible functionality for creating line plots from pandas DataFrames.
22
3-
It focuses on visualizing sequences that resemble time-based data, such as "days since an event" or "months since a
4-
competitor opened." However, it does not explicitly handle datetime values. For actual time-based plots using
5-
datetime objects, please refer to the **`timeline`** module.
3+
It focuses on visualizing sequences that are ordered or sequential but not necessarily categorical, such as "days since
4+
an event" or "months since a competitor opened." However, while this module can handle datetime values on the x-axis,
5+
the **timeline** module has additional features that make working with datetimes easier, such as easily resampling the
6+
data to alternate time frames.
67
78
The sequences used in this module can include values like "days since an event" (e.g., -2, -1, 0, 1, 2) or "months
89
since a competitor store opened." **This module is not intended for use with actual datetime values**. If a datetime
@@ -44,9 +45,9 @@
4445
def plot(
4546
df: pd.DataFrame,
4647
value_col: str | list[str],
47-
x_label: str,
48-
y_label: str,
49-
title: str,
48+
x_label: str | None = None,
49+
y_label: str | None = None,
50+
title: str | None = None,
5051
x_col: str | None = None,
5152
group_col: str | None = None,
5253
legend_title: str | None = None,
@@ -60,9 +61,9 @@ def plot(
6061
Args:
6162
df (pd.DataFrame): The dataframe to plot.
6263
value_col (str or list of str): The column(s) to plot.
63-
x_label (str): The x-axis label.
64-
y_label (str): The y-axis label.
65-
title (str): The title of the plot.
64+
x_label (str, optional): The x-axis label.
65+
y_label (str, optional): The y-axis label.
66+
title (str, optional): The title of the plot.
6667
x_col (str, optional): The column to be used as the x-axis. If None, the index is used.
6768
group_col (str, optional): The column used to define different lines.
6869
legend_title (str, optional): The title of the legend.
@@ -82,10 +83,10 @@ def plot(
8283
)
8384
colors = get_base_cmap()
8485

85-
if group_col is not None:
86-
pivot_df = df.pivot(index=x_col if x_col is not None else None, columns=group_col, values=value_col)
87-
else:
86+
if group_col is None:
8887
pivot_df = df.set_index(x_col if x_col is not None else df.index)[value_col]
88+
else:
89+
pivot_df = df.pivot(index=x_col if x_col is not None else None, columns=group_col, values=value_col)
8990

9091
ax = pivot_df.plot(
9192
ax=ax,
@@ -95,18 +96,15 @@ def plot(
9596
**kwargs,
9697
)
9798

98-
ax = gu.standard_graph_styles(
99-
ax,
100-
title=title,
101-
x_label=x_label,
102-
y_label=y_label,
103-
)
99+
ax = gu.standard_graph_styles(ax=ax, title=title, x_label=x_label, y_label=y_label)
104100

105101
if move_legend_outside:
106102
ax.legend(bbox_to_anchor=(1.05, 1))
107103

108104
if legend_title is not None:
109-
ax.legend(title=legend_title)
105+
legend = ax.get_legend()
106+
if legend is not None:
107+
legend.set_title(legend_title)
110108

111109
if source_text is not None:
112110
gu.add_source_text(ax=ax, source_text=source_text)

0 commit comments

Comments
 (0)