Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions examples/servers/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
from __future__ import annotations

import asyncio
import logging
import math

Expand Down Expand Up @@ -64,20 +65,38 @@ def calculate_sum_wrapped(s: LanguageServer, *args):

@server.command("calculate.pow")
def calculate_pow(x: float, n):
""""""
"""One typed, one un-typed argument"""
logging.info("x: %r, n: %r", x, n)
return x**n


@server.command("calculate.pow.wrapped")
def calculate_pow_wrapped(s: LanguageServer, x, n: int):
"""Using *args to accept any number of arguments"""
"""One typed, one un-typed argument"""
s.window_log_message(
types.LogMessageParams(type=types.MessageType.Info, message=f"{x=}, {n=}")
)
return calculate_pow(x, n)


@server.command("calculate.pow.async")
async def calculate_pow_async(x: float, n):
"""One typed, one un-typed argument, async"""
await asyncio.sleep(1)

logging.info("x: %r, n: %r", x, n)
return x**n


@server.command("calculate.pow.async.wrapped")
async def calculate_pow_async_wrapped(s: LanguageServer, x, n: int):
"""One typed, one un-typed argument, async"""
s.window_log_message(
types.LogMessageParams(type=types.MessageType.Info, message=f"{x=}, {n=}")
)
return await calculate_pow_async(x, n)


@server.command("calculate.div")
def calculate_div(x: float, n):
""""""
Expand Down
4 changes: 4 additions & 0 deletions pygls/feature_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def wrap_with_server(f, server):
async def wrapped(*args, **kwargs):
return await f(server, *args, **kwargs)

# Used by `workspace/executeCommand` to access the original function's
# signature. Mirrors how functools.partial works.
wrapped.func = f # type: ignore[attr-defined]

else:
wrapped = functools.partial(f, server)
if is_thread_function(f):
Expand Down
16 changes: 7 additions & 9 deletions pygls/protocol/language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,16 +415,14 @@ def _get_handler_params_annotations(handler: Callable[..., Any]):
"""Return the parameters and corresponding type annotations for the given handler
function."""

try:
annotations = typing.get_type_hints(handler)
params = inspect.signature(handler).parameters
except TypeError:
# If the user's handler requests the language server instance, the real function
# is wrapped inside whatever `functools.partial()` returns.
if not hasattr(handler, "func"):
raise

# If the user's handler requests the language server instance, the real function
# is wrapped inside whatever `functools.partial()` returns.
if hasattr(handler, "func"):
annotations = typing.get_type_hints(handler.func)
params = inspect.signature(handler.func).parameters

else:
annotations = typing.get_type_hints(handler)
params = inspect.signature(handler).parameters

return params, annotations
11 changes: 10 additions & 1 deletion tests/e2e/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,24 @@ async def test_calculate_pow_invalid(
@pytest.mark.asyncio(loop_scope="module")
@pytest.mark.parametrize(
"name",
["calculate.pow", "calculate.pow.wrapped"],
[
"calculate.pow",
"calculate.pow.async",
"calculate.pow.wrapped",
"calculate.pow.async.wrapped",
],
)
async def test_calculate_pow(
commands: Tuple[LanguageClient, types.InitializeResult],
name: str,
runtime: str,
):
"""Ensure that the example commands server can execute both the wrapped and
unwrapped ``calculate.pow`` commands correctly."""

if runtime in {"pyodide"} and "async" in name:
pytest.skip("async handlers not supported in this runtime")

client, initialize_result = commands

provider = initialize_result.capabilities.execute_command_provider
Expand Down
Loading