1+ import gc
12import logging
3+ import os
24import pickle
35import warnings
46from argparse import ArgumentParser , Namespace
57from pathlib import Path
68from typing import Optional , Union
79
810import numpy as np
11+ import psutil
912
1013from .verb_registry import Verb , hyrax_verb
1114
@@ -95,13 +98,22 @@ def _run(self, input_dir: Optional[Union[Path, str]] = None):
9598 # If the input to umap is not of the shape [samples,input_dims] we reshape the input accordingly
9699 data_sample = inference_results [index_choices ].numpy ().reshape ((sample_size , - 1 ))
97100
101+ self ._log_memory_usage ("Before fitting umap" )
98102 logger .info ("Fitting the UMAP" )
99103 # Fit a single reducer on the sampled data
100104 self .reducer .fit (data_sample )
105+ self ._log_memory_usage ("After fitting umap" )
101106
102107 # Save the reducer to our results directory
103- with open (results_dir / "umap.pickle" , "wb" ) as f :
104- pickle .dump (self .reducer , f )
108+ if self .config ["umap" ]["save_fit_umap" ]:
109+ logger .info ("Saving fitted UMAP Reducer" )
110+ with open (results_dir / "umap.pickle" , "wb" ) as f :
111+ pickle .dump (self .reducer , f )
112+
113+ # Reclaim Memory
114+ del data_sample
115+ gc .collect ()
116+ self ._log_memory_usage ("After Garbage Collection" )
105117
106118 # Run all data through the reducer in batches, writing it out as we go.
107119 batch_size = self .config ["data_loader" ]["batch_size" ]
@@ -110,36 +122,49 @@ def _run(self, input_dir: Optional[Union[Path, str]] = None):
110122 all_indexes = np .arange (0 , total_length )
111123 all_ids = np .array (list (inference_results .ids ()))
112124
113- # Process pool to do all the transforms
114- # Use 'spawn' context to safely create subprocesses after
115- # OpenMP threads are initialized by data loader
116- # TODO: See discussion here https://github.com/lincc-frameworks/hyrax/pull/297
117- # Consider getting rid of spawn if it becomes a major bottleneck
118- with mp .get_context ("spawn" ).Pool (processes = mp .cpu_count ()) as pool :
119- # Generator expression that gives a batch tuple composed of:
120- # batch ids, inference results
121- args = (
122- (
123- all_ids [batch_indexes ],
124- # We flatten all dimensions of the input array except the dimension
125- # corresponding to batch elements. This ensures that all inputs to
126- # the UMAP algorithm are flattend per input item in the batch
127- inference_results [batch_indexes ].reshape (len (batch_indexes ), - 1 ),
128- )
129- for batch_indexes in np .array_split (all_indexes , num_batches )
125+ # Generator expression that gives a batch tuple composed of:
126+ # batch ids, inference results
127+ args = (
128+ (
129+ all_ids [batch_indexes ],
130+ # We flatten all dimensions of the input array except the dimension
131+ # corresponding to batch elements. This ensures that all inputs to
132+ # the UMAP algorithm are flattend per input item in the batch
133+ inference_results [batch_indexes ].reshape (len (batch_indexes ), - 1 ),
130134 )
135+ for batch_indexes in np .array_split (all_indexes , num_batches )
136+ )
131137
132- # iterate over the mapped results to write out the umapped points
133- # imap returns results as they complete so writing should complete in parallel for large datasets
134- for batch_ids , transformed_batch in tqdm (
135- pool .imap (self ._transform_batch , args ),
138+ if self .config ["umap" ]["parallel" ]:
139+ # Process pool loop
140+ # Use 'spawn' context to safely create subprocesses after
141+ # OpenMP threads are being opened by other processes in hyrax
142+ # Not using spawn causes the issue linked below
143+ # https://github.com/lincc-frameworks/hyrax/issues/291
144+ # TODO: Find more elegant solution than just using spawn
145+ with mp .get_context ("spawn" ).Pool (processes = mp .cpu_count ()) as pool :
146+ # iterate over the mapped results to write out the umapped points
147+ # imap returns results as they complete so writing should complete
148+ # in parallel for large datasets
149+ for batch_ids , transformed_batch in tqdm (
150+ pool .imap (self ._transform_batch , args ),
151+ desc = "Creating lower dimensional representation using UMAP:" ,
152+ total = num_batches ,
153+ ):
154+ umap_results .write_batch (batch_ids , transformed_batch )
155+ else :
156+ # Sequential loop
157+ for batch_ids , batch in tqdm (
158+ args ,
136159 desc = "Creating lower dimensional representation using UMAP:" ,
137160 total = num_batches ,
138161 ):
139- logger .debug ("Writing a batch out async..." )
162+ transformed_batch = self .reducer .transform (batch )
163+ self ._log_memory_usage (f"During transformation of batch of shape { batch .shape } " )
140164 umap_results .write_batch (batch_ids , transformed_batch )
141165
142166 umap_results .write_index ()
167+ logger .info ("Finished transforming all data through UMAP" )
143168
144169 def _transform_batch (self , batch_tuple : tuple ):
145170 """Private helper to transform a single batch
@@ -163,3 +188,21 @@ def _transform_batch(self, batch_tuple: tuple):
163188 warnings .simplefilter (action = "ignore" , category = FutureWarning )
164189 logger .debug ("Transforming a batch ..." )
165190 return (batch_ids , self .reducer .transform (batch ))
191+
192+ @staticmethod
193+ def _log_memory_usage (message : str = "" ):
194+ """
195+ Log the current resident set size (RSS) memory usage of the current process in gigabytes.
196+
197+ Parameters
198+ ----------
199+ message : str, optional
200+ A descriptive message to include in the log output for context.
201+
202+ Notes
203+ -----
204+ This method is intended for debugging and performance monitoring.
205+ """
206+ process = psutil .Process (os .getpid ())
207+ mem_gb = process .memory_info ().rss / 1024 ** 3
208+ logger .debug (f"{ message } | Memory usage: { mem_gb :.2f} GB" )
0 commit comments