Skip to content

Commit 446f553

Browse files
committed
feat: progress reporting
1 parent 74ba33a commit 446f553

File tree

15 files changed

+371
-39
lines changed

15 files changed

+371
-39
lines changed

.github/workflows/ci.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ jobs:
3232
with:
3333
python-version: ${{ matrix.python-version }}
3434
cache: "pip"
35+
- name: Update pip
36+
run: pip install -U pip
3537
- name: Install with development dependencies
3638
run: pip install .[cli,dev]
3739
- name: Check with pre-commit
@@ -48,6 +50,8 @@ jobs:
4850
with:
4951
python-version: "3.11"
5052
cache: "pip"
53+
- name: Update pip
54+
run: pip install -U pip
5155
- name: Install with development dependencies
5256
run: pip install .[cli,dev]
5357
- name: Install minimum versions of dependencies
@@ -66,6 +70,8 @@ jobs:
6670
with:
6771
python-version: "3.11"
6872
cache: "pip"
73+
- name: Update pip
74+
run: pip install -U pip
6975
- name: Install with development dependencies
7076
run: pip install .[cli,dev]
7177
- name: Test with coverage

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ repos:
1212
- pystac
1313
- pytest
1414
- types-aiofiles
15+
- types-python-dateutil
16+
- types-tqdm
1517
- repo: https://github.com/charliermarsh/ruff-pre-commit
1618
rev: "v0.0.278"
1719
hooks:

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1414
- `Client.from_config` and `Client.close` ([#46](https://github.com/stac-utils/stac-asset/pull/46))
1515
- Retry configuration for S3 ([#47](https://github.com/stac-utils/stac-asset/pull/47))
1616
- `Collection` download ([#50](https://github.com/stac-utils/stac-asset/pull/50))
17+
- Progress reporting ([#55](https://github.com/stac-utils/stac-asset/pull/55))
1718

1819
### Changed
1920

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ dependencies = [
2121
"aiobotocore>=2.5.0",
2222
"aiohttp>=3.8.4",
2323
"pystac>=1.7.3",
24+
"python-dateutil>=2.7.0",
2425
"yarl>=1.9.2",
2526
]
2627

2728
[project.optional-dependencies]
28-
cli = ["click~=8.1.5", "click-logging~=1.0.1"]
29+
cli = ["click~=8.1.5", "click-logging~=1.0.1", "tqdm~=4.65.1"]
2930
dev = [
3031
"black~=23.3",
3132
"mypy~=1.3",
@@ -35,6 +36,8 @@ dev = [
3536
"pytest-cov~=4.1",
3637
"ruff==0.0.282",
3738
"types-aiofiles~=23.1",
39+
"types-python-dateutil~=2.8.19",
40+
"types-tqdm~=4.65.0",
3841
]
3942
docs = ["pydata-sphinx-theme~=0.13", "sphinx~=7.0"]
4043

src/stac_asset/_cli.py

Lines changed: 138 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,36 @@
33
import logging
44
import os
55
import sys
6-
from typing import List, Optional, Union
6+
from asyncio import Queue
7+
from dataclasses import dataclass
8+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
79

810
import click
911
import click_logging
12+
import tqdm
1013
from pystac import Item, ItemCollection
1114

1215
from . import Config, functions
1316
from .config import DEFAULT_S3_MAX_ATTEMPTS, DEFAULT_S3_RETRY_MODE
17+
from .messages import (
18+
ErrorAssetDownload,
19+
FinishAssetDownload,
20+
OpenUrl,
21+
StartAssetDownload,
22+
WriteChunk,
23+
)
1424

1525
logger = logging.getLogger(__name__)
1626
click_logging.basic_config(logger)
1727

28+
# Needed until we drop Python 3.8
29+
if TYPE_CHECKING:
30+
AnyQueue = Queue[Any]
31+
Tqdm = tqdm.tqdm[Any]
32+
else:
33+
AnyQueue = Queue
34+
Tqdm = tqdm.tqdm
35+
1836

1937
@click.group()
2038
def cli() -> None:
@@ -111,6 +129,36 @@ def download(
111129
112130
$ stac-asset download -i asset-key-to-include item.json
113131
"""
132+
asyncio.run(
133+
download_async(
134+
href,
135+
directory,
136+
alternate_assets,
137+
include,
138+
exclude,
139+
file_name,
140+
quiet,
141+
s3_requester_pays,
142+
s3_retry_mode,
143+
s3_max_attempts,
144+
warn,
145+
)
146+
)
147+
148+
149+
async def download_async(
150+
href: Optional[str],
151+
directory: Optional[str],
152+
alternate_assets: List[str],
153+
include: List[str],
154+
exclude: List[str],
155+
file_name: Optional[str],
156+
quiet: bool,
157+
s3_requester_pays: bool,
158+
s3_retry_mode: str,
159+
s3_max_attempts: int,
160+
warn: bool,
161+
) -> None:
114162
config = Config(
115163
alternate_assets=alternate_assets,
116164
include=include,
@@ -125,39 +173,58 @@ def download(
125173
if href is None or href == "-":
126174
input_dict = json.load(sys.stdin)
127175
else:
128-
input_dict = json.loads(asyncio.run(read_file(href, config)))
176+
input_dict = json.loads(await read_file(href, config))
129177
if directory is None:
130-
directory = os.getcwd()
178+
directory_str = os.getcwd()
179+
else:
180+
directory_str = str(directory)
181+
182+
if quiet:
183+
queue = None
184+
else:
185+
queue = Queue()
131186

132187
type_ = input_dict.get("type")
133188
if type_ is None:
134-
print("ERROR: missing 'type' field on input dictionary", file=sys.stderr)
189+
if not quiet:
190+
print("ERROR: missing 'type' field on input dictionary", file=sys.stderr)
135191
sys.exit(1)
136192
elif type_ == "Feature":
137193
item = Item.from_dict(input_dict)
138194
if href:
139195
item.set_self_href(href)
140196
item.make_asset_hrefs_absolute()
141-
output: Union[Item, ItemCollection] = asyncio.run(
142-
functions.download_item(
197+
198+
async def download() -> Union[Item, ItemCollection]:
199+
return await functions.download_item(
143200
item,
144-
directory,
201+
directory_str,
145202
config=config,
203+
queue=queue,
146204
)
147-
)
205+
148206
elif type_ == "FeatureCollection":
149207
item_collection = ItemCollection.from_dict(input_dict)
150-
output = asyncio.run(
151-
functions.download_item_collection(
208+
209+
async def download() -> Union[Item, ItemCollection]:
210+
return await functions.download_item_collection(
152211
item_collection,
153-
directory,
212+
directory_str,
154213
config=config,
214+
queue=queue,
155215
)
156-
)
216+
157217
else:
158-
print(f"ERROR: unsupported 'type' field: {type_}", file=sys.stderr)
218+
if not quiet:
219+
print(f"ERROR: unsupported 'type' field: {type_}", file=sys.stderr)
159220
sys.exit(2)
160221

222+
task = asyncio.create_task(report_progress(queue))
223+
output = await download()
224+
if queue:
225+
await queue.put(None)
226+
await task
227+
161228
if not quiet:
162229
json.dump(output.to_dict(transform_hrefs=False), sys.stdout)
163230

@@ -170,3 +237,61 @@ async def read_file(href: str, config: Config) -> bytes:
170237
async for chunk in client.open_href(href):
171238
data += chunk
172239
return data
240+
241+
242+
async def report_progress(queue: Optional[AnyQueue]) -> None:
243+
if queue is None:
244+
return
245+
downloads: Dict[str, Download] = dict()
246+
while True:
247+
message = await queue.get()
248+
if isinstance(message, StartAssetDownload):
249+
progress_bar = tqdm.tqdm(
250+
position=len(downloads),
251+
unit="B",
252+
unit_scale=True,
253+
unit_divisor=1024,
254+
leave=False,
255+
)
256+
if message.item_id:
257+
description = f"{message.item_id} [{message.key}]"
258+
else:
259+
description = message.key
260+
progress_bar.set_description_str(description)
261+
downloads[message.href] = Download(
262+
key=message.key,
263+
item_id=message.item_id,
264+
href=message.href,
265+
path=str(message.path),
266+
progress_bar=progress_bar,
267+
)
268+
elif isinstance(message, OpenUrl):
269+
download = downloads.get(str(message.url))
270+
if download:
271+
if message.size:
272+
download.progress_bar.reset(total=message.size)
273+
elif isinstance(message, FinishAssetDownload):
274+
download = downloads.get(message.href)
275+
if download:
276+
download.progress_bar.close()
277+
elif isinstance(message, ErrorAssetDownload):
278+
download = downloads.get(message.href)
279+
if download:
280+
download.progress_bar.close()
281+
elif isinstance(message, WriteChunk):
282+
download = downloads.get(message.href)
283+
if download:
284+
download.progress_bar.update(message.size)
285+
elif message is None:
286+
for download in downloads.values():
287+
download.progress_bar.close()
288+
return
289+
290+
291+
@dataclass
292+
class Download:
293+
key: str
294+
item_id: Optional[str]
295+
href: str
296+
path: str
297+
progress_bar: Tqdm

0 commit comments

Comments
 (0)