Skip to content

Commit 4b9e52b

Browse files
Add native asyncio specialized cursors and DRY refactors (Phase 2)
Phase 2 of native asyncio cursor support: - DRY: Add _pre_fetch flag to AthenaResultSet.__init__ to eliminate duplicated field initialization in AthenaAioResultSet - DRY: Extract _prepare_query() from BaseCursor._execute so both sync and async _execute share non-I/O query preparation logic - Add async metadata operations (list_databases, get_table_metadata, list_table_metadata) to AioBaseCursor - Add AioPandasCursor, AioArrowCursor, AioPolarsCursor using asyncio.to_thread() for result set creation (no new result set classes needed since fetch methods are in-memory only) - Add comprehensive tests for all new cursors and metadata ops Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f011ea8 commit 4b9e52b

File tree

11 files changed

+1381
-47
lines changed

11 files changed

+1381
-47
lines changed

pyathena/aio/common.py

Lines changed: 142 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import logging
66
import sys
77
from datetime import datetime, timedelta, timezone
8-
from typing import Any, Dict, List, Optional, Tuple, Union, cast
8+
from typing import Any, Dict, List, Optional, Tuple, Union
99

10-
import pyathena
1110
from pyathena.aio.util import async_retry_api_call
1211
from pyathena.common import BaseCursor
1312
from pyathena.error import DatabaseError, OperationalError
14-
from pyathena.model import AthenaQueryExecution
13+
from pyathena.model import AthenaDatabase, AthenaQueryExecution, AthenaTableMetadata
1514

1615
_logger = logging.getLogger(__name__) # type: ignore
1716

@@ -36,13 +35,7 @@ async def _execute( # type: ignore[override]
3635
result_reuse_minutes: Optional[int] = None,
3736
paramstyle: Optional[str] = None,
3837
) -> str:
39-
if pyathena.paramstyle == "qmark" or paramstyle == "qmark":
40-
query = operation
41-
execution_parameters = cast(Optional[List[str]], parameters)
42-
else:
43-
query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters))
44-
execution_parameters = None
45-
_logger.debug(query)
38+
query, execution_parameters = self._prepare_query(operation, parameters, paramstyle)
4639

4740
request = self._build_start_query_execution_request(
4841
query=query,
@@ -216,3 +209,142 @@ async def _find_previous_query_id( # type: ignore[override]
216209
except Exception:
217210
_logger.warning("Failed to check the cache. Moving on without cache.", exc_info=True)
218211
return query_id
212+
213+
async def _list_databases( # type: ignore[override]
214+
self,
215+
catalog_name: Optional[str],
216+
next_token: Optional[str] = None,
217+
max_results: Optional[int] = None,
218+
) -> Tuple[Optional[str], List[AthenaDatabase]]:
219+
request = self._build_list_databases_request(
220+
catalog_name=catalog_name,
221+
next_token=next_token,
222+
max_results=max_results,
223+
)
224+
try:
225+
response = await async_retry_api_call(
226+
self.connection._client.list_databases,
227+
config=self._retry_config,
228+
logger=_logger,
229+
**request,
230+
)
231+
except Exception as e:
232+
_logger.exception("Failed to list databases.")
233+
raise OperationalError(*e.args) from e
234+
else:
235+
return response.get("NextToken"), [
236+
AthenaDatabase({"Database": r}) for r in response.get("DatabaseList", [])
237+
]
238+
239+
async def list_databases( # type: ignore[override]
240+
self,
241+
catalog_name: Optional[str],
242+
max_results: Optional[int] = None,
243+
) -> List[AthenaDatabase]:
244+
databases: List[AthenaDatabase] = []
245+
next_token = None
246+
while True:
247+
next_token, response = await self._list_databases(
248+
catalog_name=catalog_name,
249+
next_token=next_token,
250+
max_results=max_results,
251+
)
252+
databases.extend(response)
253+
if not next_token:
254+
break
255+
return databases
256+
257+
async def _get_table_metadata( # type: ignore[override]
258+
self,
259+
table_name: str,
260+
catalog_name: Optional[str] = None,
261+
schema_name: Optional[str] = None,
262+
logging_: bool = True,
263+
) -> AthenaTableMetadata:
264+
request: Dict[str, Any] = {
265+
"CatalogName": catalog_name if catalog_name else self._catalog_name,
266+
"DatabaseName": schema_name if schema_name else self._schema_name,
267+
"TableName": table_name,
268+
}
269+
if self._work_group:
270+
request.update({"WorkGroup": self._work_group})
271+
try:
272+
response = await async_retry_api_call(
273+
self._connection.client.get_table_metadata,
274+
config=self._retry_config,
275+
logger=_logger,
276+
**request,
277+
)
278+
except Exception as e:
279+
if logging_:
280+
_logger.exception("Failed to get table metadata.")
281+
raise OperationalError(*e.args) from e
282+
else:
283+
return AthenaTableMetadata(response)
284+
285+
async def get_table_metadata( # type: ignore[override]
286+
self,
287+
table_name: str,
288+
catalog_name: Optional[str] = None,
289+
schema_name: Optional[str] = None,
290+
logging_: bool = True,
291+
) -> AthenaTableMetadata:
292+
return await self._get_table_metadata(
293+
table_name=table_name,
294+
catalog_name=catalog_name,
295+
schema_name=schema_name,
296+
logging_=logging_,
297+
)
298+
299+
async def _list_table_metadata( # type: ignore[override]
300+
self,
301+
catalog_name: Optional[str] = None,
302+
schema_name: Optional[str] = None,
303+
expression: Optional[str] = None,
304+
next_token: Optional[str] = None,
305+
max_results: Optional[int] = None,
306+
) -> Tuple[Optional[str], List[AthenaTableMetadata]]:
307+
request = self._build_list_table_metadata_request(
308+
catalog_name=catalog_name,
309+
schema_name=schema_name,
310+
expression=expression,
311+
next_token=next_token,
312+
max_results=max_results,
313+
)
314+
try:
315+
response = await async_retry_api_call(
316+
self.connection._client.list_table_metadata,
317+
config=self._retry_config,
318+
logger=_logger,
319+
**request,
320+
)
321+
except Exception as e:
322+
_logger.exception("Failed to list table metadata.")
323+
raise OperationalError(*e.args) from e
324+
else:
325+
return response.get("NextToken"), [
326+
AthenaTableMetadata({"TableMetadata": r})
327+
for r in response.get("TableMetadataList", [])
328+
]
329+
330+
async def list_table_metadata( # type: ignore[override]
331+
self,
332+
catalog_name: Optional[str] = None,
333+
schema_name: Optional[str] = None,
334+
expression: Optional[str] = None,
335+
max_results: Optional[int] = None,
336+
) -> List[AthenaTableMetadata]:
337+
metadata: List[AthenaTableMetadata] = []
338+
next_token = None
339+
while True:
340+
next_token, response = await self._list_table_metadata(
341+
catalog_name=catalog_name,
342+
schema_name=schema_name,
343+
expression=expression,
344+
next_token=next_token,
345+
max_results=max_results,
346+
)
347+
metadata.extend(response)
348+
if not next_token:
349+
break
350+
return metadata

pyathena/aio/result_set.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
import collections
54
import logging
65
from typing import (
76
TYPE_CHECKING,
87
Any,
9-
Deque,
108
Dict,
119
List,
1210
Optional,
@@ -16,7 +14,6 @@
1614
)
1715

1816
from pyathena.aio.util import async_retry_api_call
19-
from pyathena.common import CursorIterator
2017
from pyathena.converter import Converter
2118
from pyathena.error import OperationalError, ProgrammingError
2219
from pyathena.model import AthenaQueryExecution
@@ -32,9 +29,9 @@
3229
class AthenaAioResultSet(AthenaResultSet):
3330
"""Async result set that provides async fetch methods.
3431
35-
Because ``AthenaResultSet.__init__`` calls ``_pre_fetch`` (a blocking API
36-
call), this class overrides ``__init__`` to skip it and provides an
37-
``async create()`` classmethod factory instead.
32+
Skips the synchronous ``_pre_fetch`` by passing ``_pre_fetch=False`` to
33+
the parent ``__init__`` and provides an ``async create()`` classmethod
34+
factory instead.
3835
"""
3936

4037
def __init__(
@@ -45,31 +42,15 @@ def __init__(
4542
arraysize: int,
4643
retry_config: RetryConfig,
4744
) -> None:
48-
# Replicate parent field initialization without calling _pre_fetch.
49-
CursorIterator.__init__(self, arraysize=arraysize)
50-
self._connection: Optional["Connection[Any]"] = connection
51-
self._converter = converter
52-
self._query_execution: Optional[AthenaQueryExecution] = query_execution
53-
if not self._query_execution:
54-
raise ProgrammingError("Required argument `query_execution` not found.")
55-
self._retry_config = retry_config
56-
self._client = connection.session.client(
57-
"s3",
58-
region_name=connection.region_name,
59-
config=connection.config,
60-
**connection._client_kwargs,
45+
super().__init__(
46+
connection=connection,
47+
converter=converter,
48+
query_execution=query_execution,
49+
arraysize=arraysize,
50+
retry_config=retry_config,
51+
_pre_fetch=False,
6152
)
6253

63-
self._metadata: Optional[Tuple[Dict[str, Any], ...]] = None
64-
self._rows: Deque[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]] = (
65-
collections.deque()
66-
)
67-
self._next_token: Optional[str] = None
68-
69-
if self.state == AthenaQueryExecution.STATE_SUCCEEDED:
70-
self._rownumber = 0
71-
# NOTE: _pre_fetch is NOT called here; use create() instead.
72-
7354
@classmethod
7455
async def create(
7556
cls,

0 commit comments

Comments
 (0)