Skip to content

Commit 35e6cc0

Browse files
Add async SQLAlchemy dialects for native asyncio support (#673)
Enable `create_async_engine` usage by adding 5 async dialects that mirror the existing sync ones, using SQLAlchemy's AdaptedConnection and greenlet-based await_only() bridge pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 5651427 commit 35e6cc0

File tree

7 files changed

+463
-0
lines changed

7 files changed

+463
-0
lines changed

pyathena/sqlalchemy/async_arrow.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import TYPE_CHECKING
3+
4+
from pyathena.sqlalchemy.async_base import AthenaAioDialect
5+
from pyathena.util import strtobool
6+
7+
if TYPE_CHECKING:
8+
from types import ModuleType
9+
10+
11+
class AthenaAioArrowDialect(AthenaAioDialect):
12+
"""Async SQLAlchemy dialect for Amazon Athena with Apache Arrow result format.
13+
14+
This dialect uses ``AioArrowCursor`` for native asyncio query execution
15+
with Apache Arrow Table results.
16+
17+
Connection URL Format:
18+
``awsathena+aioarrow://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}``
19+
20+
Query Parameters:
21+
In addition to the base dialect parameters:
22+
- unload: If "true", use UNLOAD for Parquet output
23+
24+
Example:
25+
>>> from sqlalchemy.ext.asyncio import create_async_engine
26+
>>> engine = create_async_engine(
27+
... "awsathena+aioarrow://:@athena.us-west-2.amazonaws.com/default"
28+
... "?s3_staging_dir=s3://my-bucket/athena-results/"
29+
... "&unload=true"
30+
... )
31+
32+
See Also:
33+
:class:`~pyathena.aio.arrow.cursor.AioArrowCursor`: The underlying async cursor.
34+
:class:`~pyathena.sqlalchemy.async_base.AthenaAioDialect`: Base async dialect.
35+
"""
36+
37+
driver = "aioarrow"
38+
supports_statement_cache = True
39+
40+
def create_connect_args(self, url):
41+
from pyathena.aio.arrow.cursor import AioArrowCursor
42+
43+
opts = super()._create_connect_args(url)
44+
opts.update({"cursor_class": AioArrowCursor})
45+
cursor_kwargs = {}
46+
if "unload" in opts:
47+
cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))})
48+
if cursor_kwargs:
49+
opts.update({"cursor_kwargs": cursor_kwargs})
50+
self._connect_options = opts
51+
return [[], opts]
52+
53+
@classmethod
54+
def import_dbapi(cls) -> "ModuleType":
55+
return super().import_dbapi()

pyathena/sqlalchemy/async_base.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING, Any, Dict, List, MutableMapping, Optional, Tuple, Union, cast
5+
6+
from sqlalchemy.engine import AdaptedConnection
7+
from sqlalchemy.util.concurrency import await_only
8+
9+
import pyathena
10+
from pyathena.aio.connection import AioConnection
11+
from pyathena.error import (
12+
DatabaseError,
13+
DataError,
14+
Error,
15+
IntegrityError,
16+
InterfaceError,
17+
InternalError,
18+
NotSupportedError,
19+
OperationalError,
20+
ProgrammingError,
21+
)
22+
from pyathena.sqlalchemy.base import AthenaDialect
23+
24+
if TYPE_CHECKING:
25+
from types import ModuleType
26+
27+
from sqlalchemy import URL
28+
29+
30+
class AsyncAdaptPyathenaCursor:
31+
"""Wraps any async PyAthena cursor with a sync DBAPI interface.
32+
33+
SQLAlchemy's async engine uses greenlet-based ``await_only()`` to call
34+
async methods from synchronous code running inside the greenlet context.
35+
This adapter wraps an ``AioCursor`` (or variant) so that the dialect can
36+
use a normal synchronous DBAPI interface while the underlying I/O is async.
37+
"""
38+
39+
server_side = False
40+
__slots__ = ("_cursor",)
41+
42+
def __init__(self, cursor: Any) -> None:
43+
self._cursor = cursor
44+
45+
@property
46+
def description(self) -> Any:
47+
return self._cursor.description
48+
49+
@property
50+
def rowcount(self) -> int:
51+
return self._cursor.rowcount # type: ignore[no-any-return]
52+
53+
def close(self) -> None:
54+
self._cursor.close()
55+
56+
def execute(self, operation: str, parameters: Any = None, **kwargs: Any) -> Any:
57+
return await_only(self._cursor.execute(operation, parameters, **kwargs))
58+
59+
def executemany(
60+
self,
61+
operation: str,
62+
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
63+
**kwargs: Any,
64+
) -> None:
65+
for parameters in seq_of_parameters:
66+
await_only(self._cursor.execute(operation, parameters, **kwargs))
67+
68+
def fetchone(self) -> Any:
69+
return await_only(self._cursor.fetchone())
70+
71+
def fetchmany(self, size: Optional[int] = None) -> Any:
72+
return await_only(self._cursor.fetchmany(size))
73+
74+
def fetchall(self) -> Any:
75+
return await_only(self._cursor.fetchall())
76+
77+
def setinputsizes(self, sizes: Any) -> None:
78+
self._cursor.setinputsizes(sizes)
79+
80+
# PyAthena-specific methods used by AthenaDialect reflection
81+
def list_databases(self, *args: Any, **kwargs: Any) -> Any:
82+
return await_only(self._cursor.list_databases(*args, **kwargs))
83+
84+
def get_table_metadata(self, *args: Any, **kwargs: Any) -> Any:
85+
return await_only(self._cursor.get_table_metadata(*args, **kwargs))
86+
87+
def list_table_metadata(self, *args: Any, **kwargs: Any) -> Any:
88+
return await_only(self._cursor.list_table_metadata(*args, **kwargs))
89+
90+
def __enter__(self) -> "AsyncAdaptPyathenaCursor":
91+
return self
92+
93+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
94+
self.close()
95+
96+
97+
class AsyncAdaptPyathenaConnection(AdaptedConnection):
98+
"""Wraps ``AioConnection`` with a sync DBAPI interface.
99+
100+
This adapted connection delegates ``cursor()`` to the underlying
101+
``AioConnection`` and wraps each returned async cursor with
102+
``AsyncAdaptPyathenaCursor``.
103+
"""
104+
105+
await_only_ = staticmethod(await_only)
106+
107+
__slots__ = ("dbapi", "_connection")
108+
109+
def __init__(self, dbapi: "AsyncAdaptPyathenaDbapi", connection: AioConnection) -> None:
110+
self.dbapi = dbapi
111+
self._connection = connection
112+
113+
@property
114+
def driver_connection(self) -> AioConnection:
115+
return self._connection # type: ignore[no-any-return]
116+
117+
@property
118+
def catalog_name(self) -> Optional[str]:
119+
return self._connection.catalog_name # type: ignore[no-any-return]
120+
121+
@property
122+
def schema_name(self) -> Optional[str]:
123+
return self._connection.schema_name # type: ignore[no-any-return]
124+
125+
def cursor(self) -> AsyncAdaptPyathenaCursor:
126+
raw_cursor = self._connection.cursor()
127+
return AsyncAdaptPyathenaCursor(raw_cursor)
128+
129+
def close(self) -> None:
130+
self._connection.close()
131+
132+
def commit(self) -> None:
133+
self._connection.commit()
134+
135+
def rollback(self) -> None:
136+
pass
137+
138+
139+
class AsyncAdaptPyathenaDbapi:
140+
"""Fake DBAPI module for the async SQLAlchemy engine.
141+
142+
SQLAlchemy expects ``import_dbapi()`` to return a module-like object
143+
with ``connect()``, ``paramstyle``, and the standard DBAPI exception
144+
hierarchy. This class fulfils that contract while routing connections
145+
through ``AioConnection``.
146+
"""
147+
148+
paramstyle = "pyformat"
149+
150+
# DBAPI exception hierarchy
151+
Error = Error
152+
Warning = pyathena.Warning
153+
InterfaceError = InterfaceError
154+
DatabaseError = DatabaseError
155+
InternalError = InternalError
156+
OperationalError = OperationalError
157+
ProgrammingError = ProgrammingError
158+
IntegrityError = IntegrityError
159+
DataError = DataError
160+
NotSupportedError = NotSupportedError
161+
162+
def connect(self, **kwargs: Any) -> AsyncAdaptPyathenaConnection:
163+
connection = await_only(AioConnection.create(**kwargs))
164+
return AsyncAdaptPyathenaConnection(self, connection)
165+
166+
167+
class AthenaAioDialect(AthenaDialect):
168+
"""Base async SQLAlchemy dialect for Amazon Athena.
169+
170+
Extends the synchronous ``AthenaDialect`` with async capability
171+
by setting ``is_async = True`` and providing an adapted DBAPI module
172+
that wraps ``AioConnection`` and async cursors.
173+
174+
Connection URL Format:
175+
``awsathena+aiorest://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}``
176+
177+
Example:
178+
>>> from sqlalchemy.ext.asyncio import create_async_engine
179+
>>> engine = create_async_engine(
180+
... "awsathena+aiorest://:@athena.us-west-2.amazonaws.com/default"
181+
... "?s3_staging_dir=s3://my-bucket/athena-results/"
182+
... )
183+
184+
See Also:
185+
:class:`~pyathena.sqlalchemy.base.AthenaDialect`: Synchronous base dialect.
186+
:class:`~pyathena.aio.connection.AioConnection`: Native async connection.
187+
"""
188+
189+
is_async = True
190+
supports_statement_cache = True
191+
192+
@classmethod
193+
def import_dbapi(cls) -> "ModuleType":
194+
return AsyncAdaptPyathenaDbapi() # type: ignore[return-value]
195+
196+
@classmethod
197+
def dbapi(cls) -> "ModuleType": # type: ignore[override]
198+
return AsyncAdaptPyathenaDbapi() # type: ignore[return-value]
199+
200+
def create_connect_args(self, url: "URL") -> Tuple[Tuple[str], MutableMapping[str, Any]]:
201+
opts = self._create_connect_args(url)
202+
self._connect_options = opts
203+
return cast(Tuple[str], ()), opts
204+
205+
def get_driver_connection(self, connection: Any) -> Any:
206+
return connection
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import TYPE_CHECKING
3+
4+
from pyathena.sqlalchemy.async_base import AthenaAioDialect
5+
from pyathena.util import strtobool
6+
7+
if TYPE_CHECKING:
8+
from types import ModuleType
9+
10+
11+
class AthenaAioPandasDialect(AthenaAioDialect):
12+
"""Async SQLAlchemy dialect for Amazon Athena with pandas DataFrame result format.
13+
14+
This dialect uses ``AioPandasCursor`` for native asyncio query execution
15+
with pandas DataFrame results.
16+
17+
Connection URL Format:
18+
``awsathena+aiopandas://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}``
19+
20+
Query Parameters:
21+
In addition to the base dialect parameters:
22+
- unload: If "true", use UNLOAD for Parquet output
23+
- engine: CSV parsing engine ("c", "python", or "pyarrow")
24+
- chunksize: Number of rows per chunk for memory-efficient processing
25+
26+
Example:
27+
>>> from sqlalchemy.ext.asyncio import create_async_engine
28+
>>> engine = create_async_engine(
29+
... "awsathena+aiopandas://:@athena.us-west-2.amazonaws.com/default"
30+
... "?s3_staging_dir=s3://my-bucket/athena-results/"
31+
... "&unload=true&chunksize=10000"
32+
... )
33+
34+
See Also:
35+
:class:`~pyathena.aio.pandas.cursor.AioPandasCursor`: The underlying async cursor.
36+
:class:`~pyathena.sqlalchemy.async_base.AthenaAioDialect`: Base async dialect.
37+
"""
38+
39+
driver = "aiopandas"
40+
supports_statement_cache = True
41+
42+
def create_connect_args(self, url):
43+
from pyathena.aio.pandas.cursor import AioPandasCursor
44+
45+
opts = super()._create_connect_args(url)
46+
opts.update({"cursor_class": AioPandasCursor})
47+
cursor_kwargs = {}
48+
if "unload" in opts:
49+
cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))})
50+
if "engine" in opts:
51+
cursor_kwargs.update({"engine": opts.pop("engine")})
52+
if "chunksize" in opts:
53+
cursor_kwargs.update({"chunksize": int(opts.pop("chunksize"))}) # type: ignore
54+
if cursor_kwargs:
55+
opts.update({"cursor_kwargs": cursor_kwargs})
56+
self._connect_options = opts
57+
return [[], opts]
58+
59+
@classmethod
60+
def import_dbapi(cls) -> "ModuleType":
61+
return super().import_dbapi()
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import TYPE_CHECKING
3+
4+
from pyathena.sqlalchemy.async_base import AthenaAioDialect
5+
from pyathena.util import strtobool
6+
7+
if TYPE_CHECKING:
8+
from types import ModuleType
9+
10+
11+
class AthenaAioPolarsDialect(AthenaAioDialect):
12+
"""Async SQLAlchemy dialect for Amazon Athena with Polars DataFrame result format.
13+
14+
This dialect uses ``AioPolarsCursor`` for native asyncio query execution
15+
with Polars DataFrame results.
16+
17+
Connection URL Format:
18+
``awsathena+aiopolars://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}``
19+
20+
Query Parameters:
21+
In addition to the base dialect parameters:
22+
- unload: If "true", use UNLOAD for Parquet output
23+
24+
Example:
25+
>>> from sqlalchemy.ext.asyncio import create_async_engine
26+
>>> engine = create_async_engine(
27+
... "awsathena+aiopolars://:@athena.us-west-2.amazonaws.com/default"
28+
... "?s3_staging_dir=s3://my-bucket/athena-results/"
29+
... "&unload=true"
30+
... )
31+
32+
See Also:
33+
:class:`~pyathena.aio.polars.cursor.AioPolarsCursor`: The underlying async cursor.
34+
:class:`~pyathena.sqlalchemy.async_base.AthenaAioDialect`: Base async dialect.
35+
"""
36+
37+
driver = "aiopolars"
38+
supports_statement_cache = True
39+
40+
def create_connect_args(self, url):
41+
from pyathena.aio.polars.cursor import AioPolarsCursor
42+
43+
opts = super()._create_connect_args(url)
44+
opts.update({"cursor_class": AioPolarsCursor})
45+
cursor_kwargs = {}
46+
if "unload" in opts:
47+
cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))})
48+
if cursor_kwargs:
49+
opts.update({"cursor_kwargs": cursor_kwargs})
50+
self._connect_options = opts
51+
return [[], opts]
52+
53+
@classmethod
54+
def import_dbapi(cls) -> "ModuleType":
55+
return super().import_dbapi()

0 commit comments

Comments
 (0)