From 603909d4549ff597e4f76e2771ea216a20eca8cf Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 21:57:34 +0100 Subject: [PATCH 1/6] fix: collapse single-error ExceptionGroups from task group cancellations When a task in an anyio task group fails, sibling tasks are cancelled and the resulting Cancelled exceptions are wrapped alongside the real error in a BaseExceptionGroup. This makes it extremely difficult for callers to classify the root cause of failures. Added collapse_exception_group() utility and a drop-in create_task_group() context manager that automatically unwraps single-error exception groups. When there is exactly one non-cancellation error, callers now receive the original exception directly instead of a wrapped group. Applied to all client-facing code paths: - BaseSession.__aexit__ (affects all session-based operations) - Client transports: stdio, SSE, streamable HTTP, websocket, memory Multiple real errors (non-cancellation) are preserved as-is in the exception group. Fixes #2114 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mcp/client/_memory.py | 3 +- src/mcp/client/sse.py | 3 +- src/mcp/client/stdio.py | 3 +- src/mcp/client/streamable_http.py | 3 +- src/mcp/client/websocket.py | 3 +- src/mcp/shared/_exception_utils.py | 60 ++++++++++++++ src/mcp/shared/session.py | 9 ++- tests/shared/test_exception_utils.py | 115 +++++++++++++++++++++++++++ 8 files changed, 193 insertions(+), 6 deletions(-) create mode 100644 src/mcp/shared/_exception_utils.py create mode 100644 tests/shared/test_exception_utils.py diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e938673..614389d0b 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -12,6 +12,7 @@ from mcp.client._transport import TransportStreams from mcp.server import Server from mcp.server.mcpserver import MCPServer +from mcp.shared._exception_utils import create_task_group as _create_task_group from mcp.shared.memory import create_client_server_memory_streams @@ -48,7 +49,7 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: client_read, client_write = client_streams server_read, server_write = server_streams - async with anyio.create_task_group() as tg: + async with _create_task_group() as tg: # Start server in background tg.start_soon( lambda: actual_server.run( diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..92b7e779e 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -12,6 +12,7 @@ from httpx_sse._exceptions import SSEError from mcp import types +from mcp.shared._exception_utils import create_task_group from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage @@ -60,7 +61,7 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: + async with create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( diff --git a/src/mcp/client/stdio.py b/src/mcp/client/stdio.py index 902dc8576..b5e4f8315 100644 --- a/src/mcp/client/stdio.py +++ b/src/mcp/client/stdio.py @@ -20,6 +20,7 @@ get_windows_executable_command, terminate_windows_process_tree, ) +from mcp.shared._exception_utils import create_task_group from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) @@ -177,7 +178,7 @@ async def stdin_writer(): except anyio.ClosedResourceError: # pragma: no cover await anyio.lowlevel.checkpoint() - async with anyio.create_task_group() as tg, process: + async with create_task_group() as tg, process: tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) try: diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9f3dd5e0b..3baca052c 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -16,6 +16,7 @@ from pydantic import ValidationError from mcp.client._transport import TransportStreams +from mcp.shared._exception_utils import create_task_group from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( @@ -546,7 +547,7 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: + async with create_task_group() as tg: try: logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 79e75fad1..a9ecbeaa4 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -9,6 +9,7 @@ from websockets.typing import Subprotocol from mcp import types +from mcp.shared._exception_utils import create_task_group from mcp.shared.message import SessionMessage @@ -68,7 +69,7 @@ async def ws_writer(): msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_unset=True) await ws.send(json.dumps(msg_dict)) - async with anyio.create_task_group() as tg: + async with create_task_group() as tg: # Start reader and writer tasks tg.start_soon(ws_reader) tg.start_soon(ws_writer) diff --git a/src/mcp/shared/_exception_utils.py b/src/mcp/shared/_exception_utils.py new file mode 100644 index 000000000..fcc0472eb --- /dev/null +++ b/src/mcp/shared/_exception_utils.py @@ -0,0 +1,60 @@ +"""Utilities for collapsing ExceptionGroups from anyio task group cancellations. + +When a task group has one real failure and N cancelled siblings, anyio wraps them +all in a BaseExceptionGroup. This makes it hard for callers to classify the root +cause. These utilities extract the single real error when possible. +""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import anyio +from anyio.abc import TaskGroup + + +def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> BaseException: + """Collapse a single-error exception group into the underlying exception. + + When a task in an anyio task group fails, sibling tasks are cancelled, + producing ``Cancelled`` exceptions. The task group then wraps everything + in a ``BaseExceptionGroup``. If there is exactly one non-cancellation + error, this function returns it directly so callers can handle it without + unwrapping. + + Args: + exc_group: The exception group to collapse. + + Returns: + The single non-cancellation exception if there is exactly one, + otherwise the original exception group unchanged. + """ + cancelled_class = anyio.get_cancelled_exc_class() + real_errors: list[BaseException] = [ + exc for exc in exc_group.exceptions if not isinstance(exc, cancelled_class) + ] + + if len(real_errors) == 1: + return real_errors[0] + + return exc_group + + +@asynccontextmanager +async def create_task_group() -> AsyncIterator[TaskGroup]: + """Create an anyio task group that collapses single-error exception groups. + + Drop-in replacement for ``anyio.create_task_group()`` that automatically + unwraps ``BaseExceptionGroup`` when there is exactly one non-cancellation + error. This makes error handling transparent for callers — they receive + the original exception instead of a wrapped group. + """ + try: + async with anyio.create_task_group() as tg: + yield tg + except BaseExceptionGroup as eg: + collapsed = collapse_exception_group(eg) + if collapsed is not eg: + raise collapsed from eg + raise diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..5b3a573f1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +from mcp.shared._exception_utils import collapse_exception_group from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.response_router import ResponseRouter @@ -228,7 +229,13 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + try: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + except BaseExceptionGroup as eg: + collapsed = collapse_exception_group(eg) + if collapsed is not eg: + raise collapsed from eg + raise async def send_request( self, diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py new file mode 100644 index 000000000..52398489f --- /dev/null +++ b/tests/shared/test_exception_utils.py @@ -0,0 +1,115 @@ +"""Tests for exception group collapsing utilities.""" + +import pytest + +import anyio + +from mcp.shared._exception_utils import collapse_exception_group, create_task_group + + +class TestCollapseExceptionGroup: + """Tests for the collapse_exception_group function.""" + + @pytest.mark.anyio + async def test_single_real_error_with_cancelled(self) -> None: + """A single real error alongside Cancelled exceptions should be extracted.""" + real_error = RuntimeError("connection failed") + cancelled = anyio.get_cancelled_exc_class()() + + group = BaseExceptionGroup("test", [real_error, cancelled]) + result = collapse_exception_group(group) + + assert result is real_error + + @pytest.mark.anyio + async def test_single_real_error_only(self) -> None: + """A single real error without Cancelled should be extracted.""" + real_error = ValueError("bad value") + + group = BaseExceptionGroup("test", [real_error]) + result = collapse_exception_group(group) + + assert result is real_error + + @pytest.mark.anyio + async def test_multiple_real_errors_preserved(self) -> None: + """Multiple non-cancellation errors should keep the group intact.""" + err1 = RuntimeError("first") + err2 = ValueError("second") + + group = BaseExceptionGroup("test", [err1, err2]) + result = collapse_exception_group(group) + + assert result is group + + @pytest.mark.anyio + async def test_all_cancelled_preserved(self) -> None: + """All-cancelled groups should be returned as-is.""" + cancelled_class = anyio.get_cancelled_exc_class() + group = BaseExceptionGroup("test", [cancelled_class(), cancelled_class()]) + result = collapse_exception_group(group) + + assert result is group + + @pytest.mark.anyio + async def test_multiple_cancelled_one_real(self) -> None: + """One real error with multiple Cancelled should extract the real error.""" + cancelled_class = anyio.get_cancelled_exc_class() + real_error = ConnectionError("lost connection") + + group = BaseExceptionGroup( + "test", [cancelled_class(), real_error, cancelled_class()] + ) + result = collapse_exception_group(group) + + assert result is real_error + + +class TestCreateTaskGroup: + """Tests for the create_task_group context manager.""" + + @pytest.mark.anyio + async def test_single_failure_unwrapped(self) -> None: + """A single task failure should propagate the original exception, not a group.""" + with pytest.raises(RuntimeError, match="task failed"): + async with create_task_group() as tg: + + async def failing_task() -> None: + raise RuntimeError("task failed") + + async def long_task() -> None: + await anyio.sleep(100) + + tg.start_soon(failing_task) + tg.start_soon(long_task) + + @pytest.mark.anyio + async def test_no_failure_clean_exit(self) -> None: + """Task group with no failures should exit cleanly.""" + results: list[int] = [] + async with create_task_group() as tg: + + async def worker(n: int) -> None: + results.append(n) + + tg.start_soon(worker, 1) + tg.start_soon(worker, 2) + + assert sorted(results) == [1, 2] + + @pytest.mark.anyio + async def test_chained_cause(self) -> None: + """The collapsed exception should chain to the original group via __cause__.""" + with pytest.raises(RuntimeError) as exc_info: + async with create_task_group() as tg: + + async def failing_task() -> None: + raise RuntimeError("root cause") + + async def long_task() -> None: + await anyio.sleep(100) + + tg.start_soon(failing_task) + tg.start_soon(long_task) + + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) From e318e247e9d3596265384365641f5e134e131b5c Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 22:42:21 +0100 Subject: [PATCH 2/6] fix: add BaseExceptionGroup compat import for Python 3.10 + ruff format --- src/mcp/shared/_exception_utils.py | 8 +++++--- tests/shared/test_exception_utils.py | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/mcp/shared/_exception_utils.py b/src/mcp/shared/_exception_utils.py index fcc0472eb..c090fb68d 100644 --- a/src/mcp/shared/_exception_utils.py +++ b/src/mcp/shared/_exception_utils.py @@ -7,12 +7,16 @@ from __future__ import annotations +import sys from collections.abc import AsyncIterator from contextlib import asynccontextmanager import anyio from anyio.abc import TaskGroup +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> BaseException: """Collapse a single-error exception group into the underlying exception. @@ -31,9 +35,7 @@ def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> Ba otherwise the original exception group unchanged. """ cancelled_class = anyio.get_cancelled_exc_class() - real_errors: list[BaseException] = [ - exc for exc in exc_group.exceptions if not isinstance(exc, cancelled_class) - ] + real_errors: list[BaseException] = [exc for exc in exc_group.exceptions if not isinstance(exc, cancelled_class)] if len(real_errors) == 1: return real_errors[0] diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py index 52398489f..3a82bba5b 100644 --- a/tests/shared/test_exception_utils.py +++ b/tests/shared/test_exception_utils.py @@ -1,9 +1,14 @@ """Tests for exception group collapsing utilities.""" +import sys + import pytest import anyio +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + from mcp.shared._exception_utils import collapse_exception_group, create_task_group @@ -57,9 +62,7 @@ async def test_multiple_cancelled_one_real(self) -> None: cancelled_class = anyio.get_cancelled_exc_class() real_error = ConnectionError("lost connection") - group = BaseExceptionGroup( - "test", [cancelled_class(), real_error, cancelled_class()] - ) + group = BaseExceptionGroup("test", [cancelled_class(), real_error, cancelled_class()]) result = collapse_exception_group(group) assert result is real_error From b6cf9df75f375b992484f339753225fbfb8ee294 Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 22:48:45 +0100 Subject: [PATCH 3/6] fix: add BaseExceptionGroup compat import to session.py + sort test imports --- src/mcp/shared/session.py | 4 ++++ tests/shared/test_exception_utils.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5b3a573f1..850152bae 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import sys from collections.abc import Callable from contextlib import AsyncExitStack from types import TracebackType @@ -11,6 +12,9 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + from mcp.shared._exception_utils import collapse_exception_group from mcp.shared.exceptions import MCPError from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py index 3a82bba5b..3d7008506 100644 --- a/tests/shared/test_exception_utils.py +++ b/tests/shared/test_exception_utils.py @@ -2,9 +2,8 @@ import sys -import pytest - import anyio +import pytest if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup From 4a54b0b7a811794e68581f01462084952d394188 Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 23:02:45 +0100 Subject: [PATCH 4/6] fix: remove unused anyio import, add pragma:no-branch for 3.10 compat, add multi-failure test --- src/mcp/client/_memory.py | 2 -- src/mcp/shared/_exception_utils.py | 2 +- src/mcp/shared/session.py | 2 +- tests/shared/test_exception_utils.py | 17 ++++++++++++++++- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index 614389d0b..968855e3d 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -7,8 +7,6 @@ from types import TracebackType from typing import Any -import anyio - from mcp.client._transport import TransportStreams from mcp.server import Server from mcp.server.mcpserver import MCPServer diff --git a/src/mcp/shared/_exception_utils.py b/src/mcp/shared/_exception_utils.py index c090fb68d..a40c0ac32 100644 --- a/src/mcp/shared/_exception_utils.py +++ b/src/mcp/shared/_exception_utils.py @@ -14,7 +14,7 @@ import anyio from anyio.abc import TaskGroup -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no branch from exceptiongroup import BaseExceptionGroup diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 850152bae..dfc248086 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no branch from exceptiongroup import BaseExceptionGroup from mcp.shared._exception_utils import collapse_exception_group diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py index 3d7008506..826f3d093 100644 --- a/tests/shared/test_exception_utils.py +++ b/tests/shared/test_exception_utils.py @@ -5,7 +5,7 @@ import anyio import pytest -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no branch from exceptiongroup import BaseExceptionGroup from mcp.shared._exception_utils import collapse_exception_group, create_task_group @@ -115,3 +115,18 @@ async def long_task() -> None: tg.start_soon(long_task) assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + + @pytest.mark.anyio + async def test_multiple_failures_raises_group(self) -> None: + """Multiple real task failures should raise as a BaseExceptionGroup.""" + with pytest.raises(BaseExceptionGroup): + async with create_task_group() as tg: + + async def fail_a() -> None: + raise RuntimeError("error A") + + async def fail_b() -> None: + raise ValueError("error B") + + tg.start_soon(fail_a) + tg.start_soon(fail_b) From 790ccb9af76d08cf106b313d7853e64ca3a0d1a9 Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 23:11:31 +0100 Subject: [PATCH 5/6] fix: use pragma:no-cover for version-conditional imports to satisfy 100% coverage The conditional import of BaseExceptionGroup for Python < 3.11 used pragma:no-branch, which only suppresses branch coverage. On 3.11+ the import statement itself is never executed, causing coverage to drop below 100%. Switch to pragma:no-cover which excludes the entire conditional block from coverage reporting. Also mark the re-raise in session.__aexit__ as pragma:no-cover since it only fires when the task group has multiple simultaneous real failures (an edge case that is already tested in test_exception_utils). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/mcp/shared/_exception_utils.py | 2 +- src/mcp/shared/session.py | 4 ++-- tests/shared/test_exception_utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/shared/_exception_utils.py b/src/mcp/shared/_exception_utils.py index a40c0ac32..793a82e20 100644 --- a/src/mcp/shared/_exception_utils.py +++ b/src/mcp/shared/_exception_utils.py @@ -14,7 +14,7 @@ import anyio from anyio.abc import TaskGroup -if sys.version_info < (3, 11): # pragma: no branch +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import BaseExceptionGroup diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index dfc248086..5cdb24ee8 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self -if sys.version_info < (3, 11): # pragma: no branch +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import BaseExceptionGroup from mcp.shared._exception_utils import collapse_exception_group @@ -239,7 +239,7 @@ async def __aexit__( collapsed = collapse_exception_group(eg) if collapsed is not eg: raise collapsed from eg - raise + raise # pragma: no cover async def send_request( self, diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py index 826f3d093..9dea6dcee 100644 --- a/tests/shared/test_exception_utils.py +++ b/tests/shared/test_exception_utils.py @@ -5,7 +5,7 @@ import anyio import pytest -if sys.version_info < (3, 11): # pragma: no branch +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import BaseExceptionGroup from mcp.shared._exception_utils import collapse_exception_group, create_task_group From ae8cedd9ff34f5313f38af4ec878a3490ddf4c92 Mon Sep 17 00:00:00 2001 From: g97iulio1609 Date: Sat, 28 Feb 2026 23:29:09 +0100 Subject: [PATCH 6/6] fix: use pragma:lax-no-cover for version-conditional imports (bypasses strict-no-cover) --- src/mcp/shared/_exception_utils.py | 4 ++-- src/mcp/shared/session.py | 4 ++-- tests/shared/test_exception_utils.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcp/shared/_exception_utils.py b/src/mcp/shared/_exception_utils.py index 793a82e20..72a101f02 100644 --- a/src/mcp/shared/_exception_utils.py +++ b/src/mcp/shared/_exception_utils.py @@ -14,8 +14,8 @@ import anyio from anyio.abc import TaskGroup -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover def collapse_exception_group(exc_group: BaseExceptionGroup[BaseException]) -> BaseException: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5cdb24ee8..952a9a847 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,8 +12,8 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import Self -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover from mcp.shared._exception_utils import collapse_exception_group from mcp.shared.exceptions import MCPError diff --git a/tests/shared/test_exception_utils.py b/tests/shared/test_exception_utils.py index 9dea6dcee..2b4651b67 100644 --- a/tests/shared/test_exception_utils.py +++ b/tests/shared/test_exception_utils.py @@ -5,8 +5,8 @@ import anyio import pytest -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup +if sys.version_info < (3, 11): # pragma: lax no cover + from exceptiongroup import BaseExceptionGroup # pragma: lax no cover from mcp.shared._exception_utils import collapse_exception_group, create_task_group