Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cirro/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,11 @@ def upload_directory(directory: PathLike,
break


def download_directory(directory: str, files: List[str], s3_client: S3Client, bucket: str, prefix: str):
def download_directory(directory: str, files: List[str], s3_client: S3Client, bucket: str, prefix: str) -> List[Path]:
"""
@private
"""
local_paths = []
for file in files:
key = f'{prefix}/{file}'.lstrip('/')
local_path = Path(directory, file).expanduser()
Expand All @@ -202,6 +203,8 @@ def download_directory(directory: str, files: List[str], s3_client: S3Client, bu
s3_client.download_file(local_path=local_path,
bucket=bucket,
key=key)
local_paths += local_path
return local_paths


def get_checksum(file: Path, checksum_name: str, chunk_size=1024 * 1024) -> str:
Expand Down
28 changes: 21 additions & 7 deletions cirro/sdk/file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gzip
from io import BytesIO, StringIO
from pathlib import Path
from typing import List

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -178,17 +179,22 @@ def read_bytes(self) -> BytesIO:
"""Get a generic BytesIO object representing the Data Portal File, to be passed into readers."""
return BytesIO(self._get())

def download(self, download_location: str = None):
"""Download the file to a local directory."""
def download(self, download_location: str = None) -> Path:
"""
Download the file to a local directory.

Returns:
Path to download file
"""

if download_location is None:
raise DataPortalInputError("Must provide download location")

self._client.file.download_files(
return self._client.file.download_files(
self._file.access_context,
download_location,
[self.relative_path]
)
)[0]

def validate(self, local_path: PathLike):
"""
Expand Down Expand Up @@ -221,10 +227,18 @@ def is_valid(self, local_path: PathLike) -> bool:

class DataPortalFiles(DataPortalAssets[DataPortalFile]):
"""Collection of DataPortalFile objects."""

asset_name = "file"

def download(self, download_location: str = None) -> None:
"""Download the collection of files to a local directory."""
def download(self, download_location: str = None) -> List[Path]:
"""
Download the collection of files to a local directory.

Returns:
List of paths to downloaded files.
"""

local_paths = []
for f in self:
f.download(download_location)
local_paths += f.download(download_location)
return local_paths
6 changes: 4 additions & 2 deletions cirro/services/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,20 @@ def upload_files(self,
max_retries=self.transfer_retries
)

def download_files(self, access_context: FileAccessContext, directory: str, files: List[str]) -> None:
def download_files(self, access_context: FileAccessContext, directory: str, files: List[str]) -> List[Path]:
"""
Download a list of files to the specified directory

Args:
access_context (cirro.models.file.FileAccessContext): File access context, use class methods to generate
directory (str): download location
files (List[str]): relative path of files to download
Returns:
List of paths to downloaded files
"""
s3_client = self._generate_s3_client(access_context)

download_directory(
return download_directory(
directory,
files,
s3_client,
Expand Down
Loading