Skip to content

Commit 9164765

Browse files
authored
improve RPC types (#410)
1 parent 61c3db8 commit 9164765

File tree

1 file changed

+18
-26
lines changed

1 file changed

+18
-26
lines changed

livekit-rtc/livekit/rtc/participant.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
import mimetypes
2121
import aiofiles
22-
from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast
22+
from typing import List, Union, Callable, Dict, Awaitable, Optional, Mapping, cast, TypeVar
2323
from abc import abstractmethod, ABC
2424

2525
from ._ffi_client import FfiClient, FfiHandle
@@ -144,6 +144,12 @@ def disconnect_reason(
144144
return self._info.disconnect_reason
145145

146146

147+
RpcHandler = Callable[["RpcInvocationData"], Union[Awaitable[Optional[str]], Optional[str]]]
148+
F = TypeVar(
149+
"F", bound=Callable[[RpcInvocationData], Union[Awaitable[Optional[str]], Optional[str]]]
150+
)
151+
152+
147153
class LocalParticipant(Participant):
148154
"""Represents the local participant in a room."""
149155

@@ -155,9 +161,7 @@ def __init__(
155161
super().__init__(owned_info)
156162
self._room_queue = room_queue
157163
self._track_publications: dict[str, LocalTrackPublication] = {} # type: ignore
158-
self._rpc_handlers: Dict[
159-
str, Callable[[RpcInvocationData], Union[Awaitable[str], str]]
160-
] = {}
164+
self._rpc_handlers: Dict[str, RpcHandler] = {}
161165

162166
@property
163167
def track_publications(self) -> Mapping[str, LocalTrackPublication]:
@@ -328,8 +332,8 @@ async def perform_rpc(
328332
def register_rpc_method(
329333
self,
330334
method_name: str,
331-
handler: Optional[Callable[[RpcInvocationData], Union[Awaitable[str], str]]] = None,
332-
) -> Union[None, Callable]:
335+
handler: Optional[F] = None,
336+
) -> Union[F, Callable[[F], F]]:
333337
"""
334338
Establishes the participant as a receiver for calls of the specified RPC method.
335339
Can be used either as a decorator or a regular method.
@@ -366,18 +370,17 @@ async def greet_handler(data: RpcInvocationData) -> str:
366370
room.local_participant.register_rpc_method('greet', greet_handler)
367371
"""
368372

369-
def register(handler_func):
373+
def register(handler_func: F) -> F:
370374
self._rpc_handlers[method_name] = handler_func
371375
req = proto_ffi.FfiRequest()
372376
req.register_rpc_method.local_participant_handle = self._ffi_handle.handle
373377
req.register_rpc_method.method = method_name
374378
FfiClient.instance.request(req)
379+
return handler_func
375380

376381
if handler is not None:
377-
register(handler)
378-
return None
382+
return register(handler)
379383
else:
380-
# Called as a decorator
381384
return register
382385

383386
def unregister_rpc_method(self, method: str) -> None:
@@ -438,33 +441,22 @@ async def _handle_rpc_method_invocation(
438441
else:
439442
try:
440443
if asyncio.iscoroutinefunction(handler):
441-
async_handler = cast(Callable[[RpcInvocationData], Awaitable[str]], handler)
442-
443-
async def run_handler():
444-
try:
445-
return await async_handler(params)
446-
except asyncio.CancelledError:
447-
# This will be caught by the outer try-except if it's due to timeout
448-
raise
449-
450444
try:
451445
response_payload = await asyncio.wait_for(
452-
run_handler(), timeout=response_timeout
446+
handler(params), timeout=response_timeout
453447
)
454448
except asyncio.TimeoutError:
455449
raise RpcError._built_in(RpcError.ErrorCode.RESPONSE_TIMEOUT)
456450
except asyncio.CancelledError:
457451
raise RpcError._built_in(RpcError.ErrorCode.RECIPIENT_DISCONNECTED)
458452
else:
459-
sync_handler = cast(Callable[[RpcInvocationData], str], handler)
460-
response_payload = sync_handler(params)
453+
response_payload = cast(Optional[str], handler(params))
461454
except RpcError as error:
462455
response_error = error
463-
except Exception as error:
456+
except Exception:
464457
logger.exception(
465458
f"Uncaught error returned by RPC handler for {method}. "
466459
"Returning APPLICATION_ERROR instead. "
467-
f"Original error: {error}"
468460
)
469461
response_error = RpcError._built_in(RpcError.ErrorCode.APPLICATION_ERROR)
470462

@@ -480,8 +472,8 @@ async def run_handler():
480472
res = FfiClient.instance.request(req)
481473

482474
if res.rpc_method_invocation_response.error:
483-
message = res.rpc_method_invocation_response.error
484-
logger.exception(f"error sending rpc method invocation response: {message}")
475+
err = res.rpc_method_invocation_response.error
476+
logger.error(f"error sending rpc method invocation response: {err}")
485477

486478
async def set_metadata(self, metadata: str) -> None:
487479
"""

0 commit comments

Comments
 (0)