Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 44 additions & 26 deletions pyretailscience/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import pyretailscience.style.graph_utils as gu
from pyretailscience.data.contracts import (
CustomContract,
TransactionItemLevelContract,
TransactionLevelContract,
build_expected_columns,
build_expected_unique_columns,
build_non_null_columns,
Expand Down Expand Up @@ -146,34 +144,49 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
TransactionLevelContract.

"""
required_cols = ["customer_id", "total_price", "transaction_id", segment_col]
contract = CustomContract(
df,
basic_expectations=build_expected_columns(columns=required_cols),
extended_expectations=build_non_null_columns(columns=required_cols),
)

if contract.validate() is False:
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
raise ValueError(msg)

self.segment_col = segment_col
if TransactionItemLevelContract(df).validate() is True:
stats_df = df.groupby(segment_col).agg(
revenue=("total_price", "sum"),
transactions=("transaction_id", "nunique"),
customers=("customer_id", "nunique"),
total_quantity=("quantity", "sum"),
)

self.df = self._calc_seg_stats(df, segment_col)

@staticmethod
def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
aggs = {
"revenue": ("total_price", "sum"),
"transactions": ("transaction_id", "nunique"),
"customers": ("customer_id", "nunique"),
}
total_aggs = {
"revenue": [df["total_price"].sum()],
"transactions": [df["transaction_id"].nunique()],
"customers": [df["customer_id"].nunique()],
}
if "quantity" in df.columns:
aggs["total_quantity"] = ("quantity", "sum")
total_aggs["total_quantity"] = [df["quantity"].sum()]

stats_df = pd.concat(
[
df.groupby(segment_col).agg(**aggs),
Comment on lines +178 to +182
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Use the agg method directly with a dictionary for better readability and maintainability [Enhancement, importance: 5]

Suggested change
total_aggs["total_quantity"] = [df["quantity"].sum()]
stats_df = pd.concat(
[
df.groupby(segment_col).agg(**aggs),
stats_df = pd.concat(
[
df.groupby(segment_col).agg(aggs),
pd.DataFrame(total_aggs, index=["total"]),
],
)

pd.DataFrame(total_aggs, index=["total"]),
],
)

if "quantity" in df.columns:
stats_df["price_per_unit"] = stats_df["revenue"] / stats_df["total_quantity"]
stats_df["quantity_per_transaction"] = stats_df["total_quantity"] / stats_df["transactions"]
elif TransactionLevelContract(df).validate() is True:
stats_df = df.groupby(segment_col).agg(
revenue=("total_price", "sum"),
transactions=("transaction_id", "nunique"),
customers=("customer_id", "nunique"),
)
else:
raise NotImplementedError(
"The dataframe does not comply with the TransactionItemLevelContract or TransactionLevelContract. "
"These are the only two contracts supported at this time.",
)
total_num_customers = df["customer_id"].nunique()
stats_df["spend_per_cust"] = stats_df["revenue"] / stats_df["customers"]
stats_df["spend_per_transaction"] = stats_df["revenue"] / stats_df["transactions"]
stats_df["transactions_per_customer"] = stats_df["transactions"] / stats_df["customers"]
stats_df["customers_pct"] = stats_df["customers"] / total_num_customers

self.df = stats_df
return stats_df

def plot(
self,
Expand All @@ -185,6 +198,7 @@ def plot(
orientation: Literal["vertical", "horizontal"] = "vertical",
sort_order: Literal["ascending", "descending", None] = None,
source_text: str | None = None,
hide_total: bool = True,
**kwargs: dict[str, any],
) -> SubplotBase:
"""Plots the value_col by segment.
Expand All @@ -203,6 +217,7 @@ def plot(
sort_order (Literal["ascending", "descending", None], optional): The sort order of the segments.
Defaults to None. If None, the segments are plotted in the order they appear in the dataframe.
source_text (str, optional): The source text to add to the plot. Defaults to None.
hide_total (bool, optional): Whether to hide the total row. Defaults to True.
**kwargs: Additional keyword arguments to pass to the Pandas plot function.

Returns:
Expand All @@ -223,6 +238,9 @@ def plot(
kind = "barh"

val_s = self.df[value_col]
if hide_total:
val_s = val_s[val_s.index != "total"]

if sort_order is not None:
ascending = sort_order == "ascending"
val_s = val_s.sort_values(ascending=ascending)
Expand Down
101 changes: 101 additions & 0 deletions tests/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for the SegTransactionStats class."""

import pandas as pd
import pytest

from pyretailscience.segmentation import SegTransactionStats


class TestCalcSegStats:
"""Tests for the _calc_seg_stats method."""

@pytest.fixture()
def base_df(self):
"""Return a base DataFrame for testing."""
return pd.DataFrame(
{
"customer_id": [1, 2, 3, 4, 5],
"total_price": [100, 200, 150, 300, 250],
"transaction_id": [101, 102, 103, 104, 105],
"segment_id": ["A", "B", "A", "B", "A"],
"quantity": [10, 20, 15, 30, 25],
},
)

def test_correctly_calculates_revenue_transactions_customers_per_segment(self, base_df):
"""Test that the method correctly calculates at the transaction-item level."""
expected_output = pd.DataFrame(
{
"revenue": [500, 500, 1000],
"transactions": [3, 2, 5],
"customers": [3, 2, 5],
"total_quantity": [50, 50, 100],
"price_per_unit": [10.0, 10.0, 10.0],
"quantity_per_transaction": [16.666667, 25.0, 20.0],
},
index=["A", "B", "total"],
)

segment_stats = SegTransactionStats._calc_seg_stats(base_df, "segment_id")
pd.testing.assert_frame_equal(segment_stats, expected_output)

def test_correctly_calculates_revenue_transactions_customers(self):
"""Test that the method correctly calculates at the transaction level."""
df = pd.DataFrame(
{
"customer_id": [1, 2, 3, 4, 5],
"total_price": [100, 200, 150, 300, 250],
"transaction_id": [101, 102, 103, 104, 105],
"segment_id": ["A", "B", "A", "B", "A"],
},
)

expected_output = pd.DataFrame(
{
"revenue": [500, 500, 1000],
"transactions": [3, 2, 5],
"customers": [3, 2, 5],
},
index=["A", "B", "total"],
)

segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
pd.testing.assert_frame_equal(segment_stats, expected_output)

def test_does_not_alter_original_dataframe(self, base_df):
"""Test that the method does not alter the original DataFrame."""
original_df = base_df.copy()
_ = SegTransactionStats._calc_seg_stats(base_df, "segment_id")

pd.testing.assert_frame_equal(base_df, original_df)

def test_handles_dataframe_with_one_segment(self, base_df):
"""Test that the method correctly handles a DataFrame with only one segment."""
df = base_df.copy()
df["segment_id"] = "A"

expected_output = pd.DataFrame(
{
"revenue": [1000, 1000],
"transactions": [5, 5],
"customers": [5, 5],
"total_quantity": [100, 100],
"price_per_unit": [10.0, 10.0],
"quantity_per_transaction": [20.0, 20.0],
},
index=["A", "total"],
)

segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
pd.testing.assert_frame_equal(segment_stats, expected_output)


class TestSegTransactionStats:
"""Tests for the SegTransactionStats class."""

def test_handles_empty_dataframe_with_errors(self):
"""Test that the method raises an error when the DataFrame is missing a required column."""
df = pd.DataFrame(columns=["total_price", "transaction_id", "segment_id", "quantity"])

with pytest.raises(ValueError):
SegTransactionStats(df, "segment_id")