Skip to content

Commit 8a1896b

Browse files
authored
Clean up marimo progress bar (#8184)
* Removed extra "s" from draws in progress bar * Remove s from step_name * Removed second Draws column * Fix draw column: restore for Rich, deduplicate for marimo * Fixed test failure * Do not zero out final speed * Remove rounding for sampling <1s; simplify rich progress label * Remove guardrails on progbar stat columns; update column names
1 parent 6c48a0b commit 8a1896b

File tree

4 files changed

+25
-23
lines changed

4 files changed

+25
-23
lines changed

pymc/progress_bar/marimo_progress.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
def format_time(seconds: float) -> str:
2222
"""Format elapsed time as mm:ss or hh:mm:ss."""
23+
if seconds < 1:
24+
return f"{seconds:.1f}s"
2325
minutes, secs = divmod(int(seconds), 60)
2426
hours, minutes = divmod(minutes, 60)
2527
if hours > 0:
@@ -201,24 +203,26 @@ def _render_html(self) -> str:
201203
"""Generate HTML for all progress bars as a table with headers."""
202204
stat_keys = []
203205
if self.full_stats and self._task_state and self._task_state[0]["stats"]:
204-
stat_keys = list(self._task_state[0]["stats"].keys())
206+
stat_keys = [
207+
k for k in self._task_state[0]["stats"].keys() if k != self.step_name.lower()
208+
]
205209

206210
header_cells = ["Progress", self.step_name]
207211

208-
abbreviations = {
209-
"divergences": "Div",
210-
"diverging": "Div",
211-
"step_size": "Step",
212-
"tree_size": "Tree",
213-
"tree_depth": "Depth",
214-
"n_steps": "Steps",
215-
"energy_error": "E-err",
216-
"max_energy_error": "Max-E",
217-
"mean_tree_accept": "Accept",
218-
"scaling": "Scale",
212+
column_names = {
213+
"divergences": "Divergences",
214+
"diverging": "Divergences",
215+
"step_size": "Step size",
216+
"tree_size": "Grad evals",
217+
"tree_depth": "Tree depth",
218+
"n_steps": "Grad evals",
219+
"energy_error": "Energy error",
220+
"max_energy_error": "Max energy error",
221+
"mean_tree_accept": "Mean tree accept",
222+
"scaling": "Scaling",
219223
"tune": "Tune",
220224
}
221-
header_cells += [abbreviations.get(k, k[:6].capitalize()) for k in stat_keys]
225+
header_cells += [column_names.get(k, k.replace("_", " ").capitalize()) for k in stat_keys]
222226

223227
header_cells += ["Speed", "Elapsed"]
224228

pymc/progress_bar/progress.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class MCMCProgressBarManager(ProgressBarManager):
214214
Tracks progress via draw count with support for tuning phases.
215215
"""
216216

217-
step_name: str = "Draws"
217+
step_name: str = "Draw"
218218

219219
def __init__(
220220
self,
@@ -249,7 +249,7 @@ def __init__(
249249
)
250250

251251
progress_columns, progress_stats = step_method._progressbar_config(chains)
252-
progress_stats["draws"] = [0] * chains
252+
progress_stats["draw"] = [0] * chains
253253

254254
self.progress_stats = progress_stats
255255
self.update_stats_functions = step_method._make_progressbar_update_functions()
@@ -318,7 +318,7 @@ def update(self, chain_idx: int, is_last: bool, draw: int, tuning: bool, stats)
318318
chain_idx = 0
319319

320320
failing, all_step_stats = self._extract_stats(stats)
321-
all_step_stats["draws"] = draw + 1 if not self.combined_progress else draw
321+
all_step_stats["draw"] = draw + 1 if not self.combined_progress else draw
322322

323323
self._backend.update(
324324
task_id=chain_idx,

pymc/progress_bar/rich_progress.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _create_progress_bar(
199199
columns += [
200200
TextColumn(
201201
"{task.fields[sampling_speed]:0.2f} {task.fields[speed_unit]}",
202-
table_column=Column("Sampling Speed", ratio=1),
202+
table_column=Column("Speed", ratio=1),
203203
),
204204
TimeElapsedColumn(table_column=Column("Elapsed", ratio=1)),
205205
TimeRemainingColumn(table_column=Column("Remaining", ratio=1)),
@@ -285,8 +285,6 @@ def update(
285285
self._progress.update(
286286
rich_task_id,
287287
completed=self.total if not self.combined else self.total * self.n_bars,
288-
sampling_speed=0,
289-
speed_unit="",
290288
failing=failing,
291289
refresh=True,
292290
**stats,

tests/progress_bar/test_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def capturing_init(self, *args, **kwargs):
6363
for chain in range(chains):
6464
task = _get_task(manager, chain)
6565
assert task.completed == task.total == total
66-
assert task.fields["draws"] == total
66+
assert task.fields["draw"] == total
6767

6868

6969
def test_mcmc_combined_bar_ends_at_total():
@@ -92,7 +92,7 @@ def capturing_init(self, *args, **kwargs):
9292
total = (draws + tune) * chains
9393
task = _get_task(manager, 0)
9494
assert task.completed == task.total == total
95-
assert task.fields["draws"] == total
95+
assert task.fields["draw"] == total
9696

9797

9898
def test_mcmc_draws_stat_shows_completed_count():
@@ -104,13 +104,13 @@ def test_mcmc_draws_stat_shows_completed_count():
104104

105105
with manager:
106106
manager.update(chain_idx=0, is_last=False, draw=0, tuning=True, stats=NUTS_DUMMY_STATS)
107-
assert _get_task(manager).fields["draws"] == 1
107+
assert _get_task(manager).fields["draw"] == 1
108108

109109
for i in range(1, 15):
110110
manager.update(
111111
chain_idx=0, is_last=i == 14, draw=i, tuning=i < 5, stats=NUTS_DUMMY_STATS
112112
)
113-
assert _get_task(manager).fields["draws"] == 15
113+
assert _get_task(manager).fields["draw"] == 15
114114

115115

116116
def test_smc_bar_starts_at_zero_ends_at_one(imh_kernel):

0 commit comments

Comments
 (0)