Skip to content

Commit 18ddd7d

Browse files
authored
Prod assoc perf improvements (#71)
* feat: improved performance of product assocation calculation * fix: revert to matrix sparse structures
1 parent ca99359 commit 18ddd7d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pyretailscience/product_association.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
import numpy as np
4242
import pandas as pd
43-
from scipy.sparse import csc_array
43+
from scipy.sparse import csc_matrix
4444
from tqdm import tqdm
4545

4646
from pyretailscience.data.contracts import CustomContract, build_expected_columns, build_non_null_columns
@@ -234,7 +234,7 @@ def _calc_association( # noqa: C901 (ignore complexity) - Excluded due to min_*
234234
unique_combo_df[value_col] = pd.Categorical(unique_combo_df[value_col], ordered=True)
235235
unique_combo_df[group_col] = pd.Categorical(unique_combo_df[group_col], ordered=True)
236236

237-
sparse_matrix = csc_array(
237+
sparse_matrix = csc_matrix(
238238
(
239239
[1] * len(unique_combo_df),
240240
(
@@ -273,7 +273,7 @@ def _calc_association( # noqa: C901 (ignore complexity) - Excluded due to min_*
273273
target_item_col_index[target_item_loc] = True
274274
rows_with_target_item = sparse_matrix[:, target_item_col_index].getnnz(axis=1) == len(target_item_loc)
275275

276-
cooccurrences = sparse_matrix[rows_with_target_item, :].sum(axis=0).flatten()
276+
cooccurrences = np.array(sparse_matrix[rows_with_target_item, :].sum(axis=0)).flatten()
277277
if (cooccurrences == 0).all():
278278
continue
279279

0 commit comments

Comments
 (0)