Skip to content

Commit cde0c68

Browse files
committed
Add type annotations to everything
1 parent d71c303 commit cde0c68

30 files changed

+2813
-2487
lines changed

pystac/__init__.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,25 @@ class STACError(Exception):
1313
pass
1414

1515

16-
from pystac.version import (__version__, get_stac_version, set_stac_version)
16+
from typing import Any, Dict, Optional
17+
from pystac.version import (__version__, get_stac_version, set_stac_version) # type:ignore
1718
from pystac.stac_io import STAC_IO
18-
from pystac.extensions import Extensions
19-
from pystac.stac_object import (STACObject, STACObjectType)
20-
from pystac.media_type import MediaType
21-
from pystac.link import (Link, HIERARCHICAL_LINKS)
22-
from pystac.catalog import (Catalog, CatalogType)
23-
from pystac.collection import (Collection, Extent, SpatialExtent, TemporalExtent, Provider)
24-
from pystac.item import (Item, Asset, CommonMetadata)
19+
from pystac.extensions import Extensions # type:ignore
20+
from pystac.stac_object import (STACObject, STACObjectType) # type:ignore
21+
from pystac.media_type import MediaType # type:ignore
22+
from pystac.link import (Link, HIERARCHICAL_LINKS) # type:ignore
23+
from pystac.catalog import (Catalog, CatalogType) # type:ignore
24+
from pystac.collection import (Collection, Extent, SpatialExtent, TemporalExtent, # type:ignore
25+
Provider) # type:ignore
26+
from pystac.item import (Item, Asset, CommonMetadata) # type:ignore
2527

2628
from pystac.serialization import stac_object_from_dict
2729

2830
import pystac.validation
2931

3032
STAC_IO.stac_object_from_dict = stac_object_from_dict
3133

32-
from pystac import extensions
34+
import pystac.extensions.base
3335
import pystac.extensions.eo
3436
import pystac.extensions.label
3537
import pystac.extensions.pointcloud
@@ -43,19 +45,24 @@ class STACError(Exception):
4345
import pystac.extensions.view
4446
import pystac.extensions.file
4547

46-
STAC_EXTENSIONS = extensions.base.RegisteredSTACExtensions([
47-
extensions.eo.EO_EXTENSION_DEFINITION, extensions.label.LABEL_EXTENSION_DEFINITION,
48-
extensions.pointcloud.POINTCLOUD_EXTENSION_DEFINITION,
49-
extensions.projection.PROJECTION_EXTENSION_DEFINITION, extensions.sar.SAR_EXTENSION_DEFINITION,
50-
extensions.sat.SAT_EXTENSION_DEFINITION, extensions.scientific.SCIENTIFIC_EXTENSION_DEFINITION,
51-
extensions.single_file_stac.SFS_EXTENSION_DEFINITION,
52-
extensions.timestamps.TIMESTAMPS_EXTENSION_DEFINITION,
53-
extensions.version.VERSION_EXTENSION_DEFINITION, extensions.view.VIEW_EXTENSION_DEFINITION,
54-
extensions.file.FILE_EXTENSION_DEFINITION
55-
])
56-
57-
58-
def read_file(href):
48+
STAC_EXTENSIONS: pystac.extensions.base.RegisteredSTACExtensions = pystac.extensions.base.RegisteredSTACExtensions(
49+
[
50+
pystac.extensions.eo.EO_EXTENSION_DEFINITION,
51+
pystac.extensions.label.LABEL_EXTENSION_DEFINITION,
52+
pystac.extensions.pointcloud.POINTCLOUD_EXTENSION_DEFINITION,
53+
pystac.extensions.projection.PROJECTION_EXTENSION_DEFINITION,
54+
pystac.extensions.sar.SAR_EXTENSION_DEFINITION,
55+
pystac.extensions.sat.SAT_EXTENSION_DEFINITION,
56+
pystac.extensions.scientific.SCIENTIFIC_EXTENSION_DEFINITION,
57+
pystac.extensions.single_file_stac.SFS_EXTENSION_DEFINITION,
58+
pystac.extensions.timestamps.TIMESTAMPS_EXTENSION_DEFINITION,
59+
pystac.extensions.version.VERSION_EXTENSION_DEFINITION,
60+
pystac.extensions.view.VIEW_EXTENSION_DEFINITION,
61+
pystac.extensions.file.FILE_EXTENSION_DEFINITION
62+
])
63+
64+
65+
def read_file(href: str) -> STACObject:
5966
"""Reads a STAC object from a file.
6067
6168
This method will return either a Catalog, a Collection, or an Item based on what the
@@ -73,7 +80,7 @@ def read_file(href):
7380
return STACObject.from_file(href)
7481

7582

76-
def write_file(obj, include_self_link=True, dest_href=None):
83+
def write_file(obj: STACObject, include_self_link: bool = True, dest_href: Optional[str] = None):
7784
"""Writes a STACObject to a file.
7885
7986
This will write only the Catalog, Collection or Item ``obj``. It will not attempt
@@ -96,7 +103,7 @@ def write_file(obj, include_self_link=True, dest_href=None):
96103
obj.save_object(include_self_link=include_self_link, dest_href=dest_href)
97104

98105

99-
def read_dict(d, href=None, root=None):
106+
def read_dict(d: Dict[str, Any], href: Optional[str] = None, root: Optional[Catalog] = None):
100107
"""Reads a STAC object from a dict representing the serialized JSON version of the
101108
STAC object.
102109

pystac/cache.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from collections import ChainMap
22
from copy import copy
3+
from pystac.collection import Collection
4+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
5+
from pystac.stac_object import STACObject
36

47
import pystac
58

69

7-
def get_cache_key(stac_object):
10+
def get_cache_key(stac_object: STACObject) -> Tuple[str, bool]:
811
"""Produce a cache key for the given STAC object.
912
1013
If a self href is set, use that as the cache key.
@@ -20,7 +23,7 @@ def get_cache_key(stac_object):
2023
if href is not None:
2124
return (href, True)
2225
else:
23-
ids = []
26+
ids: List[str] = []
2427
obj = stac_object
2528
while obj is not None:
2629
ids.append(obj.id)
@@ -52,14 +55,17 @@ class ResolvedObjectCache:
5255
their cached object.
5356
ids_to_collections (Dict[str, Collection]): Map of collection IDs to collections.
5457
"""
55-
def __init__(self, id_keys_to_objects=None, hrefs_to_objects=None, ids_to_collections=None):
58+
def __init__(self,
59+
id_keys_to_objects: Optional[Dict[str, STACObject]] = None,
60+
hrefs_to_objects: Optional[Dict[str, STACObject]] = None,
61+
ids_to_collections: Dict[str, Collection] = None):
5662
self.id_keys_to_objects = id_keys_to_objects or {}
5763
self.hrefs_to_objects = hrefs_to_objects or {}
5864
self.ids_to_collections = ids_to_collections or {}
5965

6066
self._collection_cache = None
6167

62-
def get_or_cache(self, obj):
68+
def get_or_cache(self, obj: STACObject) -> STACObject:
6369
"""Gets the STACObject that is the cached version of the given STACObject; or, if
6470
none exists, sets the cached object to the given object.
6571
@@ -85,7 +91,7 @@ def get_or_cache(self, obj):
8591
self.cache(obj)
8692
return obj
8793

88-
def get(self, obj):
94+
def get(self, obj: STACObject) -> Optional[STACObject]:
8995
"""Get the cached object that has the same cache key as the given object.
9096
9197
Args:
@@ -101,7 +107,7 @@ def get(self, obj):
101107
else:
102108
return self.id_keys_to_objects.get(key)
103109

104-
def get_by_href(self, href):
110+
def get_by_href(self, href: str) -> Optional[STACObject]:
105111
"""Gets the cached object at href.
106112
107113
Args:
@@ -112,7 +118,7 @@ def get_by_href(self, href):
112118
"""
113119
return self.hrefs_to_objects.get(href)
114120

115-
def get_collection_by_id(self, id):
121+
def get_collection_by_id(self, id: str) -> Optional[Collection]:
116122
"""Retrieved a cached Collection by its ID.
117123
118124
Args:
@@ -124,7 +130,7 @@ def get_collection_by_id(self, id):
124130
"""
125131
return self.ids_to_collections.get(id)
126132

127-
def cache(self, obj):
133+
def cache(self, obj: STACObject) -> None:
128134
"""Set the given object into the cache.
129135
130136
Args:
@@ -136,10 +142,10 @@ def cache(self, obj):
136142
else:
137143
self.id_keys_to_objects[key] = obj
138144

139-
if obj.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION:
145+
if isinstance(obj, Collection):
140146
self.ids_to_collections[obj.id] = obj
141147

142-
def remove(self, obj):
148+
def remove(self, obj: STACObject) -> None:
143149
"""Removes any cached object that matches the given object's cache key.
144150
145151
Args:
@@ -155,21 +161,21 @@ def remove(self, obj):
155161
if obj.STAC_OBJECT_TYPE == pystac.STACObjectType.COLLECTION:
156162
self.id_keys_to_objects.pop(obj.id, None)
157163

158-
def __contains__(self, obj):
164+
def __contains__(self, obj: STACObject) -> bool:
159165
key, is_href = get_cache_key(obj)
160166
return key in self.hrefs_to_objects if is_href else key in self.id_keys_to_objects
161167

162-
def contains_collection_id(self, collection_id):
168+
def contains_collection_id(self, collection_id: str) -> bool:
163169
"""Returns True if there is a collection with given collection ID is cached."""
164170
return collection_id in self.ids_to_collections
165171

166-
def as_collection_cache(self):
172+
def as_collection_cache(self) -> "CollectionCache":
167173
if self._collection_cache is None:
168174
self._collection_cache = ResolvedObjectCollectionCache(self)
169175
return self._collection_cache
170176

171177
@staticmethod
172-
def merge(first, second):
178+
def merge(first: "ResolvedObjectCache", second: "ResolvedObjectCache") -> "ResolvedObjectCache":
173179
"""Merges two ResolvedObjectCache.
174180
175181
The merged cache will give preference to the first argument; that is, if there
@@ -206,55 +212,65 @@ class CollectionCache:
206212
The CollectionCache will contain collections as either as dicts or PySTAC Collections,
207213
and will set Collection JSON that it reads in order to merge in common properties.
208214
"""
209-
def __init__(self, cached_ids=None, cached_hrefs=None):
215+
def __init__(self,
216+
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None,
217+
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None):
210218
self.cached_ids = cached_ids or {}
211219
self.cached_hrefs = cached_hrefs or {}
212220

213-
def get_by_id(self, collection_id):
221+
def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]:
214222
return self.cached_ids.get(collection_id)
215223

216-
def get_by_href(self, href):
224+
def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
217225
return self.cached_hrefs.get(href)
218226

219-
def contains_id(self, collection_id):
227+
def contains_id(self, collection_id: str) -> bool:
220228
return collection_id in self.cached_ids
221229

222-
def cache(self, collection, href=None):
230+
def cache(self, collection: Union[Collection, Dict[str, Any]], href: Optional[str] = None) -> None:
223231
"""Caches a collection JSON."""
224-
self.cached_ids[collection['id']] = collection
232+
if isinstance(collection, Collection):
233+
self.cached_ids[collection.id] = collection
234+
else:
235+
self.cached_ids[collection['id']] = collection
225236

226237
if href is not None:
227238
self.cached_hrefs[href] = collection
228239

229240

230241
class ResolvedObjectCollectionCache(CollectionCache):
231-
def __init__(self, resolved_object_cache, cached_ids=None, cached_hrefs=None):
242+
def __init__(self,
243+
resolved_object_cache: ResolvedObjectCache,
244+
cached_ids: Dict[str, Union[Collection, Dict[str, Any]]] = None,
245+
cached_hrefs: Dict[str, Union[Collection, Dict[str, Any]]] = None):
232246
super().__init__(cached_ids, cached_hrefs)
233247
self.resolved_object_cache = resolved_object_cache
234248

235-
def get_by_id(self, collection_id):
249+
def get_by_id(self, collection_id: str) -> Optional[Union[Collection, Dict[str, Any]]]:
236250
result = self.resolved_object_cache.get_collection_by_id(collection_id)
237251
if result is None:
238252
return super().get_by_id(collection_id)
239253
else:
240254
return result
241255

242-
def get_by_href(self, href):
256+
def get_by_href(self, href: str) -> Optional[Union[Collection, Dict[str, Any]]]:
243257
result = self.resolved_object_cache.get_by_href(href)
244258
if result is None:
245259
return super().get_by_href(href)
246260
else:
247-
return result
261+
return cast(Collection, result)
248262

249-
def contains_id(self, collection_id):
263+
def contains_id(self, collection_id: str) -> bool:
250264
return (self.resolved_object_cache.contains_collection_id(collection_id)
251265
or super().contains_id(collection_id))
252266

253-
def cache(self, collection, href=None):
267+
def cache(self, collection: Dict[str, Any], href: Optional[str] = None) -> None:
254268
super().cache(collection, href)
255269

256270
@staticmethod
257-
def merge(resolved_object_cache, first, second):
271+
def merge(resolved_object_cache: ResolvedObjectCache,
272+
first: Optional["ResolvedObjectCollectionCache"],
273+
second: Optional["ResolvedObjectCollectionCache"]) -> "ResolvedObjectCollectionCache":
258274
first_cached_ids = {}
259275
if first is not None:
260276
first_cached_ids = copy(first.cached_ids)

0 commit comments

Comments
 (0)