diff --git a/CHANGELOG.md b/CHANGELOG.md index c3a1acb04..170b8cf80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,9 @@ - Bug in `pystac.serialization.identify_stac_object_type` where invalid objects with `stac_version == 1.0.0` were incorrectly identified as Catalogs ([#487](https://github.com/stac-utils/pystac/pull/487)) +- `Link` constructor classes (e.g. `Link.from_dict`, `Link.canonical`, etc.) now return + the calling class instead of always returning the `Link` class + ([#512](https://github.com/stac-utils/pystac/pull/512)) ### Removed diff --git a/pystac/asset.py b/pystac/asset.py index 63d281838..564c8b1d4 100644 --- a/pystac/asset.py +++ b/pystac/asset.py @@ -127,7 +127,8 @@ def clone(self) -> "Asset": Returns: Asset: The clone of this asset. """ - return Asset( + cls = self.__class__ + return cls( href=self.href, title=self.title, description=self.description, @@ -139,8 +140,8 @@ def clone(self) -> "Asset": def __repr__(self) -> str: return "".format(self.href) - @staticmethod - def from_dict(d: Dict[str, Any]) -> "Asset": + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "Asset": """Constructs an Asset from a dict. Returns: @@ -156,7 +157,7 @@ def from_dict(d: Dict[str, Any]) -> "Asset": if any(d): properties = d - return Asset( + return cls( href=href, media_type=media_type, title=title, diff --git a/pystac/catalog.py b/pystac/catalog.py index afd9ff455..131d86857 100644 --- a/pystac/catalog.py +++ b/pystac/catalog.py @@ -471,7 +471,8 @@ def to_dict(self, include_self_link: bool = True) -> Dict[str, Any]: return d def clone(self) -> "Catalog": - clone = Catalog( + cls = self.__class__ + clone = cls( id=self.id, description=self.description, title=self.title, diff --git a/pystac/collection.py b/pystac/collection.py index e87e72ba6..e858d7e20 100644 --- a/pystac/collection.py +++ b/pystac/collection.py @@ -638,7 +638,8 @@ def to_dict(self, include_self_link: bool = True) -> Dict[str, Any]: return d def clone(self) -> "Collection": - clone = Collection( + cls = self.__class__ + clone = cls( id=self.id, description=self.description, extent=self.extent.clone(), diff --git a/pystac/item.py b/pystac/item.py index 2aba89b65..7f561f78f 100644 --- a/pystac/item.py +++ b/pystac/item.py @@ -885,7 +885,8 @@ def to_dict(self, include_self_link: bool = True) -> Dict[str, Any]: return d def clone(self) -> "Item": - clone = Item( + cls = self.__class__ + clone = cls( id=self.id, geometry=deepcopy(self.geometry), bbox=copy(self.bbox), diff --git a/pystac/link.py b/pystac/link.py index 514dd847b..c7dbe2b4a 100644 --- a/pystac/link.py +++ b/pystac/link.py @@ -269,16 +269,16 @@ def clone(self) -> "Link": Returns: Link: The cloned link. """ - - return Link( + cls = self.__class__ + return cls( rel=self.rel, target=self.target, media_type=self.media_type, title=self.title, ) - @staticmethod - def from_dict(d: Dict[str, Any]) -> "Link": + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "Link": """Deserializes a Link from a dict. Args: @@ -297,7 +297,7 @@ def from_dict(d: Dict[str, Any]) -> "Link": if any(d): properties = d - return Link( + return cls( rel=rel, target=href, media_type=media_type, @@ -305,47 +305,48 @@ def from_dict(d: Dict[str, Any]) -> "Link": properties=properties, ) - @staticmethod - def root(c: "Catalog_Type") -> "Link": + @classmethod + def root(cls, c: "Catalog_Type") -> "Link": """Creates a link to a root Catalog or Collection.""" - return Link(pystac.RelType.ROOT, c, media_type=pystac.MediaType.JSON) + return cls(pystac.RelType.ROOT, c, media_type=pystac.MediaType.JSON) - @staticmethod - def parent(c: "Catalog_Type") -> "Link": + @classmethod + def parent(cls, c: "Catalog_Type") -> "Link": """Creates a link to a parent Catalog or Collection.""" - return Link(pystac.RelType.PARENT, c, media_type=pystac.MediaType.JSON) + return cls(pystac.RelType.PARENT, c, media_type=pystac.MediaType.JSON) - @staticmethod - def collection(c: "Collection_Type") -> "Link": + @classmethod + def collection(cls, c: "Collection_Type") -> "Link": """Creates a link to an item's Collection.""" - return Link(pystac.RelType.COLLECTION, c, media_type=pystac.MediaType.JSON) + return cls(pystac.RelType.COLLECTION, c, media_type=pystac.MediaType.JSON) - @staticmethod - def self_href(href: str) -> "Link": + @classmethod + def self_href(cls, href: str) -> "Link": """Creates a self link to a file's location.""" - return Link(pystac.RelType.SELF, href, media_type=pystac.MediaType.JSON) + return cls(pystac.RelType.SELF, href, media_type=pystac.MediaType.JSON) - @staticmethod - def child(c: "Catalog_Type", title: Optional[str] = None) -> "Link": + @classmethod + def child(cls, c: "Catalog_Type", title: Optional[str] = None) -> "Link": """Creates a link to a child Catalog or Collection.""" - return Link( + return cls( pystac.RelType.CHILD, c, title=title, media_type=pystac.MediaType.JSON ) - @staticmethod - def item(item: "Item_Type", title: Optional[str] = None) -> "Link": + @classmethod + def item(cls, item: "Item_Type", title: Optional[str] = None) -> "Link": """Creates a link to an Item.""" - return Link( + return cls( pystac.RelType.ITEM, item, title=title, media_type=pystac.MediaType.JSON ) - @staticmethod + @classmethod def canonical( + cls, item_or_collection: Union["Item_Type", "Collection_Type"], title: Optional[str] = None, ) -> "Link": """Creates a canonical link to an Item or Collection.""" - return Link( + return cls( pystac.RelType.CANONICAL, item_or_collection, title=title, diff --git a/tests/test_catalog.py b/tests/test_catalog.py index ce956667a..e536a546c 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -1153,3 +1153,9 @@ def test_from_file_returns_subclass(self) -> None: custom_catalog = self.BasicCustomCatalog.from_file(self.TEST_CASE_1) self.assertIsInstance(custom_catalog, self.BasicCustomCatalog) + + def test_clone(self) -> None: + custom_catalog = self.BasicCustomCatalog.from_file(self.TEST_CASE_1) + cloned_catalog = custom_catalog.clone() + + self.assertIsInstance(cloned_catalog, self.BasicCustomCatalog) diff --git a/tests/test_collection.py b/tests/test_collection.py index c1a6c475d..3dfaddd6d 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -388,3 +388,9 @@ def test_from_file_returns_subclass(self) -> None: custom_collection = self.BasicCustomCollection.from_file(self.MULTI_EXTENT) self.assertIsInstance(custom_collection, self.BasicCustomCollection) + + def test_clone(self) -> None: + custom_collection = self.BasicCustomCollection.from_file(self.MULTI_EXTENT) + cloned_collection = custom_collection.clone() + + self.assertIsInstance(cloned_collection, self.BasicCustomCollection) diff --git a/tests/test_item.py b/tests/test_item.py index e83951f4c..ee28cd104 100644 --- a/tests/test_item.py +++ b/tests/test_item.py @@ -742,3 +742,32 @@ def test_from_file_returns_subclass(self) -> None: custom_item = self.BasicCustomItem.from_file(self.SAMPLE_ITEM) self.assertIsInstance(custom_item, self.BasicCustomItem) + + def test_clone(self) -> None: + custom_item = self.BasicCustomItem.from_file(self.SAMPLE_ITEM) + cloned_item = custom_item.clone() + + self.assertIsInstance(cloned_item, self.BasicCustomItem) + + +class AssetSubClassTest(unittest.TestCase): + class CustomAsset(Asset): + pass + + def setUp(self) -> None: + self.maxDiff = None + with open(TestCases.get_path("data-files/item/sample-item.json")) as src: + item_dict = json.load(src) + + self.asset_dict = item_dict["assets"]["analytic"] + + def test_from_dict(self) -> None: + asset = self.CustomAsset.from_dict(self.asset_dict) + + self.assertIsInstance(asset, self.CustomAsset) + + def test_clone(self) -> None: + asset = self.CustomAsset.from_dict(self.asset_dict) + cloned_asset = asset.clone() + + self.assertIsInstance(cloned_asset, self.CustomAsset) diff --git a/tests/test_link.py b/tests/test_link.py index b9a50b13c..3e83c9987 100644 --- a/tests/test_link.py +++ b/tests/test_link.py @@ -135,3 +135,50 @@ def test_canonical_collection(self) -> None: link = pystac.Link.canonical(self.collection) expected = {"rel": "canonical", "href": None, "type": "application/json"} self.assertEqual(expected, link.to_dict()) + + +class LinkInheritanceTest(unittest.TestCase): + def setUp(self) -> None: + self.maxDiff = None + self.collection = pystac.Collection( + "collection id", "desc", extent=ARBITRARY_EXTENT + ) + self.item = pystac.Item( + id="test-item", + geometry=None, + bbox=None, + datetime=TEST_DATETIME, + properties={}, + ) + + class CustomLink(pystac.Link): + pass + + def test_from_dict(self) -> None: + link = self.CustomLink.from_dict( + {"rel": "r", "href": "t", "type": "a/b", "title": "t", "c": "d", "1": 2} + ) + self.assertIsInstance(link, self.CustomLink) + + def test_collection(self) -> None: + link = self.CustomLink.collection(self.collection) + self.assertIsInstance(link, self.CustomLink) + + def test_child(self) -> None: + link = self.CustomLink.child(self.collection) + self.assertIsInstance(link, self.CustomLink) + + def test_canonical_item(self) -> None: + link = self.CustomLink.canonical(self.item) + self.assertIsInstance(link, self.CustomLink) + + def test_canonical_collection(self) -> None: + link = self.CustomLink.canonical(self.collection) + self.assertIsInstance(link, self.CustomLink) + + def test_clone(self) -> None: + link = self.CustomLink.from_dict( + {"rel": "r", "href": "t", "type": "a/b", "title": "t", "c": "d", "1": 2} + ) + cloned_link = link.clone() + self.assertIsInstance(cloned_link, self.CustomLink)