diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e0404e123..7c990cd053 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). ### Fixed - [#3805](https://github.com/plotly/dash/pull/3805) Fix FastAPI POST routes deadlock caused by middleware consuming request body. Fixes [#3801](https://github.com/plotly/dash/issues/3801). +- [#3815](https://github.com/plotly/dash/pull/3815) Fix missing request context (cookies/headers) in websocket callbacks. ## [4.2.0] - 2026-06-01 - *The Freedom Update* diff --git a/dash/_utils.py b/dash/_utils.py index 1ab2036820..c6a827f10d 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -218,6 +218,23 @@ def inputs_to_dict(inputs_list): return inputs +def populate_request_metadata(g, adapter): + """Copy request metadata from a request adapter onto a context object. + + Shared by the HTTP path (``Dash._initialize_context``) and the WebSocket + path (``dash.backends.ws.create_ws_context``) so that both transports expose + identical request context (cookies, headers, args, path, remote, origin) on + ``callback_context``. + """ + g.cookies = dict(adapter.cookies) + g.headers = dict(adapter.headers) + g.args = adapter.args + g.path = adapter.full_path + g.remote = adapter.remote_addr + g.origin = adapter.origin + return g + + def convert_to_AttributeDict(nested_list): new_dict = [] for i in nested_list: diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 0516b5edcb..6c573e0101 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,5 +1,6 @@ from __future__ import annotations +from contextlib import contextmanager from contextvars import copy_context, ContextVar import asyncio import concurrent.futures @@ -328,7 +329,8 @@ def setup_catchall(self, dash_app: Dash): and passed through the middleware, which is necessary for features like authentication and timing to work correctly on all routes. FastAPI will match this catch-all route for any path that isn't matched by a more specific route, allowing the middleware to - process the request and then return the appropriate response (e.g., 404 if no Dash route matches).""" + process the request and then return the appropriate response (e.g., 404 if no Dash route matches). + """ def _setup_catchall(self): try: @@ -725,6 +727,21 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() + # Activate the WebSocket handshake request using the same + # request-context machinery as HTTP callbacks (set_current_request) + # so that callbacks running over the WebSocket transport can access + # cookies/headers (e.g. for authentication helpers such as + # dash_enterprise_auth.get_user_data). ContextVars do not propagate + # into the executor threads, so activation happens inside the worker + # thread via this callable. + @contextmanager + def activate_request(): + token = set_current_request(websocket) + try: + yield FastAPIRequestAdapter() + finally: + reset_current_request(token) + # Create janus queue for outbound messages (main loop context) outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests with standard queue.Queue for responses @@ -788,6 +805,7 @@ async def websocket_handler(websocket: WebSocket): payload, ws_cb, FastAPIResponseAdapter(), + activate_request, ) # Set up done callback to send response diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index c7634ce93a..16e7a77f6c 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -12,8 +12,10 @@ import threading from urllib.parse import urlparse +from contextlib import contextmanager from logging.config import dictConfig from contextvars import copy_context +from types import SimpleNamespace from typing import Any, Dict, TYPE_CHECKING from importlib_metadata import version as _get_distribution_version @@ -545,6 +547,25 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() + # Quart's request/websocket context cannot cross into the executor + # threads where callbacks run, so snapshot the handshake metadata + # here (where the ``websocket`` proxy is valid) into an adapter-shaped + # object. It is funnelled through the same ``populate_request_metadata`` + # helper as HTTP callbacks so the context (cookies, headers, args, + # path, remote, origin) is populated identically. + request_snapshot = SimpleNamespace( + cookies=ws.cookies, + headers=ws.headers, + args=ws.args, + full_path=ws.full_path, + remote_addr=ws.remote_addr, + origin=ws.headers.get("origin"), + ) + + @contextmanager + def activate_request(): + yield request_snapshot + # Track this connection for graceful shutdown try: ws_obj = ws._get_current_object() @@ -623,6 +644,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches payload, ws_cb, QuartResponseAdapter(), + activate_request, ) # Set up done callback to send response diff --git a/dash/backends/ws.py b/dash/backends/ws.py index a4b302f215..55d5c5ccbb 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -3,11 +3,13 @@ This module provides the WebSocket callback infrastructure for real-time bidirectional communication between Dash backends and the renderer. """ + from __future__ import annotations import asyncio import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext import inspect import json import queue @@ -189,6 +191,7 @@ def create_ws_context( payload: dict, response_adapter: "ResponseAdapter", websocket_callback: DashWebsocketCallback, + request_adapter: Any = None, ): """Create callback context from WebSocket message. @@ -196,12 +199,20 @@ def create_ws_context( payload: The callback payload response_adapter: The response adapter instance for the backend websocket_callback: The websocket callback instance for the backend + request_adapter: Optional request adapter (or any object exposing + ``cookies``/``headers``/``args``/``full_path``/``remote_addr``/ + ``origin``) captured from the WebSocket handshake. When provided, + the request metadata is copied onto the context the same way as for + regular HTTP callbacks so that ``callback_context.cookies``/ + ``headers`` (and downstream helpers such as + ``dash_enterprise_auth.get_user_data``) work inside WebSocket + callbacks. Returns: AttributeDict with callback context """ # pylint: disable=import-outside-toplevel - from dash._utils import AttributeDict, inputs_to_dict + from dash._utils import AttributeDict, inputs_to_dict, populate_request_metadata g = AttributeDict({}) g.inputs_list = payload.get("inputs", []) @@ -217,6 +228,16 @@ def create_ws_context( g.updated_props = {} g.dash_websocket = websocket_callback + if request_adapter is not None: + populate_request_metadata(g, request_adapter) + else: + g.cookies = {} + g.headers = {} + g.args = {} + g.path = "" + g.remote = "" + g.origin = "" + return g @@ -396,6 +417,7 @@ def run_callback_in_executor( payload: dict, ws_callback: DashWebsocketCallback, response_adapter: "ResponseAdapter", + activate_request: "Callable[[], Any] | None" = None, ) -> concurrent.futures.Future: """Submit callback to executor for thread pool execution. @@ -408,6 +430,13 @@ def run_callback_in_executor( payload: The callback payload from WebSocket message ws_callback: WebSocket callback instance for set_prop/get_prop response_adapter: Response adapter for the backend + activate_request: Optional zero-argument callable returning a context + manager that activates the WebSocket handshake request *inside the + worker thread* and yields a request adapter (or ``None``). This + lets each backend reuse its own request-context machinery (e.g. + ``set_current_request`` for FastAPI) so the callback context is + populated the same way as for HTTP callbacks. ContextVars do not + propagate into executor threads, so activation must happen here. Returns: Future representing the pending callback execution @@ -415,29 +444,33 @@ def run_callback_in_executor( def execute() -> dict: try: - cb_ctx = create_ws_context(payload, response_adapter, ws_callback) - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals( # pylint: disable=protected-access - cb_ctx.inputs_list + cb_ctx.states_list - ) + request_cm = activate_request() if activate_request else nullcontext() + with request_cm as request_adapter: + cb_ctx = create_ws_context( + payload, response_adapter, ws_callback, request_adapter + ) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) - ctx = copy_context() - partial_func = ( - dash_app._execute_callback( # pylint: disable=protected-access - func, args, cb_ctx.outputs_list, cb_ctx + ctx = copy_context() + partial_func = ( + dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) ) - ) - # Run in new event loop (handles both sync and async callbacks) - def run_callback(): - result = partial_func() - if inspect.iscoroutine(result): - return asyncio.run(result) - return result + # Run in new event loop (handles both sync and async callbacks) + def run_callback(): + result = partial_func() + if inspect.iscoroutine(result): + return asyncio.run(result) + return result - response_data = ctx.run(run_callback) - return {"status": "ok", "data": json.loads(response_data)} + response_data = ctx.run(run_callback) + return {"status": "ok", "data": json.loads(response_data)} except PreventUpdate: return {"status": "prevent_update"} diff --git a/dash/dash.py b/dash/dash.py index f547b95b56..4f1c3a892f 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -46,6 +46,7 @@ inputs_to_vals, interpolate_str, patch_collections_abc, + populate_request_metadata, split_callback_id, to_json, convert_to_AttributeDict, @@ -1494,12 +1495,7 @@ def _initialize_context(self, body): for x in body.get("changedPropIds", []) ] g.dash_response = self.backend.response_adapter() - g.cookies = dict(adapter.cookies) - g.headers = dict(adapter.headers) - g.args = adapter.args - g.path = adapter.full_path - g.remote = adapter.remote_addr - g.origin = adapter.origin + populate_request_metadata(g, adapter) g.updated_props = {} return g diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py index 1d74706a68..136d96db46 100644 --- a/tests/websocket/test_ws_basic.py +++ b/tests/websocket/test_ws_basic.py @@ -252,3 +252,31 @@ def update_output(value): dash_duo.wait_for_text_to_equal("#output", "Slider value: 50") assert dash_duo.get_logs() == [] + + +def test_ws008_websocket_request_context_cookies(dash_duo): + """WebSocket callbacks should expose request cookies/headers on ctx (FastAPI).""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def show_context(value): + return f"cookie={ctx.cookies.get('wscookie', '')} headers={bool(ctx.headers)}" + + dash_duo.start_server(app) + + # Set a cookie, then reload so the WebSocket handshake carries it. + dash_duo.driver.add_cookie({"name": "wscookie", "value": "wsval"}) + dash_duo.driver.refresh() + + dash_duo.wait_for_text_to_equal("#ws-output", "cookie=wsval headers=True") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_context.py b/tests/websocket/test_ws_context.py new file mode 100644 index 0000000000..3ff72ee3ff --- /dev/null +++ b/tests/websocket/test_ws_context.py @@ -0,0 +1,129 @@ +"""Unit tests for WebSocket callback context creation. + +These tests verify that request metadata captured from the WebSocket +handshake (cookies, headers, etc.) is propagated onto the callback +context. This is required so authentication helpers that read +``callback_context.cookies``/``headers`` (such as +``dash_enterprise_auth.get_user_data``) work inside WebSocket callbacks. +""" + +from types import SimpleNamespace + +import pytest + +from dash.backends.ws import create_ws_context + + +def test_create_ws_context_propagates_request_context(): + """Request metadata from the adapter should be copied onto the context.""" + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + request_adapter = SimpleNamespace( + cookies={"kcIdToken": "token-value"}, + headers={"Plotly-User-Data": "{}"}, + args={"foo": "bar"}, + full_path="/_dash-ws-callback", + remote_addr="10.0.0.1", + origin="https://example.com", + ) + + g = create_ws_context( + payload, + response_adapter=None, + websocket_callback=None, + request_adapter=request_adapter, + ) + + assert g.cookies == {"kcIdToken": "token-value"} + assert g.headers == {"Plotly-User-Data": "{}"} + assert g.args == {"foo": "bar"} + assert g.path == "/_dash-ws-callback" + assert g.remote == "10.0.0.1" + assert g.origin == "https://example.com" + + +def test_create_ws_context_defaults_without_request_adapter(): + """Context should expose empty defaults when no request adapter is given.""" + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + + g = create_ws_context(payload, response_adapter=None, websocket_callback=None) + + assert g.cookies == {} + assert g.headers == {} + assert g.args == {} + assert g.path == "" + assert g.remote == "" + assert g.origin == "" + + +def test_run_executor_activates_request_across_thread_boundary(): + """Request activation must populate context inside the executor thread. + + WebSocket callbacks run in a ``ThreadPoolExecutor`` and ContextVars do not + propagate into those threads, so the refactor activates the handshake + request (via FastAPI's ``set_current_request``) *inside* the worker thread. + This guards that seam: a request activated in the worker thread is visible + to ``FastAPIRequestAdapter`` and gets copied onto the callback context. + """ + pytest.importorskip("fastapi") + from concurrent.futures import ThreadPoolExecutor + from contextlib import contextmanager + + from dash.backends._fastapi import ( + FastAPIRequestAdapter, + reset_current_request, + set_current_request, + ) + + # Minimal stand-in for a Starlette ``WebSocket`` handshake connection: it + # only needs the attributes the request adapter reads. + handshake = SimpleNamespace( + cookies={"kcIdToken": "token-value"}, + headers={"origin": "https://example.com"}, + query_params={"foo": "bar"}, + url="http://testserver/_dash-ws-callback", + client=SimpleNamespace(host="10.0.0.1"), + ) + + @contextmanager + def activate_request(): + token = set_current_request(handshake) + try: + yield FastAPIRequestAdapter() + finally: + reset_current_request(token) + + payload = { + "inputs": [], + "state": [], + "outputs": [], + "changedPropIds": [], + } + + def worker(): + with activate_request() as request_adapter: + return create_ws_context( + payload, + response_adapter=None, + websocket_callback=None, + request_adapter=request_adapter, + ) + + with ThreadPoolExecutor(max_workers=1) as executor: + g = executor.submit(worker).result() + + assert g.cookies == {"kcIdToken": "token-value"} + assert g.headers == {"origin": "https://example.com"} + assert g.args == {"foo": "bar"} + assert g.path == "http://testserver/_dash-ws-callback" + assert g.remote == "10.0.0.1" + assert g.origin == "https://example.com" diff --git a/tests/websocket/test_ws_quart.py b/tests/websocket/test_ws_quart.py index 3d40493ba5..f423f6a637 100644 --- a/tests/websocket/test_ws_quart.py +++ b/tests/websocket/test_ws_quart.py @@ -226,3 +226,37 @@ def multi_output(n_clicks): dash_duo.wait_for_text_to_equal("#output3", "Third: 3") assert dash_duo.get_logs() == [] + + +def test_wsq007_websocket_request_context_cookies_quart(dash_duo): + """WebSocket callbacks should expose request cookies/headers on ctx (Quart). + + End-to-end regression test for https://github.com/plotly/dash/issues/3814: + request metadata from the WebSocket handshake (cookies, headers) must be + available on ``callback_context`` for ``websocket=True`` callbacks, just as + it is for HTTP callbacks. + """ + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def show_context(value): + return f"cookie={ctx.cookies.get('wscookie', '')} headers={bool(ctx.headers)}" + + dash_duo.start_server(app) + + # Set a cookie, then reload so the WebSocket handshake carries it. + dash_duo.driver.add_cookie({"name": "wscookie", "value": "wsval"}) + dash_duo.driver.refresh() + + dash_duo.wait_for_text_to_equal("#ws-output", "cookie=wsval headers=True") + + assert dash_duo.get_logs() == []