Skip to content

Commit 97d8c7c

Browse files
committed
Reducing umap memory footprint
1 parent 33cf793 commit 97d8c7c

File tree

2 files changed

+73
-24
lines changed

2 files changed

+73
-24
lines changed

src/hyrax/hyrax_default_config.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,12 @@ inference_dir = false
226226
# Number of data points used to fit the umap transform.
227227
fit_sample_size = 1024
228228

229+
# Save the fitted umap as a pickle file
230+
save_fit_umap = true
231+
232+
# Use multiprocessing during transforming to umap space (More memory intensive)
233+
parallel = false
234+
229235
# Name of the umap implementation to use
230236
name = "umap.UMAP"
231237

src/hyrax/verbs/umap.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import gc
12
import logging
3+
import os
24
import pickle
35
import warnings
46
from argparse import ArgumentParser, Namespace
57
from pathlib import Path
68
from typing import Optional, Union
79

810
import numpy as np
11+
import psutil
912

1013
from .verb_registry import Verb, hyrax_verb
1114

@@ -95,13 +98,22 @@ def _run(self, input_dir: Optional[Union[Path, str]] = None):
9598
# If the input to umap is not of the shape [samples,input_dims] we reshape the input accordingly
9699
data_sample = inference_results[index_choices].numpy().reshape((sample_size, -1))
97100

101+
self._log_memory_usage("Before fitting umap")
98102
logger.info("Fitting the UMAP")
99103
# Fit a single reducer on the sampled data
100104
self.reducer.fit(data_sample)
105+
self._log_memory_usage("After fitting umap")
101106

102107
# Save the reducer to our results directory
103-
with open(results_dir / "umap.pickle", "wb") as f:
104-
pickle.dump(self.reducer, f)
108+
if self.config["umap"]["save_fit_umap"]:
109+
logger.info("Saving fitted UMAP Reducer")
110+
with open(results_dir / "umap.pickle", "wb") as f:
111+
pickle.dump(self.reducer, f)
112+
113+
# Reclaim Memory
114+
del data_sample
115+
gc.collect()
116+
self._log_memory_usage("After Garbage Collection")
105117

106118
# Run all data through the reducer in batches, writing it out as we go.
107119
batch_size = self.config["data_loader"]["batch_size"]
@@ -110,36 +122,49 @@ def _run(self, input_dir: Optional[Union[Path, str]] = None):
110122
all_indexes = np.arange(0, total_length)
111123
all_ids = np.array(list(inference_results.ids()))
112124

113-
# Process pool to do all the transforms
114-
# Use 'spawn' context to safely create subprocesses after
115-
# OpenMP threads are initialized by data loader
116-
# TODO: See discussion here https://github.com/lincc-frameworks/hyrax/pull/297
117-
# Consider getting rid of spawn if it becomes a major bottleneck
118-
with mp.get_context("spawn").Pool(processes=mp.cpu_count()) as pool:
119-
# Generator expression that gives a batch tuple composed of:
120-
# batch ids, inference results
121-
args = (
122-
(
123-
all_ids[batch_indexes],
124-
# We flatten all dimensions of the input array except the dimension
125-
# corresponding to batch elements. This ensures that all inputs to
126-
# the UMAP algorithm are flattend per input item in the batch
127-
inference_results[batch_indexes].reshape(len(batch_indexes), -1),
128-
)
129-
for batch_indexes in np.array_split(all_indexes, num_batches)
125+
# Generator expression that gives a batch tuple composed of:
126+
# batch ids, inference results
127+
args = (
128+
(
129+
all_ids[batch_indexes],
130+
# We flatten all dimensions of the input array except the dimension
131+
# corresponding to batch elements. This ensures that all inputs to
132+
# the UMAP algorithm are flattend per input item in the batch
133+
inference_results[batch_indexes].reshape(len(batch_indexes), -1),
130134
)
135+
for batch_indexes in np.array_split(all_indexes, num_batches)
136+
)
131137

132-
# iterate over the mapped results to write out the umapped points
133-
# imap returns results as they complete so writing should complete in parallel for large datasets
134-
for batch_ids, transformed_batch in tqdm(
135-
pool.imap(self._transform_batch, args),
138+
if self.config["umap"]["parallel"]:
139+
# Process pool loop
140+
# Use 'spawn' context to safely create subprocesses after
141+
# OpenMP threads are being opened by other processes in hyrax
142+
# Not using spawn causes the issue linked below
143+
# https://github.com/lincc-frameworks/hyrax/issues/291
144+
# TODO: Find more elegant solution than just using spawn
145+
with mp.get_context("spawn").Pool(processes=mp.cpu_count()) as pool:
146+
# iterate over the mapped results to write out the umapped points
147+
# imap returns results as they complete so writing should complete
148+
# in parallel for large datasets
149+
for batch_ids, transformed_batch in tqdm(
150+
pool.imap(self._transform_batch, args),
151+
desc="Creating lower dimensional representation using UMAP:",
152+
total=num_batches,
153+
):
154+
umap_results.write_batch(batch_ids, transformed_batch)
155+
else:
156+
# Sequential loop
157+
for batch_ids, batch in tqdm(
158+
args,
136159
desc="Creating lower dimensional representation using UMAP:",
137160
total=num_batches,
138161
):
139-
logger.debug("Writing a batch out async...")
162+
transformed_batch = self.reducer.transform(batch)
163+
self._log_memory_usage(f"During transformation of batch of shape {batch.shape}")
140164
umap_results.write_batch(batch_ids, transformed_batch)
141165

142166
umap_results.write_index()
167+
logger.info("Finished transforming all data through UMAP")
143168

144169
def _transform_batch(self, batch_tuple: tuple):
145170
"""Private helper to transform a single batch
@@ -163,3 +188,21 @@ def _transform_batch(self, batch_tuple: tuple):
163188
warnings.simplefilter(action="ignore", category=FutureWarning)
164189
logger.debug("Transforming a batch ...")
165190
return (batch_ids, self.reducer.transform(batch))
191+
192+
@staticmethod
193+
def _log_memory_usage(message: str = ""):
194+
"""
195+
Log the current resident set size (RSS) memory usage of the current process in gigabytes.
196+
197+
Parameters
198+
----------
199+
message : str, optional
200+
A descriptive message to include in the log output for context.
201+
202+
Notes
203+
-----
204+
This method is intended for debugging and performance monitoring.
205+
"""
206+
process = psutil.Process(os.getpid())
207+
mem_gb = process.memory_info().rss / 1024**3
208+
logger.debug(f"{message} | Memory usage: {mem_gb:.2f} GB")

0 commit comments

Comments
 (0)