diff --git a/pystac/extensions/base.py b/pystac/extensions/base.py index 826eef48a..0209feed4 100644 --- a/pystac/extensions/base.py +++ b/pystac/extensions/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations +import warnings from abc import ABC, abstractmethod from typing import ( Any, @@ -15,6 +17,75 @@ import pystac +# Exception registration code ported with modifications from xarray +# https://github.com/pydata/xarray/blob/master/xarray/core/extensions.py + + +class ExtensionRegistrationWarning(Warning): + """Warning for conflicts in extension registration.""" + + +class _CachedExtension: + """Custom property-like object (descriptor) for caching extensions.""" + + def __init__(self, name, extension): + self._name = name + self._extension = extension + + def __get__(self, obj, cls): + if obj is None: + # we're accessing the attribute of the class, i.e., Item.proj + return self._extension + + try: + cache = obj._cache + except AttributeError: + cache = obj._cache = {} + + try: + return cache[self._name] + except KeyError: + pass + + try: + extension_obj = self._extension.ext(obj) + except pystac.ExtensionNotImplemented as e: + e.args = ( + f"{e.args[0]}\nHint: to add a new extension to this pystac object use:" + f"``{self._extension.__name__}.add_to(obj)``", + ) + raise e + except AttributeError: + # __getattr__ on data object will swallow any AttributeErrors + # raised when initializing the extension, so we need to raise as + # something else: + raise RuntimeError(f"error initializing {self._name!r} extension.") + + cache[self._name] = extension_obj + return extension_obj + + def __set__(self, obj, *args): + raise NotImplementedError( + "To add a new extension to this pystac object use: " + f"``{self._extension.__name__}.{self._name}.add_to(obj)``" + ) + + +def register_extension(name, cls): + def decorator(extension): + if hasattr(cls, name): + warnings.warn( + f"registration of extension {extension!r} under name {name!r} for type " + "{cls!r} is overriding a preexisting attribute with the same name.", + ExtensionRegistrationWarning, + stacklevel=2, + ) + setattr(cls, name, _CachedExtension(name, extension)) + return extension + + return decorator + + class SummariesExtension: """Base class for extending the properties in :attr:`pystac.Collection.summaries` to include properties defined by a STAC Extension. diff --git a/pystac/extensions/eo.py b/pystac/extensions/eo.py index 8f1b4951b..e34dd7ab6 100644 --- a/pystac/extensions/eo.py +++ b/pystac/extensions/eo.py @@ -21,6 +21,7 @@ ExtensionManagementMixin, PropertiesExtension, SummariesExtension, + register_extension, ) from pystac.extensions.hooks import ExtensionHooks from pystac.serialization.identify import STACJSONDescription, STACVersionID @@ -275,6 +276,7 @@ def band_description(common_name: str) -> Optional[str]: return None +@register_extension("eo", pystac.Item) class EOExtension( Generic[T], PropertiesExtension, diff --git a/pystac/extensions/projection.py b/pystac/extensions/projection.py index eb5fb5b27..9c81dfe3f 100644 --- a/pystac/extensions/projection.py +++ b/pystac/extensions/projection.py @@ -6,10 +6,12 @@ from typing import Any, Dict, Generic, Iterable, List, Optional, TypeVar, Union, cast import pystac + from pystac.extensions.base import ( ExtensionManagementMixin, PropertiesExtension, SummariesExtension, + register_extension, ) from pystac.extensions.hooks import ExtensionHooks @@ -29,6 +31,7 @@ TRANSFORM_PROP: str = PREFIX + "transform" +@register_extension("proj", pystac.Item) class ProjectionExtension( Generic[T], PropertiesExtension,