Skip to content

Commit 5238436

Browse files
Merge pull request #667 from pyathena-dev/feature/native-asyncio-cursor-phase2
Add native asyncio specialized cursors and DRY refactors (Phase 2)
2 parents f011ea8 + b5353e7 commit 5238436

File tree

19 files changed

+1604
-51
lines changed

19 files changed

+1604
-51
lines changed

pyathena/aio/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
# -*- coding: utf-8 -*-
2-
from __future__ import annotations

pyathena/aio/arrow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# -*- coding: utf-8 -*-

pyathena/aio/arrow/cursor.py

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
4+
import asyncio
5+
import logging
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union, cast
7+
8+
from pyathena.aio.common import AioBaseCursor
9+
from pyathena.arrow.converter import (
10+
DefaultArrowTypeConverter,
11+
DefaultArrowUnloadTypeConverter,
12+
)
13+
from pyathena.arrow.result_set import AthenaArrowResultSet
14+
from pyathena.common import CursorIterator
15+
from pyathena.error import OperationalError, ProgrammingError
16+
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
17+
from pyathena.result_set import WithResultSet
18+
19+
if TYPE_CHECKING:
20+
import polars as pl
21+
from pyarrow import Table
22+
23+
_logger = logging.getLogger(__name__) # type: ignore
24+
25+
26+
class AioArrowCursor(AioBaseCursor, CursorIterator, WithResultSet):
27+
"""Native asyncio cursor that returns results as Apache Arrow Tables.
28+
29+
Uses ``asyncio.to_thread()`` to create the result set off the event loop.
30+
Since ``AthenaArrowResultSet`` loads all data in ``__init__`` (via S3),
31+
fetch methods are synchronous (in-memory only) and do not need to be async.
32+
33+
Example:
34+
>>> async with await pyathena.aconnect(...) as conn:
35+
... cursor = conn.cursor(AioArrowCursor)
36+
... await cursor.execute("SELECT * FROM my_table")
37+
... table = cursor.as_arrow()
38+
"""
39+
40+
def __init__(
41+
self,
42+
s3_staging_dir: Optional[str] = None,
43+
schema_name: Optional[str] = None,
44+
catalog_name: Optional[str] = None,
45+
work_group: Optional[str] = None,
46+
poll_interval: float = 1,
47+
encryption_option: Optional[str] = None,
48+
kms_key: Optional[str] = None,
49+
kill_on_interrupt: bool = True,
50+
unload: bool = False,
51+
result_reuse_enable: bool = False,
52+
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
53+
on_start_query_execution: Optional[Callable[[str], None]] = None,
54+
connect_timeout: Optional[float] = None,
55+
request_timeout: Optional[float] = None,
56+
**kwargs,
57+
) -> None:
58+
super().__init__(
59+
s3_staging_dir=s3_staging_dir,
60+
schema_name=schema_name,
61+
catalog_name=catalog_name,
62+
work_group=work_group,
63+
poll_interval=poll_interval,
64+
encryption_option=encryption_option,
65+
kms_key=kms_key,
66+
kill_on_interrupt=kill_on_interrupt,
67+
result_reuse_enable=result_reuse_enable,
68+
result_reuse_minutes=result_reuse_minutes,
69+
**kwargs,
70+
)
71+
self._unload = unload
72+
self._on_start_query_execution = on_start_query_execution
73+
self._connect_timeout = connect_timeout
74+
self._request_timeout = request_timeout
75+
self._query_id: Optional[str] = None
76+
self._result_set: Optional[AthenaArrowResultSet] = None
77+
78+
@staticmethod
79+
def get_default_converter(
80+
unload: bool = False,
81+
) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]:
82+
if unload:
83+
return DefaultArrowUnloadTypeConverter()
84+
return DefaultArrowTypeConverter()
85+
86+
@property
87+
def arraysize(self) -> int:
88+
return self._arraysize
89+
90+
@arraysize.setter
91+
def arraysize(self, value: int) -> None:
92+
if value <= 0:
93+
raise ProgrammingError("arraysize must be a positive integer value.")
94+
self._arraysize = value
95+
96+
@property # type: ignore
97+
def result_set(self) -> Optional[AthenaArrowResultSet]:
98+
return self._result_set
99+
100+
@result_set.setter
101+
def result_set(self, val) -> None:
102+
self._result_set = val
103+
104+
@property
105+
def query_id(self) -> Optional[str]:
106+
return self._query_id
107+
108+
@query_id.setter
109+
def query_id(self, val) -> None:
110+
self._query_id = val
111+
112+
@property
113+
def rownumber(self) -> Optional[int]:
114+
return self.result_set.rownumber if self.result_set else None
115+
116+
@property
117+
def rowcount(self) -> int:
118+
return self.result_set.rowcount if self.result_set else -1
119+
120+
def close(self) -> None:
121+
"""Close the cursor and release associated resources."""
122+
if self.result_set and not self.result_set.is_closed:
123+
self.result_set.close()
124+
125+
async def execute( # type: ignore[override]
126+
self,
127+
operation: str,
128+
parameters: Optional[Union[Dict[str, Any], List[str]]] = None,
129+
work_group: Optional[str] = None,
130+
s3_staging_dir: Optional[str] = None,
131+
cache_size: Optional[int] = 0,
132+
cache_expiration_time: Optional[int] = 0,
133+
result_reuse_enable: Optional[bool] = None,
134+
result_reuse_minutes: Optional[int] = None,
135+
paramstyle: Optional[str] = None,
136+
on_start_query_execution: Optional[Callable[[str], None]] = None,
137+
**kwargs,
138+
) -> "AioArrowCursor":
139+
"""Execute a SQL query asynchronously and return results as Arrow Tables.
140+
141+
Args:
142+
operation: SQL query string to execute.
143+
parameters: Query parameters for parameterized queries.
144+
work_group: Athena workgroup to use for this query.
145+
s3_staging_dir: S3 location for query results.
146+
cache_size: Number of queries to check for result caching.
147+
cache_expiration_time: Cache expiration time in seconds.
148+
result_reuse_enable: Enable Athena result reuse for this query.
149+
result_reuse_minutes: Minutes to reuse cached results.
150+
paramstyle: Parameter style ('qmark' or 'pyformat').
151+
on_start_query_execution: Callback called when query starts.
152+
**kwargs: Additional execution parameters.
153+
154+
Returns:
155+
Self reference for method chaining.
156+
"""
157+
self._reset_state()
158+
if self._unload:
159+
s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir
160+
if not s3_staging_dir:
161+
raise ProgrammingError("If the unload option is used, s3_staging_dir is required.")
162+
operation, unload_location = self._formatter.wrap_unload(
163+
operation,
164+
s3_staging_dir=s3_staging_dir,
165+
format_=AthenaFileFormat.FILE_FORMAT_PARQUET,
166+
compression=AthenaCompression.COMPRESSION_SNAPPY,
167+
)
168+
else:
169+
unload_location = None
170+
self.query_id = await self._execute(
171+
operation,
172+
parameters=parameters,
173+
work_group=work_group,
174+
s3_staging_dir=s3_staging_dir,
175+
cache_size=cache_size,
176+
cache_expiration_time=cache_expiration_time,
177+
result_reuse_enable=result_reuse_enable,
178+
result_reuse_minutes=result_reuse_minutes,
179+
paramstyle=paramstyle,
180+
)
181+
182+
if self._on_start_query_execution:
183+
self._on_start_query_execution(self.query_id)
184+
if on_start_query_execution:
185+
on_start_query_execution(self.query_id)
186+
query_execution = await self._poll(self.query_id)
187+
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
188+
self.result_set = await asyncio.to_thread(
189+
AthenaArrowResultSet,
190+
connection=self._connection,
191+
converter=self._converter,
192+
query_execution=query_execution,
193+
arraysize=self.arraysize,
194+
retry_config=self._retry_config,
195+
unload=self._unload,
196+
unload_location=unload_location,
197+
connect_timeout=self._connect_timeout,
198+
request_timeout=self._request_timeout,
199+
**kwargs,
200+
)
201+
else:
202+
raise OperationalError(query_execution.state_change_reason)
203+
return self
204+
205+
async def executemany( # type: ignore[override]
206+
self,
207+
operation: str,
208+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
209+
**kwargs,
210+
) -> None:
211+
"""Execute a SQL query multiple times with different parameters.
212+
213+
Args:
214+
operation: SQL query string to execute.
215+
seq_of_parameters: Sequence of parameter sets, one per execution.
216+
**kwargs: Additional keyword arguments passed to each ``execute()``.
217+
"""
218+
for parameters in seq_of_parameters:
219+
await self.execute(operation, parameters, **kwargs)
220+
self._reset_state()
221+
222+
async def cancel(self) -> None:
223+
"""Cancel the currently executing query.
224+
225+
Raises:
226+
ProgrammingError: If no query is currently executing.
227+
"""
228+
if not self.query_id:
229+
raise ProgrammingError("QueryExecutionId is none or empty.")
230+
await self._cancel(self.query_id)
231+
232+
def fetchone(
233+
self,
234+
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
235+
"""Fetch the next row of the result set.
236+
237+
Returns:
238+
A tuple representing the next row, or None if no more rows.
239+
240+
Raises:
241+
ProgrammingError: If no result set is available.
242+
"""
243+
if not self.has_result_set:
244+
raise ProgrammingError("No result set.")
245+
result_set = cast(AthenaArrowResultSet, self.result_set)
246+
return result_set.fetchone()
247+
248+
def fetchmany(
249+
self, size: Optional[int] = None
250+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
251+
"""Fetch multiple rows from the result set.
252+
253+
Args:
254+
size: Maximum number of rows to fetch. Defaults to arraysize.
255+
256+
Returns:
257+
List of tuples representing the fetched rows.
258+
259+
Raises:
260+
ProgrammingError: If no result set is available.
261+
"""
262+
if not self.has_result_set:
263+
raise ProgrammingError("No result set.")
264+
result_set = cast(AthenaArrowResultSet, self.result_set)
265+
return result_set.fetchmany(size)
266+
267+
def fetchall(
268+
self,
269+
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
270+
"""Fetch all remaining rows from the result set.
271+
272+
Returns:
273+
List of tuples representing all remaining rows.
274+
275+
Raises:
276+
ProgrammingError: If no result set is available.
277+
"""
278+
if not self.has_result_set:
279+
raise ProgrammingError("No result set.")
280+
result_set = cast(AthenaArrowResultSet, self.result_set)
281+
return result_set.fetchall()
282+
283+
def as_arrow(self) -> "Table":
284+
"""Return query results as an Apache Arrow Table.
285+
286+
Returns:
287+
Apache Arrow Table containing all query results.
288+
"""
289+
if not self.has_result_set:
290+
raise ProgrammingError("No result set.")
291+
result_set = cast(AthenaArrowResultSet, self.result_set)
292+
return result_set.as_arrow()
293+
294+
def as_polars(self) -> "pl.DataFrame":
295+
"""Return query results as a Polars DataFrame.
296+
297+
Returns:
298+
Polars DataFrame containing all query results.
299+
"""
300+
if not self.has_result_set:
301+
raise ProgrammingError("No result set.")
302+
result_set = cast(AthenaArrowResultSet, self.result_set)
303+
return result_set.as_polars()
304+
305+
def __aiter__(self):
306+
return self
307+
308+
async def __anext__(self):
309+
row = self.fetchone()
310+
if row is None:
311+
raise StopAsyncIteration
312+
return row
313+
314+
async def __aenter__(self):
315+
return self
316+
317+
async def __aexit__(self, exc_type, exc_val, exc_tb):
318+
self.close()

0 commit comments

Comments
 (0)