Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ This project adheres to [Semantic Versioning](https://semver.org/).

### Fixed
- [#3805](https://github.com/plotly/dash/pull/3805) Fix FastAPI POST routes deadlock caused by middleware consuming request body. Fixes [#3801](https://github.com/plotly/dash/issues/3801).
- [#3815](https://github.com/plotly/dash/pull/3815) Fix missing request context (cookies/headers) in websocket callbacks.

## [4.2.0] - 2026-06-01 - *The Freedom Update*

Expand Down
17 changes: 17 additions & 0 deletions dash/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,23 @@ def inputs_to_dict(inputs_list):
return inputs


def populate_request_metadata(g, adapter):
"""Copy request metadata from a request adapter onto a context object.

Shared by the HTTP path (``Dash._initialize_context``) and the WebSocket
path (``dash.backends.ws.create_ws_context``) so that both transports expose
identical request context (cookies, headers, args, path, remote, origin) on
``callback_context``.
"""
g.cookies = dict(adapter.cookies)
g.headers = dict(adapter.headers)
g.args = adapter.args
g.path = adapter.full_path
g.remote = adapter.remote_addr
g.origin = adapter.origin
return g


def convert_to_AttributeDict(nested_list):
new_dict = []
for i in nested_list:
Expand Down
20 changes: 19 additions & 1 deletion dash/backends/_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import contextmanager
from contextvars import copy_context, ContextVar
import asyncio
import concurrent.futures
Expand Down Expand Up @@ -328,7 +329,8 @@ def setup_catchall(self, dash_app: Dash):
and passed through the middleware, which is necessary for features like authentication
and timing to work correctly on all routes. FastAPI will match this catch-all route
for any path that isn't matched by a more specific route, allowing the middleware to
process the request and then return the appropriate response (e.g., 404 if no Dash route matches)."""
process the request and then return the appropriate response (e.g., 404 if no Dash route matches).
"""

def _setup_catchall(self):
try:
Expand Down Expand Up @@ -725,6 +727,21 @@ async def websocket_handler(websocket: WebSocket):

await websocket.accept()

# Activate the WebSocket handshake request using the same
# request-context machinery as HTTP callbacks (set_current_request)
# so that callbacks running over the WebSocket transport can access
# cookies/headers (e.g. for authentication helpers such as
# dash_enterprise_auth.get_user_data). ContextVars do not propagate
# into the executor threads, so activation happens inside the worker
# thread via this callable.
@contextmanager
def activate_request():
token = set_current_request(websocket)
try:
yield FastAPIRequestAdapter()
finally:
reset_current_request(token)

# Create janus queue for outbound messages (main loop context)
outbound_queue: janus.Queue[str] = janus.Queue()
# Track pending get_props requests with standard queue.Queue for responses
Expand Down Expand Up @@ -788,6 +805,7 @@ async def websocket_handler(websocket: WebSocket):
payload,
ws_cb,
FastAPIResponseAdapter(),
activate_request,
)

# Set up done callback to send response
Expand Down
22 changes: 22 additions & 0 deletions dash/backends/_quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import threading
from urllib.parse import urlparse

from contextlib import contextmanager
from logging.config import dictConfig
from contextvars import copy_context
from types import SimpleNamespace
from typing import Any, Dict, TYPE_CHECKING

from importlib_metadata import version as _get_distribution_version
Expand Down Expand Up @@ -545,6 +547,25 @@ async def websocket_handler(): # pylint: disable=too-many-branches

await ws.accept()

# Quart's request/websocket context cannot cross into the executor
# threads where callbacks run, so snapshot the handshake metadata
# here (where the ``websocket`` proxy is valid) into an adapter-shaped
# object. It is funnelled through the same ``populate_request_metadata``
# helper as HTTP callbacks so the context (cookies, headers, args,
# path, remote, origin) is populated identically.
request_snapshot = SimpleNamespace(
cookies=ws.cookies,
headers=ws.headers,
args=ws.args,
full_path=ws.full_path,
remote_addr=ws.remote_addr,
origin=ws.headers.get("origin"),
)

@contextmanager
def activate_request():
yield request_snapshot

# Track this connection for graceful shutdown
try:
ws_obj = ws._get_current_object()
Expand Down Expand Up @@ -623,6 +644,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches
payload,
ws_cb,
QuartResponseAdapter(),
activate_request,
)

# Set up done callback to send response
Expand Down
73 changes: 53 additions & 20 deletions dash/backends/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
This module provides the WebSocket callback infrastructure for real-time
bidirectional communication between Dash backends and the renderer.
"""

from __future__ import annotations

import asyncio
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
import inspect
import json
import queue
Expand Down Expand Up @@ -189,19 +191,28 @@ def create_ws_context(
payload: dict,
response_adapter: "ResponseAdapter",
websocket_callback: DashWebsocketCallback,
request_adapter: Any = None,
):
"""Create callback context from WebSocket message.

Args:
payload: The callback payload
response_adapter: The response adapter instance for the backend
websocket_callback: The websocket callback instance for the backend
request_adapter: Optional request adapter (or any object exposing
``cookies``/``headers``/``args``/``full_path``/``remote_addr``/
``origin``) captured from the WebSocket handshake. When provided,
the request metadata is copied onto the context the same way as for
regular HTTP callbacks so that ``callback_context.cookies``/
``headers`` (and downstream helpers such as
``dash_enterprise_auth.get_user_data``) work inside WebSocket
callbacks.

Returns:
AttributeDict with callback context
"""
# pylint: disable=import-outside-toplevel
from dash._utils import AttributeDict, inputs_to_dict
from dash._utils import AttributeDict, inputs_to_dict, populate_request_metadata

g = AttributeDict({})
g.inputs_list = payload.get("inputs", [])
Expand All @@ -217,6 +228,16 @@ def create_ws_context(
g.updated_props = {}
g.dash_websocket = websocket_callback

if request_adapter is not None:
populate_request_metadata(g, request_adapter)
else:
g.cookies = {}
g.headers = {}
g.args = {}
g.path = ""
g.remote = ""
g.origin = ""

return g


Expand Down Expand Up @@ -396,6 +417,7 @@ def run_callback_in_executor(
payload: dict,
ws_callback: DashWebsocketCallback,
response_adapter: "ResponseAdapter",
activate_request: "Callable[[], Any] | None" = None,
) -> concurrent.futures.Future:
"""Submit callback to executor for thread pool execution.

Expand All @@ -408,36 +430,47 @@ def run_callback_in_executor(
payload: The callback payload from WebSocket message
ws_callback: WebSocket callback instance for set_prop/get_prop
response_adapter: Response adapter for the backend
activate_request: Optional zero-argument callable returning a context
manager that activates the WebSocket handshake request *inside the
worker thread* and yields a request adapter (or ``None``). This
lets each backend reuse its own request-context machinery (e.g.
``set_current_request`` for FastAPI) so the callback context is
populated the same way as for HTTP callbacks. ContextVars do not
propagate into executor threads, so activation must happen here.

Returns:
Future representing the pending callback execution
"""

def execute() -> dict:
try:
cb_ctx = create_ws_context(payload, response_adapter, ws_callback)
# pylint: disable=protected-access
func = dash_app._prepare_callback(cb_ctx, payload)
args = dash_app._inputs_to_vals( # pylint: disable=protected-access
cb_ctx.inputs_list + cb_ctx.states_list
)
request_cm = activate_request() if activate_request else nullcontext()
with request_cm as request_adapter:
cb_ctx = create_ws_context(
payload, response_adapter, ws_callback, request_adapter
)
# pylint: disable=protected-access
func = dash_app._prepare_callback(cb_ctx, payload)
args = dash_app._inputs_to_vals( # pylint: disable=protected-access
cb_ctx.inputs_list + cb_ctx.states_list
)

ctx = copy_context()
partial_func = (
dash_app._execute_callback( # pylint: disable=protected-access
func, args, cb_ctx.outputs_list, cb_ctx
ctx = copy_context()
partial_func = (
dash_app._execute_callback( # pylint: disable=protected-access
func, args, cb_ctx.outputs_list, cb_ctx
)
)
)

# Run in new event loop (handles both sync and async callbacks)
def run_callback():
result = partial_func()
if inspect.iscoroutine(result):
return asyncio.run(result)
return result
# Run in new event loop (handles both sync and async callbacks)
def run_callback():
result = partial_func()
if inspect.iscoroutine(result):
return asyncio.run(result)
return result

response_data = ctx.run(run_callback)
return {"status": "ok", "data": json.loads(response_data)}
response_data = ctx.run(run_callback)
return {"status": "ok", "data": json.loads(response_data)}

except PreventUpdate:
return {"status": "prevent_update"}
Expand Down
8 changes: 2 additions & 6 deletions dash/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
inputs_to_vals,
interpolate_str,
patch_collections_abc,
populate_request_metadata,
split_callback_id,
to_json,
convert_to_AttributeDict,
Expand Down Expand Up @@ -1494,12 +1495,7 @@ def _initialize_context(self, body):
for x in body.get("changedPropIds", [])
]
g.dash_response = self.backend.response_adapter()
g.cookies = dict(adapter.cookies)
g.headers = dict(adapter.headers)
g.args = adapter.args
g.path = adapter.full_path
g.remote = adapter.remote_addr
g.origin = adapter.origin
populate_request_metadata(g, adapter)
g.updated_props = {}
return g

Expand Down
28 changes: 28 additions & 0 deletions tests/websocket/test_ws_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,31 @@ def update_output(value):
dash_duo.wait_for_text_to_equal("#output", "Slider value: 50")

assert dash_duo.get_logs() == []


def test_ws008_websocket_request_context_cookies(dash_duo):
"""WebSocket callbacks should expose request cookies/headers on ctx (FastAPI)."""
app = Dash(__name__, backend="fastapi")

app.layout = html.Div(
[
dcc.Input(id="ws-input", type="text"),
html.Div(id="ws-output"),
]
)

@app.callback(
Output("ws-output", "children"), Input("ws-input", "value"), websocket=True
)
def show_context(value):
return f"cookie={ctx.cookies.get('wscookie', '')} headers={bool(ctx.headers)}"

dash_duo.start_server(app)

# Set a cookie, then reload so the WebSocket handshake carries it.
dash_duo.driver.add_cookie({"name": "wscookie", "value": "wsval"})
dash_duo.driver.refresh()

dash_duo.wait_for_text_to_equal("#ws-output", "cookie=wsval headers=True")

assert dash_duo.get_logs() == []
Loading