33
44import asyncio
55import logging
6- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union , cast
6+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union , cast
77
8- from pyathena .aio .common import AioBaseCursor
8+ from pyathena .aio .common import WithAsyncFetch
99from pyathena .arrow .converter import (
1010 DefaultArrowTypeConverter ,
1111 DefaultArrowUnloadTypeConverter ,
1414from pyathena .common import CursorIterator
1515from pyathena .error import OperationalError , ProgrammingError
1616from pyathena .model import AthenaCompression , AthenaFileFormat , AthenaQueryExecution
17- from pyathena .result_set import WithResultSet
1817
1918if TYPE_CHECKING :
2019 import polars as pl
2322_logger = logging .getLogger (__name__ ) # type: ignore
2423
2524
26- class AioArrowCursor (AioBaseCursor , CursorIterator , WithResultSet ):
25+ class AioArrowCursor (WithAsyncFetch ):
2726 """Native asyncio cursor that returns results as Apache Arrow Tables.
2827
2928 Uses ``asyncio.to_thread()`` to create the result set off the event loop.
@@ -50,7 +49,6 @@ def __init__(
5049 unload : bool = False ,
5150 result_reuse_enable : bool = False ,
5251 result_reuse_minutes : int = CursorIterator .DEFAULT_RESULT_REUSE_MINUTES ,
53- on_start_query_execution : Optional [Callable [[str ], None ]] = None ,
5452 connect_timeout : Optional [float ] = None ,
5553 request_timeout : Optional [float ] = None ,
5654 ** kwargs ,
@@ -69,10 +67,8 @@ def __init__(
6967 ** kwargs ,
7068 )
7169 self ._unload = unload
72- self ._on_start_query_execution = on_start_query_execution
7370 self ._connect_timeout = connect_timeout
7471 self ._request_timeout = request_timeout
75- self ._query_id : Optional [str ] = None
7672 self ._result_set : Optional [AthenaArrowResultSet ] = None
7773
7874 @staticmethod
@@ -83,45 +79,6 @@ def get_default_converter(
8379 return DefaultArrowUnloadTypeConverter ()
8480 return DefaultArrowTypeConverter ()
8581
86- @property
87- def arraysize (self ) -> int :
88- return self ._arraysize
89-
90- @arraysize .setter
91- def arraysize (self , value : int ) -> None :
92- if value <= 0 :
93- raise ProgrammingError ("arraysize must be a positive integer value." )
94- self ._arraysize = value
95-
96- @property # type: ignore
97- def result_set (self ) -> Optional [AthenaArrowResultSet ]:
98- return self ._result_set
99-
100- @result_set .setter
101- def result_set (self , val ) -> None :
102- self ._result_set = val
103-
104- @property
105- def query_id (self ) -> Optional [str ]:
106- return self ._query_id
107-
108- @query_id .setter
109- def query_id (self , val ) -> None :
110- self ._query_id = val
111-
112- @property
113- def rownumber (self ) -> Optional [int ]:
114- return self .result_set .rownumber if self .result_set else None
115-
116- @property
117- def rowcount (self ) -> int :
118- return self .result_set .rowcount if self .result_set else - 1
119-
120- def close (self ) -> None :
121- """Close the cursor and release associated resources."""
122- if self .result_set and not self .result_set .is_closed :
123- self .result_set .close ()
124-
12582 async def execute ( # type: ignore[override]
12683 self ,
12784 operation : str ,
@@ -133,7 +90,6 @@ async def execute( # type: ignore[override]
13390 result_reuse_enable : Optional [bool ] = None ,
13491 result_reuse_minutes : Optional [int ] = None ,
13592 paramstyle : Optional [str ] = None ,
136- on_start_query_execution : Optional [Callable [[str ], None ]] = None ,
13793 ** kwargs ,
13894 ) -> "AioArrowCursor" :
13995 """Execute a SQL query asynchronously and return results as Arrow Tables.
@@ -148,7 +104,6 @@ async def execute( # type: ignore[override]
148104 result_reuse_enable: Enable Athena result reuse for this query.
149105 result_reuse_minutes: Minutes to reuse cached results.
150106 paramstyle: Parameter style ('qmark' or 'pyformat').
151- on_start_query_execution: Callback called when query starts.
152107 **kwargs: Additional execution parameters.
153108
154109 Returns:
@@ -179,10 +134,6 @@ async def execute( # type: ignore[override]
179134 paramstyle = paramstyle ,
180135 )
181136
182- if self ._on_start_query_execution :
183- self ._on_start_query_execution (self .query_id )
184- if on_start_query_execution :
185- on_start_query_execution (self .query_id )
186137 query_execution = await self ._poll (self .query_id )
187138 if query_execution .state == AthenaQueryExecution .STATE_SUCCEEDED :
188139 self .result_set = await asyncio .to_thread (
@@ -202,84 +153,6 @@ async def execute( # type: ignore[override]
202153 raise OperationalError (query_execution .state_change_reason )
203154 return self
204155
205- async def executemany ( # type: ignore[override]
206- self ,
207- operation : str ,
208- seq_of_parameters : List [Optional [Union [Dict [str , Any ], List [str ]]]],
209- ** kwargs ,
210- ) -> None :
211- """Execute a SQL query multiple times with different parameters.
212-
213- Args:
214- operation: SQL query string to execute.
215- seq_of_parameters: Sequence of parameter sets, one per execution.
216- **kwargs: Additional keyword arguments passed to each ``execute()``.
217- """
218- for parameters in seq_of_parameters :
219- await self .execute (operation , parameters , ** kwargs )
220- self ._reset_state ()
221-
222- async def cancel (self ) -> None :
223- """Cancel the currently executing query.
224-
225- Raises:
226- ProgrammingError: If no query is currently executing.
227- """
228- if not self .query_id :
229- raise ProgrammingError ("QueryExecutionId is none or empty." )
230- await self ._cancel (self .query_id )
231-
232- def fetchone (
233- self ,
234- ) -> Optional [Union [Tuple [Optional [Any ], ...], Dict [Any , Optional [Any ]]]]:
235- """Fetch the next row of the result set.
236-
237- Returns:
238- A tuple representing the next row, or None if no more rows.
239-
240- Raises:
241- ProgrammingError: If no result set is available.
242- """
243- if not self .has_result_set :
244- raise ProgrammingError ("No result set." )
245- result_set = cast (AthenaArrowResultSet , self .result_set )
246- return result_set .fetchone ()
247-
248- def fetchmany (
249- self , size : Optional [int ] = None
250- ) -> List [Union [Tuple [Optional [Any ], ...], Dict [Any , Optional [Any ]]]]:
251- """Fetch multiple rows from the result set.
252-
253- Args:
254- size: Maximum number of rows to fetch. Defaults to arraysize.
255-
256- Returns:
257- List of tuples representing the fetched rows.
258-
259- Raises:
260- ProgrammingError: If no result set is available.
261- """
262- if not self .has_result_set :
263- raise ProgrammingError ("No result set." )
264- result_set = cast (AthenaArrowResultSet , self .result_set )
265- return result_set .fetchmany (size )
266-
267- def fetchall (
268- self ,
269- ) -> List [Union [Tuple [Optional [Any ], ...], Dict [Any , Optional [Any ]]]]:
270- """Fetch all remaining rows from the result set.
271-
272- Returns:
273- List of tuples representing all remaining rows.
274-
275- Raises:
276- ProgrammingError: If no result set is available.
277- """
278- if not self .has_result_set :
279- raise ProgrammingError ("No result set." )
280- result_set = cast (AthenaArrowResultSet , self .result_set )
281- return result_set .fetchall ()
282-
283156 def as_arrow (self ) -> "Table" :
284157 """Return query results as an Apache Arrow Table.
285158
@@ -301,18 +174,3 @@ def as_polars(self) -> "pl.DataFrame":
301174 raise ProgrammingError ("No result set." )
302175 result_set = cast (AthenaArrowResultSet , self .result_set )
303176 return result_set .as_polars ()
304-
305- def __aiter__ (self ):
306- return self
307-
308- async def __anext__ (self ):
309- row = self .fetchone ()
310- if row is None :
311- raise StopAsyncIteration
312- return row
313-
314- async def __aenter__ (self ):
315- return self
316-
317- async def __aexit__ (self , exc_type , exc_val , exc_tb ):
318- self .close ()
0 commit comments