Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Any,
Callable,
Dict,
Optional,
Union,
)
from urllib.parse import urlparse
Expand Down Expand Up @@ -128,11 +129,11 @@ def s3v4_rest_signer(properties: Properties, request: "AWSRequest", **_: Any) ->
SIGNERS: Dict[str, Callable[[Properties, "AWSRequest"], "AWSRequest"]] = {"S3V4RestSigner": s3v4_rest_signer}


def _file(_: Properties) -> LocalFileSystem:
def _file(_: Properties, __: Optional[str]) -> LocalFileSystem:
return LocalFileSystem(auto_mkdir=True)


def _s3(properties: Properties) -> AbstractFileSystem:
def _s3(properties: Properties, netloc: Optional[str]) -> AbstractFileSystem:
from s3fs import S3FileSystem

client_kwargs = {
Expand Down Expand Up @@ -179,7 +180,7 @@ def _s3(properties: Properties) -> AbstractFileSystem:
return fs


def _gs(properties: Properties) -> AbstractFileSystem:
def _gs(properties: Properties, netloc: Optional[str]) -> AbstractFileSystem:
# https://gcsfs.readthedocs.io/en/latest/api.html#gcsfs.core.GCSFileSystem
from gcsfs import GCSFileSystem

Expand All @@ -197,20 +198,25 @@ def _gs(properties: Properties) -> AbstractFileSystem:
)


def _adls(properties: Properties) -> AbstractFileSystem:
def _adls(properties: Properties, netloc: Optional[str]) -> AbstractFileSystem:
# https://fsspec.github.io/adlfs/api/

from adlfs import AzureBlobFileSystem
from azure.core.credentials import AccessToken
from azure.core.credentials_async import AsyncTokenCredential

for key, sas_token in {
key.replace(f"{ADLS_SAS_TOKEN}.", ""): value for key, value in properties.items() if key.startswith(ADLS_SAS_TOKEN)
}.items():
if ADLS_ACCOUNT_NAME not in properties:
properties[ADLS_ACCOUNT_NAME] = key.split(".")[0]
if ADLS_SAS_TOKEN not in properties:
properties[ADLS_SAS_TOKEN] = sas_token
# https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction-abfs-uri#uri-syntax
if netloc:
account_uri = netloc.split("@")[-1]
else:
account_uri = None

if not properties.get(ADLS_ACCOUNT_NAME) and account_uri:
properties[ADLS_ACCOUNT_NAME] = account_uri.split(".")[0]

# Fixes https://github.com/apache/iceberg-python/issues/1146
if not properties.get(ADLS_SAS_TOKEN) and account_uri:
properties[ADLS_SAS_TOKEN] = properties.get(f"{ADLS_SAS_TOKEN}.{account_uri}")

class StaticTokenCredential(AsyncTokenCredential):
_DEFAULT_EXPIRY_SECONDS = 3600
Expand Down Expand Up @@ -243,7 +249,7 @@ async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
)


def _hf(properties: Properties) -> AbstractFileSystem:
def _hf(properties: Properties, netloc: Optional[str]) -> AbstractFileSystem:
from huggingface_hub import HfFileSystem

return HfFileSystem(
Expand Down Expand Up @@ -368,7 +374,7 @@ class FsspecFileIO(FileIO):
def __init__(self, properties: Properties):
self._scheme_to_fs = {}
self._scheme_to_fs.update(SCHEME_TO_FS)
self.get_fs: Callable[[str], AbstractFileSystem] = lru_cache(self._get_fs)
self.get_fs: Callable[[str, Optional[str]], AbstractFileSystem] = lru_cache(self._get_fs)
super().__init__(properties=properties)

def new_input(self, location: str) -> FsspecInputFile:
Expand All @@ -381,7 +387,7 @@ def new_input(self, location: str) -> FsspecInputFile:
FsspecInputFile: An FsspecInputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
return FsspecInputFile(location=location, fs=fs)

def new_output(self, location: str) -> FsspecOutputFile:
Expand All @@ -394,7 +400,7 @@ def new_output(self, location: str) -> FsspecOutputFile:
FsspecOutputFile: An FsspecOutputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
return FsspecOutputFile(location=location, fs=fs)

def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
Expand All @@ -411,14 +417,14 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
str_location = location

uri = urlparse(str_location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
fs.rm(str_location)

def _get_fs(self, scheme: str) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme."""
def _get_fs(self, scheme: str, netloc: Optional[str] = None) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme and netloc."""
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")
return self._scheme_to_fs[scheme](self.properties)
return self._scheme_to_fs[scheme](self.properties, netloc)

def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the FsSpecFileIO fields used when pickling."""
Expand Down