Skip to content

Commit eede1e4

Browse files
Merge pull request #668 from pyathena-dev/feature/native-asyncio-cursor-phase3
2 parents 5238436 + 65c093a commit eede1e4

File tree

15 files changed

+944
-569
lines changed

15 files changed

+944
-569
lines changed

pyathena/aio/arrow/cursor.py

Lines changed: 3 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import asyncio
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, cast
6+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
77

8-
from pyathena.aio.common import AioBaseCursor
8+
from pyathena.aio.common import WithAsyncFetch
99
from pyathena.arrow.converter import (
1010
DefaultArrowTypeConverter,
1111
DefaultArrowUnloadTypeConverter,
@@ -14,7 +14,6 @@
1414
from pyathena.common import CursorIterator
1515
from pyathena.error import OperationalError, ProgrammingError
1616
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
17-
from pyathena.result_set import WithResultSet
1817

1918
if TYPE_CHECKING:
2019
import polars as pl
@@ -23,7 +22,7 @@
2322
_logger = logging.getLogger(__name__) # type: ignore
2423

2524

26-
class AioArrowCursor(AioBaseCursor, CursorIterator, WithResultSet):
25+
class AioArrowCursor(WithAsyncFetch):
2726
"""Native asyncio cursor that returns results as Apache Arrow Tables.
2827
2928
Uses ``asyncio.to_thread()`` to create the result set off the event loop.
@@ -50,7 +49,6 @@ def __init__(
5049
unload: bool = False,
5150
result_reuse_enable: bool = False,
5251
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
53-
on_start_query_execution: Optional[Callable[[str], None]] = None,
5452
connect_timeout: Optional[float] = None,
5553
request_timeout: Optional[float] = None,
5654
**kwargs,
@@ -69,10 +67,8 @@ def __init__(
6967
**kwargs,
7068
)
7169
self._unload = unload
72-
self._on_start_query_execution = on_start_query_execution
7370
self._connect_timeout = connect_timeout
7471
self._request_timeout = request_timeout
75-
self._query_id: Optional[str] = None
7672
self._result_set: Optional[AthenaArrowResultSet] = None
7773

7874
@staticmethod
@@ -83,45 +79,6 @@ def get_default_converter(
8379
return DefaultArrowUnloadTypeConverter()
8480
return DefaultArrowTypeConverter()
8581

86-
@property
87-
def arraysize(self) -> int:
88-
return self._arraysize
89-
90-
@arraysize.setter
91-
def arraysize(self, value: int) -> None:
92-
if value <= 0:
93-
raise ProgrammingError("arraysize must be a positive integer value.")
94-
self._arraysize = value
95-
96-
@property # type: ignore
97-
def result_set(self) -> Optional[AthenaArrowResultSet]:
98-
return self._result_set
99-
100-
@result_set.setter
101-
def result_set(self, val) -> None:
102-
self._result_set = val
103-
104-
@property
105-
def query_id(self) -> Optional[str]:
106-
return self._query_id
107-
108-
@query_id.setter
109-
def query_id(self, val) -> None:
110-
self._query_id = val
111-
112-
@property
113-
def rownumber(self) -> Optional[int]:
114-
return self.result_set.rownumber if self.result_set else None
115-
116-
@property
117-
def rowcount(self) -> int:
118-
return self.result_set.rowcount if self.result_set else -1
119-
120-
def close(self) -> None:
121-
"""Close the cursor and release associated resources."""
122-
if self.result_set and not self.result_set.is_closed:
123-
self.result_set.close()
124-
12582
async def execute( # type: ignore[override]
12683
self,
12784
operation: str,
@@ -133,7 +90,6 @@ async def execute( # type: ignore[override]
13390
result_reuse_enable: Optional[bool] = None,
13491
result_reuse_minutes: Optional[int] = None,
13592
paramstyle: Optional[str] = None,
136-
on_start_query_execution: Optional[Callable[[str], None]] = None,
13793
**kwargs,
13894
) -> "AioArrowCursor":
13995
"""Execute a SQL query asynchronously and return results as Arrow Tables.
@@ -148,7 +104,6 @@ async def execute( # type: ignore[override]
148104
result_reuse_enable: Enable Athena result reuse for this query.
149105
result_reuse_minutes: Minutes to reuse cached results.
150106
paramstyle: Parameter style ('qmark' or 'pyformat').
151-
on_start_query_execution: Callback called when query starts.
152107
**kwargs: Additional execution parameters.
153108
154109
Returns:
@@ -179,10 +134,6 @@ async def execute( # type: ignore[override]
179134
paramstyle=paramstyle,
180135
)
181136

182-
if self._on_start_query_execution:
183-
self._on_start_query_execution(self.query_id)
184-
if on_start_query_execution:
185-
on_start_query_execution(self.query_id)
186137
query_execution = await self._poll(self.query_id)
187138
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
188139
self.result_set = await asyncio.to_thread(
@@ -202,84 +153,6 @@ async def execute( # type: ignore[override]
202153
raise OperationalError(query_execution.state_change_reason)
203154
return self
204155

205-
async def executemany( # type: ignore[override]
206-
self,
207-
operation: str,
208-
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
209-
**kwargs,
210-
) -> None:
211-
"""Execute a SQL query multiple times with different parameters.
212-
213-
Args:
214-
operation: SQL query string to execute.
215-
seq_of_parameters: Sequence of parameter sets, one per execution.
216-
**kwargs: Additional keyword arguments passed to each ``execute()``.
217-
"""
218-
for parameters in seq_of_parameters:
219-
await self.execute(operation, parameters, **kwargs)
220-
self._reset_state()
221-
222-
async def cancel(self) -> None:
223-
"""Cancel the currently executing query.
224-
225-
Raises:
226-
ProgrammingError: If no query is currently executing.
227-
"""
228-
if not self.query_id:
229-
raise ProgrammingError("QueryExecutionId is none or empty.")
230-
await self._cancel(self.query_id)
231-
232-
def fetchone(
233-
self,
234-
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
235-
"""Fetch the next row of the result set.
236-
237-
Returns:
238-
A tuple representing the next row, or None if no more rows.
239-
240-
Raises:
241-
ProgrammingError: If no result set is available.
242-
"""
243-
if not self.has_result_set:
244-
raise ProgrammingError("No result set.")
245-
result_set = cast(AthenaArrowResultSet, self.result_set)
246-
return result_set.fetchone()
247-
248-
def fetchmany(
249-
self, size: Optional[int] = None
250-
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
251-
"""Fetch multiple rows from the result set.
252-
253-
Args:
254-
size: Maximum number of rows to fetch. Defaults to arraysize.
255-
256-
Returns:
257-
List of tuples representing the fetched rows.
258-
259-
Raises:
260-
ProgrammingError: If no result set is available.
261-
"""
262-
if not self.has_result_set:
263-
raise ProgrammingError("No result set.")
264-
result_set = cast(AthenaArrowResultSet, self.result_set)
265-
return result_set.fetchmany(size)
266-
267-
def fetchall(
268-
self,
269-
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
270-
"""Fetch all remaining rows from the result set.
271-
272-
Returns:
273-
List of tuples representing all remaining rows.
274-
275-
Raises:
276-
ProgrammingError: If no result set is available.
277-
"""
278-
if not self.has_result_set:
279-
raise ProgrammingError("No result set.")
280-
result_set = cast(AthenaArrowResultSet, self.result_set)
281-
return result_set.fetchall()
282-
283156
def as_arrow(self) -> "Table":
284157
"""Return query results as an Apache Arrow Table.
285158
@@ -301,18 +174,3 @@ def as_polars(self) -> "pl.DataFrame":
301174
raise ProgrammingError("No result set.")
302175
result_set = cast(AthenaArrowResultSet, self.result_set)
303176
return result_set.as_polars()
304-
305-
def __aiter__(self):
306-
return self
307-
308-
async def __anext__(self):
309-
row = self.fetchone()
310-
if row is None:
311-
raise StopAsyncIteration
312-
return row
313-
314-
async def __aenter__(self):
315-
return self
316-
317-
async def __aexit__(self, exc_type, exc_val, exc_tb):
318-
self.close()

pyathena/aio/common.py

Lines changed: 155 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import logging
66
import sys
77
from datetime import datetime, timedelta, timezone
8-
from typing import Any, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
99

1010
from pyathena.aio.util import async_retry_api_call
11-
from pyathena.common import BaseCursor
12-
from pyathena.error import DatabaseError, OperationalError
11+
from pyathena.common import BaseCursor, CursorIterator
12+
from pyathena.error import DatabaseError, OperationalError, ProgrammingError
1313
from pyathena.model import AthenaDatabase, AthenaQueryExecution, AthenaTableMetadata
14+
from pyathena.result_set import AthenaResultSet, WithResultSet
1415

1516
_logger = logging.getLogger(__name__) # type: ignore
1617

@@ -346,3 +347,154 @@ async def list_table_metadata( # type: ignore[override]
346347
if not next_token:
347348
break
348349
return metadata
350+
351+
352+
class WithAsyncFetch(AioBaseCursor, CursorIterator, WithResultSet):
353+
"""Mixin providing shared fetch, lifecycle, and async protocol for SQL cursors.
354+
355+
Provides properties (``arraysize``, ``result_set``, ``query_id``,
356+
``rownumber``, ``rowcount``), lifecycle methods (``close``, ``executemany``,
357+
``cancel``), default sync fetch (for cursors whose result sets load all
358+
data eagerly in ``__init__``), and the async iteration protocol.
359+
360+
Subclasses override ``execute()`` and optionally ``__init__`` and
361+
format-specific helpers.
362+
"""
363+
364+
def __init__(self, **kwargs) -> None:
365+
super().__init__(**kwargs)
366+
self._query_id: Optional[str] = None
367+
self._result_set: Optional[AthenaResultSet] = None
368+
369+
@property
370+
def arraysize(self) -> int:
371+
return self._arraysize
372+
373+
@arraysize.setter
374+
def arraysize(self, value: int) -> None:
375+
if value <= 0:
376+
raise ProgrammingError("arraysize must be a positive integer value.")
377+
self._arraysize = value
378+
379+
@property # type: ignore
380+
def result_set(self) -> Optional[AthenaResultSet]:
381+
return self._result_set
382+
383+
@result_set.setter
384+
def result_set(self, val) -> None:
385+
self._result_set = val
386+
387+
@property
388+
def query_id(self) -> Optional[str]:
389+
return self._query_id
390+
391+
@query_id.setter
392+
def query_id(self, val) -> None:
393+
self._query_id = val
394+
395+
@property
396+
def rownumber(self) -> Optional[int]:
397+
return self.result_set.rownumber if self.result_set else None
398+
399+
@property
400+
def rowcount(self) -> int:
401+
return self.result_set.rowcount if self.result_set else -1
402+
403+
def close(self) -> None:
404+
"""Close the cursor and release associated resources."""
405+
if self.result_set and not self.result_set.is_closed:
406+
self.result_set.close()
407+
408+
async def executemany( # type: ignore[override]
409+
self,
410+
operation: str,
411+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
412+
**kwargs,
413+
) -> None:
414+
"""Execute a SQL query multiple times with different parameters.
415+
416+
Args:
417+
operation: SQL query string to execute.
418+
seq_of_parameters: Sequence of parameter sets, one per execution.
419+
**kwargs: Additional keyword arguments passed to each ``execute()``.
420+
"""
421+
for parameters in seq_of_parameters:
422+
await self.execute(operation, parameters, **kwargs)
423+
# Operations that have result sets are not allowed with executemany.
424+
self._reset_state()
425+
426+
async def cancel(self) -> None:
427+
"""Cancel the currently executing query.
428+
429+
Raises:
430+
ProgrammingError: If no query is currently executing.
431+
"""
432+
if not self.query_id:
433+
raise ProgrammingError("QueryExecutionId is none or empty.")
434+
await self._cancel(self.query_id)
435+
436+
def fetchone(
437+
self,
438+
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
439+
"""Fetch the next row of the result set.
440+
441+
Returns:
442+
A tuple representing the next row, or None if no more rows.
443+
444+
Raises:
445+
ProgrammingError: If no result set is available.
446+
"""
447+
if not self.has_result_set:
448+
raise ProgrammingError("No result set.")
449+
result_set = cast(AthenaResultSet, self.result_set)
450+
return result_set.fetchone()
451+
452+
def fetchmany(
453+
self, size: Optional[int] = None
454+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
455+
"""Fetch multiple rows from the result set.
456+
457+
Args:
458+
size: Maximum number of rows to fetch. Defaults to arraysize.
459+
460+
Returns:
461+
List of tuples representing the fetched rows.
462+
463+
Raises:
464+
ProgrammingError: If no result set is available.
465+
"""
466+
if not self.has_result_set:
467+
raise ProgrammingError("No result set.")
468+
result_set = cast(AthenaResultSet, self.result_set)
469+
return result_set.fetchmany(size)
470+
471+
def fetchall(
472+
self,
473+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
474+
"""Fetch all remaining rows from the result set.
475+
476+
Returns:
477+
List of tuples representing all remaining rows.
478+
479+
Raises:
480+
ProgrammingError: If no result set is available.
481+
"""
482+
if not self.has_result_set:
483+
raise ProgrammingError("No result set.")
484+
result_set = cast(AthenaResultSet, self.result_set)
485+
return result_set.fetchall()
486+
487+
def __aiter__(self):
488+
return self
489+
490+
async def __anext__(self):
491+
row = self.fetchone()
492+
if row is None:
493+
raise StopAsyncIteration
494+
return row
495+
496+
async def __aenter__(self):
497+
return self
498+
499+
async def __aexit__(self, exc_type, exc_val, exc_tb):
500+
self.close()

0 commit comments

Comments
 (0)