1+ """This module performs gain loss analysis (switching analysis) on a DataFrame to assess customer movement between brands or products over time.
2+
3+ Gain loss analysis, also known as switching analysis, is a marketing analytics technique used to
4+ assess customer movement between brands or products over time. It helps businesses understand the dynamics of customer
5+ acquisition and churn. Here's a concise definition: Gain loss analysis examines the flow of customers to and from a
6+ brand or product, quantifying:
7+
8+ 1. Gains: New customers acquired from competitors
9+ 2. Losses: Existing customers lost to competitors
10+ 3. Net change: The overall impact on market share
11+
12+ This analysis helps marketers:
13+
14+ - Identify trends in customer behavior
15+ - Evaluate the effectiveness of marketing strategies
16+ - Understand competitive dynamics in the market
17+ """
18+
119import pandas as pd
220from matplotlib .axes import Axes , SubplotBase
321
422import pyretailscience .style .graph_utils as gu
523from pyretailscience .data .contracts import CustomContract , build_expected_columns , build_non_null_columns
6- from pyretailscience .style .graph_utils import GraphStyles as gs
24+ from pyretailscience .style .graph_utils import GraphStyles
725from pyretailscience .style .tailwind import COLORS
826
9- # TODO: Consider simplifying this by reducing the color range in the get_linear_cmap function.
10- COLORMAP_MIN = 0.25
11- COLORMAP_MAX = 0.75
12-
1327
1428class GainLoss :
29+ """A class to perform gain loss analysis on a DataFrame to assess customer movement between brands or products over time."""
30+
1531 def __init__ (
1632 self ,
1733 df : pd .DataFrame ,
@@ -24,7 +40,7 @@ def __init__(
2440 group_col : str | None = None ,
2541 value_col : str = "total_price" ,
2642 agg_func : str = "sum" ,
27- ):
43+ ) -> None :
2844 """Calculate the gain loss table for a given DataFrame at the customer level.
2945
3046 Args:
@@ -48,7 +64,7 @@ def __init__(
4864
4965 if not len (p1_index ) == len (p2_index ) == len (focus_group_index ) == len (comparison_group_index ):
5066 raise ValueError (
51- "p1_index, p2_index, focus_group_index, and comparison_group_index should have the same length"
67+ "p1_index, p2_index, focus_group_index, and comparison_group_index should have the same length" ,
5268 )
5369
5470 required_cols = ["customer_id" , value_col ] + ([group_col ] if group_col is not None else [])
@@ -58,7 +74,8 @@ def __init__(
5874 extended_expectations = build_non_null_columns (columns = required_cols ),
5975 )
6076 if contract .validate () is False :
61- raise ValueError (f"The dataframe requires the columns { required_cols } and they must be non-null" )
77+ msg = f"The dataframe requires the columns { required_cols } and they must be non-null"
78+ raise ValueError (msg )
6279
6380 self .focus_group_name = focus_group_name
6481 self .comparison_group_name = comparison_group_name
@@ -80,6 +97,49 @@ def __init__(
8097 group_col = group_col ,
8198 )
8299
100+ @staticmethod
101+ def process_customer_group (
102+ focus_p1 : float ,
103+ comparison_p1 : float ,
104+ focus_p2 : float ,
105+ comparison_p2 : float ,
106+ focus_diff : float ,
107+ comparison_diff : float ,
108+ ) -> tuple [float , float , float , float , float , float ]:
109+ """Process the gain loss for a customer group.
110+
111+ Args:
112+ focus_p1 (float | int): The focus group total in the first time period.
113+ comparison_p1 (float | int): The comparison group total in the first time period.
114+ focus_p2 (float | int): The focus group total in the second time period.
115+ comparison_p2 (float | int): The comparison group total in the second time period.
116+ focus_diff (float | int): The difference in the focus group totals.
117+ comparison_diff (float | int): The difference in the comparison group totals.
118+
119+ Returns:
120+ tuple[float, float, float, float, float, float]: The gain loss for the customer group.
121+ """
122+ if focus_p1 == 0 and comparison_p1 == 0 :
123+ return focus_p2 , 0 , 0 , 0 , 0 , 0
124+ if focus_p2 == 0 and comparison_p2 == 0 :
125+ return 0 , - 1 * focus_p1 , 0 , 0 , 0 , 0
126+
127+ if focus_diff > 0 :
128+ focus_inc_dec = focus_diff if comparison_diff > 0 else max (0 , comparison_diff + focus_diff )
129+ elif comparison_diff < 0 :
130+ focus_inc_dec = focus_diff
131+ else :
132+ focus_inc_dec = min (0 , comparison_diff + focus_diff )
133+
134+ increased_focus = max (0 , focus_inc_dec )
135+ decreased_focus = min (0 , focus_inc_dec )
136+
137+ transfer = focus_diff - focus_inc_dec
138+ switch_from_comparison = max (0 , transfer )
139+ switch_to_comparison = min (0 , transfer )
140+
141+ return 0 , 0 , increased_focus , decreased_focus , switch_from_comparison , switch_to_comparison
142+
83143 @staticmethod
84144 def _calc_gain_loss (
85145 df : pd .DataFrame ,
@@ -91,8 +151,7 @@ def _calc_gain_loss(
91151 value_col : str = "total_price" ,
92152 agg_func : str = "sum" ,
93153 ) -> pd .DataFrame :
94- """
95- Calculate the gain loss table for a given DataFrame at the customer level.
154+ """Calculate the gain loss table for a given DataFrame at the customer level.
96155
97156 Args:
98157 df (pd.DataFrame): The DataFrame to calculate the gain loss table from.
@@ -102,6 +161,7 @@ def _calc_gain_loss(
102161 comparison_group_index (list[bool]): The index for the comparison group.
103162 group_col (str | None, optional): The column to group by. Defaults to None.
104163 value_col (str, optional): The column to calculate the gain loss from. Defaults to "total_price".
164+ agg_func (str, optional): The aggregation function to use. Defaults to "sum".
105165
106166 Returns:
107167 pd.DataFrame: The gain loss table.
@@ -144,14 +204,27 @@ def _calc_gain_loss(
144204 gl_df ["comparison_diff" ] = gl_df ["comparison_p2" ] - gl_df ["comparison_p1" ]
145205 gl_df ["total_diff" ] = gl_df ["total_p2" ] - gl_df ["total_p1" ]
146206
147- gl_df ["switch_from_comparison" ] = (gl_df ["focus_diff" ] - gl_df ["total_diff" ]).apply (lambda x : max (x , 0 ))
148- gl_df ["switch_to_comparison" ] = (gl_df ["focus_diff" ] - gl_df ["total_diff" ]).apply (lambda x : min (x , 0 ))
149-
150- gl_df ["new" ] = gl_df .apply (lambda x : max (x ["total_diff" ], 0 ) if x ["focus_p1" ] == 0 else 0 , axis = 1 )
151- gl_df ["lost" ] = gl_df .apply (lambda x : min (x ["total_diff" ], 0 ) if x ["focus_p2" ] == 0 else 0 , axis = 1 )
152-
153- gl_df ["increased_focus" ] = gl_df .apply (lambda x : max (x ["total_diff" ], 0 ) if x ["focus_p1" ] != 0 else 0 , axis = 1 )
154- gl_df ["decreased_focus" ] = gl_df .apply (lambda x : min (x ["total_diff" ], 0 ) if x ["focus_p2" ] != 0 else 0 , axis = 1 )
207+ (
208+ gl_df ["new" ],
209+ gl_df ["lost" ],
210+ gl_df ["increased_focus" ],
211+ gl_df ["decreased_focus" ],
212+ gl_df ["switch_from_comparison" ],
213+ gl_df ["switch_to_comparison" ],
214+ ) = zip (
215+ * gl_df .apply (
216+ lambda x : GainLoss .process_customer_group (
217+ focus_p1 = x ["focus_p1" ],
218+ comparison_p1 = x ["comparison_p1" ],
219+ focus_p2 = x ["focus_p2" ],
220+ comparison_p2 = x ["comparison_p2" ],
221+ focus_diff = x ["focus_diff" ],
222+ comparison_diff = x ["comparison_diff" ],
223+ ),
224+ axis = 1 ,
225+ ),
226+ strict = False ,
227+ )
155228
156229 return gl_df
157230
@@ -171,8 +244,8 @@ def _calc_gains_loss_table(
171244 """
172245 if group_col is None :
173246 return gain_loss_df .sum ().to_frame ("" ).T
174- else :
175- return gain_loss_df .groupby (level = 0 ).sum ()
247+
248+ return gain_loss_df .groupby (level = 0 ).sum ()
176249
177250 def plot (
178251 self ,
@@ -193,6 +266,7 @@ def plot(
193266 ax (Axes | None, optional): The axes to plot on. Defaults to None.
194267 source_text (str | None, optional): The source text to add to the plot. Defaults to None.
195268 move_legend_outside (bool, optional): Whether to move the legend outside the plot. Defaults to False.
269+ kwargs (dict[str, any]): Additional keyword arguments to pass to the plot.
196270
197271 Returns:
198272 SubplotBase: The plot
@@ -217,7 +291,7 @@ def plot(
217291 if move_legend_outside :
218292 legend_bbox_to_anchor = (1.05 , 1 )
219293
220- # TODO: Ensure that each label ctually has data before adding to the legend
294+ # TODO: Ensure that each label actually has data before adding to the legend
221295 legend = ax .legend (
222296 [
223297 "New" ,
@@ -252,15 +326,15 @@ def plot(
252326 xycoords = "axes fraction" ,
253327 ha = "left" ,
254328 va = "center" ,
255- fontsize = gs .DEFAULT_SOURCE_FONT_SIZE ,
256- fontproperties = gs .POPPINS_LIGHT_ITALIC ,
329+ fontsize = GraphStyles .DEFAULT_SOURCE_FONT_SIZE ,
330+ fontproperties = GraphStyles .POPPINS_LIGHT_ITALIC ,
257331 color = "dimgray" ,
258332 )
259333
260334 # Set the font properties for the tick labels
261335 for tick in ax .get_xticklabels ():
262- tick .set_fontproperties (gs .POPPINS_REG )
336+ tick .set_fontproperties (GraphStyles .POPPINS_REG )
263337 for tick in ax .get_yticklabels ():
264- tick .set_fontproperties (gs .POPPINS_REG )
338+ tick .set_fontproperties (GraphStyles .POPPINS_REG )
265339
266340 return ax
0 commit comments