Skip to content

Commit 606528f

Browse files
committed
feat: add line plot
1 parent e201fea commit 606528f

File tree

2 files changed

+317
-0
lines changed

2 files changed

+317
-0
lines changed

pyretailscience/plots/__init__.py

Whitespace-only changes.

pyretailscience/plots/line.py

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
"""This module provides versatile plotting functionality for creating line plots from pandas DataFrames.
2+
3+
It is designed for plotting sequences that resemble time-based data, such as "days" or "months" since
4+
an event, but it does not explicitly handle datetime values. For actual time-based plots (using datetime
5+
objects), please refer to the `time_plot` module.
6+
7+
The sequences used in this module can include values such as "days since an event" (e.g., -2, -1, 0, 1, 2)
8+
or "months since a competitor store opened." **This module is not intended for use with actual datetime values**.
9+
If a datetime or datetime-like column is passed as `x_col`, a warning will be triggered suggesting the use
10+
of the `time_plot` module.
11+
12+
Key Features:
13+
--------------
14+
- **Plotting Sequences or Indexes**: Plot one or more value columns (`value_col`), supporting sequences
15+
like -2, -1, 0, 1, 2 (e.g., months since an event), using either the index or a specified x-axis
16+
column (`x_col`).
17+
- **Custom X-Axis or Index**: Use any column as the x-axis (`x_col`), or plot based on the index if no
18+
x-axis column is specified.
19+
- **Multiple Lines**: Create separate lines for each unique value in `group_col` (e.g., categories).
20+
- **Comprehensive Customization**: Easily customize titles, axis labels, legends, and optionally move
21+
the legend outside the plot.
22+
- **Pre-Aggregated Data**: The data must be pre-aggregated before plotting. No aggregation occurs in
23+
this module.
24+
25+
### Common Scenarios and Examples:
26+
27+
1. **Basic Plot Showing Price Trends Since Competitor Store Opened**:
28+
29+
This example demonstrates how to plot the `total_price` over the number of months since a competitor
30+
store opened. The total price remains stable or increases slightly before the store opened, and then
31+
drops randomly after the competitor's store opened.
32+
33+
**Preparing the Data**:
34+
```python
35+
import numpy as np
36+
37+
# Convert 'transaction_datetime' to a datetime column if it's not already
38+
df['transaction_datetime'] = pd.to_datetime(df['transaction_datetime'])
39+
40+
# Resample the data by month
41+
df['month'] = df['transaction_datetime'].dt.to_period('M') # Extract year and month
42+
df_monthly = df.groupby('month').agg({'total_price': 'sum'}).reset_index()
43+
44+
# Create the "months since competitor opened" column
45+
# Assume the competitor opened 60% of the way through the data
46+
competitor_opened_month_index = int(len(df_monthly) * 0.6)
47+
df_monthly['months_since_competitor_opened'] = np.arange(-competitor_opened_month_index, len(df_monthly) - competitor_opened_month_index)
48+
49+
# Simulate stable or increasing prices before competitor opened
50+
df_monthly.loc[df_monthly['months_since_competitor_opened'] < 0, 'total_price'] *= np.random.uniform(1.05, 1.2)
51+
52+
# Simulate a random drop after the competitor opened
53+
df_monthly.loc[df_monthly['months_since_competitor_opened'] >= 0, 'total_price'] *= np.random.uniform(0.8, 0.95, size=len(df_monthly[df_monthly['months_since_competitor_opened'] >= 0]))
54+
```
55+
56+
**Plotting**:
57+
```python
58+
ax = line.plot(
59+
df=df_monthly,
60+
value_col="total_price", # Plot 'total_price' values
61+
x_col="months_since_competitor_opened", # Use 'months_since_competitor_opened' as the x-axis
62+
title="Total Price Since Competitor Store Opened", # Title of the plot
63+
x_label="Months Since Competitor Opened", # X-axis label
64+
y_label="Total Price", # Y-axis label
65+
)
66+
67+
plt.show()
68+
```
69+
70+
**Use Case**: This is useful when you want to illustrate the effect of a competitor store opening
71+
on sales performance. The x-axis represents months before and after the event, and the y-axis shows
72+
how prices behaved over time.
73+
74+
---
75+
76+
2. **Plotting Price Trends by Category (Top 3 Categories)**:
77+
78+
This example plots the total price for the top 3 categories before and after the competitor opened.
79+
The data is resampled by month, split by category, and tracks the months since the competitor store opened.
80+
81+
**Preparing the Data**:
82+
```python
83+
import numpy as np
84+
import pandas as pd
85+
86+
# Convert 'transaction_datetime' to a datetime column if it's not already
87+
df['transaction_datetime'] = pd.to_datetime(df['transaction_datetime'])
88+
89+
# Resample the data by month and category
90+
df['month'] = df['transaction_datetime'].dt.to_period('M') # Extract year and month
91+
df_monthly = df.groupby(['month', 'category_0_name']).agg({'total_price': 'sum'}).reset_index()
92+
93+
# Create a separate dataframe for unique months to track "months since competitor opened"
94+
unique_months = df_monthly['month'].unique()
95+
competitor_opened_month_index = int(len(unique_months) * 0.6) # Assume competitor opened 60% of the way through
96+
97+
# Create 'months_since_competitor_opened' for each unique month
98+
months_since_competitor_opened = np.concatenate([
99+
np.arange(-competitor_opened_month_index, 0), # Before competitor opened
100+
np.arange(0, len(unique_months) - competitor_opened_month_index) # After competitor opened
101+
])
102+
103+
# Create a new dataframe with the 'months_since_competitor_opened' values and merge it back
104+
months_df = pd.DataFrame({'month': unique_months, 'months_since_competitor_opened': months_since_competitor_opened})
105+
df_monthly = df_monthly.merge(months_df, on='month', how='left')
106+
107+
# Filter to include months both before and after the competitor opened
108+
df_since_competitor_opened = df_monthly[(df_monthly['months_since_competitor_opened'] >= -6) & # Include 6 months before
109+
(df_monthly['months_since_competitor_opened'] <= 12)] # Include 12 months after
110+
111+
# Identify top 3 categories based on total_price across the selected period
112+
category_totals = df_since_competitor_opened.groupby('category_0_name')['total_price'].sum().sort_values(ascending=False)
113+
114+
# Get the top 3 categories
115+
top_categories = category_totals.head(3).index
116+
117+
# Filter the dataframe to include only the top 3 categories
118+
df_top_categories = df_since_competitor_opened[df_since_competitor_opened['category_0_name'].isin(top_categories)]
119+
```
120+
121+
**Plotting**:
122+
```python
123+
ax = line.plot(
124+
df=df_top_categories,
125+
value_col="total_price", # Plot 'total_price' values
126+
group_col="category_0_name", # Separate lines for each category
127+
x_col="months_since_competitor_opened", # Use 'months_since_competitor_opened' as the x-axis
128+
title="Total Price for Top 3 Categories (Before and After Competitor Opened)", # Title of the plot
129+
x_label="Months Since Competitor Opened", # X-axis label
130+
y_label="Total Price", # Y-axis label
131+
legend_title="Category" # Legend title
132+
)
133+
134+
plt.show()
135+
```
136+
137+
**Use Case**: Use this when you want to analyze the behavior of specific top categories before and after
138+
an event, such as the opening of a competitor store.
139+
140+
---
141+
142+
### Customization Options:
143+
- **`value_col`**: The column or list of columns to plot (e.g., `'total_price'`).
144+
- **`group_col`**: A column whose unique values will be used to create separate lines (e.g., `'category_0_name'`).
145+
- **`x_col`**: The column to use as the x-axis (e.g., `'months_since_competitor_opened'`). **Warning**: If a datetime
146+
or datetime-like column is passed, a warning will suggest using the `time_plot` module instead.
147+
- **`title`**, **`x_label`**, **`y_label`**: Custom text for the plot title and axis labels.
148+
- **`legend_title`**: Custom title for the legend based on `group_col`.
149+
- **`move_legend_outside`**: Boolean flag to move the legend outside the plot.
150+
151+
---
152+
153+
### Dependencies:
154+
- `pandas`: For DataFrame manipulation and grouping.
155+
- `matplotlib`: For generating plots.
156+
- `pyretailscience.style.graph_utils`: For applying consistent graph styles across the plots.
157+
158+
"""
159+
160+
import logging
161+
162+
import pandas as pd
163+
from matplotlib.axes import Axes, SubplotBase
164+
165+
import pyretailscience.style.graph_utils as gu
166+
from pyretailscience.style.graph_utils import GraphStyles
167+
from pyretailscience.style.tailwind import COLORS
168+
169+
logging.basicConfig(format="%(message)s", level=logging.INFO)
170+
171+
172+
def _check_datetime_column(df: pd.DataFrame, x_col: str) -> None:
173+
"""Checks if the x_col is a datetime or convertible to datetime.
174+
175+
Issues a warning if the column is datetime-like, recommending
176+
the use of a time-based plot.
177+
178+
Args:
179+
df (pd.DataFrame): The dataframe containing the column to check.
180+
x_col (str): The column to check for datetime-like values.
181+
"""
182+
if pd.api.types.is_datetime64_any_dtype(df[x_col]):
183+
logging.warning(
184+
"The column '%s' is a datetime column. Consider using the 'time_plot' function for time-based plots.",
185+
x_col,
186+
)
187+
else:
188+
try:
189+
pd.to_datetime(df[x_col])
190+
logging.warning(
191+
"The column '%s' can be converted to datetime. Consider using the 'time_plot' module for time-based plots.",
192+
x_col,
193+
)
194+
except (ValueError, TypeError):
195+
pass
196+
197+
198+
def plot(
199+
df: pd.DataFrame,
200+
value_col: str | list[str],
201+
group_col: str | None = None,
202+
x_col: str | None = None,
203+
title: str | None = None,
204+
x_label: str | None = None,
205+
y_label: str | None = None,
206+
legend_title: str | None = None,
207+
ax: Axes | None = None,
208+
source_text: str | None = None,
209+
move_legend_outside: bool = False,
210+
**kwargs: dict[str, any],
211+
) -> SubplotBase:
212+
"""Plots the `value_col` over the specified `x_col` or index, creating a separate line for each unique value in `group_col`.
213+
214+
Args:
215+
df (pd.DataFrame): The dataframe to plot.
216+
value_col (str or list of str): The column(s) to plot.
217+
group_col (str, optional): The column used to define different lines.
218+
x_col (str, optional): The column to be used as the x-axis. If None, the index is used.
219+
title (str, optional): The title of the plot.
220+
x_label (str, optional): The x-axis label.
221+
y_label (str, optional): The y-axis label.
222+
legend_title (str, optional): The title of the legend.
223+
ax (Axes, optional): Matplotlib axes object to plot on.
224+
source_text (str, optional): The source text to add to the plot.
225+
move_legend_outside (bool, optional): Move the legend outside the plot.
226+
**kwargs: Additional keyword arguments for Pandas' `plot` function.
227+
228+
Returns:
229+
SubplotBase: The matplotlib axes object.
230+
"""
231+
if x_col is not None:
232+
_check_datetime_column(df, x_col)
233+
234+
if group_col is not None:
235+
unique_groups = df[group_col].unique()
236+
colors = [
237+
COLORS[color][shade]
238+
for shade in [500, 300, 700]
239+
for color in ["green", "orange", "red", "blue", "yellow", "violet", "pink"]
240+
][: len(unique_groups)]
241+
242+
ax = None
243+
for i, group in enumerate(unique_groups):
244+
group_df = df[df[group_col] == group]
245+
246+
if x_col is not None:
247+
ax = group_df.plot(
248+
x=x_col,
249+
y=value_col,
250+
ax=ax,
251+
linewidth=3,
252+
color=colors[i],
253+
label=group,
254+
legend=True,
255+
**kwargs,
256+
)
257+
else:
258+
ax = group_df.plot(
259+
y=value_col,
260+
ax=ax,
261+
linewidth=3,
262+
color=colors[i],
263+
label=group,
264+
legend=True,
265+
**kwargs,
266+
)
267+
else:
268+
colors = COLORS["green"][500]
269+
if x_col is not None:
270+
ax = df.plot(
271+
x=x_col,
272+
y=value_col,
273+
linewidth=3,
274+
color=colors,
275+
legend=False,
276+
ax=ax,
277+
**kwargs,
278+
)
279+
else:
280+
ax = df.plot(
281+
y=value_col,
282+
linewidth=3,
283+
color=colors,
284+
legend=False,
285+
ax=ax,
286+
**kwargs,
287+
)
288+
289+
ax = gu.standard_graph_styles(
290+
ax,
291+
title=title if title else f"{value_col.title()} Over {x_col.title()}" if x_col else f"{value_col.title()} Over Index",
292+
x_label=x_label if x_label else (x_col.title() if x_col else "Index"),
293+
y_label=y_label or value_col.title(),
294+
)
295+
296+
if move_legend_outside:
297+
ax.legend(bbox_to_anchor=(1.05, 1))
298+
299+
if legend_title is not None:
300+
ax.legend(title=legend_title)
301+
302+
if source_text:
303+
ax.annotate(
304+
source_text,
305+
xy=(-0.1, -0.2),
306+
xycoords="axes fraction",
307+
ha="left",
308+
va="center",
309+
fontsize=GraphStyles.DEFAULT_SOURCE_FONT_SIZE,
310+
fontproperties=GraphStyles.POPPINS_LIGHT_ITALIC,
311+
color="dimgray",
312+
)
313+
314+
for tick in ax.get_xticklabels() + ax.get_yticklabels():
315+
tick.set_fontproperties(GraphStyles.POPPINS_REG)
316+
317+
return ax

0 commit comments

Comments
 (0)