Skip to content

Commit 9a0d672

Browse files
authored
fix: gain loss calc error with negative values (#56)
1 parent 1a55803 commit 9a0d672

File tree

2 files changed

+159
-173
lines changed

2 files changed

+159
-173
lines changed

pyretailscience/gain_loss.py

Lines changed: 99 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
1+
"""This module performs gain loss analysis (switching analysis) on a DataFrame to assess customer movement between brands or products over time.
2+
3+
Gain loss analysis, also known as switching analysis, is a marketing analytics technique used to
4+
assess customer movement between brands or products over time. It helps businesses understand the dynamics of customer
5+
acquisition and churn. Here's a concise definition: Gain loss analysis examines the flow of customers to and from a
6+
brand or product, quantifying:
7+
8+
1. Gains: New customers acquired from competitors
9+
2. Losses: Existing customers lost to competitors
10+
3. Net change: The overall impact on market share
11+
12+
This analysis helps marketers:
13+
14+
- Identify trends in customer behavior
15+
- Evaluate the effectiveness of marketing strategies
16+
- Understand competitive dynamics in the market
17+
"""
18+
119
import pandas as pd
220
from matplotlib.axes import Axes, SubplotBase
321

422
import pyretailscience.style.graph_utils as gu
523
from pyretailscience.data.contracts import CustomContract, build_expected_columns, build_non_null_columns
6-
from pyretailscience.style.graph_utils import GraphStyles as gs
24+
from pyretailscience.style.graph_utils import GraphStyles
725
from pyretailscience.style.tailwind import COLORS
826

9-
# TODO: Consider simplifying this by reducing the color range in the get_linear_cmap function.
10-
COLORMAP_MIN = 0.25
11-
COLORMAP_MAX = 0.75
12-
1327

1428
class GainLoss:
29+
"""A class to perform gain loss analysis on a DataFrame to assess customer movement between brands or products over time."""
30+
1531
def __init__(
1632
self,
1733
df: pd.DataFrame,
@@ -24,7 +40,7 @@ def __init__(
2440
group_col: str | None = None,
2541
value_col: str = "total_price",
2642
agg_func: str = "sum",
27-
):
43+
) -> None:
2844
"""Calculate the gain loss table for a given DataFrame at the customer level.
2945
3046
Args:
@@ -48,7 +64,7 @@ def __init__(
4864

4965
if not len(p1_index) == len(p2_index) == len(focus_group_index) == len(comparison_group_index):
5066
raise ValueError(
51-
"p1_index, p2_index, focus_group_index, and comparison_group_index should have the same length"
67+
"p1_index, p2_index, focus_group_index, and comparison_group_index should have the same length",
5268
)
5369

5470
required_cols = ["customer_id", value_col] + ([group_col] if group_col is not None else [])
@@ -58,7 +74,8 @@ def __init__(
5874
extended_expectations=build_non_null_columns(columns=required_cols),
5975
)
6076
if contract.validate() is False:
61-
raise ValueError(f"The dataframe requires the columns {required_cols} and they must be non-null")
77+
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
78+
raise ValueError(msg)
6279

6380
self.focus_group_name = focus_group_name
6481
self.comparison_group_name = comparison_group_name
@@ -80,6 +97,49 @@ def __init__(
8097
group_col=group_col,
8198
)
8299

100+
@staticmethod
101+
def process_customer_group(
102+
focus_p1: float,
103+
comparison_p1: float,
104+
focus_p2: float,
105+
comparison_p2: float,
106+
focus_diff: float,
107+
comparison_diff: float,
108+
) -> tuple[float, float, float, float, float, float]:
109+
"""Process the gain loss for a customer group.
110+
111+
Args:
112+
focus_p1 (float | int): The focus group total in the first time period.
113+
comparison_p1 (float | int): The comparison group total in the first time period.
114+
focus_p2 (float | int): The focus group total in the second time period.
115+
comparison_p2 (float | int): The comparison group total in the second time period.
116+
focus_diff (float | int): The difference in the focus group totals.
117+
comparison_diff (float | int): The difference in the comparison group totals.
118+
119+
Returns:
120+
tuple[float, float, float, float, float, float]: The gain loss for the customer group.
121+
"""
122+
if focus_p1 == 0 and comparison_p1 == 0:
123+
return focus_p2, 0, 0, 0, 0, 0
124+
if focus_p2 == 0 and comparison_p2 == 0:
125+
return 0, -1 * focus_p1, 0, 0, 0, 0
126+
127+
if focus_diff > 0:
128+
focus_inc_dec = focus_diff if comparison_diff > 0 else max(0, comparison_diff + focus_diff)
129+
elif comparison_diff < 0:
130+
focus_inc_dec = focus_diff
131+
else:
132+
focus_inc_dec = min(0, comparison_diff + focus_diff)
133+
134+
increased_focus = max(0, focus_inc_dec)
135+
decreased_focus = min(0, focus_inc_dec)
136+
137+
transfer = focus_diff - focus_inc_dec
138+
switch_from_comparison = max(0, transfer)
139+
switch_to_comparison = min(0, transfer)
140+
141+
return 0, 0, increased_focus, decreased_focus, switch_from_comparison, switch_to_comparison
142+
83143
@staticmethod
84144
def _calc_gain_loss(
85145
df: pd.DataFrame,
@@ -91,8 +151,7 @@ def _calc_gain_loss(
91151
value_col: str = "total_price",
92152
agg_func: str = "sum",
93153
) -> pd.DataFrame:
94-
"""
95-
Calculate the gain loss table for a given DataFrame at the customer level.
154+
"""Calculate the gain loss table for a given DataFrame at the customer level.
96155
97156
Args:
98157
df (pd.DataFrame): The DataFrame to calculate the gain loss table from.
@@ -102,6 +161,7 @@ def _calc_gain_loss(
102161
comparison_group_index (list[bool]): The index for the comparison group.
103162
group_col (str | None, optional): The column to group by. Defaults to None.
104163
value_col (str, optional): The column to calculate the gain loss from. Defaults to "total_price".
164+
agg_func (str, optional): The aggregation function to use. Defaults to "sum".
105165
106166
Returns:
107167
pd.DataFrame: The gain loss table.
@@ -144,14 +204,27 @@ def _calc_gain_loss(
144204
gl_df["comparison_diff"] = gl_df["comparison_p2"] - gl_df["comparison_p1"]
145205
gl_df["total_diff"] = gl_df["total_p2"] - gl_df["total_p1"]
146206

147-
gl_df["switch_from_comparison"] = (gl_df["focus_diff"] - gl_df["total_diff"]).apply(lambda x: max(x, 0))
148-
gl_df["switch_to_comparison"] = (gl_df["focus_diff"] - gl_df["total_diff"]).apply(lambda x: min(x, 0))
149-
150-
gl_df["new"] = gl_df.apply(lambda x: max(x["total_diff"], 0) if x["focus_p1"] == 0 else 0, axis=1)
151-
gl_df["lost"] = gl_df.apply(lambda x: min(x["total_diff"], 0) if x["focus_p2"] == 0 else 0, axis=1)
152-
153-
gl_df["increased_focus"] = gl_df.apply(lambda x: max(x["total_diff"], 0) if x["focus_p1"] != 0 else 0, axis=1)
154-
gl_df["decreased_focus"] = gl_df.apply(lambda x: min(x["total_diff"], 0) if x["focus_p2"] != 0 else 0, axis=1)
207+
(
208+
gl_df["new"],
209+
gl_df["lost"],
210+
gl_df["increased_focus"],
211+
gl_df["decreased_focus"],
212+
gl_df["switch_from_comparison"],
213+
gl_df["switch_to_comparison"],
214+
) = zip(
215+
*gl_df.apply(
216+
lambda x: GainLoss.process_customer_group(
217+
focus_p1=x["focus_p1"],
218+
comparison_p1=x["comparison_p1"],
219+
focus_p2=x["focus_p2"],
220+
comparison_p2=x["comparison_p2"],
221+
focus_diff=x["focus_diff"],
222+
comparison_diff=x["comparison_diff"],
223+
),
224+
axis=1,
225+
),
226+
strict=False,
227+
)
155228

156229
return gl_df
157230

@@ -171,8 +244,8 @@ def _calc_gains_loss_table(
171244
"""
172245
if group_col is None:
173246
return gain_loss_df.sum().to_frame("").T
174-
else:
175-
return gain_loss_df.groupby(level=0).sum()
247+
248+
return gain_loss_df.groupby(level=0).sum()
176249

177250
def plot(
178251
self,
@@ -193,6 +266,7 @@ def plot(
193266
ax (Axes | None, optional): The axes to plot on. Defaults to None.
194267
source_text (str | None, optional): The source text to add to the plot. Defaults to None.
195268
move_legend_outside (bool, optional): Whether to move the legend outside the plot. Defaults to False.
269+
kwargs (dict[str, any]): Additional keyword arguments to pass to the plot.
196270
197271
Returns:
198272
SubplotBase: The plot
@@ -217,7 +291,7 @@ def plot(
217291
if move_legend_outside:
218292
legend_bbox_to_anchor = (1.05, 1)
219293

220-
# TODO: Ensure that each label ctually has data before adding to the legend
294+
# TODO: Ensure that each label actually has data before adding to the legend
221295
legend = ax.legend(
222296
[
223297
"New",
@@ -252,15 +326,15 @@ def plot(
252326
xycoords="axes fraction",
253327
ha="left",
254328
va="center",
255-
fontsize=gs.DEFAULT_SOURCE_FONT_SIZE,
256-
fontproperties=gs.POPPINS_LIGHT_ITALIC,
329+
fontsize=GraphStyles.DEFAULT_SOURCE_FONT_SIZE,
330+
fontproperties=GraphStyles.POPPINS_LIGHT_ITALIC,
257331
color="dimgray",
258332
)
259333

260334
# Set the font properties for the tick labels
261335
for tick in ax.get_xticklabels():
262-
tick.set_fontproperties(gs.POPPINS_REG)
336+
tick.set_fontproperties(GraphStyles.POPPINS_REG)
263337
for tick in ax.get_yticklabels():
264-
tick.set_fontproperties(gs.POPPINS_REG)
338+
tick.set_fontproperties(GraphStyles.POPPINS_REG)
265339

266340
return ax

0 commit comments

Comments
 (0)