diff --git a/docs/client/index.md b/docs/client/index.md index 01287da05..be8a73c2a 100644 --- a/docs/client/index.md +++ b/docs/client/index.md @@ -145,7 +145,7 @@ The resource verbs come in pairs: two ways to list, one way to read. `read_resource` returns `contents`, a list of `TextResourceContents` or `BlobResourceContents`. Same idea as tool content: narrow with `isinstance`, then read `.text` (or `.blob`). -A client can also be told when a resource changes. On 2025-era connections that is `subscribe_resource(uri)` / `unsubscribe_resource(uri)` - a method pair `MCPServer` doesn't implement, so on the 2026-07-28 wire (where those verbs no longer exist) the request answers `-32601`, *Method not found*. The 2026 replacement is a `subscriptions/listen` stream, which `MCPServer` *does* serve - `server_capabilities.resources.subscribe` is `True` there, and the server side of the story is **[Subscriptions](../handlers/subscriptions.md)**. +A client can also be told when a resource changes. On 2025-era connections that is `subscribe_resource(uri)` / `unsubscribe_resource(uri)` - a method pair `MCPServer` doesn't implement, so on the 2026-07-28 wire (where those verbs no longer exist) the request answers `-32601`, *Method not found*. The 2026 replacement is a `subscriptions/listen` stream, which `MCPServer` *does* serve - `server_capabilities.resources.subscribe` is `True` there, and **[Subscriptions](../handlers/subscriptions.md)** tells both sides of the story - the client end is [`client.listen(...)`](../handlers/subscriptions.md#the-client-side). ## Prompts diff --git a/docs/handlers/subscriptions.md b/docs/handlers/subscriptions.md index 6ff85dd86..30094d542 100644 --- a/docs/handlers/subscriptions.md +++ b/docs/handlers/subscriptions.md @@ -85,6 +85,27 @@ Down on the low-level `Server` there is no pre-wired anything — and the same p * `ListenHandler(bus)` is the same handler `MCPServer` registers; `on_subscriptions_listen=` is an ordinary handler slot. Don't want the SDK's semantics? Write your own handler for the slot — the spec obligations come with it. * `ListenHandler.close()` gracefully ends every open stream: each one receives the listen request's result as its final frame, the spec's signal that the server ended the subscription deliberately — a clean end, as opposed to the abrupt drop a client may treat as a cue to reconnect. Without it, streams end when the client disconnects. +## The client side + +Consuming a subscription is one context manager: + +```python title="client.py" hl_lines="9 10" +--8<-- "docs_src/subscriptions/tutorial003.py" +``` + +* `client.listen(...)` takes the filter as keyword arguments — they mirror the wire `SubscriptionFilter` field for field. Entering sends the request and returns once the server's acknowledgment arrives, so `sub.honored` (the subset the server agreed to deliver) is always there before the first event. +* Iteration yields the same four typed events the server publishes: `ToolsListChanged`, `PromptsListChanged`, `ResourcesListChanged`, and `ResourceUpdated(uri=...)` — where the URI may be a sub-resource of one you subscribed to, at the server's discretion. An event is a cue to refetch — it carries no payload beyond identity, and duplicates pending consumption collapse into one. +* Leaving the block ends the subscription, with the transport's own spelling: over streamable HTTP the request's response stream is closed (that is the 2026 cancellation signal), on stream transports `notifications/cancelled` is sent. +* The stream's two endings are control flow. The server closing gracefully simply ends the `async for`; an abrupt drop raises `SubscriptionLost`. The distinction is diagnostic — a clean end versus a connection worth suspecting — not a difference in what to do next: either way the stream is gone, nothing is replayed, and a watcher that still cares re-listens and refetches. Servers close streams gracefully for their own reasons — shutdown, or shedding a subscriber whose backlog grew past bounds, as this SDK's `ListenHandler` does — so a graceful close is not a signal to stop watching: + +```python title="watch.py" hl_lines="15 16" +--8<-- "docs_src/subscriptions/tutorial004.py" +``` + +* Checking the acknowledgment (the spec's client SHOULD) is reading `sub.honored` — the kinds this stream will actually receive. A server may narrow the filter it agrees to honor (a multi-tenant server declining a URI, say), and `sub.honored` is that delivery contract — it says nothing about what exists in the catalog. Multiple subscriptions may be open concurrently; each demultiplexes by its own subscription id. +* Tool calls and other requests run freely beside an open stream — from the same task between events, or from sibling tasks sharing the client. A watcher task that refetches inside its event loop is the intended pattern, not a re-entrancy hazard. +* `listen()` requires a 2026-07-28 connection and raises `ListenNotSupportedError` on older ones, steering to the deprecated `subscribe_resource` and `message_handler` spelling those wires use. + ## Recap * A client opts in with one `subscriptions/listen` request; the response is the stream. There is nothing to configure server-side — serving it is built in. @@ -92,3 +113,4 @@ Down on the low-level `Server` there is no pre-wired anything — and the same p * Streams receive only what their filter requested; URIs match exactly; nothing is replayed. * Scaling out means implementing `SubscriptionBus` — two methods — over your own pub/sub, and passing it as `MCPServer(subscriptions=...)`. * The low-level spelling is the same machinery held in your hands: a bus, `ListenHandler(bus)`, one constructor argument. +* Consuming is `async with client.listen(...)` and `async for event in sub` — typed events, honored filter on the handle, clean end vs `SubscriptionLost`. diff --git a/docs/migration.md b/docs/migration.md index 9ff5a054c..cf64aaf97 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1536,6 +1536,23 @@ The 2026-07-28 revision reintroduces Tasks as an official extension: [SEP-2663]( ## Deprecations +### Client resource-subscription methods deprecated (SEP-2575) + +[SEP-2575](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/2575) removes `resources/subscribe` and `resources/unsubscribe` from the 2026-07-28 wire; per-URI subscriptions travel in the `subscriptions/listen` filter instead. The client verbs now carry `typing_extensions.deprecated`: + +- `Client.subscribe_resource()` / `Client.unsubscribe_resource()` +- `ClientSession.subscribe_resource()` / `ClientSession.unsubscribe_resource()` + +They keep working against 2025-era servers; a 2026-07-28 server answers them with `-32601` (method not found). Migrate to the listen driver: + +```python +async with client.listen(resource_subscriptions=["note://todo"]) as sub: + async for event in sub: # ResourceUpdated(uri="note://todo") + ... +``` + +See the [Subscriptions](handlers/subscriptions.md#the-client-side) page for the full client-side contract (typed events, the honored filter, clean end vs `SubscriptionLost`). + ### Roots, Sampling, and Logging methods deprecated (SEP-2577) [SEP-2577](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/2577) deprecates the Roots, Sampling, and Logging features as of the 2026-07-28 spec. The deprecation is advisory only: there are no wire-level changes, capability negotiation is unchanged, and every method keeps working for sessions negotiating 2025-11-25 and earlier. diff --git a/docs/whats-new.md b/docs/whats-new.md index d197833db..c19b11f9f 100644 --- a/docs/whats-new.md +++ b/docs/whats-new.md @@ -190,9 +190,9 @@ That file is the pitch in one place: one server, one `Resolve`-backed tool, and ### Change notifications become one stream -At 2026-07-28 the standalone HTTP GET stream and `resources/subscribe` are replaced by `subscriptions/listen`: the client opens one long-lived stream and names the notification kinds it wants. `MCPServer` serves it out of the box; you publish with `await ctx.notify_resource_updated(uri)` (and `notify_tools_changed()`, and so on), and multi-replica deployments plug in a shared `SubscriptionBus`. Two honest caveats as of `2.0.0b1`: the Python `Client` cannot open the listen stream yet (the driver ships in a later pre-release), and over stdio the server does not serve it. The net for a Python *client* on that release is that nothing delivers change notifications on a 2026-07-28 connection; a host that relies on `resources/updated` should connect with `mode="legacy"` until the driver lands. +At 2026-07-28 the standalone HTTP GET stream and `resources/subscribe` are replaced by `subscriptions/listen`: the client opens one long-lived stream and names the notification kinds it wants. `MCPServer` serves it out of the box; you publish with `await ctx.notify_resource_updated(uri)` (and `notify_tools_changed()`, and so on), and multi-replica deployments plug in a shared `SubscriptionBus`. On the client (since `2.0.0b2`), `async with client.listen(...)` opens the stream: the filter goes in as keyword arguments, typed change events come back, and `sub.honored` is the subset the server agreed to deliver. One honest caveat: over stdio the server does not serve the stream yet. -**[Subscriptions](handlers/subscriptions.md)** on the server, and **[Deploy & scale](run/deploy.md)** for the bus. +**[Subscriptions](handlers/subscriptions.md)** covers both sides, and **[Deploy & scale](run/deploy.md)** the bus. ### The rest, quickly diff --git a/docs_src/subscriptions/tutorial003.py b/docs_src/subscriptions/tutorial003.py new file mode 100644 index 000000000..109f838b9 --- /dev/null +++ b/docs_src/subscriptions/tutorial003.py @@ -0,0 +1,14 @@ +from mcp import Client +from mcp.client.subscriptions import ResourceUpdated + +from .tutorial001 import mcp + + +async def watch_todo() -> str: + """Wait for the todo note to change once, then stop listening.""" + async with Client(mcp) as client: + async with client.listen(resource_subscriptions=["note://todo"]) as sub: + async for event in sub: + assert isinstance(event, ResourceUpdated) + return f"changed: {event.uri}" + return "the server closed the stream before any change" diff --git a/docs_src/subscriptions/tutorial004.py b/docs_src/subscriptions/tutorial004.py new file mode 100644 index 000000000..9fefce49f --- /dev/null +++ b/docs_src/subscriptions/tutorial004.py @@ -0,0 +1,20 @@ +import anyio + +from mcp import Client +from mcp.client.subscriptions import SubscriptionLost + + +async def watch(client: Client, uri: str) -> None: + """Keep one resource fresh for as long as the client lives.""" + while True: + try: + async with client.listen(resource_subscriptions=[uri]) as sub: + await client.read_resource(uri) # refetch: no replay across streams + async for _event in sub: + await client.read_resource(uri) + except SubscriptionLost: + pass + # Graceful close or abrupt drop, the stream is gone either way. Back + # off before re-listening - a graceful close may be the server + # shedding load, and reconnecting instantly recreates the pressure. + await anyio.sleep(1) diff --git a/examples/stories/subscriptions/client.py b/examples/stories/subscriptions/client.py index 379d69bc6..d2053aaf7 100644 --- a/examples/stories/subscriptions/client.py +++ b/examples/stories/subscriptions/client.py @@ -4,88 +4,34 @@ import mcp_types as types from mcp.client import Client +from mcp.client.subscriptions import ResourceUpdated, ToolsListChanged from stories._harness import Target, run_client -SUBSCRIPTION_ID = "io.modelcontextprotocol/subscriptionId" - async def main(target: Target, *, mode: str = "auto") -> None: - # Stream frames arrive as ordinary server notifications; `message_handler` - # is constructor-only on `Client`, so the list it fills exists first. - received: list[types.ServerNotification] = [] - arrival = anyio.Event() - - async def on_message(message: object) -> None: - nonlocal arrival - if isinstance( - message, - types.SubscriptionsAcknowledgedNotification - | types.ResourceUpdatedNotification - | types.ToolListChangedNotification, - ): - received.append(message) - arrival.set() - arrival = anyio.Event() - - async def wait_for(count: int) -> None: - with anyio.fail_after(10): - while len(received) < count: - await arrival.wait() - - async with Client(target, mode=mode, message_handler=on_message) as client: + async with Client(target, mode=mode) as client: before = await client.list_tools() assert "search" not in {tool.name for tool in before.tools} - async with anyio.create_task_group() as tg: - # There is no client-side listen API yet, so the story drops to the - # `client.session` escape hatch. The request parks for the stream's - # lifetime, so it runs as a task; cancelling it releases the local - # awaiting scope. In-memory that also ends the server's stream; over - # HTTP today nothing aborts the POST, so the server-side stream ends - # when the connection closes (the `Client` exit right below). - async def listen() -> None: - request = types.SubscriptionsListenRequest( - params=types.SubscriptionsListenRequestParams( - notifications=types.SubscriptionFilter( - tools_list_changed=True, resource_subscriptions=["note://todo"] - ) - ) - ) - await client.session.send_request(request, types.SubscriptionsListenResult) - - tg.start_soon(listen) - - # ── the ack is the first frame: it echoes the honored filter, tagged ── - await wait_for(1) - ack = received[0] - assert isinstance(ack, types.SubscriptionsAcknowledgedNotification), ack - assert ack.params.notifications.tools_list_changed is True - assert ack.params.notifications.resource_subscriptions == ["note://todo"] - assert ack.params.meta is not None and SUBSCRIPTION_ID in ack.params.meta + async with client.listen(tools_list_changed=True, resource_subscriptions=["note://todo"]) as sub: + # ── entering waited for the ack: the honored filter is already in hand ── + assert sub.honored.tools_list_changed is True + assert sub.honored.resource_subscriptions == ["note://todo"] # ── exact-URI filtering: an unsubscribed note edit stays silent ── await client.call_tool("edit_note", {"name": "journal", "text": "day two"}) - # ── the subscribed URI delivers, carrying the same subscription id ── + # ── the subscribed URI delivers ── await client.call_tool("edit_note", {"name": "todo", "text": "water plants"}) - await wait_for(2) - updated = received[1] - assert isinstance(updated, types.ResourceUpdatedNotification), updated - assert updated.params.uri == "note://todo" - assert updated.params.meta is not None - assert updated.params.meta[SUBSCRIPTION_ID] == ack.params.meta[SUBSCRIPTION_ID] - assert len(received) == 2, "the journal edit must not have been delivered" + with anyio.fail_after(10): + event = await anext(sub) + assert event == ResourceUpdated(uri="note://todo"), "the journal edit must not have been delivered" # ── a runtime tool registration announces itself ── await client.call_tool("enable_search", {}) - await wait_for(3) - assert isinstance(received[2], types.ToolListChangedNotification), received[2] - - # The client is done listening: cancel the parked request and let - # the connection teardown below end the stream server-side. - tg.cancel_scope.cancel() + with anyio.fail_after(10): + assert await anext(sub) == ToolsListChanged() - # list_changed told us to re-fetch - the new tool is callable, and the - # session outlives the closed stream. + # ── leaving the block closed the stream; the session lives on ── tools = await client.list_tools() assert "search" in {tool.name for tool in tools.tools} result = await client.call_tool("search", {"query": "water"}) diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index d581fe6a5..ab4813840 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -6,7 +6,7 @@ import logging import uuid from collections.abc import Awaitable, Callable, Mapping, Sequence -from contextlib import AsyncExitStack +from contextlib import AbstractAsyncContextManager, AsyncExitStack from dataclasses import KW_ONLY, dataclass, field from typing import Any, Literal, TypeVar, cast @@ -58,6 +58,8 @@ SamplingFnT, ) from mcp.client.streamable_http import streamable_http_client +from mcp.client.subscriptions import ServerEvent, Subscription +from mcp.client.subscriptions import listen as _listen from mcp.server import Server from mcp.server.mcpserver import MCPServer from mcp.server.runner import modern_on_request @@ -67,6 +69,7 @@ from mcp.shared.extension import validate_extension_identifier from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.session import RequestResponder +from mcp.shared.subscriptions import event_to_notification logger = logging.getLogger(__name__) @@ -662,13 +665,68 @@ async def retry(r: InputResponses | None, s: str | None) -> ReadResourceResult | # Driver rounds carry inputResponses, so a terminal result reached through them is never cached (spec MUST). return await self._drive_input_required(first, retry) + def listen( + self, + *, + tools_list_changed: bool = False, + prompts_list_changed: bool = False, + resources_list_changed: bool = False, + resource_subscriptions: Sequence[str] = (), + ) -> AbstractAsyncContextManager[Subscription]: + """Open a `subscriptions/listen` stream of typed change events (2026-07-28 only). + + Keyword args mirror the wire `SubscriptionFilter`; entering waits for the ack (honored subset: `sub.honored`): + + async with client.listen(tools_list_changed=True) as sub: + async for event in sub: + tools = await client.list_tools() # refetch on change + + A graceful close ends the loop; an abrupt drop raises `SubscriptionLost`. No replay: re-listen and refetch. + + Raises: + ListenNotSupportedError: The negotiated protocol version predates 2026-07-28. + MCPError: The server rejected the request or the connection failed first. + SubscriptionLost: The stream ended before it was acknowledged. + TimeoutError: The read timeout elapsed before the acknowledgment. + """ + return _listen( + self.session, + tools_list_changed=tools_list_changed, + prompts_list_changed=prompts_list_changed, + resources_list_changed=resources_list_changed, + resource_subscriptions=resource_subscriptions, + on_event=self._evict_for_listen_event if self._response_cache is not None else None, + ) + + async def _evict_for_listen_event(self, event: ServerEvent) -> None: + """Finish response-cache eviction before a listen consumer can refetch. + + Without it the iterator wakes first and refetches a still-warm entry, with no + corrective wake (events are deduplicated level triggers). The tee path repeats + the eviction; deliberate: idempotent, and it covers non-iterating consumers. + """ + cache = self._response_cache + assert cache is not None # installed as the event barrier only when a cache exists + try: + await cache.evict_for_notification(event_to_notification(event, {})) + except Exception: # boundary: eviction reaches user store code; a cache fault must not block delivery + logger.exception("Response cache eviction failed; the event is still delivered") + + @deprecated( + "resources/subscribe is removed as of 2026-07-28; use Client.listen() instead.", + category=MCPDeprecationWarning, + ) async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> EmptyResult: - """Subscribe to resource updates.""" - return await self.session.subscribe_resource(uri, meta=meta) + """Subscribe to resource updates (2025-era servers only).""" + return await self.session.subscribe_resource(uri, meta=meta) # pyright: ignore[reportDeprecated] + @deprecated( + "resources/unsubscribe is removed as of 2026-07-28; use Client.listen() instead.", + category=MCPDeprecationWarning, + ) async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> EmptyResult: - """Unsubscribe from resource updates.""" - return await self.session.unsubscribe_resource(uri, meta=meta) + """Unsubscribe from resource updates (2025-era servers only).""" + return await self.session.unsubscribe_resource(uri, meta=meta) # pyright: ignore[reportDeprecated] async def call_tool( self, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 804180e05..097ade1c9 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -16,6 +16,7 @@ from mcp_types import ( CLIENT_CAPABILITIES_META_KEY, CLIENT_INFO_META_KEY, + CONNECTION_CLOSED, INTERNAL_ERROR, METHOD_NOT_FOUND, PROTOCOL_VERSION_META_KEY, @@ -35,8 +36,9 @@ from mcp.client._transport import ReadStream, WriteStream from mcp.client.extension import NotificationBinding, ResultClaim, UnexpectedClaimedResult +from mcp.client.subscriptions import ListenRoute from mcp.shared._compat import resync_tracer -from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, ProgressFnT +from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, ProgressFnT, as_request_id from mcp.shared.exceptions import MCPDeprecationWarning, MCPError from mcp.shared.inbound import ( MCP_METHOD_HEADER, @@ -48,9 +50,10 @@ mcp_param_headers, x_mcp_header_map, ) -from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher, cancelled_request_id_from_params from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.subscriptions import SUBSCRIPTION_ID_META_KEY, event_from_wire from mcp.shared.transport_context import TransportContext DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -93,7 +96,16 @@ def stamp(data: dict[str, Any], opts: CallOptions) -> None: meta[PROTOCOL_VERSION_META_KEY] = protocol_version meta[CLIENT_INFO_META_KEY] = client_info meta[CLIENT_CAPABILITIES_META_KEY] = capabilities - opts["cancel_on_abandon"] = False + # `cancel_on_abandon` stays at the dispatcher default (True): the + # courtesy `notifications/cancelled` is the abandon signal. On the + # stream transports it is the 2026 wire's cancellation spelling; the + # streamable-HTTP transport translates it into aborting the request's + # own POST instead of writing it (the 2026 HTTP wire has no + # client-to-server notifications - closing the stream is the signal). + # The negotiation methods still opt out, mirroring `_preconnect_stamp`: + # the spec forbids cancelling them. + if data["method"] in ("initialize", "server/discover"): + opts["cancel_on_abandon"] = False headers = opts.setdefault("headers", {}) headers[MCP_PROTOCOL_VERSION_HEADER] = protocol_version headers[MCP_METHOD_HEADER] = data["method"] @@ -351,6 +363,8 @@ def __init__( self._negotiated_version: str | None = None self._stamp: Callable[[dict[str, Any], CallOptions], None] = _preconnect_stamp self._task_group: anyio.abc.TaskGroup | None = None + # subscriptions/listen demux routes; membership decides ack consumption (raw listens are never registered) + self._listen_routes: dict[RequestId, ListenRoute] = {} if dispatcher is not None: if read_stream is not None or write_stream is not None: raise ValueError("pass read_stream/write_stream or dispatcher, not both") @@ -379,7 +393,9 @@ async def __aenter__(self) -> Self: for binding in self._notification_bindings.values(): send, receive = anyio.create_memory_object_stream[BaseModel](_NOTIFICATION_QUEUE_SIZE) self._binding_queues[binding.method] = (send, receive) - await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + await self._task_group.start( + self._dispatcher.run, self._on_request, self._on_notify, self._intercept_notification + ) for binding in self._notification_bindings.values(): _, receive = self._binding_queues[binding.method] self._task_group.start_soon(self._deliver_bound_notifications, binding, receive) @@ -413,6 +429,7 @@ async def __aexit__( result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) finally: self._close_binding_queues() + self._settle_listen_routes_closed() await resync_tracer() return result @@ -850,15 +867,23 @@ async def read_resource( raise _input_required_unexpected("read_resource") return result + @deprecated( + "resources/subscribe is removed as of 2026-07-28; use Client.listen() instead.", + category=MCPDeprecationWarning, + ) async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: - """Send a resources/subscribe request.""" + """Send a resources/subscribe request (2025-era servers only).""" return await self.send_request( types.SubscribeRequest(params=types.SubscribeRequestParams(uri=uri, _meta=meta)), types.EmptyResult, ) + @deprecated( + "resources/unsubscribe is removed as of 2026-07-28; use Client.listen() instead.", + category=MCPDeprecationWarning, + ) async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult: - """Send a resources/unsubscribe request.""" + """Send a resources/unsubscribe request (2025-era servers only).""" return await self.send_request( types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=uri, _meta=meta)), types.EmptyResult, @@ -1216,6 +1241,62 @@ async def dispatch_input_request( case types.ListRootsRequest(): # pragma: no branch return await self._list_roots_callback(ctx) + def _register_listen_route(self, request_id: RequestId) -> ListenRoute: + """Create the demux route for a listen request id; the caller registers BEFORE sending.""" + route = ListenRoute() + self._listen_routes[request_id] = route + return route + + def _unregister_listen_route(self, request_id: RequestId) -> None: + """Drop a listen route; the handle owns membership, so a missing key is a no-op.""" + self._listen_routes.pop(request_id, None) + + def _settle_listen_routes_closed(self) -> None: + """Settle all open listen routes as lost on session exit; cancelled driver tasks cannot.""" + closed = MCPError(code=CONNECTION_CLOSED, message="Connection closed") + for route in self._listen_routes.values(): + route.settle("lost", error=closed) + self._listen_routes.clear() + + def _intercept_notification(self, method: str, params: Mapping[str, Any] | None) -> bool: + """Wire-order listen demux, run synchronously on the dispatcher's receive path. + + Bookkeeping must advance in receive order with the listen result (resolved on + this same path); the spawned `_on_notify` path would race it and drop events. + Returns True to consume the frame: a live route's ack is driver state, never surfaced. + """ + if not self._listen_routes: + return False + if method == "notifications/cancelled": + request_id = cancelled_request_id_from_params(params) + if request_id is not None and (listen_route := self._listen_routes.get(request_id)) is not None: + # a server-sent cancel naming a listen request is that stream's teardown signal + listen_route.settle("lost") + return False # _on_notify swallows every cancelled either way (v1 parity) + if params is None: + return False + meta = params.get("_meta") + if not isinstance(meta, Mapping): + return False + # as_request_id is not a tripwire: raw wire _meta can carry a non-id (even unhashable) value + subscription_id = as_request_id(cast("Mapping[str, Any]", meta).get(SUBSCRIPTION_ID_META_KEY)) + if subscription_id is None or (listen_route := self._listen_routes.get(subscription_id)) is None: + return False + if method == "notifications/subscriptions/acknowledged": + raw_filter = params.get("notifications") + if raw_filter is None: + # malformed, not an empty filter: leave it to the spawned path's validation warning + return False + try: + honored = types.SubscriptionFilter.model_validate(raw_filter) + except ValidationError: + return False + listen_route.set_acked(honored) + return True + if (event := event_from_wire(method, params)) is not None: + listen_route.deliver(event) + return False # events (and any other stamped frame) still tee as usual + async def _on_notify( self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> None: @@ -1250,7 +1331,7 @@ async def _on_notify( logger.warning("Failed to validate notification: %s", method, exc_info=True) return if isinstance(notification, types.CancelledNotification): - # The dispatcher already applied the cancellation; not surfaced to message_handler. + # Never surfaced (v1 parity): the dispatcher already applied it; listen cancels settled by the intercept. return try: if isinstance(notification, types.LoggingMessageNotification): diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index f28eb7c7a..4929bdf7f 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -13,6 +13,7 @@ from anyio.abc import TaskGroup from httpx_sse import EventSource, ServerSentEvent, aconnect_sse from mcp_types import ( + CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_REQUEST, METHOD_NOT_FOUND, @@ -26,6 +27,7 @@ RequestId, jsonrpc_message_adapter, ) +from mcp_types.version import MODERN_PROTOCOL_VERSIONS from pydantic import ValidationError from mcp.client._transport import TransportStreams @@ -33,6 +35,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +from mcp.shared.jsonrpc_dispatcher import cancelled_request_id_from_params from mcp.shared.message import ClientMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -70,6 +73,19 @@ class RequestContext: read_stream_writer: StreamWriter +@dataclass(slots=True) +class _InFlightPost: + """A request POST in flight: its abort scope and the era it was sent under. + + `modern` is the negotiated-version cache as of this request's dequeue, so a + later cancel frame is interpreted under the era the request actually ran + with, not whatever the cache says by then. + """ + + scope: anyio.CancelScope + modern: bool + + class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" @@ -81,11 +97,18 @@ def __init__(self, url: str) -> None: """ self.url = url self.session_id: str | None = None - # Captured from each stamped POST's metadata. Reused on outbound HTTP that carries - # no per-message header (transport-internal GET/DELETE, and dispatcher-written - # response/error/cancel POSTs that bypass the session's stamp). Cleared when an - # `initialize` POST goes out so a probe-stamped value cannot leak onto the handshake. + # Captured from each stamped message's metadata, synchronously in the + # post_writer loop so the cache always reflects wire order (a POST task's + # scheduling is arbitrary). Reused on outbound HTTP that carries no + # per-message header (transport-internal GET/DELETE, and dispatcher-written + # response/error POSTs that bypass the session's stamp), and consulted by + # `_consume_modern_cancellation`. Cleared when an `initialize` message is + # dequeued so a probe-stamped value cannot leak onto the handshake. self._protocol_version_header: str | None = None + # Every request's POST runs inside one of these so an outbound + # `notifications/cancelled` at 2026 can abort it; see + # `_consume_modern_cancellation`. Keys are verbatim-typed ("1" is not 1). + self._in_flight_posts: dict[RequestId, _InFlightPost] = {} def _prepare_headers(self) -> dict[str, str]: """Build MCP-specific request headers for any outbound HTTP request. @@ -93,9 +116,9 @@ def _prepare_headers(self) -> dict[str, str]: These are merged with the ``httpx.AsyncClient`` defaults (these take precedence). The cached ``MCP-Protocol-Version`` is included whenever present so messages that don't pass through the session's stamp — - response/error/cancel POSTs, transport-internal GET/DELETE — still - carry the negotiated version. Per-message headers are layered on top - by the caller. + response/error POSTs, legacy cancel frames, transport-internal + GET/DELETE — still carry the negotiated version. Per-message headers + are layered on top by the caller. """ headers: dict[str, str] = { "accept": "application/json, text/event-stream", @@ -245,19 +268,57 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: await event_source.response.aclose() break + def _consume_modern_cancellation(self, session_message: SessionMessage) -> bool: + """Translate an outbound `notifications/cancelled` at 2026; True means "do not POST". + + The 2026 wire defines no client-to-server notifications over streamable + HTTP: closing a request's response stream IS its cancellation signal. + The dispatcher still emits the courtesy frame as its abandon signal + (every outbound cancel names one of our own request ids - the spec + forbids cancelling a request the sender did not issue), so this + transport translates it: when the named request's POST is in flight, + that POST's own recorded era decides - abort-and-swallow at 2026, POST + the frame below it (where the frame is the signal and a disconnect + explicitly is not). With no POST to consult, the cached negotiated + version decides; at 2026 the frame is swallowed even unmatched, so a + late cancel racing the response cannot leak onto the wire. + """ + message = session_message.message + if not (isinstance(message, JSONRPCNotification) and message.method == "notifications/cancelled"): + return False + request_id = cancelled_request_id_from_params(message.params) + post = self._in_flight_posts.get(request_id) if request_id is not None else None + if post is not None: + if not post.modern: + return False + logger.debug("aborting in-flight POST for cancelled request %r", request_id) + post.scope.cancel() + return True + return self._protocol_version_header in MODERN_PROTOCOL_VERSIONS + + async def _run_request_post( + self, + post_fn: Callable[[], Awaitable[None]], + post: _InFlightPost, + request_id: RequestId, + ) -> None: + """Run one request's POST inside its abort scope (see `_consume_modern_cancellation`).""" + try: + with post.scope: + await post_fn() + finally: + # Identity-guarded: a reused id may already have a successor + # registered while this task unwinds - popping by key alone would + # evict the live entry and leave the new POST unabortable. + if self._in_flight_posts.get(request_id) is post: + del self._in_flight_posts[request_id] + async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" message = ctx.session_message.message - is_initialization = self._is_initialization_request(message) - if is_initialization: - # `initialize` is the negotiation, not a "subsequent request" — discard any - # probe-stamped value so the discover→fallback path can't leak it onto the handshake. - self._protocol_version_header = None headers = self._prepare_headers() if ctx.metadata is not None and ctx.metadata.headers is not None: headers.update(ctx.metadata.headers) - if MCP_PROTOCOL_VERSION_HEADER in ctx.metadata.headers: - self._protocol_version_header = ctx.metadata.headers[MCP_PROTOCOL_VERSION_HEADER] async with ctx.client.stream( "POST", @@ -302,7 +363,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: await ctx.read_stream_writer.send(session_message) return - if is_initialization: + if self._is_initialization_request(message): self._maybe_extract_session_id_from_response(response) # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: @@ -378,9 +439,29 @@ async def _handle_sse_response( logger.debug("SSE stream ended", exc_info=True) # pragma: lax no cover # Stream ended without response - reconnect if we received an event with ID - if last_event_id is not None: # pragma: no branch + if last_event_id is not None: logger.info("SSE stream disconnected, reconnecting...") await self._handle_reconnection(ctx, last_event_id, retry_interval_ms) + else: + # Not resumable: resolve the waiter, else a listen stream's consumer + # would hang forever instead of learning the subscription is lost. + await self._resolve_abandoned_request( + ctx.read_stream_writer, original_request_id, "SSE stream ended without a response" + ) + + async def _resolve_abandoned_request( + self, read_stream_writer: StreamWriter, request_id: RequestId, message: str + ) -> None: + """Resolve a request whose response can never arrive with a synthesized error. + + Best-effort: a closed read stream means the session is tearing down. + """ + error_data = ErrorData(code=CONNECTION_CLOSED, message=message) + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data)) + try: + await read_stream_writer.send(error_msg) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("read stream closed before request %r could be resolved", request_id) async def _handle_reconnection( self, @@ -390,9 +471,19 @@ async def _handle_reconnection( attempt: int = 0, ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" - # Bail if max retries exceeded - if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + # Only requests reconnect: every caller arrives from a request's response stream. + assert isinstance(ctx.session_message.message, JSONRPCRequest) + original_request_id = ctx.session_message.message.id + + if attempt >= MAX_RECONNECTION_ATTEMPTS: + # Resolve on give-up: a request with no read timeout (a listen + # stream) would otherwise hang its caller forever. logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") + await self._resolve_abandoned_request( + ctx.read_stream_writer, + original_request_id, + "SSE stream ended and reconnection attempts were exhausted", + ) return # Always wait - use server value or default @@ -402,11 +493,6 @@ async def _handle_reconnection( headers = self._prepare_headers() headers[LAST_EVENT_ID] = last_event_id - # Extract original request ID to map responses - original_request_id = None - if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch - original_request_id = ctx.session_message.message.id - try: async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: event_source.response.raise_for_status() @@ -455,6 +541,8 @@ async def post_writer( async def _handle_message(session_message: SessionMessage) -> None: message = session_message.message + if self._consume_modern_cancellation(session_message): + return metadata = ( session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) @@ -470,6 +558,15 @@ async def _handle_message(session_message: SessionMessage) -> None: if self._is_initialized_notification(message): start_get_stream() + if self._is_initialization_request(message): + # `initialize` is the negotiation, not a "subsequent request" — discard any + # probe-stamped value so the discover→fallback path can't leak it onto the handshake. + self._protocol_version_header = None + elif metadata is not None and metadata.headers is not None: + stamped_version = metadata.headers.get(MCP_PROTOCOL_VERSION_HEADER) + if stamped_version is not None: + self._protocol_version_header = stamped_version + ctx = RequestContext( client=client, session_id=self.session_id, @@ -486,7 +583,15 @@ async def handle_request_async(): # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): - tg.start_soon(handle_request_async) + # Register the abort scope before the spawn: the next + # message through this loop can already be the abandon + # signal for this id, ahead of the task ever running. + post = _InFlightPost( + scope=anyio.CancelScope(), + modern=self._protocol_version_header in MODERN_PROTOCOL_VERSIONS, + ) + self._in_flight_posts[message.id] = post + tg.start_soon(self._run_request_post, handle_request_async, post, message.id) else: await handle_request_async() diff --git a/src/mcp/client/subscriptions.py b/src/mcp/client/subscriptions.py new file mode 100644 index 000000000..27283909b --- /dev/null +++ b/src/mcp/client/subscriptions.py @@ -0,0 +1,282 @@ +"""Client-side `subscriptions/listen` driver (2026-07-28, SEP-2575). + +`listen()` opens the stream as an async context manager: entering waits for +the server's acknowledgment, iteration yields typed change events, a graceful +server close ends the loop, and an abrupt drop raises `SubscriptionLost`. +There is no replay and no automatic re-listen: a client that re-opens a +subscription refetches what it depends on. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator, Awaitable, Callable, Sequence +from contextlib import asynccontextmanager +from itertools import count +from typing import TYPE_CHECKING, Literal + +import anyio +import mcp_types as types +from mcp_types.version import MODERN_PROTOCOL_VERSIONS + +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import MCPError +from mcp.shared.subscriptions import ( + PromptsListChanged, + ResourcesListChanged, + ResourceUpdated, + ServerEvent, + ToolsListChanged, + event_matches, +) + +if TYPE_CHECKING: + from mcp.client.session import ClientSession + +__all__ = [ + "ListenNotSupportedError", + "OnEvent", + "PromptsListChanged", + "ResourceUpdated", + "ResourcesListChanged", + "ServerEvent", + "Subscription", + "SubscriptionLost", + "ToolsListChanged", + "listen", +] + +_listen_ids = count(1) +"""Process-wide `listen-N` sequence: string ids can never collide with a dispatcher's minted ints.""" + +_MAX_PENDING_EVENTS = 1024 +"""Backlog backstop: the spec allows sub-resource URIs, so distinct pending +`ResourceUpdated` events are unbounded; overflowing this cap settles the +subscription lost rather than growing client memory.""" + +_SubscriptionEnd = Literal["graceful", "lost", "local"] + + +class ListenNotSupportedError(RuntimeError): + """`subscriptions/listen` requires a 2026-07-28 connection.""" + + def __init__(self, negotiated_version: str | None) -> None: + self.negotiated_version = negotiated_version + super().__init__( + f"subscriptions/listen is not available at protocol version {negotiated_version!r}; it requires " + "2026-07-28. On earlier versions use subscribe_resource() and the change notifications delivered " + "through message_handler." + ) + + +class SubscriptionLost(RuntimeError): + """The stream ended without the server's graceful close; re-listen and refetch.""" + + +class ListenRoute: + """Package-internal demux state for one listen stream, fed synchronously in receive order by the session.""" + + def __init__(self) -> None: + self.honored: types.SubscriptionFilter | None = None + self.acked = anyio.Event() + self.error: MCPError | None = None + self.end: _SubscriptionEnd | None = None + self._honored_uris: frozenset[str] = frozenset() + self._pending: dict[ServerEvent, None] = {} + self._wake = anyio.Event() + + def set_acked(self, honored: types.SubscriptionFilter) -> None: + """Record the acknowledged filter; the first ack wins.""" + if not self.acked.is_set(): + self.honored = honored + self._honored_uris = frozenset(honored.resource_subscriptions or ()) + self.acked.set() + + def deliver(self, event: ServerEvent) -> None: + """Queue an event within the honored filter, deduplicated against the backlog. + + Any `ResourceUpdated` is admitted once URI subscriptions were honored at + all: the spec allows the stamped URI to be a sub-resource of a subscribed one. + """ + if self.end is not None or self.honored is None: + return + if isinstance(event, ResourceUpdated): + admitted = bool(self._honored_uris) + else: + admitted = event_matches(self.honored, self._honored_uris, event) + if not admitted or event in self._pending: + return + if len(self._pending) >= _MAX_PENDING_EVENTS: + self.settle( + "lost", + error=MCPError( + types.INTERNAL_ERROR, + f"subscription backlog exceeded {_MAX_PENDING_EVENTS} unconsumed events; re-listen and refetch", + ), + ) + return + self._pending[event] = None + self._wake.set() + + def settle(self, end: _SubscriptionEnd, error: MCPError | None = None) -> None: + """Record the stream's end; the first reason wins and wakes both waiters.""" + if self.end is None: + self.end = end + self.error = error + self.acked.set() + self._wake.set() + + async def next_event(self) -> ServerEvent | _SubscriptionEnd: + """Peek the next pending event, or the stream's end once the backlog drains. + + A "local" end short-circuits the backlog; the other endings drain it first, + so a graceful close never swallows events that preceded it. + """ + while True: + # Snapshot the wake event before checking state so a deliver landing after the checks cannot be missed. + wake = self._wake + if self.end == "local": + return self.end + if self._pending: + return next(iter(self._pending)) + if self.end is not None: + return self.end + await wake.wait() + self._wake = anyio.Event() + + def consume(self, event: ServerEvent) -> None: + """Remove a peeked event from the backlog.""" + self._pending.pop(event, None) + + +OnEvent = Callable[[ServerEvent], Awaitable[None]] +"""Per-event barrier awaited before a `Subscription` returns each event to its consumer.""" + + +class Subscription: + """One open `subscriptions/listen` stream: an async iterator of typed events. + + Produced by `listen()` / `Client.listen()`, not constructed directly. + """ + + def __init__( + self, + route: ListenRoute, + subscription_id: types.RequestId, + honored: types.SubscriptionFilter, + on_event: OnEvent | None = None, + ): + self._route = route + self._on_event = on_event + self.subscription_id = subscription_id + """The listen request's JSON-RPC id, stamped into every frame's `_meta`.""" + self.honored = honored + """The subset of the requested filter the server agreed to deliver.""" + + def __aiter__(self) -> Subscription: + return self + + async def __anext__(self) -> ServerEvent: + """Yield the next change event; the loop ends when the stream does. + + Raises: + SubscriptionLost: the stream dropped without the server's graceful close. + """ + outcome = await self._route.next_event() + if isinstance(outcome, str): + if outcome == "lost": + raise SubscriptionLost( + f"subscription {self.subscription_id!r} ended without the server's graceful close;" + " re-listen and refetch" + ) from self._route.error + raise StopAsyncIteration + if self._on_event is not None: + # The event stays pending while the barrier runs: a cancellation or a + # raising barrier leaves it for the next anext instead of dropping it. + await self._on_event(outcome) + self._route.consume(outcome) + return outcome + + +@asynccontextmanager +async def listen( + session: ClientSession, + *, + tools_list_changed: bool = False, + prompts_list_changed: bool = False, + resources_list_changed: bool = False, + resource_subscriptions: Sequence[str] = (), + on_event: OnEvent | None = None, +) -> AsyncIterator[Subscription]: + """Open one `subscriptions/listen` stream on `session` (2026-07-28 only). + + Entering sends the request and returns once the server's acknowledgment + arrives; exiting ends the subscription. `on_event` is awaited before each + event is returned - the seam `Client.listen` uses to finish cache eviction + before the consumer can refetch. + + Raises: + ListenNotSupportedError: negotiated version predates 2026-07-28. + MCPError: the server rejected the request, or the connection failed pre-ack. + SubscriptionLost: the stream ended before it was acknowledged. + TimeoutError: the session's read timeout elapsed before the acknowledgment. + """ + if session.protocol_version not in MODERN_PROTOCOL_VERSIONS: + raise ListenNotSupportedError(session.protocol_version) + if isinstance(resource_subscriptions, str): + raise TypeError("resource_subscriptions takes a sequence of URIs, not a bare string") + request = types.SubscriptionsListenRequest( + params=types.SubscriptionsListenRequestParams( + notifications=types.SubscriptionFilter( + tools_list_changed=tools_list_changed or None, + prompts_list_changed=prompts_list_changed or None, + resources_list_changed=resources_list_changed or None, + resource_subscriptions=list(resource_subscriptions) or None, + ) + ) + ) + task_group = session._task_group # pyright: ignore[reportPrivateUsage] + if task_group is None: + raise RuntimeError("listen() requires an entered session") + request_id: types.RequestId = f"listen-{next(_listen_ids)}" + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + opts: CallOptions = {"request_id": request_id} + session._stamp(data, opts) # pyright: ignore[reportPrivateUsage] + driver_scope = anyio.CancelScope() + + async def drive() -> None: + # Deliberately no result timeout: the response arrives when the stream ends. + with driver_scope: + try: + await session._dispatcher.send_raw_request( # pyright: ignore[reportPrivateUsage] + data["method"], data.get("params"), opts + ) + except MCPError as error: + route.settle("lost", error=error) + return + except ValueError as error: + # A raw request id collided with our minted listen id: fail this subscription + # and release the route in this same slice, so it cannot consume the raw caller's ack. + session._unregister_listen_route(request_id) # pyright: ignore[reportPrivateUsage] + route.settle("lost", error=MCPError(types.INTERNAL_ERROR, str(error))) + return + # A result, whatever its body, is the spec's graceful close; with no prior ack + # it opens the subscription already closed. + route.set_acked(types.SubscriptionFilter()) + route.settle("graceful") + + # Register the demux route before the request is written so the ack cannot race it. + route = session._register_listen_route(request_id) # pyright: ignore[reportPrivateUsage] + try: + task_group.start_soon(drive) + with anyio.fail_after(session._session_read_timeout_seconds): # pyright: ignore[reportPrivateUsage] + await route.acked.wait() + if route.honored is None: + # Only reachable on failure paths: a graceful no-ack result acked an empty filter in drive(). + if route.error is not None: + raise route.error + raise SubscriptionLost(f"subscription {request_id!r} ended before it was acknowledged") + yield Subscription(route, request_id, route.honored, on_event) + finally: + route.settle("local") + driver_scope.cancel() + session._unregister_listen_route(request_id) # pyright: ignore[reportPrivateUsage] diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 28d06761d..2b7fdf35e 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -16,14 +16,14 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.subscriptions import ( +from mcp.server.subscriptions import SubscriptionBus +from mcp.shared.exceptions import MCPDeprecationWarning +from mcp.shared.subscriptions import ( PromptsListChanged, ResourcesListChanged, ResourceUpdated, - SubscriptionBus, ToolsListChanged, ) -from mcp.shared.exceptions import MCPDeprecationWarning if TYPE_CHECKING: from mcp.server.mcpserver.server import MCPServer diff --git a/src/mcp/server/subscriptions.py b/src/mcp/server/subscriptions.py index d071cfdbf..6b0b3d49b 100644 --- a/src/mcp/server/subscriptions.py +++ b/src/mcp/server/subscriptions.py @@ -13,6 +13,8 @@ `MCPServer` registers one automatically; lowlevel `Server` users pass an instance as `on_subscriptions_listen=`. +The event vocabulary lives in `mcp.shared.subscriptions`, shared with the client driver, and is re-exported here. + Per the spec, the handler acknowledges first (the ack is the first frame on the stream), tags every frame with the listen request's JSON-RPC id under `_meta["io.modelcontextprotocol/subscriptionId"]`, and never delivers an @@ -24,7 +26,6 @@ import logging from collections.abc import Callable -from dataclasses import dataclass from typing import Any, Protocol import anyio @@ -33,56 +34,39 @@ from mcp_types import ( INTERNAL_ERROR, INVALID_REQUEST, - NotificationParams, - PromptListChangedNotification, - ResourceListChangedNotification, - ResourceUpdatedNotification, - ResourceUpdatedNotificationParams, - ServerNotification, SubscriptionFilter, SubscriptionsAcknowledgedNotification, SubscriptionsAcknowledgedNotificationParams, SubscriptionsListenRequestParams, SubscriptionsListenResult, - ToolListChangedNotification, ) from mcp.server.context import ServerRequestContext from mcp.shared.exceptions import MCPError +from mcp.shared.subscriptions import ( + SUBSCRIPTION_ID_META_KEY, + PromptsListChanged, + ResourcesListChanged, + ResourceUpdated, + ServerEvent, + ToolsListChanged, + event_matches, + event_to_notification, +) -logger = logging.getLogger(__name__) - -SUBSCRIPTION_ID_META_KEY = "io.modelcontextprotocol/subscriptionId" -"""The `_meta` key carrying the subscription id on every listen-stream frame. - -The value is the `subscriptions/listen` request's JSON-RPC id, verbatim. -""" - - -@dataclass(frozen=True) -class ToolsListChanged: - """The server's tool list changed.""" - - -@dataclass(frozen=True) -class PromptsListChanged: - """The server's prompt list changed.""" - - -@dataclass(frozen=True) -class ResourcesListChanged: - """The server's resource list changed.""" - - -@dataclass(frozen=True) -class ResourceUpdated: - """The resource at `uri` changed and may need to be read again.""" - - uri: str - +__all__ = [ + "SUBSCRIPTION_ID_META_KEY", + "InMemorySubscriptionBus", + "ListenHandler", + "PromptsListChanged", + "ResourceUpdated", + "ResourcesListChanged", + "ServerEvent", + "SubscriptionBus", + "ToolsListChanged", +] -ServerEvent = ToolsListChanged | PromptsListChanged | ResourcesListChanged | ResourceUpdated -"""An event a server publishes for delivery to listen subscribers.""" +logger = logging.getLogger(__name__) class SubscriptionBus(Protocol): @@ -170,32 +154,6 @@ def _honored_subset(requested: SubscriptionFilter) -> SubscriptionFilter: ) -def _event_matches(honored: SubscriptionFilter, uris: frozenset[str], event: ServerEvent) -> bool: - """Whether `event` is within the stream's honored filter. - - `uris` is the honored `resource_subscriptions` as a set: matching runs on - every publish, and the wire filter may name many URIs. - """ - if isinstance(event, ToolsListChanged): - return honored.tools_list_changed is True - if isinstance(event, PromptsListChanged): - return honored.prompts_list_changed is True - if isinstance(event, ResourcesListChanged): - return honored.resources_list_changed is True - return event.uri in uris - - -def _event_to_notification(event: ServerEvent, meta: dict[str, Any]) -> ServerNotification: - """Build the stamped wire notification for `event`.""" - if isinstance(event, ToolsListChanged): - return ToolListChangedNotification(params=NotificationParams(_meta=meta)) - if isinstance(event, PromptsListChanged): - return PromptListChangedNotification(params=NotificationParams(_meta=meta)) - if isinstance(event, ResourcesListChanged): - return ResourceListChangedNotification(params=NotificationParams(_meta=meta)) - return ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri=event.uri, _meta=meta)) - - class ListenHandler: """Serves `subscriptions/listen`: one call is one subscription stream. @@ -244,7 +202,7 @@ async def __call__( send, recv = anyio.create_memory_object_stream[ServerEvent](self._max_buffered_events) def deliver(event: ServerEvent) -> None: - if _event_matches(honored, honored_uris, event): + if event_matches(honored, honored_uris, event): try: send.send_nowait(event) except anyio.ClosedResourceError: @@ -273,7 +231,7 @@ def deliver(event: ServerEvent) -> None: ) async for event in recv: await ctx.session.send_notification( - _event_to_notification(event, meta), related_request_id=subscription_id + event_to_notification(event, meta), related_request_id=subscription_id ) finally: _safe_unsubscribe(unsubscribe) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index fd3e69d49..e17283afa 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -28,7 +28,15 @@ from pydantic import ValidationError from mcp.shared._compat import resync_tracer -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import ( + CallOptions, + OnNotify, + OnNotifyIntercept, + OnRequest, + ProgressFnT, + coerce_request_id, + run_notify_intercept, +) from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext @@ -56,7 +64,8 @@ class _DirectDispatchContext: _back_request: _Request _back_notify: _Notify request_id: RequestId | None = None - """A dispatcher-synthesized id for requests; `None` for notifications.""" + """The caller-supplied `CallOptions["request_id"]`, else a dispatcher-synthesized + id for requests; `None` for notifications.""" message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework """Always `None`: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None @@ -105,7 +114,9 @@ def __init__(self, transport_ctx: TransportContext, *, raise_handler_exceptions: self._peer: DirectDispatcher | None = None self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None + self._on_notify_intercept: OnNotifyIntercept | None = None self._next_id = 0 + self._in_flight_ids: set[RequestId] = set() self._ready = anyio.Event() self._close_event = anyio.Event() self._running = False @@ -156,6 +167,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: @@ -167,6 +179,7 @@ async def run( try: self._on_request = on_request self._on_notify = on_notify + self._on_notify_intercept = on_notify_intercept self._running = True self._ready.set() task_status.started() @@ -227,9 +240,28 @@ async def _dispatch_request( # waiting on a peer whose run() has not started yet. await self._wait_ready() assert self._on_request is not None - # Synthesize an id: the DispatchContext contract reserves None for notifications. - self._next_id += 1 - dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id) + supplied_id = opts.get("request_id") + if supplied_id is not None: + request_id: RequestId = supplied_id + # Collisions use the same coerced domain as JSONRPCDispatcher's + # pending keys, so this in-memory stand-in raises for exactly + # the ids the wire dispatcher would; the context still sees + # the verbatim value. + in_flight_key = coerce_request_id(request_id) + if in_flight_key in self._in_flight_ids: + raise ValueError(f"request id {request_id!r} is already in flight") + else: + # Synthesize an id (the DispatchContext contract reserves None + # for notifications), minting past any key a supplied id + # occupies: the collision error is reserved for the caller + # who actually chose the id. + self._next_id += 1 + while self._next_id in self._in_flight_ids: + self._next_id += 1 + request_id = self._next_id + in_flight_key = request_id + self._in_flight_ids.add(in_flight_key) + dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=request_id) try: return await self._on_request(dctx, method, params) except MCPError: @@ -247,6 +279,8 @@ async def _dispatch_request( raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e logger.exception("request handler raised") raise MCPError(code=INTERNAL_ERROR, message="Internal server error") from None + finally: + self._in_flight_ids.discard(in_flight_key) except TimeoutError: raise MCPError( code=REQUEST_TIMEOUT, @@ -263,6 +297,8 @@ async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) # dropped, not raised back into the sender's call. logger.debug("dropped notification %r to closed DirectDispatcher", method) return + if run_notify_intercept(self._on_notify_intercept, method, params): + return assert self._on_notify is not None dctx = self._make_context() await self._on_notify(dctx, method, params) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index de83189f1..f109638f2 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -16,6 +16,7 @@ embedding a server in-process. """ +import logging from collections.abc import Awaitable, Callable, Mapping from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable @@ -26,19 +27,46 @@ from mcp.shared.message import MessageMetadata from mcp.shared.transport_context import TransportContext +logger = logging.getLogger(__name__) + __all__ = [ "CallOptions", "DispatchContext", "Dispatcher", "OnNotify", + "OnNotifyIntercept", "OnRequest", "Outbound", "ProgressFnT", + "as_request_id", + "coerce_request_id", + "run_notify_intercept", ] TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) +def as_request_id(value: object) -> RequestId | None: + """Narrow an untyped wire value to a `RequestId`, or None; rejects bool (True would alias request id 1).""" + if isinstance(value, str | int) and not isinstance(value, bool): + return value + return None + + +def coerce_request_id(request_id: RequestId) -> RequestId: + """Coerce a stringified int request id back to int so a peer-echoed id still correlates (matches the TS SDK). + + This is the collision/correlation domain dispatchers share: "7" and 7 are one + id for correlation purposes, even where the wire carries the verbatim value. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + class ProgressFnT(Protocol): """Callback invoked when a progress notification arrives for a pending request.""" @@ -51,6 +79,18 @@ class CallOptions(TypedDict, total=False): All keys are optional. Dispatchers ignore keys they do not understand. """ + request_id: RequestId + """Send the request under this caller-supplied id instead of a dispatcher-minted one. + + The peer sees the value verbatim ("7" stays a string). A value that collides + with one of the sender's own in-flight request ids raises `ValueError`. + Callers that need to know a request's id before its result arrives (a + `subscriptions/listen` stream is demultiplexed by it) mint their own ids + here; string ids that don't parse as integers can never collide with the + dispatcher's minted sequence. Per the class contract, dispatchers that + predate this key ignore it and mint as usual. + """ + timeout: float """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" @@ -184,6 +224,25 @@ async def progress(self, progress: float, total: float | None = None, message: s OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] """Handler for inbound notifications: `(ctx, method, params)`.""" +OnNotifyIntercept = Callable[[str, Mapping[str, Any] | None], bool] +"""Synchronous receive-order intercept for inbound notifications: `(method, params) -> consumed`. + +Runs before `on_notify` is scheduled so correlation state advances in wire order +relative to response resolution (the client's listen demux depends on this). +Returning True consumes the notification. Must not block the receive path. +""" + + +def run_notify_intercept(intercept: OnNotifyIntercept | None, method: str, params: Mapping[str, Any] | None) -> bool: + """Invoke `intercept`, containing a raise to that one notification (never the receive loop).""" + if intercept is None: + return False + try: + return intercept(method, params) + except Exception: + logger.exception("notification intercept raised; passing %r through", method) + return False + class Dispatcher(Outbound, Protocol[TransportT_co]): """A duplex request/notification channel with call-return semantics. @@ -198,6 +257,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: @@ -205,7 +265,9 @@ async def run( Each inbound request is dispatched to `on_request` in its own task; the returned dict (or raised `MCPError`) is sent back as the response. - Inbound notifications go to `on_notify`. + Implementations MUST offer every inbound notification to + `on_notify_intercept` synchronously in receive order (via + `run_notify_intercept`), handing only unconsumed ones to `on_notify`. `task_status.started()` is called once the dispatcher is ready to accept `send_request`/`notify` calls, so callers can use diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 64fcd3298..42798fdc5 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -39,7 +39,18 @@ from mcp.shared._compat import resync_tracer from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import ( + CallOptions, + DispatchContext, + Dispatcher, + OnNotify, + OnNotifyIntercept, + OnRequest, + ProgressFnT, + as_request_id, + coerce_request_id, + run_notify_intercept, +) from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import ( ClientMessageMetadata, @@ -49,7 +60,12 @@ ) from mcp.shared.transport_context import TransportContext -__all__ = ["JSONRPCDispatcher", "handler_exception_to_error_data", "progress_token_from_params"] +__all__ = [ + "JSONRPCDispatcher", + "cancelled_request_id_from_params", + "handler_exception_to_error_data", + "progress_token_from_params", +] logger = logging.getLogger(__name__) @@ -93,14 +109,9 @@ def progress_token_from_params(params: Mapping[str, Any] | None) -> ProgressToke return None -def _coerce_id(request_id: RequestId) -> RequestId: - """Coerce a stringified int request ID back to int so a peer-echoed ID still correlates (matches the TS SDK).""" - if isinstance(request_id, str): - try: - return int(request_id) - except ValueError: - pass - return request_id +def cancelled_request_id_from_params(params: Mapping[str, Any] | None) -> RequestId | None: + """Read `params.requestId` from a `notifications/cancelled` (`as_request_id` shape rules).""" + return as_request_id((params or {}).get("requestId")) @dataclass(slots=True) @@ -285,6 +296,7 @@ def __init__( self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._on_notify_intercept: OnNotifyIntercept | None = None self._tg: anyio.abc.TaskGroup | None = None self._running = False self._closed = False @@ -314,7 +326,22 @@ async def send_raw_request( if not self._running: raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run()") opts = opts or {} - request_id = self._allocate_id() + supplied_id = opts.get("request_id") + if supplied_id is not None: + request_id: RequestId = supplied_id + # The pending key gets the same coercion `_resolve_pending` applies + # to inbound response ids, so a supplied "7" still correlates + # whether the peer echoes "7" or 7. The wire id stays verbatim. + pending_key = coerce_request_id(request_id) + if pending_key in self._pending: + raise ValueError(f"request id {request_id!r} is already in flight") + else: + # Mint past any key a supplied id occupies: the collision error is + # reserved for the caller who actually chose the id. + request_id = self._allocate_id() + while request_id in self._pending: + request_id = self._allocate_id() + pending_key = request_id out_params = dict(params) if params is not None else {} out_meta = dict(out_params.get("_meta") or {}) on_progress = opts.get("on_progress") @@ -327,7 +354,7 @@ async def send_raw_request( # a WouldBlock later just means the waiter already has its one outcome. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) - self._pending[request_id] = pending + self._pending[pending_key] = pending plan = _plan_outbound(_related_request_id, opts) # Spec MUST: only previously-issued requests may be cancelled. A write @@ -398,7 +425,7 @@ async def send_raw_request( raise finally: # Remove the waiter on every path so a late response is dropped, not leaked. - self._pending.pop(request_id, None) + self._pending.pop(pending_key, None) send.close() receive.close() @@ -439,6 +466,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: @@ -447,6 +475,7 @@ async def run( `task_status.started()` fires once `send_raw_request` is usable. Single-shot: once the loop ends the dispatcher stays closed and cannot be restarted. """ + self._on_notify_intercept = on_notify_intercept try: # LIFO exits: the write stream closes only after the task-group join, so teardown writes still land. async with self._write_stream: @@ -548,7 +577,7 @@ async def _dispatch_request( # TODO(maxisbey): duplicate ids blind-overwrite (v1/TS parity); revisit # rejecting with INVALID_REQUEST. Key coerced so a stringified # `notifications/cancelled` id still correlates. - self._in_flight[_coerce_id(req.id)] = _InFlight(scope=scope, dctx=dctx) + self._in_flight[coerce_request_id(req.id)] = _InFlight(scope=scope, dctx=dctx) if req.method in self._inline_methods: # Spawn so `sender_ctx` applies, but park the read loop until the # handler returns - that's the inline ordering guarantee. @@ -576,25 +605,22 @@ def _dispatch_notification( `notifications/cancelled` and `notifications/progress` are intercepted here (they correlate against the `_in_flight`/`_pending` tables this - layer owns) and still teed to `on_notify` afterwards. + layer owns) and still teed to `on_notify` afterwards. The caller's + `on_notify_intercept` then runs in receive order; only unconsumed + notifications reach the spawned `on_notify`. """ if msg.method == "notifications/cancelled": - match msg.params: - # bool subclasses int: the guards keep True from aliasing request id 1. - case {"requestId": str() | int() as rid} if ( - not isinstance(rid, bool) and (in_flight := self._in_flight.get(_coerce_id(rid))) is not None - ): - in_flight.dctx.cancel_requested.set() - if self._peer_cancel_mode == "interrupt": - in_flight.scope.cancel() - case _: - pass + rid = cancelled_request_id_from_params(msg.params) + if rid is not None and (in_flight := self._in_flight.get(coerce_request_id(rid))) is not None: + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() elif msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( not isinstance(token, bool) and not isinstance(progress, bool) - and (pending := self._pending.get(_coerce_id(token))) is not None + and (pending := self._pending.get(coerce_request_id(token))) is not None and pending.on_progress is not None ): total = msg.params.get("total") @@ -608,6 +634,8 @@ def _dispatch_notification( ) case _: pass + if run_notify_intercept(self._on_notify_intercept, msg.method, msg.params): + return try: transport_ctx = self._transport_builder(metadata) except Exception: @@ -620,7 +648,7 @@ def _dispatch_notification( self._spawn(_contained_notify(on_notify), dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: - pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None + pending = self._pending.get(coerce_request_id(request_id)) if request_id is not None else None if pending is None: logger.debug("dropping response for unknown/late request id %r", request_id) return @@ -680,7 +708,7 @@ async def _handle_request( # since handler return, so a peer cancel can't interleave. # Identity guard: don't evict a duplicate id's newer entry. dctx.close() - key = _coerce_id(req.id) + key = coerce_request_id(req.id) if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: del self._in_flight[key] # A write interrupted by cancellation may still have delivered diff --git a/src/mcp/shared/subscriptions.py b/src/mcp/shared/subscriptions.py new file mode 100644 index 000000000..ba50917fa --- /dev/null +++ b/src/mcp/shared/subscriptions.py @@ -0,0 +1,106 @@ +"""Typed event vocabulary for `subscriptions/listen` (2026-07-28, SEP-2575), shared by server and client. + +Every event is a level trigger ("this changed, refetch if you care"), so both sides bound buffers by dedupe. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +from mcp_types import ( + NotificationParams, + PromptListChangedNotification, + ResourceListChangedNotification, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + ServerNotification, + SubscriptionFilter, + ToolListChangedNotification, +) + +__all__ = [ + "SUBSCRIPTION_ID_META_KEY", + "PromptsListChanged", + "ResourceUpdated", + "ResourcesListChanged", + "ServerEvent", + "ToolsListChanged", + "event_from_wire", + "event_matches", + "event_to_notification", +] + +SUBSCRIPTION_ID_META_KEY = "io.modelcontextprotocol/subscriptionId" +"""The `_meta` key on every listen-stream frame; the value is the `subscriptions/listen` request's JSON-RPC id.""" + + +@dataclass(frozen=True) +class ToolsListChanged: + """The server's tool list changed.""" + + +@dataclass(frozen=True) +class PromptsListChanged: + """The server's prompt list changed.""" + + +@dataclass(frozen=True) +class ResourcesListChanged: + """The server's resource list changed.""" + + +@dataclass(frozen=True) +class ResourceUpdated: + """The resource at `uri` changed and may need to be read again.""" + + uri: str + + +ServerEvent = ToolsListChanged | PromptsListChanged | ResourcesListChanged | ResourceUpdated +"""An event a server publishes for delivery to listen subscribers.""" + + +def event_to_notification(event: ServerEvent, meta: dict[str, Any]) -> ServerNotification: + """Build the stamped wire notification for `event` (the server's direction).""" + if isinstance(event, ToolsListChanged): + return ToolListChangedNotification(params=NotificationParams(_meta=meta)) + if isinstance(event, PromptsListChanged): + return PromptListChangedNotification(params=NotificationParams(_meta=meta)) + if isinstance(event, ResourcesListChanged): + return ResourceListChangedNotification(params=NotificationParams(_meta=meta)) + return ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri=event.uri, _meta=meta)) + + +_LIST_CHANGED_EVENTS: dict[str, ServerEvent] = { + "notifications/tools/list_changed": ToolsListChanged(), + "notifications/prompts/list_changed": PromptsListChanged(), + "notifications/resources/list_changed": ResourcesListChanged(), +} + + +def event_from_wire(method: str, params: Mapping[str, Any] | None) -> ServerEvent | None: + """The event a raw listen-stream frame announces, or None if it carries none. + + Takes the raw wire dict: the client demultiplexes before the typed notification parse.""" + if (event := _LIST_CHANGED_EVENTS.get(method)) is not None: + return event + if method == "notifications/resources/updated": + uri = (params or {}).get("uri") + if isinstance(uri, str): + return ResourceUpdated(uri=uri) + return None + + +def event_matches(honored: SubscriptionFilter, uris: frozenset[str], event: ServerEvent) -> bool: + """Whether `event` is within the stream's honored filter (`uris`: the honored resource subscriptions as a set). + + The admission predicate both sides share: server delivery and client intake honor only what was acknowledged.""" + if isinstance(event, ToolsListChanged): + return honored.tools_list_changed is True + if isinstance(event, PromptsListChanged): + return honored.prompts_list_changed is True + if isinstance(event, ResourcesListChanged): + return honored.resources_list_changed is True + return event.uri in uris diff --git a/tests/client/test_client.py b/tests/client/test_client.py index f8c02c973..6c78503b9 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -35,7 +35,7 @@ from mcp_types.version import LATEST_HANDSHAKE_VERSION from pydantic import FileUrl -from mcp import MCPError +from mcp import MCPDeprecationWarning, MCPError from mcp.client._memory import InMemoryTransport from mcp.client._transport import TransportStreams from mcp.client.client import Client @@ -310,13 +310,15 @@ async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotif async def test_client_subscribe_resource(simple_server: Server): async with Client(simple_server, mode="legacy") as client: - result = await client.subscribe_resource("memory://test") + with pytest.warns(MCPDeprecationWarning, match="use Client.listen"): + result = await client.subscribe_resource("memory://test") # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) async def test_client_unsubscribe_resource(simple_server: Server): async with Client(simple_server, mode="legacy") as client: - result = await client.unsubscribe_resource("memory://test") + with pytest.warns(MCPDeprecationWarning, match="use Client.listen"): + result = await client.unsubscribe_resource("memory://test") # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) diff --git a/tests/client/test_send_request_mcp_name.py b/tests/client/test_send_request_mcp_name.py index 408810814..e22ec4015 100644 --- a/tests/client/test_send_request_mcp_name.py +++ b/tests/client/test_send_request_mcp_name.py @@ -22,7 +22,7 @@ from mcp_types.version import LATEST_HANDSHAKE_VERSION, LATEST_MODERN_VERSION from mcp.client.session import ClientSession -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest +from mcp.shared.dispatcher import CallOptions, OnNotify, OnNotifyIntercept, OnRequest from mcp.shared.inbound import MCP_NAME_HEADER, MCP_PROTOCOL_VERSION_HEADER, encode_header_value @@ -36,6 +36,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f76991f65..507c8f69e 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -39,11 +39,13 @@ from mcp.client import ClientRequestContext from mcp.client.client import Client from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.client.subscriptions import ToolsListChanged, listen from mcp.server import Server, ServerRequestContext from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair -from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest +from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnNotifyIntercept, OnRequest from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.subscriptions import SUBSCRIPTION_ID_META_KEY from mcp.shared.transport_context import TransportContext _SendToClient = anyio.streams.memory.MemoryObjectSendStream[SessionMessage | Exception] @@ -1330,43 +1332,44 @@ def test_adopt_raises_when_no_mutual_modern_version_is_supported() -> None: assert session.protocol_version is None -@pytest.mark.anyio -async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): - """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids - cancelling it — and leaves the option unset for every other method.""" +class _OptsRecordingDispatcher: + """Records `send_raw_request` opts and answers from a per-method script (default `{}`).""" - class RecordingDispatcher: - """Records `send_raw_request` opts and answers with canned results.""" + def __init__(self, answers: dict[str, dict[str, Any]] | None = None) -> None: + self.calls: list[tuple[str, CallOptions]] = [] + self._answers = answers or {} - def __init__(self) -> None: - self.calls: list[tuple[str, CallOptions]] = [] + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() - async def run( - self, - on_request: OnRequest, - on_notify: OnNotify, - *, - task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, - ) -> None: - task_status.started() - await anyio.sleep_forever() + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, opts or {})) + return self._answers.get(method, {}) - async def send_raw_request( - self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None - ) -> dict[str, Any]: - self.calls.append((method, opts or {})) - if method == "initialize": - return InitializeResult( - protocol_version=LATEST_HANDSHAKE_VERSION, - capabilities=ServerCapabilities(), - server_info=Implementation(name="mock-server", version="0.1.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True) - return {} + async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: + pass - async def notify(self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None) -> None: - pass - dispatcher = RecordingDispatcher() +@pytest.mark.anyio +async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): + """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids + cancelling it — and leaves the option unset for every other method.""" + init_answer = InitializeResult( + protocol_version=LATEST_HANDSHAKE_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + dispatcher = _OptsRecordingDispatcher({"initialize": init_answer}) with anyio.fail_after(5): async with ClientSession(dispatcher=dispatcher) as session: await session.initialize() @@ -1376,6 +1379,27 @@ async def notify(self, method: str, params: Mapping[str, Any] | None, opts: Call assert "cancel_on_abandon" not in opts_by_method["ping"] +@pytest.mark.anyio +async def test_modern_stamp_leaves_cancel_on_abandon_at_the_dispatcher_default(): + """Post-adopt modern requests leave `cancel_on_abandon` unset (the dispatcher default, + True): the courtesy frame is the abandon signal — the 2026 cancellation spelling on + stream transports, and the streamable-HTTP transport's cue to abort the request's own + POST. The negotiation methods still opt out on every path: `send_discover`'s explicit + opts, and the stamp's own carve-out for a `server/discover` sent through the generic + `send_request`.""" + dispatcher = _OptsRecordingDispatcher({"server/discover": _discover_result_dict()}) + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.discover() + await session.send_ping() + await session.send_request(types.DiscoverRequest(params=types.RequestParams()), types.DiscoverResult) + assert [method for method, _ in dispatcher.calls] == ["server/discover", "ping", "server/discover"] + negotiation_opts, ping_opts, stamped_negotiation_opts = (opts for _, opts in dispatcher.calls) + assert negotiation_opts.get("cancel_on_abandon") is False + assert "cancel_on_abandon" not in ping_opts + assert stamped_negotiation_opts.get("cancel_on_abandon") is False + + def test_constructor_rejects_streams_and_dispatcher_together(): client_side, _server_side = create_direct_dispatcher_pair() s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) @@ -1407,6 +1431,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: @@ -1465,6 +1490,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: @@ -1787,3 +1813,137 @@ async def handler(ctx: ServerRequestContext, params: types.ReadResourceRequestPa result = await client.session.read_resource("memory://r", allow_input_required=True) assert isinstance(result, types.InputRequiredResult) assert result.request_state == "resource-state" + + +@pytest.mark.anyio +async def test_a_late_ack_for_a_closed_driver_listen_reaches_message_handler(): + """Ack consumption is keyed on the live route registry alone: a stray ack for a + closed subscription's id surfaces through message_handler like any other unowned frame.""" + seen: list[object] = [] + follow_up = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + if len(seen) == 2: + follow_up.set() + + async with raw_client_session(message_handler=handler) as (session, to_client, _): + _set_negotiated_version(session, "2026-07-28") + session._register_listen_route("listen-99") # pyright: ignore[reportPrivateUsage] + session._unregister_listen_route("listen-99") # pyright: ignore[reportPrivateUsage] + await to_client.send( + SessionMessage( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/subscriptions/acknowledged", + params={ + "notifications": {"toolsListChanged": True}, + "_meta": {SUBSCRIPTION_ID_META_KEY: "listen-99"}, + }, + ) + ) + ) + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed", params={})) + ) + with anyio.fail_after(5): + await follow_up.wait() + assert [type(message).__name__ for message in seen] == [ + "SubscriptionsAcknowledgedNotification", + "ToolListChangedNotification", + ] + + +@pytest.mark.anyio +async def test_a_graceful_result_does_not_outrun_the_events_that_preceded_it(): + """[ack, event, result] written back-to-back: the event delivers and the wire ack's filter + survives a parked message_handler tee, because routes settle on the dispatcher's receive path in wire order.""" + + async def parked_handler(message: object) -> None: + await anyio.sleep_forever() + + events: list[object] = [] + honored: list[types.SubscriptionFilter] = [] + async with raw_client_session(message_handler=parked_handler) as (session, to_client, from_client): + _set_negotiated_version(session, "2026-07-28") + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def consume() -> None: + async with listen(session, tools_list_changed=True) as sub: # pragma: no branch + honored.append(sub.honored) + events.extend([event async for event in sub]) + + tg.start_soon(consume) + request = await from_client.receive() + assert isinstance(request.message, JSONRPCRequest) + meta = {SUBSCRIPTION_ID_META_KEY: request.message.id} + for message in ( + JSONRPCNotification( + jsonrpc="2.0", + method="notifications/subscriptions/acknowledged", + params={"notifications": {"toolsListChanged": True}, "_meta": meta}, + ), + JSONRPCNotification( + jsonrpc="2.0", method="notifications/tools/list_changed", params={"_meta": meta} + ), + JSONRPCResponse(jsonrpc="2.0", id=request.message.id, result={"_meta": meta}), + ): + await to_client.send(SessionMessage(message)) + assert honored == [types.SubscriptionFilter(tools_list_changed=True)] + assert events == [ToolsListChanged()] + + +def _intercept_only_session() -> ClientSession: + """A never-entered session whose intercept can be driven directly (it is synchronous).""" + dispatcher, _peer = create_direct_dispatcher_pair() + return ClientSession(dispatcher=dispatcher) + + +def test_intercept_settles_only_the_named_listen_route_on_cancelled(): + """SDK demux contract: a server-sent cancel settles exactly the listen route it names and is never consumed.""" + session = _intercept_only_session() + route = session._register_listen_route("listen-1") # pyright: ignore[reportPrivateUsage] + intercept = session._intercept_notification # pyright: ignore[reportPrivateUsage] + assert intercept("notifications/cancelled", {"requestId": "unrelated"}) is False + assert route.end is None + assert intercept("notifications/cancelled", {"requestId": "listen-1"}) is False + assert route.end == "lost" + + +def test_intercept_ignores_frames_without_a_route_or_with_broken_meta(): + """SDK demux contract: frames that correlate to no live route flow through to the normal notification path.""" + session = _intercept_only_session() + intercept = session._intercept_notification # pyright: ignore[reportPrivateUsage] + assert intercept("notifications/tools/list_changed", {"_meta": {SUBSCRIPTION_ID_META_KEY: "listen-1"}}) is False + route = session._register_listen_route("listen-1") # pyright: ignore[reportPrivateUsage] + route.set_acked(types.SubscriptionFilter(tools_list_changed=True)) + assert intercept("notifications/tools/list_changed", None) is False + # A non-mapping `_meta` is constructible on pre-2026 wires. + assert intercept("notifications/tools/list_changed", {"_meta": "oops"}) is False + assert intercept("notifications/tools/list_changed", {"_meta": {SUBSCRIPTION_ID_META_KEY: "other"}}) is False + # A non-string uri is not an event; surface validation owns it. + meta = {"_meta": {SUBSCRIPTION_ID_META_KEY: "listen-1"}} + assert intercept("notifications/resources/updated", {"uri": 7, **meta}) is False + assert route._pending == {} # pyright: ignore[reportPrivateUsage] + + +def test_intercept_consumes_acks_for_live_routes_and_leaves_malformed_ones(): + """SDK demux contract: a well-formed ack for a live route is consumed as driver state; malformed acks pass on.""" + session = _intercept_only_session() + route = session._register_listen_route("listen-1") # pyright: ignore[reportPrivateUsage] + intercept = session._intercept_notification # pyright: ignore[reportPrivateUsage] + meta = {"_meta": {SUBSCRIPTION_ID_META_KEY: "listen-1"}} + assert intercept("notifications/subscriptions/acknowledged", {"notifications": ["nope"], **meta}) is False + assert route.honored is None + # A missing `notifications` field must not be read as an (all-refusing) empty filter. + assert intercept("notifications/subscriptions/acknowledged", dict(meta)) is False + assert route.honored is None + assert ( + intercept("notifications/subscriptions/acknowledged", {"notifications": {"toolsListChanged": True}, **meta}) + is True + ) + assert route.honored == types.SubscriptionFilter(tools_list_changed=True) + # Events deliver but are never consumed - they still tee to message_handler. + assert intercept("notifications/tools/list_changed", meta) is False + assert list(route._pending) == [ToolsListChanged()] # pyright: ignore[reportPrivateUsage] diff --git a/tests/client/test_session_claims.py b/tests/client/test_session_claims.py index 21cf2fa69..94ebd7946 100644 --- a/tests/client/test_session_claims.py +++ b/tests/client/test_session_claims.py @@ -28,7 +28,7 @@ from mcp.client.extension import ClaimContext, ResultClaim, UnexpectedClaimedResult from mcp.client.session import ClientSession, _CallToolResultAdapter -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest +from mcp.shared.dispatcher import CallOptions, OnNotify, OnNotifyIntercept, OnRequest _TASKS_EXT = "com.example/tasks" _AD_ONLY_EXT = "com.example/flags" @@ -75,6 +75,7 @@ async def run( self, on_request: OnRequest, on_notify: OnNotify, + on_notify_intercept: OnNotifyIntercept | None = None, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, ) -> None: diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index 99ff6f03e..9579def0f 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -8,16 +8,44 @@ import base64 import json +from collections.abc import AsyncIterator, Callable, Mapping +from typing import Any import anyio import httpx import pytest from inline_snapshot import snapshot -from mcp_types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse +from mcp_types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, + CONNECTION_CLOSED, + METHOD_NOT_FOUND, + PROTOCOL_VERSION_META_KEY, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, +) +from mcp_types.version import LATEST_MODERN_VERSION +from starlette.types import Receive, Scope, Send -from mcp.client.streamable_http import streamable_http_client -from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER, encode_header_value -from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.client.streamable_http import ( + MAX_RECONNECTION_ATTEMPTS, + RequestContext, + StreamableHTTPTransport, + streamable_http_client, +) +from mcp.server import Server +from mcp.server._streamable_http_modern import handle_modern_request +from mcp.server.subscriptions import InMemorySubscriptionBus, ListenHandler, ServerEvent +from mcp.shared._context_streams import ContextSendStream, create_context_streams +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.inbound import MCP_METHOD_HEADER, MCP_PROTOCOL_VERSION_HEADER, encode_header_value +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from tests.interaction.transports import StreamingASGITransport +from tests.shared.test_dispatcher import Recorder, echo_handlers @pytest.mark.parametrize( @@ -154,3 +182,497 @@ def handler(request: httpx.Request) -> httpx.Response: assert MCP_PROTOCOL_VERSION_HEADER not in recorded[1].headers assert recorded[2].headers[MCP_PROTOCOL_VERSION_HEADER] == "2025-11-25" assert recorded[3].headers[MCP_PROTOCOL_VERSION_HEADER] == "2025-11-25" + + +class _ParkedSSEStream(httpx.AsyncByteStream): + """An SSE response body that emits one comment line, then parks until closed. + + `opened` fires once the transport is iterating the body (the POST is truly in + flight); `closed` fires when httpx tears the body down — the observable proof + that an abort, not a response, ended the stream. + """ + + def __init__(self) -> None: + self.opened = anyio.Event() + self.closed = anyio.Event() + self._release = anyio.Event() + + async def __aiter__(self) -> AsyncIterator[bytes]: + self.opened.set() + yield b": parked\n\n" + await self._release.wait() + + async def aclose(self) -> None: + self.closed.set() + self._release.set() + + +def _sse_or_ack_handler( + parked: _ParkedSSEStream, posted: list[dict[str, Any]], frame_posted: anyio.Event +) -> Callable[[httpx.Request], httpx.Response]: + """Requests get the parked SSE body; notifications get 202 and set `frame_posted`.""" + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + posted.append(body) + if "id" in body: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + frame_posted.set() + return httpx.Response(202) + + return handler + + +@pytest.mark.anyio +async def test_modern_cancelled_frame_aborts_the_matching_in_flight_post() -> None: + """At 2026 an outbound `notifications/cancelled` never POSTs — closing the named + request's response stream IS the wire's cancellation signal — so the transport + aborts the in-flight POST and swallows the frame.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + posted.append(json.loads(request.content)) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (_read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id="listen-1", method="subscriptions/listen", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage( + JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": "listen-1"} + ) + ) + ) + await parked.closed.wait() + assert [body["method"] for body in posted] == ["subscriptions/listen"] + + +@pytest.mark.anyio +@pytest.mark.parametrize("stamped_version", [None, "2025-11-25"], ids=["no-version-yet", "2025-11-25"]) +async def test_legacy_cancelled_frame_posts_and_leaves_the_stream_open(stamped_version: str | None) -> None: + """Below 2026 — or before any stamped POST has revealed the version — the frame is + the spec's cancellation signal: it POSTs, and the request's stream stays open + (a 2025 disconnect is explicitly not a cancel).""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + frame_posted = anyio.Event() + handler = _sse_or_ack_handler(parked, posted, frame_posted) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (_read, write), + ): + metadata = ( + ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: stamped_version}) + if stamped_version is not None + else None + ) + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}), + metadata=metadata, + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage( + JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1}) + ) + ) + await frame_posted.wait() + # Checked before teardown: exiting the transport cancels the parked POST. + assert not parked.closed.is_set() + assert [body["method"] for body in posted] == ["tools/call", "notifications/cancelled"] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "params", + [ + pytest.param({"requestId": 999}, id="unknown-id"), + pytest.param({"requestId": True}, id="bool-must-not-alias-request-id-1"), + pytest.param({"requestId": "1"}, id="string-1-must-not-match-int-1"), + pytest.param({}, id="no-request-id"), + pytest.param(None, id="no-params"), + ], +) +async def test_modern_cancelled_frames_matching_no_post_are_swallowed(params: dict[str, Any] | None) -> None: + """At 2026 the frame is swallowed even when it aborts nothing — the wire defines no + client-to-server notifications, so a late cancel racing the response must not leak + a POST — and a mismatched id must not abort someone else's stream.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + posted.append(body) + if body.get("id") == 1: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="subscriptions/listen", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params=params)) + ) + # A follow-up request completing proves the loop moved past the swallowed frame. + await write.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=2, method="ping", params={}))) + reply = await read.receive() + # Checked before teardown: exiting the transport cancels the parked POST. + assert not parked.closed.is_set() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + assert reply.message.id == 2 + assert [body["method"] for body in posted] == ["subscriptions/listen", "ping"] + + +@pytest.mark.anyio +async def test_handler_scoped_cancelled_frames_are_translated_at_modern_too() -> None: + """A cancel carrying `ServerMessageMetadata` (a handler abandoning its own + back-channel request) still names one of OUR outbound ids — every spec-legal + cancel names a request its sender issued — so at 2026 it aborts that POST and + stays off the wire like any other.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + frame_posted = anyio.Event() + handler = _sse_or_ack_handler(parked, posted, frame_posted) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (_read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + await parked.opened.wait() + await write.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ), + metadata=ServerMessageMetadata(related_request_id=99), + ) + ) + await parked.closed.wait() + assert [body["method"] for body in posted] == ["tools/call"] + assert not frame_posted.is_set() + + +@pytest.mark.anyio +async def test_cancel_for_a_request_sent_under_2025_still_posts_after_modern_adoption() -> None: + """The translation follows the era the NAMED request was sent under, not the + cache at cancel time: a request POSTed under 2025 keeps 2025 cancellation + semantics (frame on the wire, stream left open) even after a later message + flips the negotiated version to 2026.""" + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + frame_posted = anyio.Event() + + def handler(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + posted.append(body) + if body.get("id") == 1: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=parked) + if "id" in body: + return httpx.Response(200, json={"jsonrpc": "2.0", "id": body["id"], "result": {}}) + frame_posted.set() + return httpx.Response(202) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: "2025-11-25"}), + ) + ) + await parked.opened.wait() + # A modern-stamped request flips the cached negotiated version. + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=2, method="ping", params={}), + metadata=ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}), + ) + ) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + await write.send( + SessionMessage( + JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1}) + ) + ) + await frame_posted.wait() + # Checked before teardown: exiting the transport cancels the parked POST. + assert not parked.closed.is_set() + assert [body["method"] for body in posted] == ["tools/call", "ping", "notifications/cancelled"] + + +class _SignalingBus(InMemorySubscriptionBus): + """Signals subscribe/unsubscribe so a test observes the stream lifecycle through + the bus Protocol (the public seam) instead of polling handler internals.""" + + def __init__(self) -> None: + super().__init__() + self.subscribed = anyio.Event() + self.unsubscribed = anyio.Event() + + def subscribe(self, listener: Callable[[ServerEvent], None]) -> Callable[[], None]: + unsubscribe = super().subscribe(listener) + self.subscribed.set() + + def unsubscribe_and_signal() -> None: + unsubscribe() + self.unsubscribed.set() + + return unsubscribe_and_signal + + +@pytest.mark.anyio +async def test_scope_cancel_aborts_a_modern_listen_post_end_to_end() -> None: + """Over a real ASGI bridge: cancelling the caller of a parked `subscriptions/listen` + closes the POST's response stream — the server treats the disconnect as the cancel + and releases the subscription — and no `notifications/cancelled` crosses the wire.""" + bus = _SignalingBus() + server = Server("test", on_subscriptions_listen=ListenHandler(bus)) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + async with server.lifespan(server) as lifespan_state: + await handle_modern_request(server, None, False, lifespan_state, scope, receive, send) + + posted_methods: list[str] = [] + + async def record_request(request: httpx.Request) -> None: + posted_methods.append(json.loads(request.content)["method"]) + + acked = anyio.Event() + + async def on_notify(dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + assert method == "notifications/subscriptions/acknowledged" + acked.set() + + on_request, _ = echo_handlers(Recorder()) + + with anyio.fail_after(15): + async with ( + httpx.AsyncClient( + transport=StreamingASGITransport(app), + base_url="http://testserver", + event_hooks={"request": [record_request]}, + ) as http, + streamable_http_client("http://testserver/mcp", http_client=http) as (read, write), + ): + dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read, write) + async with anyio.create_task_group() as tg: # pragma: no branch + await tg.start(dispatcher.run, on_request, on_notify) + listen_scope = anyio.CancelScope() + + async def send_listen() -> None: + params: dict[str, Any] = { + "_meta": { + PROTOCOL_VERSION_META_KEY: LATEST_MODERN_VERSION, + CLIENT_INFO_META_KEY: {"name": "test-client", "version": "0"}, + CLIENT_CAPABILITIES_META_KEY: {}, + }, + "notifications": {"toolsListChanged": True}, + } + opts: CallOptions = { + "request_id": "listen-1", + "headers": { + MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION, + MCP_METHOD_HEADER: "subscriptions/listen", + }, + } + with listen_scope: + await dispatcher.send_raw_request("subscriptions/listen", params, opts) + + tg.start_soon(send_listen) + await acked.wait() + assert bus.subscribed.is_set() + assert not bus.unsubscribed.is_set() + listen_scope.cancel() + await bus.unsubscribed.wait() + tg.cancel_scope.cancel() + assert posted_methods == ["subscriptions/listen"] + + +class _CompletingSSEStream(httpx.AsyncByteStream): + """An SSE body that delivers one JSON-RPC response, then parks in `aclose`. + + Holding `aclose` keeps the finished POST task alive past its response, so a + test can re-register the same request id underneath it before releasing. + """ + + def __init__(self, response_body: dict[str, Any]) -> None: + self._event = f"data: {json.dumps(response_body)}\n\n".encode() + self.release = anyio.Event() + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._event + + async def aclose(self) -> None: + await self.release.wait() + + +@pytest.mark.anyio +async def test_a_finished_post_task_does_not_evict_a_reused_ids_new_registration() -> None: + """Request ids are reusable once resolved; a finished POST task unwinding late + must not pop the successor's registration, or a cancel for the reused id would + find nothing to abort and the live POST would leak past the cancellation.""" + completing = _CompletingSSEStream({"jsonrpc": "2.0", "id": "dup-1", "result": {}}) + parked = _ParkedSSEStream() + posted: list[dict[str, Any]] = [] + streams = [completing, parked] + + def handler(request: httpx.Request) -> httpx.Response: + posted.append(json.loads(request.content)) + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=streams.pop(0)) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + modern = ClientMessageMetadata(headers={MCP_PROTOCOL_VERSION_HEADER: LATEST_MODERN_VERSION}) + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id="dup-1", method="tools/call", params={}), + metadata=modern, + ) + ) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCResponse) + # The first task is now parked in `aclose`; reuse its id underneath it. + await write.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id="dup-1", method="subscriptions/listen", params={}), + metadata=modern, + ) + ) + await parked.opened.wait() + completing.release.set() + await anyio.wait_all_tasks_blocked() + # The successor's registration survived: a cancel still aborts it. + await write.send( + SessionMessage( + JSONRPCNotification(jsonrpc="2.0", method="notifications/cancelled", params={"requestId": "dup-1"}) + ) + ) + await parked.closed.wait() + assert [body["method"] for body in posted] == ["tools/call", "subscriptions/listen"] + + +class _DyingSSEStream(httpx.AsyncByteStream): + """Emits one id-less comment then breaks - a non-resumable stream dropping.""" + + def __init__(self) -> None: + self.opened = anyio.Event() + + async def __aiter__(self) -> AsyncIterator[bytes]: + self.opened.set() + yield b": hello\n\n" + raise httpx.ReadError("connection reset") + + async def aclose(self) -> None: + pass + + +@pytest.mark.anyio +async def test_a_non_resumable_sse_drop_resolves_the_request_with_an_error() -> None: + """A per-request SSE stream that dies having carried no event ids can never deliver its + response; the transport resolves the waiter with CONNECTION_CLOSED instead of hanging forever.""" + dying = _DyingSSEStream() + + def handler(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, headers={"content-type": "text/event-stream"}, stream=dying) + + with anyio.fail_after(5): + async with ( + httpx.AsyncClient(transport=httpx.MockTransport(handler)) as http, + streamable_http_client("http://test/mcp", http_client=http) as (read, write), + ): + await write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id="listen-1", method="subscriptions/listen", params={})) + ) + reply = await read.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCError) + assert reply.message.id == "listen-1" + assert reply.message.error.code == CONNECTION_CLOSED + + +def _abandoned_request_context( + http: httpx.AsyncClient, send: ContextSendStream[SessionMessage | Exception] +) -> RequestContext: + return RequestContext( + client=http, + session_id=None, + session_message=SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id="listen-1", method="subscriptions/listen", params={}) + ), + metadata=None, + read_stream_writer=send, + ) + + +@pytest.mark.anyio +async def test_exhausted_reconnection_attempts_resolve_the_request_with_an_error() -> None: + """An id-bearing stream that exhausts its reconnection budget also resolves the waiter with CONNECTION_CLOSED.""" + transport = StreamableHTTPTransport("http://test/mcp") + send, receive = create_context_streams[SessionMessage | Exception](1) + async with httpx.AsyncClient() as http: + with anyio.fail_after(5): + await transport._handle_reconnection( # pyright: ignore[reportPrivateUsage] + _abandoned_request_context(http, send), "evt-7", None, MAX_RECONNECTION_ATTEMPTS + ) + reply = await receive.receive() + assert isinstance(reply, SessionMessage) + assert isinstance(reply.message, JSONRPCError) + assert reply.message.id == "listen-1" + assert reply.message.error.code == CONNECTION_CLOSED + send.close() + receive.close() + + +@pytest.mark.anyio +async def test_resolving_an_abandoned_request_after_the_reader_closed_is_contained() -> None: + """Teardown race: a stream dying after the reader closed resolves best-effort and must not crash.""" + transport = StreamableHTTPTransport("http://test/mcp") + send, receive = create_context_streams[SessionMessage | Exception](1) + receive.close() + async with httpx.AsyncClient() as http: + with anyio.fail_after(5): + await transport._handle_reconnection( # pyright: ignore[reportPrivateUsage] + _abandoned_request_context(http, send), "evt-7", None, MAX_RECONNECTION_ATTEMPTS + ) + send.close() diff --git a/tests/client/test_subscriptions.py b/tests/client/test_subscriptions.py new file mode 100644 index 000000000..0cc4f133e --- /dev/null +++ b/tests/client/test_subscriptions.py @@ -0,0 +1,666 @@ +"""Behavioral tests for the client-side `subscriptions/listen` driver (SDK-defined contract). + +Public API only, against in-process servers; wire-shape assertions live in the interaction suite. +""" + +from itertools import count +from typing import Any + +import anyio +import mcp_types as types +import pytest +from mcp_types import SubscriptionFilter + +import mcp.client.subscriptions as subscriptions_module +from mcp import Client, MCPError +from mcp.client.session import ClientSession +from mcp.client.subscriptions import ( + ListenNotSupportedError, + ListenRoute, + PromptsListChanged, + ResourcesListChanged, + ResourceUpdated, + ServerEvent, + Subscription, + SubscriptionLost, + ToolsListChanged, + listen, +) +from mcp.server import Server, ServerRequestContext +from mcp.server.subscriptions import ( + SUBSCRIPTION_ID_META_KEY, + InMemorySubscriptionBus, + ListenHandler, +) +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import CallOptions + +pytestmark = pytest.mark.anyio + + +def _bus_server(bus: InMemorySubscriptionBus, *, max_subscriptions: int | None = None) -> Server[Any]: + """A lowlevel server whose only feature is serving listen streams from `bus`.""" + handler = ( + ListenHandler(bus) if max_subscriptions is None else ListenHandler(bus, max_subscriptions=max_subscriptions) + ) + return Server("subs", on_subscriptions_listen=handler) + + +async def _ack(ctx: ServerRequestContext[Any, Any], honored: SubscriptionFilter) -> dict[str, Any]: + """Send a hand-rolled ack for a scripted listen handler; returns the stamped meta.""" + assert ctx.request_id is not None + meta: dict[str, Any] = {SUBSCRIPTION_ID_META_KEY: ctx.request_id} + await ctx.session.send_notification( + types.SubscriptionsAcknowledgedNotification( + params=types.SubscriptionsAcknowledgedNotificationParams(notifications=honored, _meta=meta) + ), + related_request_id=ctx.request_id, + ) + return meta + + +async def test_listen_surfaces_the_honored_filter_and_subscription_id(): + """Entering waits for the server ack and surfaces the honored filter and subscription id.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with client.listen( # pragma: no branch + tools_list_changed=True, resource_subscriptions=["note://todo"] + ) as sub: + assert isinstance(sub, Subscription) + assert sub.honored.tools_list_changed is True + assert sub.honored.resource_subscriptions == ["note://todo"] + assert isinstance(sub.subscription_id, str) + assert sub.subscription_id.startswith("listen-") + + +async def test_listen_delivers_all_four_typed_event_kinds(): + """Bus publishes come back as the same typed event values, in order.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with client.listen( # pragma: no branch + tools_list_changed=True, + prompts_list_changed=True, + resources_list_changed=True, + resource_subscriptions=["note://todo"], + ) as sub: + for event in ( + ToolsListChanged(), + PromptsListChanged(), + ResourcesListChanged(), + ResourceUpdated(uri="note://todo"), + ): + await bus.publish(event) + assert await anext(sub) == event + + +async def test_unconsumed_duplicate_events_coalesce(): + """Events are level triggers: duplicates pending consumption collapse to one.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with client.listen( # pragma: no branch + tools_list_changed=True, resource_subscriptions=["note://todo"] + ) as sub: + for _ in range(3): + await bus.publish(ToolsListChanged()) + await bus.publish(ResourceUpdated(uri="note://todo")) + await anyio.wait_all_tasks_blocked() + assert await anext(sub) == ToolsListChanged() + assert await anext(sub) == ResourceUpdated(uri="note://todo") + + +async def test_graceful_server_close_ends_the_loop_cleanly(): + """The server's deliberate close ends iteration cleanly, after draining prior events.""" + bus = InMemorySubscriptionBus() + handler = ListenHandler(bus) + server = Server("subs", on_subscriptions_listen=handler) + events: list[object] = [] + async with Client(server) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + await bus.publish(ToolsListChanged()) + handler.close() + events.extend([event async for event in sub]) + assert events == [ToolsListChanged()] + + +async def test_abrupt_stream_end_raises_subscription_lost(): + """A stream dying without the graceful result raises `SubscriptionLost` with the cause chained.""" + proceed = anyio.Event() + + async def dropping_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + await _ack(ctx, params.notifications) + await proceed.wait() + raise MCPError(types.INTERNAL_ERROR, "stream torn down") + + server = Server("subs", on_subscriptions_listen=dropping_listen) + async with Client(server) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + proceed.set() + with pytest.raises(SubscriptionLost) as exc_info: # pragma: no branch + await anext(sub) + assert isinstance(exc_info.value.__cause__, MCPError) + assert exc_info.value.__cause__.error.message == "stream torn down" + + +async def test_listen_on_a_legacy_connection_raises_the_typed_steer(): + """On a 2025 connection `listen` fails fast with the typed error steering to the legacy verbs.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus), mode="legacy") as client: + with anyio.fail_after(5): + # Entering is where the guard fires; __aenter__ directly avoids an unreachable with-body. + with pytest.raises(ListenNotSupportedError) as exc_info: # pragma: no branch + await client.listen(tools_list_changed=True).__aenter__() + assert exc_info.value.negotiated_version == "2025-11-25" + assert "subscribe_resource" in str(exc_info.value) + + +async def test_server_rejection_raises_from_enter_not_from_iteration(): + """A server without the listen handler fails the open from entering the context.""" + server = Server("no-listen") + async with Client(server) as client: + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: # pragma: no branch + await client.listen(tools_list_changed=True).__aenter__() + assert exc_info.value.error.code == types.METHOD_NOT_FOUND + + +async def test_immediate_result_without_ack_opens_already_closed(): + """A bare result with no ack yields a subscription already gracefully over: no filter, no events.""" + + async def degenerate_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + return types.SubscriptionsListenResult(_meta={SUBSCRIPTION_ID_META_KEY: ctx.request_id}) + + server = Server("subs", on_subscriptions_listen=degenerate_listen) + async with Client(server) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + assert sub.honored == SubscriptionFilter() + with pytest.raises(StopAsyncIteration): # pragma: no branch + await anext(sub) + + +async def test_server_sent_cancelled_for_the_listen_id_raises_subscription_lost(): + """Server-sent notifications/cancelled for the listen id surfaces as a lost subscription.""" + proceed = anyio.Event() + + async def cancelling_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + await _ack(ctx, params.notifications) + await proceed.wait() + await ctx.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=ctx.request_id)), + related_request_id=ctx.request_id, + ) + await anyio.sleep_forever() + raise AssertionError("unreachable") # pragma: no cover + + server = Server("subs", on_subscriptions_listen=cancelling_listen) + async with Client(server) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + proceed.set() + with pytest.raises(SubscriptionLost): # pragma: no branch + await anext(sub) + + +async def test_exiting_the_context_frees_the_server_slot(): + """Leaving the block ends the subscription server-side: a one-slot handler admits a second listen.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus, max_subscriptions=1)) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as first: + assert first.honored.tools_list_changed is True + async with client.listen(tools_list_changed=True) as second: # pragma: no branch + assert second.honored.tools_list_changed is True + assert second.subscription_id != first.subscription_id + + +async def test_concurrent_subscriptions_demux_independently(): + """Two open subscriptions each receive only their own filter's events.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with ( # pragma: no branch + client.listen(tools_list_changed=True) as tools_sub, + client.listen(resource_subscriptions=["note://todo"]) as notes_sub, + ): + await bus.publish(ToolsListChanged()) + await bus.publish(ResourceUpdated(uri="note://todo")) + assert await anext(tools_sub) == ToolsListChanged() + assert await anext(notes_sub) == ResourceUpdated(uri="note://todo") + # Neither stream received the other's event. + await bus.publish(ToolsListChanged()) + assert await anext(tools_sub) == ToolsListChanged() + + +async def test_change_notifications_still_reach_message_handler(): + """The demux tees: a delivered event's notification still reaches message_handler; the ack never does.""" + bus = InMemorySubscriptionBus() + seen: list[str] = [] + + async def on_message(message: object) -> None: + assert not isinstance(message, types.SubscriptionsAcknowledgedNotification) + if isinstance(message, types.ToolListChangedNotification): # pragma: no branch + seen.append("tools-changed") + + async with Client(_bus_server(bus), message_handler=on_message) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + await bus.publish(ToolsListChanged()) + assert await anext(sub) == ToolsListChanged() + await anyio.wait_all_tasks_blocked() + assert seen == ["tools-changed"] + + +async def test_enter_times_out_when_the_ack_never_arrives(): + """The ack wait rides the session's read timeout, so a wedged server cannot hang the open.""" + + async def silent_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + await anyio.sleep_forever() + raise AssertionError("unreachable") # pragma: no cover + + server = Server("subs", on_subscriptions_listen=silent_listen) + async with Client(server, read_timeout_seconds=0.05) as client: + with anyio.fail_after(5): + with pytest.raises(TimeoutError): # pragma: no branch + await client.listen(tools_list_changed=True).__aenter__() + + +async def test_an_open_stream_outlives_the_session_read_timeout(): + """The listen request is exempt from the read timeout: the stream delivers after the deadline.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus), read_timeout_seconds=0.05) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + # Real clock on purpose: this pins a timeout feature. + await anyio.sleep(0.2) + await bus.publish(ToolsListChanged()) + assert await anext(sub) == ToolsListChanged() + + +async def test_a_duplicate_ack_does_not_overwrite_the_honored_filter(): + """The first ack wins; a later conflicting ack is a no-op.""" + proceed = anyio.Event() + + async def double_acking_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + await _ack(ctx, params.notifications) + await _ack(ctx, SubscriptionFilter()) + await proceed.wait() + return types.SubscriptionsListenResult(_meta={SUBSCRIPTION_ID_META_KEY: ctx.request_id}) + + server = Server("subs", on_subscriptions_listen=double_acking_listen) + async with Client(server) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + assert sub.honored.tools_list_changed is True + proceed.set() + + +async def test_a_non_event_frame_with_the_subscription_id_is_teed_not_delivered(): + """A stamped non-event notification never surfaces as an event; it flows to message_handler.""" + proceed = anyio.Event() + + async def logging_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + meta = await _ack(ctx, params.notifications) + await ctx.session.send_notification( + types.LoggingMessageNotification( + params=types.LoggingMessageNotificationParams(level="info", data="not an event", _meta=meta) + ), + related_request_id=ctx.request_id, + ) + await proceed.wait() + return types.SubscriptionsListenResult(_meta=meta) + + logged: list[str] = [] + + async def on_message(message: object) -> None: + if isinstance(message, types.LoggingMessageNotification): # pragma: no branch + logged.append(str(message.params.data)) + + server = Server("subs", on_subscriptions_listen=logging_listen) + async with Client(server, message_handler=on_message) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + await anyio.wait_all_tasks_blocked() + proceed.set() + with pytest.raises(StopAsyncIteration): # pragma: no branch + await anext(sub) + assert logged == ["not an event"] + + +async def test_session_teardown_unblocks_a_sibling_consumer_with_subscription_lost(): + """Session teardown settles every open route as lost, unblocking parked consumers.""" + bus = InMemorySubscriptionBus() + outcome: list[str] = [] + entered = anyio.Event() + + async def consume(client: Client) -> None: + with pytest.raises(SubscriptionLost): + async with client.listen(tools_list_changed=True) as sub: + entered.set() + await anext(sub) + outcome.append("lost") + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + async with Client(_bus_server(bus)) as client: # pragma: no branch + tg.start_soon(consume, client) + await entered.wait() + assert outcome == ["lost"] + + +async def test_server_cancel_before_the_ack_raises_subscription_lost_from_enter(): + """A stream torn down before it was ever acknowledged is a failed open: enter raises.""" + + async def cancel_first_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + await ctx.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=ctx.request_id)), + related_request_id=ctx.request_id, + ) + await anyio.sleep_forever() + raise AssertionError("unreachable") # pragma: no cover + + server = Server("subs", on_subscriptions_listen=cancel_first_listen) + async with Client(server) as client: + with anyio.fail_after(5): + with pytest.raises(SubscriptionLost, match="before it was acknowledged"): # pragma: no branch + await client.listen(tools_list_changed=True).__aenter__() + + +async def test_listen_on_an_exited_session_raises_and_leaks_no_route(): + """Opening on an exited session fails loudly and leaves no demux registration behind.""" + bus = InMemorySubscriptionBus() + client = Client(_bus_server(bus)) + async with client: + session = client.session + with pytest.raises(RuntimeError): + await listen(session, tools_list_changed=True).__aenter__() + assert session._listen_routes == {} # pyright: ignore[reportPrivateUsage] + + +async def test_listen_on_a_never_entered_session_raises_runtime_error(): + """An adopted-but-never-entered session has no task group to drive the stream.""" + dispatcher, _peer = create_direct_dispatcher_pair() + session = ClientSession(dispatcher=dispatcher) + session.adopt( + types.DiscoverResult( + supported_versions=["2026-07-28"], + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="stub", version="0"), + ) + ) + with pytest.raises(RuntimeError, match="entered session"): + await listen(session, tools_list_changed=True).__aenter__() + assert session._listen_routes == {} # pyright: ignore[reportPrivateUsage] + + +async def test_a_retained_handle_after_exit_does_not_serve_stale_events(): + """Leaving the block abandons the backlog: a stashed handle must not replay buffered events.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: + await bus.publish(ToolsListChanged()) + await anyio.wait_all_tasks_blocked() + with pytest.raises(StopAsyncIteration): # pragma: no branch + await anext(sub) + + +async def test_a_stray_ack_outside_the_driver_namespace_still_reaches_message_handler(): + """Acks for ids the driver never minted flow to message_handler (the raw-listen escape hatch).""" + proceed = anyio.Event() + + async def stray_acking_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + await _ack(ctx, params.notifications) + await ctx.session.send_notification( + types.SubscriptionsAcknowledgedNotification( + params=types.SubscriptionsAcknowledgedNotificationParams( + notifications=SubscriptionFilter(), _meta={SUBSCRIPTION_ID_META_KEY: 424242} + ) + ), + related_request_id=ctx.request_id, + ) + await proceed.wait() + return types.SubscriptionsListenResult(_meta={SUBSCRIPTION_ID_META_KEY: ctx.request_id}) + + handled: list[str] = [] + + async def on_message(message: object) -> None: + handled.append(type(message).__name__) + + server = Server("subs", on_subscriptions_listen=stray_acking_listen) + async with Client(server, message_handler=on_message) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + await anyio.wait_all_tasks_blocked() + proceed.set() + with pytest.raises(StopAsyncIteration): # pragma: no branch + await anext(sub) + assert "SubscriptionsAcknowledgedNotification" in handled + + +async def test_a_bare_string_for_resource_subscriptions_is_rejected(): + """A bare string would explode into per-character URIs; it is rejected before touching the wire.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with pytest.raises(TypeError, match="sequence of URIs"): + await client.listen(resource_subscriptions="note://todo").__aenter__() # pyright: ignore[reportArgumentType] + + +def test_the_route_admits_only_honored_events_and_only_while_live(): + """Route admission: nothing before the ack, only honored events while live, nothing after the end.""" + route = ListenRoute() + route.deliver(ToolsListChanged()) + assert route._pending == {} # pyright: ignore[reportPrivateUsage] + route.set_acked(SubscriptionFilter(tools_list_changed=True, resource_subscriptions=["note://todo"])) + route.deliver(PromptsListChanged()) # kind not honored + route.deliver(ResourceUpdated(uri="note://todo/draft")) # sub-resource of a subscribed URI: spec says admit + route.deliver(ResourceUpdated(uri="note://todo")) + route.deliver(ToolsListChanged()) + route.deliver(ToolsListChanged()) # duplicate pending consumption collapses + assert list(route._pending) == [ # pyright: ignore[reportPrivateUsage] + ResourceUpdated(uri="note://todo/draft"), + ResourceUpdated(uri="note://todo"), + ToolsListChanged(), + ] + route.settle("graceful") + route.deliver(ResourceUpdated(uri="note://todo")) # post-close noise is refused + assert len(route._pending) == 3 # pyright: ignore[reportPrivateUsage] + + +def test_a_peer_flooding_distinct_uris_costs_the_subscription_not_client_memory(): + """A peer flooding distinct URIs trips the `_MAX_PENDING_EVENTS` backstop: the route + settles lost instead of growing client memory without bound.""" + route = ListenRoute() + route.set_acked(SubscriptionFilter(resource_subscriptions=["note://todo"])) + for n in range(subscriptions_module._MAX_PENDING_EVENTS): # pyright: ignore[reportPrivateUsage] + route.deliver(ResourceUpdated(uri=f"note://todo/{n}")) + assert route.end is None + route.deliver(ResourceUpdated(uri="note://todo/one-too-many")) + assert route.end == "lost" + assert route.error is not None + assert "backlog" in route.error.error.message + # The overflowing event was not queued. + assert len(route._pending) == subscriptions_module._MAX_PENDING_EVENTS # pyright: ignore[reportPrivateUsage] + + +async def test_a_cancelled_on_event_barrier_does_not_lose_the_event(): + """Cancelling `anext` mid-barrier leaves the event queued; the next `anext` re-runs the + idempotent barrier and returns it.""" + bus = InMemorySubscriptionBus() + entered = anyio.Event() + release = anyio.Event() + + async def parked_barrier(event: ServerEvent) -> None: + entered.set() + await release.wait() + + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with listen( + client.session, tools_list_changed=True, on_event=parked_barrier + ) as sub: # pragma: no branch + await bus.publish(ToolsListChanged()) + async with anyio.create_task_group() as tg: + cancel_scope = anyio.CancelScope() + + async def first_attempt() -> None: + with cancel_scope: + await anext(sub) + raise AssertionError("must be cancelled mid-barrier") # pragma: no cover + + tg.start_soon(first_attempt) + await entered.wait() + cancel_scope.cancel() + release.set() + assert await anext(sub) == ToolsListChanged() + + +async def test_events_outside_the_honored_filter_are_never_delivered(): + """A server violating its acknowledged filter cannot reach the consumer or grow the backlog.""" + proceed = anyio.Event() + + async def overreaching_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + meta = await _ack(ctx, params.notifications) # honors exactly what was requested: tools only + await ctx.session.send_notification( + types.ResourceUpdatedNotification( + params=types.ResourceUpdatedNotificationParams(uri="note://uninvited", _meta=meta) + ), + related_request_id=ctx.request_id, + ) + await ctx.session.send_notification( + types.ToolListChangedNotification(params=types.NotificationParams(_meta=meta)), + related_request_id=ctx.request_id, + ) + await proceed.wait() + return types.SubscriptionsListenResult(_meta=meta) + + server = Server("subs", on_subscriptions_listen=overreaching_listen) + async with Client(server) as client: + with anyio.fail_after(5): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + assert await anext(sub) == ToolsListChanged() + proceed.set() + with pytest.raises(StopAsyncIteration): # pragma: no branch + await anext(sub) + + +async def test_the_on_event_barrier_completes_before_each_event_is_returned(): + """`on_event` is awaited before the iterator returns each event (the Client wires cache eviction here).""" + bus = InMemorySubscriptionBus() + order: list[str] = [] + + async def barrier(event: ServerEvent) -> None: + order.append(f"barrier:{type(event).__name__}") + + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with listen(client.session, tools_list_changed=True, on_event=barrier) as sub: # pragma: no branch + await bus.publish(ToolsListChanged()) + event = await anext(sub) + order.append(f"returned:{type(event).__name__}") + assert order == ["barrier:ToolsListChanged", "returned:ToolsListChanged"] + + +async def test_client_listen_installs_the_cache_eviction_barrier_exactly_when_a_cache_exists(): + """`Client.listen` wires the response-cache evictor as the barrier only when a cache exists.""" + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as cached_client: + with anyio.fail_after(5): + async with cached_client.listen(tools_list_changed=True) as sub: # pragma: no branch + assert sub._on_event == cached_client._evict_for_listen_event # pyright: ignore[reportPrivateUsage] + async with Client(_bus_server(bus), cache=False) as uncached_client: + with anyio.fail_after(5): + async with uncached_client.listen(tools_list_changed=True) as sub: # pragma: no branch + assert sub._on_event is None # pyright: ignore[reportPrivateUsage] + + +async def test_the_cache_eviction_barrier_maps_events_and_contains_store_faults( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The barrier evicts through the same notification mapping as the message_handler wrapper; + a raising store costs a log line, not the delivery.""" + client = Client(_bus_server(InMemorySubscriptionBus())) + cache = client._response_cache # pyright: ignore[reportPrivateUsage] + assert cache is not None + evicted: list[types.ServerNotification] = [] + + async def record(notification: types.ServerNotification) -> None: + evicted.append(notification) + + monkeypatch.setattr(cache, "evict_for_notification", record) + await client._evict_for_listen_event(ResourceUpdated(uri="note://x")) # pyright: ignore[reportPrivateUsage] + assert isinstance(evicted[0], types.ResourceUpdatedNotification) + assert evicted[0].params.uri == "note://x" + + async def broken(notification: types.ServerNotification) -> None: + raise RuntimeError("store down") + + monkeypatch.setattr(cache, "evict_for_notification", broken) + # Contained: a cache fault must not block delivery. + await client._evict_for_listen_event(ToolsListChanged()) # pyright: ignore[reportPrivateUsage] + + +async def test_a_raw_request_id_collision_fails_the_subscription_not_the_session( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A raw caller occupying the driver's next minted id fails that one listen from enter; + the session survives and the next listen opens normally.""" + monkeypatch.setattr(subscriptions_module, "_listen_ids", count(7000)) + bus = InMemorySubscriptionBus() + async with Client(_bus_server(bus)) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + raw_scope = anyio.CancelScope() + + async def raw_listen() -> None: + request = types.SubscriptionsListenRequest( + params=types.SubscriptionsListenRequestParams( + notifications=SubscriptionFilter(tools_list_changed=True) + ) + ) + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + opts: CallOptions = {"request_id": "listen-7000"} + client.session._stamp(data, opts) # pyright: ignore[reportPrivateUsage] + with raw_scope: + await client.session._dispatcher.send_raw_request( # pyright: ignore[reportPrivateUsage] + data["method"], data.get("params"), opts + ) + + tg.start_soon(raw_listen) + await anyio.wait_all_tasks_blocked() + with pytest.raises(MCPError) as exc_info: + await client.listen(tools_list_changed=True).__aenter__() + assert "already in flight" in exc_info.value.error.message + # The failed open released the colliding id's demux registration. + assert client.session._listen_routes == {} # pyright: ignore[reportPrivateUsage] + raw_scope.cancel() + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + assert sub.subscription_id == "listen-7001" diff --git a/tests/docs_src/test_client.py b/tests/docs_src/test_client.py index af5e69249..3d70371f5 100644 --- a/tests/docs_src/test_client.py +++ b/tests/docs_src/test_client.py @@ -5,7 +5,7 @@ from mcp_types import Prompt, PromptArgument, PromptReference, TextContent, TextResourceContents, Tool from docs_src.client import tutorial001, tutorial002, tutorial003, tutorial004, tutorial005, tutorial006, tutorial007 -from mcp import Client, MCPError +from mcp import Client, MCPDeprecationWarning, MCPError from mcp.shared.metadata_utils import get_display_name # See test_index.py for why this is a per-module mark and not a conftest hook. @@ -128,7 +128,9 @@ async def test_resource_subscriptions_are_listen_based_on_the_modern_wire() -> N assert client.server_capabilities.resources is not None assert client.server_capabilities.resources.subscribe is True with pytest.raises(MCPError) as exc_info: - await client.subscribe_resource("catalog://genres") + # The verb is itself deprecated; the modern wire also rejects it. + with pytest.warns(MCPDeprecationWarning, match="use Client.listen"): + await client.subscribe_resource("catalog://genres") # pyright: ignore[reportDeprecated] assert exc_info.value.error.code == -32601 assert exc_info.value.error.message == "Method not found" diff --git a/tests/docs_src/test_subscriptions.py b/tests/docs_src/test_subscriptions.py index b664afe98..b74618dcf 100644 --- a/tests/docs_src/test_subscriptions.py +++ b/tests/docs_src/test_subscriptions.py @@ -5,10 +5,19 @@ import anyio import mcp_types as types import pytest +from trio.testing import MockClock -from docs_src.subscriptions import tutorial001, tutorial002 +from docs_src.subscriptions import tutorial001, tutorial002, tutorial003, tutorial004 from mcp import Client -from mcp.server.subscriptions import SUBSCRIPTION_ID_META_KEY, ToolsListChanged +from mcp.server.context import ServerRequestContext +from mcp.server.lowlevel import Server +from mcp.server.subscriptions import ( + SUBSCRIPTION_ID_META_KEY, + InMemorySubscriptionBus, + ListenHandler, + ResourceUpdated, + ToolsListChanged, +) # See test_index.py for why this is a per-module mark and not a conftest hook. pytestmark = [pytest.mark.anyio, pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning")] @@ -136,3 +145,131 @@ async def listen() -> None: assert isinstance(stream.received[2], types.ResourceUpdatedNotification) tg.cancel_scope.cancel() + + +async def test_client_listen_delivers_one_typed_event_then_closes() -> None: + """tutorial003: `Client.listen` yields typed events for the subscribed URI; leaving the block closes the stream.""" + results: list[str] = [] + + async def watch() -> None: + results.append(await tutorial003.watch_todo()) + + with anyio.fail_after(10): + async with anyio.create_task_group() as tg: + tg.start_soon(watch) + # Let the watcher park on its stream (ack complete) before the edit is published. + await anyio.wait_all_tasks_blocked() + async with Client(tutorial001.mcp) as editor: # pragma: no branch + await editor.call_tool("edit_note", {"name": "todo", "text": "water plants"}) + assert results == ["changed: note://todo"] + + +class _Reads: + """Counts server-side resource reads and lets tests await a count.""" + + def __init__(self) -> None: + self.count = 0 + self._bump = anyio.Event() + + def hit(self) -> None: + self.count += 1 + self._bump.set() + self._bump = anyio.Event() + + async def wait_for(self, count: int) -> None: + with anyio.fail_after(5): + while self.count < count: + await self._bump.wait() + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +async def test_watcher_re_listens_after_both_endings() -> None: + """tutorial004: watch() refetches on entry and per event, and re-listens after + a graceful server close and after `SubscriptionLost`. + + Runs on trio's autojumping MockClock so the loop's backoff sleep takes no wall-clock time. + + Steps: + 1. Stream 1: the entry refetch proves the ack arrived; a publish drives an event refetch. + 2. handler.close() ends stream 1 gracefully; the watcher backs off, re-listens (stream 2, + a new subscription id), and refetches. + 3. The drop tool cancels stream 2 abruptly; the watcher swallows SubscriptionLost, + re-listens (stream 3), and refetches on the next publish.""" + DROP_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": {"subscription_id": {"type": "string"}}, + "required": ["subscription_id"], + } + bus = InMemorySubscriptionBus() + handler = ListenHandler(bus) + reads = _Reads() + stream = _Stream() + + async def read_resource( + ctx: ServerRequestContext[Any], params: types.ReadResourceRequestParams + ) -> types.ReadResourceResult: + reads.hit() + return types.ReadResourceResult(contents=[types.TextResourceContents(uri=params.uri, text="fresh")]) + + async def list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="drop", description="End a subscription abruptly.", input_schema=DROP_SCHEMA)] + ) + + async def drop_stream(ctx: ServerRequestContext[Any], params: types.CallToolRequestParams) -> types.CallToolResult: + # The abrupt ending: the server cancels the named subscription without a + # graceful close. Sent request-scoped: the 2026 wire has no standalone stream. + subscription_id = (params.arguments or {})["subscription_id"] + await ctx.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=subscription_id)), + related_request_id=ctx.request_id, + ) + return types.CallToolResult(content=[]) + + server = Server( + "watched", + on_read_resource=read_resource, + on_list_tools=list_tools, + on_call_tool=drop_stream, + on_subscriptions_listen=handler, + ) + + def teed_subscription_id(index: int) -> Any: + updated = stream.received[index] + assert isinstance(updated, types.ResourceUpdatedNotification) + assert updated.params.meta is not None + return updated.params.meta[SUBSCRIPTION_ID_META_KEY] + + async with Client(server, mode="2026-07-28", message_handler=stream.handler) as client: + async with anyio.create_task_group() as tg: + tg.start_soon(tutorial004.watch, client, "note://todo") + + # Stream 1: the entry refetch proves the ack arrived; an event drives one more refetch. + await reads.wait_for(1) + await bus.publish(ResourceUpdated(uri="note://todo")) + await reads.wait_for(2) + await stream.wait_for(1) + + # Graceful close: the watcher backs off, re-listens, and refetches. + handler.close() + await reads.wait_for(3) + await bus.publish(ResourceUpdated(uri="note://todo")) + await reads.wait_for(4) + await stream.wait_for(2) + second_id = teed_subscription_id(1) + assert second_id != teed_subscription_id(0) + + # Abrupt ending: the watcher swallows SubscriptionLost and re-listens again. + await client.call_tool("drop", {"subscription_id": second_id}) + await reads.wait_for(5) + await bus.publish(ResourceUpdated(uri="note://todo")) + await reads.wait_for(6) + await stream.wait_for(3) + assert teed_subscription_id(2) != second_id + + tg.cancel_scope.cancel() diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index ada4b7fa0..95f873b79 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1182,6 +1182,50 @@ def __post_init__(self) -> None: removed_in="2026-07-28", note="removed in 2026-07-28 (SEP-2575); resources/unsubscribe replaced by subscriptions/listen.", ), + "subscriptions:listen:client:honored-surfacing": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/patterns/subscriptions#acknowledgment", + behavior=( + "Entering Client.listen() waits for the server's acknowledgment and surfaces the honored " + "filter subset on the handle, so the client can check it against what it requested (spec SHOULD)." + ), + added_in="2026-07-28", + ), + "subscriptions:listen:client:iteration": Requirement( + source="sdk", + behavior=( + "An open subscription is an async iterator of typed change events; delivered notifications " + "still tee to message_handler so caching and observers keep working." + ), + added_in="2026-07-28", + ), + "subscriptions:listen:client:graceful-close": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/patterns/subscriptions#cancellation", + behavior=( + "The server's empty subscriptions/listen result (its deliberate close) ends iteration cleanly " + "after buffered events drain; no exception is raised." + ), + added_in="2026-07-28", + ), + "subscriptions:listen:client:lost": Requirement( + source="sdk", + behavior=( + "A listen stream that ends without the graceful result raises SubscriptionLost from iteration; " + "there is no automatic re-listen." + ), + added_in="2026-07-28", + ), + "subscriptions:listen:client:era-guard": Requirement( + source="sdk", + behavior=( + "Client.listen() on a pre-2026 connection raises ListenNotSupportedError steering to " + "subscribe_resource/message_handler instead of leaking a wire -32601." + ), + removed_in="2026-07-28", + note=( + "removed_in scopes the matrix to the 2025 cells deliberately: the behavior under test is the " + "guard on connections where the method does not exist." + ), + ), "resources:updated-notification": Requirement( source=f"{SPEC_BASE_URL}/server/resources#subscriptions", behavior=( diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 44ab33e64..db7d4dfe6 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -203,6 +203,7 @@ async def list_resource_templates( ) +@pytest.mark.filterwarnings("ignore::mcp.MCPDeprecationWarning") @requirement("resources:subscribe") async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" @@ -214,11 +215,12 @@ async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeR server = Server("library", on_subscribe_resource=subscribe_resource) async with connect(server) as client: - result = await client.subscribe_resource("file:///watched.txt") + result = await client.subscribe_resource("file:///watched.txt") # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) +@pytest.mark.filterwarnings("ignore::mcp.MCPDeprecationWarning") @requirement("resources:subscribe:capability-required") async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. @@ -237,13 +239,14 @@ async def list_resources( async with connect(server) as client: with pytest.raises(MCPError) as exc_info: - await client.subscribe_resource("file:///watched.txt") + await client.subscribe_resource("file:///watched.txt") # pyright: ignore[reportDeprecated] assert exc_info.value.error == snapshot( ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="resources/subscribe") ) +@pytest.mark.filterwarnings("ignore::mcp.MCPDeprecationWarning") @requirement("resources:unsubscribe") async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" @@ -255,7 +258,7 @@ async def unsubscribe_resource(ctx: ServerRequestContext, params: types.Unsubscr server = Server("library", on_unsubscribe_resource=unsubscribe_resource) async with connect(server) as client: - result = await client.unsubscribe_resource("file:///watched.txt") + result = await client.unsubscribe_resource("file:///watched.txt") # pyright: ignore[reportDeprecated] assert result == snapshot(EmptyResult()) diff --git a/tests/interaction/lowlevel/test_subscriptions.py b/tests/interaction/lowlevel/test_subscriptions.py new file mode 100644 index 000000000..88b2ab465 --- /dev/null +++ b/tests/interaction/lowlevel/test_subscriptions.py @@ -0,0 +1,62 @@ +"""Client.listen stream endings against lowlevel servers over the connect matrix.""" + +from typing import Any + +import anyio +import mcp_types as types +import pytest + +from mcp import MCPError +from mcp.client.subscriptions import SubscriptionLost, ToolsListChanged +from mcp.server import Server, ServerRequestContext +from mcp.server.subscriptions import SUBSCRIPTION_ID_META_KEY, InMemorySubscriptionBus, ListenHandler +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("subscriptions:listen:client:graceful-close") +async def test_a_graceful_server_close_ends_iteration_after_buffered_events(connect: Connect) -> None: + """`ListenHandler.close()` sends the result last; iteration drains published events, then ends cleanly.""" + bus = InMemorySubscriptionBus() + handler = ListenHandler(bus) + server = Server("subs", on_subscriptions_listen=handler) + events: list[object] = [] + async with connect(server) as client: + with anyio.fail_after(10): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + await bus.publish(ToolsListChanged()) + handler.close() + events.extend([event async for event in sub]) + assert events == [ToolsListChanged()] + + +@requirement("subscriptions:listen:client:lost") +async def test_a_stream_dropped_after_the_ack_raises_subscription_lost(connect: Connect) -> None: + """Erroring the listen request after the ack (abrupt, not graceful) raises SubscriptionLost from iteration.""" + proceed = anyio.Event() + + async def dropping_listen( + ctx: ServerRequestContext[Any, Any], params: types.SubscriptionsListenRequestParams + ) -> types.SubscriptionsListenResult: + assert ctx.request_id is not None + await ctx.session.send_notification( + types.SubscriptionsAcknowledgedNotification( + params=types.SubscriptionsAcknowledgedNotificationParams( + notifications=params.notifications, + _meta={SUBSCRIPTION_ID_META_KEY: ctx.request_id}, + ) + ), + related_request_id=ctx.request_id, + ) + await proceed.wait() + raise MCPError(types.INTERNAL_ERROR, "stream torn down") + + server = Server("subs", on_subscriptions_listen=dropping_listen) + async with connect(server) as client: + with anyio.fail_after(10): + async with client.listen(tools_list_changed=True) as sub: # pragma: no branch + proceed.set() + with pytest.raises(SubscriptionLost): # pragma: no branch + await anext(sub) diff --git a/tests/interaction/mcpserver/test_subscriptions.py b/tests/interaction/mcpserver/test_subscriptions.py new file mode 100644 index 000000000..047b049d9 --- /dev/null +++ b/tests/interaction/mcpserver/test_subscriptions.py @@ -0,0 +1,62 @@ +"""Client.listen against MCPServer over the connect matrix (2026-07-28).""" + +import anyio +import pytest + +from mcp.client.subscriptions import ListenNotSupportedError, ResourceUpdated, ToolsListChanged +from mcp.server.mcpserver import Context, MCPServer +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _notebook() -> MCPServer: + mcp = MCPServer("notebook") + + @mcp.tool() + async def touch_tools(ctx: Context) -> str: + await ctx.notify_tools_changed() + return "ok" + + @mcp.tool() + async def edit_note(name: str, ctx: Context) -> str: + await ctx.notify_resource_updated(f"note://{name}") + return "saved" + + return mcp + + +@requirement("subscriptions:listen:client:honored-surfacing") +@requirement("subscriptions:listen:client:iteration") +async def test_listen_surfaces_the_ack_and_iterates_typed_events(connect: Connect) -> None: + """Entering waits for the ack (honored is set before any event); iteration yields + only the typed event kinds this stream opted in to.""" + mcp = _notebook() + async with connect(mcp) as client: + with anyio.fail_after(10): + async with client.listen( # pragma: no branch + tools_list_changed=True, resource_subscriptions=["note://todo"] + ) as sub: + assert sub.honored.tools_list_changed is True + assert sub.honored.resource_subscriptions == ["note://todo"] + + await client.call_tool("edit_note", {"name": "journal"}) # unsubscribed URI: silent + await client.call_tool("edit_note", {"name": "todo"}) + assert await anext(sub) == ResourceUpdated(uri="note://todo") + + await client.call_tool("touch_tools", {}) + assert await anext(sub) == ToolsListChanged() + + +@requirement("subscriptions:listen:client:era-guard") +async def test_listen_on_a_pre_2026_connection_raises_the_typed_steer(connect: Connect) -> None: + """On 2025-era connections the guard fires before anything touches the wire, steering to the legacy verbs.""" + mcp = _notebook() + async with connect(mcp) as client: + with anyio.fail_after(10): + # Entering is where the guard fires; __aenter__ directly avoids an unreachable with-body. + with pytest.raises(ListenNotSupportedError) as exc_info: + await client.listen(tools_list_changed=True).__aenter__() + assert exc_info.value.negotiated_version == client.session.protocol_version + assert "subscribe_resource" in str(exc_info.value) diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 1f8208337..c6ebb401f 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -19,12 +19,13 @@ INVALID_REQUEST, REQUEST_TIMEOUT, ErrorData, + RequestId, Tool, ) from mcp.shared._compat import resync_tracer from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair -from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnNotifyIntercept, OnRequest, Outbound from mcp.shared.exceptions import MCPError from mcp.shared.transport_context import TransportContext @@ -65,6 +66,7 @@ async def running_pair( server_on_notify: OnNotify | None = None, client_on_request: OnRequest | None = None, client_on_notify: OnNotify | None = None, + client_on_notify_intercept: OnNotifyIntercept | None = None, can_send_request: bool = True, ) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: """Yield `(client, server, client_recorder, server_recorder)` with both `run()` loops live.""" @@ -74,7 +76,9 @@ async def running_pair( s_req, s_notify = echo_handlers(server_rec) try: async with anyio.create_task_group() as tg: - await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start( + client.run, client_on_request or c_req, client_on_notify or c_notify, client_on_notify_intercept + ) await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) try: yield client, server, client_rec, server_rec @@ -396,6 +400,179 @@ async def test_direct_close_makes_run_return(): server.close() +@pytest.mark.anyio +async def test_send_raw_request_honors_caller_supplied_request_id_verbatim_typed(pair_factory: PairFactory): + """A caller-supplied `CallOptions["request_id"]` reaches the peer's context verbatim — + "7" stays a string, never the integer 7 — and the next call without one still mints + a dispatcher id as before.""" + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.send_raw_request("first", None, {"request_id": "7"}) + await client.send_raw_request("second", None) + supplied, minted = (ctx.request_id for ctx in srec.contexts) + assert supplied == "7" + assert type(supplied) is str + assert type(minted) is int + + +@pytest.mark.anyio +async def test_send_raw_request_with_in_flight_request_id_raises_and_frees_id_on_completion( + pair_factory: PairFactory, +): + """Reusing an id while it is in flight is a loud `ValueError` — silent reuse would + corrupt response correlation. Once the first request completes, the id is free + again: the reservation is in-flight-scoped, not permanent.""" + entered = anyio.Event() + release = anyio.Event() + + async def parked( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + entered.set() + await release.wait() + return {"served": method} + + async with running_pair(pair_factory, server_on_request=parked) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def first() -> None: + await client.send_raw_request("slow", None, {"request_id": "listen-1"}) + + tg.start_soon(first) + await entered.wait() + with pytest.raises(ValueError, match="already in flight"): + await client.send_raw_request("duplicate", None, {"request_id": "listen-1"}) + release.set() + result = await client.send_raw_request("again", None, {"request_id": "listen-1"}) + assert result == {"served": "again"} + + +@pytest.mark.anyio +async def test_minted_ids_skip_a_caller_supplied_id_still_in_flight(pair_factory: PairFactory): + """The dispatcher mints PAST a key a supplied id occupies — the collision error + is reserved for the caller who chose the id, never an innocent minted request.""" + entered = anyio.Event() + release = anyio.Event() + seen_ids: list[RequestId | None] = [] + + async def maybe_park( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + seen_ids.append(ctx.request_id) + if method == "park": + entered.set() + await release.wait() + return {} + + async with running_pair(pair_factory, server_on_request=maybe_park) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def parked() -> None: + await client.send_raw_request("park", None, {"request_id": "3"}) + + tg.start_soon(parked) + await entered.wait() + # The counter mints 1 and 2, then skips the occupied 3 to 4. + for _ in range(3): + await client.send_raw_request("plain", None) + release.set() + assert [request_id for request_id in seen_ids if request_id != "3"] == [1, 2, 4] + + +@pytest.mark.anyio +async def test_supplied_numeric_string_id_collides_with_its_int_twin(pair_factory: PairFactory): + """ "7" and 7 are one id in the collision domain on BOTH dispatchers, so the + in-memory pair raises exactly where the wire dispatcher (whose pending keys + are coerced for response correlation) would.""" + entered = anyio.Event() + release = anyio.Event() + + async def parked( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + entered.set() + await release.wait() + return {} + + async with running_pair(pair_factory, server_on_request=parked) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def first() -> None: + await client.send_raw_request("slow", None, {"request_id": 7}) + + tg.start_soon(first) + await entered.wait() + with pytest.raises(ValueError, match="already in flight"): + await client.send_raw_request("duplicate", None, {"request_id": "7"}) + release.set() + # Completion frees the id for either spelling. + assert await client.send_raw_request("again", None, {"request_id": "7"}) == {} + + +@pytest.mark.anyio +async def test_notify_intercept_sees_every_notification_and_consumes_on_true(pair_factory: PairFactory): + """The intercept sees every inbound notification; a frame it consumes never reaches `on_notify`, the rest do.""" + intercepted: list[str] = [] + + def intercept(method: str, params: Mapping[str, Any] | None) -> bool: + intercepted.append(method) + return method == "notifications/consumed" + + async with running_pair(pair_factory, client_on_notify_intercept=intercept) as (_client, server, crec, _srec): + with anyio.fail_after(5): + await server.notify("notifications/consumed", None) + await server.notify("notifications/passed", None) + await crec.notified.wait() + assert intercepted == ["notifications/consumed", "notifications/passed"] + assert [method for method, _ in crec.notifications] == ["notifications/passed"] + + +@pytest.mark.anyio +async def test_notify_intercept_completes_before_a_later_response_resolves(pair_factory: PairFactory): + """Notifications written before a response are intercepted before it resolves, whatever spawned handlers do.""" + seen: list[str] = [] + + def intercept(method: str, params: Mapping[str, Any] | None) -> bool: + seen.append(method) + return False + + async def notify_then_answer( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/first", None) + await ctx.notify("notifications/second", None) + return {} + + async with running_pair( + pair_factory, server_on_request=notify_then_answer, client_on_notify_intercept=intercept + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("burst", None) + assert seen == ["notifications/first", "notifications/second"] + + +@pytest.mark.anyio +async def test_a_raising_notify_intercept_is_contained_and_passes_the_frame_through(pair_factory: PairFactory): + """An intercept exception costs only that interception: the frame still reaches `on_notify`, the loop survives.""" + + def broken_intercept(method: str, params: Mapping[str, Any] | None) -> bool: + raise RuntimeError("intercept exploded") + + async with running_pair(pair_factory, client_on_notify_intercept=broken_intercept) as ( + _client, + server, + crec, + _srec, + ): + with anyio.fail_after(5): + await server.notify("notifications/survives", None) + await crec.notified.wait() + assert [method for method, _ in crec.notifications] == ["notifications/survives"] + + if TYPE_CHECKING: _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) _o: Outbound = _d diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 82d16bc4b..e91fc2de2 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -34,11 +34,10 @@ from mcp.server import Server, ServerRequestContext from mcp.shared._compat import resync_tracer from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream -from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.dispatcher import CallOptions, DispatchContext, coerce_request_id from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, - _coerce_id, _OutboundPlan, _Pending, _plan_outbound, @@ -1821,7 +1820,7 @@ async def respond_stringly() -> None: @pytest.mark.anyio async def test_error_response_with_string_id_correlates_to_int_keyed_pending_request(): - """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (same `_coerce_id` path).""" + """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (`coerce_request_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1900,10 +1899,10 @@ async def on_progress(progress: float, total: float | None, message: str | None) assert seen == [0.5] -def test_coerce_id_passes_through_non_numeric_string_and_int(): - assert _coerce_id("7") == 7 - assert _coerce_id("not-an-int") == "not-an-int" - assert _coerce_id(42) == 42 +def test_coerce_request_id_passes_through_non_numeric_string_and_int(): + assert coerce_request_id("7") == 7 + assert coerce_request_id("not-an-int") == "not-an-int" + assert coerce_request_id(42) == 42 @pytest.mark.anyio @@ -2154,7 +2153,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ids=["string-cancel-for-int-request", "int-cancel-for-string-request"], ) async def test_cancelled_correlates_across_string_and_int_request_id_forms(request_id: RequestId, cancel_id: object): - """A peer that stringifies the id between request and cancel still cancels (same `_coerce_id` path).""" + """A peer that stringifies the id between request and cancel still cancels (same `coerce_request_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -2381,3 +2380,38 @@ async def call() -> None: assert observed[0][0] == "notifications/cancelled" assert observed[0][1]["requestId"] == request_id assert observed[0][1]["reason"] == "user clicked stop" + + +@pytest.mark.anyio +async def test_send_raw_request_with_caller_supplied_string_id_is_verbatim_on_the_wire(): + """A supplied "7" goes on the wire as the string "7", and the response still + correlates when the peer echoes it back as the integer 7 — the pending key gets + the same coercion `_resolve_pending` applies to inbound ids.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + result_box: list[dict[str, Any]] = [] + done = anyio.Event() + try: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call() -> None: + result_box.append(await client.send_raw_request("tools/list", None, {"request_id": "7"})) + done.set() + + await tg.start(client.run, on_request, on_notify) + tg.start_soon(call) + wire = await c2s_recv.receive() + assert isinstance(wire, SessionMessage) + assert isinstance(wire.message, JSONRPCRequest) + assert wire.message.id == "7" + assert type(wire.message.id) is str + await s2c_send.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=7, result={"ok": True}))) + await done.wait() + tg.cancel_scope.cancel() + finally: + for stream in (c2s_send, c2s_recv, s2c_send, s2c_recv): + stream.close() + assert result_box == [{"ok": True}]