diff --git a/src/json_stream/base.py b/src/json_stream/base.py index 19c6922..f9ccd0f 100644 --- a/src/json_stream/base.py +++ b/src/json_stream/base.py @@ -14,25 +14,26 @@ class TransientAccessException(Exception): class StreamingJSONBase(ABC): @classmethod - def factory(cls, token, token_stream, persistent): + def factory(cls, token, token_stream, persistent, level = 0): if persistent: if token == '{': - return PersistentStreamingJSONObject(token_stream) + return PersistentStreamingJSONObject(token_stream, level) if token == '[': - return PersistentStreamingJSONList(token_stream) + return PersistentStreamingJSONList(token_stream, level) else: if token == '{': - return TransientStreamingJSONObject(token_stream) + return TransientStreamingJSONObject(token_stream, level) if token == '[': - return TransientStreamingJSONList(token_stream) + return TransientStreamingJSONList(token_stream, level) raise ValueError(f"Unknown operator {token}") # pragma: no cover _persistent_children: bool - def __init__(self, token_stream): + def __init__(self, token_stream, level = 0): self.streaming = True self._stream = token_stream self._child: Optional[StreamingJSONBase] = None + self.level = level def _clear_child(self): if self._child is not None: @@ -52,6 +53,9 @@ def _iter_items(self): def _done(self): self.streaming = False + park_cursor_func = getattr(self._stream, "park_cursor", None) + if park_cursor_func and self.level == 0: + park_cursor_func() raise StopIteration() def read_all(self): @@ -80,8 +84,8 @@ def __deepcopy__(self, memo): class PersistentStreamingJSONBase(StreamingJSONBase, ABC): - def __init__(self, token_stream): - super().__init__(token_stream) + def __init__(self, token_stream, level = 0): + super().__init__(token_stream, level) self._data = self._init_persistent_data() self._persistent_children = True @@ -104,8 +108,8 @@ def __repr__(self): # pragma: no cover class TransientStreamingJSONBase(StreamingJSONBase, ABC): - def __init__(self, token_stream): - super().__init__(token_stream) + def __init__(self, token_stream, level = 0): + super().__init__(token_stream, level) self._started = False self._persistent_children = False @@ -146,7 +150,7 @@ def _load_item(self): else: # pragma: no cover raise ValueError(f"Expecting value, comma or ], got {v}") if token_type == TokenType.OPERATOR: - self._child = v = self.factory(v, self._stream, self._persistent_children) + self._child = v = self.factory(v, self._stream, self._persistent_children, self.level+1) return v def _get__iter__(self): @@ -179,8 +183,8 @@ def __getitem__(self, k) -> Any: class TransientStreamingJSONList(TransientStreamingJSONBase, StreamingJSONList): - def __init__(self, token_stream): - super().__init__(token_stream) + def __init__(self, token_stream, level = 0): + super().__init__(token_stream,level) self._index = -1 def _load_item(self): @@ -214,7 +218,7 @@ def _load_item(self): token_type, v = next(self._stream) if token_type == TokenType.OPERATOR: - self._child = v = self.factory(v, self._stream, self._persistent_children) + self._child = v = self.factory(v, self._stream, self._persistent_children, self.level+1) return k, v def _get__iter__(self):