|
| 1 | +"""This module provides versatile plotting functionality for creating line plots from pandas DataFrames. |
| 2 | +
|
| 3 | +It is designed for plotting sequences that resemble time-based data, such as "days" or "months" since |
| 4 | +an event, but it does not explicitly handle datetime values. For actual time-based plots (using datetime |
| 5 | +objects), please refer to the `time_plot` module. |
| 6 | +
|
| 7 | +The sequences used in this module can include values such as "days since an event" (e.g., -2, -1, 0, 1, 2) |
| 8 | +or "months since a competitor store opened." **This module is not intended for use with actual datetime values**. |
| 9 | +If a datetime or datetime-like column is passed as `x_col`, a warning will be triggered suggesting the use |
| 10 | +of the `time_plot` module. |
| 11 | +
|
| 12 | +Key Features: |
| 13 | +-------------- |
| 14 | +- **Plotting Sequences or Indexes**: Plot one or more value columns (`value_col`), supporting sequences |
| 15 | + like -2, -1, 0, 1, 2 (e.g., months since an event), using either the index or a specified x-axis |
| 16 | + column (`x_col`). |
| 17 | +- **Custom X-Axis or Index**: Use any column as the x-axis (`x_col`), or plot based on the index if no |
| 18 | + x-axis column is specified. |
| 19 | +- **Multiple Lines**: Create separate lines for each unique value in `group_col` (e.g., categories). |
| 20 | +- **Comprehensive Customization**: Easily customize titles, axis labels, legends, and optionally move |
| 21 | + the legend outside the plot. |
| 22 | +- **Pre-Aggregated Data**: The data must be pre-aggregated before plotting. No aggregation occurs in |
| 23 | + this module. |
| 24 | +
|
| 25 | +### Common Scenarios and Examples: |
| 26 | +
|
| 27 | +1. **Basic Plot Showing Price Trends Since Competitor Store Opened**: |
| 28 | +
|
| 29 | + This example demonstrates how to plot the `total_price` over the number of months since a competitor |
| 30 | + store opened. The total price remains stable or increases slightly before the store opened, and then |
| 31 | + drops randomly after the competitor's store opened. |
| 32 | +
|
| 33 | + **Preparing the Data**: |
| 34 | + ```python |
| 35 | + import numpy as np |
| 36 | +
|
| 37 | + # Convert 'transaction_datetime' to a datetime column if it's not already |
| 38 | + df['transaction_datetime'] = pd.to_datetime(df['transaction_datetime']) |
| 39 | +
|
| 40 | + # Resample the data by month |
| 41 | + df['month'] = df['transaction_datetime'].dt.to_period('M') # Extract year and month |
| 42 | + df_monthly = df.groupby('month').agg({'total_price': 'sum'}).reset_index() |
| 43 | +
|
| 44 | + # Create the "months since competitor opened" column |
| 45 | + # Assume the competitor opened 60% of the way through the data |
| 46 | + competitor_opened_month_index = int(len(df_monthly) * 0.6) |
| 47 | + df_monthly['months_since_competitor_opened'] = np.arange(-competitor_opened_month_index, len(df_monthly) - competitor_opened_month_index) |
| 48 | +
|
| 49 | + # Simulate stable or increasing prices before competitor opened |
| 50 | + df_monthly.loc[df_monthly['months_since_competitor_opened'] < 0, 'total_price'] *= np.random.uniform(1.05, 1.2) |
| 51 | +
|
| 52 | + # Simulate a random drop after the competitor opened |
| 53 | + df_monthly.loc[df_monthly['months_since_competitor_opened'] >= 0, 'total_price'] *= np.random.uniform(0.8, 0.95, size=len(df_monthly[df_monthly['months_since_competitor_opened'] >= 0])) |
| 54 | + ``` |
| 55 | +
|
| 56 | + **Plotting**: |
| 57 | + ```python |
| 58 | + ax = line.plot( |
| 59 | + df=df_monthly, |
| 60 | + value_col="total_price", # Plot 'total_price' values |
| 61 | + x_col="months_since_competitor_opened", # Use 'months_since_competitor_opened' as the x-axis |
| 62 | + title="Total Price Since Competitor Store Opened", # Title of the plot |
| 63 | + x_label="Months Since Competitor Opened", # X-axis label |
| 64 | + y_label="Total Price", # Y-axis label |
| 65 | + ) |
| 66 | +
|
| 67 | + plt.show() |
| 68 | + ``` |
| 69 | +
|
| 70 | + **Use Case**: This is useful when you want to illustrate the effect of a competitor store opening |
| 71 | + on sales performance. The x-axis represents months before and after the event, and the y-axis shows |
| 72 | + how prices behaved over time. |
| 73 | +
|
| 74 | +--- |
| 75 | +
|
| 76 | +2. **Plotting Price Trends by Category (Top 3 Categories)**: |
| 77 | +
|
| 78 | + This example plots the total price for the top 3 categories before and after the competitor opened. |
| 79 | + The data is resampled by month, split by category, and tracks the months since the competitor store opened. |
| 80 | +
|
| 81 | + **Preparing the Data**: |
| 82 | + ```python |
| 83 | + import numpy as np |
| 84 | + import pandas as pd |
| 85 | +
|
| 86 | + # Convert 'transaction_datetime' to a datetime column if it's not already |
| 87 | + df['transaction_datetime'] = pd.to_datetime(df['transaction_datetime']) |
| 88 | +
|
| 89 | + # Resample the data by month and category |
| 90 | + df['month'] = df['transaction_datetime'].dt.to_period('M') # Extract year and month |
| 91 | + df_monthly = df.groupby(['month', 'category_0_name']).agg({'total_price': 'sum'}).reset_index() |
| 92 | +
|
| 93 | + # Create a separate dataframe for unique months to track "months since competitor opened" |
| 94 | + unique_months = df_monthly['month'].unique() |
| 95 | + competitor_opened_month_index = int(len(unique_months) * 0.6) # Assume competitor opened 60% of the way through |
| 96 | +
|
| 97 | + # Create 'months_since_competitor_opened' for each unique month |
| 98 | + months_since_competitor_opened = np.concatenate([ |
| 99 | + np.arange(-competitor_opened_month_index, 0), # Before competitor opened |
| 100 | + np.arange(0, len(unique_months) - competitor_opened_month_index) # After competitor opened |
| 101 | + ]) |
| 102 | +
|
| 103 | + # Create a new dataframe with the 'months_since_competitor_opened' values and merge it back |
| 104 | + months_df = pd.DataFrame({'month': unique_months, 'months_since_competitor_opened': months_since_competitor_opened}) |
| 105 | + df_monthly = df_monthly.merge(months_df, on='month', how='left') |
| 106 | +
|
| 107 | + # Filter to include months both before and after the competitor opened |
| 108 | + df_since_competitor_opened = df_monthly[(df_monthly['months_since_competitor_opened'] >= -6) & # Include 6 months before |
| 109 | + (df_monthly['months_since_competitor_opened'] <= 12)] # Include 12 months after |
| 110 | +
|
| 111 | + # Identify top 3 categories based on total_price across the selected period |
| 112 | + category_totals = df_since_competitor_opened.groupby('category_0_name')['total_price'].sum().sort_values(ascending=False) |
| 113 | +
|
| 114 | + # Get the top 3 categories |
| 115 | + top_categories = category_totals.head(3).index |
| 116 | +
|
| 117 | + # Filter the dataframe to include only the top 3 categories |
| 118 | + df_top_categories = df_since_competitor_opened[df_since_competitor_opened['category_0_name'].isin(top_categories)] |
| 119 | + ``` |
| 120 | +
|
| 121 | + **Plotting**: |
| 122 | + ```python |
| 123 | + ax = line.plot( |
| 124 | + df=df_top_categories, |
| 125 | + value_col="total_price", # Plot 'total_price' values |
| 126 | + group_col="category_0_name", # Separate lines for each category |
| 127 | + x_col="months_since_competitor_opened", # Use 'months_since_competitor_opened' as the x-axis |
| 128 | + title="Total Price for Top 3 Categories (Before and After Competitor Opened)", # Title of the plot |
| 129 | + x_label="Months Since Competitor Opened", # X-axis label |
| 130 | + y_label="Total Price", # Y-axis label |
| 131 | + legend_title="Category" # Legend title |
| 132 | + ) |
| 133 | +
|
| 134 | + plt.show() |
| 135 | + ``` |
| 136 | +
|
| 137 | + **Use Case**: Use this when you want to analyze the behavior of specific top categories before and after |
| 138 | + an event, such as the opening of a competitor store. |
| 139 | +
|
| 140 | +--- |
| 141 | +
|
| 142 | +### Customization Options: |
| 143 | +- **`value_col`**: The column or list of columns to plot (e.g., `'total_price'`). |
| 144 | +- **`group_col`**: A column whose unique values will be used to create separate lines (e.g., `'category_0_name'`). |
| 145 | +- **`x_col`**: The column to use as the x-axis (e.g., `'months_since_competitor_opened'`). **Warning**: If a datetime |
| 146 | + or datetime-like column is passed, a warning will suggest using the `time_plot` module instead. |
| 147 | +- **`title`**, **`x_label`**, **`y_label`**: Custom text for the plot title and axis labels. |
| 148 | +- **`legend_title`**: Custom title for the legend based on `group_col`. |
| 149 | +- **`move_legend_outside`**: Boolean flag to move the legend outside the plot. |
| 150 | +
|
| 151 | +--- |
| 152 | +
|
| 153 | +### Dependencies: |
| 154 | +- `pandas`: For DataFrame manipulation and grouping. |
| 155 | +- `matplotlib`: For generating plots. |
| 156 | +- `pyretailscience.style.graph_utils`: For applying consistent graph styles across the plots. |
| 157 | +
|
| 158 | +""" |
| 159 | + |
| 160 | +import logging |
| 161 | + |
| 162 | +import pandas as pd |
| 163 | +from matplotlib.axes import Axes, SubplotBase |
| 164 | + |
| 165 | +import pyretailscience.style.graph_utils as gu |
| 166 | +from pyretailscience.style.graph_utils import GraphStyles |
| 167 | +from pyretailscience.style.tailwind import COLORS |
| 168 | + |
| 169 | +logging.basicConfig(format="%(message)s", level=logging.INFO) |
| 170 | + |
| 171 | + |
| 172 | +def _check_datetime_column(df: pd.DataFrame, x_col: str) -> None: |
| 173 | + """Checks if the x_col is a datetime or convertible to datetime. |
| 174 | +
|
| 175 | + Issues a warning if the column is datetime-like, recommending |
| 176 | + the use of a time-based plot. |
| 177 | +
|
| 178 | + Args: |
| 179 | + df (pd.DataFrame): The dataframe containing the column to check. |
| 180 | + x_col (str): The column to check for datetime-like values. |
| 181 | + """ |
| 182 | + if pd.api.types.is_datetime64_any_dtype(df[x_col]): |
| 183 | + logging.warning( |
| 184 | + "The column '%s' is a datetime column. Consider using the 'time_plot' function for time-based plots.", |
| 185 | + x_col, |
| 186 | + ) |
| 187 | + else: |
| 188 | + try: |
| 189 | + pd.to_datetime(df[x_col]) |
| 190 | + logging.warning( |
| 191 | + "The column '%s' can be converted to datetime. Consider using the 'time_plot' module for time-based plots.", |
| 192 | + x_col, |
| 193 | + ) |
| 194 | + except (ValueError, TypeError): |
| 195 | + pass |
| 196 | + |
| 197 | + |
| 198 | +def plot( |
| 199 | + df: pd.DataFrame, |
| 200 | + value_col: str | list[str], |
| 201 | + group_col: str | None = None, |
| 202 | + x_col: str | None = None, |
| 203 | + title: str | None = None, |
| 204 | + x_label: str | None = None, |
| 205 | + y_label: str | None = None, |
| 206 | + legend_title: str | None = None, |
| 207 | + ax: Axes | None = None, |
| 208 | + source_text: str | None = None, |
| 209 | + move_legend_outside: bool = False, |
| 210 | + **kwargs: dict[str, any], |
| 211 | +) -> SubplotBase: |
| 212 | + """Plots the `value_col` over the specified `x_col` or index, creating a separate line for each unique value in `group_col`. |
| 213 | +
|
| 214 | + Args: |
| 215 | + df (pd.DataFrame): The dataframe to plot. |
| 216 | + value_col (str or list of str): The column(s) to plot. |
| 217 | + group_col (str, optional): The column used to define different lines. |
| 218 | + x_col (str, optional): The column to be used as the x-axis. If None, the index is used. |
| 219 | + title (str, optional): The title of the plot. |
| 220 | + x_label (str, optional): The x-axis label. |
| 221 | + y_label (str, optional): The y-axis label. |
| 222 | + legend_title (str, optional): The title of the legend. |
| 223 | + ax (Axes, optional): Matplotlib axes object to plot on. |
| 224 | + source_text (str, optional): The source text to add to the plot. |
| 225 | + move_legend_outside (bool, optional): Move the legend outside the plot. |
| 226 | + **kwargs: Additional keyword arguments for Pandas' `plot` function. |
| 227 | +
|
| 228 | + Returns: |
| 229 | + SubplotBase: The matplotlib axes object. |
| 230 | + """ |
| 231 | + if x_col is not None: |
| 232 | + _check_datetime_column(df, x_col) |
| 233 | + |
| 234 | + if group_col is not None: |
| 235 | + unique_groups = df[group_col].unique() |
| 236 | + colors = [ |
| 237 | + COLORS[color][shade] |
| 238 | + for shade in [500, 300, 700] |
| 239 | + for color in ["green", "orange", "red", "blue", "yellow", "violet", "pink"] |
| 240 | + ][: len(unique_groups)] |
| 241 | + |
| 242 | + ax = None |
| 243 | + for i, group in enumerate(unique_groups): |
| 244 | + group_df = df[df[group_col] == group] |
| 245 | + |
| 246 | + if x_col is not None: |
| 247 | + ax = group_df.plot( |
| 248 | + x=x_col, |
| 249 | + y=value_col, |
| 250 | + ax=ax, |
| 251 | + linewidth=3, |
| 252 | + color=colors[i], |
| 253 | + label=group, |
| 254 | + legend=True, |
| 255 | + **kwargs, |
| 256 | + ) |
| 257 | + else: |
| 258 | + ax = group_df.plot( |
| 259 | + y=value_col, |
| 260 | + ax=ax, |
| 261 | + linewidth=3, |
| 262 | + color=colors[i], |
| 263 | + label=group, |
| 264 | + legend=True, |
| 265 | + **kwargs, |
| 266 | + ) |
| 267 | + else: |
| 268 | + colors = COLORS["green"][500] |
| 269 | + if x_col is not None: |
| 270 | + ax = df.plot( |
| 271 | + x=x_col, |
| 272 | + y=value_col, |
| 273 | + linewidth=3, |
| 274 | + color=colors, |
| 275 | + legend=False, |
| 276 | + ax=ax, |
| 277 | + **kwargs, |
| 278 | + ) |
| 279 | + else: |
| 280 | + ax = df.plot( |
| 281 | + y=value_col, |
| 282 | + linewidth=3, |
| 283 | + color=colors, |
| 284 | + legend=False, |
| 285 | + ax=ax, |
| 286 | + **kwargs, |
| 287 | + ) |
| 288 | + |
| 289 | + ax = gu.standard_graph_styles( |
| 290 | + ax, |
| 291 | + title=title if title else f"{value_col.title()} Over {x_col.title()}" if x_col else f"{value_col.title()} Over Index", |
| 292 | + x_label=x_label if x_label else (x_col.title() if x_col else "Index"), |
| 293 | + y_label=y_label or value_col.title(), |
| 294 | + ) |
| 295 | + |
| 296 | + if move_legend_outside: |
| 297 | + ax.legend(bbox_to_anchor=(1.05, 1)) |
| 298 | + |
| 299 | + if legend_title is not None: |
| 300 | + ax.legend(title=legend_title) |
| 301 | + |
| 302 | + if source_text: |
| 303 | + ax.annotate( |
| 304 | + source_text, |
| 305 | + xy=(-0.1, -0.2), |
| 306 | + xycoords="axes fraction", |
| 307 | + ha="left", |
| 308 | + va="center", |
| 309 | + fontsize=GraphStyles.DEFAULT_SOURCE_FONT_SIZE, |
| 310 | + fontproperties=GraphStyles.POPPINS_LIGHT_ITALIC, |
| 311 | + color="dimgray", |
| 312 | + ) |
| 313 | + |
| 314 | + for tick in ax.get_xticklabels() + ax.get_yticklabels(): |
| 315 | + tick.set_fontproperties(GraphStyles.POPPINS_REG) |
| 316 | + |
| 317 | + return ax |
0 commit comments