diff --git a/CHANGELOG.md b/CHANGELOG.md index 5150a0e5c..f23fbcb89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ### Added +- Add a `preserve_dict` parameter to `ItemCollection.from_dict` and set it to False when using `ItemCollection.from_file`. ([#468](https://github.com/stac-utils/pystac/pull/468)) + ### Changed ### Fixed diff --git a/pystac/item_collection.py b/pystac/item_collection.py index 33cab9c6a..6744f4606 100644 --- a/pystac/item_collection.py +++ b/pystac/item_collection.py @@ -134,16 +134,26 @@ def clone(self) -> "ItemCollection": ) @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "ItemCollection": + def from_dict( + cls, d: Dict[str, Any], preserve_dict: bool = True + ) -> "ItemCollection": """Creates a :class:`ItemCollection` instance from a dictionary. Arguments: d : The dictionary from which the :class:`~ItemCollection` will be created + preserve_dict: If False, the dict parameter ``d`` may be modified + during this method call. Otherwise the dict is not mutated. + Defaults to True, which results results in a deepcopy of the + parameter. Set to False when possible to avoid the performance + hit of a deepcopy. """ if not cls.is_item_collection(d): raise STACTypeError("Dict is not a valid ItemCollection") - items = [pystac.Item.from_dict(item) for item in d.get("features", [])] + items = [ + pystac.Item.from_dict(item, preserve_dict=preserve_dict) + for item in d.get("features", []) + ] extra_fields = {k: v for k, v in d.items() if k not in ("features", "type")} return cls(items=items, extra_fields=extra_fields) @@ -166,7 +176,7 @@ def from_file( d = stac_io.read_json(href) - return cls.from_dict(d) + return cls.from_dict(d, preserve_dict=False) def save_object( self, diff --git a/tests/test_collection.py b/tests/test_collection.py index ddb67be92..3a681c7f4 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -183,7 +183,7 @@ def test_assets(self) -> None: collection = pystac.Collection.from_dict(data) collection.validate() - def test_to_dict_preserves_dict(self) -> None: + def test_from_dict_preserves_dict(self) -> None: path = TestCases.get_path("data-files/collections/with-assets.json") with open(path) as f: collection_dict = json.load(f) diff --git a/tests/test_item_collection.py b/tests/test_item_collection.py index cec478cd1..14c1116cb 100644 --- a/tests/test_item_collection.py +++ b/tests/test_item_collection.py @@ -1,4 +1,6 @@ +from copy import deepcopy import json +from pystac.item_collection import ItemCollection import unittest import pystac @@ -149,3 +151,15 @@ def test_identify_0_9_itemcollection(self) -> None: pystac.ItemCollection.is_item_collection(itemcollection_dict), msg="Did not correctly identify valid STAC 0.9 ItemCollection.", ) + + def test_from_dict_preserves_dict(self) -> None: + param_dict = deepcopy(self.item_collection_dict) + + # test that the parameter is preserved + _ = ItemCollection.from_dict(param_dict) + self.assertEqual(param_dict, self.item_collection_dict) + + # assert that the parameter is not preserved with + # non-default parameter + _ = ItemCollection.from_dict(param_dict, preserve_dict=False) + self.assertNotEqual(param_dict, self.item_collection_dict)