33import logging
44import os
55import sys
6- from typing import List , Optional , Union
6+ from asyncio import Queue
7+ from dataclasses import dataclass
8+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
79
810import click
911import click_logging
12+ import tqdm
1013from pystac import Item , ItemCollection
1114
1215from . import Config , functions
1316from .config import DEFAULT_S3_MAX_ATTEMPTS , DEFAULT_S3_RETRY_MODE
17+ from .messages import (
18+ ErrorAssetDownload ,
19+ FinishAssetDownload ,
20+ OpenUrl ,
21+ StartAssetDownload ,
22+ WriteChunk ,
23+ )
1424
1525logger = logging .getLogger (__name__ )
1626click_logging .basic_config (logger )
1727
28+ # Needed until we drop Python 3.8
29+ if TYPE_CHECKING :
30+ AnyQueue = Queue [Any ]
31+ Tqdm = tqdm .tqdm [Any ]
32+ else :
33+ AnyQueue = Queue
34+ Tqdm = tqdm .tqdm
35+
1836
1937@click .group ()
2038def cli () -> None :
@@ -111,6 +129,36 @@ def download(
111129
112130 $ stac-asset download -i asset-key-to-include item.json
113131 """
132+ asyncio .run (
133+ download_async (
134+ href ,
135+ directory ,
136+ alternate_assets ,
137+ include ,
138+ exclude ,
139+ file_name ,
140+ quiet ,
141+ s3_requester_pays ,
142+ s3_retry_mode ,
143+ s3_max_attempts ,
144+ warn ,
145+ )
146+ )
147+
148+
149+ async def download_async (
150+ href : Optional [str ],
151+ directory : Optional [str ],
152+ alternate_assets : List [str ],
153+ include : List [str ],
154+ exclude : List [str ],
155+ file_name : Optional [str ],
156+ quiet : bool ,
157+ s3_requester_pays : bool ,
158+ s3_retry_mode : str ,
159+ s3_max_attempts : int ,
160+ warn : bool ,
161+ ) -> None :
114162 config = Config (
115163 alternate_assets = alternate_assets ,
116164 include = include ,
@@ -125,39 +173,58 @@ def download(
125173 if href is None or href == "-" :
126174 input_dict = json .load (sys .stdin )
127175 else :
128- input_dict = json .loads (asyncio . run ( read_file (href , config ) ))
176+ input_dict = json .loads (await read_file (href , config ))
129177 if directory is None :
130- directory = os .getcwd ()
178+ directory_str = os .getcwd ()
179+ else :
180+ directory_str = str (directory )
181+
182+ if quiet :
183+ queue = None
184+ else :
185+ queue = Queue ()
131186
132187 type_ = input_dict .get ("type" )
133188 if type_ is None :
134- print ("ERROR: missing 'type' field on input dictionary" , file = sys .stderr )
189+ if not quiet :
190+ print ("ERROR: missing 'type' field on input dictionary" , file = sys .stderr )
135191 sys .exit (1 )
136192 elif type_ == "Feature" :
137193 item = Item .from_dict (input_dict )
138194 if href :
139195 item .set_self_href (href )
140196 item .make_asset_hrefs_absolute ()
141- output : Union [Item , ItemCollection ] = asyncio .run (
142- functions .download_item (
197+
198+ async def download () -> Union [Item , ItemCollection ]:
199+ return await functions .download_item (
143200 item ,
144- directory ,
201+ directory_str ,
145202 config = config ,
203+ queue = queue ,
146204 )
147- )
205+
148206 elif type_ == "FeatureCollection" :
149207 item_collection = ItemCollection .from_dict (input_dict )
150- output = asyncio .run (
151- functions .download_item_collection (
208+
209+ async def download () -> Union [Item , ItemCollection ]:
210+ return await functions .download_item_collection (
152211 item_collection ,
153- directory ,
212+ directory_str ,
154213 config = config ,
214+ queue = queue ,
155215 )
156- )
216+
157217 else :
158- print (f"ERROR: unsupported 'type' field: { type_ } " , file = sys .stderr )
218+ if not quiet :
219+ print (f"ERROR: unsupported 'type' field: { type_ } " , file = sys .stderr )
159220 sys .exit (2 )
160221
222+ task = asyncio .create_task (report_progress (queue ))
223+ output = await download ()
224+ if queue :
225+ await queue .put (None )
226+ await task
227+
161228 if not quiet :
162229 json .dump (output .to_dict (transform_hrefs = False ), sys .stdout )
163230
@@ -170,3 +237,61 @@ async def read_file(href: str, config: Config) -> bytes:
170237 async for chunk in client .open_href (href ):
171238 data += chunk
172239 return data
240+
241+
242+ async def report_progress (queue : Optional [AnyQueue ]) -> None :
243+ if queue is None :
244+ return
245+ downloads : Dict [str , Download ] = dict ()
246+ while True :
247+ message = await queue .get ()
248+ if isinstance (message , StartAssetDownload ):
249+ progress_bar = tqdm .tqdm (
250+ position = len (downloads ),
251+ unit = "B" ,
252+ unit_scale = True ,
253+ unit_divisor = 1024 ,
254+ leave = False ,
255+ )
256+ if message .item_id :
257+ description = f"{ message .item_id } [{ message .key } ]"
258+ else :
259+ description = message .key
260+ progress_bar .set_description_str (description )
261+ downloads [message .href ] = Download (
262+ key = message .key ,
263+ item_id = message .item_id ,
264+ href = message .href ,
265+ path = str (message .path ),
266+ progress_bar = progress_bar ,
267+ )
268+ elif isinstance (message , OpenUrl ):
269+ download = downloads .get (str (message .url ))
270+ if download :
271+ if message .size :
272+ download .progress_bar .reset (total = message .size )
273+ elif isinstance (message , FinishAssetDownload ):
274+ download = downloads .get (message .href )
275+ if download :
276+ download .progress_bar .close ()
277+ elif isinstance (message , ErrorAssetDownload ):
278+ download = downloads .get (message .href )
279+ if download :
280+ download .progress_bar .close ()
281+ elif isinstance (message , WriteChunk ):
282+ download = downloads .get (message .href )
283+ if download :
284+ download .progress_bar .update (message .size )
285+ elif message is None :
286+ for download in downloads .values ():
287+ download .progress_bar .close ()
288+ return
289+
290+
291+ @dataclass
292+ class Download :
293+ key : str
294+ item_id : Optional [str ]
295+ href : str
296+ path : str
297+ progress_bar : Tqdm
0 commit comments