Skip to content

Commit e75cd23

Browse files
Merge pull request #676 from pyathena-dev/feature/669-with-fetch-mixin
Extract shared boilerplate from sync cursors into WithFetch mixin
2 parents cd56348 + b753495 commit e75cd23

File tree

7 files changed

+166
-679
lines changed

7 files changed

+166
-679
lines changed

pyathena/arrow/cursor.py

Lines changed: 4 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
from __future__ import annotations
33

44
import logging
5-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, cast
5+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast
66

77
from pyathena.arrow.converter import (
88
DefaultArrowTypeConverter,
99
DefaultArrowUnloadTypeConverter,
1010
)
1111
from pyathena.arrow.result_set import AthenaArrowResultSet
12-
from pyathena.common import BaseCursor, CursorIterator
12+
from pyathena.common import CursorIterator
1313
from pyathena.error import OperationalError, ProgrammingError
1414
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
15-
from pyathena.result_set import WithResultSet
15+
from pyathena.result_set import WithFetch
1616

1717
if TYPE_CHECKING:
1818
import polars as pl
@@ -21,7 +21,7 @@
2121
_logger = logging.getLogger(__name__) # type: ignore
2222

2323

24-
class ArrowCursor(BaseCursor, CursorIterator, WithResultSet):
24+
class ArrowCursor(WithFetch):
2525
"""Cursor for handling Apache Arrow Table results from Athena queries.
2626
2727
This cursor returns query results as Apache Arrow Tables, which provide
@@ -116,8 +116,6 @@ def __init__(
116116
self._on_start_query_execution = on_start_query_execution
117117
self._connect_timeout = connect_timeout
118118
self._request_timeout = request_timeout
119-
self._query_id: Optional[str] = None
120-
self._result_set: Optional[AthenaArrowResultSet] = None
121119

122120
@staticmethod
123121
def get_default_converter(
@@ -127,45 +125,6 @@ def get_default_converter(
127125
return DefaultArrowUnloadTypeConverter()
128126
return DefaultArrowTypeConverter()
129127

130-
@property
131-
def arraysize(self) -> int:
132-
return self._arraysize
133-
134-
@arraysize.setter
135-
def arraysize(self, value: int) -> None:
136-
if value <= 0:
137-
raise ProgrammingError("arraysize must be a positive integer value.")
138-
self._arraysize = value
139-
140-
@property # type: ignore
141-
def result_set(self) -> Optional[AthenaArrowResultSet]:
142-
return self._result_set
143-
144-
@result_set.setter
145-
def result_set(self, val) -> None:
146-
self._result_set = val
147-
148-
@property
149-
def query_id(self) -> Optional[str]:
150-
return self._query_id
151-
152-
@query_id.setter
153-
def query_id(self, val) -> None:
154-
self._query_id = val
155-
156-
@property
157-
def rownumber(self) -> Optional[int]:
158-
return self.result_set.rownumber if self.result_set else None
159-
160-
@property
161-
def rowcount(self) -> int:
162-
"""Get the number of rows affected by the last operation."""
163-
return self.result_set.rowcount if self.result_set else -1
164-
165-
def close(self) -> None:
166-
if self.result_set and not self.result_set.is_closed:
167-
self.result_set.close()
168-
169128
def execute(
170129
self,
171130
operation: str,
@@ -255,46 +214,6 @@ def execute(
255214
raise OperationalError(query_execution.state_change_reason)
256215
return self
257216

258-
def executemany(
259-
self,
260-
operation: str,
261-
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
262-
**kwargs,
263-
) -> None:
264-
for parameters in seq_of_parameters:
265-
self.execute(operation, parameters, **kwargs)
266-
# Operations that have result sets are not allowed with executemany.
267-
self._reset_state()
268-
269-
def cancel(self) -> None:
270-
if not self.query_id:
271-
raise ProgrammingError("QueryExecutionId is none or empty.")
272-
self._cancel(self.query_id)
273-
274-
def fetchone(
275-
self,
276-
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
277-
if not self.has_result_set:
278-
raise ProgrammingError("No result set.")
279-
result_set = cast(AthenaArrowResultSet, self.result_set)
280-
return result_set.fetchone()
281-
282-
def fetchmany(
283-
self, size: Optional[int] = None
284-
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
285-
if not self.has_result_set:
286-
raise ProgrammingError("No result set.")
287-
result_set = cast(AthenaArrowResultSet, self.result_set)
288-
return result_set.fetchmany(size)
289-
290-
def fetchall(
291-
self,
292-
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
293-
if not self.has_result_set:
294-
raise ProgrammingError("No result set.")
295-
result_set = cast(AthenaArrowResultSet, self.result_set)
296-
return result_set.fetchall()
297-
298217
def as_arrow(self) -> "Table":
299218
"""Return query results as an Apache Arrow Table.
300219

pyathena/async_cursor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from multiprocessing import cpu_count
88
from typing import Any, Dict, List, Optional, Tuple, Union, cast
99

10-
from pyathena.common import CursorIterator
11-
from pyathena.cursor import BaseCursor
10+
from pyathena.common import BaseCursor, CursorIterator
1211
from pyathena.error import NotSupportedError, ProgrammingError
1312
from pyathena.model import AthenaQueryExecution
1413
from pyathena.result_set import AthenaDictResultSet, AthenaResultSet

0 commit comments

Comments
 (0)