Skip to content

Commit 4bb702a

Browse files
committed
Update code and tests to pass with type annotations
1 parent d531cc2 commit 4bb702a

34 files changed

+905
-683
lines changed

pystac/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@ class STACError(Exception):
1313
pass
1414

1515

16+
class STACTypeError(Exception):
17+
"""A STACTypeError is raised when encountering a representation of
18+
a STAC entity that is not correct for the context; for example, if
19+
a Catalog JSON was read in as an Item.
20+
"""
21+
pass
22+
23+
1624
from typing import Any, Dict, Optional
1725
from pystac.version import (__version__, get_stac_version, set_stac_version) # type:ignore
1826
from pystac.stac_io import STAC_IO
@@ -21,8 +29,12 @@ class STACError(Exception):
2129
from pystac.media_type import MediaType # type:ignore
2230
from pystac.link import (Link, HIERARCHICAL_LINKS) # type:ignore
2331
from pystac.catalog import (Catalog, CatalogType) # type:ignore
24-
from pystac.collection import (Collection, Extent, SpatialExtent, TemporalExtent, # type:ignore
25-
Provider) # type:ignore
32+
from pystac.collection import (
33+
Collection, # type:ignore
34+
Extent, # type:ignore
35+
SpatialExtent, # type:ignore
36+
TemporalExtent, # type:ignore
37+
Provider) # type:ignore
2638
from pystac.item import (Item, Asset, CommonMetadata) # type:ignore
2739

2840
from pystac.serialization import stac_object_from_dict

pystac/cache.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from collections import ChainMap
22
from copy import copy
3-
from pystac.collection import Collection
4-
from typing import Any, Dict, List, Optional, Tuple, Union, cast
5-
from pystac.stac_object import STACObject
3+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast
64

7-
import pystac
5+
import pystac as ps
86

7+
if TYPE_CHECKING:
8+
from pystac.stac_object import STACObject
9+
from pystac.collection import Collection
910

10-
def get_cache_key(stac_object: STACObject) -> Tuple[str, bool]:
11+
12+
def get_cache_key(stac_object: "STACObject") -> Tuple[str, bool]:
1113
"""Produce a cache key for the given STAC object.
1214
1315
If a self href is set, use that as the cache key.
@@ -56,16 +58,16 @@ class ResolvedObjectCache:
5658
ids_to_collections (Dict[str, Collection]): Map of collection IDs to collections.
5759
"""
5860
def __init__(self,
59-
id_keys_to_objects: Optional[Dict[str, STACObject]] = None,
60-
hrefs_to_objects: Optional[Dict[str, STACObject]] = None,
61-
ids_to_collections: Dict[str, Collection] = None):
61+
id_keys_to_objects: Optional[Dict[str, "STACObject"]] = None,
62+
hrefs_to_objects: Optional[Dict[str, "STACObject"]] = None,
63+
ids_to_collections: Dict[str, "Collection"] = None):
6264
self.id_keys_to_objects = id_keys_to_objects or {}
6365
self.hrefs_to_objects = hrefs_to_objects or {}
6466
self.ids_to_collections = ids_to_collections or {}
6567

6668
self._collection_cache = None
6769

68-
def get_or_cache(self, obj: STACObject) -> STACObject:
70+
def get_or_cache(self, obj: "STACObject") -> "STACObject":
6971
"""Gets the STACObject that is the cached version of the given STACObject; or, if
7072
none exists, sets the cached object to the given object.
7173
@@ -91,7 +93,7 @@ def get_or_cache(self, obj: STACObject) -> STACObject:
9193
self.cache(obj)
9294
return obj
9395

94-
def get(self, obj: STACObject) -> Optional[STACObject]:
96+
def get(self, obj: "STACObject") -> Optional["STACObject"]:
9597
"""Get the cached object that has the same cache key as the given object.
9698
9799
Args:
@@ -107,7 +109,7 @@ def get(self, obj: STACObject) -> Optional[STACObject]:
107109
else:
108110
return self.id_keys_to_objects.get(key)
109111

110-
def get_by_href(self, href: str) -> Optional[STACObject]:
112+
def get_by_href(self, href: str) -> Optional["STACObject"]:
111113
"""Gets the cached object at href.
112114
113115
Args:
@@ -118,7 +120,7 @@ def get_by_href(self, href: str) -> Optional[STACObject]:
118120
"""
119121
return self.hrefs_to_objects.get(href)
120122

121-
def get_collection_by_id(self, id: str) -> Optional[Collection]:
123+
def get_collection_by_id(self, id: str) -> Optional["Collection"]:
122124
"""Retrieved a cached Collection by its ID.
123125
124126
Args:
@@ -130,7 +132,7 @@ def get_collection_by_id(self, id: str) -> Optional[Collection]:
130132
"""
131133
return self.ids_to_collections.get(id)
132134

133-
def cache(self, obj: STACObject) -> None:
135+
def cache(self, obj: "STACObject") -> None:
134136
"""Set the given object into the cache.
135137
136138
Args:
@@ -142,10 +144,10 @@ def cache(self, obj: STACObject) -> None:
142144
else:
143145
self.id_keys_to_objects[key] = obj
144146

145-
if isinstance(obj, Collection):
147+
if isinstance(obj, ps.Collection):
146148
self.ids_to_collections[obj.id] = obj
147149

148-
def remove(self, obj: STACObject) -> None:
150+
def remove(self, obj: "STACObject") -> None:
149151
"""Removes any cached object that matches the given object's cache key.
150152
151153
Args:
@@ -158,10 +160,10 @@ def remove(self, obj: STACObject) -> None:
158160
else:
159161
self.id_keys_to_objects.pop(key, None)
160162

161-
if obj.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION:
163+
if obj.STAC_OBJECT_TYPE == ps.STACObjectType.COLLECTION:
162164
self.id_keys_to_objects.pop(obj.id, None)
163165

164-
def __contains__(self, obj: STACObject) -> bool:
166+
def __contains__(self, obj: "STACObject") -> bool:
165167
key, is_href = get_cache_key(obj)
166168
return key in self.hrefs_to_objects if is_href else key in self.id_keys_to_objects
167169

@@ -213,23 +215,25 @@ class CollectionCache:
213215
and will set Collection JSON that it reads in order to merge in common properties.
214216
"""
215217
def __init__(self,
216-
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None,
217-
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None):
218+
cached_ids: Dict[str, Union["Collection", Dict[str, Any]]] = None,
219+
cached_hrefs: Dict[str, Union["Collection", Dict[str, Any]]] = None):
218220
self.cached_ids = cached_ids or {}
219221
self.cached_hrefs = cached_hrefs or {}
220222

221-
def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]:
223+
def get_by_id(self, collection_id: str) -> Optional[Union["Collection", Dict[str, Any]]]:
222224
return self.cached_ids.get(collection_id)
223225

224-
def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
226+
def get_by_href(self, href: str) -> Optional[Union["Collection", Dict[str, Any]]]:
225227
return self.cached_hrefs.get(href)
226228

227229
def contains_id(self, collection_id: str) -> bool:
228230
return collection_id in self.cached_ids
229231

230-
def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[str] = None) -> None:
232+
def cache(self,
233+
collection: Union["Collection", Dict[str, Any]],
234+
href: Optional[str] = None) -> None:
231235
"""Caches a collection JSON."""
232-
if isinstance(collection, Collection):
236+
if isinstance(collection, ps.Collection):
233237
self.cached_ids[collection.id] = collection
234238
else:
235239
self.cached_ids[collection['id']] = collection
@@ -241,24 +245,24 @@ def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[st
241245
class ResolvedObjectCollectionCache(CollectionCache):
242246
def __init__(self,
243247
resolved_object_cache: ResolvedObjectCache,
244-
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None,
245-
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None):
248+
cached_ids: Dict[str, Union["Collection", Dict[str, Any]]] = None,
249+
cached_hrefs: Dict[str, Union["Collection", Dict[str, Any]]] = None):
246250
super().__init__(cached_ids, cached_hrefs)
247251
self.resolved_object_cache = resolved_object_cache
248252

249-
def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]:
253+
def get_by_id(self, collection_id: str) -> Optional[Union["Collection", Dict[str, Any]]]:
250254
result = self.resolved_object_cache.get_collection_by_id(collection_id)
251255
if result is None:
252256
return super().get_by_id(collection_id)
253257
else:
254258
return result
255259

256-
def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
260+
def get_by_href(self, href: str) -> Optional[Union["Collection", Dict[str, Any]]]:
257261
result = self.resolved_object_cache.get_by_href(href)
258262
if result is None:
259263
return super().get_by_href(href)
260264
else:
261-
return cast(Collection, result)
265+
return cast(ps.Collection, result)
262266

263267
def contains_id(self, collection_id: str) -> bool:
264268
return (self.resolved_object_cache.contains_collection_id(collection_id)

0 commit comments

Comments
 (0)