Skip to content

Commit bf8976c

Browse files
Add tests for async SQLAlchemy dialects
- Add representative async tests in tests/pyathena/aio/sqlalchemy/: basic query, reflection, schema inspection, dialect properties - Add ASYNC_SQLALCHEMY_CONNECTION_STRING to tests/__init__.py - Register async dialects in tests/sqlalchemy/__init__.py - Support --dburi async in tests/sqlalchemy/conftest.py for running the SQLAlchemy standard test suite with async dialects - Add make test-sqla-async target Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 37741f4 commit bf8976c

File tree

6 files changed

+178
-3
lines changed

6 files changed

+178
-3
lines changed

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ test: chk
2121
test-sqla:
2222
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/sqlalchemy/
2323

24+
.PHONY: test-sqla-async
25+
test-sqla-async:
26+
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/sqlalchemy/ --dburi async
27+
2428
.PHONY: tox
2529
tox:
2630
uvx tox@$(TOX_VERSION) -c pyproject.toml run

tests/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
"awsathena+rest://athena.{region_name}.amazonaws.com:443/"
88
"{schema_name}?s3_staging_dir={s3_staging_dir}&location={location}"
99
)
10+
ASYNC_SQLALCHEMY_CONNECTION_STRING = (
11+
"awsathena+aiorest://athena.{region_name}.amazonaws.com:443/"
12+
"{schema_name}?s3_staging_dir={s3_staging_dir}&location={location}"
13+
)
1014

1115

1216
class Env:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# -*- coding: utf-8 -*-
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# -*- coding: utf-8 -*-
2+
import pytest
3+
import sqlalchemy
4+
from sqlalchemy import text
5+
from sqlalchemy.ext.asyncio import create_async_engine
6+
from sqlalchemy.sql.schema import MetaData, Table
7+
8+
from tests import ASYNC_SQLALCHEMY_CONNECTION_STRING, ENV
9+
10+
11+
def _async_conn_str(**kwargs):
12+
conn_str = ASYNC_SQLALCHEMY_CONNECTION_STRING
13+
return conn_str.format(
14+
region_name=ENV.region_name,
15+
schema_name=ENV.schema,
16+
s3_staging_dir=ENV.s3_staging_dir,
17+
location=ENV.s3_staging_dir,
18+
**kwargs,
19+
)
20+
21+
22+
@pytest.fixture
23+
async def async_engine():
24+
engine = create_async_engine(_async_conn_str())
25+
try:
26+
async with engine.connect() as conn:
27+
yield engine, conn
28+
finally:
29+
await engine.dispose()
30+
31+
32+
class TestAsyncSQLAlchemyAthena:
33+
async def test_basic_query(self, async_engine):
34+
engine, conn = async_engine
35+
rows = (await conn.execute(text("SELECT * FROM one_row"))).fetchall()
36+
assert len(rows) == 1
37+
assert rows[0].number_of_rows == 1
38+
assert len(rows[0]) == 1
39+
40+
async def test_unicode(self, async_engine):
41+
engine, conn = async_engine
42+
unicode_str = "密林"
43+
returned_str = (
44+
await conn.execute(
45+
sqlalchemy.select(
46+
sqlalchemy.sql.expression.bindparam(
47+
"あまぞん", unicode_str, type_=sqlalchemy.types.String()
48+
)
49+
)
50+
)
51+
).scalar()
52+
assert returned_str == unicode_str
53+
54+
async def test_reflect_table(self, async_engine):
55+
engine, conn = async_engine
56+
one_row = await conn.run_sync(
57+
lambda sync_conn: Table("one_row", MetaData(schema=ENV.schema), autoload_with=sync_conn)
58+
)
59+
assert len(one_row.c) == 1
60+
assert one_row.c.number_of_rows is not None
61+
assert one_row.comment == "table comment"
62+
63+
async def test_reflect_schemas(self, async_engine):
64+
engine, conn = async_engine
65+
66+
def _inspect(sync_conn):
67+
insp = sqlalchemy.inspect(sync_conn)
68+
return insp.get_schema_names()
69+
70+
schemas = await conn.run_sync(_inspect)
71+
assert ENV.schema in schemas
72+
assert "default" in schemas
73+
74+
async def test_get_table_names(self, async_engine):
75+
engine, conn = async_engine
76+
77+
def _inspect(sync_conn):
78+
insp = sqlalchemy.inspect(sync_conn)
79+
return insp.get_table_names(schema=ENV.schema)
80+
81+
table_names = await conn.run_sync(_inspect)
82+
assert "many_rows" in table_names
83+
84+
async def test_has_table(self, async_engine):
85+
engine, conn = async_engine
86+
87+
def _inspect(sync_conn):
88+
insp = sqlalchemy.inspect(sync_conn)
89+
return (
90+
insp.has_table("one_row", schema=ENV.schema),
91+
insp.has_table("this_table_does_not_exist", schema=ENV.schema),
92+
)
93+
94+
exists, not_exists = await conn.run_sync(_inspect)
95+
assert exists
96+
assert not not_exists
97+
98+
async def test_get_columns(self, async_engine):
99+
engine, conn = async_engine
100+
101+
def _inspect(sync_conn):
102+
insp = sqlalchemy.inspect(sync_conn)
103+
return insp.get_columns(table_name="one_row", schema=ENV.schema)
104+
105+
columns = await conn.run_sync(_inspect)
106+
actual = columns[0]
107+
assert actual["name"] == "number_of_rows"
108+
assert isinstance(actual["type"], sqlalchemy.types.INTEGER)
109+
assert actual["nullable"]
110+
assert actual["default"] is None
111+
assert not actual["autoincrement"]
112+
assert actual["comment"] == "some comment"
113+
114+
115+
class TestAsyncDialectProperties:
116+
async def test_aiorest_dialect(self):
117+
engine = create_async_engine(_async_conn_str())
118+
try:
119+
assert engine.dialect.is_async is True
120+
assert engine.dialect.driver == "aiorest"
121+
assert engine.dialect.supports_statement_cache is True
122+
finally:
123+
await engine.dispose()
124+
125+
async def test_aiopandas_dialect(self):
126+
conn_str = _async_conn_str().replace("+aiorest", "+aiopandas")
127+
engine = create_async_engine(conn_str)
128+
try:
129+
assert engine.dialect.is_async is True
130+
assert engine.dialect.driver == "aiopandas"
131+
finally:
132+
await engine.dispose()
133+
134+
async def test_aioarrow_dialect(self):
135+
conn_str = _async_conn_str().replace("+aiorest", "+aioarrow")
136+
engine = create_async_engine(conn_str)
137+
try:
138+
assert engine.dialect.is_async is True
139+
assert engine.dialect.driver == "aioarrow"
140+
finally:
141+
await engine.dispose()
142+
143+
async def test_aiopolars_dialect(self):
144+
conn_str = _async_conn_str().replace("+aiorest", "+aiopolars")
145+
engine = create_async_engine(conn_str)
146+
try:
147+
assert engine.dialect.is_async is True
148+
assert engine.dialect.driver == "aiopolars"
149+
finally:
150+
await engine.dispose()
151+
152+
async def test_aios3fs_dialect(self):
153+
conn_str = _async_conn_str().replace("+aiorest", "+aios3fs")
154+
engine = create_async_engine(conn_str)
155+
try:
156+
assert engine.dialect.is_async is True
157+
assert engine.dialect.driver == "aios3fs"
158+
finally:
159+
await engine.dispose()

tests/sqlalchemy/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,8 @@
66
registry.register("awsathena.pandas", "pyathena.sqlalchemy.pandas", "AthenaPandasDialect")
77
registry.register("awsathena.arrow", "pyathena.sqlalchemy.arrow", "AthenaArrowDialect")
88
registry.register("awsathena.s3fs", "pyathena.sqlalchemy.s3fs", "AthenaS3FSDialect")
9+
registry.register("awsathena.aiorest", "pyathena.aio.sqlalchemy.rest", "AthenaAioRestDialect")
10+
registry.register("awsathena.aiopandas", "pyathena.aio.sqlalchemy.pandas", "AthenaAioPandasDialect")
11+
registry.register("awsathena.aioarrow", "pyathena.aio.sqlalchemy.arrow", "AthenaAioArrowDialect")
12+
registry.register("awsathena.aiopolars", "pyathena.aio.sqlalchemy.polars", "AthenaAioPolarsDialect")
13+
registry.register("awsathena.aios3fs", "pyathena.aio.sqlalchemy.s3fs", "AthenaAioS3FSDialect")

tests/sqlalchemy/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
temp_table_keyword_args,
1414
)
1515

16-
from tests import ENV, SQLALCHEMY_CONNECTION_STRING
16+
from tests import ASYNC_SQLALCHEMY_CONNECTION_STRING, ENV, SQLALCHEMY_CONNECTION_STRING
1717

1818

1919
def pytest_sessionstart(session):
20-
conn_str = (
21-
SQLALCHEMY_CONNECTION_STRING + "&tblproperties=" + quote_plus("'table_type'='ICEBERG'")
20+
use_async = session.config.getoption("--dburi", None) == ["async"]
21+
base_conn_str = (
22+
ASYNC_SQLALCHEMY_CONNECTION_STRING if use_async else SQLALCHEMY_CONNECTION_STRING
2223
)
24+
conn_str = base_conn_str + "&tblproperties=" + quote_plus("'table_type'='ICEBERG'")
2325
session.config.option.dburi = [
2426
conn_str.format(
2527
region_name=ENV.region_name,

0 commit comments

Comments
 (0)