Skip to content

Commit f011ea8

Browse files
Merge pull request #666 from pyathena-dev/feature/native-asyncio-cursor
Add native asyncio cursor support (Phase 1)
2 parents a13bd3c + 08b1025 commit f011ea8

File tree

13 files changed

+1161
-5
lines changed

13 files changed

+1161
-5
lines changed

pyathena/__init__.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pyathena.error import * # noqa
88

99
if TYPE_CHECKING:
10+
from pyathena.aio.connection import AioConnection
1011
from pyathena.connection import Connection, ConnectionCursor
1112
from pyathena.cursor import Cursor
1213

@@ -128,3 +129,32 @@ def connect(*args, **kwargs) -> "Connection[Any]":
128129
from pyathena.connection import Connection
129130

130131
return Connection(*args, **kwargs)
132+
133+
134+
async def aconnect(*args, **kwargs) -> "AioConnection":
135+
"""Create a new async database connection to Amazon Athena.
136+
137+
This is the async counterpart of :func:`connect`. It returns an
138+
``AioConnection`` whose cursors use native ``asyncio`` for polling
139+
and API calls, keeping the event loop free.
140+
141+
Args:
142+
**kwargs: Arguments forwarded to ``AioConnection.create()``.
143+
See :func:`connect` for the full list of supported arguments.
144+
145+
Returns:
146+
An ``AioConnection`` that produces ``AioCursor`` instances by default.
147+
148+
Example:
149+
>>> import pyathena
150+
>>> conn = await pyathena.aconnect(
151+
... s3_staging_dir='s3://my-bucket/staging/',
152+
... region_name='us-east-1',
153+
... )
154+
>>> async with conn.cursor() as cursor:
155+
... await cursor.execute("SELECT 1")
156+
... print(await cursor.fetchone())
157+
"""
158+
from pyathena.aio.connection import AioConnection
159+
160+
return await AioConnection.create(*args, **kwargs)

pyathena/aio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations

pyathena/aio/common.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
4+
import asyncio
5+
import logging
6+
import sys
7+
from datetime import datetime, timedelta, timezone
8+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
9+
10+
import pyathena
11+
from pyathena.aio.util import async_retry_api_call
12+
from pyathena.common import BaseCursor
13+
from pyathena.error import DatabaseError, OperationalError
14+
from pyathena.model import AthenaQueryExecution
15+
16+
_logger = logging.getLogger(__name__) # type: ignore
17+
18+
19+
class AioBaseCursor(BaseCursor):
20+
"""Async base cursor that overrides I/O methods with async equivalents.
21+
22+
Reuses ``BaseCursor.__init__``, all ``_build_*`` methods, and constants.
23+
Only the methods that perform network I/O or blocking sleep are overridden
24+
to use ``asyncio.to_thread`` / ``asyncio.sleep``.
25+
"""
26+
27+
async def _execute( # type: ignore[override]
28+
self,
29+
operation: str,
30+
parameters: Optional[Union[Dict[str, Any], List[str]]] = None,
31+
work_group: Optional[str] = None,
32+
s3_staging_dir: Optional[str] = None,
33+
cache_size: Optional[int] = 0,
34+
cache_expiration_time: Optional[int] = 0,
35+
result_reuse_enable: Optional[bool] = None,
36+
result_reuse_minutes: Optional[int] = None,
37+
paramstyle: Optional[str] = None,
38+
) -> str:
39+
if pyathena.paramstyle == "qmark" or paramstyle == "qmark":
40+
query = operation
41+
execution_parameters = cast(Optional[List[str]], parameters)
42+
else:
43+
query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters))
44+
execution_parameters = None
45+
_logger.debug(query)
46+
47+
request = self._build_start_query_execution_request(
48+
query=query,
49+
work_group=work_group,
50+
s3_staging_dir=s3_staging_dir,
51+
result_reuse_enable=result_reuse_enable,
52+
result_reuse_minutes=result_reuse_minutes,
53+
execution_parameters=execution_parameters,
54+
)
55+
query_id = await self._find_previous_query_id(
56+
query,
57+
work_group,
58+
cache_size=cache_size if cache_size else 0,
59+
cache_expiration_time=cache_expiration_time if cache_expiration_time else 0,
60+
)
61+
if query_id is None:
62+
try:
63+
response = await async_retry_api_call(
64+
self._connection.client.start_query_execution,
65+
config=self._retry_config,
66+
logger=_logger,
67+
**request,
68+
)
69+
query_id = response.get("QueryExecutionId")
70+
except Exception as e:
71+
_logger.exception("Failed to execute query.")
72+
raise DatabaseError(*e.args) from e
73+
return query_id
74+
75+
async def _get_query_execution(self, query_id: str) -> AthenaQueryExecution: # type: ignore[override]
76+
request = {"QueryExecutionId": query_id}
77+
try:
78+
response = await async_retry_api_call(
79+
self._connection.client.get_query_execution,
80+
config=self._retry_config,
81+
logger=_logger,
82+
**request,
83+
)
84+
except Exception as e:
85+
_logger.exception("Failed to get query execution.")
86+
raise OperationalError(*e.args) from e
87+
else:
88+
return AthenaQueryExecution(response)
89+
90+
async def __poll(self, query_id: str) -> AthenaQueryExecution:
91+
while True:
92+
query_execution = await self._get_query_execution(query_id)
93+
if query_execution.state in [
94+
AthenaQueryExecution.STATE_SUCCEEDED,
95+
AthenaQueryExecution.STATE_FAILED,
96+
AthenaQueryExecution.STATE_CANCELLED,
97+
]:
98+
return query_execution
99+
await asyncio.sleep(self._poll_interval)
100+
101+
async def _poll(self, query_id: str) -> AthenaQueryExecution: # type: ignore[override]
102+
try:
103+
query_execution = await self.__poll(query_id)
104+
except asyncio.CancelledError:
105+
if self._kill_on_interrupt:
106+
_logger.warning("Query canceled by user.")
107+
await self._cancel(query_id)
108+
query_execution = await self.__poll(query_id)
109+
else:
110+
raise
111+
return query_execution
112+
113+
async def _cancel(self, query_id: str) -> None: # type: ignore[override]
114+
request = {"QueryExecutionId": query_id}
115+
try:
116+
await async_retry_api_call(
117+
self._connection.client.stop_query_execution,
118+
config=self._retry_config,
119+
logger=_logger,
120+
**request,
121+
)
122+
except Exception as e:
123+
_logger.exception("Failed to cancel query.")
124+
raise OperationalError(*e.args) from e
125+
126+
async def _batch_get_query_execution( # type: ignore[override]
127+
self, query_ids: List[str]
128+
) -> List[AthenaQueryExecution]:
129+
try:
130+
response = await async_retry_api_call(
131+
self.connection._client.batch_get_query_execution,
132+
config=self._retry_config,
133+
logger=_logger,
134+
QueryExecutionIds=query_ids,
135+
)
136+
except Exception as e:
137+
_logger.exception("Failed to batch get query execution.")
138+
raise OperationalError(*e.args) from e
139+
else:
140+
return [
141+
AthenaQueryExecution({"QueryExecution": r})
142+
for r in response.get("QueryExecutions", [])
143+
]
144+
145+
async def _list_query_executions( # type: ignore[override]
146+
self,
147+
work_group: Optional[str] = None,
148+
next_token: Optional[str] = None,
149+
max_results: Optional[int] = None,
150+
) -> Tuple[Optional[str], List[AthenaQueryExecution]]:
151+
request = self._build_list_query_executions_request(
152+
work_group=work_group, next_token=next_token, max_results=max_results
153+
)
154+
try:
155+
response = await async_retry_api_call(
156+
self.connection._client.list_query_executions,
157+
config=self._retry_config,
158+
logger=_logger,
159+
**request,
160+
)
161+
except Exception as e:
162+
_logger.exception("Failed to list query executions.")
163+
raise OperationalError(*e.args) from e
164+
else:
165+
next_token = response.get("NextToken")
166+
query_ids = response.get("QueryExecutionIds")
167+
if not query_ids:
168+
return next_token, []
169+
return next_token, await self._batch_get_query_execution(query_ids)
170+
171+
async def _find_previous_query_id( # type: ignore[override]
172+
self,
173+
query: str,
174+
work_group: Optional[str],
175+
cache_size: int = 0,
176+
cache_expiration_time: int = 0,
177+
) -> Optional[str]:
178+
query_id = None
179+
if cache_size == 0 and cache_expiration_time > 0:
180+
cache_size = sys.maxsize
181+
if cache_expiration_time > 0:
182+
expiration_time = datetime.now(timezone.utc) - timedelta(seconds=cache_expiration_time)
183+
else:
184+
expiration_time = datetime.now(timezone.utc)
185+
try:
186+
next_token = None
187+
while cache_size > 0:
188+
max_results = min(cache_size, self.LIST_QUERY_EXECUTIONS_MAX_RESULTS)
189+
cache_size -= max_results
190+
next_token, query_executions = await self._list_query_executions(
191+
work_group, next_token=next_token, max_results=max_results
192+
)
193+
for execution in sorted(
194+
(
195+
e
196+
for e in query_executions
197+
if e.state == AthenaQueryExecution.STATE_SUCCEEDED
198+
and e.statement_type == AthenaQueryExecution.STATEMENT_TYPE_DML
199+
),
200+
key=lambda e: e.completion_date_time, # type: ignore
201+
reverse=True,
202+
):
203+
if (
204+
cache_expiration_time > 0
205+
and execution.completion_date_time
206+
and execution.completion_date_time.astimezone(timezone.utc)
207+
< expiration_time
208+
):
209+
next_token = None
210+
break
211+
if execution.query == query:
212+
query_id = execution.query_id
213+
break
214+
if query_id or next_token is None:
215+
break
216+
except Exception:
217+
_logger.warning("Failed to check the cache. Moving on without cache.", exc_info=True)
218+
return query_id

pyathena/aio/connection.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
4+
import asyncio
5+
from typing import Any
6+
7+
from pyathena.aio.cursor import AioCursor
8+
from pyathena.connection import Connection
9+
10+
11+
class AioConnection(Connection[AioCursor]):
12+
"""Async-aware connection to Amazon Athena.
13+
14+
Wraps the synchronous ``Connection`` with async context manager support
15+
and provides ``create()`` for non-blocking initialization.
16+
17+
Example:
18+
>>> async with await AioConnection.create(
19+
... s3_staging_dir="s3://bucket/path/",
20+
... region_name="us-east-1",
21+
... ) as conn:
22+
... async with conn.cursor() as cursor:
23+
... await cursor.execute("SELECT 1")
24+
... print(await cursor.fetchone())
25+
"""
26+
27+
def __init__(self, **kwargs: Any) -> None:
28+
if "cursor_class" not in kwargs:
29+
kwargs["cursor_class"] = AioCursor
30+
super().__init__(**kwargs)
31+
32+
@classmethod
33+
async def create(
34+
cls,
35+
**kwargs: Any,
36+
) -> "AioConnection":
37+
"""Async factory for creating an ``AioConnection``.
38+
39+
Runs the (potentially blocking) ``__init__`` in a thread so that
40+
STS calls (``role_arn`` / ``serial_number``) do not block the loop.
41+
42+
Args:
43+
**kwargs: Arguments forwarded to ``AioConnection.__init__``.
44+
45+
Returns:
46+
A fully initialized ``AioConnection``.
47+
"""
48+
return await asyncio.to_thread(cls, **kwargs)
49+
50+
async def __aenter__(self) -> "AioConnection":
51+
return self
52+
53+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
54+
self.close()

0 commit comments

Comments
 (0)