diff --git a/docs/advanced/low-level-server.md b/docs/advanced/low-level-server.md index 12c4532949..6568b76a55 100644 --- a/docs/advanced/low-level-server.md +++ b/docs/advanced/low-level-server.md @@ -181,7 +181,7 @@ The handshake belongs to the runner. `server/discover`, `ping`, and every other Each of these is one idea you now have the vocabulary for; each has its own chapter. -* `on_call_tool`, `on_get_prompt`, and `on_read_resource` may return an `InputRequiredResult` instead of their normal result to pause the call and ask the client for input; see **[Multi-round-trip requests](multi-round-trip.md)**. +* `on_call_tool`, `on_get_prompt`, and `on_read_resource` may return an `InputRequiredResult` instead of their normal result to pause the call and ask the client for input; see **[Multi-round-trip requests](multi-round-trip.md)**. True to this tier, nothing is installed for you: where `MCPServer` seals `requestState` by default, here the `request_state` you set crosses the wire exactly as written until you opt in with `server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[...]), default_audience=server.name))`: one line (both names import from `mcp.server.request_state`) for the identical sealing and verification `MCPServer` performs (**[Protecting `requestState`](multi-round-trip.md#protecting-requeststate)**). * `on_list_resources`, `on_read_resource`, `on_list_prompts`, `on_get_prompt`, `on_completion` are the same `(ctx, params) -> result` shape for the other primitives. * `server.streamable_http_app()` returns the same Starlette app `MCPServer`'s does; deploy it the way **[Running your server](../run/index.md)** deploys any other ASGI app. There is no `server.run(transport=...)` down here: `server.run(read_stream, write_stream, server.create_initialization_options())` drives one connection over a pair of streams, and that one line is the whole story. diff --git a/docs/advanced/multi-round-trip.md b/docs/advanced/multi-round-trip.md index 78e567e9d0..62734b38fc 100644 --- a/docs/advanced/multi-round-trip.md +++ b/docs/advanced/multi-round-trip.md @@ -40,6 +40,7 @@ Everything else in that file (the explicit `input_schema`, the hand-built `CallT ``` * The first round returns the `InputRequiredResult`. On the retry, `ctx.input_responses` holds the answers under the same keys and the function returns its ordinary result — prompt messages here, resource content for a template resource. +* A `request_state` you set is sealed before it crosses the wire and verified on the echo, like everything else on the server; **[Protecting `requestState`](#protecting-requeststate)** below covers what the seal gives you and when you need to configure keys. * An `@mcp.tool()` function can return the result directly the same way, when the dependency form doesn't fit. * Static `@mcp.resource()` functions don't participate: they take no `Context`, so they could never read the retry. Only template resources can ask. * The era rules below apply unchanged: returning an `InputRequiredResult` on a pre-2026 session is the same `-32603` the warning describes. @@ -84,6 +85,78 @@ Drop to the underlying session, where `allow_input_required=True` hands you the * For every entry in `input_requests` you put an `InputResponse` under the **same key** in `input_responses`. `fulfil` is where your UI goes; this one hard-codes the answer. * Same tool name, same `arguments`, every leg. The retry is the original call carried out again, not a new method. +## Protecting `requestState` + +Everything above treats `request_state` as an echo, and on the wire that is all it is. But the client holds it between legs (writing it down across processes is exactly what the previous section blessed), so what comes back is **client-supplied input**: it can be modified, expired, or lifted from a different call entirely. The spec requires servers to integrity-protect this state and reject the round when verification fails, whenever the state can influence authorization, resource access, or business logic. + +`MCPServer` protects it by default. Every server seals outgoing `requestState` and verifies every echo — resolver state and hand-built state alike — under a key generated at process start. You configure nothing, write plaintext, and read plaintext; the wire only ever carries an opaque encrypted token. + +The default key lives and dies with the process, which is the one thing you must know before deploying beyond a single process: + +```python +from mcp.server.mcpserver import MCPServer, RequestStateSecurity + +# Multi-instance or restart-surviving: one or more shared secret keys (>= 32 bytes each). +mcp = MCPServer("fleet", request_state_security=RequestStateSecurity(keys=[key])) +``` + +* **The default (no configuration)** suits a single process: stdio, or exactly one HTTP worker. A retry that lands on a different worker, a different instance behind a load balancer, or the same server after a restart is sealed under a key that process doesn't have — the client gets the frozen rejection below and must start the flow over. +* **`keys=[...]`** is required whenever a retry can reach a **different instance** (multi-worker `uvicorn`, load-balanced HTTP) or must survive restarts: every instance verifies what any sibling minted. Same machinery, your secret instead of a generated one. +* For your own crypto, such as a KMS or an existing token service, pass `RequestStateSecurity(codec=...)` instead of `keys`; **[Bring your own crypto](#bring-your-own-crypto)** below covers the contract. + +### What the seal carries + +Default or configured, `requestState` on the wire is an encrypted, authenticated token. Your code never sees it: handlers and resolvers write plaintext and read plaintext (`ctx.request_state`); the SDK seals on the way out and verifies on the way in. Beyond integrity, each token is bound to: + +* **A time window.** Every round re-seals with a fresh expiry, so `RequestStateSecurity(ttl=...)` (default 600 seconds) bounds per-round think time, not the whole flow. +* **The authenticated principal.** When the request carries an OAuth access token the SDK validated, the state is bound to the token's client, issuer, and subject: state minted for one user fails under another, even when both users share one OAuth client. A verifier that supplies no subject degrades the binding to the client identity alone, which under URL-based client IDs is shared by every user of that client software. When auth is terminated outside the SDK (a fronting proxy), or the transport is unauthenticated, there is no principal to bind and this check is inert, unless `RequestStateSecurity(bind_principal=...)` supplies one from your own identity signal. Whichever components your token verifier supplies, it must supply them consistently: a verifier that includes the subject on some requests and omits it on others changes the principal mid-flow, and in-flight rounds are rejected. +* **The originating request.** The method, the tool or prompt name (or resource URI), and a digest of the arguments. A token replayed against a different tool, different arguments, or a different method fails. +* **The exact question asked.** Every resolver answer is pinned to the rendered question the client was shown, both on the round it first arrives and when a recorded answer is reused later. Redeploy with a reworded message or a changed schema and the server re-asks instead of consuming a stale answer. The same pinning cuts the other way: derive messages from the tool's arguments, not from per-call data. A message built from a timestamp or a live rate renders differently every round, so every recorded answer looks stale and the server re-asks until the client's round limit ends the call. + +All of that is the SDK's job, not yours, and not the codec's if you bring your own. + +### Rotating keys + +`keys[0]` seals new state; every key in the list verifies. Zero-downtime rotation is three phases, each fully rolled out before the next: + +```python +RequestStateSecurity(keys=[OLD, NEW]) # 1: every instance learns to verify NEW; OLD still mints +RequestStateSecurity(keys=[NEW, OLD]) # 2: NEW mints; in-flight OLD state keeps verifying +RequestStateSecurity(keys=[NEW]) # 3: one ttl after phase 2 is fully out, retire OLD +``` + +Never promote the minter first: minting under a key some instance can't yet verify drops in-flight rounds mid-rollout. + +Keys are scoped to one service. The sealed envelope also carries the server's name as an audience claim, so a token minted by a different service that happens to share a secret is rejected anyway. The claim is only as distinctive as the name, so a server given an explicit policy must have a real name or set `RequestStateSecurity(audience=...)` — an unnamed one raises at construction. `audience=` also serves deliberate multi-service topologies where one service must accept state another minted. (The no-configuration default is exempt: its key never leaves the process, so the audience claim has nothing to add.) + +### Bring your own crypto + +`RequestStateSecurity(codec=...)` takes anything with `seal(bytes) -> str` and `unseal(str) -> bytes` that raises `InvalidRequestState` for any token it did not mint. The classic shape is envelope encryption against a KMS, where you unwrap a data key once at startup and keep the per-token crypto local: + +```python title="server.py" hl_lines="12 26-27 34-35 38" +--8<-- "docs_src/mrtr/tutorial005.py" +``` + +TTL, principal binding, and request binding are **not** the codec's job: the SDK stamps them into the payload before `seal` and re-verifies them after `unseal`, for every codec. A codec's only obligations are integrity (tampered means raise) and, ideally, confidentiality. + +### When verification fails + +Every inbound failure, whether tampered, expired, replayed against a different request or principal, or sealed under a key this server doesn't know, gets the same answer: + +```json +{"code": -32602, "message": "Invalid or expired requestState"} +``` + +One frozen message for every cause, so the wire never reveals which check failed; the real reason goes to the server log. Every inbound `requestState` on `tools/call`, `prompts/get`, and `resources/read` is checked, including one arriving for a handler that never mints state. The most common rejection in practice isn't an attacker — it's the default process-local key meeting a retry from before a restart or from another instance; the client restarts the flow, and `keys=[...]` is the fix when that matters. + +### Hand-built state + +A `request_state` you set yourself (returning `InputRequiredResult` from a tool, prompt, or resource-template function) is sealed and verified by the same machinery as resolver state, with zero code changes: write plaintext, read plaintext, and every binding above applies. + +The one thing the SDK cannot pin for you, even when configured, is question identity: it doesn't know which of *your* questions an answer in your state belongs to. If you store answers keyed by question, include your own question identifier in the state and check it on the retry. + +The low-level `Server` is the no-batteries tier: unlike `MCPServer`, nothing is sealed until you append the boundary yourself, and your `request_state` crosses the wire exactly as written until you do. The one-line opt-in is shown in **[The low-level Server](low-level-server.md#the-other-handlers)**. + ## A 2026-07-28 result `InputRequiredResult` only exists at protocol version **2026-07-28**. The in-memory `Client(server)` negotiates it for you; over the wire, `mode="auto"` discovers it. After connecting, `client.protocol_version` tells you what you got. @@ -108,5 +181,6 @@ Drop to the underlying session, where `allow_input_required=True` hands you the * To inspect or persist rounds, use `client.session.call_tool(..., allow_input_required=True)` and own the `while isinstance(result, InputRequiredResult)` loop yourself. * On `@mcp.tool()`, a dependency that asks the user produces this result for you (**[Dependencies](../tutorial/dependencies.md)**); the **low-level** `Server` is the manual form. * Prompts and resources participate too: an `@mcp.prompt()` or template `@mcp.resource()` function returns the `InputRequiredResult` itself and reads `ctx.input_responses` on the retry. +* `requestState` comes back as client-supplied input, so `MCPServer` seals it by default — resolver state and hand-built state alike — under a process-local key; multi-instance deployments pass `RequestStateSecurity(keys=[...])` (or a custom codec) so every instance can verify what a sibling minted. The seal binds every token to a time window, the originating request, and the authenticated principal when the request carries auth the SDK validated or `bind_principal=` supplies your own identity signal (**[Protecting `requestState`](#protecting-requeststate)**). This is the mechanism that replaces server-initiated sampling and the rest of the push-style back-channel; see **[Deprecated features](deprecated.md)**. diff --git a/docs/tutorial/dependencies.md b/docs/tutorial/dependencies.md index b7b18fe763..8d6d91412d 100644 --- a/docs/tutorial/dependencies.md +++ b/docs/tutorial/dependencies.md @@ -131,7 +131,8 @@ That's the right default for a precondition: no answer, no order. When declining its question, an eliciting resolver must derive its question deterministically from the tool's arguments and earlier answers. A per-call generated value (a `default_factory` id, a timestamp) is re-derived on each round and must not appear in a question the answer is meant - to bind to. + to bind to. A question built from such volatile data makes every recorded answer look stale, + so the server re-asks it on every round until the client's round limit ends the call. ## Recap diff --git a/docs_src/mrtr/tutorial005.py b/docs_src/mrtr/tutorial005.py new file mode 100644 index 0000000000..a8588b250f --- /dev/null +++ b/docs_src/mrtr/tutorial005.py @@ -0,0 +1,38 @@ +import os + +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from mcp.server import MCPServer +from mcp.server.mcpserver import InvalidRequestState, RequestStateSecurity + +PREFIX = "kms1." # format version; fed to GCM as associated data, so it is bound under the tag + + +def unwrap_data_key() -> bytes: + """One KMS call at process start, kms.decrypt(CiphertextBlob=...); every token after that is local crypto.""" + return os.urandom(32) # stand-in for the unwrapped 32-byte data key + + +class EnvelopeCodec: + def __init__(self, data_key: bytes) -> None: + self._aesgcm = AESGCM(data_key) + + def seal(self, payload: bytes) -> str: + nonce = os.urandom(12) + return PREFIX + (nonce + self._aesgcm.encrypt(nonce, payload, PREFIX.encode())).hex() + + def unseal(self, token: str) -> bytes: + if not token.startswith(PREFIX): + raise InvalidRequestState("unknown token format") + body = token[len(PREFIX) :] + try: + raw = bytes.fromhex(body) + if raw.hex() != body: # only the exact string seal() produced verifies + raise ValueError("non-canonical hex") + return self._aesgcm.decrypt(raw[:12], raw[12:], PREFIX.encode()) + except (ValueError, InvalidTag) as exc: + raise InvalidRequestState("token failed verification") from exc + + +mcp = MCPServer("Deployer", request_state_security=RequestStateSecurity(codec=EnvelopeCodec(unwrap_data_key()))) diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 8621c877a8..218188f50a 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -6,16 +6,13 @@ import asyncio import base64 -import binascii -import hashlib -import hmac import json import logging from typing import Annotated, Any import click from mcp.server import ServerRequestContext -from mcp.server.mcpserver import Context, MCPServer +from mcp.server.mcpserver import Context, MCPServer, RequestStateSecurity from mcp.server.mcpserver.prompts.base import UserMessage from mcp.server.streamable_http import EventCallback, EventMessage, EventStore from mcp.shared.exceptions import MCPError @@ -47,7 +44,7 @@ TextResourceContents, UnsubscribeRequestParams, ) -from mcp_types.jsonrpc import INVALID_PARAMS, MISSING_REQUIRED_CLIENT_CAPABILITY +from mcp_types.jsonrpc import MISSING_REQUIRED_CLIENT_CAPABILITY from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -100,8 +97,12 @@ async def replay_events_after(self, last_event_id: EventId, send_callback: Event # Create event store for SSE resumability (SEP-1699) event_store = InMemoryEventStore() +# Fixed fixture key (RequestStateSecurity requires at least 32 bytes); a real deployment would load a shared secret. +_REQUEST_STATE_KEY = b"everything-server-fixture-request-state-key" + mcp = MCPServer( name="mcp-conformance-test-server", + request_state_security=RequestStateSecurity(keys=[_REQUEST_STATE_KEY]), ) @@ -497,30 +498,12 @@ async def test_input_required_result_multi_round(ctx: Context) -> str | InputReq ) -# Fixed key for the conformance fixture; a real server would derive or rotate this. -_STATE_HMAC_KEY = b"everything-server-fixture-key" - - -def _seal_state(payload: str) -> str: - encoded = base64.urlsafe_b64encode(payload.encode()).decode() - sig = hmac.new(_STATE_HMAC_KEY, encoded.encode(), hashlib.sha256).hexdigest() - return f"{encoded}.{sig}" - - -def _unseal_state(state: str) -> str: - encoded, _, sig = state.partition(".") - expected = hmac.new(_STATE_HMAC_KEY, encoded.encode(), hashlib.sha256).hexdigest() - if not sig or not hmac.compare_digest(sig, expected): - raise MCPError(code=INVALID_PARAMS, message="requestState failed integrity verification") - try: - return base64.urlsafe_b64decode(encoded).decode() - except (binascii.Error, UnicodeDecodeError) as e: - raise MCPError(code=INVALID_PARAMS, message="requestState failed integrity verification") from e - - @mcp.tool() async def test_input_required_result_tampered_state(ctx: Context) -> str | InputRequiredResult: - """Tests that the server rejects a requestState that fails HMAC verification""" + """Tests that the server rejects a tampered requestState echo. + + The handler stays plaintext; tamper rejection happens in the SDK's request-state boundary. + """ if ctx.request_state is None: confirm = ElicitRequest( params=ElicitRequestFormParams( @@ -528,9 +511,8 @@ async def test_input_required_result_tampered_state(ctx: Context) -> str | Input requested_schema={"type": "object", "properties": {"ok": {"type": "boolean"}}, "required": ["ok"]}, ) ) - return InputRequiredResult(input_requests={"confirm": confirm}, request_state=_seal_state("round-1")) - payload = _unseal_state(ctx.request_state) - return f"state-ok: {payload}" + return InputRequiredResult(input_requests={"confirm": confirm}, request_state="round-1") + return f"state-ok: {ctx.request_state}" @mcp.tool() diff --git a/examples/stories/README.md b/examples/stories/README.md index 8c1cceb5b6..79d7143110 100644 --- a/examples/stories/README.md +++ b/examples/stories/README.md @@ -128,7 +128,7 @@ opens with a banner saying what replaces it. | [`dual_era`](dual_era/) | one server factory serving both protocol eras; era-neutral accessors | current | | **— feature stories —** | | | | [`streaming`](streaming/) | progress notifications, in-flight logging, cancellation | current | -| [`mrtr`](mrtr/) | `InputRequiredResult` round-trip: the `Client` auto-loop and a manual session-level loop | current | +| [`mrtr`](mrtr/) | `InputRequiredResult` round-trip: the `Client` auto-loop, a manual session-level loop, and the default `requestState` sealing (a tampered echo gets one frozen error) | current | | [`legacy_elicitation`](legacy_elicitation/) | server pauses a tool to ask the user (form + url) via a push request | legacy | | [`refund_desk`](refund_desk/) | resolver DI: `Annotated[T, Resolve(fn)]` params filled server-side, hidden from the input schema | current | | [`sampling`](sampling/) | server asks the client's LLM mid-tool (push request) | deprecated | diff --git a/examples/stories/mrtr/README.md b/examples/stories/mrtr/README.md index aaad86ca9d..870db7d298 100644 --- a/examples/stories/mrtr/README.md +++ b/examples/stories/mrtr/README.md @@ -3,15 +3,20 @@ Multi-round tool result: on the 2026-07-28 protocol a tool that needs user input mid-call **returns** `resultType: "input_required"` with embedded `inputRequests` and an opaque `requestState`, instead of pushing a -server→client request. The client fulfils the embedded requests and retries the +server-to-client request. The client fulfils the embedded requests and retries the original `tools/call` carrying `inputResponses` and the echoed `requestState`. The story shows both the `Client` auto-loop (one `await call_tool`, callbacks -fired transparently) and a manual `client.session` loop (the persistable form). +fired transparently) and a manual `client.session` loop (the persistable +form). Because `requestState` round-trips through the client, it also shows +the security surface that protects it: `MCPServer` seals state by default +under a process-local key, handlers keep writing plaintext, and the wire only +ever carries an opaque token. The manual loop tampers with the sealed token to +show what a forged echo gets back. ## Run it ```bash -# HTTP — the client self-hosts the server on a free port, runs, then tears it +# HTTP: the client self-hosts the server on a free port, runs, then tears it # down (the InputRequiredResult round-trip is 2026-era only) uv run python -m stories.mrtr.client --http # same, against the lowlevel-API server variant @@ -20,36 +25,55 @@ uv run python -m stories.mrtr.client --http --server server_lowlevel ## What to look at -- `client.py` `main` — the auto-loop is invisible at the call site: +- `server.py` `build_server`: no security configuration at all. The default + seals under a key generated at process start, which is right for a + single-process server like this one; a fleet (multi-worker or load-balanced) + shares keys with `request_state_security=RequestStateSecurity(keys=[...])` + so any instance can verify state another minted. +- `server.py` `deploy`: handlers stay plaintext. The first round returns + `InputRequiredResult(input_requests={...}, + request_state="awaiting-confirm")` and the retry asserts + `ctx.request_state == "awaiting-confirm"`. The tool never touches the + crypto; the boundary seals on the way out and unseals the echo on the way + back in. +- `client.py` `main`: the auto-loop is invisible at the call site: `Client(target, mode=mode, elicitation_callback=on_elicit)` then `await client.call_tool("deploy", ...)`. The same `on_elicit` callback the legacy push path uses is dispatched for each embedded `inputRequests` entry. -- `client.py` manual block — `client.session.call_tool(..., +- `client.py` manual block: `client.session.call_tool(..., allow_input_required=True)` returns the raw `InputRequiredResult` so - `request_state` can be persisted between rounds; the retry is just another - `tools/call` with `input_responses=` / `request_state=`. -- `server.py` `deploy` — `ctx.input_responses` / `ctx.request_state` read the - retry payload; the first round returns - `InputRequiredResult(input_requests={...}, request_state=...)`, the second - returns the final string. -- `server_lowlevel.py` — same wire contract via `params.input_responses` / - `params.request_state` and a hand-built `InputRequiredResult`. + `request_state` can be persisted between rounds. The wire value is an opaque + sealed token, **not** the string the server code wrote. The client asserts + exactly that, then retries with one character of the token flipped and gets + the single frozen error every verification failure maps to: `-32602`, + `"Invalid or expired requestState"`, `{"reason": "invalid_request_state"}`. + The specific reason (tampered tag, expiry, wrong request, wrong principal) + appears only in the server's log, never on the wire. The untampered token + then completes the round normally. +- `server_lowlevel.py`: the lowlevel tier doesn't seal by default; the same + enforcement is one appended middleware: + `server.middleware.append(RequestStateBoundary(RequestStateSecurity.ephemeral(), + default_audience=server.name))`. ## Caveats - **Loop bound.** The auto-loop gives up after `input_required_max_rounds` (default 10) with `InputRequiredRoundsExceededError`; raise it on the `Client` ctor or drop to the manual loop. -- **`requestState` integrity is the server's job.** The client echoes it - byte-exact and never inspects it; the server MUST treat it as - attacker-controlled. The SDK ships no signing helper yet. +- **The default key dies with the process.** It is generated at startup and + held only in memory, so a server restart (or a retry landing on a different + instance) invalidates in-flight rounds: the client gets the same frozen + rejection and must start the flow over. Use + `RequestStateSecurity(keys=[...])` when state must survive either. ## Spec -[Input required tool results — server features](https://modelcontextprotocol.io/specification/draft/server/tools#input-required-tool-results) +[Input required tool results (server features)](https://modelcontextprotocol.io/specification/draft/server/tools#input-required-tool-results), +[Multi-round-trip requests (security patterns)](https://modelcontextprotocol.io/specification/draft/basic/patterns/mrtr) ## See also -`legacy_elicitation/` and `sampling/` — the handshake-era push equivalents this -mechanism replaces on the 2026 protocol. `refund_desk/` — resolver DI at the -MCPServer tier: the questions a tool can declare instead of pushing by hand. +`legacy_elicitation/` and `sampling/`: the handshake-era push equivalents this +mechanism replaces on the 2026 protocol. `refund_desk/`: resolver DI at the +MCPServer tier: the questions a tool can declare instead of pushing by hand +(its elicited answers ride in the same sealed `requestState`). diff --git a/examples/stories/mrtr/client.py b/examples/stories/mrtr/client.py index 5b686c3c9c..7280fd0aed 100644 --- a/examples/stories/mrtr/client.py +++ b/examples/stories/mrtr/client.py @@ -2,6 +2,7 @@ import mcp_types as types +from mcp import MCPError from mcp.client import Client, ClientRequestContext from stories._harness import Target, run_client @@ -27,14 +28,37 @@ async def main(target: Target, *, mode: str = "auto") -> None: first = await client.session.call_tool("deploy", {"env": "staging"}, allow_input_required=True) assert isinstance(first, types.InputRequiredResult) assert first.input_requests is not None and "confirm" in first.input_requests - assert first.request_state == "awaiting-confirm" - # Decline this time so the path diverges from the auto-loop run above. + # The boundary sealed server.py's plaintext "awaiting-confirm"; the wire token is opaque. + token = first.request_state + assert token is not None and token != "awaiting-confirm", token + responses: types.InputResponses = {"confirm": types.ElicitResult(action="decline")} + + # Tamper demo: flipping any one character fails verification, and every failure + # maps to one frozen wire error; the real reason appears only in the server log. + i = len(token) // 2 + tampered = token[:i] + ("A" if token[i] != "A" else "B") + token[i + 1 :] + try: + await client.session.call_tool( + "deploy", + {"env": "staging"}, + input_responses=responses, + request_state=tampered, + allow_input_required=True, + ) + except MCPError as e: + assert e.code == types.INVALID_PARAMS + assert e.message == "Invalid or expired requestState" + assert e.data == {"reason": "invalid_request_state"} + else: + raise AssertionError("expected MCPError for a tampered requestState") + + # The untampered token still completes the round; decline so this path diverges from the auto run. second = await client.session.call_tool( "deploy", {"env": "staging"}, input_responses=responses, - request_state=first.request_state, + request_state=token, allow_input_required=True, ) assert isinstance(second, types.CallToolResult) diff --git a/examples/stories/mrtr/server.py b/examples/stories/mrtr/server.py index d83c2e9835..8155b90f4d 100644 --- a/examples/stories/mrtr/server.py +++ b/examples/stories/mrtr/server.py @@ -13,19 +13,19 @@ def build_server() -> MCPServer: + # requestState is sealed by default under a process-local key, which suits this + # single-process server; fleets share keys=[...] so any instance can verify. mcp = MCPServer("mrtr-example") @mcp.tool(description="Deploy to an environment, asking the user to confirm first.") async def deploy(env: str, ctx: Context) -> str | InputRequiredResult: responses = ctx.input_responses if responses is None or "confirm" not in responses: - # First round: ask the client to elicit confirmation. request_state is opaque - # to the client; here it carries the step name so the retry can verify the echo. ask = ElicitRequest( params=ElicitRequestFormParams(message=f"Deploy to {env}?", requested_schema=CONFIRM_SCHEMA) ) + # The boundary seals this plaintext request_state on the way out and unseals the echo on retry. return InputRequiredResult(input_requests={"confirm": ask}, request_state="awaiting-confirm") - # Retry round: the client echoed request_state byte-exact and supplied the answer. assert ctx.request_state == "awaiting-confirm", ctx.request_state answer = responses["confirm"] if isinstance(answer, ElicitResult) and answer.action == "accept" and (answer.content or {}).get("confirm"): diff --git a/examples/stories/mrtr/server_lowlevel.py b/examples/stories/mrtr/server_lowlevel.py index 0ed13cea49..6f3f489d8b 100644 --- a/examples/stories/mrtr/server_lowlevel.py +++ b/examples/stories/mrtr/server_lowlevel.py @@ -6,6 +6,7 @@ from mcp.server.context import ServerRequestContext from mcp.server.lowlevel import Server +from mcp.server.request_state import RequestStateBoundary, RequestStateSecurity from stories._hosting import run_server_from_args CONFIRM_SCHEMA: types.ElicitRequestedSchema = { @@ -55,7 +56,11 @@ async def call_tool( return types.CallToolResult(content=[types.TextContent(text=f"deployed to {env}")]) return types.CallToolResult(content=[types.TextContent(text=f"deployment to {env} cancelled")]) - return Server("mrtr-example", on_list_tools=list_tools, on_call_tool=call_tool) + server = Server("mrtr-example", on_list_tools=list_tools, on_call_tool=call_tool) + # Lowlevel opt-in: append the same boundary middleware MCPServer installs by + # default; the server name becomes the token audience. + server.middleware.append(RequestStateBoundary(RequestStateSecurity.ephemeral(), default_audience=server.name)) + return server if __name__ == "__main__": diff --git a/examples/stories/refund_desk/README.md b/examples/stories/refund_desk/README.md index 5b5bb55327..f10363698b 100644 --- a/examples/stories/refund_desk/README.md +++ b/examples/stories/refund_desk/README.md @@ -29,7 +29,9 @@ uv run python -m stories.refund_desk.client --http - `server.py` `refund_order` — the signature is the whole story: `order_id` and `reason` are model-facing; `cents` and `restock` carry `Resolve(...)` markers and never reach the input schema. `client.py` asserts `properties` and - `required` are exactly `{order_id, reason}`. + `required` are exactly `{order_id, reason}`. At 2026 the resolver's elicited + answers ride between rounds inside a `requestState` the SDK seals by default; + see `mrtr/` for the full security walk-through. - `server.py` `refund_scope` — the no-round-trip fast path: a one-line order returns `Scope(full=True)` directly; only a multi-line order returns `Elicit(...)`. The ORD-7001 call completes with zero elicitations. diff --git a/examples/stories/refund_desk/server.py b/examples/stories/refund_desk/server.py index f29a266f0b..a263b93850 100644 --- a/examples/stories/refund_desk/server.py +++ b/examples/stories/refund_desk/server.py @@ -103,6 +103,8 @@ def ask_restock( def build_server() -> MCPServer: + # Elicited answers ride between rounds in a requestState the SDK seals by default; + # see mrtr/ for the full security walk-through. mcp = MCPServer("refund-desk") @mcp.tool(description="Refund an order. The amount comes from the order record, not from the caller.") diff --git a/src/mcp-types/mcp_types/methods.py b/src/mcp-types/mcp_types/methods.py index f49c158d92..37e1145386 100644 --- a/src/mcp-types/mcp_types/methods.py +++ b/src/mcp-types/mcp_types/methods.py @@ -13,7 +13,7 @@ from collections.abc import Mapping from functools import cache from types import MappingProxyType, UnionType -from typing import Any, Final, Literal, TypeVar, get_args +from typing import Any, Final, Literal, TypeGuard, TypeVar, cast, get_args from pydantic import BaseModel, TypeAdapter @@ -28,6 +28,7 @@ "CLIENT_REQUESTS", "CLIENT_RESULTS", "CacheableMethod", + "INPUT_REQUIRED_METHODS", "MONOLITH_NOTIFICATIONS", "MONOLITH_REQUESTS", "MONOLITH_RESULTS", @@ -36,6 +37,7 @@ "SERVER_RESULTS", "SPEC_CLIENT_METHODS", "SPEC_CLIENT_NOTIFICATION_METHODS", + "is_input_required", "parse_client_notification", "parse_client_request", "parse_client_result", @@ -423,6 +425,22 @@ ) """Runtime mirror of `CacheableMethod`, derived from `MONOLITH_RESULTS`.""" +INPUT_REQUIRED_METHODS: Final[frozenset[str]] = frozenset( + method + for method, row in MONOLITH_RESULTS.items() + if any( + issubclass(arm, types.InputRequiredResult) for arm in (get_args(row) if isinstance(row, UnionType) else (row,)) + ) +) +"""Methods whose results may be `InputRequiredResult`, derived from `MONOLITH_RESULTS`.""" + + +def is_input_required(result: object) -> TypeGuard[types.InputRequiredResult | dict[str, Any]]: + """True when `result` is an `input_required` interim result, typed or wire-shaped.""" + if isinstance(result, types.InputRequiredResult): + return True + return isinstance(result, Mapping) and cast("Mapping[str, Any]", result).get("resultType") == "input_required" + # --- Parse functions --- diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 638ea63a9d..c2db891ca6 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -609,7 +609,9 @@ async def call_tool( callbacks and the call is retried automatically (up to `input_required_max_rounds`). To drive the loop yourself — e.g. to persist `request_state` across process restarts — use - `client.session.call_tool(..., allow_input_required=True)`. + `client.session.call_tool(..., allow_input_required=True)`. Persisted + state is still subject to the server's TTL, request binding, and key + lifetime; a server on the default process-local key rejects it after a restart. Args: name: The name of the tool to call. diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index ba66e94226..29413abf2b 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -7,7 +7,7 @@ from starlette.requests import HTTPConnection from starlette.types import Receive, Scope, Send -from mcp.server.auth.provider import AccessToken, TokenVerifier +from mcp.server.auth.provider import AccessToken, TokenVerifier, principal_components class AuthenticatedUser(SimpleUser): @@ -34,13 +34,8 @@ def authorization_context(user: AuthenticatedUser) -> AuthorizationContext: See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for a verifier that populates `subject` and `claims` from an introspection response.""" - token = user.access_token - issuer = (token.claims or {}).get("iss") - return AuthorizationContext( - client_id=token.client_id, - issuer=str(issuer) if issuer is not None else None, - subject=token.subject, - ) + client_id, issuer, subject = principal_components(user.access_token) + return AuthorizationContext(client_id=client_id, issuer=issuer, subject=subject) class BearerAuthBackend(AuthenticationBackend): diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index eeb371f1c2..644868f3e5 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -59,6 +59,17 @@ class AccessToken(BaseModel): claims: dict[str, Any] | None = None # additional claims (e.g. `iss`, `act`) +def principal_components(token: AccessToken) -> tuple[str, str | None, str | None]: + """The (client_id, issuer, subject) triple identifying the principal a token represents. + + The single source for "who is this token's principal": session ownership and + request-state binding both build on it. Components the token verifier does + not supply are `None`, so comparisons degrade to the remaining components. + """ + issuer = (token.claims or {}).get("iss") + return token.client_id, str(issuer) if issuer is not None else None, token.subject + + RegistrationErrorCode = Literal[ "invalid_redirect_uri", "invalid_client_metadata", diff --git a/src/mcp/server/mcpserver/__init__.py b/src/mcp/server/mcpserver/__init__.py index 8ee6c4e4e2..0205df1920 100644 --- a/src/mcp/server/mcpserver/__init__.py +++ b/src/mcp/server/mcpserver/__init__.py @@ -3,6 +3,14 @@ from mcp_types import Icon from mcp.server.extension import Extension, MethodBinding, ResourceBinding, ToolBinding +from mcp.server.request_state import ( + AESGCMRequestStateCodec, + InvalidRequestState, + RequestStateBoundary, + RequestStateCodec, + RequestStateSecurity, + authenticated_principal, +) from .context import Context from .resolve import ( @@ -36,4 +44,10 @@ "require_client_extension", "ResourceSecurity", "DEFAULT_RESOURCE_SECURITY", + "RequestStateSecurity", + "RequestStateCodec", + "RequestStateBoundary", + "AESGCMRequestStateCodec", + "InvalidRequestState", + "authenticated_principal", ] diff --git a/src/mcp/server/mcpserver/resolve.py b/src/mcp/server/mcpserver/resolve.py index 9ff8dfeed5..d752afc10c 100644 --- a/src/mcp/server/mcpserver/resolve.py +++ b/src/mcp/server/mcpserver/resolve.py @@ -28,7 +28,11 @@ from __future__ import annotations +import base64 +import hashlib import inspect +import json +import logging import types import typing from collections.abc import Callable, Hashable, Mapping @@ -43,6 +47,7 @@ ElicitRequestFormParams, ElicitResult, FormElicitationCapability, + InputRequest, InputRequests, InputRequiredResult, InputResponses, @@ -61,6 +66,7 @@ ) from mcp.server.mcpserver.context import Context from mcp.server.mcpserver.exceptions import InvalidSignature, ToolError +from mcp.server.request_state import compact_json from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import MCPError @@ -73,7 +79,9 @@ # `InputRequiredResult` rather than as a standalone server-to-client request. # Pinned (not `LATEST_MODERN_VERSION`, which moves when newer revisions are added). _INPUT_REQUIRED_VERSION = "2026-07-28" -_STATE_VERSION = 1 +_STATE_VERSION = 3 # v3: recorded and pended outcomes pinned to ASCII-canonical question renders + +logger = logging.getLogger(__name__) class Resolve: @@ -369,7 +377,11 @@ def __init__( self.context = context self.input_required = input_required self.answers: InputResponses = context.input_responses or {} if input_required else {} - self.state = _decode_state(context.request_state) if input_required else {} + decoded = _decode_state(context.request_state if input_required else None) + self.state = decoded.outcomes + # Digests of the questions asked last round: an answer is accepted only + # for the exact rendering the client was shown. + self.asked = decoded.asked # In-call dedup keyed by resolver identity (distinguishes two instances of # the same bound method); `persist` holds the wire-shaped record of each # elicited outcome, keyed by its wire key - exactly what the next round's @@ -431,7 +443,8 @@ async def resolve_arguments( injected[name] = outcome if wants_union else _unwrap(outcome, name) if res.pending: - return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.persist)) + asked = {key: _request_digest(request) for key, request in res.pending.items()} + return InputRequiredResult(input_requests=res.pending, request_state=_encode_state(res.persist, asked)) return injected @@ -494,19 +507,25 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio if not res.input_required: return await res.context.elicit(elicit.message, elicit.schema) + request = _elicit_request(elicit) + q = _request_digest(request) + # A recorded outcome from a prior round is consulted only here, after the body # decided to ask, so a `request_state` entry can never stand in for a resolver's - # own computation. Re-validate it against the live `Elicit.schema`. A recorded - # outcome wins over a re-sent answer; an invalid entry self-deletes and falls - # through to the fresh answer (or to re-asking). - outcome = _restore_outcome(res, key, elicit.schema) + # own computation. A recorded outcome wins over a re-sent answer. + outcome = _restore_outcome(res, key, elicit.schema, q) if outcome is not None: return outcome answer = res.answers.get(key) + # An answer counts only for the rendering recorded when it was asked; an answer to + # an unrecorded or differently-worded question re-asks instead of being consumed. + if answer is not None and res.asked.get(key) != q: + logger.info("Discarding the answer for resolver %r: the question changed since it was asked", key) + answer = None if answer is None: _require_form_elicitation(res.context, key) - res.pending[key] = _elicit_request(elicit) + res.pending[key] = request raise _Pending if not isinstance(answer, ElicitResult): raise ToolError(f"Resolver {key!r} received a non-elicitation response") @@ -521,12 +540,12 @@ async def _elicit(elicit: Elicit[Any], key: str, res: _Resolution) -> Elicitatio ) from e # Persist the exact wire content that just passed validation - never the # model - so restoring next round revalidates the same bytes the client sent. - res.persist[key] = _StateEntry(action="accept", data=answer.content) + res.persist[key] = _StateEntry(action="accept", data=answer.content, q=q) return AcceptedElicitation(data=data) if answer.action == "decline": - res.persist[key] = _StateEntry(action="decline") + res.persist[key] = _StateEntry(action="decline", q=q) return DeclinedElicitation() - res.persist[key] = _StateEntry(action="cancel") + res.persist[key] = _StateEntry(action="cancel", q=q) return CancelledElicitation() @@ -595,37 +614,58 @@ class _StateEntry(BaseModel): action: Literal["accept", "decline", "cancel"] data: Any = None + q: str | None = None + """Digest of the exact rendered question this outcome answered.""" + + +def _request_digest(request: InputRequest) -> str: + """Pin an outcome to the exact rendered question the client was shown. + + A redeploy that rewords or reshapes a question re-asks it instead of reusing the recorded answer. + """ + params = request.params + rendered = compact_json(params.model_dump(mode="json", by_alias=True, exclude_none=True) if params else None) + digest = hashlib.sha256(rendered.encode()).digest()[:16] + return base64.urlsafe_b64encode(digest).decode().rstrip("=") class _State(BaseModel): - """The decoded `request_state`: resolver outcomes from earlier rounds.""" + """The decoded `request_state`: resolver progress from earlier rounds.""" v: int outcomes: dict[str, _StateEntry] = {} + asked: dict[str, str] = {} + """Question digest of each elicitation asked last round, keyed by wire key.""" -def _decode_state(request_state: str | None) -> dict[str, _StateEntry]: +def _decode_state(request_state: str | None) -> _State: """Decode the per-call resolution progress from `request_state`. - `request_state` is client-trusted (integrity sealing is a follow-up); validate - it through `_State` and treat anything malformed as "no progress yet". + Parsed with stdlib `json.loads` because `_encode_state` may emit escaped + lone surrogates, which pydantic's JSON parser rejects. The string arrives + boundary-authenticated, so malformed content or a version mismatch is + drift within the operator's own fleet (e.g. a rolling upgrade) and is + treated as "no progress yet". """ + empty = _State(v=_STATE_VERSION) if not request_state: - return {} + return empty try: - state = _State.model_validate_json(request_state) - except ValidationError: - return {} - return state.outcomes if state.v == _STATE_VERSION else {} + state = _State.model_validate(json.loads(request_state)) + except ValueError: + return empty + return state if state.v == _STATE_VERSION else empty -def _encode_state(outcomes: Mapping[str, _StateEntry]) -> str: - """Encode recorded elicitation outcomes (keyed by wire key) for the next round. +def _encode_state(outcomes: Mapping[str, _StateEntry], asked: Mapping[str, str]) -> str: + """Encode recorded outcomes and asked-question digests for the next round. - Entries already hold the client's wire-shaped data exactly as it was sent (and - validated), so encoding is pure wrapping: encode-restore is the identity. + Outcome entries already hold the client's wire-shaped data exactly as it was + sent (and validated), so encoding is pure wrapping: encode-restore is the + identity. """ - return _State(v=_STATE_VERSION, outcomes=dict(outcomes)).model_dump_json() + state = _State(v=_STATE_VERSION, outcomes=dict(outcomes), asked=dict(asked)) + return compact_json(state.model_dump(mode="json")) def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel]) -> ElicitationResult[Any]: @@ -642,12 +682,12 @@ def _outcome_from_state(entry: _StateEntry, schema: type[BaseModel]) -> Elicitat return _accepted(schema.model_validate(entry.data)) -def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel]) -> ElicitationResult[Any] | None: +def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel], q: str) -> ElicitationResult[Any] | None: """Restore `key`'s recorded outcome from a prior round, or `None` when absent. - `request_state` is client-trusted, so an entry whose data fails validation gets - the `_decode_state` treatment - dropped as if no progress was recorded, so the - question is asked again - rather than surfacing a validation error. + An entry pinned to a question digest other than `q`, or whose accepted + data fails validation against the live `schema`, is dropped as if no + progress was recorded, so the question is asked again. Carries the original decoded entry forward unchanged in `res.persist`: if a later resolver is still pending, the next round's `request_state` is built from @@ -657,6 +697,9 @@ def _restore_outcome(res: _Resolution, key: str, schema: type[BaseModel]) -> Eli entry = res.state.get(key) if entry is None: return None + if entry.q != q: + del res.state[key] + return None try: outcome = _outcome_from_state(entry, schema) except ValidationError: diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 6764709806..3750429cdc 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -83,6 +83,7 @@ from mcp.server.mcpserver.tools import Tool, ToolManager from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.logging import configure_logging, get_logger +from mcp.server.request_state import RequestStateBoundary, RequestStateSecurity from mcp.server.sse import SseServerTransport from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore @@ -133,6 +134,15 @@ class Settings(BaseSettings, Generic[LifespanResultT]): auth: AuthSettings | None +_MISSING_AUDIENCE = ( + "request_state_security is configured but this server has no name. Sealed\n" + "requestState carries the server name as an audience claim, so state minted by\n" + "another service that shares the same keys is rejected; unnamed servers would\n" + "all stamp the same placeholder and the check would mean nothing. Name the\n" + 'server (MCPServer("my-service", ...)) or set RequestStateSecurity(audience=...).' +) + + def lifespan_wrapper( app: MCPServer[LifespanResultT], lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], @@ -170,6 +180,7 @@ def __init__( lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, auth: AuthSettings | None = None, resource_security: ResourceSecurity = DEFAULT_RESOURCE_SECURITY, + request_state_security: RequestStateSecurity | None = None, cache_hints: Mapping[CacheableMethod, CacheHint] | None = None, ): self._resource_security = resource_security @@ -210,6 +221,17 @@ def __init__( # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) + # Ordering: inside OpenTelemetry (spans record the sealed wire form), + # outside extension interceptors (extensions see plaintext). + if request_state_security is None: + security = RequestStateSecurity.ephemeral() + else: + # A supplied policy usually means shared keys, where the audience claim is + # what separates services; an unnamed server would stamp the placeholder. + if not name and request_state_security.audience is None: + raise ValueError(_MISSING_AUDIENCE) + security = request_state_security + self._lowlevel_server.middleware.append(RequestStateBoundary(security, default_audience=self.name)) # Validate auth configuration if self.settings.auth is not None: if auth_server_provider and token_verifier: # pragma: no cover diff --git a/src/mcp/server/request_state.py b/src/mcp/server/request_state.py new file mode 100644 index 0000000000..ad1abe8c36 --- /dev/null +++ b/src/mcp/server/request_state.py @@ -0,0 +1,454 @@ +"""Integrity protection for the multi-round-trip `requestState` (MCP 2026-07-28). + +The spec requires servers to treat the client-echoed `requestState` as +attacker-controlled: `RequestStateBoundary` seals every outgoing value and +verifies every inbound echo, so handlers only ever see plaintext they minted. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import logging +import math +import os +import time +from collections.abc import Callable, Mapping, Sequence +from dataclasses import replace +from typing import Any, NoReturn, Protocol, cast + +from cryptography.exceptions import InvalidTag +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from mcp_types import INTERNAL_ERROR, INVALID_PARAMS +from mcp_types.methods import INPUT_REQUIRED_METHODS, is_input_required + +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.server.auth.provider import principal_components +from mcp.server.context import CallNext, HandlerResult, ServerRequestContext +from mcp.shared.exceptions import MCPError + +__all__ = [ + "AESGCMRequestStateCodec", + "InvalidRequestState", + "RequestStateBoundary", + "RequestStateCodec", + "RequestStateSecurity", + "authenticated_principal", +] + +logger = logging.getLogger(__name__) + + +class InvalidRequestState(Exception): + """A sealed `requestState` token failed verification. + + The message is a log-only reason code; the boundary never puts it on the wire. + """ + + +class RequestStateCodec(Protocol): + """Authenticated crypto over the framework's request-state envelope. + + The framework stamps and re-verifies every envelope claim (expiry, request + binding, principal); a codec only provides integrity and, ideally, + confidentiality (a sign-only codec leaves the payload client-readable). + + Requirements: `unseal(seal(payload))` round-trips, and `unseal` raises + `InvalidRequestState` for any token it did not mint unmodified; tokens + never name their algorithm (version with a format prefix bound under the + authentication tag, RFC 8725); comparisons are constant-time. Both methods + are synchronous, so cache key material rather than calling a KMS per token. + """ + + def seal(self, payload: bytes) -> str: + """Return an opaque URL-safe token protecting `payload`.""" + ... + + def unseal(self, token: str) -> bytes: + """Reverse `seal`. + + Raises: + InvalidRequestState: Malformed, unauthentic, or unknown-key token. + """ + ... + + +def authenticated_principal(ctx: ServerRequestContext[Any, Any]) -> str | None: + """Default principal binding: the authenticated (client, issuer, subject) identity. + + Uses the same components session ownership uses, so two users of one OAuth + client are distinct principals whenever the token verifier supplies a + subject, and the binding degrades to the client identity when it does not. + Returns `None` (state not principal-bound) on unauthenticated transports. + """ + token = get_access_token() + if token is None: + return None + return compact_json(principal_components(token)) + + +class RequestStateSecurity: + """Policy for protecting `requestState`: codec, TTL, principal, audience. + + Exactly one of `keys` or `codec`: + + RequestStateSecurity(keys=[secret]) # built-in AES-256-GCM + RequestStateSecurity(codec=MyKmsCodec()) # bring your own crypto + RequestStateSecurity.ephemeral() # process-local key + + `keys` is the rotation ring: `keys[0]` seals, every key unseals. + Zero-downtime rotation, each phase fully rolled out before the next: + `keys=[old, new]`, then `keys=[new, old]`, then `keys=[new]` after one TTL. + + The boundary enforces expiry, request binding, audience, and principal for + every codec, fail-closed in both directions. `audience=None` defers to the + boundary's `default_audience` (`MCPServer` passes its server name). + """ + + codec: RequestStateCodec + ttl: float + bind_principal: Callable[[ServerRequestContext[Any, Any]], str | None] | None + audience: str | None + + def __init__( + self, + *, + keys: Sequence[bytes | bytearray | str] | None = None, + codec: RequestStateCodec | None = None, + ttl: float = 600.0, + bind_principal: Callable[[ServerRequestContext[Any, Any]], str | None] | None = authenticated_principal, + audience: str | None = None, + ) -> None: + if (keys is None) == (codec is None): + raise ValueError("RequestStateSecurity takes exactly one of keys= or codec=") + if not (math.isfinite(ttl) and ttl > 0): + raise ValueError(f"request-state ttl must be a positive finite number, got {ttl!r}") + if keys is not None: + self.codec = AESGCMRequestStateCodec(keys) + else: + assert codec is not None + self.codec = codec + self.ttl = ttl + self.bind_principal = bind_principal + self.audience = audience + + @classmethod + def ephemeral(cls, *, ttl: float = 600.0, audience: str | None = None) -> RequestStateSecurity: + """Protection under a key generated now and held only by this process. + + This is the policy `MCPServer` installs when `request_state_security=` + is omitted; call it yourself on the lowlevel tier or to set `ttl`/ + `audience`. Suits single-process deployments (stdio, one HTTP worker): + state minted before a restart or by another worker is rejected. + Multi-instance deployments must share a key via `keys=[...]`. + """ + return cls(keys=[os.urandom(32)], ttl=ttl, audience=audience) + + +_KDF_INFO = b"mcp/request-state/v1/aes-256-gcm" +_KID_INFO = b"mcp/request-state/v1/kid:" +_TOKEN_PREFIX = "v1." +_KID_LEN = 4 +_NONCE_LEN = 12 + + +def compact_json(value: Any, *, sort_keys: bool = False) -> str: + """Canonical JSON for everything the state path digests or seals. + + ASCII output keeps the encode total: a lone surrogate in client-supplied + text escapes instead of raising. Anything consuming this must parse with + stdlib `json.loads`, which accepts those escapes (pydantic's JSON parser + does not). + """ + return json.dumps(value, sort_keys=sort_keys, separators=(",", ":")) + + +def _b64u(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode().rstrip("=") + + +def _b64u_decode(text: str) -> bytes: + """Strict inverse of `_b64u`: only the canonical unpadded encoding decodes.""" + raw = base64.urlsafe_b64decode(text + "=" * (-len(text) % 4)) + if _b64u(raw) != text: + raise ValueError("non-canonical base64url") + return raw + + +def _derive_key(secret: bytes) -> bytes: + """Stretch an operator secret (>= 32 bytes, any format) into the AES-256 key.""" + return HKDF(algorithm=SHA256(), length=32, salt=None, info=_KDF_INFO).derive(secret) + + +class AESGCMRequestStateCodec: + """Built-in codec: AES-256-GCM under key(s) derived with HKDF-SHA256. + + Tokens are encrypted, not merely signed, so clients cannot read the state. + `keys[0]` seals; all keys unseal (rotation, see `RequestStateSecurity`). + Each token carries a 4-byte non-secret key fingerprint for an O(1) ring + lookup, and the "v1." prefix and fingerprint are bound into the GCM + associated data, so a token cannot be replayed into another format version + or ring slot. Key bytes are copied at construction. + """ + + def __init__(self, keys: Sequence[bytes | bytearray | str]) -> None: + for i, key in enumerate(cast("Sequence[object]", keys)): + if not isinstance(key, bytes | bytearray | str): + # Never coerce: bytes(32) would silently build an all-zero key. + raise TypeError( + f"request-state keys must be bytes, bytearray, or str; keys[{i}] is {type(key).__name__}" + ) + material = [k.encode() if isinstance(k, str) else bytes(k) for k in keys] + if not material: + raise ValueError("AESGCMRequestStateCodec requires at least one key") + for i, k in enumerate(material): + if len(k) < 32: + raise ValueError( + f"request-state keys must be at least 32 bytes of secret randomness; " + f"keys[{i}] is {len(k)} bytes. " + 'Generate one with: python -c "import secrets; print(secrets.token_hex(32))"' + ) + self._ring: dict[bytes, AESGCM] = {} + self._mint_kid = b"" + for i, secret in enumerate(material): + key = _derive_key(secret) + kid = hashlib.sha256(_KID_INFO + key).digest()[:_KID_LEN] + if kid in self._ring: + raise ValueError(f"keys[{i}] duplicates an earlier ring key") + self._ring[kid] = AESGCM(key) + if i == 0: + self._mint_kid = kid + + def seal(self, payload: bytes) -> str: + kid = self._mint_kid + nonce = os.urandom(_NONCE_LEN) + sealed = self._ring[kid].encrypt(nonce, payload, _TOKEN_PREFIX.encode() + kid) + return _TOKEN_PREFIX + _b64u(kid + nonce + sealed) + + def unseal(self, token: str) -> bytes: + if not token.startswith(_TOKEN_PREFIX): + raise InvalidRequestState("malformed") + try: + raw = _b64u_decode(token[len(_TOKEN_PREFIX) :]) + except ValueError as exc: + raise InvalidRequestState("malformed") from exc + if len(raw) < _KID_LEN + _NONCE_LEN + 16: + raise InvalidRequestState("malformed") + kid, nonce, sealed = raw[:_KID_LEN], raw[_KID_LEN : _KID_LEN + _NONCE_LEN], raw[_KID_LEN + _NONCE_LEN :] + aead = self._ring.get(kid) + if aead is None: + raise InvalidRequestState("unknown key") + try: + return aead.decrypt(nonce, sealed, _TOKEN_PREFIX.encode() + kid) + except InvalidTag: + raise InvalidRequestState("seal") from None + + +# The multi-round-trip carriers: the only methods whose results may carry `requestState`. +_MRTR_METHODS = INPUT_REQUIRED_METHODS +_ENVELOPE_VERSION = 1 +_FUTURE_SKEW = 60.0 +_PRINCIPAL_LABEL = b"mcp/request-state/principal:" + +_RoundBinding = tuple[str, str, str | None] +"""The (target, args-digest, principal) one round's envelope binds, computed once per round.""" + + +def _reject(method: str, reason: str) -> NoReturn: + """Refuse a round: frozen wire error, real reason to the server log only.""" + logger.warning("requestState rejected on %s: %s", method, reason) + raise MCPError( + code=INVALID_PARAMS, + message="Invalid or expired requestState", + data={"reason": "invalid_request_state"}, + ) + + +def _request_identity(method: str, params: Mapping[str, Any] | None) -> tuple[str, str]: + """Salient (target, args-digest) for the request a token binds to. + + Per-method allowlist, never a denylist: a future wire field cannot silently join the digest. + """ + p: Mapping[str, Any] = params or {} + args: dict[str, Any] = {} + if method == "resources/read": + target = str(p.get("uri", "")) + else: + target, args = str(p.get("name", "")), p.get("arguments") or args + return target, _b64u(hashlib.sha256(compact_json(args, sort_keys=True).encode()).digest()[:16]) + + +def _principal_claim(principal: str) -> str: + salt = os.urandom(8) + tag = hashlib.sha256(_PRINCIPAL_LABEL + salt + _principal_bytes(principal)).digest()[:16] + return _b64u(salt + tag) + + +def _principal_matches(claim: str, principal: str) -> bool: + try: + raw = _b64u_decode(claim) + except ValueError: + return False + # A wrong-length claim never matches: compare_digest handles mismatched sizes. + expected = hashlib.sha256(_PRINCIPAL_LABEL + raw[:8] + _principal_bytes(principal)).digest()[:16] + return hmac.compare_digest(raw[8:], expected) + + +def _principal_bytes(principal: str) -> bytes: + # The digest input is one-way and never decoded, so surrogatepass keeps it total. + return principal.encode("utf-8", "surrogatepass") + + +def _bound_principal( + security: RequestStateSecurity, + ctx: ServerRequestContext[Any, Any], + fail: Callable[[str], NoReturn], +) -> str | None: + """Run `bind_principal` under the deny-on-error discipline, in one place for both directions. + + `fail` converts a failure into the calling direction's wire shape: the + frozen rejection when verifying, the sanitized internal error when sealing. + """ + try: + principal = security.bind_principal(ctx) if security.bind_principal is not None else None + except Exception: # deny-on-error: a raising principal binding must fail closed + logger.exception("bind_principal raised while processing requestState on %s", ctx.method) + fail("principal binding error") + # The declared return type is str | None, but a user callback can ignore it. + if principal is not None and not isinstance(cast("object", principal), str): + fail(f"bind_principal returned {type(principal).__name__}, expected str or None") + return principal + + +class RequestStateBoundary: + """Server middleware sealing/unsealing `requestState` at the wire boundary. + + Acts only on the multi-round-trip carriers (tools/call, prompts/get, + resources/read); every other method passes through untouched. + + Inbound state is verified (codec unseal plus claims check) and replaced + with the plaintext the server minted before any interceptor or handler + runs; failure answers -32602 with the frozen message "Invalid or expired + requestState", the real reason going to the server log only. Outbound, an + `input_required` result carrying `requestState` is sealed in a fresh + claims envelope; handlers and resolvers never call the codec. + + `default_audience` seeds the audience claim when the policy sets none, and + must be stated explicitly: it is the service identity that stops state + minted by another service sharing the same keys. `MCPServer` installs this + middleware with its server name by default (under an ephemeral policy + unless `request_state_security=` supplies one); lowlevel `Server` users + append one to `server.middleware`, passing their server's name (or `None` + to deliberately leave tokens audience-free). + """ + + def __init__(self, security: RequestStateSecurity, *, default_audience: str | None) -> None: + self._security = security + self._audience = security.audience if security.audience is not None else default_audience + + async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult: + if ctx.method not in _MRTR_METHODS: + return await call_next(ctx) + binding: _RoundBinding | None = None + if ctx.params is not None and ctx.params.get("requestState") is not None: + # An explicit JSON null counts as absent: stripping the field is already in any client's power. + plaintext, binding = self._unseal(ctx) + ctx = replace(ctx, params={**ctx.params, "requestState": plaintext}) + result = await call_next(ctx) + return self._seal_result(ctx, result, binding) + + def _unseal(self, ctx: ServerRequestContext[Any, Any]) -> tuple[str, _RoundBinding]: + assert ctx.params is not None + wire = ctx.params["requestState"] + if not isinstance(wire, str): + _reject(ctx.method, "non-string requestState") + security = self._security + try: + payload = security.codec.unseal(wire) + except InvalidRequestState as exc: + _reject(ctx.method, str(exc)) + except Exception: # deny-on-error: a buggy custom codec must fail closed + logger.exception("requestState codec raised during unseal on %s", ctx.method) + _reject(ctx.method, "codec error") + try: + claims = json.loads(payload) + version, iat, exp, inner = claims["v"], claims["iat"], claims["exp"], claims["s"] + except (ValueError, KeyError, TypeError): + _reject(ctx.method, "malformed") + if version != _ENVELOPE_VERSION or not isinstance(inner, str): + _reject(ctx.method, "malformed") + now = time.time() + # Accept-conditions are stated positively so a NaN claim fails the comparison and rejects. + if not isinstance(iat, int | float) or not (iat <= now + _FUTURE_SKEW): + _reject(ctx.method, "minted in the future") + if not isinstance(exp, int | float) or not (now < exp): + _reject(ctx.method, "expired") + target, args_digest = _request_identity(ctx.method, ctx.params) + if claims.get("m") != ctx.method or claims.get("t") != target or claims.get("a") != args_digest: + _reject(ctx.method, "request binding") + if claims.get("aud") != self._audience: + _reject(ctx.method, "audience") + + def fail_verify(reason: str) -> NoReturn: + _reject(ctx.method, reason) + + principal = _bound_principal(security, ctx, fail_verify) + claim = claims.get("p") + if (claim is None) != (principal is None): + _reject(ctx.method, "principal drift") + if claim is not None and principal is not None: + if not isinstance(claim, str) or not _principal_matches(claim, principal): + _reject(ctx.method, "principal") + return inner, (target, args_digest, principal) + + def _seal_result( + self, ctx: ServerRequestContext[Any, Any], result: HandlerResult, binding: _RoundBinding | None + ) -> HandlerResult: + # Spec-path results arrive as wire mappings; a short-circuiting middleware may return a model. + if not is_input_required(result): + return result + state = result.get("requestState") if isinstance(result, Mapping) else result.request_state + if state is None: + return result + if isinstance(result, Mapping): + if not isinstance(state, str): + # Only a short-circuiting middleware can put a non-string here; nothing to seal. + return result + return {**result, "requestState": self._seal(ctx, state, binding)} + return result.model_copy(update={"request_state": self._seal(ctx, state, binding)}) + + def _seal(self, ctx: ServerRequestContext[Any, Any], state: str, binding: _RoundBinding | None = None) -> str: + security = self._security + if binding is None: + + def fail_seal(reason: str) -> NoReturn: + logger.error("refusing to seal requestState on %s: %s", ctx.method, reason) + raise MCPError(code=INTERNAL_ERROR, message="Internal error") + + target, args_digest = _request_identity(ctx.method, ctx.params) + binding = (target, args_digest, _bound_principal(security, ctx, fail_seal)) + target, args_digest, principal = binding + now = time.time() + claims: dict[str, Any] = { + "v": _ENVELOPE_VERSION, + "iat": now, + "exp": now + security.ttl, + "m": ctx.method, + "t": target, + "a": args_digest, + "s": state, + } + if self._audience is not None: + claims["aud"] = self._audience + if principal is not None: + claims["p"] = _principal_claim(principal) + payload = compact_json(claims).encode() + try: + return security.codec.seal(payload) + except Exception: # deny-on-error: a raising custom codec must not leak its failure + logger.exception("requestState codec raised during seal on %s", ctx.method) + raise MCPError(code=INTERNAL_ERROR, message="Internal error") from None diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 6773fd4de8..6aa9cd6d5c 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -204,7 +204,7 @@ async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult: if (hint := self.server.cache_hints.get(method)) is not None: if isinstance(result, CacheableResult): result = apply_cache_hint(result, hint) - elif isinstance(result, Mapping) and result.get("resultType") != "input_required": + elif isinstance(result, Mapping) and not _methods.is_input_required(result): # Hint keys first so wire keys the handler set win, matching `apply_cache_hint` precedence. result = {"ttlMs": hint.ttl_ms, "cacheScope": hint.scope, **result} # Dump and serialize inside the chain so the OpenTelemetry span (the diff --git a/tests/docs_src/test_mrtr.py b/tests/docs_src/test_mrtr.py index 110bd8f781..cf7842b0af 100644 --- a/tests/docs_src/test_mrtr.py +++ b/tests/docs_src/test_mrtr.py @@ -18,9 +18,10 @@ TextContent, ) -from docs_src.mrtr import tutorial001, tutorial002, tutorial003, tutorial004 +from docs_src.mrtr import tutorial001, tutorial002, tutorial003, tutorial004, tutorial005 from mcp import Client, MCPError from mcp.client import ClientRequestContext +from mcp.server.mcpserver import InvalidRequestState # See test_index.py for why this is a per-module mark and not a conftest hook. pytestmark = [pytest.mark.anyio, pytest.mark.filterwarnings("error::mcp.MCPDeprecationWarning")] @@ -161,3 +162,36 @@ async def test_the_prompt_auto_loop_returns_the_final_messages() -> None: ], ) ) + + +def test_a_custom_codec_round_trips_what_it_sealed() -> None: + """tutorial005: `unseal(seal(payload))` returns the payload; the token itself is opaque hex.""" + codec = tutorial005.EnvelopeCodec(tutorial005.unwrap_data_key()) + token = codec.seal(b"round-1") + assert token.startswith(tutorial005.PREFIX) + assert b"round-1" not in token.encode() + assert codec.unseal(token) == b"round-1" + + +def test_a_custom_codec_raises_invalid_request_state_for_any_bad_token() -> None: + """tutorial005: any token the codec did not mint intact raises `InvalidRequestState`.""" + codec = tutorial005.EnvelopeCodec(tutorial005.unwrap_data_key()) + token = codec.seal(b"round-1") + with pytest.raises(InvalidRequestState): + codec.unseal(token + "00") + with pytest.raises(InvalidRequestState): + codec.unseal("not-a-token") + + +def test_a_custom_codec_rejects_every_alias_of_a_minted_token() -> None: + """tutorial005: only the exact minted string verifies; rewritten spellings of it do not.""" + codec = tutorial005.EnvelopeCodec(tutorial005.unwrap_data_key()) + token = codec.seal(b"round-1") + body = token.removeprefix(tutorial005.PREFIX) + for alias in ( + body, # prefix stripped + tutorial005.PREFIX + body.upper(), # non-canonical hex case + tutorial005.PREFIX + body[:8] + " " + body[8:], # whitespace bytes.fromhex would skip + ): + with pytest.raises(InvalidRequestState): + codec.unseal(alias) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index bd14e294c2..6ab3436771 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -9,8 +9,18 @@ from starlette.requests import Request from starlette.types import Message, Receive, Scope, Send -from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, BearerAuthBackend, RequireAuthMiddleware -from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, ProviderTokenVerifier +from mcp.server.auth.middleware.bearer_auth import ( + AuthenticatedUser, + BearerAuthBackend, + RequireAuthMiddleware, + authorization_context, +) +from mcp.server.auth.provider import ( + AccessToken, + OAuthAuthorizationServerProvider, + ProviderTokenVerifier, + principal_components, +) class MockOAuthProvider: @@ -446,3 +456,16 @@ async def send(message: Message) -> None: # pragma: no cover assert app.scope == scope assert app.receive == receive assert app.send == send + + +def test_authorization_context_is_built_from_principal_components() -> None: + """Session ownership identifies the principal via the shared principal_components triple.""" + token = AccessToken( + token="t", client_id="client-1", scopes=[], subject="alice", claims={"iss": "https://as.example"} + ) + client_id, issuer, subject = principal_components(token) + assert authorization_context(AuthenticatedUser(token)) == { + "client_id": client_id, + "issuer": issuer, + "subject": subject, + } diff --git a/tests/server/auth/test_provider.py b/tests/server/auth/test_provider.py index aaaeb413a4..8c07d02acb 100644 --- a/tests/server/auth/test_provider.py +++ b/tests/server/auth/test_provider.py @@ -1,6 +1,6 @@ """Tests for mcp.server.auth.provider module.""" -from mcp.server.auth.provider import construct_redirect_uri +from mcp.server.auth.provider import AccessToken, construct_redirect_uri, principal_components def test_construct_redirect_uri_no_existing_params(): @@ -77,3 +77,14 @@ def test_construct_redirect_uri_encoded_values(): # urlencode uses + for spaces by default assert "state=test+state+with+spaces" in result + + +def test_principal_components_composes_client_issuer_subject(): + """The triple identifying a token's principal, degrading per missing component.""" + bare = AccessToken(token="t", client_id="client-1", scopes=[]) + assert principal_components(bare) == ("client-1", None, None) + + full = AccessToken( + token="t", client_id="client-1", scopes=[], subject="alice", claims={"iss": "https://as.example"} + ) + assert principal_components(full) == ("client-1", "https://as.example", "alice") diff --git a/tests/server/mcpserver/test_resolve.py b/tests/server/mcpserver/test_resolve.py index 571cefcb6c..c28f12481c 100644 --- a/tests/server/mcpserver/test_resolve.py +++ b/tests/server/mcpserver/test_resolve.py @@ -3,7 +3,7 @@ import json from collections.abc import Callable from datetime import datetime -from typing import Annotated, Any, Literal, TypeVar +from typing import Annotated, Any, Literal, TypeVar, cast import anyio import pytest @@ -23,22 +23,28 @@ from mcp import Client, InputRequiredRoundsExceededError from mcp.client import ClientRequestContext +from mcp.server.context import ServerRequestContext from mcp.server.mcpserver import ( AcceptedElicitation, + AESGCMRequestStateCodec, CancelledElicitation, Context, DeclinedElicitation, Elicit, ElicitationResult, MCPServer, + RequestStateBoundary, + RequestStateSecurity, Resolve, ) from mcp.server.mcpserver.exceptions import InvalidSignature from mcp.server.mcpserver.resolve import ( _check_elicit_return, _decode_state, + _elicit_request, _encode_state, _outcome_from_state, + _request_digest, _resolver_key, _state_key, _StateEntry, @@ -50,6 +56,11 @@ from mcp.shared.exceptions import MCPError +def _question_digest(elicit: Elicit[Any]) -> str: + """The digest `_elicit` pins: the rendered request the client would be shown.""" + return _request_digest(_elicit_request(elicit)) + + class Login(BaseModel): username: str @@ -126,9 +137,49 @@ def _answer_round( return responses +# Fixed key shared with servers under test, so tests can unseal minted wire +# state and seal crafted state the server will accept. +_PIN_KEY = b"0123456789abcdef0123456789abcdef" + + +def _unseal_inner(request_state: str | None) -> str: + """Unseal a wire `request_state` minted under `_PIN_KEY` into the inner plaintext state.""" + assert request_state is not None + claims = json.loads(AESGCMRequestStateCodec([_PIN_KEY]).unseal(request_state)) + inner = claims["s"] + assert isinstance(inner, str) + return inner + + +def _outcomes_on_the_wire(request_state: str | None) -> dict[str, Any]: + """Unseal a wire `request_state` minted under `_PIN_KEY` and return its outcomes.""" + return json.loads(_unseal_inner(request_state))["outcomes"] + + +def _sealed_state(inner: str, *, tool: str, args: dict[str, Any], audience: str) -> str: + """Seal a hand-built inner state exactly as the boundary does for a `tools/call` retry. + + The production `RequestStateBoundary._seal` binds method, tool, arguments, and + audience (the server name), so the test must then call exactly `tool` with + exactly `args` on the MCPServer named `audience`. + """ + ctx = ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/call", + params={"name": tool, "arguments": args}, + ) + return RequestStateBoundary(RequestStateSecurity(keys=[_PIN_KEY]), default_audience=audience)._seal(ctx, inner) + + +def _wire_key(fn: Callable[..., Any]) -> str: + return f"{fn.__module__}:{fn.__qualname__}" + + @pytest.mark.anyio async def test_resolver_returns_value_directly_without_eliciting(): - mcp = MCPServer(name="Direct") + mcp = MCPServer(name="Direct", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Login | Elicit[Login]: username = (ctx.headers or {}).get("x-github-user") @@ -149,7 +200,7 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E @pytest.mark.anyio async def test_resolver_elicits_and_injects_unwrapped_model_on_accept(): - mcp = MCPServer(name="Accept") + mcp = MCPServer(name="Accept", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Login | Elicit[Login]: return Elicit("GitHub username?", Login) @@ -164,7 +215,7 @@ async def whoami(login: Annotated[Login, Resolve(login)]) -> str: @pytest.mark.anyio async def test_consumer_receives_result_union_and_branches(): - mcp = MCPServer(name="Union") + mcp = MCPServer(name="Union", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Login | Elicit[Login]: return Elicit("GitHub username?", Login) @@ -183,7 +234,7 @@ async def whoami(login: Annotated[ElicitationResult[Login], Resolve(login)]) -> @pytest.mark.anyio async def test_decline_reaches_union_consumer_without_aborting(): - mcp = MCPServer(name="UnionDecline") + mcp = MCPServer(name="UnionDecline", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Login | Elicit[Login]: return Elicit("GitHub username?", Login) @@ -202,7 +253,7 @@ async def whoami( @pytest.mark.anyio async def test_decline_aborts_when_consumer_wants_unwrapped(): - mcp = MCPServer(name="UnwrappedDecline") + mcp = MCPServer(name="UnwrappedDecline", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Login | Elicit[Login]: return Elicit("GitHub username?", Login) @@ -220,7 +271,7 @@ async def whoami(login: Annotated[Login, Resolve(login)]) -> str: @pytest.mark.anyio async def test_nested_resolver_sees_dependency_and_tool_args(): - mcp = MCPServer(name="Nested") + mcp = MCPServer(name="Nested", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Login | Elicit[Login]: return Elicit("GitHub username?", Login) @@ -251,7 +302,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - @pytest.mark.anyio async def test_resolver_runs_once_for_two_consumers(): - mcp = MCPServer(name="ExactlyOnce") + mcp = MCPServer(name="ExactlyOnce", request_state_security=RequestStateSecurity.ephemeral()) elicit_count = 0 async def login(ctx: Context) -> Login | Elicit[Login]: @@ -281,7 +332,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - @pytest.mark.anyio async def test_sync_resolver(): - mcp = MCPServer(name="Sync") + mcp = MCPServer(name="Sync", request_state_security=RequestStateSecurity.ephemeral()) def login(ctx: Context) -> Login: return Login(username="sync-user") @@ -428,7 +479,7 @@ async def tool(login: Annotated[Login, Resolve(BadResolver())]) -> str: @pytest.mark.anyio async def test_by_name_resolver_param_uses_aliased_tool_arg(): - mcp = MCPServer(name="Aliased") + mcp = MCPServer(name="Aliased", request_state_security=RequestStateSecurity.ephemeral()) # `schema` collides with a BaseModel attribute, so func_metadata aliases the field; # the runtime kwarg key is the alias, which is what a by-name resolver must match. @@ -448,7 +499,7 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E @pytest.mark.anyio async def test_resolver_may_return_non_basemodel_value(): - mcp = MCPServer(name="NonModel") + mcp = MCPServer(name="NonModel", request_state_security=RequestStateSecurity.ephemeral()) async def get_token(ctx: Context) -> str: return "secret-token" @@ -466,7 +517,7 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E @pytest.mark.anyio async def test_resolver_accepts_optional_context_annotation(): - mcp = MCPServer(name="OptionalContext") + mcp = MCPServer(name="OptionalContext", request_state_security=RequestStateSecurity.ephemeral()) async def whoami(ctx: Context | None) -> str: assert ctx is not None @@ -485,7 +536,7 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E @pytest.mark.anyio async def test_bound_method_resolver_runs_once_across_references(): - mcp = MCPServer(name="BoundMethod") + mcp = MCPServer(name="BoundMethod", request_state_security=RequestStateSecurity.ephemeral()) calls = 0 class Service: @@ -537,7 +588,7 @@ async def tool(value: Annotated[Login, Resolve(service.a)]) -> str: @pytest.mark.anyio async def test_resolver_and_body_see_the_same_validated_default(): - mcp = MCPServer(name="DefaultFactory") + mcp = MCPServer(name="DefaultFactory", request_state_security=RequestStateSecurity.ephemeral()) counter = {"n": 0} def next_id() -> int: @@ -590,7 +641,7 @@ def fn() -> None: ... # pragma: no cover def _delete_folder_server() -> tuple[MCPServer, dict[str, list[str]]]: """The `delete_folder` example from docs/migration.md, wired to an in-memory fs.""" - mcp = MCPServer(name="files") + mcp = MCPServer(name="files", request_state_security=RequestStateSecurity.ephemeral()) fs: dict[str, list[str]] = {} async def confirm_delete(path: str) -> Confirm | Elicit[Confirm]: @@ -723,7 +774,7 @@ async def never(context: ClientRequestContext, params: ElicitRequestParams) -> E @pytest.mark.anyio async def test_input_required_asks_each_question_once_while_bodies_rerun(): - mcp = MCPServer(name="ExactlyOnceMRTR") + mcp = MCPServer(name="ExactlyOnceMRTR", request_state_security=RequestStateSecurity.ephemeral()) counts = {"login": 0, "confirm": 0} async def login(ctx: Context) -> Login | Elicit[Login]: @@ -768,7 +819,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - @pytest.mark.anyio async def test_input_required_batches_independent_elicits_in_one_round(): - mcp = MCPServer(name="BatchedMRTR") + mcp = MCPServer(name="BatchedMRTR", request_state_security=RequestStateSecurity.ephemeral()) async def ask_name(ctx: Context) -> Elicit[Login]: return Elicit("Name?", Login) @@ -812,7 +863,7 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: async def test_auto_driver_answers_independent_questions_in_a_single_round(): # The pure `count_round` resolver is never persisted in `request_state`, so it # re-runs on every round: its run count is the number of rounds the call took. - mcp = MCPServer(name="AutoBatch") + mcp = MCPServer(name="AutoBatch", request_state_security=RequestStateSecurity.ephemeral()) rounds = 0 async def count_round(ctx: Context) -> int: @@ -871,7 +922,8 @@ def test_uses_input_required_version_gate(): ], ) def test_decode_state_tolerates_malformed_request_state(request_state: str | None): - assert _decode_state(request_state) == {} + state = _decode_state(request_state) + assert state.outcomes == {} and state.asked == {} def test_state_round_trips_accept_decline_cancel(): @@ -881,8 +933,10 @@ def test_state_round_trips_accept_decline_cancel(): "c": _StateEntry(action="cancel"), "d": _StateEntry(action="accept", data="raw-token"), # non-dict wire value } - decoded = _decode_state(_encode_state(entries)) + state = _decode_state(_encode_state(entries, {"e": "asked-digest"})) + decoded = state.outcomes assert decoded == entries # encode-restore is the identity on the stored entries + assert state.asked == {"e": "asked-digest"} accepted = _outcome_from_state(decoded["a"], Login) assert isinstance(accepted, AcceptedElicitation) and accepted.data == Login(username="octocat") @@ -907,7 +961,7 @@ def test_check_elicit_return_allows_one_arm_and_rejects_two(): @pytest.mark.anyio async def test_non_elicitation_response_raises(): - mcp = MCPServer(name="WrongResponse") + mcp = MCPServer(name="WrongResponse", request_state_security=RequestStateSecurity.ephemeral()) async def ask(ctx: Context) -> Elicit[Login]: return Elicit("Name?", Login) @@ -942,7 +996,7 @@ async def test_direct_call_tool_with_non_eliciting_resolver(): # `MCPServer.call_tool()` called directly builds a Context with no request, so # `ctx.protocol_version` is None. A tool whose resolvers never elicit must still # work there (regression: it used to raise "Context is not available"). - mcp = MCPServer(name="Direct") + mcp = MCPServer(name="Direct", request_state_security=RequestStateSecurity.ephemeral()) async def whoami(ctx: Context) -> Login: return Login(username="direct") @@ -959,7 +1013,7 @@ async def tool(login: Annotated[Login, Resolve(whoami)]) -> str: @pytest.mark.anyio async def test_two_instances_of_one_method_do_not_collide(): - mcp = MCPServer(name="Instances") + mcp = MCPServer(name="Instances", request_state_security=RequestStateSecurity.ephemeral()) class Service: def __init__(self, name: str) -> None: @@ -985,7 +1039,7 @@ async def both( @pytest.mark.anyio async def test_non_serializable_sibling_resolver_does_not_break_rounds(): - mcp = MCPServer(name="NonSerializable") + mcp = MCPServer(name="NonSerializable", request_state_security=RequestStateSecurity.ephemeral()) async def clock(ctx: Context) -> datetime: return datetime(2026, 1, 1) @@ -1013,7 +1067,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - async def test_bare_elicit_dependency_restored_as_model(): # A `-> Elicit[Login]` (bare, no union) resolver feeds a dependent resolver. After # the round-trip the dependency must come back as a Login model, not a raw dict. - mcp = MCPServer(name="BareElicitDep") + mcp = MCPServer(name="BareElicitDep", request_state_security=RequestStateSecurity.ephemeral()) async def login(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1045,7 +1099,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - async def test_accept_with_no_content_is_an_error_not_a_cancel(mode: Literal["legacy", "auto"]): # Both transports must agree: mode="legacy" elicits synchronously mid-call, # mode="auto" rides the 2026-07-28 input_required loop. - mcp = MCPServer(name="AcceptNoContent") + mcp = MCPServer(name="AcceptNoContent", request_state_security=RequestStateSecurity.ephemeral()) async def ask(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1069,7 +1123,7 @@ async def test_eliciting_tool_without_client_capability_is_a_protocol_error(): # The server must not send an `input_requests` entry the client has not declared # capability for: with no `elicitation` declared (no callback), the call fails as # a -32021 protocol error, not a CallToolResult execution failure. - mcp = MCPServer(name="NoElicitationCapability") + mcp = MCPServer(name="NoElicitationCapability", request_state_security=RequestStateSecurity.ephemeral()) async def ask(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1088,7 +1142,7 @@ async def tool(login: Annotated[Login, Resolve(ask)]) -> str: @pytest.mark.anyio async def test_independent_nested_deps_batch_into_one_round(): - mcp = MCPServer(name="NestedBatch") + mcp = MCPServer(name="NestedBatch", request_state_security=RequestStateSecurity.ephemeral()) async def ask_a(ctx: Context) -> Elicit[Login]: return Elicit("A name?", Login) @@ -1135,7 +1189,7 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: async def test_deep_chain_keeps_early_answers_across_rounds(): # A 4-round dependency chain where an early answer (A) must survive in # request_state while later resolvers are asked. It must be asked exactly once. - mcp = MCPServer(name="DeepChain") + mcp = MCPServer(name="DeepChain", request_state_security=RequestStateSecurity.ephemeral()) async def ra(ctx: Context) -> Elicit[Login]: return Elicit("A name?", Login) @@ -1177,7 +1231,7 @@ async def callback(context: ClientRequestContext, params: ElicitRequestParams) - async def test_factory_closures_get_distinct_wire_keys(): # Two resolvers from one factory share module:qualname; they must still get # distinct questions and their own values (regression: they collided on the wire). - mcp = MCPServer(name="FactoryClosures") + mcp = MCPServer(name="FactoryClosures", request_state_security=RequestStateSecurity.ephemeral()) def make(label: str): async def resolver(ctx: Context) -> Elicit[Login]: @@ -1222,7 +1276,7 @@ async def test_eliciting_resolver_without_elicit_arm_restores_a_typed_model(): # round flow, must still come back as a Login model (not a raw dict): restore # validates against the live `Elicit.schema` the body produced, not the lying # annotation, so a dependent resolver/tool can use its attributes. - mcp = MCPServer(name="LyingAnnotation") + mcp = MCPServer(name="LyingAnnotation", request_state_security=RequestStateSecurity.ephemeral()) # Annotated without an `Elicit[T]` return arm; the body asks anyway. async def login(ctx: Context) -> object: @@ -1274,7 +1328,7 @@ async def test_declined_outcome_persists_in_request_state_and_is_not_reasked(): # A decline is recorded in `request_state` just like an accept: RB elicits only # after seeing RA's decline, so RA's outcome must survive into the round that # answers RB without RA being asked again. - mcp = MCPServer(name="DeclinePersists") + mcp = MCPServer(name="DeclinePersists", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) async def ra(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1308,7 +1362,7 @@ async def act( assert second.input_requests is not None (rb_key,) = second.input_requests # only RB's question; RA is not re-asked assert rb_key != ra_key - assert _decode_state(second.request_state)[ra_key].action == "decline" + assert _outcomes_on_the_wire(second.request_state)[ra_key]["action"] == "decline" final = await client.session.call_tool( "act", @@ -1325,9 +1379,8 @@ async def act( @pytest.mark.anyio async def test_unknown_response_keys_and_ghost_state_entries_are_ignored(): # `input_responses` keys the server never asked for and `request_state` outcome - # entries matching no resolver are tolerated (both are client-supplied), and the - # ghost state entry is not echoed into any later round's `request_state`. - mcp = MCPServer(name="GhostKeys") + # entries matching no resolver are tolerated and not echoed into later rounds. + mcp = MCPServer(name="GhostKeys", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) async def ra(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1349,8 +1402,13 @@ async def act( assert first.request_state is not None (ra_key,) = first.input_requests - spliced = json.loads(first.request_state) - spliced["outcomes"]["ghost"] = {"action": "accept", "data": {"username": "spooky"}} + spliced = json.loads(_unseal_inner(first.request_state)) + # A well-formed v2 entry under an unknown key: dropped as unknown, not as malformed. + spliced["outcomes"]["ghost"] = { + "action": "accept", + "data": {"username": "spooky"}, + "q": _question_digest(Elicit("user?", Login)), + } second = await client.session.call_tool( "act", {}, @@ -1358,13 +1416,13 @@ async def act( ra_key: ElicitResult(action="accept", content={"username": "octocat"}), "ghost": ElicitResult(action="accept", content={"username": "spooky"}), }, - request_state=json.dumps(spliced), + request_state=_sealed_state(json.dumps(spliced), tool="act", args={}, audience="GhostKeys"), allow_input_required=True, ) assert isinstance(second, InputRequiredResult) assert second.input_requests is not None (rb_key,) = second.input_requests - outcomes = _decode_state(second.request_state) + outcomes = _outcomes_on_the_wire(second.request_state) assert ra_key in outcomes assert "ghost" not in outcomes # the spliced entry is dropped, not carried onward @@ -1389,10 +1447,8 @@ async def act( ], ) async def test_forged_state_entry_failing_the_schema_is_reasked_not_an_error(forged_data: str | dict[str, bool]): - # `request_state` is client-trusted JSON: an accept entry whose data does not - # validate against the resolver's schema reads as no recorded progress, so the - # question is asked again (not an error) and a proper answer completes the call. - mcp = MCPServer(name="ForgedState") + # Authenticated state is not schema-trusted: a failing accept entry reads as no progress and is re-asked. + mcp = MCPServer(name="ForgedState", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) async def ask(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1408,15 +1464,23 @@ async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: assert first.request_state is not None (key,) = first.input_requests - forged = json.loads(first.request_state) - forged["outcomes"][key] = {"action": "accept", "data": forged_data} + forged = json.loads(_unseal_inner(first.request_state)) + # The digest matches the live question, so the entry stands or falls on schema alone. + forged["outcomes"][key] = { + "action": "accept", + "data": forged_data, + "q": _question_digest(Elicit("user?", Login)), + } second = await client.session.call_tool( - "whoami", {}, request_state=json.dumps(forged), allow_input_required=True + "whoami", + {}, + request_state=_sealed_state(json.dumps(forged), tool="whoami", args={}, audience="ForgedState"), + allow_input_required=True, ) assert isinstance(second, InputRequiredResult) # re-asked, not an error assert second.input_requests is not None assert set(second.input_requests) == {key} - assert _decode_state(second.request_state) == {} # the forged entry is dropped + assert _outcomes_on_the_wire(second.request_state) == {} # the forged entry is dropped final = await client.session.call_tool( "whoami", @@ -1436,7 +1500,7 @@ async def test_schema_mismatched_fresh_answer_fails_the_call_without_pydantic_le # An accepted answer whose content fails the requested schema fails the call # with the framework's own message on both transports; pydantic's error text # (which carries an "errors.pydantic.dev" link) must not leak to the client. - mcp = MCPServer(name="MismatchedAnswer") + mcp = MCPServer(name="MismatchedAnswer", request_state_security=RequestStateSecurity.ephemeral()) async def ask(ctx: Context) -> Elicit[Login]: return Elicit("user?", Login) @@ -1464,7 +1528,7 @@ async def test_auto_driver_gives_up_when_the_chain_outlasts_its_round_budget(): # than the default `input_required_max_rounds`, so `client.call_tool` must raise # rather than loop on. The pure `count_leg` resolver is never persisted, so it # re-runs on every server leg: its final value is the exact number of legs. - mcp = MCPServer(name="TooDeep") + mcp = MCPServer(name="TooDeep", request_state_security=RequestStateSecurity.ephemeral()) legs = 0 async def count_leg(ctx: Context) -> int: @@ -1514,7 +1578,7 @@ async def test_aliased_elicitation_model_round_trips_through_request_state(): # the same validation the answer originally passed - aliases and all. A # re-derived (field-name) shape would fail validation on the round after # next, drop the stored answer, and re-ask the user forever. - mcp = MCPServer(name="AliasState") + mcp = MCPServer(name="AliasState", request_state_security=RequestStateSecurity.ephemeral()) async def who(ctx: Context) -> Elicit[Handle]: return Elicit("handle?", Handle) @@ -1566,7 +1630,7 @@ async def test_divergent_validation_and_serialization_aliases_round_trip(): # the validated model (which serializes under the *serialization* alias) would # produce data the schema's own validation rejects, dropping the stored answer # on the round after next and re-asking the user. - mcp = MCPServer(name="DivergentAliases") + mcp = MCPServer(name="DivergentAliases", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) async def who(ctx: Context) -> Elicit[Account]: return Elicit("account?", Account) @@ -1602,7 +1666,7 @@ async def act( (go_key,) = second.input_requests # only the dependent question; the stored answer holds assert go_key != who_key # The stored entry is the client's wire content, not a re-serialization of it. - assert _decode_state(second.request_state)[who_key].data == {"vUser": "octocat"} + assert _outcomes_on_the_wire(second.request_state)[who_key]["data"] == {"vUser": "octocat"} final = await client.session.call_tool( "act", @@ -1621,7 +1685,7 @@ async def test_state_entry_never_replaces_a_resolver_computed_value(): # `request_state` is client-echoed: an accept entry under a resolver's wire key # must only satisfy a question the resolver is actually asking, never stand in # for the body's own computation on a branch that does not ask. - mcp = MCPServer(name="StateVsBody") + mcp = MCPServer(name="StateVsBody", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) calls = {"decide": 0} async def decide(ctx: Context) -> Restock | Elicit[Restock]: @@ -1632,11 +1696,17 @@ async def decide(ctx: Context) -> Restock | Elicit[Restock]: async def plan_restock(restock: Annotated[Restock, Resolve(decide)]) -> str: return str(restock.needed) - wire_key = f"{decide.__module__}:{decide.__qualname__}" - crafted = json.dumps({"v": 1, "outcomes": {wire_key: {"action": "accept", "data": {"needed": True}}}}) + # A decodable v2 entry; the resolver never asks, so it must go unconsulted, not dropped as malformed. + entry = {"action": "accept", "data": {"needed": True}, "q": _question_digest(Elicit("Restock?", Restock))} + crafted = json.dumps({"v": 3, "outcomes": {_wire_key(decide): entry}}) async with Client(mcp, elicitation_callback=_never) as client: - result = await client.session.call_tool("plan_restock", {}, request_state=crafted, allow_input_required=True) + result = await client.session.call_tool( + "plan_restock", + {}, + request_state=_sealed_state(crafted, tool="plan_restock", args={}, audience="StateVsBody"), + allow_input_required=True, + ) assert isinstance(result, CallToolResult) assert isinstance(result.content[0], TextContent) # The body ran and its computation won; the crafted entry was never consulted. @@ -1648,7 +1718,7 @@ async def plan_restock(restock: Annotated[Restock, Resolve(decide)]) -> str: async def test_state_decline_entry_for_a_pure_resolver_is_ignored(): # A decline/cancel entry can only answer a question; a resolver with no Elicit # arm never asks one, so such an entry cannot suppress its computed value. - mcp = MCPServer(name="PureVsDecline") + mcp = MCPServer(name="PureVsDecline", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) async def lookup(ctx: Context) -> Login: return Login(username="server-side") @@ -1657,11 +1727,17 @@ async def lookup(ctx: Context) -> Login: async def whoami(login: Annotated[Login, Resolve(lookup)]) -> str: return login.username - wire_key = f"{lookup.__module__}:{lookup.__qualname__}" - crafted = json.dumps({"v": 1, "outcomes": {wire_key: {"action": "decline"}}}) + # A decodable v2 entry: `lookup` never asks, so no digest can make the decline apply. + entry = {"action": "decline", "q": _question_digest(Elicit("user?", Login))} + crafted = json.dumps({"v": 3, "outcomes": {_wire_key(lookup): entry}}) async with Client(mcp, elicitation_callback=_never) as client: - result = await client.session.call_tool("whoami", {}, request_state=crafted, allow_input_required=True) + result = await client.session.call_tool( + "whoami", + {}, + request_state=_sealed_state(crafted, tool="whoami", args={}, audience="PureVsDecline"), + allow_input_required=True, + ) assert isinstance(result, CallToolResult) assert not result.is_error assert isinstance(result.content[0], TextContent) @@ -1673,7 +1749,7 @@ async def test_dynamic_schema_resolver_restores_across_rounds(): # `-> Elicit[BaseModel]` is the natural annotation for `create_model(...)` # schemas; the restored answer must validate against the live question's # schema, so the dynamic shape works across a multi-question chain. - mcp = MCPServer(name="DynamicSchema") + mcp = MCPServer(name="DynamicSchema", request_state_security=RequestStateSecurity.ephemeral()) dyn = create_model("Dyn", token=(str, ...)) async def first(ctx: Context) -> Elicit[BaseModel]: @@ -1729,7 +1805,7 @@ def answer(key: str, params: ElicitRequestFormParams) -> ElicitResult: def test_tool_combining_resolvers_with_input_required_return_is_rejected(annotation: Any): # A call has one input_responses/request_state channel: resolver elicitation # and a hand-rolled InputRequiredResult body cannot share it. - mcp = MCPServer(name="ChannelOwnership") + mcp = MCPServer(name="ChannelOwnership", request_state_security=RequestStateSecurity.ephemeral()) async def lookup(ctx: Context) -> Login: return Login(username="x") # pragma: no cover - registration is rejected @@ -1755,7 +1831,7 @@ def test_unevaluable_alias_and_parameterized_generics_declare_no_arm(): # can see and must not break registration (the in-call guard still covers a # body that returns an InputRequiredResult anyway). A parameterized generic # return is never the InputRequiredResult class either. - mcp = MCPServer(name="RegistrationTolerance") + mcp = MCPServer(name="RegistrationTolerance", request_state_security=RequestStateSecurity.ephemeral()) async def lookup(ctx: Context) -> Login: return Login(username="x") # pragma: no cover - only registration is exercised @@ -1778,7 +1854,7 @@ async def test_tool_returning_input_required_dynamically_with_resolvers_is_an_er # The annotated form of this combination is rejected at registration; a body # that returns an InputRequiredResult without declaring it fails loudly at the # same boundary instead of silently fighting the resolvers for the channel. - mcp = MCPServer(name="DynamicChannelClash") + mcp = MCPServer(name="DynamicChannelClash", request_state_security=RequestStateSecurity.ephemeral()) async def lookup(ctx: Context) -> Login: return Login(username="x") @@ -1792,3 +1868,500 @@ async def sneaky(login: Annotated[Login, Resolve(lookup)]): assert result.is_error assert isinstance(result.content[0], TextContent) assert "the multi-round flow is driven either by resolvers or by the tool body" in result.content[0].text + + +def test_question_digest_pins_the_rendered_question(): + # Computed over the rendered wire question: identical Elicits agree, any change diverges. + digest = _question_digest(Elicit("Name?", Login)) + assert digest == _question_digest(Elicit("Name?", Login)) + assert digest != _question_digest(Elicit("Your name, please?", Login)) + assert digest != _question_digest(Elicit("Name?", Confirm)) + # A 16-byte sha256 prefix, base64url without padding. + assert len(digest) == 22 and "=" not in digest + + +def test_state_round_trips_question_digests_at_v3(): + # v2 carries digests for every action and round-trips exactly; v1 (mid rolling deploy) reads as no progress. + entries = { + "a": _StateEntry(action="accept", data={"username": "octocat"}, q="qa"), + "b": _StateEntry(action="decline", q="qb"), + "c": _StateEntry(action="cancel", q="qc"), + } + encoded = _encode_state(entries, {}) + assert json.loads(encoded)["v"] == 3 + assert _decode_state(encoded).outcomes == entries + v1 = json.dumps({"v": 1, "outcomes": {"a": {"action": "decline"}}}) + assert _decode_state(v1).outcomes == {} + + +@pytest.mark.anyio +async def test_restored_answer_with_matching_digest_completes_without_reasking(): + mcp = MCPServer(name="PinHappyPath", request_state_security=RequestStateSecurity.ephemeral()) + + async def who(ctx: Context) -> Elicit[Login]: + return Elicit("Who?", Login) + + async def check(login: Annotated[Login, Resolve(who)]) -> Elicit[Confirm]: + return Elicit(f"Go as {login.username}?", Confirm) + + @mcp.tool() + async def act( + login: Annotated[Login, Resolve(who)], + confirm: Annotated[Confirm, Resolve(check)], + ) -> str: + return f"{login.username}:{confirm.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + assert set(first.input_requests) == {_wire_key(who)} + + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(who): ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + # Only the dependent question; the stored answer holds, "Who?" is not re-asked. + assert set(second.input_requests) == {_wire_key(check)} + + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(check): ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat:True" + + +@pytest.mark.anyio +async def test_restored_entry_is_repersisted_with_its_question_digest_intact(): + # A restored entry must ride into the next round's state digest-intact, or it would be re-asked next round. + mcp = MCPServer(name="RepersistPin", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + + async def who(ctx: Context) -> Elicit[Login]: + return Elicit("Who?", Login) + + async def check(login: Annotated[Login, Resolve(who)]) -> Elicit[Confirm]: + return Elicit(f"Go as {login.username}?", Confirm) + + async def plan(confirm: Annotated[Confirm, Resolve(check)], ctx: Context) -> Elicit[Restock]: + return Elicit("Restock too?", Restock) + + # The body never runs (a question always pends); a bare `...` costs no coverage. + @mcp.tool() + async def act(restock: Annotated[Restock, Resolve(plan)]) -> str: ... + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(who): ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + third = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(check): ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(third, InputRequiredResult) + + round_two = _outcomes_on_the_wire(second.request_state) + round_three = _outcomes_on_the_wire(third.request_state) + # Accept entries are pinned to the exact rendered question they answered. + assert round_two[_wire_key(who)]["q"] == _question_digest(Elicit("Who?", Login)) + assert round_three[_wire_key(check)]["q"] == _question_digest(Elicit("Go as octocat?", Confirm)) + assert round_three[_wire_key(who)] == round_two[_wire_key(who)] + + +@pytest.mark.anyio +async def test_decline_and_cancel_entries_carry_the_question_digest(): + mcp = MCPServer(name="PinAllActions", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + async def ask_restock(ctx: Context) -> Elicit[Restock]: + return Elicit("Restock?", Restock) + + # The body never runs (a question always pends); a bare `...` costs no coverage. + @mcp.tool() + async def act( + name: Annotated[ElicitationResult[Login], Resolve(ask_name)], + confirm: Annotated[ElicitationResult[Confirm], Resolve(ask_confirm)], + restock: Annotated[Restock, Resolve(ask_restock)], + ) -> str: ... + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + # The third question stays unanswered, so the call pends and outcomes hit the wire. + second = await client.session.call_tool( + "act", + {}, + input_responses={ + _wire_key(ask_name): ElicitResult(action="decline"), + _wire_key(ask_confirm): ElicitResult(action="cancel"), + }, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + + outcomes = _outcomes_on_the_wire(second.request_state) + assert outcomes[_wire_key(ask_name)]["action"] == "decline" + assert outcomes[_wire_key(ask_name)]["q"] == _question_digest(Elicit("Name?", Login)) + assert outcomes[_wire_key(ask_confirm)]["action"] == "cancel" + assert outcomes[_wire_key(ask_confirm)]["q"] == _question_digest(Elicit("Confirm?", Confirm)) + + +@pytest.mark.anyio +async def test_state_entry_without_a_question_digest_is_dropped_and_reasked(): + # An entry with no digest cannot prove its question, so it reads as no progress and is re-asked. + mcp = MCPServer(name="UnpinnedEntry", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + + async def ask(ctx: Context) -> Elicit[Login]: + return Elicit("user?", Login) + + @mcp.tool() + async def whoami(login: Annotated[Login, Resolve(ask)]) -> str: + return login.username + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("whoami", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.input_requests is not None + (key,) = first.input_requests + + # Schema-valid accept data under the live key, but no "q" pin. + entry = {"action": "accept", "data": {"username": "spooky"}} + crafted = json.dumps({"v": 3, "outcomes": {key: entry}}) + second = await client.session.call_tool( + "whoami", + {}, + request_state=_sealed_state(crafted, tool="whoami", args={}, audience="UnpinnedEntry"), + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) # re-asked, not honored and not an error + assert second.input_requests is not None + assert set(second.input_requests) == {key} + assert _outcomes_on_the_wire(second.request_state) == {} # the unpinned entry is dropped + + final = await client.session.call_tool( + "whoami", + {}, + input_responses={key: ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "octocat" + + +@pytest.mark.anyio +async def test_reworded_question_drops_the_stored_answer_and_reasks(): + # An answer holds only while its question is byte-identical: a reword (redeploy) drops it and re-asks. + mcp = MCPServer(name="RewordAccept", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + wording = {"deploy": "Deploy to prod?"} + + async def ask_deploy(ctx: Context) -> Elicit[Confirm]: + return Elicit(wording["deploy"], Confirm) + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + @mcp.tool() + async def act( + deploy: Annotated[Confirm, Resolve(ask_deploy)], + name: Annotated[Login, Resolve(ask_name)], + ) -> str: + return f"{deploy.ok}:{name.username}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_deploy): ElicitResult(action="accept", content={"ok": True})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert _outcomes_on_the_wire(second.request_state)[_wire_key(ask_deploy)]["q"] == _question_digest( + Elicit("Deploy to prod?", Confirm) + ) + + wording["deploy"] = "Deploy to staging?" + + third = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_name): ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=second.request_state, + allow_input_required=True, + ) + # The stale answer is dropped and the reworded question is asked, not an error. + assert isinstance(third, InputRequiredResult) + assert third.input_requests is not None + assert set(third.input_requests) == {_wire_key(ask_deploy)} + question = third.input_requests[_wire_key(ask_deploy)].params + assert isinstance(question, ElicitRequestFormParams) + assert question.message == "Deploy to staging?" + # The sibling answer recorded in the same state survives the drop. + outcomes = _outcomes_on_the_wire(third.request_state) + assert _wire_key(ask_deploy) not in outcomes + assert outcomes[_wire_key(ask_name)]["action"] == "accept" + + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_deploy): ElicitResult(action="accept", content={"ok": True})}, + request_state=third.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "True:octocat" + + +@pytest.mark.anyio +async def test_decline_of_a_reworded_question_does_not_suppress_the_new_question(): + # A decline pinned to the old wording must not suppress the reworded question. + mcp = MCPServer(name="RewordDecline", request_state_security=RequestStateSecurity.ephemeral()) + wording = {"q": "Use defaults?"} + + async def ask(ctx: Context) -> Elicit[Confirm]: + return Elicit(wording["q"], Confirm) + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + @mcp.tool() + async def act( + choice: Annotated[ElicitationResult[Confirm], Resolve(ask)], + name: Annotated[Login, Resolve(ask_name)], + ) -> str: + kind = "accepted" if isinstance(choice, AcceptedElicitation) else "declined" + return f"{kind}:{name.username}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask): ElicitResult(action="decline")}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + + wording["q"] = "Use the new defaults?" + + third = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_name): ElicitResult(action="accept", content={"username": "octocat"})}, + request_state=second.request_state, + allow_input_required=True, + ) + # The stale decline is dropped and the reworded question is asked again. + assert isinstance(third, InputRequiredResult) + assert third.input_requests is not None + assert set(third.input_requests) == {_wire_key(ask)} + question = third.input_requests[_wire_key(ask)].params + assert isinstance(question, ElicitRequestFormParams) + assert question.message == "Use the new defaults?" + + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask): ElicitResult(action="accept", content={"ok": True})}, + request_state=third.request_state, + allow_input_required=True, + ) + # Accepting the new question proves the old decline did not stick. + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "accepted:octocat" + + +@pytest.mark.anyio +async def test_reworded_question_reasks_even_when_the_answer_first_arrives(): + # The pend round records each question's digest in the state, so an answer that + # first arrives after a reword (redeploy between ask and retry) re-asks instead + # of being consumed as consent to the new wording. + mcp = MCPServer(name="RewordArrival", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + wording = {"deploy": "Deploy to prod?"} + + async def ask_deploy(ctx: Context) -> Elicit[Confirm]: + return Elicit(wording["deploy"], Confirm) + + @mcp.tool() + async def act(deploy: Annotated[Confirm, Resolve(ask_deploy)]) -> str: + return f"deployed:{deploy.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + pended = json.loads(_unseal_inner(first.request_state))["asked"] + assert pended == {_wire_key(ask_deploy): _question_digest(Elicit("Deploy to prod?", Confirm))} + + wording["deploy"] = "Deploy to staging?" + + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_deploy): ElicitResult(action="accept", content={"ok": True})}, + request_state=first.request_state, + allow_input_required=True, + ) + # The stale answer to the old wording is not consumed; the reworded question is asked. + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + question = second.input_requests[_wire_key(ask_deploy)].params + assert isinstance(question, ElicitRequestFormParams) + assert question.message == "Deploy to staging?" + + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_deploy): ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "deployed:True" + + +@pytest.mark.anyio +async def test_an_answer_without_the_echoed_state_is_reasked_not_consumed(): + # Without the echoed state there is no record of which question the client was + # shown, so an answer arriving stateless re-asks instead of being consumed. + mcp = MCPServer(name="Stateless", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + + async def ask(ctx: Context) -> Elicit[Confirm]: + return Elicit("Proceed?", Confirm) + + @mcp.tool() + async def act(go: Annotated[Confirm, Resolve(ask)]) -> str: + return f"went:{go.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask): ElicitResult(action="accept", content={"ok": True})}, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask): ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "went:True" + + +@pytest.mark.anyio +async def test_recorded_answer_containing_a_lone_surrogate_survives_to_later_rounds(): + # The state encoder escapes lone surrogates, so the decoder must parse them back: + # a recorded answer with one must restore on the next round, not silently re-ask. + mcp = MCPServer(name="Surrogate", request_state_security=RequestStateSecurity(keys=[_PIN_KEY])) + + async def ask_name(ctx: Context) -> Elicit[Login]: + return Elicit("Name?", Login) + + async def ask_confirm(ctx: Context) -> Elicit[Confirm]: + return Elicit("Confirm?", Confirm) + + @mcp.tool() + async def act( + name: Annotated[Login, Resolve(ask_name)], + go: Annotated[Confirm, Resolve(ask_confirm)], + ) -> str: + return f"{name.username}:{go.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + + second = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_name): ElicitResult(action="accept", content={"username": "oc\ud800t"})}, + request_state=first.request_state, + allow_input_required=True, + ) + # The surrogate-bearing answer is recorded; only the unanswered question remains. + assert isinstance(second, InputRequiredResult) + assert second.input_requests is not None + assert set(second.input_requests) == {_wire_key(ask_confirm)} + + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask_confirm): ElicitResult(action="accept", content={"ok": True})}, + request_state=second.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "oc\ud800t:True" + + +@pytest.mark.anyio +async def test_resolver_elicitation_seals_and_completes_on_a_fully_default_server(): + # The headline default-posture invariant: a resolver tool on a bare MCPServer() - + # no name, no security configuration - mints sealed state and completes the round. + mcp = MCPServer() + + async def ask(ctx: Context) -> Elicit[Confirm]: + return Elicit("Go?", Confirm) + + @mcp.tool() + async def act(go: Annotated[Confirm, Resolve(ask)]) -> str: + return f"went:{go.ok}" + + async with Client(mcp, elicitation_callback=_never) as client: + first = await client.session.call_tool("act", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + assert first.request_state.startswith("v1.") + final = await client.session.call_tool( + "act", + {}, + input_responses={_wire_key(ask): ElicitResult(action="accept", content={"ok": True})}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(final, CallToolResult) + assert isinstance(final.content[0], TextContent) + assert final.content[0].text == "went:True" diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index b4a1184580..2ae9d5ff74 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1866,7 +1866,8 @@ def get_user(user_id: str) -> str: assert exc_info.value.error.data == {"uri": "resource://users/999"} -async def test_tool_returning_input_required_result_reaches_client_unchanged(): +async def test_tool_returning_input_required_result_reaches_client_sealed(): + # Default posture: the wire carries an opaque sealed token, never the handler's plaintext. mcp = MCPServer() @mcp.tool() @@ -1878,7 +1879,7 @@ async def ask(ctx: Context) -> str | InputRequiredResult: result = await client.session.call_tool("ask", allow_input_required=True) assert isinstance(result, InputRequiredResult) - assert result.request_state == "round-1" + _assert_sealed(result.request_state, "round-1") assert result.input_requests is not None assert result.input_requests["roots"].method == "roots/list" @@ -1927,6 +1928,13 @@ async def greet(ctx: Context) -> str | InputRequiredResult: assert block.text == "Hello, Alice! (state=r1)" +def _assert_sealed(state: str | None, plaintext: str) -> None: + """The wire form is an opaque sealed token, never the handler's plaintext.""" + assert state is not None + assert state != plaintext + assert state.startswith("v1.") + + def _ask_who() -> ElicitRequest: return ElicitRequest( params=ElicitRequestFormParams( @@ -1940,9 +1948,9 @@ def _ask_who() -> ElicitRequest: ) -async def test_prompt_returning_input_required_result_reaches_client_unchanged(): - """A prompt function may return an InputRequiredResult and the pipeline passes it - through to the client (spec-mandated: SEP-2322 allows it on prompts/get).""" +async def test_prompt_returning_input_required_result_reaches_client_sealed(): + """A prompt function may return an InputRequiredResult and the pipeline delivers it + to the client with the state sealed (spec-mandated: SEP-2322 allows it on prompts/get).""" mcp = MCPServer() @mcp.prompt() @@ -1954,7 +1962,7 @@ async def briefing(ctx: Context) -> list[UserMessage] | InputRequiredResult: result = await client.session.get_prompt("briefing", allow_input_required=True) assert isinstance(result, InputRequiredResult) - assert result.request_state == "round-1" + _assert_sealed(result.request_state, "round-1") assert result.input_requests is not None assert result.input_requests["who"].method == "elicitation/create" @@ -2023,9 +2031,9 @@ async def ask(topic: str, ctx: Context) -> str | InputRequiredResult: assert exc.value.error.message == "Handler returned an invalid result" -async def test_resource_template_returning_input_required_result_reaches_client_unchanged(): +async def test_resource_template_returning_input_required_result_reaches_client_sealed(): """A resource template function may return an InputRequiredResult and the pipeline - passes it through to the client (spec-mandated: SEP-2322 allows it on resources/read).""" + delivers it with the state sealed (spec-mandated: SEP-2322 allows it on resources/read).""" mcp = MCPServer() @mcp.resource("ask://{topic}") @@ -2037,7 +2045,7 @@ async def ask(topic: str, ctx: Context) -> str | InputRequiredResult: result = await client.session.read_resource("ask://databases", allow_input_required=True) assert isinstance(result, InputRequiredResult) - assert result.request_state == "round-1" + _assert_sealed(result.request_state, "round-1") assert result.input_requests is not None assert result.input_requests["who"].method == "elicitation/create" @@ -2121,22 +2129,26 @@ async def ask(topic: str, ctx: Context) -> str: return f"{topic} content" @mcp.tool() - async def outer(ctx: Context) -> str: + async def outer(ctx: Context) -> str | InputRequiredResult: + if ctx.input_responses is None: + return InputRequiredResult(input_requests={"who": _ask_who()}, request_state="outer-state") contents = list(await ctx.read_resource("ask://databases")) assert isinstance(contents[0].content, str) - return contents[0].content + return f"{contents[0].content} (state={ctx.request_state})" with anyio.fail_after(5): async with Client(mcp, mode="2026-07-28") as client: + r1 = await client.session.call_tool("outer", allow_input_required=True) + assert isinstance(r1, InputRequiredResult) result = await client.session.call_tool( "outer", input_responses={"who": ElicitResult(action="accept", content={"name": "Alice"})}, - request_state="outer-state", + request_state=r1.request_state, ) assert isinstance(result, CallToolResult) block = result.content[0] assert isinstance(block, TextContent) - assert block.text == "databases content" + assert block.text == "databases content (state=outer-state)" assert seen_responses == [None] assert seen_state == [None] diff --git a/tests/server/test_request_state.py b/tests/server/test_request_state.py new file mode 100644 index 0000000000..590c046e94 --- /dev/null +++ b/tests/server/test_request_state.py @@ -0,0 +1,479 @@ +"""Unit tests for `mcp.server.request_state`: codec, security policy, and default principal binding.""" + +import base64 +import string +from collections.abc import Callable +from typing import Any, cast + +import pytest +from inline_snapshot import snapshot + +from mcp.server.auth.middleware.auth_context import auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, authorization_context +from mcp.server.auth.provider import AccessToken, principal_components +from mcp.server.context import ServerRequestContext +from mcp.server.request_state import ( + AESGCMRequestStateCodec, + InvalidRequestState, + RequestStateSecurity, + authenticated_principal, +) + +_TOKEN_PREFIX = "v1." +_KID_LEN = 4 +_NONCE_LEN = 12 +_GCM_TAG_LEN = 16 +_BODY_FLOOR = _KID_LEN + _NONCE_LEN + _GCM_TAG_LEN +_B64URL_ALPHABET = set(string.ascii_letters + string.digits + "-_") + +_KEY_A = b"a" * 32 +_KEY_B = b"b" * 32 +_KEY_OLD = b"o" * 32 +_KEY_NEW = b"n" * 32 + +# Distinctive plaintext: opacity and log-secrecy assertions search for it. +_PAYLOAD = b"sentinel-plaintext-3f9c" +# `InvalidRequestState` messages are short log-only reason codes, never payload. +_REASON_CODE_MAX_LEN = 40 + + +def _b64u_nopad(data: bytes) -> str: + return base64.urlsafe_b64encode(data).decode().rstrip("=") + + +def _decode_body(token: str) -> bytes: + body = token.removeprefix(_TOKEN_PREFIX) + return base64.urlsafe_b64decode(body + "=" * (-len(body) % 4)) + + +def _flip_body_byte(token: str, index: int) -> str: + raw = bytearray(_decode_body(token)) + raw[index] ^= 0xFF + return _TOKEN_PREFIX + _b64u_nopad(bytes(raw)) + + +def _flip_prefix_char(token: str) -> str: + return "x" + token[1:] + + +def _flip_kid_byte(token: str) -> str: + return _flip_body_byte(token, 0) + + +def _flip_nonce_byte(token: str) -> str: + return _flip_body_byte(token, _KID_LEN) + + +def _flip_ciphertext_byte(token: str) -> str: + return _flip_body_byte(token, _KID_LEN + _NONCE_LEN) + + +def _flip_tag_byte(token: str) -> str: + return _flip_body_byte(token, -1) + + +def _inject_junk_chars(body: str) -> str: + return body[:10] + "!@\n*" + body[10:] + + +def _append_newline(body: str) -> str: + return body + "\n" + + +def _append_padding(body: str) -> str: + return body + "=" * (-len(body) % 4 or 4) + + +def _bare_context() -> ServerRequestContext[Any, Any]: + return ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/call", + ) + + +class _StaticCodec: + """Minimal `RequestStateCodec` stand-in for policy tests; no real crypto.""" + + def seal(self, payload: bytes) -> str: + return payload.hex() + + def unseal(self, token: str) -> bytes: + return bytes.fromhex(token) + + +# -- AESGCMRequestStateCodec -------------------------------------------------- + + +@pytest.mark.parametrize( + "payload", + [ + pytest.param(b"", id="empty"), + pytest.param(b"plain ascii state", id="ascii"), + pytest.param("ünïcødé – 状態".encode(), id="multi-byte-utf8"), + pytest.param(bytes(range(256)), id="raw-binary"), + pytest.param(bytes(range(256)) * 256, id="64KiB"), + ], +) +def test_seal_unseal_round_trips_any_payload(payload: bytes) -> None: + """SDK-defined: the codec is byte-transparent, so any payload survives seal/unseal unchanged.""" + codec = AESGCMRequestStateCodec([_KEY_A]) + assert codec.unseal(codec.seal(payload)) == payload + + +def test_a_sealed_token_is_v1_plus_unpadded_b64url_over_kid_nonce_and_ciphertext() -> None: + """SDK-defined token format: "v1." plus unpadded base64url over kid(4) || nonce(12) || ciphertext+tag.""" + token = AESGCMRequestStateCodec([_KEY_A]).seal(_PAYLOAD) + assert token.startswith(_TOKEN_PREFIX) + body = token.removeprefix(_TOKEN_PREFIX) + assert "=" not in body + assert set(body) <= _B64URL_ALPHABET + assert len(_decode_body(token)) == _KID_LEN + _NONCE_LEN + len(_PAYLOAD) + _GCM_TAG_LEN + + +def test_two_seals_of_the_same_payload_produce_distinct_tokens_that_both_unseal() -> None: + """SDK-defined: every seal draws a fresh nonce, so identical payloads yield distinct tokens that both verify.""" + codec = AESGCMRequestStateCodec([_KEY_A]) + first = codec.seal(_PAYLOAD) + second = codec.seal(_PAYLOAD) + assert first != second + assert codec.unseal(first) == _PAYLOAD + assert codec.unseal(second) == _PAYLOAD + + +@pytest.mark.parametrize( + "corrupt", + [ + pytest.param(_flip_prefix_char, id="prefix-char"), + pytest.param(_flip_kid_byte, id="kid-byte"), + pytest.param(_flip_nonce_byte, id="nonce-byte"), + pytest.param(_flip_ciphertext_byte, id="ciphertext-byte"), + pytest.param(_flip_tag_byte, id="tag-byte"), + ], +) +def test_a_token_corrupted_in_any_region_is_rejected_without_echoing_the_payload( + corrupt: Callable[[str], str], +) -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): any corrupted token region is rejected.""" + codec = AESGCMRequestStateCodec([_KEY_A]) + token = codec.seal(_PAYLOAD) + with pytest.raises(InvalidRequestState) as exc: + codec.unseal(corrupt(token)) + message = str(exc.value) + assert len(message) <= _REASON_CODE_MAX_LEN + assert _PAYLOAD.decode() not in message + + +@pytest.mark.parametrize( + "token", + [ + pytest.param("", id="empty-string"), + pytest.param(_b64u_nopad(b"\x00" * 64), id="missing-prefix"), + pytest.param(_TOKEN_PREFIX + "!!!not-base64!!!", id="garbage-after-prefix"), + pytest.param(_TOKEN_PREFIX + _b64u_nopad(b"\x00" * (_BODY_FLOOR - 1)), id="below-floor"), + ], +) +def test_a_structurally_malformed_token_is_rejected(token: str) -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): tokens this codec never minted fail.""" + with pytest.raises(InvalidRequestState): + AESGCMRequestStateCodec([_KEY_A]).unseal(token) + + +def test_a_token_minted_under_a_key_outside_the_ring_is_rejected_as_unknown_key() -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): a foreign-key token fails as "unknown key".""" + token = AESGCMRequestStateCodec([_KEY_A]).seal(_PAYLOAD) + with pytest.raises(InvalidRequestState) as exc: + AESGCMRequestStateCodec([_KEY_B]).unseal(token) + assert str(exc.value) == "unknown key" + + +@pytest.mark.parametrize( + "ring", + [ + pytest.param([_KEY_OLD, _KEY_NEW], id="rotation-phase-1"), + pytest.param([_KEY_NEW, _KEY_OLD], id="rotation-phase-2"), + ], +) +def test_a_token_minted_under_the_old_key_unseals_under_any_ring_containing_it(ring: list[bytes]) -> None: + """SDK-defined rotation: every ring key verifies, so old-key state survives both rollout phases.""" + token = AESGCMRequestStateCodec([_KEY_OLD]).seal(_PAYLOAD) + assert AESGCMRequestStateCodec(ring).unseal(token) == _PAYLOAD + + +def test_the_first_ring_key_mints_and_later_ring_keys_only_verify() -> None: + """SDK-defined rotation: keys[0] is the minter, so [new, old] state verifies under [new] but not [old].""" + token = AESGCMRequestStateCodec([_KEY_NEW, _KEY_OLD]).seal(_PAYLOAD) + assert AESGCMRequestStateCodec([_KEY_NEW]).unseal(token) == _PAYLOAD + with pytest.raises(InvalidRequestState): + AESGCMRequestStateCodec([_KEY_OLD]).unseal(token) + + +def test_a_token_minted_under_a_retired_key_is_rejected() -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): retired-key state fails verification.""" + token = AESGCMRequestStateCodec([_KEY_OLD]).seal(_PAYLOAD) + with pytest.raises(InvalidRequestState): + AESGCMRequestStateCodec([_KEY_NEW]).unseal(token) + + +def test_an_empty_key_ring_is_rejected_at_construction() -> None: + """SDK-defined: an empty ring is a configuration error caught at construction.""" + with pytest.raises(ValueError) as exc: + AESGCMRequestStateCodec([]) + assert str(exc.value) == snapshot("AESGCMRequestStateCodec requires at least one key") + + +def test_a_key_shorter_than_32_bytes_is_rejected_with_generation_guidance() -> None: + """SDK-defined: keys must carry at least 32 bytes; the error includes generation guidance.""" + with pytest.raises(ValueError) as exc: + AESGCMRequestStateCodec([b"k" * 31]) + assert str(exc.value) == snapshot( + "request-state keys must be at least 32 bytes of secret randomness; keys[0] is 31 bytes. " + 'Generate one with: python -c "import secrets; print(secrets.token_hex(32))"' + ) + + +def test_a_duplicate_key_in_the_ring_is_rejected_at_construction() -> None: + """SDK-defined: duplicate ring keys are a rotation mistake caught at construction.""" + with pytest.raises(ValueError) as exc: + AESGCMRequestStateCodec([_KEY_A, _KEY_A]) + assert str(exc.value) == snapshot("keys[1] duplicates an earlier ring key") + + +def test_a_non_key_typed_ring_entry_is_rejected_naming_its_index_and_type() -> None: + """SDK-defined: a non-key ring entry raises a TypeError naming its index and type, in codec and policy.""" + with pytest.raises(TypeError) as exc: + AESGCMRequestStateCodec([_KEY_A, cast("Any", 32)]) + assert str(exc.value) == snapshot("request-state keys must be bytes, bytearray, or str; keys[1] is int") + with pytest.raises(TypeError) as exc: + RequestStateSecurity(keys=[cast("Any", 32)]) + assert str(exc.value) == snapshot("request-state keys must be bytes, bytearray, or str; keys[0] is int") + + +def test_a_mixed_ring_of_bytes_bytearray_and_str_entries_still_works() -> None: + """SDK-defined: bytes, bytearray, and str keys interoperate in one ring.""" + codec = AESGCMRequestStateCodec([_KEY_A, bytearray(_KEY_B), "c" * 32]) + assert codec.unseal(codec.seal(_PAYLOAD)) == _PAYLOAD + assert codec.unseal(AESGCMRequestStateCodec([bytearray(_KEY_B)]).seal(_PAYLOAD)) == _PAYLOAD + assert codec.unseal(AESGCMRequestStateCodec(["c" * 32]).seal(_PAYLOAD)) == _PAYLOAD + + +def test_a_str_key_is_equivalent_to_its_utf8_bytes_form() -> None: + """SDK-defined: a str key is utf-8 encoded, so it is the same ring key as its bytes spelling.""" + token = AESGCMRequestStateCodec(["k" * 32]).seal(_PAYLOAD) + assert AESGCMRequestStateCodec([b"k" * 32]).unseal(token) == _PAYLOAD + + +def test_bytearray_key_material_is_copied_at_construction() -> None: + """SDK-defined: key bytes are copied at construction; mutating the caller's bytearray later has no effect.""" + material = bytearray(b"m" * 32) + codec = AESGCMRequestStateCodec([cast("Any", material)]) + minted_before_mutation = codec.seal(_PAYLOAD) + material[:] = b"X" * 32 + assert codec.unseal(minted_before_mutation) == _PAYLOAD + assert AESGCMRequestStateCodec([b"m" * 32]).unseal(codec.seal(_PAYLOAD)) == _PAYLOAD + + +def test_the_token_reveals_the_payload_neither_in_its_text_nor_its_decoded_bytes() -> None: + """SDK-defined: the token is encrypted, not merely signed, so the plaintext appears nowhere in it.""" + token = AESGCMRequestStateCodec([_KEY_A]).seal(_PAYLOAD) + assert _PAYLOAD.decode() not in token + assert _b64u_nopad(_PAYLOAD) not in token + assert _PAYLOAD.hex() not in token + assert _PAYLOAD not in _decode_body(token) + + +def test_every_substitution_of_the_final_token_character_is_rejected() -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): canonical decoding + rejects every final-character substitution despite base64 don't-care padding bits.""" + codec = AESGCMRequestStateCodec([_KEY_A]) + body = codec.seal(_PAYLOAD).removeprefix(_TOKEN_PREFIX) + substitutions = [c for c in sorted(_B64URL_ALPHABET) if c != body[-1]] + assert len(substitutions) == 63 + for c in substitutions: + with pytest.raises(InvalidRequestState): + codec.unseal(_TOKEN_PREFIX + body[:-1] + c) + + +@pytest.mark.parametrize( + "mangle", + [ + pytest.param(_inject_junk_chars, id="junk-chars-injected"), + pytest.param(_append_newline, id="newline-appended"), + pytest.param(_append_padding, id="padding-appended"), + ], +) +def test_a_non_canonical_token_body_is_rejected(mangle: Callable[[str], str]) -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): lax-decoder aliases of a token are rejected.""" + codec = AESGCMRequestStateCodec([_KEY_A]) + body = codec.seal(_PAYLOAD).removeprefix(_TOKEN_PREFIX) + with pytest.raises(InvalidRequestState): + codec.unseal(_TOKEN_PREFIX + mangle(body)) + + +def test_a_token_reprefixed_to_a_future_format_version_is_rejected() -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): the prefix is tag-bound; "v2." replay fails.""" + codec = AESGCMRequestStateCodec([_KEY_A]) + token = codec.seal(_PAYLOAD) + with pytest.raises(InvalidRequestState): + codec.unseal("v2." + token.removeprefix(_TOKEN_PREFIX)) + + +def test_a_kid_transplanted_onto_another_tokens_body_is_rejected() -> None: + """Spec-mandated (basic/patterns/mrtr, server requirement 4): the kid is tag-bound; transplanting it fails.""" + raw_a = _decode_body(AESGCMRequestStateCodec([_KEY_A]).seal(_PAYLOAD)) + raw_b = _decode_body(AESGCMRequestStateCodec([_KEY_B]).seal(_PAYLOAD)) + assert raw_a[:_KID_LEN] != raw_b[:_KID_LEN] + transplanted = _TOKEN_PREFIX + _b64u_nopad(raw_a[:_KID_LEN] + raw_b[_KID_LEN:]) + with pytest.raises(InvalidRequestState): + AESGCMRequestStateCodec([_KEY_A, _KEY_B]).unseal(transplanted) + + +# -- RequestStateSecurity ----------------------------------------------------- + + +def test_keys_and_codec_together_are_rejected_at_policy_construction() -> None: + """SDK-defined: keys= and codec= are mutually exclusive.""" + with pytest.raises(ValueError) as exc: + RequestStateSecurity(keys=[_KEY_A], codec=_StaticCodec()) + assert str(exc.value) == snapshot("RequestStateSecurity takes exactly one of keys= or codec=") + + +def test_a_policy_with_neither_keys_nor_codec_is_rejected() -> None: + """SDK-defined: a policy must name its codec; an empty policy is a mistake, not a posture.""" + with pytest.raises(ValueError) as exc: + RequestStateSecurity() + assert str(exc.value) == snapshot("RequestStateSecurity takes exactly one of keys= or codec=") + + +@pytest.mark.parametrize( + "ttl", + [ + pytest.param(0.0, id="zero"), + pytest.param(-600.0, id="negative"), + pytest.param(float("nan"), id="nan"), + pytest.param(float("inf"), id="inf"), + ], +) +def test_a_non_positive_or_non_finite_ttl_is_rejected_at_policy_construction(ttl: float) -> None: + """SDK-defined: zero, negative, NaN, and infinite ttl fail at construction for keys and ephemeral() alike.""" + with pytest.raises(ValueError, match="positive finite"): + RequestStateSecurity(keys=[_KEY_A], ttl=ttl) + with pytest.raises(ValueError, match="positive finite"): + RequestStateSecurity.ephemeral(ttl=ttl) + + +def test_keys_produce_a_working_built_in_codec_on_the_policy() -> None: + """SDK-defined: keys=[...] builds the built-in AES-GCM codec, exposed on .codec.""" + security = RequestStateSecurity(keys=[_KEY_A]) + assert isinstance(security.codec, AESGCMRequestStateCodec) + assert security.codec.unseal(security.codec.seal(_PAYLOAD)) == _PAYLOAD + + +def test_a_custom_codec_is_stored_on_the_policy_as_is() -> None: + """SDK-defined: codec=... stores the caller's object unwrapped.""" + codec = _StaticCodec() + security = RequestStateSecurity(codec=codec) + assert security.codec is codec + assert codec.unseal(codec.seal(_PAYLOAD)) == _PAYLOAD + + +def test_ephemeral_policies_are_protected_and_mutually_unintelligible() -> None: + """SDK-defined: ephemeral() protects under a process-local key, so a sibling instance rejects its tokens.""" + first = RequestStateSecurity.ephemeral() + second = RequestStateSecurity.ephemeral() + token = first.codec.seal(_PAYLOAD) + assert first.codec.unseal(token) == _PAYLOAD + with pytest.raises(InvalidRequestState): + second.codec.unseal(token) + + +def test_the_policy_stores_an_explicit_audience_and_defaults_to_none() -> None: + """SDK-defined: audience is stored as given; None defers to the server tier's `default_audience`.""" + assert RequestStateSecurity(keys=[_KEY_A]).audience is None + assert RequestStateSecurity(keys=[_KEY_A], audience="svc").audience == "svc" + assert RequestStateSecurity.ephemeral(audience="svc").audience == "svc" + + +def test_the_default_principal_binding_is_authenticated_principal() -> None: + """SDK-defined: an unconfigured policy binds state to the authenticated OAuth client by default.""" + assert RequestStateSecurity(keys=[_KEY_A]).bind_principal is authenticated_principal + + +def test_an_explicit_principal_binding_callable_is_stored() -> None: + """SDK-defined: a custom bind_principal callable is stored as given.""" + + def tenant_binding(ctx: ServerRequestContext[Any, Any]) -> str | None: + return "tenant-1" + + security = RequestStateSecurity(keys=[_KEY_A], bind_principal=tenant_binding) + assert security.bind_principal is tenant_binding + assert tenant_binding(_bare_context()) == "tenant-1" + + +# -- authenticated_principal ---------------------------------------------------- + + +def test_authenticated_principal_is_none_without_an_auth_context() -> None: + """SDK-defined: without an auth context the default binding derives no principal.""" + assert authenticated_principal(_bare_context()) is None + + +@pytest.mark.parametrize( + ("token", "expected"), + [ + pytest.param( + AccessToken(token="at-1", client_id="client-123", scopes=[]), + '["client-123",null,null]', + id="client-only", + ), + pytest.param( + AccessToken(token="at-2", client_id="client-123", scopes=[], subject="alice"), + '["client-123",null,"alice"]', + id="with-subject", + ), + pytest.param( + AccessToken( + token="at-3", client_id="client-123", scopes=[], subject="alice", claims={"iss": "https://as.example"} + ), + '["client-123","https://as.example","alice"]', + id="with-issuer-and-subject", + ), + ], +) +def test_authenticated_principal_is_the_tokens_client_issuer_subject_identity( + token: AccessToken, expected: str +) -> None: + """SDK-defined: the default binding composes (client_id, issuer, subject), degrading per component.""" + reset = auth_context_var.set(AuthenticatedUser(token)) + try: + assert authenticated_principal(_bare_context()) == expected + finally: + auth_context_var.reset(reset) + + +def test_authenticated_principal_distinguishes_two_subjects_of_one_client() -> None: + """SDK-defined: two users of the same OAuth client are distinct principals when subjects are supplied.""" + alice = AccessToken(token="at-a", client_id="https://agent.example/client.json", scopes=[], subject="alice") + bob = AccessToken(token="at-b", client_id="https://agent.example/client.json", scopes=[], subject="bob") + principals: list[str | None] = [] + for token in (alice, bob): + reset = auth_context_var.set(AuthenticatedUser(token)) + try: + principals.append(authenticated_principal(_bare_context())) + finally: + auth_context_var.reset(reset) + assert principals[0] != principals[1] + + +def test_authenticated_principal_uses_the_same_components_as_session_ownership() -> None: + """SDK-defined: the binding and authorization_context derive from one principal_components source.""" + token = AccessToken( + token="at-1", client_id="client-123", scopes=[], subject="alice", claims={"iss": "https://as.example"} + ) + assert authorization_context(AuthenticatedUser(token)) == { + "client_id": "client-123", + "issuer": "https://as.example", + "subject": "alice", + } + assert list(principal_components(token)) == ["client-123", "https://as.example", "alice"] diff --git a/tests/server/test_request_state_boundary.py b/tests/server/test_request_state_boundary.py new file mode 100644 index 0000000000..ed4e5b0662 --- /dev/null +++ b/tests/server/test_request_state_boundary.py @@ -0,0 +1,1297 @@ +"""`RequestStateBoundary` end to end: seal outbound, verify and restore inbound, one frozen error on failure.""" + +import json +import logging +from collections.abc import Awaitable, Callable +from typing import Any, cast + +import anyio +import pytest +from mcp_types import ( + INTERNAL_ERROR, + INVALID_PARAMS, + CallToolRequestParams, + CallToolResult, + ElicitRequest, + ElicitRequestFormParams, + ElicitResult, + InputRequiredResult, + ListToolsResult, + PaginatedRequestParams, + ReadResourceResult, + RequestParams, + TextContent, + TextResourceContents, + Tool, +) + +import mcp.server.request_state as request_state_module +from mcp import Client +from mcp.server import MCPServer, Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import auth_context_var +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken +from mcp.server.context import HandlerResult +from mcp.server.mcpserver import Context +from mcp.server.mcpserver.server import _MISSING_AUDIENCE +from mcp.server.request_state import ( + AESGCMRequestStateCodec, + InvalidRequestState, + RequestStateBoundary, + RequestStateSecurity, +) +from mcp.shared.exceptions import MCPError + +from .test_runner import connected_runner + +pytestmark = pytest.mark.anyio + +_KEY = b"0123456789abcdef0123456789abcdef" # 32 bytes +_T0 = 1_782_345_600.0 # frozen mint instant for clock-controlled tests +_TTL = 600.0 + + +def _ask(message: str) -> ElicitRequest: + """A minimal elicitation request for a manual tool's `input_requests`.""" + return ElicitRequest( + params=ElicitRequestFormParams( + message=message, + requested_schema={ + "type": "object", + "properties": {"confirm": {"type": "boolean"}}, + "required": ["confirm"], + }, + ) + ) + + +def _accept() -> ElicitResult: + return ElicitResult(action="accept", content={"confirm": True}) + + +async def _list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + """`ClientSession.call_tool` consults tools/list, so lowlevel fixtures must answer it.""" + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) + + +class _PassthroughCodec: + """Cryptography-free codec (the token IS the payload) that puts arbitrary bytes behind a successful unseal.""" + + def seal(self, payload: bytes) -> str: + return payload.decode() + + def unseal(self, token: str) -> bytes: + return token.encode() + + +class _CustomMethodParams(RequestParams): + """Params for a custom (non-carrier) method.""" + + request_state: str | None = None + + +class _Clock: + """Stands in for the `time` module inside `mcp.server.request_state`.""" + + def __init__(self, now: float) -> None: + self.now = now + + def time(self) -> float: + return self.now + + +def _tamper(token: str) -> str: + """Flip one mid-token character; strict canonical decoding rejects any single-character change.""" + i = len(token) // 2 + return token[:i] + ("A" if token[i] != "A" else "B") + token[i + 1 :] + + +def _assert_frozen_rejection(exc: pytest.ExceptionInfo[MCPError]) -> None: + """Assert the single frozen wire shape for every inbound verification failure.""" + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "Invalid or expired requestState" + assert exc.value.error.data == {"reason": "invalid_request_state"} + + +def _manual_server( + security: RequestStateSecurity | None, *, state: str = "awaiting-confirm", name: str = "manual" +) -> tuple[MCPServer, list[str | None]]: + """MCPServer with one manual MRTR tool: round 1 asks, the retry records the echoed `ctx.request_state`. + + `security=None` exercises the default posture (process-local ephemeral sealing), not plaintext. + """ + seen: list[str | None] = [] + mcp = MCPServer(name, request_state_security=security) + + @mcp.tool() + async def deploy(env: str, ctx: Context) -> str | InputRequiredResult: + if ctx.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask(f"Deploy to {env}?")}, request_state=state) + seen.append(ctx.request_state) + return f"deployed to {env}" + + return mcp, seen + + +async def _first_round(client: Client, name: str, args: dict[str, Any]) -> str: + """Round 1 of the manual loop: call without responses, return the wire token.""" + first = await client.session.call_tool(name, args, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + return first.request_state + + +async def _retry(client: Client, name: str, args: dict[str, Any], token: str) -> CallToolResult | InputRequiredResult: + """The retry round: echo the wire token with the elicited answer attached.""" + return await client.session.call_tool( + name, args, input_responses={"confirm": _accept()}, request_state=token, allow_input_required=True + ) + + +# -- end-to-end seal/unseal through the public surfaces ------------------------------- + + +async def test_request_state_is_sealed_on_the_wire_and_restored_for_the_handler() -> None: + """Spec-mandated (basic/patterns/mrtr server requirements 4-5): the wire carries an + opaque token, never the handler's plaintext, and a faithful echo restores it.""" + plaintext = "awaiting-confirm:9f2e" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY]), state=plaintext) + + with anyio.fail_after(5): + async with Client(mcp) as client: + first = await client.session.call_tool("deploy", {"env": "prod"}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + assert first.request_state != plaintext + assert first.request_state.startswith("v1.") + second = await _retry(client, "deploy", {"env": "prod"}, first.request_state) + + assert isinstance(second, CallToolResult) + assert not second.is_error + assert isinstance(second.content[0], TextContent) + assert second.content[0].text == "deployed to prod" + assert seen == [plaintext] + + +async def test_lowlevel_server_gets_identical_sealing_from_the_one_line_middleware_append() -> None: + """Spec-mandated (basic/patterns/mrtr server requirements 4-5): appending the public + `RequestStateBoundary` to `Server.middleware` gives the lowlevel tier the same sealing.""" + plaintext = "lowlevel-round-1" + seen: list[str | None] = [] + + async def call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | InputRequiredResult: + if params.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask("Proceed?")}, request_state=plaintext) + seen.append(params.request_state) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("srv", on_call_tool=call_tool, on_list_tools=_list_tools) + server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience=server.name)) + + with anyio.fail_after(5): + async with Client(server) as client: + first = await client.session.call_tool("t", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + assert first.request_state != plaintext + assert first.request_state.startswith("v1.") + second = await _retry(client, "t", {}, first.request_state) + + assert isinstance(second, CallToolResult) + assert seen == [plaintext] + claims = json.loads(AESGCMRequestStateCodec([_KEY]).unseal(first.request_state)) + assert claims["aud"] == "srv" + + +async def test_a_resource_template_flow_seals_on_resources_read_and_restores_the_plaintext() -> None: + """Spec-mandated (basic/patterns/mrtr server requirements 4-5): resources/read is an + MRTR carrier, so a template's `requestState` crosses sealed and bound to the uri.""" + plaintext = "resource-round-1" + seen: list[str | None] = [] + mcp = MCPServer("templated", request_state_security=RequestStateSecurity(keys=[_KEY])) + + @mcp.resource("deploy://{env}/confirm") + async def confirm(env: str, ctx: Context) -> str | InputRequiredResult: + if ctx.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask(f"Read {env}?")}, request_state=plaintext) + seen.append(ctx.request_state) + return f"confirmed {env}" + + with anyio.fail_after(5): + async with Client(mcp) as client: + first = await client.session.read_resource("deploy://prod/confirm", allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + assert first.request_state != plaintext + assert first.request_state.startswith("v1.") + second = await client.session.read_resource( + "deploy://prod/confirm", + input_responses={"confirm": _accept()}, + request_state=first.request_state, + allow_input_required=True, + ) + + assert isinstance(second, ReadResourceResult) + assert isinstance(second.contents[0], TextResourceContents) + assert second.contents[0].text == "confirmed prod" + claims = json.loads(AESGCMRequestStateCodec([_KEY]).unseal(first.request_state)) + assert (claims["m"], claims["t"], claims["s"]) == ("resources/read", "deploy://prod/confirm", plaintext) + assert seen == [plaintext] + + +# -- verification failures: tamper, expiry, future skew ------------------------------- + + +async def test_tampered_request_state_is_rejected_with_the_frozen_wire_error() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 5): a modified echo is + rejected with the frozen -32602 shape and the handler never runs.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, _tamper(token)) + _assert_frozen_rejection(exc) + + assert seen == [] + + +async def test_expired_request_state_is_rejected_and_just_inside_ttl_is_accepted( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Spec-mandated (basic/patterns/mrtr server requirements 4-5): one second past `ttl` + is rejected, one second inside completes.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], ttl=_TTL)) + clock = _Clock(_T0) + monkeypatch.setattr(request_state_module, "time", clock) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) # minted at _T0 + clock.now = _T0 + _TTL + 1 + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + clock.now = _T0 + _TTL - 1 + second = await _retry(client, "deploy", {"env": "prod"}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +async def test_state_minted_in_the_future_is_rejected_beyond_the_sixty_second_skew( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Spec-mandated (basic/patterns/mrtr server requirements 4-5): a token minted 120 s + ahead of the verifier's clock is rejected, 30 s ahead is inside the skew allowance.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], ttl=_TTL)) + clock = _Clock(_T0) + monkeypatch.setattr(request_state_module, "time", clock) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) # minted at _T0 + clock.now = _T0 - 120 + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + clock.now = _T0 - 30 + second = await _retry(client, "deploy", {"env": "prod"}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +# -- request binding ------------------------------------------------------------------- + + +async def test_round_one_state_replayed_on_a_different_tool_is_rejected() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 4): a token minted for tool + A is rejected when echoed on tool B of the same server.""" + seen: list[str | None] = [] + + def make_tool(state: str) -> Callable[[Context], Awaitable[str | InputRequiredResult]]: + async def tool(ctx: Context) -> str | InputRequiredResult: + if ctx.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask(state)}, request_state=state) + seen.append(ctx.request_state) + return "done" + + return tool + + mcp = MCPServer("two-tools", request_state_security=RequestStateSecurity(keys=[_KEY])) + mcp.tool(name="alpha")(make_tool("alpha-state")) + mcp.tool(name="beta")(make_tool("beta-state")) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "alpha", {}) + with pytest.raises(MCPError) as exc: + await _retry(client, "beta", {}, token) + second = await _retry(client, "alpha", {}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen == ["alpha-state"] + + +async def test_retry_with_different_arguments_is_rejected_and_the_original_arguments_complete() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 4): the same tool retried + with different arguments is rejected.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "staging"}, token) + second = await _retry(client, "deploy", {"env": "prod"}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +# -- principal binding ----------------------------------------------------------------- + + +async def test_state_minted_with_a_principal_is_rejected_when_the_verifier_derives_none() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 4): state sealed for a + principal is rejected when the verifying round derives none.""" + principal: list[str | None] = ["alice"] + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=lambda ctx: principal[0])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + principal[0] = None + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + _assert_frozen_rejection(exc) + + assert seen == [] + + +async def test_state_minted_without_a_principal_is_rejected_when_the_verifier_derives_one() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 4): unbound state is + rejected once the verifying round derives a principal.""" + principal: list[str | None] = [None] + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=lambda ctx: principal[0])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + principal[0] = "alice" + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + _assert_frozen_rejection(exc) + + assert seen == [] + + +async def test_state_for_a_different_principal_is_rejected_and_the_same_principal_completes() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 4): one principal's token is + rejected when echoed by another and accepted when the same principal returns.""" + principal: list[str | None] = ["alice"] + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=lambda ctx: principal[0])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + principal[0] = "bob" + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + principal[0] = "alice" + second = await _retry(client, "deploy", {"env": "prod"}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +async def test_a_principal_binding_that_raises_fails_the_seal_as_an_internal_error( + caplog: pytest.LogCaptureFixture, +) -> None: + """SDK-defined: a raising `bind_principal` fails the seal as a bare internal error, not an unbound mint.""" + + def boom(ctx: ServerRequestContext[Any, Any]) -> str | None: + raise RuntimeError("identity provider down") + + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=boom)) + + with anyio.fail_after(5): + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc: + await client.session.call_tool("deploy", {"env": "prod"}, allow_input_required=True) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "Internal error" + assert exc.value.error.data is None # the reason never reaches the wire + + assert seen == [] + assert any(r.exc_info is not None and r.exc_info[0] is RuntimeError for r in caplog.records) + + +async def test_a_principal_binding_that_raises_fails_the_unseal_with_the_frozen_rejection( + caplog: pytest.LogCaptureFixture, +) -> None: + """SDK-defined: a `bind_principal` that raises while verifying collapses to the frozen rejection.""" + rounds: list[int] = [] + + def flaky(ctx: ServerRequestContext[Any, Any]) -> str | None: + rounds.append(1) + if len(rounds) == 1: + return "alice" # mint round succeeds + raise RuntimeError("identity provider down") # verify round raises + + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=flaky)) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + _assert_frozen_rejection(exc) + + assert seen == [] + assert any(r.exc_info is not None and r.exc_info[0] is RuntimeError for r in caplog.records) + + +async def test_two_mints_for_the_same_principal_carry_different_salted_principal_claims() -> None: + """SDK-defined: the `p` claim is salted per mint, so two tokens for the same principal are not linkable.""" + mcp, _ = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=lambda ctx: "alice")) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token_one = await _first_round(client, "deploy", {"env": "prod"}) + token_two = await _first_round(client, "deploy", {"env": "prod"}) + + codec = AESGCMRequestStateCodec([_KEY]) + claims_one = json.loads(codec.unseal(token_one)) + claims_two = json.loads(codec.unseal(token_two)) + assert "p" in claims_one + assert "p" in claims_two + assert claims_one["p"] != claims_two["p"] + + +# -- audience binding ------------------------------------------------------------------ + + +async def test_two_servers_sharing_a_key_reject_each_others_state_via_the_name_audience() -> None: + """SDK-defined: the server name is the default audience, so servers sharing a key reject each other's state.""" + mcp_billing, seen_billing = _manual_server(RequestStateSecurity(keys=[_KEY]), name="billing") + mcp_payments, seen_payments = _manual_server(RequestStateSecurity(keys=[_KEY]), name="payments") + + with anyio.fail_after(5): + async with Client(mcp_billing) as billing, Client(mcp_payments) as payments: + token = await _first_round(billing, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as exc: + await _retry(payments, "deploy", {"env": "prod"}, token) + second = await _retry(billing, "deploy", {"env": "prod"}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen_billing == ["awaiting-confirm"] + assert seen_payments == [] + + +async def test_audience_presence_drift_is_rejected_in_both_directions() -> None: + """SDK-defined: audience presence drift is rejected in both directions; each boundary accepts its own mint.""" + + def make_server(boundary: RequestStateBoundary) -> Server: + async def call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | InputRequiredResult: + if params.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask("Go?")}, request_state="round-1") + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("srv", on_call_tool=call_tool, on_list_tools=_list_tools) + server.middleware.append(boundary) + return server + + security = RequestStateSecurity(keys=[_KEY]) + bound = make_server(RequestStateBoundary(security, default_audience="svc")) + unbound = make_server(RequestStateBoundary(security, default_audience=None)) + + with anyio.fail_after(5): + async with Client(bound) as on_bound, Client(unbound) as on_unbound: + bound_token = await _first_round(on_bound, "t", {}) + unbound_token = await _first_round(on_unbound, "t", {}) + with pytest.raises(MCPError) as bound_state_on_unbound: + await _retry(on_unbound, "t", {}, bound_token) + with pytest.raises(MCPError) as unbound_state_on_bound: + await _retry(on_bound, "t", {}, unbound_token) + assert isinstance(await _retry(on_bound, "t", {}, bound_token), CallToolResult) + assert isinstance(await _retry(on_unbound, "t", {}, unbound_token), CallToolResult) + + _assert_frozen_rejection(bound_state_on_unbound) + _assert_frozen_rejection(unbound_state_on_bound) + + +async def test_an_explicit_policy_audience_overrides_the_server_name_default() -> None: + """SDK-defined: `RequestStateSecurity(audience=...)` overrides the server-name default.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], audience="prod-fleet"), name="one-box") + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + second = await _retry(client, "deploy", {"env": "prod"}, token) + + claims = json.loads(AESGCMRequestStateCodec([_KEY]).unseal(token)) + assert claims["aud"] == "prod-fleet" + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +# -- claims envelope (white-box through the public codec) ----------------------------- + + +async def test_claims_envelope_carries_the_documented_fields_and_omits_p_when_unbound() -> None: + """SDK-defined: the sealed payload is the documented claims JSON; no `p` claim when the principal is None.""" + plaintext = "step-one" + mcp, _ = _manual_server( + RequestStateSecurity(keys=[_KEY], ttl=_TTL, bind_principal=lambda ctx: None), state=plaintext + ) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + + claims = json.loads(AESGCMRequestStateCodec([_KEY]).unseal(token)) + assert set(claims) == {"v", "iat", "exp", "m", "t", "a", "s", "aud"} + assert claims["v"] == 1 + assert claims["exp"] == claims["iat"] + int(_TTL) + assert claims["m"] == "tools/call" + assert claims["t"] == "deploy" + assert isinstance(claims["a"], str) and claims["a"] + assert claims["aud"] == "manual" # the MCPServer name, the boundary's default audience + assert claims["s"] == plaintext + + +async def test_each_round_is_resealed_with_a_fresh_token_and_a_restamped_iat( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """SDK-defined: every round reseals with a fresh token and `iat`, so `ttl` bounds per-round think time.""" + mcp = MCPServer("wizard-server", request_state_security=RequestStateSecurity(keys=[_KEY], ttl=_TTL)) + + @mcp.tool() + async def wizard(ctx: Context) -> str | InputRequiredResult: + if ctx.input_responses is None: + return InputRequiredResult(input_requests={"first": _ask("First?")}, request_state="step-1") + if ctx.request_state == "step-1": + return InputRequiredResult(input_requests={"second": _ask("Second?")}, request_state="step-2") + return "done" + + clock = _Clock(_T0) + monkeypatch.setattr(request_state_module, "time", clock) + + with anyio.fail_after(5): + async with Client(mcp) as client: + first = await client.session.call_tool("wizard", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + clock.now = _T0 + 5 + second = await client.session.call_tool( + "wizard", + {}, + input_responses={"first": _accept()}, + request_state=first.request_state, + allow_input_required=True, + ) + assert isinstance(second, InputRequiredResult) + assert second.request_state is not None + third = await client.session.call_tool( + "wizard", + {}, + input_responses={"second": _accept()}, + request_state=second.request_state, + allow_input_required=True, + ) + + assert isinstance(third, CallToolResult) + assert first.request_state != second.request_state + codec = AESGCMRequestStateCodec([_KEY]) + claims_one = json.loads(codec.unseal(first.request_state)) + claims_two = json.loads(codec.unseal(second.request_state)) + assert claims_two["iat"] >= claims_one["iat"] + assert (claims_one["iat"], claims_two["iat"]) == (int(_T0), int(_T0) + 5) + + +# -- the default posture: every MCPServer seals under an ephemeral policy --------------- + + +async def test_an_mcpserver_seals_request_state_by_default() -> None: + """SDK-defined: with no `request_state_security=`, an MCPServer seals under a process-local key.""" + plaintext = "plain-wizard-state" + mcp, seen = _manual_server(None, state=plaintext) + + with anyio.fail_after(5): + async with Client(mcp) as client: + first = await client.session.call_tool("deploy", {"env": "prod"}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is not None + assert first.request_state != plaintext + assert first.request_state.startswith("v1.") + with pytest.raises(MCPError) as fabricated: + await _retry(client, "deploy", {"env": "prod"}, plaintext) + second = await _retry(client, "deploy", {"env": "prod"}, first.request_state) + + _assert_frozen_rejection(fabricated) + assert isinstance(second, CallToolResult) + assert seen == [plaintext] + + +async def test_the_default_key_is_per_instance_so_servers_never_cross_accept() -> None: + """SDK-defined: each default MCPServer mints its own ephemeral key; another instance rejects its state.""" + one, seen_one = _manual_server(None) + two, seen_two = _manual_server(None) + + with anyio.fail_after(5): + async with Client(one) as on_one, Client(two) as on_two: + token = await _first_round(on_one, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as exc: + await _retry(on_two, "deploy", {"env": "prod"}, token) + second = await _retry(on_one, "deploy", {"env": "prod"}, token) + + _assert_frozen_rejection(exc) + assert isinstance(second, CallToolResult) + assert seen_one == ["awaiting-confirm"] + assert seen_two == [] + + +async def test_a_boundary_free_lowlevel_server_passes_request_state_through_verbatim() -> None: + """SDK-defined: without a boundary in `Server.middleware`, `requestState` crosses as the handler's plaintext.""" + plaintext = "lowlevel-plain-round-1" + seen: list[str | None] = [] + + async def call_tool( + ctx: ServerRequestContext[Any], params: CallToolRequestParams + ) -> CallToolResult | InputRequiredResult: + if params.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask("Proceed?")}, request_state=plaintext) + seen.append(params.request_state) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("srv", on_call_tool=call_tool, on_list_tools=_list_tools) + + with anyio.fail_after(5): + async with Client(server) as client: + first = await client.session.call_tool("t", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state == plaintext + second = await _retry(client, "t", {}, plaintext) + + assert isinstance(second, CallToolResult) + assert seen == [plaintext] + + +# -- malformed wire input -------------------------------------------------------------- + + +async def test_non_string_inbound_request_state_is_rejected_with_the_frozen_error() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 5): a non-string + `requestState` fails at the boundary with the frozen shape.""" + calls: list[str] = [] + + async def call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: + calls.append(params.name) + return CallToolResult(content=[TextContent(text="ran")]) + + server = Server("srv", on_call_tool=call_tool) + server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience=None)) + + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/call", {"name": "t", "arguments": {}, "requestState": 123}) + assert calls == [] + result = await client.send_raw_request("tools/call", {"name": "t", "arguments": {}}) + + _assert_frozen_rejection(exc) + assert result["content"][0]["text"] == "ran" + assert calls == ["t"] + + +@pytest.mark.parametrize( + "install_boundary", + [ + pytest.param(True, id="boundary-installed"), + pytest.param(False, id="no-boundary"), + ], +) +async def test_an_explicit_null_request_state_is_treated_as_absent(install_boundary: bool) -> None: + """SDK-defined: an explicit `"requestState": null` is the field's absence, so the handler runs and sees None.""" + seen: list[str | None] = [] + + async def call_tool(ctx: ServerRequestContext[Any], params: CallToolRequestParams) -> CallToolResult: + seen.append(params.request_state) + return CallToolResult(content=[TextContent(text="ran")]) + + server = Server("srv", on_call_tool=call_tool) + if install_boundary: + server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience=None)) + + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("tools/call", {"name": "t", "arguments": {}, "requestState": None}) + + assert result["content"][0]["text"] == "ran" + assert seen == [None] + + +# -- boundary scope: only the three carrier methods are touched ------------------------- + + +async def test_inbound_request_state_on_a_non_carrier_method_passes_through_unverified() -> None: + """SDK-defined: only the MRTR carriers are touched; a custom method's `requestState` arrives as sent.""" + calls: list[str] = [] + + async def custom(ctx: ServerRequestContext[Any], params: _CustomMethodParams) -> dict[str, Any]: + calls.append(params.request_state or "fresh") + return {"resultType": "complete"} + + server = Server("srv", on_list_tools=_list_tools) + server.add_request_handler("example/mrtr", _CustomMethodParams, custom) + server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience=None)) + + async with connected_runner(server) as (client, _): + ok = await client.send_raw_request("example/mrtr", {"requestState": "CLIENT-SENT-VALUE"}) + fresh = await client.send_raw_request("example/mrtr", {}) + + assert ok == {"resultType": "complete"} + assert fresh == {"resultType": "complete"} + assert calls == ["CLIENT-SENT-VALUE", "fresh"] + + +async def test_outbound_request_state_on_a_non_carrier_method_is_not_sealed() -> None: + """SDK-defined: an input_required result on a custom method keeps its `requestState` unsealed.""" + + async def custom(ctx: ServerRequestContext[Any], params: _CustomMethodParams) -> InputRequiredResult: + return InputRequiredResult(input_requests={"confirm": _ask("?")}, request_state="ext-handler-plaintext") + + server = Server("srv", on_list_tools=_list_tools) + server.add_request_handler("example/mrtr", _CustomMethodParams, custom) + server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience=None)) + + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("example/mrtr", {}) + + assert result["resultType"] == "input_required" + assert result["requestState"] == "ext-handler-plaintext" + + +async def test_an_off_set_input_required_result_without_state_passes_through_untouched() -> None: + """SDK-defined: an input_required result on a non-carrier method minting no state crosses unmodified.""" + + async def custom(ctx: ServerRequestContext[Any], params: _CustomMethodParams) -> InputRequiredResult: + return InputRequiredResult(input_requests={"confirm": _ask("?")}) + + server = Server("srv", on_list_tools=_list_tools) + server.add_request_handler("example/mrtr", _CustomMethodParams, custom) + server.middleware.append(RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience=None)) + + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("example/mrtr", {}) + + assert result["resultType"] == "input_required" + assert "confirm" in result["inputRequests"] + assert "requestState" not in result + + +# -- custom codec: deny on error ------------------------------------------------------- + + +async def test_a_codec_that_raises_unexpectedly_fails_closed_with_the_frozen_error( + caplog: pytest.LogCaptureFixture, +) -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 5): a codec that raises + unexpectedly denies with the frozen rejection.""" + + class ExplodingCodec: + def seal(self, payload: bytes) -> str: + return "opaque-token" + + def unseal(self, token: str) -> bytes: + raise RuntimeError("codec exploded") + + mcp, seen = _manual_server(RequestStateSecurity(codec=ExplodingCodec())) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + assert token == "opaque-token" + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + _assert_frozen_rejection(exc) + + assert seen == [] + assert any(r.exc_info is not None and r.exc_info[0] is RuntimeError for r in caplog.records) + + +async def test_a_codec_reject_reason_reaches_the_log_but_never_the_wire( + caplog: pytest.LogCaptureFixture, +) -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 5): a custom codec's + `InvalidRequestState` reason is logged server-side, never sent on the wire.""" + + class RefusingCodec: + def seal(self, payload: bytes) -> str: + return "opaque-token" + + def unseal(self, token: str) -> bytes: + raise InvalidRequestState("boom") + + mcp, seen = _manual_server(RequestStateSecurity(codec=RefusingCodec())) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + _assert_frozen_rejection(exc) + + assert "boom" in caplog.text + assert seen == [] + + +@pytest.mark.parametrize( + "payload", + [ + pytest.param("not a claims envelope", id="not-json"), + pytest.param(json.dumps({"v": 1, "iat": 1, "exp": 2}), id="json-missing-claims"), + pytest.param(json.dumps({"v": 2, "iat": 1, "exp": 2, "s": "x"}), id="json-wrong-envelope-version"), + pytest.param(json.dumps({"v": 1, "iat": 1, "exp": 2, "s": 7}), id="json-non-string-state"), + ], +) +async def test_codec_authenticated_bytes_that_are_not_a_claims_envelope_are_rejected(payload: str) -> None: + """SDK-defined: codec-authenticated bytes that are not the claims envelope collapse to the frozen rejection.""" + mcp, seen = _manual_server(RequestStateSecurity(codec=_PassthroughCodec(), bind_principal=None)) + + with anyio.fail_after(5): + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, payload) + _assert_frozen_rejection(exc) + + assert seen == [] + + +async def test_a_forged_principal_claim_that_is_not_base64_is_rejected() -> None: + """SDK-defined: a `p` claim that does not decode as base64 collapses to the frozen rejection.""" + mcp, seen = _manual_server(RequestStateSecurity(codec=_PassthroughCodec(), bind_principal=lambda ctx: "alice")) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + claims = json.loads(token) # passthrough codec: the token IS the envelope JSON + claims["p"] = "A" # a single base64 char can never pad to a valid quantum + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, json.dumps(claims)) + _assert_frozen_rejection(exc) + + assert seen == [] + + +@pytest.mark.parametrize("forged", [pytest.param(7, id="int"), pytest.param({"x": 1}, id="object")]) +async def test_a_non_string_principal_claim_is_rejected_with_the_frozen_error(forged: Any) -> None: + """SDK-defined: a non-string `p` claim inside a validly-sealed envelope collapses to the frozen rejection.""" + mcp, seen = _manual_server(RequestStateSecurity(codec=_PassthroughCodec(), bind_principal=lambda ctx: "alice")) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + claims = json.loads(token) # passthrough codec: the token IS the envelope JSON + claims["p"] = forged + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, json.dumps(claims)) + _assert_frozen_rejection(exc) + + assert seen == [] + + +# -- log secrecy and the cause-invariant wire error ------------------------------------ + + +async def test_the_wire_error_never_varies_by_cause_and_logs_never_leak_secrets( + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 5): tampered, expired, and rebound + echoes get identical wire errors, with reasons logged but no secrets in any record.""" + plaintext = "secret-plaintext-state-1f9b" + principal = "principal-alice-7c3d" + mcp, seen = _manual_server( + RequestStateSecurity(keys=[_KEY], ttl=_TTL, bind_principal=lambda ctx: principal), state=plaintext + ) + clock = _Clock(_T0) + monkeypatch.setattr(request_state_module, "time", clock) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + with pytest.raises(MCPError) as tampered: + await _retry(client, "deploy", {"env": "prod"}, _tamper(token)) + clock.now = _T0 + _TTL + 1 + with pytest.raises(MCPError) as expired: + await _retry(client, "deploy", {"env": "prod"}, token) + clock.now = _T0 + with pytest.raises(MCPError) as rebound: + await _retry(client, "deploy", {"env": "staging"}, token) + _assert_frozen_rejection(tampered) + + shapes = [(e.value.error.code, e.value.error.message, e.value.error.data) for e in (tampered, expired, rebound)] + assert shapes[0] == shapes[1] == shapes[2] + assert seen == [] + + reject_logs = [r for r in caplog.records if r.name == "mcp.server.request_state" and r.levelno == logging.WARNING] + assert len(reject_logs) == 3 + for record in caplog.records: + message = record.getMessage() + assert token not in message + assert plaintext not in message + assert principal not in message + + +# -- pass-through inertness ------------------------------------------------------------ + + +async def test_a_complete_result_crosses_the_boundary_untouched() -> None: + """SDK-defined: a complete tools/call wire result passes the boundary as the identical object.""" + boundary = RequestStateBoundary(RequestStateSecurity(keys=[_KEY], bind_principal=None), default_audience=None) + complete: dict[str, Any] = {"resultType": "complete", "content": []} + + async def call_next(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + return complete + + ctx = ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/call", + params={"name": "t", "arguments": {}}, + ) + + assert await boundary(ctx, call_next) is complete + + +async def test_input_required_without_request_state_is_untouched() -> None: + """SDK-defined: an `input_required` result that asks without minting state crosses the boundary unmodified.""" + seen: list[str | None] = [] + mcp = MCPServer("stateless-ask", request_state_security=RequestStateSecurity(keys=[_KEY])) + + @mcp.tool() + async def ask(ctx: Context) -> str | InputRequiredResult: + if ctx.input_responses is None: + return InputRequiredResult(input_requests={"confirm": _ask("Sure?")}) + seen.append(ctx.request_state) + return "done" + + with anyio.fail_after(5): + async with Client(mcp) as client: + first = await client.session.call_tool("ask", {}, allow_input_required=True) + assert isinstance(first, InputRequiredResult) + assert first.request_state is None + second = await client.session.call_tool( + "ask", {}, input_responses={"confirm": _accept()}, allow_input_required=True + ) + + assert isinstance(second, CallToolResult) + assert seen == [None] + + +async def test_an_input_required_mapping_with_a_non_string_state_is_not_sealed() -> None: + """SDK-defined: a non-string `requestState` in a wire mapping is not this module's mint; it crosses unchanged.""" + boundary = RequestStateBoundary(RequestStateSecurity(keys=[_KEY], bind_principal=None), default_audience=None) + malformed: dict[str, Any] = {"resultType": "input_required", "inputRequests": {}, "requestState": 7} + + async def call_next(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + return malformed + + ctx = ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/call", + params={"name": "t", "arguments": {}}, + ) + + assert await boundary(ctx, call_next) is malformed + + +async def test_a_notification_crosses_the_boundary_unharmed() -> None: + """SDK-defined: the boundary is inert for notifications.""" + boundary = RequestStateBoundary(RequestStateSecurity(keys=[_KEY], bind_principal=None), default_audience=None) + forwarded: list[ServerRequestContext[Any, Any]] = [] + + async def call_next(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + forwarded.append(ctx) + return None + + ctx = ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="notifications/progress", + params={"progressToken": "p", "progress": 1}, + ) + + assert await boundary(ctx, call_next) is None + assert len(forwarded) == 1 + assert forwarded[0] is ctx + + +async def test_a_non_mrtr_method_with_no_params_is_untouched() -> None: + """SDK-defined: a non-carrier method with absent params passes the boundary inert.""" + boundary = RequestStateBoundary(RequestStateSecurity(keys=[_KEY], bind_principal=None), default_audience=None) + listing: dict[str, Any] = {"tools": [], "resultType": "complete"} + + async def call_next(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + return listing + + ctx = ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/list", + params=None, + ) + + assert await boundary(ctx, call_next) is listing + + +# -- direct chain invocation: the model-path seal -------------------------------------- + + +async def test_a_short_circuited_input_required_model_is_sealed_via_the_model_path() -> None: + """SDK-defined: a short-circuited `InputRequiredResult` model is sealed via the model path, on a copy.""" + boundary = RequestStateBoundary(RequestStateSecurity(keys=[_KEY], bind_principal=None), default_audience=None) + interim = InputRequiredResult(input_requests={"confirm": _ask("Go?")}, request_state="model-plaintext") + + async def call_next(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + return interim + + ctx = ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/call", + params={"name": "shortcut", "arguments": {}}, + ) + + result = await boundary(ctx, call_next) + + assert isinstance(result, InputRequiredResult) + assert result.input_requests == interim.input_requests + assert result.request_state is not None + assert result.request_state != "model-plaintext" + assert result.request_state.startswith("v1.") + claims = json.loads(AESGCMRequestStateCodec([_KEY]).unseal(result.request_state)) + assert (claims["m"], claims["t"], claims["s"]) == ("tools/call", "shortcut", "model-plaintext") + assert interim.request_state == "model-plaintext" + + +# -- user-supplied code on the seal path fails closed ----------------------------------- + + +class _RaisingSealCodec: + """Codec whose seal always fails, standing in for a KMS outage in a custom codec.""" + + def seal(self, payload: bytes) -> str: + raise RuntimeError("kms unreachable at 10.0.0.7: wrapped-key-id-42") + + def unseal(self, token: str) -> bytes: + raise InvalidRequestState("never minted") + + +async def test_a_codec_that_raises_during_seal_yields_a_sanitized_internal_error() -> None: + """SDK-defined: a raising custom codec fails the round with a sanitized internal error, never its own text.""" + mcp, _ = _manual_server(RequestStateSecurity(codec=_RaisingSealCodec())) + + with anyio.fail_after(5): + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc: + await client.session.call_tool("deploy", {"env": "prod"}, allow_input_required=True) + # The unseal direction of the same broken codec still maps to the frozen rejection. + with pytest.raises(MCPError) as inbound: + await _retry(client, "deploy", {"env": "prod"}, "token-this-codec-never-minted") + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "Internal error" + _assert_frozen_rejection(inbound) + _assert_frozen_rejection(inbound) + + +async def test_a_non_string_principal_fails_closed_when_sealing() -> None: + """SDK-defined: a bind_principal returning a non-string denies the round with the sanitized internal error.""" + + def numeric_user_id(ctx: ServerRequestContext[Any, Any]) -> str: + return cast("str", 12345) + + mcp, _ = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=numeric_user_id)) + + with anyio.fail_after(5): + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc: + await client.session.call_tool("deploy", {"env": "prod"}, allow_input_required=True) + assert exc.value.error.code == INTERNAL_ERROR + assert exc.value.error.message == "Internal error" + + +async def test_a_non_string_principal_fails_closed_when_verifying() -> None: + """SDK-defined: a non-string principal on the verify side rejects with the frozen error, not a crash.""" + principal: list[Any] = ["alice"] + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=lambda ctx: principal[0])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + principal[0] = 12345 + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, token) + _assert_frozen_rejection(exc) + assert seen == [] + + +# -- lone surrogates: every encode on the state path is total ---------------------------- + + +async def test_lone_surrogate_arguments_are_digested_not_crashed() -> None: + """SDK-defined: a lone UTF-16 surrogate in an argument string digests like any other value.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY])) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "\ud800-prod"}) + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "\udfff-prod"}, token) + second = await _retry(client, "deploy", {"env": "\ud800-prod"}, token) + + _assert_frozen_rejection(exc) # different args reject as a binding mismatch, not an internal error + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +async def test_lone_surrogate_handler_state_seals_and_restores() -> None: + """SDK-defined: handler-minted state containing a lone surrogate round-trips through the seal exactly.""" + plaintext = "awaiting-\ud800-confirm" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY]), state=plaintext) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + second = await _retry(client, "deploy", {"env": "prod"}, token) + + assert isinstance(second, CallToolResult) + assert seen == [plaintext] + + +async def test_lone_surrogate_principal_binds_and_verifies() -> None: + """SDK-defined: a principal string containing a lone surrogate binds and verifies like any other.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], bind_principal=lambda ctx: "tenant-\ud800")) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + second = await _retry(client, "deploy", {"env": "prod"}, token) + + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +# -- fractional clocks: the configured ttl is the effective ttl -------------------------- + + +async def test_a_fractional_mint_instant_keeps_the_full_ttl(monkeypatch: pytest.MonkeyPatch) -> None: + """SDK-defined: a token minted at a fractional instant lives the full configured ttl.""" + mcp, seen = _manual_server(RequestStateSecurity(keys=[_KEY], ttl=0.5)) + clock = _Clock(_T0 + 0.9) + monkeypatch.setattr(request_state_module, "time", clock) + + with anyio.fail_after(5): + async with Client(mcp) as client: + token = await _first_round(client, "deploy", {"env": "prod"}) + clock.now = _T0 + 1.3 # 0.4s after mint, inside the 0.5s ttl + second = await _retry(client, "deploy", {"env": "prod"}, token) + clock.now = _T0 + 2.0 + late = await _first_round(client, "deploy", {"env": "prod"}) + clock.now = _T0 + 2.6 # 0.6s after mint, past the ttl + with pytest.raises(MCPError) as exc: + await _retry(client, "deploy", {"env": "prod"}, late) + _assert_frozen_rejection(exc) + + assert isinstance(second, CallToolResult) + assert seen == ["awaiting-confirm"] + + +async def test_default_principal_distinguishes_two_subjects_of_one_oauth_client() -> None: + """Spec-mandated (basic/patterns/mrtr server requirement 5, cross-user reuse): with the + default binding, state sealed for one user of an OAuth client is rejected for another + user of the same client and restored only for the original subject.""" + boundary = RequestStateBoundary(RequestStateSecurity(keys=[_KEY]), default_audience="svc") + seen: list[str | None] = [] + + async def mint(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + return InputRequiredResult(input_requests={"confirm": _ask("PIN?")}, request_state="alice-secret") + + async def restore(ctx: ServerRequestContext[Any, Any]) -> HandlerResult: + assert ctx.params is not None + seen.append(ctx.params["requestState"]) + return CallToolResult(content=[TextContent(text="done")]) + + def request(token: str | None = None) -> ServerRequestContext[Any, Any]: + params: dict[str, Any] = {"name": "fetch_pin", "arguments": {}} + if token is not None: + params["requestState"] = token + return ServerRequestContext( + session=cast("Any", None), + lifespan_context={}, + protocol_version="2026-07-28", + method="tools/call", + params=params, + ) + + def as_user(subject: str) -> AuthenticatedUser: + shared_client = "https://agent.example/client.json" + return AuthenticatedUser( + AccessToken(token=f"at-{subject}", client_id=shared_client, scopes=[], subject=subject) + ) + + reset = auth_context_var.set(as_user("alice")) + try: + sealed = await boundary(request(), mint) + finally: + auth_context_var.reset(reset) + assert isinstance(sealed, InputRequiredResult) + assert sealed.request_state is not None + + reset = auth_context_var.set(as_user("bob")) + try: + with pytest.raises(MCPError) as exc: + await boundary(request(sealed.request_state), restore) + finally: + auth_context_var.reset(reset) + _assert_frozen_rejection(exc) + assert seen == [] + + reset = auth_context_var.set(as_user("alice")) + try: + result = await boundary(request(sealed.request_state), restore) + finally: + auth_context_var.reset(reset) + assert isinstance(result, CallToolResult) + assert seen == ["alice-secret"] + + +@pytest.mark.parametrize("name", [None, ""], ids=["unnamed", "empty-string"]) +def test_a_shared_key_policy_without_a_real_name_must_set_an_audience(name: str | None) -> None: + """SDK-defined: explicit keys usually mean a fleet, where the audience claim is what + separates services; without a real name every server would stamp the same placeholder.""" + with pytest.raises(ValueError) as excinfo: + MCPServer(name, request_state_security=RequestStateSecurity(keys=[_KEY])) + assert str(excinfo.value) == _MISSING_AUDIENCE + + # Every neighboring posture constructs: the default needs no name, and a real + # name or an explicit audience satisfies a shared-key policy. + MCPServer(name) + MCPServer(name, request_state_security=RequestStateSecurity(keys=[_KEY], audience="svc")) + MCPServer("named", request_state_security=RequestStateSecurity(keys=[_KEY])) diff --git a/tests/types/test_methods.py b/tests/types/test_methods.py index 342720c32c..126e06c291 100644 --- a/tests/types/test_methods.py +++ b/tests/types/test_methods.py @@ -553,6 +553,21 @@ def test_cacheable_methods_mirror_the_cacheable_method_literal(): assert methods.CACHEABLE_METHODS == frozenset(get_args(methods.CacheableMethod)) +def test_input_required_methods_mirror_the_monolith_input_required_arms(): + """MRTR weld: the spec's three multi-round-trip carriers are the only input_required methods.""" + assert methods.INPUT_REQUIRED_METHODS == frozenset({"prompts/get", "resources/read", "tools/call"}) + + +def test_is_input_required_matches_typed_and_wire_shapes(): + """SDK-defined predicate: True only for the typed model and the tagged wire mapping.""" + assert methods.is_input_required(types.InputRequiredResult(request_state="s")) + assert methods.is_input_required({"resultType": "input_required", "inputRequests": {}}) + assert not methods.is_input_required({"resultType": "complete", "content": []}) + assert not methods.is_input_required({}) + assert not methods.is_input_required(types.CallToolResult(content=[])) + assert not methods.is_input_required(None) + + def test_minimal_request_bodies_parse_through_every_request_row(): for (method, version), surface_type in methods.CLIENT_REQUESTS.items(): parsed = methods.parse_client_request(method, version, REQUEST_PARAMS_FIXTURES[surface_type])