From 609585af921dbcd81981d2e7602748bdb9e201e5 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 3 Feb 2026 15:54:48 +0100 Subject: [PATCH] feat: add `ClientRequestContext` type alias for client-side handlers Introduce a dedicated `ClientRequestContext` type alias in `mcp.client.context` to provide a cleaner API for client-side callback handlers (sampling, elicitation, list_roots). This improves the developer experience by: - Providing a concrete type instead of requiring `RequestContext[ClientSession]` - Making the internal `RequestContext` private by moving it to `_context.py` - Exporting `ClientRequestContext` from `mcp.client` for easy access All examples and conformance tests have been updated to use `ClientRequestContext` instead of the internal `RequestContext` type. The migration guide has been updated to reflect these changes. --- .github/actions/conformance/client.py | 4 ++-- README.md | 4 ++-- README.v2.md | 4 ++-- docs/migration.md | 17 +++++++++-------- .../mcp_simple_task_interactive_client/main.py | 6 +++--- examples/snippets/clients/stdio_client.py | 4 ++-- .../snippets/clients/url_elicitation_client.py | 4 ++-- src/mcp/client/__init__.py | 3 ++- src/mcp/client/context.py | 16 ++++++++++++++++ src/mcp/client/experimental/task_handlers.py | 2 +- src/mcp/client/session.py | 2 +- src/mcp/server/context.py | 2 +- src/mcp/shared/{context.py => _context.py} | 0 src/mcp/shared/progress.py | 2 +- tests/client/test_list_roots_callback.py | 2 +- tests/client/test_sampling_callback.py | 2 +- tests/client/test_session.py | 2 +- .../tasks/client/test_capabilities.py | 2 +- .../experimental/tasks/client/test_handlers.py | 2 +- .../tasks/test_elicitation_scenarios.py | 2 +- tests/server/mcpserver/test_elicitation.py | 2 +- tests/server/mcpserver/test_integration.py | 2 +- tests/server/mcpserver/test_url_elicitation.py | 2 +- tests/shared/test_progress_notifications.py | 2 +- tests/shared/test_streamable_http.py | 2 +- 25 files changed, 55 insertions(+), 37 deletions(-) create mode 100644 src/mcp/client/context.py rename src/mcp/shared/{context.py => _context.py} (100%) diff --git a/.github/actions/conformance/client.py b/.github/actions/conformance/client.py index 87f323132..2e1e7788b 100644 --- a/.github/actions/conformance/client.py +++ b/.github/actions/conformance/client.py @@ -38,9 +38,9 @@ PrivateKeyJWTOAuthProvider, SignedJWTParameters, ) +from mcp.client.context import ClientRequestContext from mcp.client.streamable_http import streamable_http_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken -from mcp.shared.context import RequestContext # Set up logging to stderr (stdout is for conformance test output) logging.basicConfig( @@ -187,7 +187,7 @@ async def run_sse_retry(server_url: str) -> None: async def default_elicitation_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: """Accept elicitation and apply defaults from the schema (SEP-1034).""" diff --git a/README.md b/README.md index 0f0468a19..b255b9eae 100644 --- a/README.md +++ b/README.md @@ -2120,8 +2120,8 @@ import asyncio import os from mcp import ClientSession, StdioServerParameters, types +from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client -from mcp.shared.context import RequestContext # Create server parameters for stdio connection server_params = StdioServerParameters( @@ -2133,7 +2133,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams + context: ClientRequestContext, params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/README.v2.md b/README.v2.md index 4fc110448..4eaced423 100644 --- a/README.v2.md +++ b/README.v2.md @@ -2121,8 +2121,8 @@ import asyncio import os from mcp import ClientSession, StdioServerParameters, types +from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client -from mcp.shared.context import RequestContext # Create server parameters for stdio connection server_params = StdioServerParameters( @@ -2134,7 +2134,7 @@ server_params = StdioServerParameters( # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession], params: types.CreateMessageRequestParams + context: ClientRequestContext, params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/docs/migration.md b/docs/migration.md index 84320ffef..7d30f0ac9 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -387,6 +387,7 @@ The `RequestContext` class has been split to separate shared fields from server- **Before (v1):** ```python +from mcp.client.session import ClientSession from mcp.shared.context import RequestContext, LifespanContextT, RequestT from mcp.shared.progress import ProgressContext @@ -400,19 +401,19 @@ progress_ctx: ProgressContext[SendRequestT, SendNotificationT, SendResultT, Rece **After (v2):** ```python -from mcp.shared.context import RequestContext +from mcp.client.context import ClientRequestContext +from mcp.client.session import ClientSession +from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT from mcp.shared.progress import ProgressContext -# RequestContext with 1 type parameter -ctx: RequestContext[ClientSession] - -# ProgressContext with 1 type parameter -progress_ctx: ProgressContext[ClientSession] +# For client-side context (sampling, elicitation, list_roots callbacks) +ctx: ClientRequestContext # For server-specific context with lifespan and request types -from mcp.server.context import ServerRequestContext, LifespanContextT, RequestT - server_ctx: ServerRequestContext[LifespanContextT, RequestT] + +# ProgressContext with 1 type parameter +progress_ctx: ProgressContext[ClientSession] ``` ### Resource URI type changed from `AnyUrl` to `str` diff --git a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py index a929418fa..ff5f49928 100644 --- a/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py +++ b/examples/clients/simple-task-interactive-client/mcp_simple_task_interactive_client/main.py @@ -10,8 +10,8 @@ import click from mcp import ClientSession +from mcp.client.context import ClientRequestContext from mcp.client.streamable_http import streamable_http_client -from mcp.shared.context import RequestContext from mcp.types import ( CallToolResult, CreateMessageRequestParams, @@ -23,7 +23,7 @@ async def elicitation_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: ElicitRequestParams, ) -> ElicitResult: """Handle elicitation requests from the server.""" @@ -38,7 +38,7 @@ async def elicitation_callback( async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: CreateMessageRequestParams, ) -> CreateMessageResult: """Handle sampling requests from the server.""" diff --git a/examples/snippets/clients/stdio_client.py b/examples/snippets/clients/stdio_client.py index ab3959f09..c1f85f42a 100644 --- a/examples/snippets/clients/stdio_client.py +++ b/examples/snippets/clients/stdio_client.py @@ -6,8 +6,8 @@ import os from mcp import ClientSession, StdioServerParameters, types +from mcp.client.context import ClientRequestContext from mcp.client.stdio import stdio_client -from mcp.shared.context import RequestContext # Create server parameters for stdio connection server_params = StdioServerParameters( @@ -19,7 +19,7 @@ # Optional: create a sampling callback async def handle_sampling_message( - context: RequestContext[ClientSession], params: types.CreateMessageRequestParams + context: ClientRequestContext, params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: print(f"Sampling request: {params.messages}") return types.CreateMessageResult( diff --git a/examples/snippets/clients/url_elicitation_client.py b/examples/snippets/clients/url_elicitation_client.py index b534135e0..9888c588e 100644 --- a/examples/snippets/clients/url_elicitation_client.py +++ b/examples/snippets/clients/url_elicitation_client.py @@ -31,14 +31,14 @@ from urllib.parse import urlparse from mcp import ClientSession, types +from mcp.client.context import ClientRequestContext from mcp.client.sse import sse_client -from mcp.shared.context import RequestContext from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError from mcp.types import URL_ELICITATION_REQUIRED async def handle_elicitation( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: """Handle elicitation requests from the server. diff --git a/src/mcp/client/__init__.py b/src/mcp/client/__init__.py index a1eaf3d7c..59bce03b8 100644 --- a/src/mcp/client/__init__.py +++ b/src/mcp/client/__init__.py @@ -2,6 +2,7 @@ from mcp.client._transport import Transport from mcp.client.client import Client +from mcp.client.context import ClientRequestContext from mcp.client.session import ClientSession -__all__ = ["Client", "ClientSession", "Transport"] +__all__ = ["Client", "ClientRequestContext", "ClientSession", "Transport"] diff --git a/src/mcp/client/context.py b/src/mcp/client/context.py new file mode 100644 index 000000000..2f4404e00 --- /dev/null +++ b/src/mcp/client/context.py @@ -0,0 +1,16 @@ +"""Request context for MCP client handlers.""" + +from mcp.client.session import ClientSession +from mcp.shared._context import RequestContext + +ClientRequestContext = RequestContext[ClientSession] +"""Context for handling incoming requests in a client session. + +This context is passed to client-side callbacks (sampling, elicitation, list_roots) when the server sends requests +to the client. + +Attributes: + request_id: The unique identifier for this request. + meta: Optional metadata associated with the request. + session: The client session handling this request. +""" diff --git a/src/mcp/client/experimental/task_handlers.py b/src/mcp/client/experimental/task_handlers.py index 448322cfb..ea1938a73 100644 --- a/src/mcp/client/experimental/task_handlers.py +++ b/src/mcp/client/experimental/task_handlers.py @@ -19,7 +19,7 @@ from pydantic import TypeAdapter import mcp.types as types -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.session import RequestResponder if TYPE_CHECKING: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b10d02ce6..09d03bdb8 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -10,7 +10,7 @@ import mcp.types as types from mcp.client.experimental import ExperimentalClientFeatures from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 0951a0784..43b9d3800 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -7,7 +7,7 @@ from mcp.server.experimental.request_context import Experimental from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.message import CloseSSEStreamCallback LifespanContextT = TypeVar("LifespanContextT") diff --git a/src/mcp/shared/context.py b/src/mcp/shared/_context.py similarity index 100% rename from src/mcp/shared/context.py rename to src/mcp/shared/_context.py diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 7225ac8d0..510bd8163 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -5,7 +5,7 @@ from pydantic import BaseModel -from mcp.shared.context import RequestContext, SessionT +from mcp.shared._context import RequestContext, SessionT from mcp.types import ProgressToken diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index 40265d57f..6a2f49f39 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -5,7 +5,7 @@ from mcp.client.session import ClientSession from mcp.server.mcpserver import MCPServer from mcp.server.mcpserver.server import Context -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.types import ListRootsResult, Root, TextContent diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 28995e0fb..3357bc921 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -3,7 +3,7 @@ from mcp import Client from mcp.client.session import ClientSession from mcp.server.mcpserver import MCPServer -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.types import ( CreateMessageRequestParams, CreateMessageResult, diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 40bd65b97..5cc685aaf 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -5,7 +5,7 @@ import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index 04561a090..965ec6eea 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -7,7 +7,7 @@ from mcp import ClientCapabilities from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.types import ( LATEST_PROTOCOL_VERSION, diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py index 9061aedc2..05165df24 100644 --- a/tests/experimental/tasks/client/test_handlers.py +++ b/tests/experimental/tasks/client/test_handlers.py @@ -22,7 +22,7 @@ import mcp.types as types from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers from mcp.client.session import ClientSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder diff --git a/tests/experimental/tasks/test_elicitation_scenarios.py b/tests/experimental/tasks/test_elicitation_scenarios.py index f755658c4..57122da7b 100644 --- a/tests/experimental/tasks/test_elicitation_scenarios.py +++ b/tests/experimental/tasks/test_elicitation_scenarios.py @@ -20,7 +20,7 @@ from mcp.server import Server from mcp.server.experimental.task_context import ServerTaskContext from mcp.server.lowlevel import NotificationOptions -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.experimental.tasks.helpers import is_terminal from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore from mcp.shared.message import SessionMessage diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 37e87a1f4..6cf49fbd7 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -9,7 +9,7 @@ from mcp.client.session import ClientSession, ElicitationFnT from mcp.server.mcpserver import Context, MCPServer from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index 40453b89d..c4ea2dad6 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -33,7 +33,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index 667a4279a..1311bd672 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -9,7 +9,7 @@ from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.mcpserver import Context, MCPServer from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ca632148b..a7ed7acb1 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -11,7 +11,7 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.context import RequestContext +from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.progress import progress from mcp.shared.session import RequestResponder diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index c1d0e3062..266162f62 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -43,12 +43,12 @@ ) from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared._context import RequestContext from mcp.shared._httpx_utils import ( MCP_DEFAULT_SSE_READ_TIMEOUT, MCP_DEFAULT_TIMEOUT, create_mcp_http_client, ) -from mcp.shared.context import RequestContext from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.types import InitializeResult, JSONRPCRequest, TextContent, TextResourceContents, Tool