|
3 | 3 | import pandas as pd
|
4 | 4 | import pytest
|
5 | 5 |
|
6 |
| -from pyretailscience.segmentation import SegTransactionStats |
| 6 | +from pyretailscience.segmentation import HMLSegmentation, SegTransactionStats |
7 | 7 |
|
8 | 8 |
|
9 | 9 | class TestCalcSegStats:
|
@@ -99,3 +99,110 @@ def test_handles_empty_dataframe_with_errors(self):
|
99 | 99 |
|
100 | 100 | with pytest.raises(ValueError):
|
101 | 101 | SegTransactionStats(df, "segment_id")
|
| 102 | + |
| 103 | + |
| 104 | +class TestHMLSegmentation: |
| 105 | + """Tests for the HMLSegmentation class.""" |
| 106 | + |
| 107 | + @pytest.fixture() |
| 108 | + def base_df(self): |
| 109 | + """Return a base DataFrame for testing.""" |
| 110 | + return pd.DataFrame({"customer_id": [1, 2, 3, 4, 5], "total_price": [1000, 200, 0, 500, 300]}) |
| 111 | + |
| 112 | + def test_no_transactions(self): |
| 113 | + """Test that the method raises an error when there are no transactions.""" |
| 114 | + data = {"customer_id": [], "total_price": []} |
| 115 | + df = pd.DataFrame(data) |
| 116 | + with pytest.raises(ValueError): |
| 117 | + HMLSegmentation(df) |
| 118 | + |
| 119 | + # Correctly handles zero spend customers when zero_value_customers is "exclude" |
| 120 | + def test_handles_zero_spend_customers_are_excluded_in_result(self, base_df): |
| 121 | + """Test that the method correctly handles zero spend customers when zero_value_customers is "exclude".""" |
| 122 | + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="exclude") |
| 123 | + result_df = hml_segmentation.df |
| 124 | + |
| 125 | + zero_spend_customer_id = 3 |
| 126 | + |
| 127 | + assert result_df.loc[1, "segment_name"] == "Heavy" |
| 128 | + assert result_df.loc[1, "segment_id"] == "H" |
| 129 | + assert result_df.loc[2, "segment_name"] == "Light" |
| 130 | + assert result_df.loc[2, "segment_id"] == "L" |
| 131 | + assert zero_spend_customer_id not in result_df.index |
| 132 | + assert result_df.loc[4, "segment_name"] == "Medium" |
| 133 | + assert result_df.loc[4, "segment_id"] == "M" |
| 134 | + assert result_df.loc[5, "segment_name"] == "Light" |
| 135 | + assert result_df.loc[5, "segment_id"] == "L" |
| 136 | + |
| 137 | + # Correctly handles zero spend customers when zero_value_customers is "include_with_light" |
| 138 | + def test_handles_zero_spend_customers_include_with_light(self, base_df): |
| 139 | + """Test that the method correctly handles zero spend customers when zero_value_customers is "include_with_light".""" |
| 140 | + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="include_with_light") |
| 141 | + result_df = hml_segmentation.df |
| 142 | + |
| 143 | + assert result_df.loc[1, "segment_name"] == "Heavy" |
| 144 | + assert result_df.loc[1, "segment_id"] == "H" |
| 145 | + assert result_df.loc[2, "segment_name"] == "Light" |
| 146 | + assert result_df.loc[2, "segment_id"] == "L" |
| 147 | + assert result_df.loc[3, "segment_name"] == "Light" |
| 148 | + assert result_df.loc[3, "segment_id"] == "L" |
| 149 | + assert result_df.loc[4, "segment_name"] == "Medium" |
| 150 | + assert result_df.loc[4, "segment_id"] == "M" |
| 151 | + assert result_df.loc[5, "segment_name"] == "Light" |
| 152 | + assert result_df.loc[5, "segment_id"] == "L" |
| 153 | + |
| 154 | + # Correctly handles zero spend customers when zero_value_customers is "separate_segment" |
| 155 | + def test_handles_zero_spend_customers_separate_segment(self, base_df): |
| 156 | + """Test that the method correctly handles zero spend customers when zero_value_customers is "separate_segment".""" |
| 157 | + hml_segmentation = HMLSegmentation(base_df, zero_value_customers="separate_segment") |
| 158 | + result_df = hml_segmentation.df |
| 159 | + |
| 160 | + assert result_df.loc[1, "segment_name"] == "Heavy" |
| 161 | + assert result_df.loc[1, "segment_id"] == "H" |
| 162 | + assert result_df.loc[2, "segment_name"] == "Light" |
| 163 | + assert result_df.loc[2, "segment_id"] == "L" |
| 164 | + assert result_df.loc[3, "segment_name"] == "Zero" |
| 165 | + assert result_df.loc[3, "segment_id"] == "Z" |
| 166 | + assert result_df.loc[4, "segment_name"] == "Medium" |
| 167 | + assert result_df.loc[4, "segment_id"] == "M" |
| 168 | + assert result_df.loc[5, "segment_name"] == "Light" |
| 169 | + assert result_df.loc[5, "segment_id"] == "L" |
| 170 | + |
| 171 | + # Raises ValueError if required columns are missing |
| 172 | + def test_raises_value_error_if_required_columns_missing(self, base_df): |
| 173 | + """Test that the method raises an error when the DataFrame is missing a required column.""" |
| 174 | + with pytest.raises(ValueError): |
| 175 | + HMLSegmentation(base_df.drop(columns=["customer_id"])) |
| 176 | + |
| 177 | + # DataFrame with only one customer |
| 178 | + def test_segments_customer_single(self): |
| 179 | + """Test that the method correctly segments a DataFrame with only one customer.""" |
| 180 | + data = {"customer_id": [1], "total_price": [0]} |
| 181 | + df = pd.DataFrame(data) |
| 182 | + with pytest.raises(ValueError): |
| 183 | + HMLSegmentation(df) |
| 184 | + |
| 185 | + # Validate that the input dataframe is not changed |
| 186 | + def test_input_dataframe_not_changed(self, base_df): |
| 187 | + """Test that the method does not alter the original DataFrame.""" |
| 188 | + original_df = base_df.copy() |
| 189 | + |
| 190 | + hml_segmentation = HMLSegmentation(base_df) |
| 191 | + _ = hml_segmentation.df |
| 192 | + |
| 193 | + assert original_df.equals(base_df) # Check if the original dataframe is not changed |
| 194 | + |
| 195 | + def test_alternate_value_col(self, base_df): |
| 196 | + """Test that the method correctly segments a DataFrame with an alternate value column.""" |
| 197 | + base_df = base_df.rename(columns={"total_price": "quantity"}) |
| 198 | + hml_segmentation = HMLSegmentation(base_df, value_col="quantity") |
| 199 | + result_df = hml_segmentation.df |
| 200 | + |
| 201 | + assert result_df.loc[1, "segment_name"] == "Heavy" |
| 202 | + assert result_df.loc[1, "segment_id"] == "H" |
| 203 | + assert result_df.loc[2, "segment_name"] == "Light" |
| 204 | + assert result_df.loc[2, "segment_id"] == "L" |
| 205 | + assert result_df.loc[4, "segment_name"] == "Medium" |
| 206 | + assert result_df.loc[4, "segment_id"] == "M" |
| 207 | + assert result_df.loc[5, "segment_name"] == "Light" |
| 208 | + assert result_df.loc[5, "segment_id"] == "L" |
0 commit comments