Skip to content

fix: Do not throw error when in a uvloop context #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ DOCKER_IMAGE ?= downloads.unstructured.io/unstructured-io/unstructured-api:lates

.PHONY: install-test
install-test:
pip install pytest pytest-asyncio pytest-mock requests_mock pypdf deepdiff requests-toolbelt
pip install pytest pytest-asyncio pytest-mock requests_mock pypdf deepdiff requests-toolbelt uvloop

.PHONY: install-dev
install-dev:
Expand Down
29 changes: 29 additions & 0 deletions _test_unstructured_client/integration/test_integration_freemium.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,32 @@ def test_partition_handling_server_error(error, split_pdf, monkeypatch, doc_path

with pytest.raises(sdk_raises):
response = client.general.partition(req)


def test_uvloop_partitions_without_errors(client, doc_path):
async def call_api():
filename = "layout-parser-paper-fast.pdf"
with open(doc_path / filename, "rb") as f:
files = shared.Files(
content=f.read(),
file_name=filename,
)

req = shared.PartitionParameters(
files=files,
strategy="fast",
languages=["eng"],
split_pdf_page=True,
)

resp = client.general.partition(req)

if resp is not None:
return resp.elements
else:
return []

import uvloop
uvloop.install()
elements = asyncio.run(call_api())
assert len(elements) > 0
22 changes: 18 additions & 4 deletions src/unstructured_client/_hooks/custom/split_pdf_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ async def run_tasks(coroutines: list[Awaitable], allow_failed: bool = False) ->
return sorted(results, key=lambda x: x[0])


def context_is_uvloop():
"""Return true if uvloop is installed and we're currently in a uvloop context. Our asyncio splitting code currently doesn't work under uvloop."""
try:
import uvloop # pylint: disable=import-outside-toplevel
loop = asyncio.get_event_loop()
return isinstance(loop, uvloop.Loop)
except ImportError:
return False

def get_optimal_split_size(num_pages: int, concurrency_level: int) -> int:
"""Distributes pages to workers evenly based on the number of pages and desired concurrency level."""
if num_pages < MAX_PAGES_PER_SPLIT * concurrency_level:
Expand All @@ -94,10 +103,6 @@ class SplitPdfHook(SDKInitHook, BeforeRequestHook, AfterSuccessHook, AfterErrorH
"""

def __init__(self) -> None:
# This allows us to use an event loop in an env with an existing loop
# Temporary fix until we can improve the async splitting behavior
nest_asyncio.apply()

self.client: Optional[requests.Session] = None
self.coroutines_to_execute: dict[
str, list[Coroutine[Any, Any, requests.Response]]
Expand All @@ -121,6 +126,8 @@ def sdk_init(
self.client = client
return base_url, client


# pylint: disable=too-many-return-statements
def before_request(
self, hook_ctx: BeforeRequestContext, request: requests.PreparedRequest
) -> Union[requests.PreparedRequest, Exception]:
Expand All @@ -143,6 +150,13 @@ def before_request(
logger.warning("HTTP client not accessible! Continuing without splitting.")
return request

if context_is_uvloop():
logger.warning("Splitting is currently incompatible with uvloop. Continuing without splitting.")
return request

# This allows us to use an event loop in an env with an existing loop
# Temporary fix until we can improve the async splitting behavior
nest_asyncio.apply()
operation_id = hook_ctx.operation_id
content_type = request.headers.get("Content-Type")
body = request.body
Expand Down