diff --git a/dotnet/README.md b/dotnet/README.md index 567b1c33b..dff35beaa 100644 --- a/dotnet/README.md +++ b/dotnet/README.md @@ -228,7 +228,11 @@ subscription.Dispose(); ##### `AbortAsync(): Task` -Abort the currently processing message in this session. +Abort the currently processing message in this session. This also cancels the `CancellationToken` passed to any in-flight tool handlers (see [Cancelling Tool Handlers](#cancelling-tool-handlers)). + +##### `CancelToolCall(toolCallId: string): bool` + +Cooperatively cancel a single in-flight tool handler by cancelling its `CancellationToken`, without aborting the broader agentic loop. Returns `true` if a matching in-flight tool call was found and signalled, `false` otherwise. ##### `GetEventsAsync(): Task>` @@ -544,6 +548,68 @@ var lookupIssue = CopilotTool.DefineTool( }); ``` +#### Cancelling Tool Handlers + +Long-running tool handlers can opt in to cooperative cancellation. The SDK passes a `CancellationToken` that is cancelled when `session.AbortAsync()` (which cancels the whole agentic loop) or `session.CancelToolCall(toolCallId)` (which cancels a single in-flight handler) is invoked. + +You can receive the token in two ways: + +**Option 1 — direct `CancellationToken` parameter** (standard .NET pattern, automatically bound by `AIFunctionFactory`): + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Tools = [ + CopilotTool.DefineTool( + async ([Description("URL to fetch")] string url, CancellationToken cancellationToken) => + { + // The request is cancelled automatically when the session/tool is cancelled + using var http = new HttpClient(); + return await http.GetStringAsync(url, cancellationToken); + }, + factoryOptions: new AIFunctionFactoryOptions + { + Name = "fetch_data", + Description = "Fetch a remote URL", + }), + ] +}); +``` + +**Option 2 — via `ToolInvocation`** (useful when you already use `ToolInvocation` for the session ID or tool call ID): + +```csharp +var session = await client.CreateSessionAsync(new SessionConfig +{ + Tools = [ + CopilotTool.DefineTool( + async ([Description("URL to fetch")] string url, ToolInvocation invocation) => + { + using var http = new HttpClient(); + return await http.GetStringAsync(url, invocation.CancellationToken); + }, + factoryOptions: new AIFunctionFactoryOptions + { + Name = "fetch_data", + Description = "Fetch a remote URL", + }), + ] +}); +``` + +Cancel a specific in-flight handler without aborting the rest of the turn: + +```csharp +session.On(e => +{ + // Cancel this specific tool call after a deadline + _ = Task.Delay(TimeSpan.FromSeconds(5)).ContinueWith(_ => + session.CancelToolCall(e.Data.ToolCallId)); +}); +``` + +Handlers that ignore the token continue to run to completion, so existing handlers keep working unchanged. + ## Commands Register slash commands so that users of the CLI's TUI can invoke custom actions via `/commandName`. Each command has a `Name`, optional `Description`, and a `Handler` called when the user executes it. diff --git a/dotnet/src/Session.cs b/dotnet/src/Session.cs index 095c1abf7..300aad647 100644 --- a/dotnet/src/Session.cs +++ b/dotnet/src/Session.cs @@ -68,6 +68,13 @@ public sealed partial class CopilotSession : IAsyncDisposable private volatile Func>? _autoModeSwitchHandler; private ImmutableArray _eventHandlers = ImmutableArray.Empty; + // Guards _inFlightToolCalls — accessed from the event-processing loop and from + // AbortAsync / CancelToolCall which may be called from any thread. + private readonly object _inFlightToolCallsLock = new(); + // Keyed by requestId (unique per RPC request) to avoid collisions on toolCallId reuse. + // The tuple stores the toolCallId for lookup by CancelToolCall and the CTS to cancel. + private readonly Dictionary _inFlightToolCalls = []; + private sealed record EventSubscription(Type EventType, Action Handler); private SessionHooks? _hooks; @@ -710,6 +717,11 @@ await HandleElicitationRequestAsync( /// private async Task ExecuteToolAndRespondAsync(string requestId, string toolName, string toolCallId, JsonElement? arguments, AIFunction tool) { + using var cts = new CancellationTokenSource(); + lock (_inFlightToolCallsLock) + { + _inFlightToolCalls[requestId] = (toolCallId, cts); + } try { var invocation = new ToolInvocation @@ -717,7 +729,8 @@ private async Task ExecuteToolAndRespondAsync(string requestId, string toolName, SessionId = SessionId, ToolCallId = toolCallId, ToolName = toolName, - Arguments = arguments + Arguments = arguments, + CancellationToken = cts.Token }; var aiFunctionArgs = new AIFunctionArguments @@ -737,7 +750,7 @@ private async Task ExecuteToolAndRespondAsync(string requestId, string toolName, } var toolTimestamp = Stopwatch.GetTimestamp(); - var result = await tool.InvokeAsync(aiFunctionArgs); + var result = await tool.InvokeAsync(aiFunctionArgs, cts.Token); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotSession.ExecuteToolAndRespondAsync tool dispatch. Elapsed={Elapsed}, SessionId={SessionId}, RequestId={RequestId}, ToolCallId={ToolCallId}, Tool={ToolName}", toolTimestamp, @@ -773,6 +786,17 @@ private async Task ExecuteToolAndRespondAsync(string requestId, string toolName, // Connection already disposed — nothing we can do } } + finally + { + // Only remove if this is still the active CTS for this requestId. + lock (_inFlightToolCallsLock) + { + if (_inFlightToolCalls.TryGetValue(requestId, out var entry) && entry.Cts == cts) + { + _inFlightToolCalls.Remove(requestId); + } + } + } } /// @@ -1580,9 +1604,87 @@ public async Task AbortAsync(CancellationToken cancellationToken = default) { ThrowIfDisposed(); + // Cooperatively cancel any in-flight tool handlers that opted in to the + // CancellationToken exposed on their ToolInvocation (or as a direct parameter). + // Handlers that ignore the token continue to run to completion. + AbortInFlightToolCalls(); + await InvokeRpcAsync("session.abort", [new SessionAbortRequest { SessionId = SessionId }], cancellationToken); } + /// + /// Cooperatively cancels a single in-flight tool handler by signalling its + /// , without aborting the broader agentic loop. + /// + /// The toolCallId of the in-flight tool invocation to cancel. + /// + /// if a matching in-flight tool call was found and its cancellation + /// token was signalled; if no matching in-flight tool call exists. + /// + /// + /// This only affects handlers that opted in to the cancellation token (e.g. by passing it to + /// a cancellable API or by checking ). + /// Handlers that ignore the token continue to run to completion, preserving existing behavior. + /// + /// + /// + /// session.On<ToolExecutionStartEvent>(e => + /// { + /// // Cancel this specific tool call after 5 seconds + /// _ = Task.Delay(TimeSpan.FromSeconds(5)).ContinueWith(_ => + /// session.CancelToolCall(e.Data.ToolCallId)); + /// }); + /// + /// + public bool CancelToolCall(string toolCallId) + { + ArgumentNullException.ThrowIfNull(toolCallId); + CancellationTokenSource? found = null; + lock (_inFlightToolCallsLock) + { + foreach (var (_, (tid, cts)) in _inFlightToolCalls) + { + if (tid == toolCallId) + { + found = cts; + break; + } + } + } + // Cancel outside the lock to avoid running CancellationToken callbacks + // while holding _inFlightToolCallsLock, which could cause deadlocks if + // a callback (directly or indirectly) touches session state. + if (found is not null) + { + found.Cancel(); + return true; + } + return false; + } + + /// + /// Cancels the for every currently in-flight tool handler. + /// + private void AbortInFlightToolCalls() + { + List? snapshot = null; + lock (_inFlightToolCallsLock) + { + if (_inFlightToolCalls.Count > 0) + { + snapshot = [.. _inFlightToolCalls.Values.Select(e => e.Cts)]; + } + } + + if (snapshot is not null) + { + foreach (var cts in snapshot) + { + cts.Cancel(); + } + } + } + /// /// Changes the model for this session. /// The new model takes effect for the next message. Conversation history is preserved. @@ -1709,6 +1811,11 @@ public async ValueTask DisposeAsync() _eventChannel.Writer.TryComplete(); + // Abort any in-flight tool handlers so they can release resources before the + // session connection is torn down. + AbortInFlightToolCalls(); + lock (_inFlightToolCallsLock) { _inFlightToolCalls.Clear(); } + try { await InvokeRpcAsync( diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 706a1ec6b..01715ff74 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -739,6 +739,21 @@ public sealed class ToolInvocation /// Arguments passed to the tool by the language model. /// public JsonElement? Arguments { get; set; } + /// + /// A that is cancelled when + /// or + /// is called while this handler is in flight. Handlers may opt in to cooperative + /// cancellation by passing it to cancellable APIs or by checking + /// . + /// Handlers that ignore it continue to run to completion, preserving existing behavior. + /// + /// + /// Note that a parameter can also be + /// declared directly on the tool handler delegate — the SDK binds the same token to it + /// automatically via Microsoft.Extensions.AI.AIFunctionFactory. + /// + [JsonIgnore] + public CancellationToken CancellationToken { get; set; } } /// diff --git a/dotnet/test/Unit/CopilotToolTests.cs b/dotnet/test/Unit/CopilotToolTests.cs index 9c9e2a93b..5ff743d89 100644 --- a/dotnet/test/Unit/CopilotToolTests.cs +++ b/dotnet/test/Unit/CopilotToolTests.cs @@ -136,6 +136,60 @@ public void DefineTool_Preserves_Additional_Properties_And_ToolOptions_Take_Prec Assert.True((bool)skipPermission!); } + [Fact] + public async Task DefineTool_Binds_CancellationToken_Parameter() + { + CancellationToken receivedToken = default; + var function = CopilotTool.DefineTool( + async (string value, CancellationToken cancellationToken) => + { + receivedToken = cancellationToken; + await Task.CompletedTask; + return value; + }, + factoryOptions: new() { Name = "echo", Description = "Echo a value" }); + + var schema = function.JsonSchema.GetRawText(); + Assert.Contains("\"value\"", schema); + Assert.DoesNotContain("\"cancellationToken\"", schema); + + using var cts = new CancellationTokenSource(); + using var document = JsonDocument.Parse("\"hello\""); + await function.InvokeAsync(new AIFunctionArguments + { + ["value"] = document.RootElement.Clone(), + }, cts.Token); + + Assert.Equal(cts.Token, receivedToken); + } + + [Fact] + public async Task DefineTool_Exposes_CancellationToken_On_ToolInvocation() + { + CancellationToken receivedToken = default; + var function = CopilotTool.DefineTool( + async (string value, ToolInvocation invocation) => + { + receivedToken = invocation.CancellationToken; + await Task.CompletedTask; + return value; + }, + factoryOptions: new() { Name = "echo", Description = "Echo a value" }); + + using var cts = new CancellationTokenSource(); + using var document = JsonDocument.Parse("\"hello\""); + await function.InvokeAsync(new AIFunctionArguments + { + ["value"] = document.RootElement.Clone(), + Context = new Dictionary + { + [typeof(ToolInvocation)] = new ToolInvocation { ToolName = "echo", CancellationToken = cts.Token } + } + }); + + Assert.Equal(cts.Token, receivedToken); + } + [DisplayName("test_tool")] [Description("Test tool")] private static string ReturnsOk() => "ok";