Skip to content

Commit 6e0785b

Browse files
committed
Return calling class in from_file
1 parent f1f02a2 commit 6e0785b

File tree

8 files changed

+95
-58
lines changed

8 files changed

+95
-58
lines changed

pystac/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,18 @@ def read_file(href: str) -> STACObject:
9494
a :class:`~pystac.STACObject` and must be read using
9595
:meth:`ItemCollection.from_file <pystac.ItemCollection.from_file>`
9696
"""
97-
return STACObject.from_file(href)
97+
stac_io = StacIO.default()
98+
d = stac_io.read_json(href)
99+
typ = pystac.serialization.identify.identify_stac_object_type(d)
100+
101+
if typ == STACObjectType.CATALOG:
102+
return Catalog.from_file(href)
103+
elif typ == STACObjectType.COLLECTION:
104+
return Collection.from_file(href)
105+
elif typ == STACObjectType.ITEM:
106+
return Item.from_file(href)
107+
else:
108+
raise STACTypeError(f"Cannot read file of type {typ}")
98109

99110

100111
def write_file(

pystac/catalog.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,12 @@ def full_copy(
946946

947947
@classmethod
948948
def from_file(cls, href: str, stac_io: Optional[pystac.StacIO] = None) -> "Catalog":
949+
if stac_io is None:
950+
stac_io = pystac.StacIO.default()
951+
949952
result = super().from_file(href, stac_io)
950953
if not isinstance(result, Catalog):
951954
raise pystac.STACTypeError(f"{result} is not a {Catalog}.")
955+
result._stac_io = stac_io
956+
952957
return result

pystac/serialization/__init__.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,8 @@
11
# flake8: noqa
2-
from typing import Any, Dict, Optional, TYPE_CHECKING
3-
4-
import pystac
52
from pystac.serialization.identify import (
63
STACVersionRange,
74
identify_stac_object,
85
identify_stac_object_type,
96
)
107
from pystac.serialization.common_properties import merge_common_properties
118
from pystac.serialization.migrate import migrate_to_latest
12-
13-
if TYPE_CHECKING:
14-
from pystac.stac_object import STACObject
15-
from pystac.catalog import Catalog
16-
17-
18-
def stac_object_from_dict(
19-
d: Dict[str, Any], href: Optional[str] = None, root: Optional["Catalog"] = None
20-
) -> "STACObject":
21-
"""Determines how to deserialize a dictionary into a STAC object.
22-
23-
Args:
24-
d : The dict to parse.
25-
href : Optional href that is the file location of the object being
26-
parsed.
27-
root : Optional root of the catalog for this object.
28-
If provided, the root's resolved object cache can be used to search for
29-
previously resolved instances of the STAC object.
30-
31-
Note: This is used internally in StacIO instances to deserialize STAC Objects.
32-
"""
33-
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
34-
collection_cache = None
35-
if root is not None:
36-
collection_cache = root._resolved_objects.as_collection_cache()
37-
38-
# Merge common properties in case this is an older STAC object.
39-
merge_common_properties(d, json_href=href, collection_cache=collection_cache)
40-
41-
info = identify_stac_object(d)
42-
43-
d = migrate_to_latest(d, info)
44-
45-
if info.object_type == pystac.STACObjectType.CATALOG:
46-
return pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)
47-
48-
if info.object_type == pystac.STACObjectType.COLLECTION:
49-
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)
50-
51-
if info.object_type == pystac.STACObjectType.ITEM:
52-
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)
53-
54-
raise pystac.STACTypeError(f"Unknown STAC object type {info.object_type}")

pystac/stac_io.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919

2020
import pystac
2121
from pystac.utils import safe_urlparse
22-
import pystac.serialization
22+
from pystac.serialization import (
23+
merge_common_properties,
24+
identify_stac_object_type,
25+
identify_stac_object,
26+
migrate_to_latest,
27+
)
2328

2429
# Use orjson if available
2530
try:
@@ -95,12 +100,31 @@ def stac_object_from_dict(
95100
href: Optional[str] = None,
96101
root: Optional["Catalog_Type"] = None,
97102
) -> "STACObject_Type":
98-
result = pystac.serialization.stac_object_from_dict(d, href, root)
99-
if isinstance(result, pystac.Catalog):
100-
# Set the stac_io instance for usage by io operations
101-
# where this catalog is the root.
103+
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
104+
collection_cache = None
105+
if root is not None:
106+
collection_cache = root._resolved_objects.as_collection_cache()
107+
108+
# Merge common properties in case this is an older STAC object.
109+
merge_common_properties(
110+
d, json_href=href, collection_cache=collection_cache
111+
)
112+
113+
info = identify_stac_object(d)
114+
d = migrate_to_latest(d, info)
115+
116+
if info.object_type == pystac.STACObjectType.CATALOG:
117+
result = pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)
102118
result._stac_io = self
103-
return result
119+
return result
120+
121+
if info.object_type == pystac.STACObjectType.COLLECTION:
122+
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)
123+
124+
if info.object_type == pystac.STACObjectType.ITEM:
125+
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)
126+
127+
raise ValueError(f"Unknown STAC object type {info.object_type}")
104128

105129
def read_json(
106130
self, source: Union[str, "Link_Type"], *args: Any, **kwargs: Any
@@ -302,7 +326,30 @@ def stac_object_from_dict(
302326
root: Optional["Catalog_Type"] = None,
303327
) -> "STACObject_Type":
304328
STAC_IO.issue_deprecation_warning()
305-
return pystac.serialization.stac_object_from_dict(d, href, root)
329+
if identify_stac_object_type(d) == pystac.STACObjectType.ITEM:
330+
collection_cache = None
331+
if root is not None:
332+
collection_cache = root._resolved_objects.as_collection_cache()
333+
334+
# Merge common properties in case this is an older STAC object.
335+
merge_common_properties(
336+
d, json_href=href, collection_cache=collection_cache
337+
)
338+
339+
info = identify_stac_object(d)
340+
341+
d = migrate_to_latest(d, info)
342+
343+
if info.object_type == pystac.STACObjectType.CATALOG:
344+
return pystac.Catalog.from_dict(d, href=href, root=root, migrate=False)
345+
346+
if info.object_type == pystac.STACObjectType.COLLECTION:
347+
return pystac.Collection.from_dict(d, href=href, root=root, migrate=False)
348+
349+
if info.object_type == pystac.STACObjectType.ITEM:
350+
return pystac.Item.from_dict(d, href=href, root=root, migrate=False)
351+
352+
raise ValueError(f"Unknown STAC object type {info.object_type}")
306353

307354
# This is set in __init__.py
308355
_STAC_OBJECT_CLASSES = None

pystac/stac_object.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from pystac import STACError
77
from pystac.link import Link
88
from pystac.utils import is_absolute_href, make_absolute_href
9+
from pystac import serialization
10+
from pystac.serialization.identify import identify_stac_object
911

1012
if TYPE_CHECKING:
1113
from pystac.catalog import Catalog as Catalog_Type
@@ -469,7 +471,10 @@ def from_file(
469471
if not is_absolute_href(href):
470472
href = make_absolute_href(href)
471473

472-
o = stac_io.read_stac_object(href)
474+
d = stac_io.read_json(href)
475+
info = identify_stac_object(d)
476+
d = serialization.migrate.migrate_to_latest(d, info)
477+
o = cls.from_dict(d, href=href)
473478

474479
# Set the self HREF, if it's not already set to something else.
475480
if o.get_self_href() is None:

tests/test_catalog.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,12 @@ def setUp(self) -> None:
11221122
self.stac_io = pystac.StacIO.default()
11231123

11241124
def test_from_dict_returns_subclass(self) -> None:
1125-
11261125
catalog_dict = self.stac_io.read_json(self.TEST_CASE_1)
11271126
custom_catalog = self.BasicCustomCatalog.from_dict(catalog_dict)
1127+
1128+
self.assertIsInstance(custom_catalog, self.BasicCustomCatalog)
1129+
1130+
def test_from_file_returns_subclass(self) -> None:
1131+
custom_catalog = self.BasicCustomCatalog.from_file(self.TEST_CASE_1)
1132+
11281133
self.assertIsInstance(custom_catalog, self.BasicCustomCatalog)

tests/test_collection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,12 @@ def setUp(self) -> None:
279279
self.stac_io = pystac.StacIO.default()
280280

281281
def test_from_dict_returns_subclass(self) -> None:
282-
283282
collection_dict = self.stac_io.read_json(self.MULTI_EXTENT)
284283
custom_collection = self.BasicCustomCollection.from_dict(collection_dict)
284+
285+
self.assertIsInstance(custom_collection, self.BasicCustomCollection)
286+
287+
def test_from_file_returns_subclass(self) -> None:
288+
custom_collection = self.BasicCustomCollection.from_file(self.MULTI_EXTENT)
289+
285290
self.assertIsInstance(custom_collection, self.BasicCustomCollection)

tests/test_item.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,12 @@ def setUp(self) -> None:
713713
self.stac_io = pystac.StacIO.default()
714714

715715
def test_from_dict_returns_subclass(self) -> None:
716-
717716
item_dict = self.stac_io.read_json(self.SAMPLE_ITEM)
718717
custom_item = self.BasicCustomItem.from_dict(item_dict)
718+
719+
self.assertIsInstance(custom_item, self.BasicCustomItem)
720+
721+
def test_from_file_returns_subclass(self) -> None:
722+
custom_item = self.BasicCustomItem.from_file(self.SAMPLE_ITEM)
723+
719724
self.assertIsInstance(custom_item, self.BasicCustomItem)

0 commit comments

Comments
 (0)