diff --git a/python/copilot/client.py b/python/copilot/client.py index c25e6809..ec85e5e6 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -22,7 +22,7 @@ from collections.abc import Callable from dataclasses import asdict, is_dataclass from pathlib import Path -from typing import Any, cast +from typing import Any, Callable, Optional, cast, overload from .generated.rpc import ServerRpc from .generated.session_events import session_event_from_dict @@ -51,6 +51,8 @@ ToolResult, ) +HandlerUnsubcribe = Callable[[], None] + def _get_bundled_cli_path() -> str | None: """Get the path to the bundled CLI binary, if available.""" @@ -1007,11 +1009,20 @@ async def set_foreground_session_id(self, session_id: str) -> None: error = response.get("error", "Unknown error") raise RuntimeError(f"Failed to set foreground session: {error}") + @overload + def on(self, handler: SessionLifecycleHandler, /) -> HandlerUnsubcribe: ... + + @overload + def on( + self, event_type: SessionLifecycleEventType, /, handler: SessionLifecycleHandler + ) -> HandlerUnsubcribe: ... + def on( self, event_type_or_handler: SessionLifecycleEventType | SessionLifecycleHandler, - handler: SessionLifecycleHandler | None = None, - ) -> Callable[[], None]: + /, + handler: Optional[SessionLifecycleHandler] = None, + ) -> HandlerUnsubcribe: """ Subscribe to session lifecycle events. @@ -1568,9 +1579,10 @@ async def _execute_tool_call( } try: - result = handler(invocation) - if inspect.isawaitable(result): - result = await result + raw_result = handler(invocation) + if inspect.isawaitable(raw_result): + raw_result = await raw_result + result: ToolResult = cast(ToolResult, raw_result) except Exception as exc: # pylint: disable=broad-except # Don't expose detailed error information to the LLM for security reasons. # The actual error is stored in the 'error' field for debugging. diff --git a/python/copilot/session.py b/python/copilot/session.py index a02dcf1e..94ab2fca 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -8,8 +8,9 @@ import asyncio import inspect import threading + from collections.abc import Callable -from typing import Any, cast +from typing import Any, Callable, Optional, cast from .generated.rpc import SessionRpc from .generated.session_events import SessionEvent, SessionEventType, session_event_from_dict