Skip to content

Update file extension and add Link.ext #1265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pystac/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def validate_owner_has_extension(
@classmethod
def ensure_owner_has_extension(
cls,
asset: pystac.Asset | AssetDefinition,
asset_or_link: pystac.Asset | AssetDefinition | pystac.Link,
add_if_missing: bool = False,
) -> None:
"""Given an :class:`~pystac.Asset`, checks if the asset's owner has this
Expand All @@ -206,15 +206,15 @@ def ensure_owner_has_extension(
STACError : If ``add_if_missing`` is ``True`` and ``asset.owner`` is
``None``.
"""
if asset.owner is None:
if asset_or_link.owner is None:
if add_if_missing:
raise pystac.STACError(
"Attempted to use add_if_missing=True for an Asset or ItemAsset "
f"Attempted to use add_if_missing=True for a {type(asset_or_link)} "
"with no owner. Use .set_owner or set add_if_missing=False."
)
else:
return
return cls.ensure_has_extension(cast(S, asset.owner), add_if_missing)
return cls.ensure_has_extension(cast(S, asset_or_link.owner), add_if_missing)

@classmethod
def validate_has_extension(cls, obj: S, add_if_missing: bool = False) -> None:
Expand Down
58 changes: 36 additions & 22 deletions pystac/extensions/ext.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeVar, cast

from pystac import Asset, Catalog, Collection, Item, STACError
from pystac import Asset, Catalog, Collection, Item, Link, STACError
from pystac.extensions.classification import ClassificationExtension
from pystac.extensions.datacube import DatacubeExtension
from pystac.extensions.eo import EOExtension
Expand All @@ -22,7 +22,8 @@
from pystac.extensions.view import ViewExtension
from pystac.extensions.xarray_assets import XarrayAssetsExtension

T = TypeVar("T", Asset, AssetDefinition)
T = TypeVar("T", Asset, AssetDefinition, Link)
U = TypeVar("U", Asset, AssetDefinition)

EXTENSION_NAMES = Literal[
"classification",
Expand Down Expand Up @@ -196,14 +197,14 @@ def xarray(self) -> XarrayAssetsExtension[Item]:
return XarrayAssetsExtension.ext(self.stac_object)


class _AssetExt(Generic[T]):
class _AssetsExt(Generic[T]):
stac_object: T

def has(self, name: EXTENSION_NAMES) -> bool:
if self.stac_object.owner is None:
raise STACError(
f"Attempted to add extension='{name}' for an Asset with no owner. "
"Use Asset.set_owner and then try to add the extension again."
f"Attempted to add extension='{name}' for an object with no owner. "
"Use `.set_owner` and then try to add the extension again."
)
else:
return cast(
Expand All @@ -213,67 +214,71 @@ def has(self, name: EXTENSION_NAMES) -> bool:
def add(self, name: EXTENSION_NAMES) -> None:
if self.stac_object.owner is None:
raise STACError(
f"Attempted to add extension='{name}' for an Asset with no owner. "
"Use Asset.set_owner and then try to add the extension again."
f"Attempted to add extension='{name}' for an object with no owner. "
"Use `.set_owner` and then try to add the extension again."
)
else:
_get_class_by_name(name).add_to(self.stac_object.owner)

def remove(self, name: EXTENSION_NAMES) -> None:
if self.stac_object.owner is None:
raise STACError(
f"Attempted to remove extension='{name}' for an Asset with no owner. "
"Use Asset.set_owner and then try to remove the extension again."
f"Attempted to remove extension='{name}' for an object with no owner. "
"Use `.set_owner` and then try to remove the extension again."
)
else:
_get_class_by_name(name).remove_from(self.stac_object.owner)


class _AssetExt(_AssetsExt[U]):
stac_object: U

@property
def classification(self) -> ClassificationExtension[T]:
def classification(self) -> ClassificationExtension[U]:
return ClassificationExtension.ext(self.stac_object)

@property
def cube(self) -> DatacubeExtension[T]:
def cube(self) -> DatacubeExtension[U]:
return DatacubeExtension.ext(self.stac_object)

@property
def eo(self) -> EOExtension[T]:
def eo(self) -> EOExtension[U]:
return EOExtension.ext(self.stac_object)

@property
def pc(self) -> PointcloudExtension[T]:
def pc(self) -> PointcloudExtension[U]:
return PointcloudExtension.ext(self.stac_object)

@property
def proj(self) -> ProjectionExtension[T]:
def proj(self) -> ProjectionExtension[U]:
return ProjectionExtension.ext(self.stac_object)

@property
def raster(self) -> RasterExtension[T]:
def raster(self) -> RasterExtension[U]:
return RasterExtension.ext(self.stac_object)

@property
def sar(self) -> SarExtension[T]:
def sar(self) -> SarExtension[U]:
return SarExtension.ext(self.stac_object)

@property
def sat(self) -> SatExtension[T]:
def sat(self) -> SatExtension[U]:
return SatExtension.ext(self.stac_object)

@property
def storage(self) -> StorageExtension[T]:
def storage(self) -> StorageExtension[U]:
return StorageExtension.ext(self.stac_object)

@property
def table(self) -> TableExtension[T]:
def table(self) -> TableExtension[U]:
return TableExtension.ext(self.stac_object)

@property
def version(self) -> BaseVersionExtension[T]:
def version(self) -> BaseVersionExtension[U]:
return BaseVersionExtension.ext(self.stac_object)

@property
def view(self) -> ViewExtension[T]:
def view(self) -> ViewExtension[U]:
return ViewExtension.ext(self.stac_object)


Expand All @@ -282,7 +287,7 @@ class AssetExt(_AssetExt[Asset]):
stac_object: Asset

@property
def file(self) -> FileExtension:
def file(self) -> FileExtension[Asset]:
return FileExtension.ext(self.stac_object)

@property
Expand All @@ -297,3 +302,12 @@ def xarray(self) -> XarrayAssetsExtension[Asset]:
@dataclass
class ItemAssetExt(_AssetExt[AssetDefinition]):
stac_object: AssetDefinition


@dataclass
class LinkExt(_AssetsExt[Link]):
stac_object: Link

@property
def file(self) -> FileExtension[Link]:
return FileExtension.ext(self.stac_object)
Loading