Skip to content

feat: segment stats calc now uses duckdb to improve performance #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ scipy = "^1.13.0"
scikit-learn = "^1.4.2"
matplotlib-set-diagrams = "~0.0.2"
toml = "^0.10.2"
duckdb = "^1.0.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
Expand Down
6 changes: 6 additions & 0 deletions pyretailscience/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(self) -> None:
# Calculated columns
"column.calc.price_per_unit": "price_per_unit",
"column.calc.units_per_transaction": "units_per_transaction",
"column.calc.spend_per_customer": "spend_per_customer",
"column.calc.spend_per_transaction": "spend_per_transaction",
"column.calc.transactions_per_customer": "transactions_per_customer",
# Abbreviation suffix
"column.suffix.count": "cnt",
"column.suffix.percent": "pct",
Expand Down Expand Up @@ -86,6 +89,9 @@ def __init__(self) -> None:
# Calculated columns
"column.calc.price_per_unit": "The name of the column containing the price per unit.",
"column.calc.units_per_transaction": "The name of the column containing the units per transaction.",
"column.calc.spend_per_customer": "The name of the column containing the spend per customer.",
"column.calc.spend_per_transaction": "The name of the column containing the spend per transaction.",
"column.calc.transactions_per_customer": "The name of the column containing the transactions per customer.",
# Abbreviation suffixes
"column.suffix.count": "The suffix to use for count columns.",
"column.suffix.percent": "The suffix to use for percentage columns.",
Expand Down
105 changes: 66 additions & 39 deletions pyretailscience/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import Literal

import duckdb
import pandas as pd
from duckdb import DuckDBPyRelation
from matplotlib.axes import Axes, SubplotBase

import pyretailscience.style.graph_utils as gu
Expand Down Expand Up @@ -204,12 +206,14 @@ def __init__(
class SegTransactionStats:
"""Calculates transaction statistics by segment."""

def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
def __init__(self, data: pd.DataFrame | DuckDBPyRelation, segment_col: str = "segment_id") -> None:
"""Calculates transaction statistics by segment.

Args:
df (pd.DataFrame): A dataframe with the transaction data. The dataframe must comply with the
TransactionItemLevelContract or the TransactionLevelContract.
data (pd.DataFrame | DuckDBPyRelation): The transaction data. The dataframe must contain the columns
customer_id, unit_spend and transaction_id. If the dataframe contains the column unit_quantity, then
the columns unit_spend and unit_quantity are used to calculate the price_per_unit and
units_per_transaction.
segment_col (str, optional): The column to use for the segmentation. Defaults to "segment_id".

Raises:
Expand All @@ -223,54 +227,77 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None:
get_option("column.transaction_id"),
segment_col,
]
if get_option("column.unit_quantity") in df.columns:
if get_option("column.unit_quantity") in data.columns:
required_cols.append(get_option("column.unit_quantity"))
contract = CustomContract(
df,
basic_expectations=build_expected_columns(columns=required_cols),
extended_expectations=build_non_null_columns(columns=required_cols),
)

if contract.validate() is False:
msg = f"The dataframe requires the columns {required_cols} and they must be non-null"
missing_cols = set(required_cols) - set(data.columns)

if len(missing_cols) > 0:
msg = f"The following columns are required but missing: {missing_cols}"
raise ValueError(msg)

self.segment_col = segment_col

self.df = self._calc_seg_stats(df, segment_col)
self.df = self._calc_seg_stats(data, segment_col)

@staticmethod
def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame:
aggs = {
get_option("column.agg.unit_spend"): (get_option("column.unit_spend"), "sum"),
get_option("column.agg.transaction_id"): (get_option("column.transaction_id"), "nunique"),
get_option("column.agg.customer_id"): (get_option("column.customer_id"), "nunique"),
}
total_aggs = {
get_option("column.agg.unit_spend"): [df[get_option("column.unit_spend")].sum()],
get_option("column.agg.transaction_id"): [df[get_option("column.transaction_id")].nunique()],
get_option("column.agg.customer_id"): [df[get_option("column.customer_id")].nunique()],
}
if get_option("column.unit_quantity") in df.columns:
aggs[get_option("column.agg.unit_quantity")] = (get_option("column.unit_quantity"), "sum")
total_aggs[get_option("column.agg.unit_quantity")] = [df[get_option("column.unit_quantity")].sum()]

stats_df = pd.concat(
[
df.groupby(segment_col).agg(**aggs),
pd.DataFrame(total_aggs, index=["total"]),
],
)
def _calc_seg_stats(data: pd.DataFrame | DuckDBPyRelation, segment_col: str) -> pd.DataFrame:
"""Calculates the transaction statistics by segment.

Args:
data (DuckDBPyRelation): The transaction data.
segment_col (str): The column to use for the segmentation.

Returns:
pd.DataFrame: The transaction statistics by segment.

"""
if isinstance(data, pd.DataFrame):
data = duckdb.from_df(data)
elif not isinstance(data, DuckDBPyRelation):
raise TypeError("data must be either a pandas DataFrame or a DuckDBPyRelation")

base_aggs = [
f"SUM({get_option('column.unit_spend')}) as {get_option('column.agg.unit_spend')},",
f"COUNT(DISTINCT {get_option('column.transaction_id')}) as {get_option('column.agg.transaction_id')},",
f"COUNT(DISTINCT {get_option('column.customer_id')}) as {get_option('column.agg.customer_id')},",
]

total_customers = data.aggregate("COUNT(DISTINCT customer_id)").fetchone()[0]
return_cols = [
"*,",
f"{get_option('column.agg.unit_spend')} / {get_option('column.agg.customer_id')} ",
f"as {get_option('column.calc.spend_per_customer')},",
f"{get_option('column.agg.unit_spend')} / {get_option('column.agg.transaction_id')} ",
f"as {get_option('column.calc.spend_per_transaction')},",
f"{get_option('column.agg.transaction_id')} / {get_option('column.agg.customer_id')} ",
f"as {get_option('column.calc.transactions_per_customer')},",
f"{get_option('column.agg.customer_id')} / {total_customers}",
f"as customers_{get_option('column.suffix.percent')},",
]

if get_option("column.unit_quantity") in df.columns:
stats_df[get_option("column.calc.price_per_unit")] = (
stats_df[get_option("column.agg.unit_spend")] / stats_df[get_option("column.agg.unit_quantity")]
if get_option("column.unit_quantity") in data.columns:
base_aggs.append(
f"SUM({get_option('column.unit_quantity')})::bigint as {get_option('column.agg.unit_quantity')},",
)
stats_df[get_option("column.calc.units_per_transaction")] = (
stats_df[get_option("column.agg.unit_quantity")] / stats_df[get_option("column.agg.transaction_id")]
return_cols.extend(
[
f"({get_option('column.agg.unit_spend')} / {get_option('column.agg.unit_quantity')}) ",
f"as {get_option('column.calc.price_per_unit')},",
f"({get_option('column.agg.unit_quantity')} / {get_option('column.agg.transaction_id')}) ",
f"as {get_option('column.calc.units_per_transaction')},",
],
)

return stats_df
segment_stats = data.aggregate(f"{segment_col} as segment_name," + "".join(base_aggs))
total_stats = data.aggregate("'Total' as segment_name," + "".join(base_aggs))
final_stats_df = segment_stats.union(total_stats).select("".join(return_cols)).df()
final_stats_df = final_stats_df.set_index("segment_name").sort_index()

# Make sure Total is the last row
desired_index_sort = final_stats_df.index.drop("Total").tolist() + ["Total"] # noqa: RUF005

return final_stats_df.reindex(desired_index_sort)

def plot(
self,
Expand Down
30 changes: 21 additions & 9 deletions tests/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ def test_correctly_calculates_revenue_transactions_customers_per_segment(self, b
"""Test that the method correctly calculates at the transaction-item level."""
expected_output = pd.DataFrame(
{
get_option("column.agg.unit_spend"): [500, 500, 1000],
"segment_name": ["A", "B", "Total"],
get_option("column.agg.unit_spend"): [500.0, 500.0, 1000.0],
get_option("column.agg.transaction_id"): [3, 2, 5],
get_option("column.agg.customer_id"): [3, 2, 5],
get_option("column.agg.unit_quantity"): [50, 50, 100],
get_option("column.calc.spend_per_customer"): [166.666667, 250.0, 200.0],
get_option("column.calc.spend_per_transaction"): [166.666667, 250.0, 200.0],
get_option("column.calc.transactions_per_customer"): [1.0, 1.0, 1.0],
f"customers_{get_option('column.suffix.percent')}": [0.6, 0.4, 1.0],
get_option("column.calc.price_per_unit"): [10.0, 10.0, 10.0],
get_option("column.calc.units_per_transaction"): [16.666667, 25.0, 20.0],
},
index=["A", "B", "total"],
)
).set_index("segment_name")

segment_stats = SegTransactionStats._calc_seg_stats(base_df, "segment_id")
pd.testing.assert_frame_equal(segment_stats, expected_output)
Expand All @@ -53,12 +57,16 @@ def test_correctly_calculates_revenue_transactions_customers(self):

expected_output = pd.DataFrame(
{
get_option("column.agg.unit_spend"): [500, 500, 1000],
"segment_name": ["A", "B", "Total"],
get_option("column.agg.unit_spend"): [500.0, 500.0, 1000.0],
get_option("column.agg.transaction_id"): [3, 2, 5],
get_option("column.agg.customer_id"): [3, 2, 5],
get_option("column.calc.spend_per_customer"): [166.666667, 250.0, 200.0],
get_option("column.calc.spend_per_transaction"): [166.666667, 250.0, 200.0],
get_option("column.calc.transactions_per_customer"): [1.0, 1.0, 1.0],
f"customers_{get_option('column.suffix.percent')}": [0.6, 0.4, 1.0],
},
index=["A", "B", "total"],
)
).set_index("segment_name")

segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
pd.testing.assert_frame_equal(segment_stats, expected_output)
Expand All @@ -77,15 +85,19 @@ def test_handles_dataframe_with_one_segment(self, base_df):

expected_output = pd.DataFrame(
{
get_option("column.agg.unit_spend"): [1000, 1000],
"segment_name": ["A", "Total"],
get_option("column.agg.unit_spend"): [1000.0, 1000.0],
get_option("column.agg.transaction_id"): [5, 5],
get_option("column.agg.customer_id"): [5, 5],
get_option("column.agg.unit_quantity"): [100, 100],
get_option("column.calc.spend_per_customer"): [200.0, 200.0],
get_option("column.calc.spend_per_transaction"): [200.0, 200.0],
get_option("column.calc.transactions_per_customer"): [1.0, 1.0],
f"customers_{get_option('column.suffix.percent')}": [1.0, 1.0],
get_option("column.calc.price_per_unit"): [10.0, 10.0],
get_option("column.calc.units_per_transaction"): [20.0, 20.0],
},
index=["A", "total"],
)
).set_index("segment_name")

segment_stats = SegTransactionStats._calc_seg_stats(df, "segment_id")
pd.testing.assert_frame_equal(segment_stats, expected_output)
Expand Down