1
+ """This module performs gain loss analysis (switching analysis) on a DataFrame to assess customer movement between brands or products over time.
2
+
3
+ Gain loss analysis, also known as switching analysis, is a marketing analytics technique used to
4
+ assess customer movement between brands or products over time. It helps businesses understand the dynamics of customer
5
+ acquisition and churn. Here's a concise definition: Gain loss analysis examines the flow of customers to and from a
6
+ brand or product, quantifying:
7
+
8
+ 1. Gains: New customers acquired from competitors
9
+ 2. Losses: Existing customers lost to competitors
10
+ 3. Net change: The overall impact on market share
11
+
12
+ This analysis helps marketers:
13
+
14
+ - Identify trends in customer behavior
15
+ - Evaluate the effectiveness of marketing strategies
16
+ - Understand competitive dynamics in the market
17
+ """
18
+
1
19
import pandas as pd
2
20
from matplotlib .axes import Axes , SubplotBase
3
21
4
22
import pyretailscience .style .graph_utils as gu
5
23
from pyretailscience .data .contracts import CustomContract , build_expected_columns , build_non_null_columns
6
- from pyretailscience .style .graph_utils import GraphStyles as gs
24
+ from pyretailscience .style .graph_utils import GraphStyles
7
25
from pyretailscience .style .tailwind import COLORS
8
26
9
- # TODO: Consider simplifying this by reducing the color range in the get_linear_cmap function.
10
- COLORMAP_MIN = 0.25
11
- COLORMAP_MAX = 0.75
12
-
13
27
14
28
class GainLoss :
29
+ """A class to perform gain loss analysis on a DataFrame to assess customer movement between brands or products over time."""
30
+
15
31
def __init__ (
16
32
self ,
17
33
df : pd .DataFrame ,
@@ -24,7 +40,7 @@ def __init__(
24
40
group_col : str | None = None ,
25
41
value_col : str = "total_price" ,
26
42
agg_func : str = "sum" ,
27
- ):
43
+ ) -> None :
28
44
"""Calculate the gain loss table for a given DataFrame at the customer level.
29
45
30
46
Args:
@@ -48,7 +64,7 @@ def __init__(
48
64
49
65
if not len (p1_index ) == len (p2_index ) == len (focus_group_index ) == len (comparison_group_index ):
50
66
raise ValueError (
51
- "p1_index, p2_index, focus_group_index, and comparison_group_index should have the same length"
67
+ "p1_index, p2_index, focus_group_index, and comparison_group_index should have the same length" ,
52
68
)
53
69
54
70
required_cols = ["customer_id" , value_col ] + ([group_col ] if group_col is not None else [])
@@ -58,7 +74,8 @@ def __init__(
58
74
extended_expectations = build_non_null_columns (columns = required_cols ),
59
75
)
60
76
if contract .validate () is False :
61
- raise ValueError (f"The dataframe requires the columns { required_cols } and they must be non-null" )
77
+ msg = f"The dataframe requires the columns { required_cols } and they must be non-null"
78
+ raise ValueError (msg )
62
79
63
80
self .focus_group_name = focus_group_name
64
81
self .comparison_group_name = comparison_group_name
@@ -80,6 +97,49 @@ def __init__(
80
97
group_col = group_col ,
81
98
)
82
99
100
+ @staticmethod
101
+ def process_customer_group (
102
+ focus_p1 : float ,
103
+ comparison_p1 : float ,
104
+ focus_p2 : float ,
105
+ comparison_p2 : float ,
106
+ focus_diff : float ,
107
+ comparison_diff : float ,
108
+ ) -> tuple [float , float , float , float , float , float ]:
109
+ """Process the gain loss for a customer group.
110
+
111
+ Args:
112
+ focus_p1 (float | int): The focus group total in the first time period.
113
+ comparison_p1 (float | int): The comparison group total in the first time period.
114
+ focus_p2 (float | int): The focus group total in the second time period.
115
+ comparison_p2 (float | int): The comparison group total in the second time period.
116
+ focus_diff (float | int): The difference in the focus group totals.
117
+ comparison_diff (float | int): The difference in the comparison group totals.
118
+
119
+ Returns:
120
+ tuple[float, float, float, float, float, float]: The gain loss for the customer group.
121
+ """
122
+ if focus_p1 == 0 and comparison_p1 == 0 :
123
+ return focus_p2 , 0 , 0 , 0 , 0 , 0
124
+ if focus_p2 == 0 and comparison_p2 == 0 :
125
+ return 0 , - 1 * focus_p1 , 0 , 0 , 0 , 0
126
+
127
+ if focus_diff > 0 :
128
+ focus_inc_dec = focus_diff if comparison_diff > 0 else max (0 , comparison_diff + focus_diff )
129
+ elif comparison_diff < 0 :
130
+ focus_inc_dec = focus_diff
131
+ else :
132
+ focus_inc_dec = min (0 , comparison_diff + focus_diff )
133
+
134
+ increased_focus = max (0 , focus_inc_dec )
135
+ decreased_focus = min (0 , focus_inc_dec )
136
+
137
+ transfer = focus_diff - focus_inc_dec
138
+ switch_from_comparison = max (0 , transfer )
139
+ switch_to_comparison = min (0 , transfer )
140
+
141
+ return 0 , 0 , increased_focus , decreased_focus , switch_from_comparison , switch_to_comparison
142
+
83
143
@staticmethod
84
144
def _calc_gain_loss (
85
145
df : pd .DataFrame ,
@@ -91,8 +151,7 @@ def _calc_gain_loss(
91
151
value_col : str = "total_price" ,
92
152
agg_func : str = "sum" ,
93
153
) -> pd .DataFrame :
94
- """
95
- Calculate the gain loss table for a given DataFrame at the customer level.
154
+ """Calculate the gain loss table for a given DataFrame at the customer level.
96
155
97
156
Args:
98
157
df (pd.DataFrame): The DataFrame to calculate the gain loss table from.
@@ -102,6 +161,7 @@ def _calc_gain_loss(
102
161
comparison_group_index (list[bool]): The index for the comparison group.
103
162
group_col (str | None, optional): The column to group by. Defaults to None.
104
163
value_col (str, optional): The column to calculate the gain loss from. Defaults to "total_price".
164
+ agg_func (str, optional): The aggregation function to use. Defaults to "sum".
105
165
106
166
Returns:
107
167
pd.DataFrame: The gain loss table.
@@ -144,14 +204,27 @@ def _calc_gain_loss(
144
204
gl_df ["comparison_diff" ] = gl_df ["comparison_p2" ] - gl_df ["comparison_p1" ]
145
205
gl_df ["total_diff" ] = gl_df ["total_p2" ] - gl_df ["total_p1" ]
146
206
147
- gl_df ["switch_from_comparison" ] = (gl_df ["focus_diff" ] - gl_df ["total_diff" ]).apply (lambda x : max (x , 0 ))
148
- gl_df ["switch_to_comparison" ] = (gl_df ["focus_diff" ] - gl_df ["total_diff" ]).apply (lambda x : min (x , 0 ))
149
-
150
- gl_df ["new" ] = gl_df .apply (lambda x : max (x ["total_diff" ], 0 ) if x ["focus_p1" ] == 0 else 0 , axis = 1 )
151
- gl_df ["lost" ] = gl_df .apply (lambda x : min (x ["total_diff" ], 0 ) if x ["focus_p2" ] == 0 else 0 , axis = 1 )
152
-
153
- gl_df ["increased_focus" ] = gl_df .apply (lambda x : max (x ["total_diff" ], 0 ) if x ["focus_p1" ] != 0 else 0 , axis = 1 )
154
- gl_df ["decreased_focus" ] = gl_df .apply (lambda x : min (x ["total_diff" ], 0 ) if x ["focus_p2" ] != 0 else 0 , axis = 1 )
207
+ (
208
+ gl_df ["new" ],
209
+ gl_df ["lost" ],
210
+ gl_df ["increased_focus" ],
211
+ gl_df ["decreased_focus" ],
212
+ gl_df ["switch_from_comparison" ],
213
+ gl_df ["switch_to_comparison" ],
214
+ ) = zip (
215
+ * gl_df .apply (
216
+ lambda x : GainLoss .process_customer_group (
217
+ focus_p1 = x ["focus_p1" ],
218
+ comparison_p1 = x ["comparison_p1" ],
219
+ focus_p2 = x ["focus_p2" ],
220
+ comparison_p2 = x ["comparison_p2" ],
221
+ focus_diff = x ["focus_diff" ],
222
+ comparison_diff = x ["comparison_diff" ],
223
+ ),
224
+ axis = 1 ,
225
+ ),
226
+ strict = False ,
227
+ )
155
228
156
229
return gl_df
157
230
@@ -171,8 +244,8 @@ def _calc_gains_loss_table(
171
244
"""
172
245
if group_col is None :
173
246
return gain_loss_df .sum ().to_frame ("" ).T
174
- else :
175
- return gain_loss_df .groupby (level = 0 ).sum ()
247
+
248
+ return gain_loss_df .groupby (level = 0 ).sum ()
176
249
177
250
def plot (
178
251
self ,
@@ -193,6 +266,7 @@ def plot(
193
266
ax (Axes | None, optional): The axes to plot on. Defaults to None.
194
267
source_text (str | None, optional): The source text to add to the plot. Defaults to None.
195
268
move_legend_outside (bool, optional): Whether to move the legend outside the plot. Defaults to False.
269
+ kwargs (dict[str, any]): Additional keyword arguments to pass to the plot.
196
270
197
271
Returns:
198
272
SubplotBase: The plot
@@ -217,7 +291,7 @@ def plot(
217
291
if move_legend_outside :
218
292
legend_bbox_to_anchor = (1.05 , 1 )
219
293
220
- # TODO: Ensure that each label ctually has data before adding to the legend
294
+ # TODO: Ensure that each label actually has data before adding to the legend
221
295
legend = ax .legend (
222
296
[
223
297
"New" ,
@@ -252,15 +326,15 @@ def plot(
252
326
xycoords = "axes fraction" ,
253
327
ha = "left" ,
254
328
va = "center" ,
255
- fontsize = gs .DEFAULT_SOURCE_FONT_SIZE ,
256
- fontproperties = gs .POPPINS_LIGHT_ITALIC ,
329
+ fontsize = GraphStyles .DEFAULT_SOURCE_FONT_SIZE ,
330
+ fontproperties = GraphStyles .POPPINS_LIGHT_ITALIC ,
257
331
color = "dimgray" ,
258
332
)
259
333
260
334
# Set the font properties for the tick labels
261
335
for tick in ax .get_xticklabels ():
262
- tick .set_fontproperties (gs .POPPINS_REG )
336
+ tick .set_fontproperties (GraphStyles .POPPINS_REG )
263
337
for tick in ax .get_yticklabels ():
264
- tick .set_fontproperties (gs .POPPINS_REG )
338
+ tick .set_fontproperties (GraphStyles .POPPINS_REG )
265
339
266
340
return ax
0 commit comments