Skip to content

Commit 65c093a

Browse files
Fix AioSparkCursor bugs and add missing tests
- Change _calculate exception from OperationalError to DatabaseError to match sync SparkBaseCursor behavior - Add retry logic to _read_s3_file_as_text via async_retry_api_call to match sync version's retry_api_call usage - Remove incorrect # type: ignore[override] on name-mangled __poll - Add test_async_iterator for AioS3FSCursor - Add test_executemany and test_context_manager for AioSparkCursor - Move runtime import to top-level in S3FS test file Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ddc401e commit 65c093a

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

pyathena/aio/spark/cursor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Dict, List, Optional, Union, cast
77

88
from pyathena.aio.util import async_retry_api_call
9-
from pyathena.error import NotSupportedError, OperationalError, ProgrammingError
9+
from pyathena.error import DatabaseError, NotSupportedError, OperationalError, ProgrammingError
1010
from pyathena.model import (
1111
AthenaCalculationExecution,
1212
AthenaCalculationExecutionStatus,
@@ -123,10 +123,10 @@ async def _calculate( # type: ignore[override]
123123
calculation_id = response.get("CalculationExecutionId")
124124
except Exception as e:
125125
_logger.exception("Failed to execute calculation.")
126-
raise OperationalError(*e.args) from e
126+
raise DatabaseError(*e.args) from e
127127
return cast(str, calculation_id)
128128

129-
async def __poll( # type: ignore[override]
129+
async def __poll(
130130
self, query_id: str
131131
) -> Union[AthenaQueryExecution, AthenaCalculationExecution]:
132132
while True:
@@ -181,8 +181,10 @@ async def _terminate_session(self) -> None: # type: ignore[override]
181181

182182
async def _read_s3_file_as_text(self, uri) -> str: # type: ignore[override]
183183
bucket, key = parse_output_location(uri)
184-
response = await asyncio.to_thread(
184+
response = await async_retry_api_call(
185185
self._client.get_object,
186+
config=self._retry_config,
187+
logger=_logger,
186188
Bucket=bucket,
187189
Key=key,
188190
)

tests/pyathena/aio/s3fs/test_cursor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import pytest
33

4+
from pyathena.aio.s3fs.cursor import AioS3FSCursor
45
from pyathena.error import ProgrammingError
56
from pyathena.s3fs.result_set import AthenaS3FSResultSet
67
from tests import ENV
@@ -58,11 +59,16 @@ async def test_invalid_arraysize(self, aio_s3fs_cursor):
5859
with pytest.raises(ProgrammingError):
5960
aio_s3fs_cursor.arraysize = -1
6061

62+
async def test_async_iterator(self, aio_s3fs_cursor):
63+
await aio_s3fs_cursor.execute("SELECT * FROM one_row")
64+
rows = []
65+
async for row in aio_s3fs_cursor:
66+
rows.append(row)
67+
assert rows == [(1,)]
68+
6169
async def test_context_manager(self):
6270
conn = await _aio_connect(schema_name=ENV.schema)
6371
try:
64-
from pyathena.aio.s3fs.cursor import AioS3FSCursor
65-
6672
async with conn.cursor(AioS3FSCursor) as cursor:
6773
await cursor.execute("SELECT * FROM one_row")
6874
assert await cursor.fetchone() == (1,)

tests/pyathena/aio/spark/test_cursor.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import pytest
55

6-
from pyathena.error import DatabaseError, OperationalError
6+
from pyathena.error import DatabaseError, NotSupportedError, OperationalError
77
from pyathena.model import AthenaCalculationExecutionStatus
88
from tests import ENV
9+
from tests.pyathena.aio.conftest import _aio_connect
910

1011

1112
class TestAioSparkCursor:
@@ -142,3 +143,22 @@ async def cancel_after_delay(c):
142143
)
143144

144145
await task
146+
147+
async def test_executemany(self, aio_spark_cursor):
148+
with pytest.raises(NotSupportedError):
149+
await aio_spark_cursor.executemany("SELECT 1", [])
150+
151+
async def test_context_manager(self):
152+
import asyncio
153+
154+
from pyathena.aio.spark.cursor import AioSparkCursor
155+
156+
conn = await _aio_connect(
157+
schema_name=ENV.schema,
158+
cursor_class=AioSparkCursor,
159+
work_group=ENV.spark_work_group,
160+
)
161+
cursor = await asyncio.to_thread(conn.cursor)
162+
async with cursor:
163+
await cursor.execute("print('hello')")
164+
assert await cursor.get_std_out() == "hello"

0 commit comments

Comments
 (0)