Skip to content

Commit c81b41c

Browse files
committed
Improve tqdm progress bar when sampling tables of known sizes
1 parent 9e2701e commit c81b41c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

subsetter/sampler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,7 @@ def __init__(self, source: DatabaseConfig, config: SamplerConfig) -> None:
692692
self.compact_columns: Dict[Tuple[str, str], Set[str]] = {}
693693
self.temp_tables = TempTableCreator()
694694
self.passthrough_tables: Set[str] = set()
695+
self.cached_table_sizes: Dict[Tuple[str, str], int] = {}
695696

696697
def sample(
697698
self,
@@ -866,6 +867,7 @@ def _materialize_tables(
866867
primary_key=table.primary_key,
867868
)
868869
)
870+
self.cached_table_sizes[(schema, table_name)] = rowcount
869871
LOGGER.info(
870872
"Materialized %d rows for %s.%s in temporary table",
871873
rowcount,
@@ -1005,9 +1007,9 @@ def _copy_results(
10051007

10061008
rows = 0
10071009

1008-
def _count_rows(result):
1010+
def _count_rows(result, total: Optional[int]):
10091011
nonlocal rows
1010-
for row in tqdm(result, desc="row progress", unit="rows"):
1012+
for row in tqdm(result, total=total, desc="row progress", unit="rows"):
10111013
# result_processor
10121014
rows += 1
10131015
yield row
@@ -1017,7 +1019,7 @@ def _count_rows(result):
10171019
schema,
10181020
table_name,
10191021
columns,
1020-
_count_rows(result),
1022+
_count_rows(result, self.cached_table_sizes.get((schema, table_name))),
10211023
filter_view=filter_view,
10221024
multiplier=(
10231025
1

0 commit comments

Comments
 (0)