Skip to content

Commit 2207be8

Browse files
authored
Merge pull request #218 from lucasimi/avoid-useless-recomputations
Setting explicit random_state to lens functions
2 parents a5c1810 + 9f3875b commit 2207be8

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

app/streamlit_app.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@
118118
)
119119

120120

121+
logger = st.logger.get_logger(__name__)
122+
123+
121124
def _check_limits_mapper_graph(mapper_graph):
122125
if LIMITS_ENABLED:
123126
num_nodes = mapper_graph.number_of_nodes()
@@ -298,22 +301,34 @@ def mapper_lens_input_section(X):
298301
value=2,
299302
min_value=1,
300303
)
304+
pca_random_state = st.number_input(
305+
"PCA random state",
306+
value=VD_SEED,
307+
)
301308
_, n_feats = X.shape
302309
if pca_n > n_feats:
303310
lens = X
304311
else:
305-
lens = PCA(n_components=pca_n).fit_transform(X)
312+
lens = PCA(n_components=pca_n, random_state=pca_random_state).fit_transform(
313+
X
314+
)
306315
elif lens_type == V_LENS_UMAP:
307316
umap_n = st.number_input(
308317
"UMAP Components",
309318
value=2,
310319
min_value=1,
311320
)
321+
umap_random_state = st.number_input(
322+
"UMAP random state",
323+
value=VD_SEED,
324+
)
312325
_, n_feats = X.shape
313326
if umap_n > n_feats:
314327
lens = X
315328
else:
316-
lens = UMAP(n_components=umap_n).fit_transform(X)
329+
lens = UMAP(
330+
n_components=umap_n, random_state=umap_random_state
331+
).fit_transform(X)
317332
return lens
318333

319334

@@ -492,6 +507,7 @@ def mapper_clustering_input_section():
492507
show_spinner="Computing Mapper",
493508
)
494509
def compute_mapper(mapper, X, y):
510+
logger.info("Generating Mapper graph")
495511
mapper_graph = mapper.fit_transform(X, y)
496512
return mapper_graph
497513

@@ -599,6 +615,7 @@ def plot_color_input_section(df_X, df_y):
599615
)
600616
def compute_mapper_plot(mapper_graph, dim, seed, iterations):
601617
_check_limits_mapper_graph(mapper_graph)
618+
logger.info("Generating Mapper plot")
602619
mapper_plot = MapperPlot(
603620
mapper_graph,
604621
dim,
@@ -640,6 +657,7 @@ def mapper_plot_section(mapper_graph):
640657
def compute_mapper_fig(
641658
mapper_plot, colors, node_size, cmap, _agg, agg_name, colors_feat
642659
):
660+
logger.info("Generating Mapper figure")
643661
mapper_fig = mapper_plot.plot_plotly(
644662
colors,
645663
node_size=node_size,

0 commit comments

Comments
 (0)