|
48 | 48 | from typing import Optional |
49 | 49 | from collections.abc import Iterator |
50 | 50 |
|
| 51 | +try: |
| 52 | + from anyio import create_memory_object_stream, create_task_group |
| 53 | + from mcp.types import ( |
| 54 | + JSONRPCMessage, |
| 55 | + JSONRPCNotification, |
| 56 | + JSONRPCRequest, |
| 57 | + ) |
| 58 | + from mcp.shared.message import SessionMessage |
| 59 | +except ImportError: |
| 60 | + create_memory_object_stream = None |
| 61 | + create_task_group = None |
| 62 | + JSONRPCMessage = None |
| 63 | + JSONRPCNotification = None |
| 64 | + JSONRPCRequest = None |
| 65 | + SessionMessage = None |
| 66 | + |
51 | 67 |
|
52 | 68 | SENTRY_EVENT_SCHEMA = "./checkouts/data-schemas/relay/event.schema.json" |
53 | 69 |
|
@@ -592,6 +608,106 @@ def suppress_deprecation_warnings(): |
592 | 608 | yield |
593 | 609 |
|
594 | 610 |
|
| 611 | +@pytest.fixture |
| 612 | +def get_initialization_payload(): |
| 613 | + def inner(request_id: str): |
| 614 | + return SessionMessage( # type: ignore |
| 615 | + message=JSONRPCMessage( # type: ignore |
| 616 | + root=JSONRPCRequest( # type: ignore |
| 617 | + jsonrpc="2.0", |
| 618 | + id=request_id, |
| 619 | + method="initialize", |
| 620 | + params={ |
| 621 | + "protocolVersion": "2025-11-25", |
| 622 | + "capabilities": {}, |
| 623 | + "clientInfo": {"name": "test-client", "version": "1.0.0"}, |
| 624 | + }, |
| 625 | + ) |
| 626 | + ) |
| 627 | + ) |
| 628 | + |
| 629 | + return inner |
| 630 | + |
| 631 | + |
| 632 | +@pytest.fixture |
| 633 | +def get_initialized_notification_payload(): |
| 634 | + def inner(): |
| 635 | + return SessionMessage( # type: ignore |
| 636 | + message=JSONRPCMessage( # type: ignore |
| 637 | + root=JSONRPCNotification( # type: ignore |
| 638 | + jsonrpc="2.0", |
| 639 | + method="notifications/initialized", |
| 640 | + ) |
| 641 | + ) |
| 642 | + ) |
| 643 | + |
| 644 | + return inner |
| 645 | + |
| 646 | + |
| 647 | +@pytest.fixture |
| 648 | +def get_mcp_command_payload(): |
| 649 | + def inner(method: str, params, request_id: str): |
| 650 | + return SessionMessage( # type: ignore |
| 651 | + message=JSONRPCMessage( # type: ignore |
| 652 | + root=JSONRPCRequest( # type: ignore |
| 653 | + jsonrpc="2.0", |
| 654 | + id=request_id, |
| 655 | + method=method, |
| 656 | + params=params, |
| 657 | + ) |
| 658 | + ) |
| 659 | + ) |
| 660 | + |
| 661 | + return inner |
| 662 | + |
| 663 | + |
| 664 | +@pytest.fixture |
| 665 | +def stdio( |
| 666 | + get_initialization_payload, |
| 667 | + get_initialized_notification_payload, |
| 668 | + get_mcp_command_payload, |
| 669 | +): |
| 670 | + async def inner(server, method: str, params, request_id: str | None = None): |
| 671 | + if request_id is None: |
| 672 | + request_id = "1" |
| 673 | + |
| 674 | + read_stream_writer, read_stream = create_memory_object_stream(0) # type: ignore |
| 675 | + write_stream, write_stream_reader = create_memory_object_stream(0) # type: ignore |
| 676 | + |
| 677 | + result = {} |
| 678 | + |
| 679 | + async def run_server(): |
| 680 | + await server.run( |
| 681 | + read_stream, write_stream, server.create_initialization_options() |
| 682 | + ) |
| 683 | + |
| 684 | + async def simulate_client(tg, result): |
| 685 | + init_request = get_initialization_payload("1") |
| 686 | + await read_stream_writer.send(init_request) |
| 687 | + |
| 688 | + await write_stream_reader.receive() |
| 689 | + |
| 690 | + initialized_notification = get_initialized_notification_payload() |
| 691 | + await read_stream_writer.send(initialized_notification) |
| 692 | + |
| 693 | + request = get_mcp_command_payload( |
| 694 | + method, params=params, request_id=request_id |
| 695 | + ) |
| 696 | + await read_stream_writer.send(request) |
| 697 | + |
| 698 | + result["response"] = await write_stream_reader.receive() |
| 699 | + |
| 700 | + tg.cancel_scope.cancel() |
| 701 | + |
| 702 | + async with create_task_group() as tg: # type: ignore |
| 703 | + tg.start_soon(run_server) |
| 704 | + tg.start_soon(simulate_client, tg, result) |
| 705 | + |
| 706 | + return result["response"] |
| 707 | + |
| 708 | + return inner |
| 709 | + |
| 710 | + |
595 | 711 | class MockServerRequestHandler(BaseHTTPRequestHandler): |
596 | 712 | def do_GET(self): # noqa: N802 |
597 | 713 | # Process an HTTP GET request and return a response. |
|
0 commit comments