Skip to content

Commit 5642129

Browse files
authored
Merge pull request #9 from scikit-learn-contrib/refactor/optimize-weak-point-assignment
Refactor: Optimized _assign_weak_points using boolean masking
2 parents 2d42fe0 + c250f36 commit 5642129

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

denmune_skl/denmune.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,20 +359,31 @@ def _assign_weak_points(self, mutual_neighbors, labels, cluster_parent):
359359
Modifies labels in-place.
360360
"""
361361
# We loop until convergence (no new points assigned)
362-
while True:
362+
unclassified_indices = np.where(labels == -1)[0]
363+
364+
while len(unclassified_indices) > 0:
363365
newly_assigned_count = 0
364-
unclassified_indices = np.where(labels == -1)[0]
365366

366-
if len(unclassified_indices) == 0:
367-
break
367+
# Pre-allocate a mask for points we fail to classify this round.
368+
# Default is False (don't keep), we set to True if we need to try again.
369+
keep_mask = np.zeros(len(unclassified_indices), dtype=bool)
370+
371+
# 'idx' is the index of point within `unclassified_indices`
372+
# 'sample_idx' is the index of the point within X (Original data)
373+
for idx, sample_idx in enumerate(unclassified_indices):
374+
mnn_of_i = mutual_neighbors[sample_idx]
368375

369-
for i in unclassified_indices:
370-
mnn_of_i = mutual_neighbors[i]
376+
# Optimization: If point has no neighbors, it is noise.
377+
# Leave keep_mask[idx] as False. It will be dropped forever.
371378
if not mnn_of_i:
372379
continue
373380

374381
classified_mnn = [n for n in mnn_of_i if labels[n] != -1]
382+
383+
# If no neighbors are classified yet, we must keep this point for
384+
# the next pass.
375385
if not classified_mnn:
386+
keep_mask[idx] = True
376387
continue
377388

378389
# Vote for the cluster with max intersection
@@ -382,18 +393,22 @@ def _assign_weak_points(self, mutual_neighbors, labels, cluster_parent):
382393
]
383394

384395
if not neighbor_roots:
396+
keep_mask[idx] = True
385397
continue
386398

387399
# Find majority vote
388400
unique_roots, counts = np.unique(neighbor_roots, return_counts=True)
389401
best_cluster_root = unique_roots[np.argmax(counts)]
390402

391-
labels[i] = best_cluster_root
403+
labels[sample_idx] = best_cluster_root
392404
newly_assigned_count += 1
393405

394406
if newly_assigned_count == 0:
395407
break
396408

409+
# Apply the mask to shrink the list for the next iteration.
410+
unclassified_indices = unclassified_indices[keep_mask]
411+
397412
def _flatten_union_find(self, labels, cluster_parent, n_samples):
398413
"""
399414
Resolves the Union-Find structure into a flat array of root representatives.

pixi.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)