diff --git a/pyretailscience/options.py b/pyretailscience/options.py new file mode 100644 index 0000000..d354a38 --- /dev/null +++ b/pyretailscience/options.py @@ -0,0 +1,363 @@ +"""This module provides a simplified implementation of a pandas-like options system. + +It allows users to get, set, and reset various options that control the behavior +of data display and processing. The module also includes a context manager for +temporarily changing options. + +Example: + >>> set_option('display.max_rows', 100) + >>> print(get_option('display.max_rows')) + 100 + >>> with option_context('display.max_rows', 10): + ... print(get_option('display.max_rows')) + 10 + >>> print(get_option('display.max_rows')) + 100 + +""" + +from collections.abc import Generator +from contextlib import contextmanager +from functools import lru_cache +from pathlib import Path + +import toml + +OptionTypes = str | int | float | bool | list | dict | None + + +class Options: + """A class to manage configurable options.""" + + def __init__(self) -> None: + """Initializes the options with default values.""" + self._options: dict[str, OptionTypes] = { + # Database columns + "column.customer_id": "customer_id", + "column.transaction_id": "transaction_id", + "column.transaction_date": "transaction_date", + "column.transaction_time": "transaction_time", + "column.product_id": "product_id", + "column.unit_quantity": "unit_quantity", + "column.unit_price": "unit_price", + "column.unit_spend": "unit_spend", + "column.store_id": "store_id", + # Aggregation columns + "column.agg.customer_id": "customers", + "column.agg.transaction_id": "transactions", + "column.agg.product_id": "products", + "column.agg.unit_quantity": "units", + "column.agg.unit_price": "prices", + "column.agg.unit_spend": "spend", + "column.agg.store_id": "stores", + # Calculated columns + "column.calc.price_per_unit": "price_per_unit", + "column.calc.units_per_transaction": "units_per_transaction", + # Abbreviation suffix + "column.suffix.count": "cnt", + "column.suffix.percent": "pct", + "column.suffix.difference": "diff", + "column.suffix.contribution": "contrib", + } + self._descriptions: dict[str, str] = { + # Database columns + "column.customer_id": "The name of the column containing customer IDs.", + "column.transaction_id": "The name of the column containing transaction IDs.", + "column.transaction_date": "The name of the column containing transaction dates.", + "column.transaction_time": "The name of the column containing transaction times.", + "column.product_id": "The name of the column containing product IDs.", + "column.unit_quantity": "The name of the column containing the number of units sold.", + "column.unit_price": "The name of the column containing the unit price of the product.", + "column.unit_spend": ( + "The name of the column containing the total spend of the products in the transaction." + "ie, unit_price * units", + ), + "column.store_id": "The name of the column containing store IDs of the transaction.", + # Aggregation columns + "column.agg.customer_id": "The name of the column containing the number of unique customers.", + "column.agg.transaction_id": "The name of the column containing the number of transactions.", + "column.agg.product_id": "The name of the column containing the number of unique products.", + "column.agg.unit_quantity": "The name of the column containing the total number of units sold.", + "column.agg.unit_price": "The name of the column containing the average unit price of products.", + "column.agg.unit_spend": ( + "The name of the column containing the total spend of the units in the transaction." + ), + "column.agg.store_id": "The name of the column containing the number of unique stores.", + # Calculated columns + "column.calc.price_per_unit": "The name of the column containing the price per unit.", + "column.calc.units_per_transaction": "The name of the column containing the units per transaction.", + # Abbreviation suffixes + "column.suffix.count": "The suffix to use for count columns.", + "column.suffix.percent": "The suffix to use for percentage columns.", + "column.suffix.difference": "The suffix to use for difference columns.", + "column.suffix.contribution": "The suffix to use for revenue contribution columns.", + } + self._default_options: dict[str, OptionTypes] = self._options.copy() + + def set_option(self, pat: str, val: OptionTypes) -> None: + """Set the value of the specified option. + + Args: + pat: The option name. + val: The value to set the option to. + + Raises: + ValueError: If the option name is unknown. + """ + if pat not in self._options: + msg = f"Unknown option: {pat}" + raise ValueError(msg) + + self._options[pat] = val + + def get_option(self, pat: str) -> OptionTypes: + """Get the value of the specified option. + + Args: + pat: The option name. + + Returns: + The value of the option. + + Raises: + ValueError: If the option name is unknown. + """ + if pat in self._options: + return self._options[pat] + + msg = f"Unknown option: {pat}" + raise ValueError(msg) + + def reset_option(self, pat: str) -> None: + """Reset the specified option to its default value. + + Args: + pat: The option name. + + Raises: + ValueError: If the option name is unknown. + """ + if pat not in self._options: + msg = f"Unknown option: {pat}" + raise ValueError(msg) + + self._options[pat] = self._default_options[pat] + + def list_options(self) -> list[str]: + """List all available options. + + Returns: + A list of all option names. + """ + return list(self._options.keys()) + + def describe_option(self, pat: str) -> str: + """Describe the specified option. + + Args: + pat: The option name. + + Returns: + A string describing the option and its current value. + + Raises: + ValueError: If the option name is unknown. + """ + if pat in self._descriptions: + return f"{pat}: {self._descriptions[pat]} (current value: {self._options[pat]})" + + msg = f"Unknown option: {pat}" + raise ValueError(msg) + + @staticmethod + def flatten_options(k: str, v: OptionTypes, parent_key: str = "") -> dict[str, OptionTypes]: + """Flatten nested options into a single dictionary.""" + if parent_key != "": + parent_key += "." + + if isinstance(v, dict): + ret_dict = {} + for sub_key, sub_value in v.items(): + ret_dict.update(Options.flatten_options(sub_key, sub_value, parent_key=f"{parent_key}{k}")) + return ret_dict + + return {f"{parent_key}{k}": v} + + @classmethod + def load_from_project(cls) -> "Options": + """Try to load options from a pyretailscience.toml file in the project root directory. + + If the project root directory cannot be found, return a default Options instance. + + Returns: + An Options instance with options loaded from the pyretailscience.toml file or default + """ + options_instance = cls() + + project_root = find_project_root() + if project_root is None: + return options_instance + + toml_file = Path(project_root) / "pyretailscience.toml" + if toml_file.is_file(): + return Options.load_from_toml(toml_file) + + return options_instance + + @classmethod + def load_from_toml(cls, file_path: str | None = None) -> "Options": + """Load options from a TOML file. + + Args: + file_path: The path to the TOML file. + + Raises: + ValueError: If the TOML file contains unknown options. + """ + options_instance = cls() + + with open(file_path) as f: + toml_data = toml.load(f) + + for section, options in toml_data.items(): + for option_name, option_value in Options.flatten_options(section, options).items(): + if option_name in options_instance._options: # noqa: SLF001 + options_instance.set_option(option_name, option_value) + else: + msg = f"Unknown option in TOML file: {option_name}" + raise ValueError(msg) + + return options_instance + + +@lru_cache +def find_project_root() -> str | None: + """Returns the directory containing .git, .hg, or pyproject.toml, starting from the current working directory.""" + current_dir = Path.cwd() + + while True: + if (Path(current_dir / ".git")).is_dir() or (Path(current_dir / "pyretailscience.toml")).is_file(): + return current_dir + + parent_dir = Path(current_dir).parent + reached_root = parent_dir == current_dir + if reached_root: + return None + + current_dir = parent_dir + + +# Global instance of Options +_global_options = Options().load_from_project() + + +def set_option(pat: str, val: OptionTypes) -> None: + """Set the value of the specified option. + + This is a global function that delegates to the _global_options instance. + + Args: + pat: The option name. + val: The value to set the option to. + + Raises: + ValueError: If the option name is unknown. + """ + _global_options.set_option(pat, val) + + +def get_option(pat: str) -> OptionTypes: + """Get the value of the specified option. + + This is a global function that delegates to the _global_options instance. + + Args: + pat: The option name. + + Returns: + The value of the option. + + Raises: + ValueError: If the option name is unknown. + """ + return _global_options.get_option(pat) + + +def reset_option(pat: str) -> None: + """Reset the specified option to its default value. + + This is a global function that delegates to the _global_options instance. + + Args: + pat: The option name. + + Raises: + ValueError: If the option name is unknown. + """ + _global_options.reset_option(pat) + + +def list_options() -> list[str]: + """List all available options. + + This is a global function that delegates to the _global_options instance. + + Returns: + A list of all option names. + """ + return _global_options.list_options() + + +def describe_option(pat: str) -> str: + """Describe the specified option. + + This is a global function that delegates to the _global_options instance. + + Args: + pat: The option name. + + Returns: + A string describing the option and its current value. + + Raises: + ValueError: If the option name is unknown. + """ + return _global_options.describe_option(pat) + + +@contextmanager +def option_context(*args: OptionTypes) -> Generator[None, None, None]: + """Context manager to temporarily set options. + + Temporarily set options and restore them to their previous values after the + context exits. The arguments should be supplied as alternating option names + and values. + + Args: + *args: An even number of arguments, alternating between option names (str) + and their corresponding values. + + Yields: + None + + Raises: + ValueError: If an odd number of arguments is supplied. + + Example: + >>> with option_context('display.max_rows', 10, 'display.max_columns', 5): + ... # Do something with modified options + ... pass + >>> # Options are restored to their previous values here + """ + if len(args) % 2 != 0: + raise ValueError("The context manager requires an even number of arguments") + + old_options: dict[str, OptionTypes] = {} + try: + for pat, val in zip(args[::2], args[1::2], strict=True): + old_options[pat] = get_option(pat) + set_option(pat, val) + yield + finally: + for pat, val in old_options.items(): + set_option(pat, val) diff --git a/pyretailscience/segmentation.py b/pyretailscience/segmentation.py index 4bd58ad..a020dfc 100644 --- a/pyretailscience/segmentation.py +++ b/pyretailscience/segmentation.py @@ -12,6 +12,7 @@ build_expected_unique_columns, build_non_null_columns, ) +from pyretailscience.options import get_option from pyretailscience.style.tailwind import COLORS @@ -31,7 +32,12 @@ def add_segment(self, df: pd.DataFrame) -> pd.DataFrame: ValueError: If the number of rows before and after the merge do not match. """ rows_before = len(df) - df = df.merge(self.df[["segment_name", "segment_id"]], how="left", left_on="customer_id", right_index=True) + df = df.merge( + self.df[["segment_name", "segment_id"]], + how="left", + left_on=get_option("column.customer_id"), + right_index=True, + ) rows_after = len(df) if rows_before != rows_after: raise ValueError("The number of rows before and after the merge do not match. This should not happen.") @@ -51,7 +57,7 @@ def __init__(self, df: pd.DataFrame) -> None: Raises: ValueError: If the dataframe does not have the columns customer_id, segment_name and segment_id. """ - required_cols = "customer_id", "segment_name", "segment_id" + required_cols = get_option("column.customer_id"), "segment_name", "segment_id" contract = CustomContract( df, basic_expectations=build_expected_columns(columns=required_cols), @@ -63,7 +69,9 @@ def __init__(self, df: pd.DataFrame) -> None: msg = f"The dataframe requires the columns {required_cols} and they must be non-null and unique." raise ValueError(msg) - self.df = df[["customer_id", "segment_name", "segment_id"]].set_index("customer_id") + self.df = df[[get_option("column.customer_id"), "segment_name", "segment_id"]].set_index( + get_option("column.customer_id"), + ) class ThresholdSegmentation(BaseSegmentation): @@ -74,7 +82,7 @@ def __init__( df: pd.DataFrame, thresholds: list[float], segments: dict[any, str], - value_col: str = "total_price", + value_col: str | None = None, agg_func: str = "sum", zero_segment_name: str = "Zero", zero_segment_id: str = "Z", @@ -86,7 +94,7 @@ def __init__( df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column. thresholds (List[float]): The percentile thresholds for segmentation. segments (Dict[str, str]): A dictionary where keys are segment IDs and values are segment names. - value_col (str): The column to use for the segmentation. + value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend"). agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". zero_segment_name (str, optional): The name of the segment for customers with zero spend. Defaults to "Zero". zero_segment_id (str, optional): The ID of the segment for customers with zero spend. Defaults to "Z". @@ -100,7 +108,9 @@ def __init__( if df.empty: raise ValueError("Input DataFrame is empty") - required_cols = ["customer_id", value_col] + value_col = get_option("column.unit_spend") if value_col is None else value_col + + required_cols = [get_option("column.customer_id"), value_col] contract = CustomContract( df, basic_expectations=build_expected_columns(columns=required_cols), @@ -128,7 +138,7 @@ def __init__( raise ValueError("The number of thresholds must match the number of segments.") # Group by customer_id and calculate total_spend - grouped_df = df.groupby("customer_id")[value_col].agg(agg_func).to_frame(value_col) + grouped_df = df.groupby(get_option("column.customer_id"))[value_col].agg(agg_func).to_frame(value_col) # Separate customers with zero spend self.df = grouped_df @@ -138,7 +148,7 @@ def __init__( zero_cust_df["segment_name"] = zero_segment_name zero_cust_df["segment_id"] = zero_segment_id - self.df = grouped_df[~zero_idx] + self.df = grouped_df[~zero_idx].copy() # Create a new column 'segment' based on the total_spend labels = list(segments.values()) @@ -161,20 +171,25 @@ class HMLSegmentation(ThresholdSegmentation): def __init__( self, df: pd.DataFrame, - value_col: str = "total_price", + value_col: str | None = None, agg_func: str = "sum", zero_value_customers: Literal["separate_segment", "exclude", "include_with_light"] = "separate_segment", ) -> None: """Segments customers into Heavy, Medium, Light and Zero spenders based on the total spend. + HMLSegmentation is a subclass of ThresholdSegmentation and based around an industry standard definition. The + thresholds for Heavy (top 20%), Medium (next 30%) and Light (bottom 50%) are chosen based on the pareto + distribution, commonly know as the 80/20 rule. It is typically used in retail to segment customers based on + their spend, transaction volume or quantities purchased. + Args: df (pd.DataFrame): A dataframe with the transaction data. The dataframe must contain a customer_id column. - value_col (str, optional): The column to use for the segmentation. Defaults to "total_price". + value_col (str, optional): The column to use for the segmentation. Defaults to get_option("column.unit_spend"). agg_func (str, optional): The aggregation function to use when grouping by customer_id. Defaults to "sum". zero_value_customers (Literal["separate_segment", "exclude", "include_with_light"], optional): How to handle customers with zero spend. Defaults to "separate_segment". """ - thresholds = [0.500, 0.800, 1] + thresholds = [0, 0.500, 0.800, 1] segments = {"L": "Light", "M": "Medium", "H": "Heavy"} super().__init__( df=df, @@ -202,9 +217,14 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None: TransactionLevelContract. """ - required_cols = ["customer_id", "total_price", "transaction_id", segment_col] - if "quantity" in df.columns: - required_cols.append("quantity") + required_cols = [ + get_option("column.customer_id"), + get_option("column.unit_spend"), + get_option("column.transaction_id"), + segment_col, + ] + if get_option("column.unit_quantity") in df.columns: + required_cols.append(get_option("column.unit_quantity")) contract = CustomContract( df, basic_expectations=build_expected_columns(columns=required_cols), @@ -222,18 +242,18 @@ def __init__(self, df: pd.DataFrame, segment_col: str = "segment_id") -> None: @staticmethod def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame: aggs = { - "revenue": ("total_price", "sum"), - "transactions": ("transaction_id", "nunique"), - "customers": ("customer_id", "nunique"), + get_option("column.agg.unit_spend"): (get_option("column.unit_spend"), "sum"), + get_option("column.agg.transaction_id"): (get_option("column.transaction_id"), "nunique"), + get_option("column.agg.customer_id"): (get_option("column.customer_id"), "nunique"), } total_aggs = { - "revenue": [df["total_price"].sum()], - "transactions": [df["transaction_id"].nunique()], - "customers": [df["customer_id"].nunique()], + get_option("column.agg.unit_spend"): [df[get_option("column.unit_spend")].sum()], + get_option("column.agg.transaction_id"): [df[get_option("column.transaction_id")].nunique()], + get_option("column.agg.customer_id"): [df[get_option("column.customer_id")].nunique()], } - if "quantity" in df.columns: - aggs["total_quantity"] = ("quantity", "sum") - total_aggs["total_quantity"] = [df["quantity"].sum()] + if get_option("column.unit_quantity") in df.columns: + aggs[get_option("column.agg.unit_quantity")] = (get_option("column.unit_quantity"), "sum") + total_aggs[get_option("column.agg.unit_quantity")] = [df[get_option("column.unit_quantity")].sum()] stats_df = pd.concat( [ @@ -242,9 +262,13 @@ def _calc_seg_stats(df: pd.DataFrame, segment_col: str) -> pd.DataFrame: ], ) - if "quantity" in df.columns: - stats_df["price_per_unit"] = stats_df["revenue"] / stats_df["total_quantity"] - stats_df["quantity_per_transaction"] = stats_df["total_quantity"] / stats_df["transactions"] + if get_option("column.unit_quantity") in df.columns: + stats_df[get_option("column.calc.price_per_unit")] = ( + stats_df[get_option("column.agg.unit_spend")] / stats_df[get_option("column.agg.unit_quantity")] + ) + stats_df[get_option("column.calc.units_per_transaction")] = ( + stats_df[get_option("column.agg.unit_quantity")] / stats_df[get_option("column.agg.transaction_id")] + ) return stats_df diff --git a/tests/test_options.py b/tests/test_options.py new file mode 100644 index 0000000..ef9b70b --- /dev/null +++ b/tests/test_options.py @@ -0,0 +1,218 @@ +"""Tests for the Options module.""" + +from pathlib import Path +from unittest.mock import patch + +import pytest +import toml + +import pyretailscience.options as opt + + +class TestOptions: + """Test for option handling class.""" + + def test_unknown_option_raises_value_error(self): + """Test setting/getting/resetting an unknown option raises a ValueError.""" + options = opt.Options() + with pytest.raises(ValueError, match="Unknown option: unknown.option"): + options.set_option("unknown.option", "some_value") + with pytest.raises(ValueError, match="Unknown option: unknown.option"): + options.get_option("unknown_option") + with pytest.raises(ValueError, match="Unknown option: unknown.option"): + options.reset_option("unknown_option") + with pytest.raises(ValueError, match="Unknown option: unknown.option"): + options.describe_option("unknown_option") + + def test_list_options_returns_all_options(self): + """Test listing all options returns all options.""" + options = opt.Options() + assert options.list_options() == list(options._options.keys()) + + def test_set_option_updates_value(self): + """Test setting an option updates the option value correctly.""" + options = opt.Options() + options.set_option("column.customer_id", "new_customer_id") + assert options.get_option("column.customer_id") == "new_customer_id" + + def test_get_option_retrieves_correct_value(self): + """Test getting an option retrieves the correct value.""" + options = opt.Options() + expected_value = options._options["column.customer_id"] + actual_value = options.get_option("column.customer_id") + assert actual_value == expected_value + + def test_reset_option_restores_default_value(self): + """Test resetting an option restores its default value.""" + options = opt.Options() + expected_value = options._options["column.customer_id"] + options.set_option("column.customer_id", "new_customer_id") + options.reset_option("column.customer_id") + assert options.get_option("column.customer_id") == expected_value + + def test_describe_option_correct_description_and_value(self): + """Test describing an option provides the correct description and current value.""" + options = opt.Options() + option = "column.customer_id" + expected_description = options._descriptions[option] + expected_value = options._options[option] + + description = options.describe_option(option) + assert description == f"{option}: {expected_description} (current value: {expected_value})" + + def test_matching_keys_between_options_and_descriptions(self): + """Test that all options have a corresponding description and vice versa.""" + options = opt.Options() + assert set(options._options.keys()) == set(options._descriptions.keys()) + + def test_context_manager_overrides_option(self): + """Test that the context manager overrides the option value correctly at the global level.""" + original_value = opt.get_option("column.customer_id") + with opt.option_context("column.customer_id", "new_customer_id"): + assert opt.get_option("column.customer_id") == "new_customer_id" + assert opt.get_option("column.customer_id") == original_value + + def test_context_manager_odd_number_of_arguments_raises_value_error(self): + """Test that the context manager raises a ValueError when an odd number of arguments is passed.""" + with ( + pytest.raises(ValueError, match="The context manager requires an even number of arguments"), + opt.option_context("column.customer_id"), + ): + pass + + def test_set_option_updates_value_global_level(self): + """Test setting an option updates the option value correctly at the global level.""" + opt.set_option("column.customer_id", "new_customer_id") + assert opt.get_option("column.customer_id") == "new_customer_id" + opt.reset_option("column.customer_id") + + def test_get_option_retrieves_correct_value_global_level(self): + """Test getting an option retrieves the correct value at the global level.""" + # Instantiate Options class to get the default value + options = opt.Options() + expected_value = options._options["column.customer_id"] + del options + + actual_value = opt.get_option("column.customer_id") + assert actual_value == expected_value + + def test_reset_option_restores_default_value_global_level(self): + """Test resetting an option restores its default value at the global level.""" + # Instantiate Options class to get the default value + options = opt.Options() + expected_value = options._options["column.customer_id"] + del options + + opt.set_option("column.customer_id", "new_customer_id") + opt.reset_option("column.customer_id") + assert opt.get_option("column.customer_id") == expected_value + + def test_describe_option_correct_description_and_value_global_level(self): + """Test describing an option provides the correct description and current value at the global level.""" + option = "column.customer_id" + # Instantiate Options class to get the default value + options = opt.Options() + expected_description = options._descriptions[option] + expected_value = options._options[option] + del options + + description = opt.describe_option(option) + assert description == f"{option}: {expected_description} (current value: {expected_value})" + + def test_list_options_returns_all_options_global_level(self): + """Test listing all options returns all options at the global level.""" + options = opt.Options() + options_list = list(options._options.keys()) + del options + + assert opt.list_options() == options_list + + def test_load_invalid_format_toml(self): + """Test loading an invalid TOML file raises a ValueError.""" + test_file_path = Path("tests/toml_files/corrupt.toml").resolve() + with pytest.raises(toml.TomlDecodeError): + opt.Options.load_from_toml(test_file_path) + + def test_load_valid_toml(self): + """Test loading a valid TOML file updates the options correctly.""" + test_file_path = Path("tests/toml_files/valid.toml").resolve() + options = opt.Options.load_from_toml(test_file_path) + assert options.get_option("column.customer_id") == "new_customer_id" + assert options.get_option("column.product_id") == "new_product_id" + assert options.get_option("column.agg.customer_id") == "new_customers" + assert options.get_option("column.calc.price_per_unit") == "new_price_per_unit" + assert options.get_option("column.suffix.count") == "new_cnt" + + def test_load_invalid_option_toml(self): + """Test loading an invalid TOML file raises a ValueError.""" + test_file_path = Path("tests/toml_files/invalid_option.toml").resolve() + with pytest.raises(ValueError, match="Unknown option in TOML file: column.agg.unknown_column"): + opt.Options.load_from_toml(test_file_path) + + def test_flatten_options(self): + """Test flattening the options dictionary.""" + nested_options = { + "column": { + "customer_id": "customer_id", + "agg": { + "customer_id": "customer_id", + "product_id": "product_id", + }, + }, + } + expected_flat_options = { + "column.customer_id": "customer_id", + "column.agg.customer_id": "customer_id", + "column.agg.product_id": "product_id", + } + assert expected_flat_options == opt.Options.flatten_options("column", nested_options["column"]) + + @pytest.fixture() + def _reset_lru_cache(self): + opt.find_project_root.cache_clear() + yield + opt.find_project_root.cache_clear() + + @pytest.mark.usefixtures("_reset_lru_cache") + @patch("pathlib.Path.cwd") + @patch("pathlib.Path.is_dir") + def test_find_project_root_git_found(self, mock_is_dir, mock_cwd): + """Test finding the project root when the .git directory is found.""" + mock_cwd.return_value = Path("/home/user/project") + mock_is_dir.side_effect = [True] # .git directory exists + assert opt.find_project_root() == Path("/home/user/project") + + @pytest.mark.usefixtures("_reset_lru_cache") + @patch("pathlib.Path.cwd") + @patch("pathlib.Path.is_dir") + @patch("pathlib.Path.is_file") + def test_find_project_root_toml_found(self, mock_is_file, mock_is_dir, mock_cwd): + """Test finding the project root when the pyretailscience.toml file is found.""" + mock_cwd.return_value = Path("/home/user/project") + mock_is_dir.side_effect = [False] # .git directory doesn't exist + mock_is_file.side_effect = [True] # pyretailscience.toml file exists + assert opt.find_project_root() == Path("/home/user/project") + + @pytest.mark.usefixtures("_reset_lru_cache") + @patch("pathlib.Path.cwd") + @patch("pathlib.Path.is_dir") + @patch("pathlib.Path.is_file") + @patch("pathlib.Path.parent") + def test_find_project_root_no_project_found(self, mock_parent, mock_is_file, mock_is_dir, mock_cwd): + """Test finding the project root when no project root is found.""" + mock_cwd.return_value = Path("/") + mock_is_dir.side_effect = [False, False] # No .git directory + mock_is_file.side_effect = [False, False] # No pyretailscience.toml file + mock_parent.return_value = Path("/") + assert opt.find_project_root() is None + + @pytest.mark.usefixtures("_reset_lru_cache") + @patch("pathlib.Path.cwd") + @patch("pathlib.Path.is_dir") + @patch("pathlib.Path.is_file") + def test_find_project_root_found_in_parent(self, mock_is_file, mock_is_dir, mock_cwd): + """Test finding the project root when the project root is found in a parent directory.""" + mock_cwd.return_value = Path("/home/user/project/subdir") + mock_is_dir.side_effect = [False, True] # .git directory in parent + mock_is_file.side_effect = [False] # No pyretailscience.toml file + assert opt.find_project_root() == Path("/home/user/project") diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index cc923a9..f9ec029 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -3,6 +3,7 @@ import pandas as pd import pytest +from pyretailscience.options import get_option from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats, ThresholdSegmentation @@ -14,11 +15,11 @@ def base_df(self): """Return a base DataFrame for testing.""" return pd.DataFrame( { - "customer_id": [1, 2, 3, 4, 5], - "total_price": [100, 200, 150, 300, 250], - "transaction_id": [101, 102, 103, 104, 105], + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 150, 300, 250], + get_option("column.transaction_id"): [101, 102, 103, 104, 105], "segment_id": ["A", "B", "A", "B", "A"], - "quantity": [10, 20, 15, 30, 25], + get_option("column.unit_quantity"): [10, 20, 15, 30, 25], }, ) @@ -26,12 +27,12 @@ def test_correctly_calculates_revenue_transactions_customers_per_segment(self, b """Test that the method correctly calculates at the transaction-item level.""" expected_output = pd.DataFrame( { - "revenue": [500, 500, 1000], - "transactions": [3, 2, 5], - "customers": [3, 2, 5], - "total_quantity": [50, 50, 100], - "price_per_unit": [10.0, 10.0, 10.0], - "quantity_per_transaction": [16.666667, 25.0, 20.0], + get_option("column.agg.unit_spend"): [500, 500, 1000], + get_option("column.agg.transaction_id"): [3, 2, 5], + get_option("column.agg.customer_id"): [3, 2, 5], + get_option("column.agg.unit_quantity"): [50, 50, 100], + get_option("column.calc.price_per_unit"): [10.0, 10.0, 10.0], + get_option("column.calc.units_per_transaction"): [16.666667, 25.0, 20.0], }, index=["A", "B", "total"], ) @@ -43,18 +44,18 @@ def test_correctly_calculates_revenue_transactions_customers(self): """Test that the method correctly calculates at the transaction level.""" df = pd.DataFrame( { - "customer_id": [1, 2, 3, 4, 5], - "total_price": [100, 200, 150, 300, 250], - "transaction_id": [101, 102, 103, 104, 105], + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 150, 300, 250], + get_option("column.transaction_id"): [101, 102, 103, 104, 105], "segment_id": ["A", "B", "A", "B", "A"], }, ) expected_output = pd.DataFrame( { - "revenue": [500, 500, 1000], - "transactions": [3, 2, 5], - "customers": [3, 2, 5], + get_option("column.agg.unit_spend"): [500, 500, 1000], + get_option("column.agg.transaction_id"): [3, 2, 5], + get_option("column.agg.customer_id"): [3, 2, 5], }, index=["A", "B", "total"], ) @@ -76,12 +77,12 @@ def test_handles_dataframe_with_one_segment(self, base_df): expected_output = pd.DataFrame( { - "revenue": [1000, 1000], - "transactions": [5, 5], - "customers": [5, 5], - "total_quantity": [100, 100], - "price_per_unit": [10.0, 10.0], - "quantity_per_transaction": [20.0, 20.0], + get_option("column.agg.unit_spend"): [1000, 1000], + get_option("column.agg.transaction_id"): [5, 5], + get_option("column.agg.customer_id"): [5, 5], + get_option("column.agg.unit_quantity"): [100, 100], + get_option("column.calc.price_per_unit"): [10.0, 10.0], + get_option("column.calc.units_per_transaction"): [20.0, 20.0], }, index=["A", "total"], ) @@ -95,14 +96,19 @@ class TestThresholdSegmentation: def test_correct_segmentation(self): """Test that the method correctly segments customers based on given thresholds and segments.""" - df = pd.DataFrame({"customer_id": [1, 2, 3, 4], "total_price": [100, 200, 300, 400]}) + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4], + get_option("column.unit_spend"): [100, 200, 300, 400], + }, + ) thresholds = [0.5, 1] segments = {0: "Low", 1: "High"} seg = ThresholdSegmentation( df=df, thresholds=thresholds, segments=segments, - value_col="total_price", + value_col=get_option("column.unit_spend"), zero_value_customers="exclude", ) result_df = seg.df @@ -113,7 +119,7 @@ def test_correct_segmentation(self): def test_single_customer(self): """Test that the method correctly segments a DataFrame with only one customer.""" - df = pd.DataFrame({"customer_id": [1], "total_price": [100]}) + df = pd.DataFrame({get_option("column.customer_id"): [1], get_option("column.unit_spend"): [100]}) thresholds = [0.5, 1] segments = {0: "Low"} with pytest.raises(ValueError): @@ -127,7 +133,7 @@ def test_correct_aggregation_function(self): """Test that the correct aggregation function is applied for product_id custom segmentation.""" df = pd.DataFrame( { - "customer_id": [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5], + get_option("column.customer_id"): [1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5], "product_id": [3, 4, 4, 6, 1, 5, 7, 2, 2, 3, 2, 3, 4, 1], }, ) @@ -145,7 +151,7 @@ def test_correct_aggregation_function(self): expected_result = pd.DataFrame( { - "customer_id": [1, 2, 3, 4, 5], + get_option("column.customer_id"): [1, 2, 3, 4, 5], "product_id": [1, 4, 2, 2, 3], "segment_name": ["Low", "High", "Medium", "Medium", "Medium"], "segment_id": ["A", "C", "B", "B", "B"], @@ -167,11 +173,11 @@ def test_correctly_checks_segment_data(self): """Test that the method correctly merges segment data back into the original DataFrame.""" df = pd.DataFrame( { - "customer_id": [1, 2, 3, 4, 5], - "total_price": [100, 200, 0, 150, 0], + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 0, 150, 0], }, ) - value_col = "total_price" + value_col = get_option("column.unit_spend") agg_func = "sum" thresholds = [0.33, 0.66, 1] segments = {"A": "Low", "B": "Medium", "C": "High"} @@ -193,8 +199,8 @@ def test_correctly_checks_segment_data(self): # Assert the correct segment_name and segment_id expected_df = pd.DataFrame( { - "customer_id": [1, 2, 3, 4, 5], - "total_price": [100, 200, 0, 150, 0], + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 0, 150, 0], "segment_name": ["Low", "High", "Zero", "Medium", "Zero"], "segment_id": ["A", "C", "Z", "B", "Z"], }, @@ -203,11 +209,16 @@ def test_correctly_checks_segment_data(self): def test_handles_dataframe_with_duplicate_customer_id_entries(self): """Test that the method correctly handles a DataFrame with duplicate customer_id entries.""" - df = pd.DataFrame({"customer_id": [1, 2, 3, 1, 2, 3], "total_price": [100, 200, 300, 150, 250, 350]}) + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 1, 2, 3], + get_option("column.unit_spend"): [100, 200, 300, 150, 250, 350], + }, + ) my_seg = ThresholdSegmentation( df=df, - value_col="total_price", + value_col=get_option("column.unit_spend"), agg_func="sum", thresholds=[0.5, 0.8, 1], segments={"L": "Light", "M": "Medium", "H": "Heavy"}, @@ -220,8 +231,13 @@ def test_handles_dataframe_with_duplicate_customer_id_entries(self): def test_correctly_maps_segment_names_to_segment_ids_with_fixed_thresholds(self): """Test that the method correctly maps segment names to segment IDs with fixed thresholds.""" # Setup - df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) - value_col = "total_price" + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 300, 400, 500], + }, + ) + value_col = get_option("column.unit_spend") agg_func = "sum" thresholds = [0.33, 0.66, 1] segments = {1: "Low", 2: "Medium", 3: "High"} @@ -241,7 +257,12 @@ def test_correctly_maps_segment_names_to_segment_ids_with_fixed_thresholds(self) def test_thresholds_not_unique(self): """Test that the method raises an error when the thresholds are not unique.""" - df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 300, 400, 500], + }, + ) thresholds = [0.5, 0.5, 0.8, 1] segments = {1: "Low", 2: "Medium", 3: "High"} @@ -250,7 +271,12 @@ def test_thresholds_not_unique(self): def test_thresholds_too_few_segments(self): """Test that the method raises an error when there are too few/many segments for the number of thresholds.""" - df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 300, 400, 500], + }, + ) thresholds = [0.4, 0.6, 0.8, 1] segments = {1: "Low", 3: "High"} @@ -264,7 +290,12 @@ def test_thresholds_too_few_segments(self): def test_thresholds_too_too_few_thresholds(self): """Test that the method raises an error when there are too few/many thresholds for the number of segments.""" - df = pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [100, 200, 300, 400, 500]}) + df = pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [100, 200, 300, 400, 500], + }, + ) thresholds = [0.4, 1] segments = {1: "Low", 2: "Medium", 3: "High"} @@ -282,7 +313,9 @@ class TestSegTransactionStats: def test_handles_empty_dataframe_with_errors(self): """Test that the method raises an error when the DataFrame is missing a required column.""" - df = pd.DataFrame(columns=["total_price", "transaction_id", "segment_id", "quantity"]) + df = pd.DataFrame( + columns=[get_option("column.unit_spend"), get_option("column.transaction_id"), "segment_id", "quantity"], + ) with pytest.raises(ValueError): SegTransactionStats(df, "segment_id") @@ -294,11 +327,16 @@ class TestHMLSegmentation: @pytest.fixture() def base_df(self): """Return a base DataFrame for testing.""" - return pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [1000, 200, 0, 500, 300]}) + return pd.DataFrame( + { + get_option("column.customer_id"): [1, 2, 3, 4, 5], + get_option("column.unit_spend"): [1000, 200, 0, 500, 300], + }, + ) def test_no_transactions(self): """Test that the method raises an error when there are no transactions.""" - data = {"customer_id": [], "total_price": []} + data = {get_option("column.customer_id"): [], get_option("column.unit_spend"): []} df = pd.DataFrame(data) with pytest.raises(ValueError): HMLSegmentation(df) @@ -359,12 +397,12 @@ def test_handles_zero_spend_customers_separate_segment(self, base_df): def test_raises_value_error_if_required_columns_missing(self, base_df): """Test that the method raises an error when the DataFrame is missing a required column.""" with pytest.raises(ValueError): - HMLSegmentation(base_df.drop(columns=["customer_id"])) + HMLSegmentation(base_df.drop(columns=[get_option("column.customer_id")])) # DataFrame with only one customer def test_segments_customer_single(self): """Test that the method correctly segments a DataFrame with only one customer.""" - data = {"customer_id": [1], "total_price": [0]} + data = {get_option("column.customer_id"): [1], get_option("column.unit_spend"): [0]} df = pd.DataFrame(data) with pytest.raises(ValueError): HMLSegmentation(df) @@ -381,7 +419,7 @@ def test_input_dataframe_not_changed(self, base_df): def test_alternate_value_col(self, base_df): """Test that the method correctly segments a DataFrame with an alternate value column.""" - base_df = base_df.rename(columns={"total_price": "quantity"}) + base_df = base_df.rename(columns={get_option("column.unit_spend"): "quantity"}) hml_segmentation = HMLSegmentation(base_df, value_col="quantity") result_df = hml_segmentation.df diff --git a/tests/toml_files/corrupt.toml b/tests/toml_files/corrupt.toml new file mode 100644 index 0000000..67ff45d --- /dev/null +++ b/tests/toml_files/corrupt.toml @@ -0,0 +1 @@ +This is not toml and should generate an error. diff --git a/tests/toml_files/invalid_option.toml b/tests/toml_files/invalid_option.toml new file mode 100644 index 0000000..f6a4d00 --- /dev/null +++ b/tests/toml_files/invalid_option.toml @@ -0,0 +1,3 @@ +[column.agg] +customer_id="customer_id" +unknown_column="unknown_column" diff --git a/tests/toml_files/valid.toml b/tests/toml_files/valid.toml new file mode 100644 index 0000000..79b7cfe --- /dev/null +++ b/tests/toml_files/valid.toml @@ -0,0 +1,12 @@ +[column] +customer_id="new_customer_id" +product_id="new_product_id" + +[column.agg] +customer_id="new_customers" + +[column.calc] +price_per_unit="new_price_per_unit" + +[column.suffix] +count="new_cnt"