diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 5a5d34dcd..a4479b75c 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -85,6 +85,13 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private List? _modelsCache; private ServerRpc? _serverRpc; + /// + /// Client-global RPC handlers (e.g. the LLM inference provider adapter), + /// built once at construction when the corresponding option is configured and + /// registered on every connection. Null when no client-global API is enabled. + /// + private readonly ClientGlobalApiHandlers? _clientGlobalApis; + private sealed record LifecycleSubscription(Type EventType, Action Handler); /// @@ -165,6 +172,8 @@ public CopilotClient(CopilotClientOptions? options = null) _logger = _options.Logger ?? NullLogger.Instance; _onListModels = _options.OnListModels; + _clientGlobalApis = BuildClientGlobalApis(); + // Empty mode: validate at construction time that the app supplied a // per-session persistence location. The runtime is mode-agnostic, so // without this check it would silently fall back to ~/.copilot, which @@ -276,6 +285,8 @@ async Task StartCoreAsync(CancellationToken ct) sessionFsTimestamp); } + await ConfigureLlmInferenceAsync(ct); + LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.StartAsync complete. Elapsed={Elapsed}", startTimestamp); @@ -426,7 +437,6 @@ private async Task CleanupConnectionAsync(List? errors, bool graceful private async Task CleanupConnectionAsync(Connection ctx, List? errors, bool gracefulRuntimeShutdown) { - var runtimeShutdownCompleted = false; if (gracefulRuntimeShutdown && ctx.CliProcess is not null) { var runtimeShutdownTimestamp = Stopwatch.GetTimestamp(); @@ -434,7 +444,6 @@ private async Task CleanupConnectionAsync(Connection ctx, List? error { using var cancellation = new CancellationTokenSource(s_runtimeShutdownTimeout); await ctx.Server.Runtime.ShutdownAsync(cancellation.Token); - runtimeShutdownCompleted = true; LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.StopAsync runtime shutdown complete. Elapsed={Elapsed}", runtimeShutdownTimestamp); @@ -466,11 +475,11 @@ or IOException if (ctx.CliProcess is { } childProcess) { - await CleanupCliProcessAsync(childProcess, ctx.StderrPump, errors, _logger, runtimeShutdownCompleted); + await CleanupCliProcessAsync(childProcess, ctx.StderrPump, errors, _logger); } } - private static async Task CleanupCliProcessAsync(Process childProcess, ProcessStderrPump? stderrPump, List? errors, ILogger? logger, bool waitForGracefulExit = false) + private static async Task CleanupCliProcessAsync(Process childProcess, ProcessStderrPump? stderrPump, List? errors, ILogger? logger) { stderrPump?.Cancel(); @@ -478,30 +487,12 @@ private static async Task CleanupCliProcessAsync(Process childProcess, ProcessSt { if (!childProcess.HasExited) { - if (waitForGracefulExit) - { - var shutdownWaitTimestamp = Stopwatch.GetTimestamp(); - try - { - await childProcess.WaitForExitAsync().WaitAsync(s_runtimeShutdownTimeout); - } - catch (TimeoutException ex) - { - if (logger is not null) - { - LoggingHelpers.LogTiming(logger, LogLevel.Debug, ex, - "Timed out waiting for runtime process to exit after graceful shutdown. Elapsed={Elapsed}, Timeout={Timeout}", - shutdownWaitTimestamp, - s_runtimeShutdownTimeout); - } - } - } - - if (childProcess.HasExited) - { - return; - } - + // The runtime completes all cleanup before responding to + // runtime.shutdown and then leaves termination to us; it + // deliberately keeps its JSON-RPC server alive to send the + // response and never self-exits. Waiting for a self-exit that + // will never come just wastes time, so terminate the child + // immediately and only wait to reap it. childProcess.Kill(entireProcessTree: true); // Kill is asynchronous; wait for the root CLI process to exit so cleanup callers // do not observe StopAsync/DisposeAsync completion while it is still tearing down. @@ -1678,6 +1669,39 @@ await Rpc.SessionFs.SetProviderAsync( cancellationToken: cancellationToken); } + /// + /// Builds the client-global RPC handler bag at construction time. Currently + /// only the LLM inference provider adapter is registered; returns null when no + /// client-global API is configured so the registration is skipped entirely. + /// + private ClientGlobalApiHandlers? BuildClientGlobalApis() + { + var handler = _options.RequestHandler; + if (handler is null) + { + return null; + } + + return new ClientGlobalApiHandlers + { + LlmInference = new LlmInferenceAdapter(handler, () => _serverRpc), + }; + } + + /// + /// Tells the runtime to route its outbound model-layer requests through this + /// client's LLM inference provider. No-op when interception is not configured. + /// + private async Task ConfigureLlmInferenceAsync(CancellationToken cancellationToken) + { + if (_clientGlobalApis?.LlmInference is null) + { + return; + } + + await Rpc.LlmInference.SetProviderAsync(cancellationToken); + } + private void ConfigureSessionFsHandlers(CopilotSession session, Func? createSessionFsHandler) { if (_options.SessionFs is null) @@ -2072,6 +2096,10 @@ private async Task ConnectToServerAsync(Process? cliProcess, string? var session = GetSession(sessionId) ?? throw new ArgumentException($"Unknown session {sessionId}"); return session.ClientSessionApis; }); + if (_clientGlobalApis is not null) + { + ClientGlobalApiRegistration.RegisterClientGlobalApiHandlers(rpc, _clientGlobalApis); + } rpc.StartListening(); LoggingHelpers.LogTiming(_logger, LogLevel.Debug, null, "CopilotClient.ConnectToServerAsync transport setup complete. Elapsed={Elapsed}", diff --git a/dotnet/src/CopilotRequestHandler.cs b/dotnet/src/CopilotRequestHandler.cs new file mode 100644 index 000000000..8d407e9c2 --- /dev/null +++ b/dotnet/src/CopilotRequestHandler.cs @@ -0,0 +1,1023 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Rpc; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Net.WebSockets; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading.Channels; + +namespace GitHub.Copilot; + +/// +/// Transport the runtime would otherwise use to issue an intercepted +/// model-layer request. +/// +[Experimental(Diagnostics.Experimental)] +public enum CopilotRequestTransport +{ + /// + /// Plain HTTP or a streamed SSE response. Each body chunk is an opaque + /// byte range. + /// + Http, + + /// + /// Full-duplex WebSocket channel. Each request-body chunk is one inbound + /// WebSocket message and each response-body write is one outbound message. + /// + WebSocket, +} + +/// +/// Per-request context handed to every hook. +/// Exposes the routing and cancellation details of a single intercepted request +/// so overrides can observe or rewrite it. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class CopilotRequestContext +{ + /// Opaque runtime-minted id, stable across the request lifecycle. + public required string RequestId { get; init; } + + /// Runtime session id that triggered the request, if any. + public string? SessionId { get; init; } + + /// Transport the runtime would otherwise use. + public CopilotRequestTransport Transport { get; init; } + + /// Original request URL. + public required string Url { get; init; } + + /// Original request headers. + public required IReadOnlyDictionary> Headers { get; init; } + + /// + /// Cancelled when the runtime aborts this in-flight request. Subclasses that + /// issue their own I/O should pass this through so the upstream call is torn + /// down too. + /// + public CancellationToken CancellationToken { get; init; } + + internal LlmWebSocketResponseBridge? WebSocketResponse { get; set; } +} + +/// A single WebSocket message exchanged through a hook. +[Experimental(Diagnostics.Experimental)] +public readonly struct CopilotWebSocketMessage(ReadOnlyMemory data, bool isBinary) +{ + /// The message payload bytes. + public ReadOnlyMemory Data { get; } = data; + + /// True for a binary frame; false for a UTF-8 text frame. + public bool IsBinary { get; } = isBinary; + + /// Decodes the payload as UTF-8 text. + public string GetText() => Encoding.UTF8.GetString(Data.ToArray()); + + /// Creates a text message from a UTF-8 string. + public static CopilotWebSocketMessage Text(string text) => new(Encoding.UTF8.GetBytes(text), isBinary: false); + + /// Creates a binary message from raw bytes. + public static CopilotWebSocketMessage Binary(ReadOnlyMemory data) => new(data, isBinary: true); +} + +/// +/// Terminal status for a callback-owned WebSocket connection. +/// +[Experimental(Diagnostics.Experimental)] +public sealed class CopilotWebSocketCloseStatus +{ + /// The close description, if any. + public string? Description { get; init; } + + /// + /// Optional error code surfaced to the runtime when the close is a failure + /// rather than a clean end-of-stream. + /// + public string? ErrorCode { get; init; } + + /// The error that terminated the connection, if any. + public Exception? Error { get; init; } + + /// Shared normal-closure instance. + public static CopilotWebSocketCloseStatus NormalClosure { get; } = new(); +} + +/// +/// Lower-level WebSocket handler with no upstream connection. This is the +/// abstract base shared by all WebSocket handlers; it does not open or forward +/// to any upstream server on its own. Subclass it directly only to service a +/// fully synthetic connection yourself. For the common case of mutating and +/// forwarding traffic to the real upstream, subclass +/// instead, which connects upstream and +/// forwards by default. +/// +[Experimental(Diagnostics.Experimental)] +public abstract class CopilotWebSocketHandlerBase : IAsyncDisposable +{ + private readonly TaskCompletionSource _completion = + new(TaskCreationOptions.RunContinuationsAsynchronously); + private int _closed; + private bool _suppressCloseOnDispose; + + /// Request context for this WebSocket connection. + protected CopilotRequestContext Context { get; } + + internal Task Completion => _completion.Task; + + /// + /// Initializes a per-connection handler for the supplied request context. + /// + protected CopilotWebSocketHandlerBase(CopilotRequestContext context) + { + Context = context; + _ = context.WebSocketResponse ?? throw new InvalidOperationException("WebSocket response bridge is not attached."); + } + + /// + /// Send a message from the runtime to the upstream connection. + /// + public abstract Task SendRequestMessageAsync(CopilotWebSocketMessage message); + + /// + /// Send a message from the upstream connection back to the runtime. + /// Override to mutate or duplicate messages; call base to emit. + /// + public virtual Task SendResponseMessageAsync(CopilotWebSocketMessage message) => + Context.WebSocketResponse!.WriteAsync(message); + + /// + /// Close the connection and finalise the runtime-facing response. + /// + public virtual async Task CloseAsync(CopilotWebSocketCloseStatus status) + { + if (Interlocked.Exchange(ref _closed, 1) != 0) + { + return; + } + + if (status.Error is not null) + { + await Context.WebSocketResponse! + .ErrorAsync(status.Description ?? status.Error.Message, status.ErrorCode) + .ConfigureAwait(false); + } + else + { + await Context.WebSocketResponse!.EndAsync().ConfigureAwait(false); + } + + _completion.TrySetResult(status); + } + + internal void SuppressCloseOnDispose() => _suppressCloseOnDispose = true; + + internal virtual Task OpenAsync() => Task.CompletedTask; + + /// + public virtual async ValueTask DisposeAsync() + { + GC.SuppressFinalize(this); + if (!_suppressCloseOnDispose && Volatile.Read(ref _closed) == 0) + { + await CloseAsync(CopilotWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + } + } +} + +/// +/// WebSocket handler that connects to the real upstream and forwards traffic by +/// default. This is the type returned by the default +/// . Override nothing to +/// get full pass-through. To mutate traffic, subclass this type and override a +/// send method, then call the base implementation to keep forwarding upstream. +/// (Subclassing instead would drop +/// forwarding entirely.) +/// +[Experimental(Diagnostics.Experimental)] +public class CopilotWebSocketHandler : CopilotWebSocketHandlerBase +{ + private readonly string _url; + private readonly IReadOnlyDictionary> _headers; + private WebSocket? _upstream; + private CancellationTokenSource? _pumpCts; + private Task? _responsePump; + + /// + /// Initializes a forwarding handler that will open the upstream socket on + /// demand using the supplied URL/headers (or the values from + /// when omitted). + /// + public CopilotWebSocketHandler( + CopilotRequestContext context, + string? url = null, + IReadOnlyDictionary>? headers = null) + : base(context) + { + _url = url ?? context.Url; + _headers = headers ?? context.Headers; + } + + /// + /// Opens the upstream socket and starts the built-in response pump. + /// + internal override async Task OpenAsync() + { + if (_upstream is not null) + { + return; + } + + var socket = new ClientWebSocket(); + foreach (var (name, values) in _headers) + { + if (LlmInferenceHeaders.Forbidden.Contains(name)) + { + continue; + } + + try + { + socket.Options.SetRequestHeader(name, string.Join(", ", values)); + } + catch + { + // Some headers are managed by the handshake; ignore rejections. + } + } + + await socket.ConnectAsync(ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false); + _upstream = socket; + _pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken); + _responsePump = Task.Run(() => PumpResponsesAsync(_pumpCts.Token), _pumpCts.Token); + } + + /// + /// Sends a message from the runtime to the upstream connection. Subclasses may override to mutate messages. + /// + /// The message to send. + /// A representing the asynchronous operation. + public override Task SendRequestMessageAsync(CopilotWebSocketMessage message) + { + if (_upstream?.State != WebSocketState.Open) + { + return Task.CompletedTask; + } + + var type = message.IsBinary ? WebSocketMessageType.Binary : WebSocketMessageType.Text; + return _upstream.SendAsync( + new ArraySegment(message.Data.ToArray()), + type, + endOfMessage: true, + Context.CancellationToken); + } + + /// + public override async Task CloseAsync(CopilotWebSocketCloseStatus status) + { + _pumpCts?.Cancel(); + if (_upstream is not null) + { + await CloseWebSocketQuietlyAsync(_upstream).ConfigureAwait(false); + } + await base.CloseAsync(status).ConfigureAwait(false); + } + + /// + public override async ValueTask DisposeAsync() + { + GC.SuppressFinalize(this); + try + { + await base.DisposeAsync().ConfigureAwait(false); + } + finally + { + _pumpCts?.Cancel(); + _pumpCts?.Dispose(); + _upstream?.Dispose(); + if (_responsePump is not null) + { + await ObserveQuietlyAsync(_responsePump).ConfigureAwait(false); + } + } + } + + private async Task PumpResponsesAsync(CancellationToken cancellationToken) + { + if (_upstream is null) + { + return; + } + + try + { + while (_upstream.State == WebSocketState.Open) + { + var message = await ReceiveMessageAsync(_upstream, cancellationToken).ConfigureAwait(false); + if (message is null) + { + break; + } + + await SendResponseMessageAsync(message.Value).ConfigureAwait(false); + } + + await CloseAsync(CopilotWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + } + catch (OperationCanceledException) when (Context.CancellationToken.IsCancellationRequested) + { + // Runtime-side cancellation aborts the request pump; the outer + // handler rethrows that cancellation rather than finalising here. + } + catch (Exception ex) + { + await CloseAsync(new CopilotWebSocketCloseStatus + { + Description = ex.Message, + Error = ex, + }).ConfigureAwait(false); + } + } + + private static async Task ReceiveMessageAsync(WebSocket socket, CancellationToken cancellationToken) + { + var buffer = new byte[16 * 1024]; + using var assembled = new MemoryStream(); + WebSocketReceiveResult result; + do + { + try + { + result = await socket.ReceiveAsync(new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return null; + } + catch (WebSocketException) + { + return null; + } + + if (result.MessageType == WebSocketMessageType.Close) + { + return null; + } + + assembled.Write(buffer, 0, result.Count); + } + while (!result.EndOfMessage); + + return new CopilotWebSocketMessage(assembled.ToArray(), result.MessageType == WebSocketMessageType.Binary); + } + + private static async Task CloseWebSocketQuietlyAsync(WebSocket socket) + { + try + { + if (socket.State is WebSocketState.Open or WebSocketState.CloseReceived) + { + await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, statusDescription: null, CancellationToken.None).ConfigureAwait(false); + } + } + catch + { + // Best-effort; the socket may already be closed. + } + } + + [SuppressMessage("Usage", "CA1031:Do not catch general exception types", Justification = "Best-effort teardown of the losing pump.")] + private static async Task ObserveQuietlyAsync(Task task) + { + try + { + await task.ConfigureAwait(false); + } + catch + { + // Best-effort teardown only. + } + } + + private static Uri ToWebSocketUri(string url) + { + var builder = new UriBuilder(url); + if (builder.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase)) + { + builder.Scheme = "wss"; + } + else if (builder.Scheme.Equals("http", StringComparison.OrdinalIgnoreCase)) + { + builder.Scheme = "ws"; + } + + return builder.Uri; + } +} + +/// +/// Base class for SDK consumers who want to observe or mutate the LLM inference +/// requests the runtime issues (for both CAPI and BYOK providers). Subclass and +/// override or . +/// +[Experimental(Diagnostics.Experimental)] +public class CopilotRequestHandler +{ + private static readonly HttpClient s_sharedHttpClient = new(); + + /// + /// Issue the upstream HTTP request. Override to mutate the request before + /// calling base, mutate the returned response after, or replace the + /// call entirely. + /// + protected virtual Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) => + s_sharedHttpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, ctx.CancellationToken); + + /// + /// Open the upstream WebSocket connection. Override to return a custom + /// or to construct a + /// against a rewritten URL. + /// + protected virtual Task OpenWebSocketAsync(CopilotRequestContext ctx) => + Task.FromResult(new CopilotWebSocketHandler(ctx)); + + /// + /// Entry point invoked by the adapter once per intercepted request. Routes to + /// the HTTP or WebSocket flow and drives the consumer's overridable hooks. + /// + internal Task HandleAsync(LlmInferenceExchange exchange) => + exchange.Context.Transport == CopilotRequestTransport.WebSocket + ? HandleWebSocketAsync(exchange) + : HandleHttpAsync(exchange); + + private async Task HandleHttpAsync(LlmInferenceExchange exchange) + { + using var request = await BuildHttpRequestAsync(exchange).ConfigureAwait(false); + using var response = await SendRequestAsync(request, exchange.Context).ConfigureAwait(false); + await StreamResponseAsync(response, exchange).ConfigureAwait(false); + } + + private static async Task BuildHttpRequestAsync(LlmInferenceExchange exchange) + { + var method = new HttpMethod(exchange.Method.ToUpperInvariant()); + var message = new HttpRequestMessage(method, exchange.Context.Url); + + var hasBody = method != HttpMethod.Get && method != HttpMethod.Head; + var body = await DrainAsync(exchange.RequestBody).ConfigureAwait(false); + if (hasBody && body.Length > 0) + { + message.Content = new ByteArrayContent(body); + } + + foreach (var (name, values) in exchange.Context.Headers) + { + if (LlmInferenceHeaders.Forbidden.Contains(name)) + { + continue; + } + + if (!message.Headers.TryAddWithoutValidation(name, values)) + { + message.Content ??= new ByteArrayContent([]); + message.Content.Headers.TryAddWithoutValidation(name, values); + } + } + + return message; + } + + private static async Task StreamResponseAsync(HttpResponseMessage response, LlmInferenceExchange exchange) + { + await exchange.StartResponseAsync( + (int)response.StatusCode, + response.ReasonPhrase, + HeadersToMultiMap(response)).ConfigureAwait(false); + + var ct = exchange.Context.CancellationToken; +#if NETSTANDARD2_0 + using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); +#else + using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false); +#endif + var buffer = new byte[16 * 1024]; + int read; +#if NETSTANDARD2_0 + while ((read = await stream.ReadAsync(buffer, 0, buffer.Length, ct).ConfigureAwait(false)) > 0) +#else + while ((read = await stream.ReadAsync(buffer.AsMemory(), ct).ConfigureAwait(false)) > 0) +#endif + { + await exchange.WriteResponseAsync(new ReadOnlyMemory(buffer, 0, read)).ConfigureAwait(false); + } + + await exchange.EndResponseAsync().ConfigureAwait(false); + } + + private async Task HandleWebSocketAsync(LlmInferenceExchange exchange) + { + var ctx = exchange.Context; + var bridge = new LlmWebSocketResponseBridge(exchange); + ctx.WebSocketResponse = bridge; + + var handler = await OpenWebSocketAsync(ctx).ConfigureAwait(false); + try + { + await handler.OpenAsync().ConfigureAwait(false); + + // The runtime blocks the WebSocket connect until it receives the + // 101 response head (the upgrade acknowledgement) and only then + // begins forwarding inbound messages as request-body chunks. Emit + // it eagerly here — waiting for the first upstream message would + // deadlock, since the upstream stays silent until it receives a + // request message the runtime won't send before the upgrade + // completes. + await bridge.StartAsync().ConfigureAwait(false); + + var clientPump = Task.Run(async () => + { + await foreach (var chunk in exchange.RequestBody.WithCancellation(ctx.CancellationToken).ConfigureAwait(false)) + { + await handler.SendRequestMessageAsync(new CopilotWebSocketMessage(chunk, isBinary: false)).ConfigureAwait(false); + } + }, ctx.CancellationToken); + + var first = await Task.WhenAny(clientPump, handler.Completion).ConfigureAwait(false); + if (first == clientPump) + { + if (clientPump.IsFaulted || clientPump.IsCanceled) + { + handler.SuppressCloseOnDispose(); + await clientPump.ConfigureAwait(false); + } + + await handler.CloseAsync(CopilotWebSocketCloseStatus.NormalClosure).ConfigureAwait(false); + await handler.Completion.ConfigureAwait(false); + return; + } + + var closeStatus = await handler.Completion.ConfigureAwait(false); + if (closeStatus.Error is not null) + { + throw closeStatus.Error; + } + } + finally + { + await handler.DisposeAsync().ConfigureAwait(false); + } + } + + private static async Task DrainAsync(IAsyncEnumerable> stream) + { + using var buffer = new MemoryStream(); + await foreach (var chunk in stream.ConfigureAwait(false)) + { + if (chunk.Length > 0) + { + buffer.Write(chunk.ToArray(), 0, chunk.Length); + } + } + + return buffer.ToArray(); + } + + private static Dictionary> HeadersToMultiMap(HttpResponseMessage response) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var header in response.Headers) + { + result[header.Key] = [.. header.Value]; + } + + if (response.Content is not null) + { + foreach (var header in response.Content.Headers) + { + result[header.Key] = [.. header.Value]; + } + } + + return result; + } +} + +/// +/// One intercepted request in flight. Carries the request context plus the body +/// byte stream the runtime feeds in via httpRequestChunk frames, and +/// emits the consumer's response straight back to the runtime through the +/// generated llmInference server API. Replaces the former +/// provider/sink/response-channel indirection with a single object the adapter +/// owns and the handler writes to. +/// +internal sealed class LlmInferenceExchange +{ + private readonly Func _getServerRpc; + private readonly Channel _body = Channel.CreateUnbounded( + new UnboundedChannelOptions { SingleReader = true, SingleWriter = true }); + + private bool _started; + private bool _finished; + private bool _cancelled; + + internal LlmInferenceExchange(string requestId, Func getServerRpc) + { + RequestId = requestId; + _getServerRpc = getServerRpc; + } + + internal string RequestId { get; } + + internal string Method { get; set; } = "GET"; + + internal CopilotRequestContext Context { get; set; } = null!; + + internal CancellationTokenSource Abort { get; } = new(); + + internal bool Started => _started; + + internal bool Finished => _finished; + + internal bool Cancelled => _cancelled; + + // --- Request body feed (driven by the adapter as chunk frames arrive) --- + + internal void PushChunk(byte[] data) => _body.Writer.TryWrite(new BodyItem { Chunk = data }); + + internal void PushEnd() => _body.Writer.TryWrite(new BodyItem { End = true }); + + internal void PushCancel(string? reason) + { + _cancelled = true; + Abort.Cancel(); + _body.Writer.TryWrite(new BodyItem { Cancel = true, CancelReason = reason }); + } + + /// + /// Request body bytes, yielded as they arrive. A cancel frame surfaces as an + /// so the consumer's upstream call + /// is torn down. + /// + internal IAsyncEnumerable> RequestBody => ReadBodyAsync(Abort.Token); + + private async IAsyncEnumerable> ReadBodyAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + while (await _body.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + while (_body.Reader.TryRead(out var item)) + { + if (item.Cancel) + { + _body.Writer.TryComplete(); + throw new OperationCanceledException( + item.CancelReason is null + ? "Request cancelled by runtime" + : $"Request cancelled by runtime: {item.CancelReason}"); + } + + if (item.End) + { + _body.Writer.TryComplete(); + yield break; + } + + if (item.Chunk is { Length: > 0 }) + { + yield return item.Chunk; + } + } + } + } + + // --- Response emit (driven by the handler). Strict state machine: --- + // StartResponseAsync once -> zero or more WriteResponseAsync -> exactly one + // of EndResponseAsync / ErrorResponseAsync. + + internal async Task StartResponseAsync(int status, string? statusText, IReadOnlyDictionary>? headers) + { + if (_started) + { + throw new InvalidOperationException("LLM inference response StartAsync() called twice."); + } + + if (_finished) + { + throw new InvalidOperationException("LLM inference response already finished."); + } + + _started = true; + await ServerRpc() + .LlmInference.HttpResponseStartAsync(RequestId, status, ToWireHeaders(headers), statusText) + .ConfigureAwait(false); + } + + internal Task WriteResponseAsync(ReadOnlyMemory data) => + WriteChunkAsync(Convert.ToBase64String(data.ToArray()), binary: true); + + internal Task WriteResponseAsync(string text) + { + ArgumentNullException.ThrowIfNull(text); + return WriteChunkAsync(text, binary: false); + } + + internal async Task EndResponseAsync() + { + if (_finished) + { + return; + } + + _finished = true; + await ServerRpc().LlmInference.HttpResponseChunkAsync(RequestId, string.Empty, end: true).ConfigureAwait(false); + } + + internal async Task ErrorResponseAsync(string message, string? code = null) + { + ArgumentNullException.ThrowIfNull(message); + + if (_finished) + { + return; + } + + _finished = true; + await ServerRpc() + .LlmInference.HttpResponseChunkAsync( + RequestId, + string.Empty, + end: true, + error: new LlmInferenceHttpResponseChunkError { Message = message, Code = code }) + .ConfigureAwait(false); + } + + private async Task WriteChunkAsync(string data, bool binary) + { + if (_cancelled) + { + throw new InvalidOperationException("LLM inference request was cancelled by the runtime."); + } + + if (!_started) + { + throw new InvalidOperationException("LLM inference response WriteAsync() called before StartAsync()."); + } + + if (_finished) + { + throw new InvalidOperationException("LLM inference response WriteAsync() called after EndAsync()/ErrorAsync()."); + } + + await ServerRpc() + .LlmInference.HttpResponseChunkAsync(RequestId, data, binary: binary, end: false) + .ConfigureAwait(false); + } + + private ServerRpc ServerRpc() => + _getServerRpc() ?? throw new InvalidOperationException("LLM inference response used after RPC connection closed."); + + private static Dictionary> ToWireHeaders(IReadOnlyDictionary>? headers) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + if (headers is null) + { + return result; + } + + foreach (var (name, values) in headers) + { + result[name] = values as IList ?? [.. values]; + } + + return result; + } + + private struct BodyItem + { + public byte[]? Chunk; + public bool End; + public bool Cancel; + public string? CancelReason; + } +} + +/// +/// Adapts the generated RPC entry points onto +/// a consumer's . Each httpRequestStart +/// allocates an and runs the handler in the +/// background; subsequent httpRequestChunk frames feed its body stream. +/// +internal sealed class LlmInferenceAdapter(CopilotRequestHandler handler, Func getServerRpc) : ILlmInferenceHandler +{ + private readonly CopilotRequestHandler _handler = handler ?? throw new ArgumentNullException(nameof(handler)); + private readonly Func _getServerRpc = getServerRpc ?? throw new ArgumentNullException(nameof(getServerRpc)); + private readonly ConcurrentDictionary _pending = new(StringComparer.Ordinal); + + public Task HttpRequestStartAsync(LlmInferenceHttpRequestStartRequest request, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + var transport = request.Transport == LlmInferenceHttpRequestStartTransport.Websocket + ? CopilotRequestTransport.WebSocket + : CopilotRequestTransport.Http; + + // The runtime dispatches httpRequestStart and httpRequestChunk frames + // concurrently, so body chunks (including the terminal end frame) can + // arrive before this start frame runs. GetOrAdd adopts any exchange a + // racing chunk already created — with its buffered body — instead of + // dropping those frames and hanging the body drain. + var exchange = _pending.GetOrAdd(request.RequestId, id => new LlmInferenceExchange(id, _getServerRpc)); + exchange.Method = request.Method; + exchange.Context = new CopilotRequestContext + { + RequestId = request.RequestId, + SessionId = request.SessionId, + Transport = transport, + Url = request.Url, + Headers = ToReadOnlyHeaders(request.Headers), + CancellationToken = exchange.Abort.Token, + }; + + // Return from httpRequestStart immediately (after registering state) so + // the runtime's RPC reply is not gated on the consumer's I/O. The actual + // handler work runs asynchronously, exactly once per request. + _ = RunAsync(exchange); + + return Task.FromResult(new LlmInferenceHttpRequestStartResult()); + } + + public Task HttpRequestChunkAsync(LlmInferenceHttpRequestChunkRequest request, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(request); + + // A chunk may arrive before its matching httpRequestStart (frames are + // dispatched concurrently). GetOrAdd buffers the body into the + // exchange's channel so no chunk — in particular the terminal end + // frame — is ever lost; the start frame later adopts this same exchange. + var exchange = _pending.GetOrAdd(request.RequestId, id => new LlmInferenceExchange(id, _getServerRpc)); + RouteChunk(exchange, request); + + return Task.FromResult(new LlmInferenceHttpRequestChunkResult()); + } + + private async Task RunAsync(LlmInferenceExchange exchange) + { + try + { + await _handler.HandleAsync(exchange).ConfigureAwait(false); + if (!exchange.Finished) + { + await FinalizeAsync(exchange, 502, "LLM inference handler returned without finalising the response (call ResponseBody.EndAsync() or .ErrorAsync()).", code: null).ConfigureAwait(false); + } + } + catch (Exception ex) + { + if (exchange.Cancelled || exchange.Abort.IsCancellationRequested) + { + // The runtime already cancelled this request; the handler's throw + // is just the abort propagating out of its upstream call. + await FinalizeAsync(exchange, 499, "Request cancelled by runtime", code: "cancelled").ConfigureAwait(false); + return; + } + + await FinalizeAsync(exchange, 502, ex.Message, code: null).ConfigureAwait(false); + } + finally + { + _pending.TryRemove(exchange.RequestId, out _); + } + } + + private static async Task FinalizeAsync(LlmInferenceExchange exchange, int status, string message, string? code) + { + if (exchange.Finished) + { + return; + } + + try + { + if (!exchange.Started) + { + await exchange.StartResponseAsync(status, statusText: null, headers: null).ConfigureAwait(false); + } + + await exchange.ErrorResponseAsync(message, code).ConfigureAwait(false); + } + catch + { + // Best-effort — the connection may already be dead. + } + } + + private static void RouteChunk(LlmInferenceExchange exchange, LlmInferenceHttpRequestChunkRequest chunk) + { + if (chunk.Cancel == true) + { + exchange.PushCancel(chunk.CancelReason); + return; + } + + if (!string.IsNullOrEmpty(chunk.Data)) + { + exchange.PushChunk(DecodeChunkData(chunk.Data, chunk.Binary == true)); + } + + if (chunk.End == true) + { + exchange.PushEnd(); + } + } + + private static byte[] DecodeChunkData(string data, bool binary) => + binary ? Convert.FromBase64String(data) : Encoding.UTF8.GetBytes(data); + + private static Dictionary> ToReadOnlyHeaders(IDictionary> headers) + { + var result = new Dictionary>(StringComparer.OrdinalIgnoreCase); + foreach (var (name, values) in headers) + { + result[name] = values as IReadOnlyList ?? [.. values]; + } + + return result; + } +} + +/// +/// Forwards upstream WebSocket messages back to the owning +/// . The 101 upgrade head is emitted eagerly +/// via (the runtime gates the connect on it); +/// thereafter writes are serialised so the head always precedes any body or +/// terminal frame. +/// +internal sealed class LlmWebSocketResponseBridge(LlmInferenceExchange exchange) +{ + private readonly SemaphoreSlim _gate = new(1, 1); + private bool _started; + private bool _completed; + + /// Emit the 101 upgrade head now, acknowledging the WebSocket connect. + internal Task StartAsync() => RunAsync(terminal: false, () => Task.CompletedTask); + + internal Task WriteAsync(CopilotWebSocketMessage message) => RunAsync(terminal: false, () => + message.IsBinary + ? exchange.WriteResponseAsync(message.Data) + : exchange.WriteResponseAsync(message.GetText())); + + internal Task EndAsync() => RunAsync(terminal: true, () => exchange.EndResponseAsync()); + + internal Task ErrorAsync(string message, string? code) => + RunAsync(terminal: true, () => exchange.ErrorResponseAsync(message, code)); + + private async Task RunAsync(bool terminal, Func action) + { + await _gate.WaitAsync().ConfigureAwait(false); + try + { + if (_completed) + { + return; + } + + if (!_started) + { + _started = true; + await exchange.StartResponseAsync(101, statusText: null, headers: null).ConfigureAwait(false); + } + + if (terminal) + { + _completed = true; + } + + await action().ConfigureAwait(false); + } + finally + { + _gate.Release(); + } + } +} + +internal static class LlmInferenceHeaders +{ + // Computed/managed by the HTTP/WS stack; forwarding them verbatim either + // throws or corrupts the request. + internal static readonly HashSet Forbidden = new(StringComparer.OrdinalIgnoreCase) + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + }; +} diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 7ec632a5e..e23a8e3b5 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -10573,6 +10573,76 @@ public sealed class CanvasProviderInvokeActionRequest public string SessionId { get; set; } = string.Empty; } +/// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestStartResult +{ +} + +/// The head of an outbound model-layer HTTP request. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestStartRequest +{ + /// Gets or sets the headers value. + [JsonPropertyName("headers")] + public IDictionary> Headers { get => field ??= new Dictionary>(); set; } + + /// HTTP method, e.g. GET, POST. + [JsonPropertyName("method")] + public string Method { get; set; } = string.Empty; + + /// Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; + + /// Id of the runtime session that triggered this request, when one is in scope. Absent for requests issued outside any session (e.g. startup model-catalog or capability resolution). This is a payload field — not a dispatch key — because the client-global API is registered process-wide rather than per session. + [JsonPropertyName("sessionId")] + public string? SessionId { get; set; } + + /// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. + [JsonPropertyName("transport")] + public LlmInferenceHttpRequestStartTransport? Transport { get; set; } + + /// Absolute request URL. + [JsonPropertyName("url")] + public string Url { get; set; } = string.Empty; +} + +/// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestChunkResult +{ +} + +/// A request body chunk or cancellation signal. +[Experimental(Diagnostics.Experimental)] +public sealed class LlmInferenceHttpRequestChunkRequest +{ + /// When true, `data` is base64-encoded bytes. When absent or false, `data` is UTF-8 text. + [JsonPropertyName("binary")] + public bool? Binary { get; set; } + + /// When true, the runtime is cancelling the in-flight request (e.g. upstream consumer aborted). `data` is ignored. Implies end-of-request. + [JsonPropertyName("cancel")] + public bool? Cancel { get; set; } + + /// Optional human-readable reason for the cancellation, propagated for logging. + [JsonPropertyName("cancelReason")] + public string? CancelReason { get; set; } + + /// Body byte range. UTF-8 text when `binary` is absent or false; base64-encoded bytes when `binary` is true. May be empty. + [JsonPropertyName("data")] + public string Data { get; set; } = string.Empty; + + /// When true, this is the final body chunk for the request. The SDK may rely on having received an end-marked chunk before treating the request body as complete. + [JsonPropertyName("end")] + public bool? End { get; set; } + + /// Matches the requestId from the originating httpRequestStart frame. + [JsonPropertyName("requestId")] + public string RequestId { get; set; } = string.Empty; +} + /// Model capability category for grouping in the model picker. [JsonConverter(typeof(Converter))] [DebuggerDisplay("{Value,nq}")] @@ -16136,6 +16206,69 @@ public override void Write(Utf8JsonWriter writer, SessionFsSqliteQueryType value } +/// Transport the runtime would otherwise use for this request. `http` (the default when absent) covers plain HTTP and SSE responses; `websocket` indicates a full-duplex message channel where each body chunk maps to one WebSocket message and the `binary` flag distinguishes text from binary frames. The SDK consumer uses this to decide whether to service the request with an HTTP client or a WebSocket client. It is the one piece of request metadata the consumer cannot reliably infer from the URL or headers alone. +[Experimental(Diagnostics.Experimental)] +[JsonConverter(typeof(Converter))] +[DebuggerDisplay("{Value,nq}")] +public readonly struct LlmInferenceHttpRequestStartTransport : IEquatable +{ + private readonly string? _value; + + /// Initializes a new instance of the struct. + /// The value to associate with this . + [JsonConstructor] + public LlmInferenceHttpRequestStartTransport(string value) + { + ArgumentException.ThrowIfNullOrWhiteSpace(value); + _value = value; + } + + /// Gets the value associated with this . + public string Value => _value ?? string.Empty; + + /// Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. + public static LlmInferenceHttpRequestStartTransport Http { get; } = new("http"); + + /// Full-duplex WebSocket channel. Each body chunk maps to exactly one WebSocket message and the `binary` flag distinguishes text from binary frames; request and response chunks flow concurrently. + public static LlmInferenceHttpRequestStartTransport Websocket { get; } = new("websocket"); + + /// Returns a value indicating whether two instances are equivalent. + public static bool operator ==(LlmInferenceHttpRequestStartTransport left, LlmInferenceHttpRequestStartTransport right) => left.Equals(right); + + /// Returns a value indicating whether two instances are not equivalent. + public static bool operator !=(LlmInferenceHttpRequestStartTransport left, LlmInferenceHttpRequestStartTransport right) => !(left == right); + + /// + public override bool Equals(object? obj) => obj is LlmInferenceHttpRequestStartTransport other && Equals(other); + + /// + public bool Equals(LlmInferenceHttpRequestStartTransport other) => string.Equals(Value, other.Value, StringComparison.OrdinalIgnoreCase); + + /// + public override int GetHashCode() => StringComparer.OrdinalIgnoreCase.GetHashCode(Value); + + /// + public override string ToString() => Value; + + /// Provides a for serializing instances. + [EditorBrowsable(EditorBrowsableState.Never)] + public sealed class Converter : JsonConverter + { + /// + public override LlmInferenceHttpRequestStartTransport Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + return new(GeneratedStringEnumJson.ReadValue(ref reader, typeToConvert)); + } + + /// + public override void Write(Utf8JsonWriter writer, LlmInferenceHttpRequestStartTransport value, JsonSerializerOptions options) + { + GeneratedStringEnumJson.WriteValue(writer, value.Value, typeof(LlmInferenceHttpRequestStartTransport)); + } + } +} + + /// Provides server-scoped RPC methods (no session required). public sealed class ServerRpc { @@ -20306,6 +20439,53 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, FuncHandles `llmInference` client global API methods. +[Experimental(Diagnostics.Experimental)] +public interface ILlmInferenceHandler +{ + /// Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true). + /// The head of an outbound model-layer HTTP request. + /// The to monitor for cancellation requests. The default is . + /// Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + Task HttpRequestStartAsync(LlmInferenceHttpRequestStartRequest request, CancellationToken cancellationToken = default); + /// Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set. + /// A request body chunk or cancellation signal. + /// The to monitor for cancellation requests. The default is . + /// Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + Task HttpRequestChunkAsync(LlmInferenceHttpRequestChunkRequest request, CancellationToken cancellationToken = default); +} + +/// Provides all client global API handler groups for a connection. +public sealed class ClientGlobalApiHandlers +{ + /// Optional handler for LlmInference client global API methods. + public ILlmInferenceHandler? LlmInference { get; set; } +} + +/// Registers client global API handlers on a JSON-RPC connection. +internal static class ClientGlobalApiRegistration +{ + /// + /// Registers handlers for server-to-client global API calls. + /// Unlike client session APIs, these methods carry no implicit + /// sessionId dispatch key — a single set of handlers serves the + /// entire connection. + /// + public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiHandlers handlers) + { + rpc.SetLocalRpcMethod("llmInference.httpRequestStart", (Func>)(async (request, cancellationToken) => + { + var handler = handlers.LlmInference ?? throw new InvalidOperationException("No llmInference client-global handler registered"); + return await handler.HttpRequestStartAsync(request, cancellationToken); + }), singleObjectParam: true); + rpc.SetLocalRpcMethod("llmInference.httpRequestChunk", (Func>)(async (request, cancellationToken) => + { + var handler = handlers.LlmInference ?? throw new InvalidOperationException("No llmInference client-global handler registered"); + return await handler.HttpRequestChunkAsync(request, cancellationToken); + }), singleObjectParam: true); + } +} + [JsonSourceGenerationOptions( JsonSerializerDefaults.Web, AllowOutOfOrderMetadataProperties = true, @@ -20686,6 +20866,10 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func public SessionFsConfig? SessionFs { get; set; } + /// + /// Configures interception of the LLM inference requests the runtime would + /// otherwise issue itself (for both CAPI and BYOK providers). When set, the + /// client registers a client-global LLM inference handler on connect, so + /// every model-layer HTTP / WebSocket request is routed to this + /// subclass instead of the runtime's own + /// outbound call. + /// + [Experimental(Diagnostics.Experimental)] + public CopilotRequestHandler? RequestHandler { get; set; } + /// /// OpenTelemetry configuration for the runtime. /// When set to a non- instance, the runtime is started with OpenTelemetry instrumentation enabled. diff --git a/dotnet/test/E2E/CopilotRequestCancelErrorE2ETests.cs b/dotnet/test/E2E/CopilotRequestCancelErrorE2ETests.cs new file mode 100644 index 000000000..99e9e2546 --- /dev/null +++ b/dotnet/test/E2E/CopilotRequestCancelErrorE2ETests.cs @@ -0,0 +1,172 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Net.Http; +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// Cancellation and error coverage for . These +/// two scenarios exercise the handler's terminal paths that the happy-path +/// session-id and WebSocket tests never reach: +/// +/// +/// +/// Error — the handler throws from +/// for an inference request. The base adapter reports a transport error back to +/// the runtime rather than hanging. +/// +/// +/// +/// +/// Runtime cancel — the handler blocks an inference request indefinitely; +/// when the consumer aborts the turn the runtime cancels the in-flight request, +/// firing . The handler +/// observes the abort instead of leaking a stuck request. +/// +/// +/// +/// Non-inference model-layer requests (catalog, policy, model session) are served +/// via so the turn +/// reaches the inference step; the success-path SSE body is intentionally omitted +/// because neither scenario completes a turn. +/// +public class CopilotRequestCancelErrorE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "copilot_request_cancel_error", output) +{ + private CopilotClient CreateClientWith(CopilotRequestHandler handler) => + Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + RequestHandler = handler, + }); + + [Fact] + public async Task Reports_A_Thrown_Callback_Error_Instead_Of_Hanging() + { + var handler = new ThrowingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + try + { + // The callback throws on inference; the turn surfaces an error (or + // completes without an assistant message) rather than hanging. + await Record.ExceptionAsync(() => + session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." })); + } + finally + { + await session.DisposeAsync(); + } + + Assert.True(handler.InferenceAttempts > 0, "expected the inference callback to be reached and raise"); + } + + [Fact] + public async Task Observes_Runtime_Cancellation_Of_An_In_Flight_Inference_Request() + { + var handler = new CancellingRequestHandler(); + await using var client = CreateClientWith(handler); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + try + { + await session.SendAsync(new MessageOptions { Prompt = "Say OK." }); + await WaitForAsync(() => handler.InferenceEntered, TimeSpan.FromSeconds(60)); + await session.AbortAsync(); + await WaitForAsync(() => handler.SawAbort, TimeSpan.FromSeconds(30)); + } + finally + { + await session.DisposeAsync(); + } + + Assert.True(handler.InferenceEntered, "expected the inference callback to be entered"); + Assert.True(handler.SawAbort, "expected the callback to observe runtime cancellation"); + } + + private static async Task WaitForAsync(Func predicate, TimeSpan timeout) + { + var deadline = DateTime.UtcNow + timeout; + while (!predicate()) + { + if (DateTime.UtcNow > deadline) + { + throw new TimeoutException("WaitForAsync timed out"); + } + + await Task.Delay(50); + } + } +} + +/// Throws from every inference request to exercise the error-reporting path. +internal sealed class ThrowingRequestHandler : CopilotRequestHandler +{ + private int _inferenceAttempts; + + public int InferenceAttempts => Volatile.Read(ref _inferenceAttempts); + + protected override Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) + { + var url = request.RequestUri!.ToString(); + if (!RecordingRequestHandler.IsInferenceUrl(url)) + { + return Task.FromResult(RecordingRequestHandler.BuildNonInferenceResponse(url)); + } + + Interlocked.Increment(ref _inferenceAttempts); + throw new InvalidOperationException("synthetic-callback-transport-failure"); + } +} + +/// Blocks every inference request until the runtime cancels it. +internal sealed class CancellingRequestHandler : CopilotRequestHandler +{ + private volatile bool _inferenceEntered; + private volatile bool _sawAbort; + + public bool InferenceEntered => _inferenceEntered; + + public bool SawAbort => _sawAbort; + + protected override async Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) + { + var url = request.RequestUri!.ToString(); + if (!RecordingRequestHandler.IsInferenceUrl(url)) + { + return RecordingRequestHandler.BuildNonInferenceResponse(url); + } + + _inferenceEntered = true; + try + { + // Never produce a response; wait for the runtime to cancel us. + await Task.Delay(Timeout.Infinite, ctx.CancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + _sawAbort = true; + throw; + } + + return RecordingRequestHandler.BuildNonInferenceResponse(url); + } +} diff --git a/dotnet/test/E2E/CopilotRequestE2EProvider.cs b/dotnet/test/E2E/CopilotRequestE2EProvider.cs new file mode 100644 index 000000000..e92df5fae --- /dev/null +++ b/dotnet/test/E2E/CopilotRequestE2EProvider.cs @@ -0,0 +1,168 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using System.Collections.Concurrent; +using System.Net; +using System.Net.Http; +using System.Text; +using System.Text.RegularExpressions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// A subclass for e2e tests that records every +/// intercepted request (url + threaded session id) and fully replaces the +/// upstream call with a fabricated, well-formed response for every model-layer +/// endpoint, so an agent turn completes entirely off-network — no upstream +/// server and no CAPI proxy acting as the inference endpoint. +/// +/// +/// +/// This exercises the public extension surface end to end: a consumer subclasses +/// and overrides to +/// short-circuit the upstream HTTP call with any +/// it likes. The base class streams that response back to the runtime. +/// +/// +/// All response bodies are emitted as raw JSON string literals rather than via +/// JsonSerializer: the test project disables reflection-based STJ on +/// net8.0 (JsonSerializerIsReflectionEnabledByDefault=false), so +/// serializing anonymous types would throw at runtime. +/// +/// +internal sealed class RecordingRequestHandler : CopilotRequestHandler +{ + internal const string SyntheticText = "OK from the synthetic stream."; + + private static readonly Regex WantsStreamRegex = new("\"stream\"\\s*:\\s*true", RegexOptions.Compiled); + + private readonly ConcurrentQueue _records = new(); + + public IReadOnlyCollection Records => _records; + + public IReadOnlyList InferenceRequests => + [.. _records.Where(r => IsInferenceUrl(r.Url))]; + + protected override async Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) + { + var url = request.RequestUri!.ToString(); + _records.Enqueue(new InterceptedRequest(url, ctx.SessionId)); + + var bodyText = request.Content is null + ? string.Empty +#if NET8_0_OR_GREATER + : await request.Content.ReadAsStringAsync(ctx.CancellationToken).ConfigureAwait(false); +#else + : await request.Content.ReadAsStringAsync().ConfigureAwait(false); +#endif + + return IsInferenceUrl(url) + ? BuildInferenceResponse(url, bodyText) + : BuildNonInferenceResponse(url); + } + + internal static bool IsInferenceUrl(string url) + { + var u = url.ToLowerInvariant(); + return u.EndsWith("/chat/completions", StringComparison.Ordinal) + || u.EndsWith("/responses", StringComparison.Ordinal) + || u.EndsWith("/v1/messages", StringComparison.Ordinal) + || u.EndsWith("/messages", StringComparison.Ordinal); + } + + /// + /// Synthesizes a well-formed inference response so the agent turn completes. + /// The runtime selects /responses for both the CAPI and BYOK sessions + /// here; /chat/completions is handled too for robustness. + /// + private static HttpResponseMessage BuildInferenceResponse(string url, string bodyText) + { + var wantsStream = WantsStreamRegex.IsMatch(bodyText); + var u = url.ToLowerInvariant(); + + if (u.Contains("/responses", StringComparison.Ordinal)) + { + return wantsStream + ? Sse(string.Concat(ResponsesStreamEvents)) + : Json(BufferedResponseJson); + } + + if (u.Contains("/chat/completions", StringComparison.Ordinal) && wantsStream) + { + return Sse(string.Concat(ChatCompletionStreamEvents)); + } + + // /chat/completions non-streaming (and any other inference url) — buffered JSON. + return Json(BufferedChatCompletionJson); + } + + /// + /// Serves the non-inference model-layer GETs/POSTs the runtime issues + /// (catalog, model session, policy). These flow through the same callback + /// but carry no session id (they happen outside an agent turn). Shared with + /// the cancel/error e2e handlers so the turn can reach the inference step. + /// + internal static HttpResponseMessage BuildNonInferenceResponse(string url) + { + var u = url.ToLowerInvariant(); + if (u.EndsWith("/models", StringComparison.Ordinal)) + { + return Json(ModelCatalogJson); + } + + if (u.Contains("/models/session", StringComparison.Ordinal)) + { + return Json("{}"); + } + + if (u.Contains("/policy", StringComparison.Ordinal)) + { + return Json("{\"state\":\"enabled\"}"); + } + + return Json("{}"); + } + + internal static HttpResponseMessage Json(string body) => new(HttpStatusCode.OK) + { + Content = new StringContent(body, Encoding.UTF8, "application/json"), + }; + + private static HttpResponseMessage Sse(string body) => new(HttpStatusCode.OK) + { + Content = new StringContent(body, Encoding.UTF8, "text/event-stream"), + }; + + private static readonly string[] ResponsesStreamEvents = + [ + "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"in_progress\",\"output\":[]}}\n\n", + "event: response.output_item.added\ndata: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[]}}\n\n", + "event: response.content_part.added\ndata: {\"type\":\"response.content_part.added\",\"output_index\":0,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"text\":\"\"}}\n\n", + "event: response.output_text.delta\ndata: {\"type\":\"response.output_text.delta\",\"output_index\":0,\"content_index\":0,\"delta\":\"" + SyntheticText + "\"}\n\n", + "event: response.output_text.done\ndata: {\"type\":\"response.output_text.done\",\"output_index\":0,\"content_index\":0,\"text\":\"" + SyntheticText + "\"}\n\n", + "event: response.completed\ndata: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + SyntheticText + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}}\n\n", + ]; + + private static readonly string[] ChatCompletionStreamEvents = + [ + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"" + SyntheticText + "\"},\"finish_reason\":null}]}\n\n", + "data: {\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":7,\"total_tokens\":12}}\n\n", + "data: [DONE]\n\n", + ]; + + private static readonly string BufferedResponseJson = + "{\"id\":\"resp_stub_1\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + SyntheticText + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}"; + + private static readonly string BufferedChatCompletionJson = + "{\"id\":\"chatcmpl-stub-1\",\"object\":\"chat.completion\",\"created\":1,\"model\":\"claude-sonnet-4.5\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"" + SyntheticText + "\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":7,\"total_tokens\":12}}"; + + private const string ModelCatalogJson = + "{\"data\":[{\"id\":\"claude-sonnet-4.5\",\"name\":\"Claude Sonnet 4.5\",\"object\":\"model\",\"vendor\":\"Anthropic\",\"version\":\"1\",\"preview\":false,\"model_picker_enabled\":true,\"capabilities\":{\"type\":\"chat\",\"family\":\"claude-sonnet-4.5\",\"tokenizer\":\"o200k_base\",\"limits\":{\"max_context_window_tokens\":200000,\"max_output_tokens\":8192},\"supports\":{\"streaming\":true,\"tool_calls\":true,\"parallel_tool_calls\":true,\"vision\":true}}}]}"; +} + +/// A single request the callback intercepted. +internal sealed record InterceptedRequest(string Url, string? SessionId); diff --git a/dotnet/test/E2E/CopilotRequestSessionIdE2ETests.cs b/dotnet/test/E2E/CopilotRequestSessionIdE2ETests.cs new file mode 100644 index 000000000..e09c72c46 --- /dev/null +++ b/dotnet/test/E2E/CopilotRequestSessionIdE2ETests.cs @@ -0,0 +1,104 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// Asserts the runtime threads its session id into the LLM inference callback +/// for BOTH a CAPI session and a BYOK session. The callback alone services +/// every model-layer request — no upstream server, no CAPI proxy acting as the +/// inference endpoint — so the only source of req.SessionId is the +/// runtime's own per-client threading. +/// +public class CopilotRequestSessionIdE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "llm_inference_session_id", output) +{ + private CopilotClient CreateClientWith(RecordingRequestHandler provider) => + Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + RequestHandler = provider, + }); + + [Fact] + public async Task Threads_The_Session_Id_Into_A_Capi_Session_Inference_Request() + { + var provider = new RecordingRequestHandler(); + await using var client = CreateClientWith(provider); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + var capiSessionId = session.SessionId; + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + var inference = provider.InferenceRequests; + Assert.NotEmpty(inference); + Assert.All(inference, r => Assert.Equal(capiSessionId, r.SessionId)); + + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from the synthetic", content); + } + + [Fact] + public async Task Threads_The_Session_Id_Into_A_Byok_Session_Inference_Request() + { + var provider = new RecordingRequestHandler(); + await using var client = CreateClientWith(provider); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + // BYOK providers require an explicit model id. + Model = "claude-sonnet-4.5", + Provider = new ProviderConfig + { + Type = "openai", + WireApi = "responses", + BaseUrl = "https://byok.invalid/v1", + ApiKey = "byok-secret", + ModelId = "claude-sonnet-4.5", + WireModel = "claude-sonnet-4.5", + }, + }); + var byokSessionId = session.SessionId; + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + var inference = provider.InferenceRequests; + Assert.NotEmpty(inference); + Assert.All(inference, r => Assert.Equal(byokSessionId, r.SessionId)); + + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from the synthetic", content); + } +} diff --git a/dotnet/test/E2E/CopilotRequestWebSocketE2ETests.cs b/dotnet/test/E2E/CopilotRequestWebSocketE2ETests.cs new file mode 100644 index 000000000..f719c8f51 --- /dev/null +++ b/dotnet/test/E2E/CopilotRequestWebSocketE2ETests.cs @@ -0,0 +1,388 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +#if NET8_0_OR_GREATER + +using System.Net; +using System.Net.Sockets; +using System.Net.WebSockets; +using System.Text; +using GitHub.Copilot.Test.Harness; +using Xunit; +using Xunit.Abstractions; + +namespace GitHub.Copilot.Test.E2E; + +#pragma warning disable GHCP001 // The LLM inference surface is intentionally experimental. + +/// +/// Drives a full agent turn over the WebSocket inference transport through a +/// subclass. A single handler services both +/// transports against an in-process fake upstream: model-layer GETs and the +/// single-shot HTTP /responses call are forwarded over HTTP, while the +/// main turn flows over a real WebSocket opened by a +/// . +/// +/// +/// This is the regression test for the WebSocket upgrade deadlock: the runtime +/// blocks the WebSocket connect until it observes the 101 response head, so the +/// handler must emit it eagerly rather than waiting for the first upstream +/// message. Without the eager start the turn never completes and this test +/// times out. +/// +public class CopilotRequestWebSocketE2ETests(E2ETestFixture fixture, ITestOutputHelper output) + : E2ETestBase(fixture, "copilot_request_websocket", output) +{ + [Fact] + public async Task Services_A_WebSocket_Turn_End_To_End_Via_The_Request_Handler() + { + await using var upstream = new FakeCopilotUpstream(); + var counters = new HandlerCounters(); + var handler = new ForwardingUpstreamHandler(upstream.BaseUrl, counters); + + // Enable the WebSocket Responses transport in the spawned runtime so the + // main agent turn picks the WS path; single-shot calls still go over HTTP + // through the same handler. + var env = Ctx.GetEnvironment(); + env["COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES"] = "true"; + + await using var client = Ctx.CreateClient(options: new CopilotClientOptions + { + Connection = RuntimeConnection.ForStdio(), + RequestHandler = handler, + Environment = env, + }); + await client.StartAsync(); + + var session = await client.CreateSessionAsync(new SessionConfig + { + OnPermissionRequest = PermissionHandler.ApproveAll, + }); + + string content; + try + { + var msg = await session.SendAndWaitAsync(new MessageOptions { Prompt = "Say OK." }); + content = msg?.Data.Content ?? string.Empty; + } + finally + { + await session.DisposeAsync(); + } + + // The HTTP hooks fired — the runtime issued model-layer GETs (catalog, + // policy) and possibly a single-shot inference, all forwarded over HTTP. + Assert.True(counters.HttpRequests > 0, "expected SendRequestAsync to fire"); + + // The WebSocket hooks fired — the main agent turn went over the WS path + // and we observed messages in both directions. + Assert.True(counters.WsRequestMessages > 0, "expected SendRequestMessageAsync (runtime -> upstream) to fire"); + Assert.True(counters.WsResponseMessages > 0, "expected SendResponseMessageAsync (upstream -> runtime) to fire"); + Assert.True(upstream.WsRequestMessageCount > 0, "expected upstream WS to receive request messages"); + + // The synthetic content surfaced in the assistant turn — proves the full + // chain (runtime -> handler -> upstream -> handler -> runtime) over the + // WebSocket transport is intact. + // Validate the final assistant response arrived (guards against truncated captures) + Assert.Contains("OK from synthetic", content); + } +} + +/// Cross-direction message counters shared with the test assertions. +internal sealed class HandlerCounters +{ + public int HttpRequests; + public int WsRequestMessages; + public int WsResponseMessages; +} + +/// +/// A that points every intercepted request at +/// the in-process : HTTP requests are rewritten +/// and forwarded by the base class, and WebSocket connections are opened against +/// the rewritten URL via a counting . +/// +internal sealed class ForwardingUpstreamHandler(string upstreamBaseUrl, HandlerCounters counters) : CopilotRequestHandler +{ + private readonly Uri _upstream = new(upstreamBaseUrl); + + protected override Task SendRequestAsync(HttpRequestMessage request, CopilotRequestContext ctx) + { + Interlocked.Increment(ref counters.HttpRequests); + request.RequestUri = Rewrite(request.RequestUri!); + return base.SendRequestAsync(request, ctx); + } + + protected override Task OpenWebSocketAsync(CopilotRequestContext ctx) + { + var wsUrl = Rewrite(new Uri(ctx.Url)).ToString(); + return Task.FromResult(new CountingForwardingWebSocketHandler(ctx, wsUrl, counters)); + } + + private Uri Rewrite(Uri original) => new UriBuilder(original) + { + Scheme = _upstream.Scheme, + Host = _upstream.Host, + Port = _upstream.Port, + }.Uri; +} + +/// +/// A pass-through forwarding handler that counts messages in both directions. +/// +internal sealed class CountingForwardingWebSocketHandler( + CopilotRequestContext context, + string url, + HandlerCounters counters) + : CopilotWebSocketHandler(context, url) +{ + public override Task SendRequestMessageAsync(CopilotWebSocketMessage message) + { + Interlocked.Increment(ref counters.WsRequestMessages); + return base.SendRequestMessageAsync(message); + } + + public override Task SendResponseMessageAsync(CopilotWebSocketMessage message) + { + Interlocked.Increment(ref counters.WsResponseMessages); + return base.SendResponseMessageAsync(message); + } +} + +/// +/// In-process upstream that speaks the CAPI shapes the runtime needs: model +/// catalog (advertising the WebSocket /responses endpoint), policy, a +/// single-shot HTTP /responses SSE stream, and a WebSocket endpoint at +/// /responses that answers each inbound response.create with the +/// ordered /responses events the reducer expects. +/// +internal sealed class FakeCopilotUpstream : IAsyncDisposable +{ + private const string HttpText = "OK from synthetic HTTP upstream."; + private const string WsText = "OK from synthetic WS upstream."; + + private readonly HttpListener _listener = new(); + private readonly CancellationTokenSource _cts = new(); + private readonly Task _loop; + private int _wsRequestMessages; + + public string BaseUrl { get; } + + public int WsRequestMessageCount => Volatile.Read(ref _wsRequestMessages); + + public FakeCopilotUpstream() + { + var port = GetFreePort(); + BaseUrl = $"http://127.0.0.1:{port}/"; + _listener.Prefixes.Add(BaseUrl); + _listener.Start(); + _loop = Task.Run(() => AcceptLoopAsync(_cts.Token), _cts.Token); + } + + private async Task AcceptLoopAsync(CancellationToken ct) + { + while (!ct.IsCancellationRequested) + { + HttpListenerContext context; + try + { + context = await _listener.GetContextAsync().ConfigureAwait(false); + } + catch + { + break; + } + + _ = Task.Run(() => HandleContextAsync(context, ct), ct); + } + } + + private async Task HandleContextAsync(HttpListenerContext context, CancellationToken ct) + { + try + { + if (context.Request.IsWebSocketRequest) + { + await HandleWebSocketAsync(context, ct).ConfigureAwait(false); + } + else + { + await HandleHttpAsync(context, ct).ConfigureAwait(false); + } + } + catch + { + // Best-effort: the runtime tears connections down as turns complete. + } + } + + private async Task HandleWebSocketAsync(HttpListenerContext context, CancellationToken ct) + { + var wsContext = await context.AcceptWebSocketAsync(subProtocol: null).ConfigureAwait(false); + var socket = wsContext.WebSocket; + var buffer = new byte[16 * 1024]; + + while (socket.State == WebSocketState.Open && !ct.IsCancellationRequested) + { + var message = await ReceiveTextAsync(socket, buffer, ct).ConfigureAwait(false); + if (message is null) + { + break; + } + + Interlocked.Increment(ref _wsRequestMessages); + + foreach (var (_, json) in ResponseEvents(WsText, "resp_stub_ws")) + { + var bytes = Encoding.UTF8.GetBytes(json); + await socket.SendAsync( + new ArraySegment(bytes), + WebSocketMessageType.Text, + endOfMessage: true, + ct).ConfigureAwait(false); + } + } + + if (socket.State == WebSocketState.Open || socket.State == WebSocketState.CloseReceived) + { + try + { + await socket.CloseAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None).ConfigureAwait(false); + } + catch + { + // Already torn down. + } + } + } + + private static async Task ReceiveTextAsync(WebSocket socket, byte[] buffer, CancellationToken ct) + { + using var assembled = new MemoryStream(); + WebSocketReceiveResult result; + do + { + result = await socket.ReceiveAsync(new ArraySegment(buffer), ct).ConfigureAwait(false); + if (result.MessageType == WebSocketMessageType.Close) + { + return null; + } + + assembled.Write(buffer, 0, result.Count); + } + while (!result.EndOfMessage); + + return Encoding.UTF8.GetString(assembled.ToArray()); + } + + private static async Task HandleHttpAsync(HttpListenerContext context, CancellationToken ct) + { + if (context.Request.HasEntityBody) + { + using var input = context.Request.InputStream; + var drain = new byte[8 * 1024]; + while (await input.ReadAsync(drain.AsMemory(), ct).ConfigureAwait(false) > 0) + { + // Discard the request body; the synthetic response is fixed. + } + } + + var path = context.Request.Url!.AbsolutePath.ToLowerInvariant(); + string contentType = "application/json"; + string body; + + if (path.EndsWith("/models", StringComparison.Ordinal)) + { + body = ModelCatalogJson; + } + else if (path.Contains("/models/session")) + { + body = "{}"; + } + else if (path.Contains("/policy")) + { + body = "{\"state\":\"enabled\"}"; + } + else if (path.EndsWith("/responses", StringComparison.Ordinal)) + { + contentType = "text/event-stream"; + body = BuildSse(HttpText, "resp_stub_http"); + } + else + { + body = "{}"; + } + + var bytes = Encoding.UTF8.GetBytes(body); + context.Response.StatusCode = 200; + context.Response.ContentType = contentType; + context.Response.ContentLength64 = bytes.Length; + await context.Response.OutputStream.WriteAsync(bytes.AsMemory(), ct).ConfigureAwait(false); + context.Response.OutputStream.Close(); + } + + private static string BuildSse(string text, string id) + { + var sb = new StringBuilder(); + foreach (var (type, json) in ResponseEvents(text, id)) + { + sb.Append("event: ").Append(type).Append("\ndata: ").Append(json).Append("\n\n"); + } + + return sb.ToString(); + } + + private static (string Type, string Json)[] ResponseEvents(string text, string id) => + [ + ("response.created", + "{\"type\":\"response.created\",\"response\":{\"id\":\"" + id + "\",\"object\":\"response\",\"status\":\"in_progress\",\"output\":[]}}"), + ("response.output_item.added", + "{\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[]}}"), + ("response.content_part.added", + "{\"type\":\"response.content_part.added\",\"output_index\":0,\"content_index\":0,\"part\":{\"type\":\"output_text\",\"text\":\"\"}}"), + ("response.output_text.delta", + "{\"type\":\"response.output_text.delta\",\"output_index\":0,\"content_index\":0,\"delta\":\"" + text + "\"}"), + ("response.output_text.done", + "{\"type\":\"response.output_text.done\",\"output_index\":0,\"content_index\":0,\"text\":\"" + text + "\"}"), + ("response.completed", + "{\"type\":\"response.completed\",\"response\":{\"id\":\"" + id + "\",\"object\":\"response\",\"status\":\"completed\",\"output\":[{\"id\":\"msg_1\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"" + text + "\"}]}],\"usage\":{\"input_tokens\":5,\"output_tokens\":7,\"total_tokens\":12}}}"), + ]; + + private const string ModelCatalogJson = + "{\"data\":[{\"id\":\"claude-sonnet-4.5\",\"name\":\"Claude Sonnet 4.5\",\"object\":\"model\",\"vendor\":\"Anthropic\",\"version\":\"1\",\"preview\":false,\"model_picker_enabled\":true,\"supported_endpoints\":[\"/responses\",\"ws:/responses\"],\"capabilities\":{\"type\":\"chat\",\"family\":\"claude-sonnet-4.5\",\"tokenizer\":\"o200k_base\",\"limits\":{\"max_context_window_tokens\":200000,\"max_output_tokens\":8192},\"supports\":{\"streaming\":true,\"tool_calls\":true,\"parallel_tool_calls\":true,\"vision\":true}}}]}"; + + private static int GetFreePort() + { + using var probe = new TcpListener(IPAddress.Loopback, 0); + probe.Start(); + return ((IPEndPoint)probe.LocalEndpoint).Port; + } + + public async ValueTask DisposeAsync() + { + _cts.Cancel(); + try + { + _listener.Stop(); + _listener.Close(); + } + catch + { + // Already stopped. + } + + try + { + await _loop.ConfigureAwait(false); + } + catch + { + // Accept loop unwinds on listener shutdown. + } + + _cts.Dispose(); + } +} + +#endif diff --git a/go/client.go b/go/client.go index af9044ad9..29dc98427 100644 --- a/go/client.go +++ b/go/client.go @@ -371,6 +371,15 @@ func (c *Client) Start(ctx context.Context) error { } } + // If a request handler was configured, register as the inference provider. + if c.options.RequestHandler != nil { + if _, err := c.RPC.LlmInference.SetProvider(ctx); err != nil { + killErr := c.killProcess() + c.state = stateError + return errors.Join(err, killErr) + } + } + c.state = stateConnected return nil } @@ -418,7 +427,6 @@ func (c *Client) Stop() error { c.startStopMux.Lock() defer c.startStopMux.Unlock() - runtimeShutdownCompleted := false if c.process != nil && !c.isExternalServer && c.RPC != nil { rpcClient := c.RPC runtimeShutdownStart := time.Now() @@ -434,7 +442,6 @@ func (c *Client) Stop() error { c.logDebugTiming(runtimeShutdownStart, "CopilotClient.Stop runtime shutdown failed") errs = append(errs, fmt.Errorf("failed to gracefully shut down runtime: %w", err)) } else { - runtimeShutdownCompleted = true c.logDebugTiming(runtimeShutdownStart, "CopilotClient.Stop runtime shutdown complete") } case <-time.After(runtimeShutdownTimeout): @@ -443,29 +450,14 @@ func (c *Client) Stop() error { } } - // Give runtime.shutdown a bounded window to let the child exit on its own - // before falling back to killing it. + // The runtime completes all cleanup before responding to runtime.shutdown + // and then leaves termination to us; it deliberately keeps its JSON-RPC + // server alive to send the response and never self-exits. Waiting for a + // self-exit that will never come just wastes time, so terminate the child + // immediately and only wait to reap it. if c.process != nil && !c.isExternalServer { - if c.processDone != nil { - if runtimeShutdownCompleted { - select { - case <-c.processDone: - c.osProcess.Swap(nil) - c.process = nil - case <-time.After(runtimeShutdownTimeout): - if err := c.killProcessAndWait(); err != nil { - errs = append(errs, err) - } - } - } else { - if err := c.killProcessAndWait(); err != nil { - errs = append(errs, err) - } - } - } else { - if err := c.killProcessAndWait(); err != nil { - errs = append(errs, err) - } + if err := c.killProcessAndWait(); err != nil { + errs = append(errs, err) } } c.process = nil @@ -2003,6 +1995,15 @@ func (c *Client) setupNotificationHandler() { } return session.clientSessionAPIs }) + if c.options.RequestHandler != nil { + adapter := newCopilotRequestAdapter(c.options.RequestHandler, func() *rpc.ServerLlmInferenceAPI { + if c.RPC == nil { + return nil + } + return c.RPC.LlmInference + }) + rpc.RegisterClientGlobalAPIHandlers(c.client, &rpc.ClientGlobalAPIHandlers{LlmInference: adapter}) + } } func (c *Client) handleSessionEvent(req sessionEventRequest) { diff --git a/go/copilot_request_handler.go b/go/copilot_request_handler.go new file mode 100644 index 000000000..20621ee1a --- /dev/null +++ b/go/copilot_request_handler.go @@ -0,0 +1,848 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package copilot + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/coder/websocket" + "github.com/github/copilot-sdk/go/rpc" +) + +// Hop-by-hop and length headers the transport recomputes; forwarding them +// verbatim corrupts the request. +var forbiddenRequestHeaders = map[string]struct{}{ + "host": {}, + "connection": {}, + "content-length": {}, + "transfer-encoding": {}, + "keep-alive": {}, + "upgrade": {}, + "proxy-connection": {}, + "te": {}, + "trailer": {}, +} + +func isForbiddenRequestHeader(name string) bool { + lower := strings.ToLower(name) + if _, ok := forbiddenRequestHeaders[lower]; ok { + return true + } + return strings.HasPrefix(lower, "sec-websocket-") +} + +var sharedHTTPTransport = func() http.RoundTripper { + t := http.DefaultTransport.(*http.Transport).Clone() + t.DisableCompression = true + return t +}() + +// CopilotRequestContext is the per-request context handed to every +// [CopilotRequestHandler] seam. +type CopilotRequestContext struct { + RequestID string + SessionID string + // Transport is "http" (covering plain HTTP and SSE) or "websocket". + Transport string + Method string + URL string + Headers http.Header + // Body yields request body frames as they arrive from the runtime. The + // channel is closed when the body ends or the request is cancelled. For + // WebSocket requests each frame's Binary flag distinguishes a binary frame + // from a UTF-8 text frame; for HTTP it is always a body byte chunk. + Body <-chan CopilotWebSocketMessage + // Context is cancelled when the runtime cancels this in-flight request. + Context context.Context +} + +// CopilotWebSocketCloseStatus is the terminal status for a callback-owned +// WebSocket connection. +type CopilotWebSocketCloseStatus struct { + Description string + ErrorCode string + Err error +} + +// CopilotWebSocketMessage is a single WebSocket frame exchanged through the +// handler seam. Binary distinguishes a binary frame from a UTF-8 text frame. +type CopilotWebSocketMessage struct { + Data []byte + Binary bool +} + +// Text decodes the frame payload as a UTF-8 string. +func (m CopilotWebSocketMessage) Text() string { return string(m.Data) } + +// NewTextMessage creates a text-frame message from a UTF-8 string. +func NewTextMessage(text string) CopilotWebSocketMessage { + return CopilotWebSocketMessage{Data: []byte(text), Binary: false} +} + +// NewBinaryMessage creates a binary-frame message from raw bytes. +func NewBinaryMessage(data []byte) CopilotWebSocketMessage { + return CopilotWebSocketMessage{Data: data, Binary: true} +} + +// CopilotRequestHandler is the idiomatic handler for intercepting or replacing +// LLM inference requests. HTTP requests are forwarded through Transport (an +// [http.RoundTripper]); supply a custom RoundTripper to mutate the request, +// post-process the response, or replace the call entirely. WebSocket requests +// are serviced by OpenWebSocket; supply one to return a custom handler. +// +// The default behaviour (both fields nil) transparently forwards HTTP through a +// shared transport and opens a forwarding WebSocket connection to the runtime's +// original URL. +type CopilotRequestHandler struct { + // Transport forwards HTTP requests. When nil a shared default transport is + // used. RoundTrip is called directly, so redirects are not followed. + Transport http.RoundTripper + // OpenWebSocket returns a per-connection WebSocket handler. When nil a + // transparent [ForwardingCopilotWebSocketHandler] to the request URL is opened. + OpenWebSocket func(ctx *CopilotRequestContext) (CopilotWebSocketHandler, error) +} + +// WebSocketResponseWriter forwards upstream→runtime WebSocket messages back +// into the runtime response. A [CopilotWebSocketHandler] receives one in +// [CopilotWebSocketHandler.Open]. +type WebSocketResponseWriter interface { + // SendText forwards an upstream text message to the runtime. + SendText(data []byte) error + // SendBinary forwards an upstream binary message to the runtime. + SendBinary(data []byte) error +} + +// CopilotWebSocketHandler is a per-connection WebSocket handler returned by +// [CopilotRequestHandler.OpenWebSocket]. The default implementation is +// [ForwardingCopilotWebSocketHandler]; a full transport replacement implements +// this interface directly. +type CopilotWebSocketHandler interface { + // Open establishes the connection and starts forwarding upstream→runtime + // messages into resp. It must not block. ctx is cancelled on teardown. + Open(ctx context.Context, resp WebSocketResponseWriter) error + // SendRequestMessage forwards one runtime→upstream message. + SendRequestMessage(ctx context.Context, msg CopilotWebSocketMessage) error + // Done is closed when the upstream connection completes (closed or errored). + Done() <-chan struct{} + // Err returns the terminal error after Done is closed, or nil on clean close. + Err() error + // Close tears down the connection. + Close() error +} + +// copilotContextKey is used to attach [CopilotRequestContext] to an +// [http.Request] so custom [http.RoundTripper] implementations can access +// metadata (e.g. SessionID) without additional parameters. +type copilotContextKey struct{} + +// RequestContextFrom returns the [CopilotRequestContext] attached to an +// http.Request by the adapter, or nil if not present. Call this from a custom +// [http.RoundTripper] to access metadata such as SessionID. +func RequestContextFrom(r *http.Request) *CopilotRequestContext { + v, _ := r.Context().Value(copilotContextKey{}).(*CopilotRequestContext) + return v +} + +func (h *CopilotRequestHandler) handle(rctx *CopilotRequestContext, sink *responseSink) error { + if rctx.Transport == "websocket" { + return h.handleWebSocket(rctx, sink) + } + return h.handleHTTP(rctx, sink) +} + +func (h *CopilotRequestHandler) roundTripper() http.RoundTripper { + if h.Transport != nil { + return h.Transport + } + return sharedHTTPTransport +} + +func (h *CopilotRequestHandler) handleHTTP(rctx *CopilotRequestContext, sink *responseSink) error { + httpReq, err := buildHTTPRequest(rctx) + if err != nil { + return err + } + resp, err := h.roundTripper().RoundTrip(httpReq) + if err != nil { + return err + } + defer resp.Body.Close() + return streamResponseToSink(resp, sink) +} + +func buildHTTPRequest(rctx *CopilotRequestContext) (*http.Request, error) { + body := drainBody(rctx.Body) + method := strings.ToUpper(rctx.Method) + var bodyReader io.Reader + if len(body) > 0 && method != http.MethodGet && method != http.MethodHead { + bodyReader = bytes.NewReader(body) + } + httpReq, err := http.NewRequestWithContext(rctx.Context, method, rctx.URL, bodyReader) + if err != nil { + return nil, err + } + // Attach rctx so custom RoundTripper implementations can read metadata + // (e.g. SessionID) via [RequestContextFrom]. + httpReq = httpReq.WithContext(context.WithValue(httpReq.Context(), copilotContextKey{}, rctx)) + for name, values := range rctx.Headers { + if isForbiddenRequestHeader(name) { + continue + } + for _, v := range values { + httpReq.Header.Add(name, v) + } + } + return httpReq, nil +} + +func drainBody(ch <-chan CopilotWebSocketMessage) []byte { + var buf bytes.Buffer + for frame := range ch { + buf.Write(frame.Data) + } + return buf.Bytes() +} + +func streamResponseToSink(resp *http.Response, sink *responseSink) error { + if err := sink.start(resp.StatusCode, statusText(resp), cloneHeader(resp.Header)); err != nil { + return err + } + buf := make([]byte, 32*1024) + for { + n, readErr := resp.Body.Read(buf) + if n > 0 { + frame := make([]byte, n) + copy(frame, buf[:n]) + if err := sink.writeText(frame); err != nil { + return err + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return sink.sinkError(readErr.Error(), "") + } + } + return sink.end() +} + +func statusText(resp *http.Response) string { + return strings.TrimSpace(strings.TrimPrefix(resp.Status, strconv.Itoa(resp.StatusCode))) +} + +func cloneHeader(h http.Header) http.Header { + out := http.Header{} + for k, vs := range h { + out[k] = append([]string(nil), vs...) + } + return out +} + +func (h *CopilotRequestHandler) handleWebSocket(rctx *CopilotRequestContext, sink *responseSink) error { + var handler CopilotWebSocketHandler + var err error + if h.OpenWebSocket != nil { + handler, err = h.OpenWebSocket(rctx) + } else { + handler = NewForwardingCopilotWebSocketHandler(rctx.URL, rctx.Headers) + } + if err != nil { + return err + } + + writer := &wsResponseWriter{sink: sink} + // Emit the 101 upgrade head eagerly — the runtime gates connect_via_callback + // on receiving httpResponseStart/101 before sending request chunks; a lazy + // first-write start deadlocks until timeout. + if err := writer.start(); err != nil { + return err + } + if err := handler.Open(rctx.Context, writer); err != nil { + return writer.fail(err.Error(), "") + } + defer func() { _ = handler.Close() }() + + clientDone := make(chan struct{}) + go func() { + defer close(clientDone) + for { + select { + case frame, ok := <-rctx.Body: + if !ok { + return + } + if err := handler.SendRequestMessage(rctx.Context, frame); err != nil { + return + } + case <-rctx.Context.Done(): + return + } + } + }() + + select { + case <-handler.Done(): + if e := handler.Err(); e != nil { + return writer.fail(e.Error(), "") + } + return writer.end() + case <-clientDone: + _ = handler.Close() + <-handler.Done() + if e := handler.Err(); e != nil { + return writer.fail(e.Error(), "") + } + return writer.end() + case <-rctx.Context.Done(): + return writer.fail("Request cancelled by runtime", "cancelled") + } +} + +// wsResponseWriter serialises WebSocket response writes into the sink. +type wsResponseWriter struct { + mu sync.Mutex + sink *responseSink + started bool + completed bool +} + +func (w *wsResponseWriter) start() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.started { + return nil + } + w.started = true + return w.sink.start(101, "", http.Header{}) +} + +func (w *wsResponseWriter) SendText(data []byte) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + return w.sink.writeText(data) +} + +func (w *wsResponseWriter) SendBinary(data []byte) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + return w.sink.writeBinary(data) +} + +func (w *wsResponseWriter) end() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + w.completed = true + return w.sink.end() +} + +func (w *wsResponseWriter) fail(message string, code string) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.completed { + return nil + } + w.completed = true + return w.sink.sinkError(message, code) +} + +// ForwardingCopilotWebSocketHandler is the default [CopilotWebSocketHandler]: +// it dials the real upstream and runs a receive loop forwarding upstream→runtime +// messages. Set OnSendRequestMessage / OnSendResponseMessage to observe, +// transform, or drop messages in either direction. +type ForwardingCopilotWebSocketHandler struct { + URL string + Headers http.Header + // OnSendRequestMessage observes or transforms each runtime→upstream frame. + // The frame type (text vs binary) is available via the message's Binary + // field and may be changed in the returned message. Return nil to drop the + // frame. + OnSendRequestMessage func(msg CopilotWebSocketMessage) *CopilotWebSocketMessage + // OnSendResponseMessage observes or transforms each upstream→runtime frame. + // The frame type (text vs binary) is available via the message's Binary + // field and may be changed in the returned message. Return nil to drop the + // frame. + OnSendResponseMessage func(msg CopilotWebSocketMessage) *CopilotWebSocketMessage + + conn *websocket.Conn + resp WebSocketResponseWriter + done chan struct{} + err error + closeOnce sync.Once +} + +// NewForwardingCopilotWebSocketHandler creates a forwarding handler targeting +// url with the given handshake headers. +func NewForwardingCopilotWebSocketHandler(url string, headers http.Header) *ForwardingCopilotWebSocketHandler { + return &ForwardingCopilotWebSocketHandler{URL: url, Headers: headers, done: make(chan struct{})} +} + +func (f *ForwardingCopilotWebSocketHandler) Open(ctx context.Context, resp WebSocketResponseWriter) error { + f.resp = resp + if f.done == nil { + f.done = make(chan struct{}) + } + opts := &websocket.DialOptions{HTTPHeader: f.dialHeaders()} + conn, _, err := websocket.Dial(ctx, f.URL, opts) + if err != nil { + return err + } + conn.SetReadLimit(-1) + f.conn = conn + go f.receiveLoop(ctx) + return nil +} + +func (f *ForwardingCopilotWebSocketHandler) dialHeaders() http.Header { + out := http.Header{} + for name, values := range f.Headers { + if isForbiddenRequestHeader(name) { + continue + } + for _, v := range values { + out.Add(name, v) + } + } + return out +} + +func (f *ForwardingCopilotWebSocketHandler) receiveLoop(ctx context.Context) { + defer close(f.done) + for { + typ, data, err := f.conn.Read(ctx) + if err != nil { + if websocket.CloseStatus(err) == websocket.StatusNormalClosure || websocket.CloseStatus(err) == websocket.StatusGoingAway { + f.err = nil + } else if ctx.Err() != nil { + f.err = nil + } else { + f.err = err + } + return + } + out := CopilotWebSocketMessage{Data: data, Binary: typ == websocket.MessageBinary} + if f.OnSendResponseMessage != nil { + transformed := f.OnSendResponseMessage(out) + if transformed == nil { + continue + } + out = *transformed + } + if out.Binary { + _ = f.resp.SendBinary(out.Data) + } else { + _ = f.resp.SendText(out.Data) + } + } +} + +func (f *ForwardingCopilotWebSocketHandler) SendRequestMessage(ctx context.Context, msg CopilotWebSocketMessage) error { + out := msg + if f.OnSendRequestMessage != nil { + transformed := f.OnSendRequestMessage(msg) + if transformed == nil { + return nil + } + out = *transformed + } + if f.conn == nil { + return nil + } + msgType := websocket.MessageText + if out.Binary { + msgType = websocket.MessageBinary + } + return f.conn.Write(ctx, msgType, out.Data) +} + +func (f *ForwardingCopilotWebSocketHandler) Done() <-chan struct{} { return f.done } +func (f *ForwardingCopilotWebSocketHandler) Err() error { return f.err } + +func (f *ForwardingCopilotWebSocketHandler) Close() error { + f.closeOnce.Do(func() { + if f.conn != nil { + _ = f.conn.Close(websocket.StatusNormalClosure, "") + } + }) + return nil +} + +// --- Internal adapter --- + +// frameQueue is an unbounded FIFO of body frames, decoupling the RPC dispatch +// goroutine (which only pushes) from the consumer goroutine (which pops). +type frameQueue struct { + mu sync.Mutex + cond *sync.Cond + items []CopilotWebSocketMessage + done bool +} + +func newFrameQueue() *frameQueue { + q := &frameQueue{} + q.cond = sync.NewCond(&q.mu) + return q +} + +func (q *frameQueue) push(m CopilotWebSocketMessage) { + q.mu.Lock() + if !q.done { + q.items = append(q.items, m) + } + q.cond.Signal() + q.mu.Unlock() +} + +func (q *frameQueue) close() { + q.mu.Lock() + q.done = true + q.cond.Broadcast() + q.mu.Unlock() +} + +func (q *frameQueue) pop() (CopilotWebSocketMessage, bool) { + q.mu.Lock() + defer q.mu.Unlock() + for len(q.items) == 0 && !q.done { + q.cond.Wait() + } + if len(q.items) > 0 { + m := q.items[0] + q.items = q.items[1:] + return m, true + } + return CopilotWebSocketMessage{}, false +} + +type pendingExchange struct { + mu sync.Mutex + queue *frameQueue + ctx context.Context + cancel context.CancelFunc + started bool + finished bool +} + +type copilotRequestAdapter struct { + handler *CopilotRequestHandler + getRPC func() *rpc.ServerLlmInferenceAPI + + mu sync.Mutex + pending map[string]*pendingExchange +} + +func newCopilotRequestAdapter(handler *CopilotRequestHandler, getRPC func() *rpc.ServerLlmInferenceAPI) rpc.LlmInferenceHandler { + return &copilotRequestAdapter{ + handler: handler, + getRPC: getRPC, + pending: make(map[string]*pendingExchange), + } +} + +// getOrCreateExchange returns the exchange for requestID, allocating one if it +// does not yet exist. The runtime dispatches httpRequestStart and +// httpRequestChunk frames on separate goroutines (see jsonrpc2.handleRequest), +// so a body chunk — including the terminal end frame — can arrive before its +// start frame runs. Creating the exchange (and its buffering frameQueue) on +// first touch means those chunks are buffered rather than dropped, instead of +// hanging the body drain forever. +func (a *copilotRequestAdapter) getOrCreateExchange(requestID string) *pendingExchange { + a.mu.Lock() + defer a.mu.Unlock() + if exchange, ok := a.pending[requestID]; ok { + return exchange + } + ctx, cancel := context.WithCancel(context.Background()) + exchange := &pendingExchange{queue: newFrameQueue(), ctx: ctx, cancel: cancel} + a.pending[requestID] = exchange + return exchange +} + +func (a *copilotRequestAdapter) HttpRequestStart(params *rpc.LlmInferenceHTTPRequestStartRequest) (*rpc.LlmInferenceHTTPRequestStartResult, error) { + // Adopt any exchange a racing chunk already created — with its buffered + // body — rather than dropping those frames. + exchange := a.getOrCreateExchange(params.RequestID) + ctx := exchange.ctx + bodyCh := make(chan CopilotWebSocketMessage) + + go func() { + defer close(bodyCh) + for { + m, ok := exchange.queue.pop() + if !ok { + return + } + select { + case bodyCh <- m: + case <-ctx.Done(): + return + } + } + }() + + transport := "http" + if params.Transport != nil { + transport = string(*params.Transport) + } + sessionID := "" + if params.SessionID != nil { + sessionID = *params.SessionID + } + headers := http.Header{} + for k, v := range params.Headers { + headers[k] = append([]string(nil), v...) + } + + rctx := &CopilotRequestContext{ + RequestID: params.RequestID, + SessionID: sessionID, + Method: params.Method, + URL: params.URL, + Headers: headers, + Transport: transport, + Body: bodyCh, + Context: ctx, + } + sink := &responseSink{requestID: params.RequestID, adapter: a, exchange: exchange} + go a.runHandler(rctx, sink, exchange) + return &rpc.LlmInferenceHTTPRequestStartResult{}, nil +} + +func (a *copilotRequestAdapter) HttpRequestChunk(params *rpc.LlmInferenceHTTPRequestChunkRequest) (*rpc.LlmInferenceHTTPRequestChunkResult, error) { + // May arrive before the matching start frame (frames are dispatched on + // separate goroutines); get-or-create so the body is buffered, never lost. + exchange := a.getOrCreateExchange(params.RequestID) + a.routeChunk(exchange, params) + return &rpc.LlmInferenceHTTPRequestChunkResult{}, nil +} + +func (a *copilotRequestAdapter) routeChunk(exchange *pendingExchange, params *rpc.LlmInferenceHTTPRequestChunkRequest) { + if params.Cancel != nil && *params.Cancel { + exchange.cancel() + exchange.queue.close() + return + } + if params.Data != "" { + binary := params.Binary != nil && *params.Binary + if data, err := decodeChunkData(params.Data, binary); err == nil { + exchange.queue.push(CopilotWebSocketMessage{Data: data, Binary: binary}) + } + } + if params.End != nil && *params.End { + exchange.queue.close() + } +} + +func (a *copilotRequestAdapter) runHandler(rctx *CopilotRequestContext, sink *responseSink, exchange *pendingExchange) { + err := a.handler.handle(rctx, sink) + if err != nil { + if exchange.ctx.Err() != nil { + a.finishCancelled(sink, exchange) + return + } + a.failViaSink(sink, exchange, err.Error()) + return + } + exchange.mu.Lock() + finished := exchange.finished + exchange.mu.Unlock() + if !finished { + a.failViaSink(sink, exchange, "CopilotRequestHandler returned without finalising the response") + } +} + +func (a *copilotRequestAdapter) failViaSink(sink *responseSink, exchange *pendingExchange, message string) { + exchange.mu.Lock() + finished := exchange.finished + started := exchange.started + exchange.mu.Unlock() + if finished { + return + } + if !started { + _ = sink.start(502, "", http.Header{}) + } + _ = sink.sinkError(message, "") +} + +func (a *copilotRequestAdapter) finishCancelled(sink *responseSink, exchange *pendingExchange) { + exchange.mu.Lock() + finished := exchange.finished + started := exchange.started + exchange.mu.Unlock() + if finished { + return + } + if !started { + _ = sink.start(499, "", http.Header{}) + } + _ = sink.sinkError("Request cancelled by runtime", "cancelled") +} + +func (a *copilotRequestAdapter) removePending(requestID string) { + a.mu.Lock() + delete(a.pending, requestID) + a.mu.Unlock() +} + +func decodeChunkData(data string, binary bool) ([]byte, error) { + if binary { + return base64.StdEncoding.DecodeString(data) + } + return []byte(data), nil +} + +// responseSink writes response frames to the runtime via RPC. +type responseSink struct { + requestID string + adapter *copilotRequestAdapter + exchange *pendingExchange +} + +func (s *responseSink) rpcAPI() (*rpc.ServerLlmInferenceAPI, error) { + r := s.adapter.getRPC() + if r == nil { + return nil, fmt.Errorf("CopilotRequestHandler response sink used after RPC connection closed") + } + return r, nil +} + +func (s *responseSink) start(status int, statusTxt string, headers http.Header) error { + s.exchange.mu.Lock() + if s.exchange.started { + s.exchange.mu.Unlock() + return fmt.Errorf("CopilotRequestHandler response sink Start() called twice") + } + if s.exchange.finished { + s.exchange.mu.Unlock() + return fmt.Errorf("CopilotRequestHandler response sink already finished") + } + s.exchange.started = true + s.exchange.mu.Unlock() + + api, err := s.rpcAPI() + if err != nil { + return err + } + var st *string + if statusTxt != "" { + st = &statusTxt + } + h := map[string][]string(headers) + if h == nil { + h = map[string][]string{} + } + _, err = api.HttpResponseStart(context.Background(), &rpc.LlmInferenceHTTPResponseStartRequest{ + RequestID: s.requestID, + Status: int64(status), + StatusText: st, + Headers: h, + }) + return err +} + +func (s *responseSink) writeText(data []byte) error { + return s.writeRaw(string(data), false) +} + +func (s *responseSink) writeBinary(data []byte) error { + return s.writeRaw(base64.StdEncoding.EncodeToString(data), true) +} + +func (s *responseSink) writeRaw(data string, binary bool) error { + s.exchange.mu.Lock() + started := s.exchange.started + finished := s.exchange.finished + s.exchange.mu.Unlock() + if !started { + return fmt.Errorf("CopilotRequestHandler response sink Write() called before Start()") + } + if finished { + return fmt.Errorf("CopilotRequestHandler response sink Write() called after End()/Error()") + } + api, err := s.rpcAPI() + if err != nil { + return err + } + end := false + chunk := &rpc.LlmInferenceHTTPResponseChunkRequest{ + RequestID: s.requestID, + Data: data, + End: &end, + } + if binary { + b := true + chunk.Binary = &b + } + _, err = api.HttpResponseChunk(context.Background(), chunk) + return err +} + +func (s *responseSink) end() error { + s.exchange.mu.Lock() + if s.exchange.finished { + s.exchange.mu.Unlock() + return nil + } + s.exchange.finished = true + s.exchange.mu.Unlock() + s.adapter.removePending(s.requestID) + api, err := s.rpcAPI() + if err != nil { + return err + } + end := true + _, err = api.HttpResponseChunk(context.Background(), &rpc.LlmInferenceHTTPResponseChunkRequest{ + RequestID: s.requestID, + Data: "", + End: &end, + }) + return err +} + +func (s *responseSink) sinkError(message string, code string) error { + s.exchange.mu.Lock() + if s.exchange.finished { + s.exchange.mu.Unlock() + return nil + } + s.exchange.finished = true + s.exchange.mu.Unlock() + s.adapter.removePending(s.requestID) + api, err := s.rpcAPI() + if err != nil { + return err + } + end := true + chunkErr := &rpc.LlmInferenceHTTPResponseChunkError{Message: message} + if code != "" { + c := code + chunkErr.Code = &c + } + _, err = api.HttpResponseChunk(context.Background(), &rpc.LlmInferenceHTTPResponseChunkRequest{ + RequestID: s.requestID, + Data: "", + End: &end, + Error: chunkErr, + }) + return err +} diff --git a/go/go.mod b/go/go.mod index 16114a0ab..586a5d336 100644 --- a/go/go.mod +++ b/go/go.mod @@ -8,6 +8,7 @@ require ( ) require ( + github.com/coder/websocket v1.8.15 github.com/google/uuid v1.6.0 go.opentelemetry.io/otel v1.35.0 go.opentelemetry.io/otel/trace v1.35.0 diff --git a/go/go.sum b/go/go.sum index ec2bbcc1e..e7ac53d5a 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,3 +1,5 @@ +github.com/coder/websocket v1.8.15 h1:6B2JPeOGlpff2Uz6vOEH1Vzpi0iUz20A+lPVhPHtNUA= +github.com/coder/websocket v1.8.15/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= diff --git a/go/internal/e2e/copilot_request_cancel_error_e2e_test.go b/go/internal/e2e/copilot_request_cancel_error_e2e_test.go new file mode 100644 index 000000000..c48a61702 --- /dev/null +++ b/go/internal/e2e/copilot_request_cancel_error_e2e_test.go @@ -0,0 +1,173 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "errors" + "io" + "net/http" + "sync" + "sync/atomic" + "testing" + "time" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// TestCopilotRequestCancelError covers the two terminal paths of +// CopilotRequestHandler that the happy-path handler and session-id tests never +// reach: +// +// - error: the Transport returns an error for an inference request → the +// adapter reports a transport error instead of hanging. +// - cancel: the Transport blocks indefinitely on an inference request; when +// the consumer aborts the turn the runtime cancels the in-flight request, +// firing the request's context cancellation. + +// --- error case --- + +type throwingTransport struct { + mu sync.Mutex + totalCalls int + callsBeforeError int +} + +func (tr *throwingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + tr.mu.Lock() + tr.totalCalls++ + tr.mu.Unlock() + + if isInferenceURL(req.URL.String()) { + // Drain the body so the request is fully consumed before erroring. + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + tr.mu.Lock() + tr.callsBeforeError++ + tr.mu.Unlock() + return nil, errors.New("synthetic-callback-transport-failure") + } + return buildNonInferenceResponse(req.URL.String()), nil +} + +func TestCopilotRequestError(t *testing.T) { + ctx := testharness.NewTestContext(t) + transport := &throwingTransport{} + handler := &copilot.CopilotRequestHandler{Transport: transport} + client := newCopilotRequestClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // The transport throws on inference; the agent layer surfaces it as an + // error or an event rather than hanging. + _, sendErr := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + _ = session.Disconnect() + + transport.mu.Lock() + total := transport.totalCalls + before := transport.callsBeforeError + transport.mu.Unlock() + + if total == 0 { + t.Fatal("Expected the transport to be invoked") + } + if before == 0 { + t.Fatal("Expected the inference transport call to be reached and raise") + } + if sendErr != nil && len(sendErr.Error()) == 0 { + t.Fatal("Expected a non-empty error string when an error surfaces") + } +} + +// --- cancel case --- + +type cancellingTransport struct { + inferenceEntered atomic.Bool + sawAbort atomic.Bool + abortSeen chan struct{} + once sync.Once +} + +func newCancellingTransport() *cancellingTransport { + return &cancellingTransport{abortSeen: make(chan struct{})} +} + +func (tr *cancellingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if !isInferenceURL(req.URL.String()) { + return buildNonInferenceResponse(req.URL.String()), nil + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + tr.inferenceEntered.Store(true) + // Block until the runtime cancels the request (via context cancellation). + <-req.Context().Done() + tr.sawAbort.Store(true) + tr.once.Do(func() { close(tr.abortSeen) }) + return nil, req.Context().Err() +} + +func waitFor(t *testing.T, predicate func() bool, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for !predicate() { + if time.Now().After(deadline) { + t.Fatal("waitFor timed out") + } + time.Sleep(50 * time.Millisecond) + } +} + +func TestCopilotRequestCancel(t *testing.T) { + ctx := testharness.NewTestContext(t) + transport := newCancellingTransport() + handler := &copilot.CopilotRequestHandler{Transport: transport} + client := newCopilotRequestClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if _, err := session.Send(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}); err != nil { + t.Fatalf("send failed: %v", err) + } + waitFor(t, transport.inferenceEntered.Load, 60*time.Second) + if err := session.Abort(t.Context()); err != nil { + t.Fatalf("abort failed: %v", err) + } + + select { + case <-transport.abortSeen: + case <-time.After(30 * time.Second): + t.Fatal("Timed out waiting for the transport to observe runtime cancellation") + } + _ = session.Disconnect() + + if !transport.inferenceEntered.Load() { + t.Fatal("Expected the inference transport call to be entered") + } + if !transport.sawAbort.Load() { + t.Fatal("Expected the transport to observe the runtime-driven cancellation") + } +} diff --git a/go/internal/e2e/copilot_request_handler_e2e_test.go b/go/internal/e2e/copilot_request_handler_e2e_test.go new file mode 100644 index 000000000..6d68a5c1e --- /dev/null +++ b/go/internal/e2e/copilot_request_handler_e2e_test.go @@ -0,0 +1,207 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + + "github.com/coder/websocket" + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +const ( + handlerHTTPText = "OK from synthetic HTTP upstream." + handlerWSText = "OK from synthetic WS upstream." +) + +// wsSupportedEndpoints advertises both HTTP /responses and WS /responses so +// the runtime picks the WebSocket path when the ExP flag is set. +var wsSupportedEndpoints = []string{"/responses", "ws:/responses"} + +type handlerCounters struct { + httpRequests atomic.Int32 + httpResponses atomic.Int32 + wsRequestMessages atomic.Int32 + wsResponseMessages atomic.Int32 + upstreamWSRequests atomic.Int32 +} + +func sseBody(text, respID string) string { + return buildResponsesSSEBody(text, respID) +} + +// startFakeUpstreams brings up a real HTTP upstream (catalog / policy / +// responses-SSE) and a real WebSocket upstream that echoes /responses events +// per inbound message. +func startFakeUpstreams(t *testing.T, counters *handlerCounters) (httpURL, wsURL string) { + t.Helper() + + httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.ToLower(strings.SplitN(r.URL.Path, "?", 2)[0]) + defer func() { _ = r.Body.Close() }() + switch { + case strings.HasSuffix(path, "/models"): + w.Header().Set("content-type", "application/json") + _, _ = w.Write([]byte(modelCatalogJSON(wsSupportedEndpoints))) + case strings.HasSuffix(path, "/models/session"): + w.Header().Set("content-type", "application/json") + _, _ = w.Write([]byte("{}")) + case strings.Contains(path, "/policy"): + w.Header().Set("content-type", "application/json") + _, _ = w.Write([]byte(`{"state":"enabled"}`)) + case strings.HasSuffix(path, "/responses"): + w.Header().Set("content-type", "text/event-stream") + _, _ = w.Write([]byte(sseBody(handlerHTTPText, "resp_stub_http"))) + default: + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":"not_found"}`)) + } + })) + t.Cleanup(httpSrv.Close) + + wsSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{InsecureSkipVerify: true}) + if err != nil { + return + } + defer c.Close(websocket.StatusNormalClosure, "") + c.SetReadLimit(-1) + bg := context.Background() + for { + _, _, readErr := c.Read(bg) + if readErr != nil { + return + } + counters.upstreamWSRequests.Add(1) + for _, event := range responsesEvents(handlerWSText, "resp_stub_ws") { + raw, _ := json.Marshal(event) + if err := c.Write(bg, websocket.MessageText, raw); err != nil { + return + } + } + } + })) + t.Cleanup(wsSrv.Close) + + return httpSrv.URL, "ws://" + strings.TrimPrefix(wsSrv.URL, "http://") +} + +type rewritingRoundTripper struct { + base *url.URL + counters *handlerCounters + inner http.RoundTripper +} + +func (rt *rewritingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + rt.counters.httpRequests.Add(1) + req.URL.Scheme = rt.base.Scheme + req.URL.Host = rt.base.Host + req.Host = rt.base.Host + req.Header.Set("x-test-mutated", "1") + resp, err := rt.inner.RoundTrip(req) + if err != nil { + return nil, err + } + rt.counters.httpResponses.Add(1) + resp.Header.Set("x-test-response-mutated", "1") + return resp, nil +} + +func TestCopilotRequestHandler(t *testing.T) { + ctx := testharness.NewTestContext(t) + counters := &handlerCounters{} + httpURL, wsURL := startFakeUpstreams(t, counters) + + httpBase, err := url.Parse(httpURL) + if err != nil { + t.Fatalf("Failed to parse upstream URL: %v", err) + } + wsBase, err := url.Parse(wsURL) + if err != nil { + t.Fatalf("Failed to parse upstream ws URL: %v", err) + } + + handler := &copilot.CopilotRequestHandler{ + Transport: &rewritingRoundTripper{ + base: httpBase, + counters: counters, + inner: http.DefaultTransport.(*http.Transport).Clone(), + }, + OpenWebSocket: func(rctx *copilot.CopilotRequestContext) (copilot.CopilotWebSocketHandler, error) { + parsed, perr := url.Parse(rctx.URL) + if perr != nil { + return nil, perr + } + parsed.Scheme = wsBase.Scheme + parsed.Host = wsBase.Host + fwd := copilot.NewForwardingCopilotWebSocketHandler(parsed.String(), rctx.Headers) + fwd.OnSendRequestMessage = func(msg copilot.CopilotWebSocketMessage) *copilot.CopilotWebSocketMessage { + counters.wsRequestMessages.Add(1) + return &msg + } + fwd.OnSendResponseMessage = func(msg copilot.CopilotWebSocketMessage) *copilot.CopilotWebSocketMessage { + counters.wsResponseMessages.Add(1) + return &msg + } + return fwd, nil + }, + } + + client := newCopilotRequestClient(ctx, handler, "COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES=true") + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + // The HTTP seam fired — the runtime issued model-layer GETs (catalog, + // policy) and possibly a single-shot inference through the RoundTripper. + if counters.httpRequests.Load() == 0 { + t.Fatal("Expected the HTTP RoundTripper to fire") + } + if counters.httpResponses.Load() == 0 { + t.Fatal("Expected the HTTP response mutation to fire") + } + + // The WebSocket seam fired — the main agent turn went over the WS path and + // we observed messages in both directions. + if counters.wsRequestMessages.Load() == 0 { + t.Fatal("Expected runtime → upstream ws messages") + } + if counters.wsResponseMessages.Load() == 0 { + t.Fatal("Expected upstream → runtime ws messages") + } + if counters.upstreamWSRequests.Load() == 0 { + t.Fatal("Expected the upstream WS to receive request messages") + } + + // Validate the final assistant response arrived (guards against truncated captures) + text := assistantText(result) + if !strings.Contains(text, "OK from synthetic") || !strings.Contains(text, "upstream") { + t.Fatalf("Expected synthetic upstream content in assistant reply, got %q", text) + } +} diff --git a/go/internal/e2e/copilot_request_helpers_test.go b/go/internal/e2e/copilot_request_helpers_test.go new file mode 100644 index 000000000..69e82c2ab --- /dev/null +++ b/go/internal/e2e/copilot_request_helpers_test.go @@ -0,0 +1,230 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "encoding/json" + "io" + "net/http" + "regexp" + "strings" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +// Shared synthetic-upstream helpers for the CopilotRequestHandler e2e tests. +// +// These tests have no recorded snapshots: the registered handler fabricates +// well-formed model responses and the runtime routes all of its model-layer +// HTTP/WebSocket traffic through that handler instead of the CAPI proxy. The +// helpers centralise the synthetic CAPI shapes (model catalog, policy, +// /responses SSE, /chat/completions) so each test focuses on the behaviour it +// is exercising. + +const syntheticResponseText = "OK from the synthetic stream." + +var streamTrueRe = regexp.MustCompile(`"stream"\s*:\s*true`) + +func isStreamingRequest(body string) bool { + return streamTrueRe.MatchString(body) +} + +func isInferenceURL(url string) bool { + u := strings.ToLower(url) + return strings.HasSuffix(u, "/chat/completions") || + strings.HasSuffix(u, "/responses") || + strings.HasSuffix(u, "/v1/messages") || + strings.HasSuffix(u, "/messages") +} + +func sseFrame(eventType string, data map[string]any) string { + raw, _ := json.Marshal(data) + return "event: " + eventType + "\ndata: " + string(raw) + "\n\n" +} + +func modelCatalogJSON(supportedEndpoints []string) string { + model := map[string]any{ + "id": "claude-sonnet-4.5", + "name": "Claude Sonnet 4.5", + "object": "model", + "vendor": "Anthropic", + "version": "1", + "preview": false, + "model_picker_enabled": true, + "capabilities": map[string]any{ + "type": "chat", + "family": "claude-sonnet-4.5", + "tokenizer": "o200k_base", + "limits": map[string]any{ + "max_context_window_tokens": 200000, + "max_output_tokens": 8192, + }, + "supports": map[string]any{ + "streaming": true, + "tool_calls": true, + "parallel_tool_calls": true, + "vision": true, + }, + }, + } + if supportedEndpoints != nil { + model["supported_endpoints"] = supportedEndpoints + } + raw, _ := json.Marshal(map[string]any{"data": []any{model}}) + return string(raw) +} + +// responsesEvents returns the ordered /responses event objects the runtime's +// reducer expects. Used raw (one object == one WebSocket message) for the WS +// path and SSE-framed for the HTTP path. +func responsesEvents(text, respID string) []map[string]any { + return []map[string]any{ + { + "type": "response.created", + "response": map[string]any{"id": respID, "object": "response", "status": "in_progress", "output": []any{}}, + }, + { + "type": "response.output_item.added", + "output_index": 0, + "item": map[string]any{"id": "msg_1", "type": "message", "role": "assistant", "content": []any{}}, + }, + { + "type": "response.content_part.added", + "output_index": 0, + "content_index": 0, + "part": map[string]any{"type": "output_text", "text": ""}, + }, + {"type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": text}, + {"type": "response.output_text.done", "output_index": 0, "content_index": 0, "text": text}, + { + "type": "response.completed", + "response": map[string]any{ + "id": respID, + "object": "response", + "status": "completed", + "output": []any{ + map[string]any{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": []any{map[string]any{"type": "output_text", "text": text}}, + }, + }, + "usage": map[string]any{"input_tokens": 5, "output_tokens": 7, "total_tokens": 12}, + }, + }, + } +} + +// buildResponsesSSEBody returns a complete SSE body for a /responses streaming response. +func buildResponsesSSEBody(text, respID string) string { + var sb strings.Builder + for _, event := range responsesEvents(text, respID) { + sb.WriteString(sseFrame(event["type"].(string), event)) + } + return sb.String() +} + +// buildInferenceResponse synthesizes a well-formed inference HTTP response. +func buildInferenceResponse(url string, bodyText string) *http.Response { + wantsStream := isStreamingRequest(bodyText) + u := strings.ToLower(url) + + if strings.Contains(u, "/responses") { + if wantsStream { + return buildSSEResponse(buildResponsesSSEBody(syntheticResponseText, "resp_stub_1")) + } + events := responsesEvents(syntheticResponseText, "resp_stub_1") + last := events[len(events)-1]["response"] + raw, _ := json.Marshal(last) + return buildJSONResponse(200, string(raw)) + } + + if strings.Contains(u, "/chat/completions") && wantsStream { + base := func() map[string]any { + return map[string]any{ + "id": "chatcmpl-stub-1", "object": "chat.completion.chunk", + "created": 1, "model": "claude-sonnet-4.5", + } + } + c1 := base() + c1["choices"] = []any{map[string]any{"index": 0, "delta": map[string]any{"role": "assistant", "content": ""}, "finish_reason": nil}} + c2 := base() + c2["choices"] = []any{map[string]any{"index": 0, "delta": map[string]any{"content": syntheticResponseText}, "finish_reason": nil}} + c3 := base() + c3["choices"] = []any{map[string]any{"index": 0, "delta": map[string]any{}, "finish_reason": "stop"}} + c3["usage"] = map[string]any{"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12} + var sb strings.Builder + for _, chunk := range []map[string]any{c1, c2, c3} { + raw, _ := json.Marshal(chunk) + sb.WriteString("data: " + string(raw) + "\n\n") + } + sb.WriteString("data: [DONE]\n\n") + return buildSSEResponse(sb.String()) + } + + raw, _ := json.Marshal(map[string]any{ + "id": "chatcmpl-stub-1", "object": "chat.completion", "created": 1, "model": "claude-sonnet-4.5", + "choices": []any{map[string]any{"index": 0, "message": map[string]any{"role": "assistant", "content": syntheticResponseText}, "finish_reason": "stop"}}, + "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + }) + return buildJSONResponse(200, string(raw)) +} + +// buildNonInferenceResponse serves catalog / session / policy endpoints. +func buildNonInferenceResponse(url string) *http.Response { + u := strings.ToLower(url) + switch { + case strings.HasSuffix(u, "/models"): + return buildJSONResponse(200, modelCatalogJSON(nil)) + case strings.Contains(u, "/models/session"): + return buildJSONResponse(200, "{}") + case strings.Contains(u, "/policy"): + return buildJSONResponse(200, `{"state":"enabled"}`) + } + return buildJSONResponse(200, "{}") +} + +func buildJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Status: http.StatusText(status), + Header: http.Header{"Content-Type": {"application/json"}}, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func buildSSEResponse(body string) *http.Response { + return &http.Response{ + StatusCode: 200, + Status: "OK", + Header: http.Header{"Content-Type": {"text/event-stream"}, "Cache-Control": {"no-cache"}}, + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func assistantText(msg *copilot.SessionEvent) string { + if msg == nil { + return "" + } + if d, ok := msg.Data.(*copilot.AssistantMessageData); ok { + return d.Content + } + return "" +} + +// newCopilotRequestClient builds a client wired to handler via RequestHandler. +// Each test that needs inference interception owns an isolated client carrying +// its own handler. extraEnv is appended to the spawned runtime's environment +// (e.g. to flip an ExP flag for the WS transport). +func newCopilotRequestClient(ctx *testharness.TestContext, handler *copilot.CopilotRequestHandler, extraEnv ...string) *copilot.Client { + return ctx.NewClient(func(o *copilot.ClientOptions) { + o.RequestHandler = handler + if len(extraEnv) > 0 { + o.Env = append(o.Env, extraEnv...) + } + }) +} diff --git a/go/internal/e2e/copilot_request_session_id_e2e_test.go b/go/internal/e2e/copilot_request_session_id_e2e_test.go new file mode 100644 index 000000000..809f77da7 --- /dev/null +++ b/go/internal/e2e/copilot_request_session_id_e2e_test.go @@ -0,0 +1,153 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package e2e + +import ( + "io" + "strings" + "sync" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" + "net/http" +) + +type interceptedRequest struct { + url string + sessionID string +} + +// recordingTransport intercepts every model-layer request, records its URL and +// session ID (extracted from the CopilotRequestContext attached to the +// http.Request), and synthesizes a well-formed response so turns complete. +type recordingTransport struct { + mu sync.Mutex + records []interceptedRequest +} + +func (rt *recordingTransport) RoundTrip(req *http.Request) (*http.Response, error) { + rctx := copilot.RequestContextFrom(req) + sessionID := "" + if rctx != nil { + sessionID = rctx.SessionID + } + rt.mu.Lock() + rt.records = append(rt.records, interceptedRequest{url: req.URL.String(), sessionID: sessionID}) + rt.mu.Unlock() + + bodyBytes := []byte(nil) + if req.Body != nil { + bodyBytes, _ = io.ReadAll(req.Body) + } + bodyText := string(bodyBytes) + + if isInferenceURL(req.URL.String()) { + return buildInferenceResponse(req.URL.String(), bodyText), nil + } + return buildNonInferenceResponse(req.URL.String()), nil +} + +func (rt *recordingTransport) inferenceRecords() []interceptedRequest { + rt.mu.Lock() + defer rt.mu.Unlock() + var out []interceptedRequest + for _, r := range rt.records { + if isInferenceURL(r.url) { + out = append(out, r) + } + } + return out +} + +func TestCopilotRequestSessionID(t *testing.T) { + ctx := testharness.NewTestContext(t) + transport := &recordingTransport{} + handler := &copilot.CopilotRequestHandler{Transport: transport} + client := newCopilotRequestClient(ctx, handler) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + var capiSessionID string + + t.Run("threads session id into a CAPI session", func(t *testing.T) { + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + }) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + capiSessionID = session.SessionID + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + inference := transport.inferenceRecords() + if len(inference) == 0 { + t.Fatal("Expected at least one intercepted inference request") + } + for _, r := range inference { + if r.sessionID != capiSessionID { + t.Fatalf("CAPI inference request must carry session id %q, got %q", capiSessionID, r.sessionID) + } + } + + // Validate the final assistant response arrived (guards against truncated captures) + if !strings.Contains(assistantText(result), "OK from the synthetic") { + t.Fatalf("Expected synthetic content in assistant reply, got %q", assistantText(result)) + } + }) + + t.Run("threads session id into a BYOK session", func(t *testing.T) { + before := len(transport.inferenceRecords()) + session, err := client.CreateSession(t.Context(), &copilot.SessionConfig{ + OnPermissionRequest: copilot.PermissionHandler.ApproveAll, + Model: "claude-sonnet-4.5", + Provider: &copilot.ProviderConfig{ + Type: "openai", + WireAPI: "responses", + BaseURL: "https://byok.invalid/v1", + APIKey: "byok-secret", + ModelID: "claude-sonnet-4.5", + WireModel: "claude-sonnet-4.5", + }, + }) + if err != nil { + t.Fatalf("Failed to create BYOK session: %v", err) + } + byokSessionID := session.SessionID + + result, err := session.SendAndWait(t.Context(), copilot.MessageOptions{Prompt: "Say OK."}) + if err != nil { + t.Fatalf("send_and_wait failed: %v", err) + } + _ = session.Disconnect() + + inference := transport.inferenceRecords() + if len(inference) <= before { + t.Fatal("Expected at least one intercepted BYOK inference request") + } + for _, r := range inference[before:] { + if r.sessionID != byokSessionID { + t.Fatalf("BYOK inference request must carry session id %q, got %q", byokSessionID, r.sessionID) + } + } + + if byokSessionID == capiSessionID { + t.Fatal("Expected per-session ids to differ between turns") + } + + // Validate the final assistant response arrived (guards against truncated captures) + if !strings.Contains(assistantText(result), "OK from the synthetic") { + t.Fatalf("Expected synthetic content in assistant reply, got %q", assistantText(result)) + } + }) +} diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index b5c10e9b3..3c93bbd34 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -17109,3 +17109,94 @@ func RegisterClientSessionAPIHandlers(client *jsonrpc2.Client, getHandlers func( return raw, nil }) } + +// Experimental: LlmInferenceHandler contains experimental APIs that may change or be +// removed. +type LlmInferenceHandler interface { + // HttpRequestChunk delivers a body byte range (or a cancellation signal) for a request + // previously announced via httpRequestStart, correlated by requestId. The runtime fires at + // least one chunk per request — when there is no body, a single chunk with empty data and + // end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; + // the SDK then stops issuing httpResponseChunk frames and may emit a terminal + // httpResponseChunk with error set. + // + // RPC method: llmInference.httpRequestChunk. + // + // Parameters: A request body chunk or cancellation signal. + // + // Returns: Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as + // fire-and-forget. + HttpRequestChunk(request *LlmInferenceHTTPRequestChunkRequest) (*LlmInferenceHTTPRequestChunkResult, error) + // HttpRequestStart announces an outbound model-layer HTTP request the runtime wants the SDK + // client to service. Carries the request head only; the body always follows as one or more + // httpRequestChunk frames keyed by the same requestId, even when the body is empty (a + // single chunk with end=true). + // + // RPC method: llmInference.httpRequestStart. + // + // Parameters: The head of an outbound model-layer HTTP request. + // + // Returns: Acknowledgement. Returning successfully simply means the SDK accepted the start + // frame; it does not imply the request will succeed. + HttpRequestStart(request *LlmInferenceHTTPRequestStartRequest) (*LlmInferenceHTTPRequestStartResult, error) +} + +// ClientGlobalAPIHandlers provides all client-global API handler groups. +// +// Unlike client-session handlers these carry no implicit session id dispatch +// key; a single set of handlers serves the entire connection. +type ClientGlobalAPIHandlers struct { + LlmInference LlmInferenceHandler +} + +func clientGlobalHandlerError(err error) *jsonrpc2.Error { + if err == nil { + return nil + } + var rpcErr *jsonrpc2.Error + if errors.As(err, &rpcErr) { + return rpcErr + } + return &jsonrpc2.Error{Code: -32603, Message: err.Error()} +} + +// RegisterClientGlobalAPIHandlers registers handlers for server-to-client client-global API +// calls. +func RegisterClientGlobalAPIHandlers(client *jsonrpc2.Client, handlers *ClientGlobalAPIHandlers) { + client.SetRequestHandler("llmInference.httpRequestChunk", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request LlmInferenceHTTPRequestChunkRequest + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + if handlers == nil || handlers.LlmInference == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: "No llmInference client-global handler registered"} + } + result, err := handlers.LlmInference.HttpRequestChunk(&request) + if err != nil { + return nil, clientGlobalHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) + client.SetRequestHandler("llmInference.httpRequestStart", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) { + var request LlmInferenceHTTPRequestStartRequest + if err := json.Unmarshal(params, &request); err != nil { + return nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)} + } + if handlers == nil || handlers.LlmInference == nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: "No llmInference client-global handler registered"} + } + result, err := handlers.LlmInference.HttpRequestStart(&request) + if err != nil { + return nil, clientGlobalHandlerError(err) + } + raw, err := json.Marshal(result) + if err != nil { + return nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)} + } + return raw, nil + }) +} diff --git a/go/types.go b/go/types.go index ba83c6b6d..7c0b56c12 100644 --- a/go/types.go +++ b/go/types.go @@ -116,6 +116,12 @@ type ClientOptions struct { // on connection, routing session-scoped file I/O through per-session // handlers. SessionFS *SessionFSConfig + // RequestHandler registers a connection-level LLM inference callback. When + // non-nil, the client registers as the inference provider on connect, and + // the runtime routes its model-layer HTTP and WebSocket traffic through + // this handler instead of issuing the calls itself. Works for both CAPI + // and BYOK sessions. + RequestHandler *CopilotRequestHandler // Telemetry configures OpenTelemetry integration for the runtime. // When non-nil, COPILOT_OTEL_ENABLED=true is set and any populated // fields are mapped to the corresponding environment variables. diff --git a/java/src/main/java/com/github/copilot/CopilotClient.java b/java/src/main/java/com/github/copilot/CopilotClient.java index 81b3c5f15..63b70e2df 100644 --- a/java/src/main/java/com/github/copilot/CopilotClient.java +++ b/java/src/main/java/com/github/copilot/CopilotClient.java @@ -248,11 +248,25 @@ private Connection startCoreBody() { RpcHandlerDispatcher dispatcher = new RpcHandlerDispatcher(sessions, lifecycleManager::dispatch, executor); dispatcher.registerHandlers(rpc); + // Register the LLM inference request handler when configured. + com.github.copilot.CopilotRequestHandler requestHandler = this.options.getRequestHandler(); + boolean hasLlmInference = requestHandler != null; + if (hasLlmInference) { + LlmInferenceAdapter llmAdapter = new LlmInferenceAdapter(requestHandler, + () -> connection.serverRpc().llmInference, executor); + llmAdapter.registerHandlers(rpc); + } + // Verify protocol version verifyProtocolVersion(connection); LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.start protocol verification complete. Elapsed={Elapsed}", startNanos); + // Register as the runtime's LLM inference provider once connected. + if (hasLlmInference) { + connection.serverRpc().llmInference.setProvider().join(); + } + LoggingHelpers.logTiming(LOG, Level.FINE, "CopilotClient.start complete. Elapsed={Elapsed}", startNanos); return connection; } catch (Exception e) { @@ -439,20 +453,22 @@ private CompletableFuture cleanupConnection(boolean gracefulRuntimeShutdow private static void cleanupCliProcess(Process process, boolean forceImmediately) { try { if (process.isAlive()) { - if (!forceImmediately && process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { - return; - } - + // The runtime completes all cleanup before responding to + // runtime.shutdown and then leaves termination to us; it + // deliberately keeps its JSON-RPC server alive to send the + // response and never self-exits. Waiting for a self-exit that + // will never come just wastes time, so terminate the child + // immediately and only wait to reap it. if (forceImmediately) { process.destroyForcibly(); if (!process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { LOG.fine("Process did not terminate within force kill timeout"); } return; - } else { - process.destroy(); } - if (!forceImmediately && process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + + process.destroy(); + if (process.waitFor(FORCE_KILL_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { return; } diff --git a/java/src/main/java/com/github/copilot/CopilotRequestContext.java b/java/src/main/java/com/github/copilot/CopilotRequestContext.java new file mode 100644 index 000000000..705cd903f --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotRequestContext.java @@ -0,0 +1,114 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * The per-request context handed to every {@link CopilotRequestHandler} hook. + * It exposes the routing and cancellation details of a single intercepted + * request so overrides can observe or rewrite it. + * + * @since 1.0.0 + */ +public final class CopilotRequestContext { + + private final String requestId; + private final String sessionId; + private final CopilotRequestTransport transport; + private final String url; + private final Map> headers; + private final CompletableFuture cancellation; + + private LlmWebSocketResponseBridge webSocketResponse; + + CopilotRequestContext(String requestId, String sessionId, CopilotRequestTransport transport, String url, + Map> headers, CompletableFuture cancellation) { + this.requestId = requestId; + this.sessionId = sessionId; + this.transport = transport; + this.url = url; + this.headers = headers; + this.cancellation = cancellation; + } + + /** + * Gets the opaque runtime-minted request id, stable across the request + * lifecycle. + * + * @return the request id + */ + public String requestId() { + return requestId; + } + + /** + * Gets the id of the runtime session that triggered this request, or + * {@code null} when the request was issued outside any session (for example the + * startup model catalog). + * + * @return the session id, or {@code null} + */ + public String sessionId() { + return sessionId; + } + + /** + * Gets the transport the runtime would otherwise use. + * + * @return the transport + */ + public CopilotRequestTransport transport() { + return transport; + } + + /** + * Gets the absolute request URL. + * + * @return the URL + */ + public String url() { + return url; + } + + /** + * Gets the request headers, multi-valued. + * + * @return the headers (never {@code null}) + */ + public Map> headers() { + return headers; + } + + /** + * A future that completes when the runtime cancels this in-flight request (for + * example because the agent turn was aborted upstream). Subclasses that issue + * their own I/O should pass it through so the upstream call is torn down too. + * + * @return the cancellation future + */ + public CompletableFuture cancellation() { + return cancellation; + } + + /** + * Whether the runtime has cancelled this in-flight request. + * + * @return {@code true} once the request has been cancelled + */ + public boolean isCancelled() { + return cancellation.isDone(); + } + + LlmWebSocketResponseBridge webSocketResponse() { + return webSocketResponse; + } + + void setWebSocketResponse(LlmWebSocketResponseBridge webSocketResponse) { + this.webSocketResponse = webSocketResponse; + } +} diff --git a/java/src/main/java/com/github/copilot/CopilotRequestHandler.java b/java/src/main/java/com/github/copilot/CopilotRequestHandler.java new file mode 100644 index 000000000..2de287397 --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotRequestHandler.java @@ -0,0 +1,229 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; + +/** + * The base class for SDK consumers who want to observe or replace the LLM + * inference requests the runtime issues (for both CAPI and BYOK providers). + *

+ * When set as the {@code requestHandler} on + * {@link com.github.copilot.rpc.CopilotClientOptions}, the runtime routes its + * model-layer HTTP and WebSocket traffic through this handler instead of + * issuing the calls itself. Subclass and override {@link #sendRequest} to + * mutate or replace HTTP calls, or {@link #openWebSocket} to mutate the + * handshake or return a fully custom {@link CopilotWebSocketHandlerBase}. + * + * @since 1.0.0 + */ +public class CopilotRequestHandler { + + private static final Set FORBIDDEN_REQUEST_HEADERS = Set.of("host", "connection", "content-length", + "transfer-encoding", "keep-alive", "upgrade", "proxy-connection", "te", "trailer"); + + private static final HttpClient SHARED_HTTP_CLIENT = HttpClient.newBuilder() + .followRedirects(HttpClient.Redirect.NEVER).build(); + + private static final int RESPONSE_CHUNK_SIZE = 32 * 1024; + + static boolean isForbiddenRequestHeader(String name) { + String lower = name.toLowerCase(Locale.ROOT); + return FORBIDDEN_REQUEST_HEADERS.contains(lower) || lower.startsWith("sec-websocket-"); + } + + /** + * The {@link HttpClient} used to forward HTTP requests. Override to supply a + * custom client (proxy, TLS, timeouts). The default never follows redirects, so + * 3xx responses are forwarded verbatim. + * + * @return the HTTP client + */ + protected HttpClient httpClient() { + return SHARED_HTTP_CLIENT; + } + + /** + * Forwards an HTTP request and returns the upstream response. The default sends + * {@code request} through {@link #httpClient()} and cancels the in-flight call + * when the runtime cancels the request. Override to mutate the request before + * sending, post-process the response, or replace the call entirely. + * + * @param request + * the request built from the runtime's inference request + * @param ctx + * the per-request context + * @return the upstream response, with the body as an {@link InputStream} + * @throws Exception + * if the request could not be completed + */ + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext ctx) throws Exception { + CompletableFuture> future = httpClient().sendAsync(request, + HttpResponse.BodyHandlers.ofInputStream()); + ctx.cancellation().whenComplete((v, t) -> future.cancel(true)); + return future.get(); + } + + /** + * Returns a per-connection WebSocket handler for a WebSocket request. The + * default opens a transparent forwarding connection to the request URL. + * Override to mutate the handshake (via {@code ctx}) or return a fully custom + * handler. + * + * @param ctx + * the per-request context + * @return the WebSocket handler + * @throws Exception + * if the handler could not be created + */ + protected CopilotWebSocketHandlerBase openWebSocket(CopilotRequestContext ctx) throws Exception { + return new CopilotWebSocketHandler(ctx); + } + + /** + * Entry point invoked by the adapter once per intercepted request. Routes to + * the HTTP or WebSocket flow and drives the consumer's overridable hooks. + */ + void handle(LlmInferenceExchange exchange) throws Exception { + if (exchange.context().transport() == CopilotRequestTransport.WEBSOCKET) { + handleWebSocket(exchange); + } else { + handleHttp(exchange); + } + } + + private void handleHttp(LlmInferenceExchange exchange) throws Exception { + HttpRequest httpRequest = buildHttpRequest(exchange); + HttpResponse response = sendRequest(httpRequest, exchange.context()); + streamResponse(response, exchange); + } + + private static HttpRequest buildHttpRequest(LlmInferenceExchange exchange) throws InterruptedException { + CopilotRequestContext ctx = exchange.context(); + String method = exchange.method() == null ? "GET" : exchange.method().toUpperCase(Locale.ROOT); + boolean bodyless = method.equals("GET") || method.equals("HEAD"); + byte[] body = bodyless ? new byte[0] : exchange.drainBody(); + HttpRequest.BodyPublisher publisher = body.length > 0 + ? HttpRequest.BodyPublishers.ofByteArray(body) + : HttpRequest.BodyPublishers.noBody(); + + HttpRequest.Builder builder = HttpRequest.newBuilder().uri(URI.create(ctx.url())).method(method, publisher); + Map> headers = ctx.headers(); + if (headers != null) { + for (Map.Entry> entry : headers.entrySet()) { + if (isForbiddenRequestHeader(entry.getKey()) || entry.getValue() == null) { + continue; + } + for (String value : entry.getValue()) { + builder.header(entry.getKey(), value); + } + } + } + return builder.build(); + } + + private static void streamResponse(HttpResponse response, LlmInferenceExchange exchange) + throws IOException { + exchange.startResponse(response.statusCode(), null, response.headers().map()); + try (InputStream body = response.body()) { + byte[] buffer = new byte[RESPONSE_CHUNK_SIZE]; + int n; + while ((n = body.read(buffer)) != -1) { + if (n > 0) { + byte[] frame = new byte[n]; + System.arraycopy(buffer, 0, frame, 0, n); + exchange.writeResponseBinary(frame); + } + } + } catch (IOException e) { + exchange.errorResponse(e.getMessage(), null); + return; + } + exchange.endResponse(); + } + + private void handleWebSocket(LlmInferenceExchange exchange) throws Exception { + CopilotRequestContext ctx = exchange.context(); + LlmWebSocketResponseBridge bridge = new LlmWebSocketResponseBridge(exchange); + ctx.setWebSocketResponse(bridge); + + CopilotWebSocketHandlerBase handler = openWebSocket(ctx); + try { + handler.open(); + + // The runtime blocks the WebSocket connect until it receives the 101 + // response head (the upgrade acknowledgement) and only then begins + // forwarding inbound messages as request-body chunks. Emit it eagerly + // here — waiting for the first upstream message would deadlock, since the + // upstream stays silent until it receives a request message the runtime + // won't send before the upgrade completes. + bridge.start(); + + CompletableFuture pumpDone = new CompletableFuture<>(); + Thread pump = new Thread(() -> { + try { + LlmInferenceExchange.BodyFrame frame; + while ((frame = exchange.readFrame()) != null) { + handler.sendRequestMessage(new CopilotWebSocketMessage(frame.data(), frame.binary())); + } + pumpDone.complete(null); + } catch (Exception e) { + pumpDone.completeExceptionally(e); + } + }, "llm-ws-request-pump"); + pump.setDaemon(true); + pump.start(); + + CompletableFuture.anyOf(pumpDone, handler.completion()).handle((v, t) -> null).join(); + + if (pumpDone.isDone() && !handler.completion().isDone()) { + if (isPumpFault(pumpDone)) { + handler.suppressCloseOnDispose(); + awaitPump(pumpDone); + return; + } + handler.close(CopilotWebSocketCloseStatus.NORMAL_CLOSURE); + handler.completion().join(); + return; + } + + CopilotWebSocketCloseStatus status = handler.completion().join(); + if (status.error() != null) { + throw asException(status.error()); + } + } finally { + handler.close(); + } + } + + private static boolean isPumpFault(CompletableFuture pumpDone) { + return pumpDone.isCompletedExceptionally(); + } + + private static void awaitPump(CompletableFuture pumpDone) throws Exception { + try { + pumpDone.join(); + } catch (CancellationException e) { + throw e; + } catch (Exception e) { + throw asException(e.getCause() != null ? e.getCause() : e); + } + } + + private static Exception asException(Throwable t) { + return t instanceof Exception e ? e : new RuntimeException(t); + } +} diff --git a/java/src/main/java/com/github/copilot/CopilotRequestTransport.java b/java/src/main/java/com/github/copilot/CopilotRequestTransport.java new file mode 100644 index 000000000..e1069de0b --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotRequestTransport.java @@ -0,0 +1,44 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +/** + * The transport the runtime would otherwise use to issue an intercepted + * model-layer request. + * + * @since 1.0.0 + */ +public enum CopilotRequestTransport { + + /** + * Plain HTTP or a streamed SSE response. Each request/response body chunk is an + * opaque byte range. + */ + HTTP, + + /** + * Full-duplex WebSocket channel. Each request-body chunk is one inbound + * WebSocket message and each response-body write is one outbound message. + */ + WEBSOCKET; + + /** The wire value for the plain HTTP and SSE transport. */ + static final String WIRE_HTTP = "http"; + + /** The wire value for the full-duplex WebSocket transport. */ + static final String WIRE_WEBSOCKET = "websocket"; + + /** + * Maps a wire transport string onto the enum, defaulting to {@link #HTTP} for + * {@code null} or any unrecognised value. + * + * @param wire + * the wire transport value + * @return the transport + */ + static CopilotRequestTransport fromWire(String wire) { + return WIRE_WEBSOCKET.equals(wire) ? WEBSOCKET : HTTP; + } +} diff --git a/java/src/main/java/com/github/copilot/CopilotWebSocketCloseStatus.java b/java/src/main/java/com/github/copilot/CopilotWebSocketCloseStatus.java new file mode 100644 index 000000000..6ced0182e --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotWebSocketCloseStatus.java @@ -0,0 +1,66 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +/** + * The terminal status for a callback-owned WebSocket connection. + * + * @since 1.0.0 + */ +public final class CopilotWebSocketCloseStatus { + + /** A shared normal-closure (clean end-of-stream) instance. */ + public static final CopilotWebSocketCloseStatus NORMAL_CLOSURE = new CopilotWebSocketCloseStatus(null, null, null); + + private final String description; + private final String errorCode; + private final Throwable error; + + /** + * Creates a close status. + * + * @param description + * the close description, or {@code null} + * @param errorCode + * an optional machine-readable error code surfaced to the runtime + * when the close is a failure, or {@code null} + * @param error + * the error that terminated the connection, or {@code null} for a + * clean close + */ + public CopilotWebSocketCloseStatus(String description, String errorCode, Throwable error) { + this.description = description; + this.errorCode = errorCode; + this.error = error; + } + + /** + * Gets the close description, if any. + * + * @return the description, or {@code null} + */ + public String description() { + return description; + } + + /** + * Gets the optional error code surfaced to the runtime when the close is a + * failure rather than a clean end-of-stream. + * + * @return the error code, or {@code null} + */ + public String errorCode() { + return errorCode; + } + + /** + * Gets the error that terminated the connection, if any. + * + * @return the error, or {@code null} for a clean close + */ + public Throwable error() { + return error; + } +} diff --git a/java/src/main/java/com/github/copilot/CopilotWebSocketHandler.java b/java/src/main/java/com/github/copilot/CopilotWebSocketHandler.java new file mode 100644 index 000000000..71aac9ecd --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotWebSocketHandler.java @@ -0,0 +1,197 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.ByteArrayOutputStream; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.WebSocket; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletionStage; + +/** + * The default pass-through {@link CopilotWebSocketHandlerBase}: it dials the + * real upstream using {@link java.net.http.WebSocket} and relays + * upstream-to-runtime messages into the runtime response unchanged. + *

+ * Subclass and override {@link #sendRequestMessage} or + * {@link #sendResponseMessage} (calling {@code super}) to observe, transform, + * or drop messages in either direction. + * + * @since 1.0.0 + */ +public class CopilotWebSocketHandler extends CopilotWebSocketHandlerBase { + + private final String url; + private final Map> headers; + + private volatile WebSocket webSocket; + + /** + * Creates a forwarding handler targeting the request URL and headers from + * {@code context}. + * + * @param context + * the per-request context + */ + public CopilotWebSocketHandler(CopilotRequestContext context) { + this(context, context.url(), context.headers()); + } + + /** + * Creates a forwarding handler targeting {@code url} with the handshake headers + * from {@code context}. + * + * @param context + * the per-request context + * @param url + * the upstream WebSocket URL + */ + public CopilotWebSocketHandler(CopilotRequestContext context, String url) { + this(context, url, context.headers()); + } + + /** + * Creates a forwarding handler targeting {@code url} with the given handshake + * headers. + * + * @param context + * the per-request context + * @param url + * the upstream WebSocket URL + * @param headers + * the handshake headers, multi-valued + */ + public CopilotWebSocketHandler(CopilotRequestContext context, String url, Map> headers) { + super(context); + this.url = url; + this.headers = headers; + } + + @Override + void open() throws Exception { + if (webSocket != null) { + return; + } + WebSocket.Builder builder = HttpClient.newHttpClient().newWebSocketBuilder(); + if (headers != null) { + for (Map.Entry> entry : headers.entrySet()) { + if (CopilotRequestHandler.isForbiddenRequestHeader(entry.getKey()) || entry.getValue() == null) { + continue; + } + for (String value : entry.getValue()) { + builder.header(entry.getKey(), value); + } + } + } + try { + this.webSocket = builder.buildAsync(URI.create(normalizeWebSocketScheme(url)), new ForwardingListener()) + .join(); + } catch (Exception e) { + throw unwrap(e); + } + } + + @Override + public void sendRequestMessage(CopilotWebSocketMessage message) throws Exception { + WebSocket ws = this.webSocket; + if (ws == null) { + return; + } + if (message.binary()) { + ws.sendBinary(ByteBuffer.wrap(message.data()), true).join(); + } else { + ws.sendText(message.text(), true).join(); + } + } + + @Override + public void close(CopilotWebSocketCloseStatus status) throws Exception { + WebSocket ws = this.webSocket; + if (ws != null && !ws.isOutputClosed()) { + ws.sendClose(WebSocket.NORMAL_CLOSURE, "").exceptionally(ex -> null); + } + super.close(status); + } + + private void forward(byte[] data, boolean binary) { + try { + sendResponseMessage(new CopilotWebSocketMessage(data, binary)); + } catch (Exception e) { + completion().completeExceptionally(e); + } + } + + private static String normalizeWebSocketScheme(String url) { + if (url.startsWith("http://")) { + return "ws://" + url.substring("http://".length()); + } + if (url.startsWith("https://")) { + return "wss://" + url.substring("https://".length()); + } + return url; + } + + private static Exception unwrap(Exception e) { + Throwable cause = e.getCause(); + if (cause instanceof Exception ex) { + return ex; + } + return e; + } + + private final class ForwardingListener implements WebSocket.Listener { + + private final StringBuilder textBuffer = new StringBuilder(); + private final ByteArrayOutputStream binaryBuffer = new ByteArrayOutputStream(); + + @Override + public void onOpen(WebSocket webSocket) { + webSocket.request(Long.MAX_VALUE); + } + + @Override + public CompletionStage onText(WebSocket webSocket, CharSequence data, boolean last) { + textBuffer.append(data); + if (last) { + byte[] message = textBuffer.toString().getBytes(StandardCharsets.UTF_8); + textBuffer.setLength(0); + forward(message, false); + } + return null; + } + + @Override + public CompletionStage onBinary(WebSocket webSocket, ByteBuffer data, boolean last) { + byte[] chunk = new byte[data.remaining()]; + data.get(chunk); + binaryBuffer.writeBytes(chunk); + if (last) { + byte[] message = binaryBuffer.toByteArray(); + binaryBuffer.reset(); + forward(message, true); + } + return null; + } + + @Override + public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { + close(); + return null; + } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + try { + close(new CopilotWebSocketCloseStatus(error.getMessage(), null, error)); + } catch (Exception e) { + completion().completeExceptionally(e); + } + } + } +} diff --git a/java/src/main/java/com/github/copilot/CopilotWebSocketHandlerBase.java b/java/src/main/java/com/github/copilot/CopilotWebSocketHandlerBase.java new file mode 100644 index 000000000..9eba8162a --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotWebSocketHandlerBase.java @@ -0,0 +1,119 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A per-connection WebSocket handler returned by + * {@link CopilotRequestHandler#openWebSocket}. + *

+ * The default implementation is {@link CopilotWebSocketHandler}, which dials + * the real upstream and transparently relays messages in both directions. A + * full transport replacement subclasses this type directly and brings its own + * transport and receive loop, forwarding upstream-to-runtime messages by + * calling {@link #sendResponseMessage} and finishing with + * {@link #close(CopilotWebSocketCloseStatus)}. + * + * @since 1.0.0 + */ +public abstract class CopilotWebSocketHandlerBase implements AutoCloseable { + + private final LlmWebSocketResponseBridge response; + private final CompletableFuture completion = new CompletableFuture<>(); + private final AtomicBoolean closed = new AtomicBoolean(); + private volatile boolean suppressCloseOnDispose; + + /** The request context for this WebSocket connection. */ + protected final CopilotRequestContext context; + + /** + * Initializes a per-connection handler for the supplied request context. + * + * @param context + * the per-request context + */ + protected CopilotWebSocketHandlerBase(CopilotRequestContext context) { + this.context = context; + this.response = Objects.requireNonNull(context.webSocketResponse(), + "WebSocket response bridge is not attached"); + } + + /** + * Sends a message from the runtime to the upstream connection. + * + * @param message + * the message to forward upstream + * @throws Exception + * if the message could not be forwarded + */ + public abstract void sendRequestMessage(CopilotWebSocketMessage message) throws Exception; + + /** + * Sends a message from the upstream connection back to the runtime. Override to + * mutate or duplicate messages; call {@code super} to emit. + * + * @param message + * the upstream-to-runtime message + * @throws Exception + * if the message could not be delivered + */ + public void sendResponseMessage(CopilotWebSocketMessage message) throws Exception { + response.write(message); + } + + /** + * Closes the connection and finalises the runtime-facing response. Idempotent. + * + * @param status + * the terminal status; a non-null + * {@link CopilotWebSocketCloseStatus#error()} surfaces a transport + * failure, otherwise a clean end-of-stream + * @throws Exception + * if the terminal frame could not be delivered + */ + public void close(CopilotWebSocketCloseStatus status) throws Exception { + if (!closed.compareAndSet(false, true)) { + return; + } + if (status.error() != null) { + response.error(status.description() != null ? status.description() : status.error().getMessage(), + status.errorCode()); + } else { + response.end(); + } + completion.complete(status); + } + + /** + * Tears down the connection, finalising with a normal closure unless the + * connection has already been closed or close-on-dispose was suppressed. + */ + @Override + public void close() { + if (!suppressCloseOnDispose && !closed.get()) { + try { + close(CopilotWebSocketCloseStatus.NORMAL_CLOSURE); + } catch (Exception ignored) { + // Best-effort teardown; the connection may already be gone. + } + } + } + + CompletableFuture completion() { + return completion; + } + + void suppressCloseOnDispose() { + suppressCloseOnDispose = true; + } + + void open() throws Exception { + // Default: nothing to establish. CopilotWebSocketHandler dials + // the upstream here. + } +} diff --git a/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java b/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java new file mode 100644 index 000000000..921ab01db --- /dev/null +++ b/java/src/main/java/com/github/copilot/CopilotWebSocketMessage.java @@ -0,0 +1,52 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.nio.charset.StandardCharsets; + +/** + * A single WebSocket message exchanged through a + * {@link CopilotWebSocketHandlerBase} hook. + * + * @param data + * the message payload bytes + * @param binary + * {@code true} for a binary frame, {@code false} for a UTF-8 text + * frame + * @since 1.0.0 + */ +public record CopilotWebSocketMessage(byte[] data, boolean binary) { + + /** + * Decodes the payload as UTF-8 text. + * + * @return the payload as text + */ + public String text() { + return new String(data, StandardCharsets.UTF_8); + } + + /** + * Creates a text message from a UTF-8 string. + * + * @param text + * the text payload + * @return a text message + */ + public static CopilotWebSocketMessage text(String text) { + return new CopilotWebSocketMessage(text.getBytes(StandardCharsets.UTF_8), false); + } + + /** + * Creates a binary message from raw bytes. + * + * @param data + * the binary payload + * @return a binary message + */ + public static CopilotWebSocketMessage binary(byte[] data) { + return new CopilotWebSocketMessage(data, true); + } +} diff --git a/java/src/main/java/com/github/copilot/LlmInferenceAdapter.java b/java/src/main/java/com/github/copilot/LlmInferenceAdapter.java new file mode 100644 index 000000000..9087df6c1 --- /dev/null +++ b/java/src/main/java/com/github/copilot/LlmInferenceAdapter.java @@ -0,0 +1,203 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Base64; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; +import java.util.function.Supplier; +import java.util.logging.Level; +import java.util.logging.Logger; + +import com.fasterxml.jackson.databind.JsonNode; +import com.github.copilot.generated.rpc.ServerLlmInferenceApi; + +/** + * Adapts the generated {@code llmInference.*} reverse-RPC entry points onto a + * consumer's {@link CopilotRequestHandler}. Each {@code httpRequestStart} + * allocates an {@link LlmInferenceExchange} and runs the handler in the + * background; subsequent {@code httpRequestChunk} frames feed its request body + * stream. + */ +final class LlmInferenceAdapter { + + private static final Logger LOG = Logger.getLogger(LlmInferenceAdapter.class.getName()); + + private final CopilotRequestHandler handler; + private final Supplier rpcSupplier; + private final Executor executor; + + private final Map pending = new ConcurrentHashMap<>(); + + LlmInferenceAdapter(CopilotRequestHandler handler, Supplier rpcSupplier, Executor executor) { + this.handler = handler; + this.rpcSupplier = rpcSupplier; + this.executor = executor; + } + + void registerHandlers(JsonRpcClient rpc) { + rpc.registerMethodHandler("llmInference.httpRequestStart", + (rpcId, params) -> handleRequestStart(rpc, rpcId, params)); + rpc.registerMethodHandler("llmInference.httpRequestChunk", + (rpcId, params) -> handleRequestChunk(rpc, rpcId, params)); + } + + private LlmInferenceExchange getOrCreateExchange(String requestId) { + // The runtime dispatches httpRequestStart and httpRequestChunk frames + // independently. Even though the current reader dispatches them in + // order, get-or-create keeps the adapter correct regardless: a body + // chunk (including the terminal end frame) that races ahead of its + // start frame is buffered into the same exchange rather than dropped, + // which would otherwise hang the body drain forever. + return pending.computeIfAbsent(requestId, id -> new LlmInferenceExchange(id, rpcSupplier)); + } + + private void handleRequestStart(JsonRpcClient rpc, String rpcId, JsonNode params) { + String requestId = params.get("requestId").asText(); + String sessionId = textOrNull(params, "sessionId"); + String method = textOrNull(params, "method"); + String url = textOrNull(params, "url"); + CopilotRequestTransport transport = CopilotRequestTransport.fromWire(textOrNull(params, "transport")); + Map> headers = parseHeaders(params.get("headers")); + + // Adopt any exchange a racing chunk already created — with its buffered + // body — rather than dropping those frames. + LlmInferenceExchange exchange = getOrCreateExchange(requestId); + exchange.setMethod(method); + exchange.setContext( + new CopilotRequestContext(requestId, sessionId, transport, url, headers, exchange.cancellation())); + + // Return from httpRequestStart immediately (after registering state) so the + // runtime's RPC reply is not gated on the consumer's I/O. The actual handler + // work runs asynchronously. + runAsync(() -> runHandler(exchange)); + + ack(rpc, rpcId); + } + + private void handleRequestChunk(JsonRpcClient rpc, String rpcId, JsonNode params) { + String requestId = params.get("requestId").asText(); + // May arrive before the matching start frame; get-or-create so the body + // is buffered, never lost. + LlmInferenceExchange exchange = getOrCreateExchange(requestId); + routeChunk(exchange, params); + ack(rpc, rpcId); + } + + private static void routeChunk(LlmInferenceExchange exchange, JsonNode params) { + if (boolOr(params, "cancel")) { + exchange.pushCancel(); + return; + } + String data = textOr(params, "data", ""); + boolean binary = boolOr(params, "binary"); + if (!data.isEmpty()) { + byte[] bytes = binary ? Base64.getDecoder().decode(data) : data.getBytes(StandardCharsets.UTF_8); + exchange.pushChunk(bytes, binary); + } + if (boolOr(params, "end")) { + exchange.pushEnd(); + } + } + + private void runHandler(LlmInferenceExchange exchange) { + try { + handler.handle(exchange); + if (!exchange.finished()) { + finalizeError(exchange, 502, "LLM inference handler returned without finalising the response " + + "(call endResponse() or errorResponse())", null); + } + } catch (Exception e) { + if (exchange.cancelled() || exchange.cancellation().isDone()) { + // The runtime already cancelled this request; the handler's throw is + // just the abort propagating out of its upstream call. + finalizeError(exchange, 499, "Request cancelled by runtime", "cancelled"); + } else { + String message = e.getMessage() != null ? e.getMessage() : e.toString(); + finalizeError(exchange, 502, message, null); + } + } finally { + pending.remove(exchange.requestId()); + } + } + + private static void finalizeError(LlmInferenceExchange exchange, int status, String message, String code) { + if (exchange.finished()) { + return; + } + try { + if (!exchange.started()) { + exchange.startResponse(status, null, null); + } + exchange.errorResponse(message, code); + } catch (IOException e) { + LOG.log(Level.FINE, "Failed to deliver LLM inference failure", e); + } + } + + private void ack(JsonRpcClient rpc, String rpcId) { + long id; + try { + id = Long.parseLong(rpcId); + } catch (NumberFormatException e) { + return; + } + try { + rpc.sendResponse(id, Map.of()); + } catch (IOException e) { + LOG.log(Level.FINE, "Failed to acknowledge LLM inference frame", e); + } + } + + private void runAsync(Runnable task) { + try { + if (executor != null) { + CompletableFuture.runAsync(task, executor); + } else { + CompletableFuture.runAsync(task); + } + } catch (RejectedExecutionException e) { + LOG.log(Level.WARNING, "Executor rejected LLM inference task; running inline", e); + task.run(); + } + } + + private static String textOrNull(JsonNode params, String field) { + return params.has(field) && !params.get(field).isNull() ? params.get(field).asText() : null; + } + + private static String textOr(JsonNode params, String field, String fallback) { + return params.has(field) && !params.get(field).isNull() ? params.get(field).asText() : fallback; + } + + private static boolean boolOr(JsonNode params, String field) { + return params.has(field) && !params.get(field).isNull() && params.get(field).asBoolean(); + } + + private static Map> parseHeaders(JsonNode node) { + Map> result = new LinkedHashMap<>(); + if (node != null && node.isObject()) { + node.properties().forEach(entry -> { + List values = new ArrayList<>(); + JsonNode value = entry.getValue(); + if (value.isArray()) { + value.forEach(item -> values.add(item.asText())); + } else if (!value.isNull()) { + values.add(value.asText()); + } + result.put(entry.getKey(), values); + }); + } + return result; + } +} diff --git a/java/src/main/java/com/github/copilot/LlmInferenceExchange.java b/java/src/main/java/com/github/copilot/LlmInferenceExchange.java new file mode 100644 index 000000000..9c2bbe40c --- /dev/null +++ b/java/src/main/java/com/github/copilot/LlmInferenceExchange.java @@ -0,0 +1,256 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Supplier; + +import com.github.copilot.generated.rpc.LlmInferenceHttpResponseChunkError; +import com.github.copilot.generated.rpc.LlmInferenceHttpResponseChunkParams; +import com.github.copilot.generated.rpc.LlmInferenceHttpResponseStartParams; +import com.github.copilot.generated.rpc.ServerLlmInferenceApi; + +/** + * One intercepted request in flight. Carries the request context plus the body + * byte stream the runtime feeds in via {@code httpRequestChunk} frames, and + * emits the consumer's response straight back to the runtime through the + * generated {@code llmInference} server API. + *

+ * This is the single object the {@link LlmInferenceAdapter} owns and the + * {@link CopilotRequestHandler} writes to, replacing the former + * provider/sink/request-body/response-channel indirection. The response state + * machine is strict: {@link #startResponse} once, then zero or more + * {@code writeResponse*} calls, finishing with exactly one of + * {@link #endResponse} or {@link #errorResponse}. + */ +final class LlmInferenceExchange { + + /** + * A single request body frame. + * + * @param data + * the frame bytes + * @param binary + * {@code true} when delivered as binary, {@code false} for UTF-8 + * text + */ + record BodyFrame(byte[] data, boolean binary) { + } + + private enum ItemKind { + CHUNK, END, CANCEL + } + + private record BodyItem(ItemKind kind, byte[] data, boolean binary) { + } + + private final String requestId; + private String method; + private final Supplier rpcSupplier; + + private final BlockingQueue body = new LinkedBlockingQueue<>(); + private final CompletableFuture cancellation = new CompletableFuture<>(); + + private final Object lock = new Object(); + private boolean started; + private boolean finished; + private boolean cancelled; + + private CopilotRequestContext context; + + LlmInferenceExchange(String requestId, Supplier rpcSupplier) { + this.requestId = requestId; + this.rpcSupplier = rpcSupplier; + } + + String requestId() { + return requestId; + } + + String method() { + return method; + } + + void setMethod(String method) { + this.method = method; + } + + CompletableFuture cancellation() { + return cancellation; + } + + CopilotRequestContext context() { + return context; + } + + void setContext(CopilotRequestContext context) { + this.context = context; + } + + boolean started() { + synchronized (lock) { + return started; + } + } + + boolean finished() { + synchronized (lock) { + return finished; + } + } + + boolean cancelled() { + synchronized (lock) { + return cancelled; + } + } + + // --- Request body feed (driven by the adapter as chunk frames arrive) --- + + void pushChunk(byte[] data, boolean binary) { + body.add(new BodyItem(ItemKind.CHUNK, data, binary)); + } + + void pushEnd() { + body.add(new BodyItem(ItemKind.END, null, false)); + } + + void pushCancel() { + synchronized (lock) { + cancelled = true; + } + if (!cancellation.isDone()) { + cancellation.complete(null); + } + body.add(new BodyItem(ItemKind.CANCEL, null, false)); + } + + /** + * Reads the next request body frame, blocking until one is available. + * + * @return the next frame, or {@code null} when the body has ended + * @throws InterruptedException + * if interrupted while waiting + * @throws CancellationException + * if the runtime cancelled the request + */ + BodyFrame readFrame() throws InterruptedException { + BodyItem item = body.take(); + switch (item.kind()) { + case CANCEL -> { + // Re-arm the sentinel so subsequent reads keep failing fast. + body.add(item); + throw new CancellationException("Request cancelled by runtime"); + } + case END -> { + body.add(item); + return null; + } + default -> { + return new BodyFrame(item.data(), item.binary()); + } + } + } + + byte[] drainBody() throws InterruptedException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + BodyFrame frame; + while ((frame = readFrame()) != null) { + out.writeBytes(frame.data()); + } + return out.toByteArray(); + } + + // --- Response emit (driven by the handler) --- + + void startResponse(int status, String statusText, Map> headers) throws IOException { + synchronized (lock) { + if (started) { + throw new IOException("LLM inference response startResponse() called twice"); + } + if (finished) { + throw new IOException("LLM inference response already finished"); + } + started = true; + } + var params = new LlmInferenceHttpResponseStartParams(requestId, (long) status, statusText, headers); + join(api().httpResponseStart(params)); + } + + void writeResponseText(String text) throws IOException { + writeChunk(text, false); + } + + void writeResponseBinary(byte[] data) throws IOException { + writeChunk(Base64.getEncoder().encodeToString(data), true); + } + + void endResponse() throws IOException { + synchronized (lock) { + if (finished) { + return; + } + finished = true; + } + var params = new LlmInferenceHttpResponseChunkParams(requestId, "", null, Boolean.TRUE, null); + join(api().httpResponseChunk(params)); + } + + void errorResponse(String message, String code) throws IOException { + synchronized (lock) { + if (finished) { + return; + } + finished = true; + } + var error = new LlmInferenceHttpResponseChunkError(message, code); + var params = new LlmInferenceHttpResponseChunkParams(requestId, "", null, Boolean.TRUE, error); + join(api().httpResponseChunk(params)); + } + + private void writeChunk(String data, boolean binary) throws IOException { + synchronized (lock) { + if (cancelled) { + throw new IOException("LLM inference request was cancelled by the runtime"); + } + if (!started) { + throw new IOException("LLM inference response writeResponse() called before startResponse()"); + } + if (finished) { + throw new IOException( + "LLM inference response writeResponse() called after endResponse()/errorResponse()"); + } + } + var params = new LlmInferenceHttpResponseChunkParams(requestId, data, binary ? Boolean.TRUE : null, + Boolean.FALSE, null); + join(api().httpResponseChunk(params)); + } + + private ServerLlmInferenceApi api() throws IOException { + ServerLlmInferenceApi api = rpcSupplier.get(); + if (api == null) { + throw new IOException("LLM inference response used after RPC connection closed"); + } + return api; + } + + private static T join(CompletableFuture future) throws IOException { + try { + return future.join(); + } catch (CompletionException | CancellationException e) { + Throwable cause = e.getCause() != null ? e.getCause() : e; + throw new IOException(cause.getMessage(), cause); + } + } +} diff --git a/java/src/main/java/com/github/copilot/LlmWebSocketResponseBridge.java b/java/src/main/java/com/github/copilot/LlmWebSocketResponseBridge.java new file mode 100644 index 000000000..b7bbbd8c7 --- /dev/null +++ b/java/src/main/java/com/github/copilot/LlmWebSocketResponseBridge.java @@ -0,0 +1,73 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.IOException; + +/** + * Forwards upstream WebSocket messages back to the owning + * {@link LlmInferenceExchange}. The {@code 101} upgrade head is emitted eagerly + * via {@link #start()} (the runtime gates the WebSocket connect on it); + * thereafter writes are serialised so the head always precedes any body or + * terminal frame. + */ +final class LlmWebSocketResponseBridge { + + private final LlmInferenceExchange exchange; + private final Object lock = new Object(); + private boolean started; + private boolean completed; + + LlmWebSocketResponseBridge(LlmInferenceExchange exchange) { + this.exchange = exchange; + } + + /** + * Emits the {@code 101} upgrade head now, acknowledging the WebSocket connect. + */ + void start() throws IOException { + run(false, () -> { + }); + } + + void write(CopilotWebSocketMessage message) throws IOException { + run(false, () -> { + if (message.binary()) { + exchange.writeResponseBinary(message.data()); + } else { + exchange.writeResponseText(message.text()); + } + }); + } + + void end() throws IOException { + run(true, exchange::endResponse); + } + + void error(String message, String code) throws IOException { + run(true, () -> exchange.errorResponse(message, code)); + } + + private void run(boolean terminal, IoAction action) throws IOException { + synchronized (lock) { + if (completed) { + return; + } + if (!started) { + started = true; + exchange.startResponse(101, null, null); + } + if (terminal) { + completed = true; + } + action.run(); + } + } + + @FunctionalInterface + private interface IoAction { + void run() throws IOException; + } +} diff --git a/java/src/main/java/com/github/copilot/rpc/CopilotClientOptions.java b/java/src/main/java/com/github/copilot/rpc/CopilotClientOptions.java index 515d1488b..e9f59aa64 100644 --- a/java/src/main/java/com/github/copilot/rpc/CopilotClientOptions.java +++ b/java/src/main/java/com/github/copilot/rpc/CopilotClientOptions.java @@ -15,6 +15,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.github.copilot.CopilotRequestHandler; import java.util.Optional; import java.util.OptionalInt; @@ -55,6 +56,7 @@ public class CopilotClientOptions { private String logLevel = "info"; private CopilotClientMode mode = CopilotClientMode.COPILOT_CLI; private Supplier>> onListModels; + private CopilotRequestHandler requestHandler; private int port; private TelemetryConfig telemetry; private Integer sessionIdleTimeoutSeconds; @@ -454,6 +456,34 @@ public CopilotClientOptions setOnListModels(Supplier + * When provided, the client registers as the runtime's LLM inference provider + * on connect, and the runtime routes its model-layer HTTP and WebSocket traffic + * (both BYOK and CAPI) through the handler instead of issuing the calls itself. + * + * @param requestHandler + * the request handler (must not be {@code null}) + * @return this options instance for method chaining + * @throws IllegalArgumentException + * if {@code requestHandler} is {@code null} + */ + public CopilotClientOptions setRequestHandler(CopilotRequestHandler requestHandler) { + this.requestHandler = Objects.requireNonNull(requestHandler, "requestHandler must not be null"); + return this; + } + /** * Gets the TCP port for the CLI server. * @@ -689,6 +719,7 @@ public CopilotClientOptions clone() { copy.gitHubToken = this.gitHubToken; copy.logLevel = this.logLevel; copy.onListModels = this.onListModels; + copy.requestHandler = this.requestHandler; copy.port = this.port; copy.remote = this.remote; copy.sessionIdleTimeoutSeconds = this.sessionIdleTimeoutSeconds; diff --git a/java/src/main/java/module-info.java b/java/src/main/java/module-info.java index 81dd5ae2f..9f48b3747 100644 --- a/java/src/main/java/module-info.java +++ b/java/src/main/java/module-info.java @@ -12,7 +12,7 @@ requires com.fasterxml.jackson.datatype.jsr310; requires static com.github.spotbugs.annotations; requires static java.compiler; - requires static java.net.http; + requires java.net.http; requires java.logging; exports com.github.copilot; diff --git a/java/src/test/java/com/github/copilot/CopilotClientTest.java b/java/src/test/java/com/github/copilot/CopilotClientTest.java index 7c97a3886..d977563ae 100644 --- a/java/src/test/java/com/github/copilot/CopilotClientTest.java +++ b/java/src/test/java/com/github/copilot/CopilotClientTest.java @@ -58,7 +58,12 @@ void testStopRequestsRuntimeShutdownForOwnedProcess() throws Exception { verify(rpc).invoke(eq("runtime.shutdown"), eq(Map.of()), eq(Void.class)); verify(rpc).close(); - verify(process, never()).destroy(); + // The runtime never self-exits after runtime.shutdown (it keeps its + // JSON-RPC server alive to send the response and leaves termination to + // the caller), so stop() terminates the owned process. The mocked + // process exits on the first SIGTERM (waitFor returns true), so we + // never escalate to destroyForcibly(). + verify(process).destroy(); verify(process, never()).destroyForcibly(); } diff --git a/java/src/test/java/com/github/copilot/CopilotRequestCancelErrorE2ETest.java b/java/src/test/java/com/github/copilot/CopilotRequestCancelErrorE2ETest.java new file mode 100644 index 000000000..7d7ae5d70 --- /dev/null +++ b/java/src/test/java/com/github/copilot/CopilotRequestCancelErrorE2ETest.java @@ -0,0 +1,152 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static com.github.copilot.CopilotRequestTestSupport.buildNonInferenceResponse; +import static com.github.copilot.CopilotRequestTestSupport.isInferenceUrl; +import static com.github.copilot.CopilotRequestTestSupport.newLlmClient; +import static com.github.copilot.CopilotRequestTestSupport.setupCapiAuth; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.InputStream; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.SessionConfig; + +/** + * Cancellation and error coverage for {@link CopilotRequestHandler}. These two + * scenarios exercise the handler's terminal paths the happy-path session-id and + * forwarding tests never reach: + *

    + *
  • Error — the handler throws from + * {@link CopilotRequestHandler#sendRequest} for an inference request. The base + * adapter reports a transport error back to the runtime rather than + * hanging.
  • + *
  • Runtime cancel — the handler blocks an inference request + * indefinitely; when the consumer aborts the turn the runtime cancels the + * in-flight request, firing {@link CopilotRequestContext#cancellation()}. The + * handler observes the abort instead of leaking a stuck request.
  • + *
+ */ +public class CopilotRequestCancelErrorE2ETest { + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + /** Throws from every inference request to exercise the error-reporting path. */ + private static final class ThrowingRequestHandler extends CopilotRequestHandler { + + private final AtomicInteger inferenceAttempts = new AtomicInteger(); + + @Override + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext rctx) { + String url = request.uri().toString(); + if (!isInferenceUrl(url)) { + return buildNonInferenceResponse(url); + } + inferenceAttempts.incrementAndGet(); + throw new IllegalStateException("synthetic-callback-transport-failure"); + } + } + + /** Blocks every inference request until the runtime cancels it. */ + private static final class CancellingRequestHandler extends CopilotRequestHandler { + + private volatile boolean inferenceEntered; + private volatile boolean sawAbort; + + @Override + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext rctx) { + String url = request.uri().toString(); + if (!isInferenceUrl(url)) { + return buildNonInferenceResponse(url); + } + inferenceEntered = true; + try { + // Never produce a response; wait for the runtime to cancel us. + rctx.cancellation().join(); + } catch (CancellationException | java.util.concurrent.CompletionException e) { + // The cancellation future completes normally on cancel; this guards + // against any exceptional completion too. + } + sawAbort = true; + throw new CancellationException("Request cancelled by runtime"); + } + } + + @Test + void reportsThrownHandlerErrorInsteadOfHanging() throws Exception { + setupCapiAuth(ctx); + ThrowingRequestHandler handler = new ThrowingRequestHandler(); + + try (CopilotClient client = newLlmClient(ctx, handler)) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(); + + // The handler throws on inference; the turn surfaces an error (or completes + // without an assistant message) rather than hanging. + try { + session.sendAndWait(new MessageOptions().setPrompt("Say OK.")).get(60, TimeUnit.SECONDS); + } catch (Exception ignored) { + // Expected: the inference callback raised. + } + session.close(); + } + + assertTrue(handler.inferenceAttempts.get() > 0, "Expected the inference callback to be reached and raise"); + } + + @Test + void observesRuntimeCancellationOfInFlightInference() throws Exception { + setupCapiAuth(ctx); + CancellingRequestHandler handler = new CancellingRequestHandler(); + + try (CopilotClient client = newLlmClient(ctx, handler)) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(); + + session.send(new MessageOptions().setPrompt("Say OK.")).get(60, TimeUnit.SECONDS); + waitFor(() -> handler.inferenceEntered, 60_000); + session.abort().get(30, TimeUnit.SECONDS); + waitFor(() -> handler.sawAbort, 30_000); + session.close(); + } + + assertTrue(handler.inferenceEntered, "Expected the inference callback to be entered"); + assertTrue(handler.sawAbort, "Expected the callback to observe runtime cancellation"); + } + + private static void waitFor(java.util.function.BooleanSupplier predicate, long timeoutMillis) + throws InterruptedException { + long deadline = System.currentTimeMillis() + timeoutMillis; + while (!predicate.getAsBoolean()) { + if (System.currentTimeMillis() > deadline) { + throw new AssertionError("waitFor timed out"); + } + Thread.sleep(50); + } + } +} diff --git a/java/src/test/java/com/github/copilot/CopilotRequestHandlerE2ETest.java b/java/src/test/java/com/github/copilot/CopilotRequestHandlerE2ETest.java new file mode 100644 index 000000000..fd3460119 --- /dev/null +++ b/java/src/test/java/com/github/copilot/CopilotRequestHandlerE2ETest.java @@ -0,0 +1,177 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static com.github.copilot.CopilotRequestTestSupport.SYNTHETIC_TEXT; +import static com.github.copilot.CopilotRequestTestSupport.assistantText; +import static com.github.copilot.CopilotRequestTestSupport.newLlmClient; +import static com.github.copilot.CopilotRequestTestSupport.setupCapiAuth; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.InputStream; +import java.net.URI; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.github.copilot.CopilotRequestTestSupport.InterceptedRequest; +import com.github.copilot.CopilotRequestTestSupport.RecordingRequestHandler; +import com.github.copilot.generated.AssistantMessageEvent; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.SessionConfig; + +/** + * End-to-end coverage for {@link CopilotRequestHandler}: a synthetic HTTP turn + * that the handler fully fabricates off-network, and a forwarding turn that + * relays both the HTTP and WebSocket transports to a real in-process upstream. + */ +public class CopilotRequestHandlerE2ETest { + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @Test + void streamsSyntheticHttpInference() throws Exception { + setupCapiAuth(ctx); + RecordingRequestHandler handler = new RecordingRequestHandler(SYNTHETIC_TEXT); + + try (CopilotClient client = newLlmClient(ctx, handler)) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(); + + AssistantMessageEvent result = session.sendAndWait(new MessageOptions().setPrompt("Say OK.")).get(60, + TimeUnit.SECONDS); + session.close(); + + // The handler intercepted the startup catalog and at least one inference + // request, fully replacing the runtime's outbound model-layer calls. + List records = handler.records(); + assertFalse(records.isEmpty(), "Expected the runtime to invoke the request handler"); + assertTrue(records.stream().anyMatch(r -> r.url().toLowerCase(Locale.ROOT).endsWith("/models")), + "Expected to intercept the /models catalog request"); + assertFalse(handler.inferenceRequests().isEmpty(), + "Expected at least one inference request via the handler"); + + // Validate the final assistant response arrived (guards against truncated + // captures) + assertTrue(assistantText(result).contains("OK from the synthetic"), + "Expected synthetic content in assistant reply, got " + assistantText(result)); + } + } + + @Test + void forwardsHttpAndWebSocketToUpstream() throws Exception { + setupCapiAuth(ctx); + + AtomicInteger httpRequests = new AtomicInteger(); + AtomicInteger httpResponses = new AtomicInteger(); + AtomicInteger wsRequestMessages = new AtomicInteger(); + AtomicInteger wsResponseMessages = new AtomicInteger(); + + try (FakeUpstreamServer upstream = new FakeUpstreamServer("OK from synthetic HTTP upstream.", + "OK from synthetic WS upstream.")) { + + String httpBase = upstream.httpUrl(); + String wsBase = upstream.wsUrl(); + + CopilotRequestHandler handler = new CopilotRequestHandler() { + @Override + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext rctx) + throws Exception { + httpRequests.incrementAndGet(); + URI rewritten = URI.create(rewriteHost(httpBase, request.uri())); + HttpRequest.Builder builder = HttpRequest.newBuilder().uri(rewritten); + request.bodyPublisher().ifPresentOrElse(bp -> builder.method(request.method(), bp), + () -> builder.method(request.method(), HttpRequest.BodyPublishers.noBody())); + request.headers().map().forEach((name, values) -> { + for (String value : values) { + try { + builder.header(name, value); + } catch (IllegalArgumentException ignored) { + // Restricted header rejected by java.net.http; skip it. + } + } + }); + builder.header("x-test-mutated", "1"); + HttpResponse response = httpClient() + .sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream()).get(); + httpResponses.incrementAndGet(); + return response; + } + + @Override + protected CopilotWebSocketHandlerBase openWebSocket(CopilotRequestContext rctx) { + String rewritten = rewriteHost(wsBase, URI.create(rctx.url())); + return new CopilotWebSocketHandler(rctx, rewritten) { + @Override + public void sendRequestMessage(CopilotWebSocketMessage message) throws Exception { + wsRequestMessages.incrementAndGet(); + super.sendRequestMessage(message); + } + + @Override + public void sendResponseMessage(CopilotWebSocketMessage message) throws Exception { + wsResponseMessages.incrementAndGet(); + super.sendResponseMessage(message); + } + }; + } + }; + + try (CopilotClient client = newLlmClient(ctx, handler, + "COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES=true")) { + CopilotSession session = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(); + + AssistantMessageEvent result = session.sendAndWait(new MessageOptions().setPrompt("Say OK.")).get(60, + TimeUnit.SECONDS); + session.close(); + + // The HTTP override fired — the runtime issued model-layer GETs (catalog, + // policy) and possibly a single-shot inference through the send override. + assertTrue(httpRequests.get() > 0, "Expected the HTTP send override to fire"); + assertTrue(httpResponses.get() > 0, "Expected the HTTP response mutation to fire"); + + // The WebSocket override fired — the main agent turn went over the WS path + // and we observed messages in both directions. + assertTrue(wsRequestMessages.get() > 0, "Expected runtime -> upstream ws messages"); + assertTrue(wsResponseMessages.get() > 0, "Expected upstream -> runtime ws messages"); + assertTrue(upstream.upstreamWsRequests() > 0, "Expected the upstream WS to receive request messages"); + + // Validate the final assistant response arrived (guards against truncated + // captures) + String text = assistantText(result); + assertTrue(text.contains("OK from synthetic") && text.contains("upstream"), + "Expected synthetic upstream content in assistant reply, got " + text); + } + } + } + + private static String rewriteHost(String base, URI original) { + String path = original.getRawPath() == null ? "" : original.getRawPath(); + String query = original.getRawQuery(); + return base + path + (query != null ? "?" + query : ""); + } +} diff --git a/java/src/test/java/com/github/copilot/CopilotRequestSessionIdE2ETest.java b/java/src/test/java/com/github/copilot/CopilotRequestSessionIdE2ETest.java new file mode 100644 index 000000000..daf524945 --- /dev/null +++ b/java/src/test/java/com/github/copilot/CopilotRequestSessionIdE2ETest.java @@ -0,0 +1,100 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import static com.github.copilot.CopilotRequestTestSupport.SYNTHETIC_TEXT; +import static com.github.copilot.CopilotRequestTestSupport.assistantText; +import static com.github.copilot.CopilotRequestTestSupport.newLlmClient; +import static com.github.copilot.CopilotRequestTestSupport.setupCapiAuth; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import com.github.copilot.CopilotRequestTestSupport.InterceptedRequest; +import com.github.copilot.CopilotRequestTestSupport.RecordingRequestHandler; +import com.github.copilot.generated.AssistantMessageEvent; +import com.github.copilot.rpc.MessageOptions; +import com.github.copilot.rpc.PermissionHandler; +import com.github.copilot.rpc.ProviderConfig; +import com.github.copilot.rpc.SessionConfig; + +/** + * Verifies that the triggering session id is threaded into every inference + * request context, for both CAPI and BYOK sessions, and that per-session ids + * differ. + */ +public class CopilotRequestSessionIdE2ETest { + + private static E2ETestContext ctx; + + @BeforeAll + static void setup() throws Exception { + ctx = E2ETestContext.create(); + } + + @AfterAll + static void teardown() throws Exception { + if (ctx != null) { + ctx.close(); + } + } + + @Test + void threadsSessionIdForCapiAndByok() throws Exception { + setupCapiAuth(ctx); + RecordingRequestHandler handler = new RecordingRequestHandler(SYNTHETIC_TEXT); + + try (CopilotClient client = newLlmClient(ctx, handler)) { + // CAPI session. + CopilotSession capiSession = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL)).get(); + String capiSessionId = capiSession.getSessionId(); + + AssistantMessageEvent capiResult = capiSession.sendAndWait(new MessageOptions().setPrompt("Say OK.")) + .get(60, TimeUnit.SECONDS); + capiSession.close(); + + List capiInference = handler.inferenceRequests(); + assertFalse(capiInference.isEmpty(), "Expected at least one intercepted inference request"); + for (InterceptedRequest r : capiInference) { + assertEquals(capiSessionId, r.sessionId(), "CAPI inference request must carry the session id"); + } + assertTrue(assistantText(capiResult).contains("OK from the synthetic"), + "Expected synthetic content in CAPI assistant reply, got " + assistantText(capiResult)); + + // BYOK session. + int before = handler.inferenceRequests().size(); + ProviderConfig provider = new ProviderConfig().setType("openai").setWireApi("responses") + .setBaseUrl("https://byok.invalid/v1").setApiKey("byok-secret").setModelId("claude-sonnet-4.5") + .setWireModel("claude-sonnet-4.5"); + CopilotSession byokSession = client + .createSession(new SessionConfig().setOnPermissionRequest(PermissionHandler.APPROVE_ALL) + .setModel("claude-sonnet-4.5").setProvider(provider)) + .get(); + String byokSessionId = byokSession.getSessionId(); + + AssistantMessageEvent byokResult = byokSession.sendAndWait(new MessageOptions().setPrompt("Say OK.")) + .get(60, TimeUnit.SECONDS); + byokSession.close(); + + List byokInference = handler.inferenceRequests(); + assertTrue(byokInference.size() > before, "Expected at least one intercepted BYOK inference request"); + for (InterceptedRequest r : byokInference.subList(before, byokInference.size())) { + assertEquals(byokSessionId, r.sessionId(), "BYOK inference request must carry the session id"); + } + assertNotEquals(capiSessionId, byokSessionId, "Expected per-session ids to differ between turns"); + assertTrue(assistantText(byokResult).contains("OK from the synthetic"), + "Expected synthetic content in BYOK assistant reply, got " + assistantText(byokResult)); + } + } +} diff --git a/java/src/test/java/com/github/copilot/CopilotRequestTestSupport.java b/java/src/test/java/com/github/copilot/CopilotRequestTestSupport.java new file mode 100644 index 000000000..3b01734bd --- /dev/null +++ b/java/src/test/java/com/github/copilot/CopilotRequestTestSupport.java @@ -0,0 +1,506 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import java.util.regex.Pattern; +import javax.net.ssl.SSLSession; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.copilot.generated.AssistantMessageEvent; +import com.github.copilot.rpc.CopilotClientOptions; + +/** + * Shared synthetic-upstream helpers for the {@link CopilotRequestHandler} e2e + * tests. + * + *

+ * These tests have no recorded snapshots: a {@link CopilotRequestHandler} + * subclass fabricates well-formed model responses and the runtime routes all of + * its model-layer HTTP/WebSocket traffic through that handler instead of the + * CAPI proxy. The helpers centralise the synthetic CAPI shapes (model catalog, + * policy, {@code /responses} SSE, {@code /chat/completions}) so each test + * focuses on the behaviour it is exercising. + *

+ */ +final class CopilotRequestTestSupport { + + static final String SYNTHETIC_TEXT = "OK from the synthetic stream."; + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final Pattern STREAM_TRUE = Pattern.compile("\"stream\"\\s*:\\s*true"); + + private CopilotRequestTestSupport() { + } + + /** + * Builds a client wired to {@code handler} via the {@code requestHandler} + * option. The shared context client has no request handler, so each inference + * test owns an isolated client carrying its own handler. {@code extraEnv} + * entries (formatted {@code KEY=value}) are added to the spawned runtime's + * environment, e.g. to flip an ExP flag for the WebSocket transport. + */ + static CopilotClient newLlmClient(E2ETestContext ctx, CopilotRequestHandler handler, String... extraEnv) { + Map env = new HashMap<>(ctx.getEnvironment()); + for (String entry : extraEnv) { + int eq = entry.indexOf('='); + if (eq > 0) { + env.put(entry.substring(0, eq), entry.substring(eq + 1)); + } + } + return ctx.createClient(new CopilotClientOptions().setEnvironment(env).setRequestHandler(handler)); + } + + /** + * Initializes the proxy state and registers a synthetic CAPI user so the + * runtime can resolve auth for sessions that route their model-layer traffic + * through the handler instead of the proxy. + */ + static void setupCapiAuth(E2ETestContext ctx) throws IOException, InterruptedException { + ctx.initializeProxy(); + ctx.setCopilotUserByToken("fake-token-for-e2e-tests", "e2e-user", "individual_pro", ctx.getProxyUrl(), + "https://localhost:1/telemetry", "e2e-tracking-id"); + } + + static Map> headers(String name, String value) { + Map> headers = new LinkedHashMap<>(); + headers.put(name, List.of(value)); + return headers; + } + + static String json(Object value) { + try { + return MAPPER.writeValueAsString(value); + } catch (JsonProcessingException e) { + throw new UncheckedIOException(e); + } + } + + static boolean wantsStream(String body) { + return STREAM_TRUE.matcher(body).find(); + } + + static boolean isInferenceUrl(String url) { + String u = url.toLowerCase(Locale.ROOT); + return u.endsWith("/chat/completions") || u.endsWith("/responses") || u.endsWith("/v1/messages") + || u.endsWith("/messages"); + } + + static String sse(String eventType, Object data) { + return "event: " + eventType + "\ndata: " + json(data) + "\n\n"; + } + + static String sseBody(String text, String respId) { + StringBuilder sb = new StringBuilder(); + for (Map event : responsesEvents(text, respId)) { + sb.append(sse((String) event.get("type"), event)); + } + return sb.toString(); + } + + // --- Synthetic response builders for the CopilotRequestHandler send override + // --- + + /** + * Drains the body of an outbound {@link HttpRequest} to a UTF-8 string. Mirrors + * the .NET {@code request.Content.ReadAsStringAsync()} the recording handler + * uses to inspect the request the runtime built. + */ + static String requestBodyText(HttpRequest request) { + return request.bodyPublisher().map(CopilotRequestTestSupport::drain).orElse(""); + } + + private static String drain(HttpRequest.BodyPublisher publisher) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + CompletableFuture done = new CompletableFuture<>(); + publisher.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(ByteBuffer item) { + byte[] chunk = new byte[item.remaining()]; + item.get(chunk); + out.writeBytes(chunk); + } + + @Override + public void onError(Throwable throwable) { + done.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + done.complete(null); + } + }); + done.join(); + return out.toString(StandardCharsets.UTF_8); + } + + /** + * Synthesizes a well-formed inference response, dispatching by URL and the + * request body's stream flag exactly as a real reverse proxy would. + */ + static HttpResponse buildInferenceResponse(String url, String bodyText, String text) { + boolean stream = wantsStream(bodyText); + String u = url.toLowerCase(Locale.ROOT); + + if (u.contains("/responses")) { + if (!stream) { + List> events = responsesEvents(text, "resp_stub_1"); + Object last = events.get(events.size() - 1).get("response"); + return jsonResponse(json(last)); + } + return sseResponse(sseBody(text, "resp_stub_1")); + } + + if (u.contains("/chat/completions") && stream) { + StringBuilder sb = new StringBuilder(); + for (Map chunk : chatCompletionChunks(text)) { + sb.append("data: ").append(json(chunk)).append("\n\n"); + } + sb.append("data: [DONE]\n\n"); + return sseResponse(sb.toString()); + } + + return jsonResponse(json(chatCompletion(text))); + } + + /** + * Serves the non-inference model-layer requests the runtime issues (catalog, + * model session, policy), with an empty-JSON fallback for anything else. + */ + static HttpResponse buildNonInferenceResponse(String url) { + String u = url.toLowerCase(Locale.ROOT); + if (u.endsWith("/models")) { + return jsonResponse(modelCatalog(null)); + } + if (u.contains("/models/session")) { + return jsonResponse("{}"); + } + if (u.contains("/policy")) { + return jsonResponse("{\"state\":\"enabled\"}"); + } + return jsonResponse("{}"); + } + + static HttpResponse jsonResponse(String body) { + return new StubHttpResponse(200, "application/json", body); + } + + static HttpResponse sseResponse(String body) { + return new StubHttpResponse(200, "text/event-stream", body); + } + + static String modelCatalog(List supportedEndpoints) { + Map limits = new LinkedHashMap<>(); + limits.put("max_context_window_tokens", 200000); + limits.put("max_output_tokens", 8192); + + Map supports = new LinkedHashMap<>(); + supports.put("streaming", true); + supports.put("tool_calls", true); + supports.put("parallel_tool_calls", true); + supports.put("vision", true); + + Map capabilities = new LinkedHashMap<>(); + capabilities.put("type", "chat"); + capabilities.put("family", "claude-sonnet-4.5"); + capabilities.put("tokenizer", "o200k_base"); + capabilities.put("limits", limits); + capabilities.put("supports", supports); + + Map model = new LinkedHashMap<>(); + model.put("id", "claude-sonnet-4.5"); + model.put("name", "Claude Sonnet 4.5"); + model.put("object", "model"); + model.put("vendor", "Anthropic"); + model.put("version", "1"); + model.put("preview", false); + model.put("model_picker_enabled", true); + model.put("capabilities", capabilities); + if (supportedEndpoints != null) { + model.put("supported_endpoints", supportedEndpoints); + } + + Map root = new LinkedHashMap<>(); + root.put("data", List.of(model)); + return json(root); + } + + /** + * Returns the ordered {@code /responses} event objects the runtime's reducer + * expects. Used raw (one object == one WebSocket message) for the WS path and + * SSE-framed for the HTTP path. + */ + static List> responsesEvents(String text, String respId) { + Map created = new LinkedHashMap<>(); + created.put("type", "response.created"); + created.put("response", responseShell(respId, "in_progress", List.of())); + + Map itemAdded = new LinkedHashMap<>(); + itemAdded.put("type", "response.output_item.added"); + itemAdded.put("output_index", 0); + itemAdded.put("item", message("msg_1", List.of())); + + Map partAdded = new LinkedHashMap<>(); + partAdded.put("type", "response.content_part.added"); + partAdded.put("output_index", 0); + partAdded.put("content_index", 0); + partAdded.put("part", outputText("")); + + Map delta = new LinkedHashMap<>(); + delta.put("type", "response.output_text.delta"); + delta.put("output_index", 0); + delta.put("content_index", 0); + delta.put("delta", text); + + Map done = new LinkedHashMap<>(); + done.put("type", "response.output_text.done"); + done.put("output_index", 0); + done.put("content_index", 0); + done.put("text", text); + + Map completedResponse = responseShell(respId, "completed", + List.of(message("msg_1", List.of(outputText(text))))); + completedResponse.put("usage", usage()); + Map completed = new LinkedHashMap<>(); + completed.put("type", "response.completed"); + completed.put("response", completedResponse); + + return List.of(created, itemAdded, partAdded, delta, done, completed); + } + + private static Map responseShell(String respId, String status, List output) { + Map response = new LinkedHashMap<>(); + response.put("id", respId); + response.put("object", "response"); + response.put("status", status); + response.put("output", output); + return response; + } + + private static Map message(String id, List content) { + Map item = new LinkedHashMap<>(); + item.put("id", id); + item.put("type", "message"); + item.put("role", "assistant"); + item.put("content", content); + return item; + } + + private static Map outputText(String text) { + Map part = new LinkedHashMap<>(); + part.put("type", "output_text"); + part.put("text", text); + return part; + } + + private static Map usage() { + Map usage = new LinkedHashMap<>(); + usage.put("input_tokens", 5); + usage.put("output_tokens", 7); + usage.put("total_tokens", 12); + return usage; + } + + private static List> chatCompletionChunks(String text) { + Map c1 = chatChunkBase(); + c1.put("choices", List.of(choice(0, delta("assistant", ""), null))); + Map c2 = chatChunkBase(); + c2.put("choices", List.of(choice(0, delta(null, text), null))); + Map c3 = chatChunkBase(); + c3.put("choices", List.of(choice(0, new LinkedHashMap<>(), "stop"))); + c3.put("usage", chatUsage()); + return List.of(c1, c2, c3); + } + + private static Map chatChunkBase() { + Map base = new LinkedHashMap<>(); + base.put("id", "chatcmpl-stub-1"); + base.put("object", "chat.completion.chunk"); + base.put("created", 1); + base.put("model", "claude-sonnet-4.5"); + return base; + } + + private static Map delta(String role, String content) { + Map delta = new LinkedHashMap<>(); + if (role != null) { + delta.put("role", role); + } + delta.put("content", content); + return delta; + } + + private static Map choice(int index, Map delta, String finishReason) { + Map choice = new LinkedHashMap<>(); + choice.put("index", index); + choice.put("delta", delta); + choice.put("finish_reason", finishReason); + return choice; + } + + private static Map chatUsage() { + Map usage = new LinkedHashMap<>(); + usage.put("prompt_tokens", 5); + usage.put("completion_tokens", 7); + usage.put("total_tokens", 12); + return usage; + } + + private static Map chatCompletion(String text) { + Map message = new LinkedHashMap<>(); + message.put("role", "assistant"); + message.put("content", text); + + Map choice = new LinkedHashMap<>(); + choice.put("index", 0); + choice.put("message", message); + choice.put("finish_reason", "stop"); + + Map root = new LinkedHashMap<>(); + root.put("id", "chatcmpl-stub-1"); + root.put("object", "chat.completion"); + root.put("created", 1); + root.put("model", "claude-sonnet-4.5"); + root.put("choices", List.of(choice)); + root.put("usage", chatUsage()); + return root; + } + + static String assistantText(AssistantMessageEvent event) { + if (event == null || event.getData() == null) { + return ""; + } + String content = event.getData().content(); + return content != null ? content : ""; + } + + /** A single request the handler intercepted. */ + record InterceptedRequest(String url, String sessionId) { + } + + /** + * A {@link CopilotRequestHandler} that records every intercepted request and + * fully replaces the upstream call with a fabricated, well-formed response for + * every model-layer endpoint, so an agent turn completes entirely off-network. + */ + static class RecordingRequestHandler extends CopilotRequestHandler { + + private final ConcurrentLinkedQueue records = new ConcurrentLinkedQueue<>(); + private final String text; + + RecordingRequestHandler(String text) { + this.text = text; + } + + List records() { + return new ArrayList<>(records); + } + + List inferenceRequests() { + List out = new ArrayList<>(); + for (InterceptedRequest r : records) { + if (isInferenceUrl(r.url())) { + out.add(r); + } + } + return out; + } + + @Override + protected HttpResponse sendRequest(HttpRequest request, CopilotRequestContext ctx) + throws Exception { + String url = request.uri().toString(); + records.add(new InterceptedRequest(url, ctx.sessionId())); + if (isInferenceUrl(url)) { + return buildInferenceResponse(url, requestBodyText(request), text); + } + return buildNonInferenceResponse(url); + } + } + + /** + * A minimal {@link HttpResponse} over an in-memory body for the send override. + */ + private static final class StubHttpResponse implements HttpResponse { + + private final int status; + private final HttpHeaders headers; + private final byte[] body; + + StubHttpResponse(int status, String contentType, String body) { + this.status = status; + this.body = body.getBytes(StandardCharsets.UTF_8); + this.headers = HttpHeaders.of(Map.of("content-type", List.of(contentType)), (k, v) -> true); + } + + @Override + public int statusCode() { + return status; + } + + @Override + public HttpRequest request() { + return null; + } + + @Override + public Optional> previousResponse() { + return Optional.empty(); + } + + @Override + public HttpHeaders headers() { + return headers; + } + + @Override + public InputStream body() { + return new ByteArrayInputStream(body); + } + + @Override + public Optional sslSession() { + return Optional.empty(); + } + + @Override + public URI uri() { + return null; + } + + @Override + public HttpClient.Version version() { + return HttpClient.Version.HTTP_1_1; + } + } +} diff --git a/java/src/test/java/com/github/copilot/FakeUpstreamServer.java b/java/src/test/java/com/github/copilot/FakeUpstreamServer.java new file mode 100644 index 000000000..7af60d4d3 --- /dev/null +++ b/java/src/test/java/com/github/copilot/FakeUpstreamServer.java @@ -0,0 +1,299 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +package com.github.copilot; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A minimal raw-socket HTTP/1.1 + RFC 6455 WebSocket upstream used by the + * idiomatic-handler e2e test. + *

+ * It serves the synthetic CAPI HTTP endpoints (model catalog, model session, + * policy, {@code /responses} SSE) and, on a WebSocket upgrade, echoes the + * ordered {@code /responses} events as one batch of text messages per inbound + * message. It avoids any third-party server dependency so the test exercises + * the real {@link java.net.http.WebSocket} forwarding path against a genuine + * upstream. + *

+ */ +final class FakeUpstreamServer implements AutoCloseable { + + private static final String WS_MAGIC = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + private final ServerSocket serverSocket; + private final Thread acceptThread; + private final AtomicInteger upstreamWsRequests = new AtomicInteger(); + private final String httpText; + private final String wsText; + private volatile boolean running = true; + + FakeUpstreamServer(String httpText, String wsText) throws IOException { + this.httpText = httpText; + this.wsText = wsText; + this.serverSocket = new ServerSocket(0, 50, InetAddress.getByName("127.0.0.1")); + this.acceptThread = new Thread(this::acceptLoop, "fake-upstream-accept"); + this.acceptThread.setDaemon(true); + this.acceptThread.start(); + } + + int port() { + return serverSocket.getLocalPort(); + } + + String httpUrl() { + return "http://127.0.0.1:" + port(); + } + + String wsUrl() { + return "ws://127.0.0.1:" + port(); + } + + int upstreamWsRequests() { + return upstreamWsRequests.get(); + } + + private void acceptLoop() { + while (running) { + try { + Socket socket = serverSocket.accept(); + Thread t = new Thread(() -> handle(socket), "fake-upstream-conn"); + t.setDaemon(true); + t.start(); + } catch (IOException e) { + return; + } + } + } + + private void handle(Socket socket) { + try (socket) { + InputStream in = socket.getInputStream(); + OutputStream out = socket.getOutputStream(); + + String requestLine = readLine(in); + if (requestLine == null || requestLine.isEmpty()) { + return; + } + String[] parts = requestLine.split(" "); + String path = parts.length > 1 ? parts[1] : "/"; + + Map headers = new java.util.LinkedHashMap<>(); + String line; + while ((line = readLine(in)) != null && !line.isEmpty()) { + int colon = line.indexOf(':'); + if (colon > 0) { + headers.put(line.substring(0, colon).trim().toLowerCase(Locale.ROOT), + line.substring(colon + 1).trim()); + } + } + + if ("websocket".equalsIgnoreCase(headers.get("upgrade"))) { + serveWebSocket(in, out, headers); + return; + } + serveHttp(in, out, path, headers); + } catch (Exception ignored) { + // Connection error; drop it. + } + } + + private void serveHttp(InputStream in, OutputStream out, String path, Map headers) + throws IOException { + String contentLength = headers.get("content-length"); + if (contentLength != null) { + int len; + try { + len = Integer.parseInt(contentLength.trim()); + } catch (NumberFormatException e) { + len = 0; + } + byte[] body = new byte[len]; + int read = 0; + while (read < len) { + int n = in.read(body, read, len - read); + if (n < 0) { + break; + } + read += n; + } + } + + String lower = path.toLowerCase(Locale.ROOT); + String contentType = "application/json"; + String body; + int status = 200; + if (lower.endsWith("/models")) { + body = CopilotRequestTestSupport.modelCatalog(List.of("/responses", "ws:/responses")); + } else if (lower.contains("/models/session")) { + body = "{}"; + } else if (lower.contains("/policy")) { + body = "{\"state\":\"enabled\"}"; + } else if (lower.endsWith("/responses")) { + contentType = "text/event-stream"; + body = CopilotRequestTestSupport.sseBody(httpText, "resp_stub_http"); + } else { + status = 404; + body = "{\"error\":\"not_found\"}"; + } + + byte[] bodyBytes = body.getBytes(StandardCharsets.UTF_8); + String header = "HTTP/1.1 " + status + " " + (status == 200 ? "OK" : "Not Found") + "\r\n" + "content-type: " + + contentType + "\r\n" + "content-length: " + bodyBytes.length + "\r\n" + "connection: close\r\n\r\n"; + out.write(header.getBytes(StandardCharsets.US_ASCII)); + out.write(bodyBytes); + out.flush(); + } + + private void serveWebSocket(InputStream in, OutputStream out, Map headers) throws Exception { + String key = headers.get("sec-websocket-key"); + MessageDigest sha1 = MessageDigest.getInstance("SHA-1"); + byte[] digest = sha1.digest((key + WS_MAGIC).getBytes(StandardCharsets.US_ASCII)); + String accept = Base64.getEncoder().encodeToString(digest); + String response = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Accept: " + accept + "\r\n\r\n"; + out.write(response.getBytes(StandardCharsets.US_ASCII)); + out.flush(); + + ByteArrayOutputStream message = new ByteArrayOutputStream(); + while (true) { + int b1 = in.read(); + if (b1 < 0) { + return; + } + boolean fin = (b1 & 0x80) != 0; + int opcode = b1 & 0x0F; + + int b2 = in.read(); + if (b2 < 0) { + return; + } + boolean masked = (b2 & 0x80) != 0; + long len = b2 & 0x7F; + if (len == 126) { + len = ((long) in.read() << 8) | in.read(); + } else if (len == 127) { + len = 0; + for (int i = 0; i < 8; i++) { + len = (len << 8) | in.read(); + } + } + + byte[] mask = new byte[4]; + if (masked) { + readFully(in, mask, 4); + } + byte[] payload = new byte[(int) len]; + readFully(in, payload, (int) len); + if (masked) { + for (int i = 0; i < payload.length; i++) { + payload[i] ^= mask[i % 4]; + } + } + + if (opcode == 0x8) { + writeFrame(out, 0x8, new byte[0]); + out.flush(); + return; + } + if (opcode == 0x9) { + writeFrame(out, 0xA, payload); + out.flush(); + continue; + } + if (opcode == 0x0 || opcode == 0x1 || opcode == 0x2) { + message.writeBytes(payload); + if (!fin) { + continue; + } + message.reset(); + upstreamWsRequests.incrementAndGet(); + for (Map event : CopilotRequestTestSupport.responsesEvents(wsText, "resp_stub_ws")) { + byte[] raw = CopilotRequestTestSupport.json(event).getBytes(StandardCharsets.UTF_8); + writeFrame(out, 0x1, raw); + } + out.flush(); + } + } + } + + private static void writeFrame(OutputStream out, int opcode, byte[] payload) throws IOException { + List bytes = new ArrayList<>(); + bytes.add(0x80 | opcode); + int len = payload.length; + if (len < 126) { + bytes.add(len); + } else if (len < 65536) { + bytes.add(126); + bytes.add((len >> 8) & 0xFF); + bytes.add(len & 0xFF); + } else { + bytes.add(127); + for (int i = 7; i >= 0; i--) { + bytes.add((int) ((((long) len) >> (8 * i)) & 0xFF)); + } + } + byte[] header = new byte[bytes.size()]; + for (int i = 0; i < bytes.size(); i++) { + header[i] = (byte) (int) bytes.get(i); + } + out.write(header); + out.write(payload); + } + + private static void readFully(InputStream in, byte[] buffer, int len) throws IOException { + int read = 0; + while (read < len) { + int n = in.read(buffer, read, len - read); + if (n < 0) { + throw new IOException("Unexpected end of stream"); + } + read += n; + } + } + + private static String readLine(InputStream in) throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + int c; + while ((c = in.read()) != -1) { + if (c == '\r') { + int next = in.read(); + if (next == '\n' || next == -1) { + break; + } + buffer.write('\r'); + buffer.write(next); + continue; + } + if (c == '\n') { + break; + } + buffer.write(c); + } + if (c == -1 && buffer.size() == 0) { + return null; + } + return buffer.toString(StandardCharsets.US_ASCII); + } + + @Override + public void close() throws IOException { + running = false; + serverSocket.close(); + } +} diff --git a/nodejs/package-lock.json b/nodejs/package-lock.json index 08cd933c4..d75e7af89 100644 --- a/nodejs/package-lock.json +++ b/nodejs/package-lock.json @@ -16,6 +16,7 @@ "devDependencies": { "@platformatic/vfs": "^0.3.0", "@types/node": "^25.2.0", + "@types/ws": "^8.18.1", "@typescript-eslint/eslint-plugin": "^8.54.0", "@typescript-eslint/parser": "^8.54.0", "esbuild": "^0.28.1", @@ -29,7 +30,8 @@ "semver": "^7.7.3", "tsx": "^4.20.6", "typescript": "^5.0.0", - "vitest": "^4.0.18" + "vitest": "^4.0.18", + "ws": "^8.21.0" }, "engines": { "node": "^20.19.0 || >=22.12.0" @@ -1289,6 +1291,16 @@ "undici-types": "~7.18.0" } }, + "node_modules/@types/ws": { + "version": "8.18.1", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz", + "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@typescript-eslint/eslint-plugin": { "version": "8.56.1", "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.56.1.tgz", @@ -3906,6 +3918,28 @@ "dev": true, "license": "MIT" }, + "node_modules/ws": { + "version": "8.21.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.21.0.tgz", + "integrity": "sha512-Vsp28b7DRcimFQvrqu2Wek3z1iYxDCWqHYB8Qsnk/S4RfaCQzPGPyBNuVjJV3cd6UiKtUtp6sNM77gWvzcCH+g==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/yaml": { "version": "2.9.0", "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.9.0.tgz", diff --git a/nodejs/package.json b/nodejs/package.json index 140a50fd4..d40474ba8 100644 --- a/nodejs/package.json +++ b/nodejs/package.json @@ -63,6 +63,7 @@ "devDependencies": { "@platformatic/vfs": "^0.3.0", "@types/node": "^25.2.0", + "@types/ws": "^8.18.1", "@typescript-eslint/eslint-plugin": "^8.54.0", "@typescript-eslint/parser": "^8.54.0", "esbuild": "^0.28.1", @@ -76,7 +77,8 @@ "semver": "^7.7.3", "tsx": "^4.20.6", "typescript": "^5.0.0", - "vitest": "^4.0.18" + "vitest": "^4.0.18", + "ws": "^8.21.0" }, "engines": { "node": "^20.19.0 || >=22.12.0" diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index a6efb061a..8cebcf341 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -29,12 +29,15 @@ import { import { createServerRpc, createInternalServerRpc, + registerClientGlobalApiHandlers, registerClientSessionApiHandlers, } from "./generated/rpc.js"; import type { OpenCanvasInstance, SessionUpdateOptionsParams } from "./generated/rpc.js"; import { getSdkProtocolVersion } from "./sdkProtocolVersion.js"; import { CopilotSession } from "./session.js"; import { createSessionFsAdapter, type SessionFsProvider } from "./sessionFsProvider.js"; +import { createCopilotRequestAdapter } from "./copilotRequestHandler.js"; +import type { CopilotRequestHandler } from "./copilotRequestHandler.js"; import { getTraceContext } from "./telemetry.js"; import { ToolSet } from "./toolSet.js"; import type { @@ -418,6 +421,8 @@ export class CopilotClient { private negotiatedProtocolVersion: number | null = null; /** Connection-level session filesystem config, set via constructor option. */ private sessionFsConfig: SessionFsConfig | null = null; + private requestHandler: CopilotRequestHandler | null = null; + private llmInferenceHandlers: import("./generated/rpc.js").ClientGlobalApiHandlers = {}; /** * Typed server-scoped RPC methods. @@ -529,6 +534,8 @@ export class CopilotClient { this.onListModels = options.onListModels; this.onGetTraceContext = options.onGetTraceContext; this.sessionFsConfig = options.sessionFs ?? null; + this.requestHandler = options.requestHandler ?? null; + this.setupLlmInference(); const effectiveEnv = options.env ?? process.env; this.resolvedEnv = effectiveEnv; @@ -645,6 +652,21 @@ export class CopilotClient { session.clientSessionApis.sessionFs = createSessionFsAdapter(provider); } + private setupLlmInference(): void { + if (!this.requestHandler) { + return; + } + this.llmInferenceHandlers = { + llmInference: createCopilotRequestAdapter(this.requestHandler, () => { + if (!this.connection) { + return undefined; + } + this._rpc ??= createServerRpc(this.connection); + return this._rpc; + }), + }; + } + /** * Starts the CLI server and establishes a connection. * @@ -692,6 +714,13 @@ export class CopilotClient { }); } + // If a request handler was configured, register it. The runtime + // will then route outbound model HTTP requests through the + // registered handler for the duration of each session. + if (this.requestHandler) { + await this.connection!.sendRequest("llmInference.setProvider", {}); + } + this.state = "connected"; } catch (error) { this.state = "error"; @@ -761,7 +790,6 @@ export class CopilotClient { // Ask SDK-owned runtimes to flush and clean up before we tear down // their transport/process. External runtimes may be shared, so only // close our connection to them. - let runtimeShutdownCompleted = false; if (this.connection && this.cliProcess && !this.isExternalServer) { const runtimeShutdownStart = Date.now(); const shutdownPromise = this.rpc.runtime.shutdown(); @@ -772,7 +800,6 @@ export class CopilotClient { RUNTIME_SHUTDOWN_TIMEOUT_MS, `runtime.shutdown timed out after ${RUNTIME_SHUTDOWN_TIMEOUT_MS}ms` ); - runtimeShutdownCompleted = true; this.logDebugTiming( "CopilotClient.stop runtime shutdown complete", runtimeShutdownStart @@ -829,25 +856,24 @@ export class CopilotClient { } } - // Give runtime.shutdown a bounded window to let the child exit on its - // own before falling back to SIGTERM. + // The runtime completes all cleanup before responding to + // runtime.shutdown and then leaves termination to us; it deliberately + // keeps its JSON-RPC server alive to send the response and never + // self-exits. Waiting a grace window for a self-exit that will never + // come just wastes time, so terminate the child immediately and only + // wait to reap it. if (this.cliProcess && !this.isExternalServer) { const child = this.cliProcess; this.cliProcess = null; try { if (child.exitCode == null && child.signalCode == null) { - const exitedGracefully = runtimeShutdownCompleted - ? await waitForChildExit(child, RUNTIME_SHUTDOWN_TIMEOUT_MS) - : false; - if (!exitedGracefully) { - child.kill(); - if (!(await waitForChildExit(child, RUNTIME_SHUTDOWN_TIMEOUT_MS))) { - errors.push( - new Error( - `Timed out waiting for CLI process to exit after kill: ${RUNTIME_SHUTDOWN_TIMEOUT_MS}ms` - ) - ); - } + child.kill(); + if (!(await waitForChildExit(child, RUNTIME_SHUTDOWN_TIMEOUT_MS))) { + errors.push( + new Error( + `Timed out waiting for CLI process to exit after kill: ${RUNTIME_SHUTDOWN_TIMEOUT_MS}ms` + ) + ); } } } catch (error) { @@ -2362,6 +2388,11 @@ export class CopilotClient { return session.clientSessionApis; }); + // Register client *global* API handlers (e.g. LLM inference) on the + // same connection. These methods carry no implicit sessionId dispatch + // — the runtime calls into a single handler for the whole connection. + registerClientGlobalApiHandlers(this.connection, this.llmInferenceHandlers); + this.connection.onClose(() => { this.state = "disconnected"; }); diff --git a/nodejs/src/copilotRequestHandler.ts b/nodejs/src/copilotRequestHandler.ts new file mode 100644 index 000000000..11cee309b --- /dev/null +++ b/nodejs/src/copilotRequestHandler.ts @@ -0,0 +1,820 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import type { + LlmInferenceHandler, + LlmInferenceHeaders, + LlmInferenceHttpRequestChunkRequest, + LlmInferenceHttpRequestChunkResult, + LlmInferenceHttpRequestStartRequest, + LlmInferenceHttpRequestStartResult, +} from "./generated/rpc.js"; +import type { createServerRpc } from "./generated/rpc.js"; + +type ServerRpc = ReturnType; + +const sharedTextDecoder = new TextDecoder("utf-8", { fatal: false }); +const sharedTextEncoder = new TextEncoder(); + +const kBridge = Symbol("copilotWebSocketResponseBridge"); +const kCompletion = Symbol("copilotWebSocketCompletion"); +const kOpen = Symbol("copilotWebSocketOpen"); +const kSuppressCloseOnDispose = Symbol("copilotWebSocketSuppressCloseOnDispose"); +const kHandle = Symbol("copilotRequestHandle"); + +type InternalContext = CopilotRequestContext & { [kBridge]: CopilotWebSocketResponseBridge }; + +/** + * Per-request context handed to every {@link CopilotRequestHandler} hook. + * + * @experimental + */ +export interface CopilotRequestContext { + readonly requestId: string; + readonly sessionId?: string; + readonly transport: "http" | "websocket"; + readonly url: string; + readonly headers: LlmInferenceHeaders; + readonly signal: AbortSignal; +} + +/** + * Terminal status for a callback-owned WebSocket connection. + * + * @experimental + */ +export class CopilotWebSocketCloseStatus { + static readonly normalClosure = new CopilotWebSocketCloseStatus(); + + constructor( + readonly description?: string, + readonly errorCode?: string, + readonly error?: Error + ) {} +} + +/** + * Lower-level WebSocket handler with no upstream connection. + * + * This is the abstract base shared by all WebSocket handlers. It does not open + * or forward to any upstream server on its own — subclass it directly only when + * you want to service a fully synthetic connection yourself (e.g. answer the + * runtime without any real backend). For the common case of mutating and + * forwarding traffic to the real upstream, subclass {@link CopilotWebSocketHandler} + * instead, which connects upstream and forwards by default. + * + * @experimental + */ +export abstract class CopilotWebSocketHandlerBase implements AsyncDisposable { + readonly #response: CopilotWebSocketResponseBridge; + readonly #completion: Promise; + #resolveCompletion!: (status: CopilotWebSocketCloseStatus) => void; + #closed = false; + [kSuppressCloseOnDispose] = false; + + protected readonly context: CopilotRequestContext; + + protected constructor(context: CopilotRequestContext) { + this.context = context; + const bridge = (context as Partial)[kBridge]; + if (!bridge) { + throw new Error("WebSocket response bridge is not attached"); + } + this.#response = bridge; + this.#completion = new Promise((resolve) => { + this.#resolveCompletion = resolve; + }); + } + + async sendResponseMessage(data: string | Uint8Array): Promise { + await this.#response.write(data); + } + + async close( + status: CopilotWebSocketCloseStatus = CopilotWebSocketCloseStatus.normalClosure + ): Promise { + if (this.#closed) { + return; + } + this.#closed = true; + if (status.error) { + await this.#response.error({ + message: status.description ?? status.error.message, + code: status.errorCode, + }); + } else { + await this.#response.end(); + } + this.#resolveCompletion(status); + } + + abstract sendRequestMessage(data: string | Uint8Array): Promise | void; + + async [Symbol.asyncDispose](): Promise { + if (!this[kSuppressCloseOnDispose] && !this.#closed) { + await this.close(CopilotWebSocketCloseStatus.normalClosure); + } + } + + /** @internal */ + get [kCompletion](): Promise { + return this.#completion; + } + + /** @internal */ + async [kOpen](): Promise {} +} + +/** + * WebSocket handler that connects to the real upstream and forwards traffic by + * default. This is the type returned by the default + * {@link CopilotRequestHandler.openWebSocket}. + * + * Override nothing to get full pass-through. To mutate traffic, subclass this + * type and override a message hook, then call `super` to keep forwarding to the + * upstream. (Subclassing {@link CopilotWebSocketHandlerBase} instead would drop + * forwarding entirely.) + * + * @experimental + */ +export class CopilotWebSocketHandler extends CopilotWebSocketHandlerBase { + readonly #url: string; + #upstream: WebSocket | null = null; + + constructor(context: CopilotRequestContext, url = context.url) { + super(context); + this.#url = url; + } + + override sendRequestMessage(data: string | Uint8Array): void { + if (this.#upstream?.readyState !== WebSocket.OPEN) { + return; + } + this.#upstream.send(data); + } + + /** @internal */ + override async [kOpen](): Promise { + if (this.#upstream) { + return; + } + const upstream = new WebSocket(this.#url); + upstream.binaryType = "arraybuffer"; + this.#upstream = upstream; + upstream.addEventListener("message", (event) => { + void this.sendResponseMessage(normalizeWsData(event.data)).catch( + async (err: unknown) => { + await this.close( + new CopilotWebSocketCloseStatus( + err instanceof Error ? err.message : String(err), + undefined, + err instanceof Error ? err : new Error(String(err)) + ) + ); + } + ); + }); + upstream.addEventListener("close", () => { + void this.close(CopilotWebSocketCloseStatus.normalClosure); + }); + upstream.addEventListener("error", () => { + void this.close( + new CopilotWebSocketCloseStatus( + "WebSocket error", + undefined, + new Error("WebSocket error") + ) + ); + }); + await new Promise((resolve, reject) => { + if (upstream.readyState === WebSocket.OPEN) { + resolve(); + return; + } + upstream.addEventListener("open", () => resolve(), { once: true }); + upstream.addEventListener("error", () => reject(new Error("WebSocket error")), { + once: true, + }); + }); + } + + override async close( + status: CopilotWebSocketCloseStatus = CopilotWebSocketCloseStatus.normalClosure + ): Promise { + try { + if ( + this.#upstream?.readyState === WebSocket.OPEN || + this.#upstream?.readyState === WebSocket.CONNECTING + ) { + this.#upstream?.close(); + } + } catch { + // Best-effort; the socket may already be closed. + } + await super.close(status); + } + + override async [Symbol.asyncDispose](): Promise { + try { + await super[Symbol.asyncDispose](); + } finally { + try { + this.#upstream?.close(); + } catch { + // Best-effort. + } + } + } +} + +/** + * Base class for SDK consumers who want to observe or mutate the outbound + * model-layer requests the runtime issues (for both CAPI and BYOK providers). + * Subclass and override {@link sendRequest} or {@link openWebSocket}; an + * instance that overrides nothing is a transparent pass-through. + * + * @experimental + */ +export class CopilotRequestHandler { + protected sendRequest(request: Request, ctx: CopilotRequestContext): Promise { + return fetch(request, { signal: ctx.signal }); + } + + protected openWebSocket(ctx: CopilotRequestContext): Promise { + return Promise.resolve(new CopilotWebSocketHandler(ctx)); + } + + /** @internal */ + async [kHandle](exchange: CopilotRequestExchange): Promise { + const bridge = new CopilotWebSocketResponseBridge(exchange); + const ctx: InternalContext = { + requestId: exchange.requestId, + sessionId: exchange.sessionId, + transport: exchange.transport, + url: exchange.url, + headers: exchange.headers, + signal: exchange.signal, + [kBridge]: bridge, + }; + + if (exchange.transport === "websocket") { + await this.#handleWebSocket(exchange, ctx); + } else { + await this.#handleHttp(exchange, ctx); + } + } + + async #handleHttp(exchange: CopilotRequestExchange, ctx: CopilotRequestContext): Promise { + const request = await buildFetchRequest(exchange); + const response = await this.sendRequest(request, ctx); + await streamResponse(response, exchange); + } + + async #handleWebSocket(exchange: CopilotRequestExchange, ctx: InternalContext): Promise { + const handler = await this.openWebSocket(ctx); + try { + await handler[kOpen](); + + // The runtime blocks the WebSocket connect until it receives the + // 101 response head (the upgrade acknowledgement) and only then + // begins forwarding inbound messages as request-body chunks. Emit + // it eagerly here — waiting for the first upstream message would + // deadlock, since the upstream stays silent until it receives a + // request message the runtime won't send before the upgrade + // completes. + await ctx[kBridge].start(); + + let cancelled: unknown; + const clientSettled = (async () => { + for await (const chunk of exchange.requestBody) { + await handler.sendRequestMessage(decodeFrame(chunk)); + } + return "client-complete" as const; + })().catch((err) => { + cancelled = err; + return "client-error" as const; + }); + + const first = await Promise.race([ + clientSettled, + handler[kCompletion].then(() => "server-done" as const), + ]); + + if (first === "client-error") { + handler[kSuppressCloseOnDispose] = true; + throw cancelled instanceof Error ? cancelled : new Error(String(cancelled)); + } + + if (first === "client-complete") { + await handler.close(CopilotWebSocketCloseStatus.normalClosure); + await handler[kCompletion]; + return; + } + + const status = await handler[kCompletion]; + if (status.error) { + throw status.error; + } + } finally { + await handler[Symbol.asyncDispose](); + } + } +} + +/** + * Adapt a {@link CopilotRequestHandler} into the generated + * {@link LlmInferenceHandler} shape consumed by the SDK's RPC dispatcher. + * + * Maintains a per-`requestId` table of {@link CopilotRequestExchange}: each + * `httpRequestStart` allocates one and fires the handler in the background, + * returning immediately so the runtime's RPC reply is not gated on the + * consumer's I/O. Subsequent `httpRequestChunk` frames are routed into the + * matching exchange's body stream. + * + * @internal + */ +export function createCopilotRequestAdapter( + handler: CopilotRequestHandler, + getServerRpc: () => ServerRpc | undefined +): LlmInferenceHandler { + const pending = new Map(); + + function getOrCreate(requestId: string): CopilotRequestExchange { + // The runtime dispatches httpRequestStart and httpRequestChunk frames + // independently. get-or-create keeps the adapter correct regardless of + // arrival order: a body chunk (including the terminal end frame) that + // races ahead of its start frame is buffered into the same exchange + // rather than dropped, which would otherwise hang the body drain. + let exchange = pending.get(requestId); + if (!exchange) { + exchange = new CopilotRequestExchange(requestId, getServerRpc); + pending.set(requestId, exchange); + } + return exchange; + } + + async function run(exchange: CopilotRequestExchange): Promise { + try { + await handler[kHandle](exchange); + if (!exchange.finished) { + await finalize( + exchange, + 502, + "Copilot request handler returned without finalising the response (call responseBody.end() or .error())." + ); + } + } catch (err) { + if (exchange.cancelled || exchange.signal.aborted) { + // The runtime already cancelled this request; the handler's + // throw is just the abort propagating out of its upstream call. + await finalize(exchange, 499, "Request cancelled by runtime", "cancelled"); + return; + } + const message = err instanceof Error ? err.message : String(err); + await finalize(exchange, 502, message); + } finally { + pending.delete(exchange.requestId); + } + } + + return { + async httpRequestStart( + params: LlmInferenceHttpRequestStartRequest + ): Promise { + // Adopt any exchange a racing chunk already created — with its + // buffered body — rather than dropping those frames. + const exchange = getOrCreate(params.requestId); + exchange.setContext(params); + void run(exchange); + return {}; + }, + async httpRequestChunk( + params: LlmInferenceHttpRequestChunkRequest + ): Promise { + // May arrive before the matching start frame; get-or-create so the + // body is buffered, never lost. + routeChunk(getOrCreate(params.requestId), params); + return {}; + }, + }; +} + +async function finalize( + exchange: CopilotRequestExchange, + status: number, + message: string, + code?: string +): Promise { + if (exchange.finished) { + return; + } + try { + if (!exchange.started) { + await exchange.startResponse({ status, headers: {} }); + } + await exchange.errorResponse({ message, code }); + } catch { + // Best-effort — the connection may already be dead. + } +} + +function routeChunk( + exchange: CopilotRequestExchange, + params: LlmInferenceHttpRequestChunkRequest +): void { + if (params.cancel) { + exchange.pushCancel(params.cancelReason); + return; + } + if (params.data && params.data.length > 0) { + exchange.pushChunk(decodeChunkData(params.data, !!params.binary)); + } + if (params.end) { + exchange.pushEnd(); + } +} + +/** Response head emitted to the runtime via {@link CopilotRequestExchange.startResponse}. */ +interface ResponseInit { + status: number; + statusText?: string; + headers?: LlmInferenceHeaders; +} + +interface BodyQueueItem { + chunk?: Uint8Array; + end?: boolean; + cancel?: { reason?: string }; +} + +/** + * One intercepted request in flight. Carries the request context plus the body + * byte stream the runtime feeds in via `httpRequestChunk` frames, and emits the + * handler's response straight back to the runtime through the generated + * `llmInference` server API. Replaces the former provider/sink/response-channel + * indirection with a single object the adapter owns and the handler drives. + */ +class CopilotRequestExchange { + readonly requestId: string; + sessionId?: string; + method = "GET"; + url = ""; + headers: LlmInferenceHeaders = {}; + transport: "http" | "websocket" = "http"; + + readonly #getServerRpc: () => ServerRpc | undefined; + readonly #abort = new AbortController(); + readonly #buffer: BodyQueueItem[] = []; + #waker: (() => void) | null = null; + #drained = false; + #started = false; + #finished = false; + #cancelled = false; + + constructor(requestId: string, getServerRpc: () => ServerRpc | undefined) { + this.requestId = requestId; + this.#getServerRpc = getServerRpc; + } + + /** Fill in the request context once the matching start frame arrives. */ + setContext(params: LlmInferenceHttpRequestStartRequest): void { + this.sessionId = params.sessionId; + this.method = params.method; + this.url = params.url; + this.headers = params.headers; + this.transport = params.transport ?? "http"; + } + + get signal(): AbortSignal { + return this.#abort.signal; + } + + get started(): boolean { + return this.#started; + } + + get finished(): boolean { + return this.#finished; + } + + get cancelled(): boolean { + return this.#cancelled; + } + + // --- Request body feed (driven by the adapter as chunk frames arrive) --- + + pushChunk(chunk: Uint8Array): void { + this.#push({ chunk }); + } + + pushEnd(): void { + this.#push({ end: true }); + } + + pushCancel(reason?: string): void { + this.#cancelled = true; + this.#abort.abort(); + this.#push({ cancel: { reason } }); + } + + #push(item: BodyQueueItem): void { + this.#buffer.push(item); + const w = this.#waker; + this.#waker = null; + w?.(); + } + + /** + * Request body bytes, yielded as they arrive. A cancel frame surfaces as a + * thrown error so the handler's upstream call is torn down. + */ + get requestBody(): AsyncIterable { + return { + [Symbol.asyncIterator]: (): AsyncIterator => ({ + next: async (): Promise> => { + if (this.#drained) { + return { value: undefined, done: true }; + } + while (this.#buffer.length === 0) { + await new Promise((resolve) => { + this.#waker = resolve; + }); + } + const item = this.#buffer.shift()!; + if (item.cancel) { + this.#drained = true; + throw new Error( + item.cancel.reason + ? `Request cancelled by runtime: ${item.cancel.reason}` + : "Request cancelled by runtime" + ); + } + if (item.end) { + this.#drained = true; + return { value: undefined, done: true }; + } + return { value: item.chunk ?? new Uint8Array(), done: false }; + }, + }), + }; + } + + // --- Response emit (driven by the handler). Strict state machine: --- + // startResponse once -> 0..N writeResponse -> exactly one of + // endResponse / errorResponse. + + async startResponse(init: ResponseInit): Promise { + if (this.#started) { + throw new Error("Copilot request response start() called twice."); + } + if (this.#finished) { + throw new Error("Copilot request response already finished."); + } + this.#started = true; + await this.#rpc().llmInference.httpResponseStart({ + requestId: this.requestId, + status: init.status, + statusText: init.statusText, + headers: init.headers ?? {}, + }); + } + + async writeResponse(data: string | Uint8Array): Promise { + if (this.#cancelled) { + throw new Error("Copilot request was cancelled by the runtime."); + } + if (!this.#started) { + throw new Error("Copilot request response write() called before start()."); + } + if (this.#finished) { + throw new Error("Copilot request response write() called after end()/error()."); + } + const isString = typeof data === "string"; + await this.#rpc().llmInference.httpResponseChunk({ + requestId: this.requestId, + data: isString ? data : Buffer.from(data).toString("base64"), + binary: !isString, + end: false, + }); + } + + async endResponse(): Promise { + if (this.#finished) { + return; + } + this.#finished = true; + await this.#rpc().llmInference.httpResponseChunk({ + requestId: this.requestId, + data: "", + end: true, + }); + } + + async errorResponse(error: { message: string; code?: string }): Promise { + if (this.#finished) { + return; + } + this.#finished = true; + await this.#rpc().llmInference.httpResponseChunk({ + requestId: this.requestId, + data: "", + end: true, + error: { message: error.message, code: error.code }, + }); + } + + #rpc(): ServerRpc { + const r = this.#getServerRpc(); + if (!r) { + throw new Error("Copilot request response used after RPC connection closed."); + } + return r; + } +} + +const FORBIDDEN_REQUEST_HEADERS = new Set([ + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", +]); + +async function buildFetchRequest(exchange: CopilotRequestExchange): Promise { + const headers = new Headers(); + for (const [name, values] of Object.entries(exchange.headers)) { + if (!values) { + continue; + } + if (FORBIDDEN_REQUEST_HEADERS.has(name.toLowerCase())) { + continue; + } + for (const value of values) { + headers.append(name, value); + } + } + + const method = exchange.method.toUpperCase(); + const hasBody = method !== "GET" && method !== "HEAD"; + + let body: Uint8Array | undefined; + if (hasBody) { + const buffered = await drainAsync(exchange.requestBody); + if (buffered.length > 0) { + body = buffered; + } + } else { + await drainAsync(exchange.requestBody); + } + + return new Request(exchange.url, { method, headers, body }); +} + +async function drainAsync(stream: AsyncIterable): Promise { + const parts: Uint8Array[] = []; + let total = 0; + for await (const chunk of stream) { + parts.push(chunk); + total += chunk.byteLength; + } + if (parts.length === 0) { + return new Uint8Array(0); + } + if (parts.length === 1) { + return parts[0]; + } + const out = new Uint8Array(total); + let off = 0; + for (const part of parts) { + out.set(part, off); + off += part.byteLength; + } + return out; +} + +async function streamResponse(response: Response, exchange: CopilotRequestExchange): Promise { + await exchange.startResponse({ + status: response.status, + statusText: response.statusText || undefined, + headers: headersToMultiMap(response.headers), + }); + + const body = response.body; + if (!body) { + await exchange.endResponse(); + return; + } + + const reader = body.getReader(); + try { + for (;;) { + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.byteLength > 0) { + await exchange.writeResponse(value); + } + } + await exchange.endResponse(); + } finally { + reader.releaseLock(); + } +} + +function headersToMultiMap(headers: Headers): LlmInferenceHeaders { + const out: Record = {}; + headers.forEach((value, name) => { + if (name.toLowerCase() === "set-cookie") { + return; + } + const list = out[name] ?? (out[name] = []); + list.push(value); + }); + const setCookies = headers.getSetCookie(); + if (setCookies.length > 0) { + out["set-cookie"] = setCookies; + } + return out; +} + +function decodeChunkData(data: string, binary: boolean): Uint8Array { + if (binary) { + return new Uint8Array(Buffer.from(data, "base64")); + } + return sharedTextEncoder.encode(data); +} + +function decodeFrame(chunk: Uint8Array): string { + return sharedTextDecoder.decode(chunk); +} + +function normalizeWsData(data: unknown): string | Uint8Array { + if (typeof data === "string") { + return data; + } + if (data instanceof Uint8Array) { + return data; + } + if (data instanceof ArrayBuffer) { + return new Uint8Array(data); + } + return new Uint8Array(); +} + +/** + * Forwards upstream WebSocket messages back to the owning + * {@link CopilotRequestExchange}. The 101 upgrade head is emitted eagerly via + * {@link start} (the runtime gates the connect on it); thereafter writes are + * serialised so the head always precedes any body or terminal frame. + */ +class CopilotWebSocketResponseBridge { + readonly #exchange: CopilotRequestExchange; + #started = false; + #completed = false; + #serial: Promise = Promise.resolve(); + + constructor(exchange: CopilotRequestExchange) { + this.#exchange = exchange; + } + + /** Emit the 101 upgrade head now, acknowledging the WebSocket connect. */ + start(): Promise { + return this.#run(false, () => Promise.resolve()); + } + + write(data: string | Uint8Array): Promise { + return this.#run(false, () => this.#exchange.writeResponse(data)); + } + + end(): Promise { + return this.#run(true, () => this.#exchange.endResponse()); + } + + error(error: { message: string; code?: string }): Promise { + return this.#run(true, () => this.#exchange.errorResponse(error)); + } + + #run(terminal: boolean, action: () => Promise): Promise { + const task = this.#serial.then(async () => { + if (this.#completed) { + return; + } + if (!this.#started) { + this.#started = true; + await this.#exchange.startResponse({ status: 101, headers: {} }); + } + if (terminal) { + this.#completed = true; + } + await action(); + }); + this.#serial = task.catch(() => {}); + return task; + } +} diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index 6303f9db2..e265846fd 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -507,6 +507,7 @@ export type InstructionSourceLocation = * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema * via the `definition` "LlmInferenceHttpRequestStartTransport". */ +/** @experimental */ export type LlmInferenceHttpRequestStartTransport = /** Plain HTTP or SSE response. Each body chunk is an opaque byte range; the response is a status line, headers, and a (possibly streamed) body. */ | "http" @@ -4359,6 +4360,7 @@ export interface LlmInferenceHeaders { * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema * via the `definition` "LlmInferenceHttpRequestChunkRequest". */ +/** @experimental */ export interface LlmInferenceHttpRequestChunkRequest { /** * Matches the requestId from the originating httpRequestStart frame. @@ -4391,6 +4393,7 @@ export interface LlmInferenceHttpRequestChunkRequest { * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema * via the `definition` "LlmInferenceHttpRequestChunkResult". */ +/** @experimental */ export interface LlmInferenceHttpRequestChunkResult {} /** * The head of an outbound model-layer HTTP request. @@ -4398,6 +4401,7 @@ export interface LlmInferenceHttpRequestChunkResult {} * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema * via the `definition` "LlmInferenceHttpRequestStartRequest". */ +/** @experimental */ export interface LlmInferenceHttpRequestStartRequest { /** * Opaque runtime-minted id, unique per in-flight request. The SDK uses this to correlate httpRequestChunk frames and to address its httpResponseStart / httpResponseChunk replies back to the runtime. @@ -4424,6 +4428,7 @@ export interface LlmInferenceHttpRequestStartRequest { * This interface was referenced by `_RpcSchemaRoot`'s JSON-Schema * via the `definition` "LlmInferenceHttpRequestStartResult". */ +/** @experimental */ export interface LlmInferenceHttpRequestStartResult {} /** * Set to terminate the response with a transport-level failure. Implies end-of-stream; any further chunks for this requestId are ignored. @@ -15852,3 +15857,52 @@ export function registerClientSessionApiHandlers( return handler.invoke(params); }); } + +/** Handler for `llmInference` client global API methods. */ +/** @experimental */ +export interface LlmInferenceHandler { + /** + * Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true). + * + * @param params The head of an outbound model-layer HTTP request. + * + * @returns Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed. + */ + httpRequestStart(params: LlmInferenceHttpRequestStartRequest): Promise; + /** + * Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set. + * + * @param params A request body chunk or cancellation signal. + * + * @returns Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget. + */ + httpRequestChunk(params: LlmInferenceHttpRequestChunkRequest): Promise; +} + +/** All client global API handler groups. */ +export interface ClientGlobalApiHandlers { + llmInference?: LlmInferenceHandler; +} + +/** + * Register client global API handlers on a JSON-RPC connection. + * The server calls these methods to delegate work to the client. + * Unlike session-scoped client APIs, these methods carry no implicit + * `sessionId` dispatch key — a single set of handlers serves the entire + * connection. + */ +export function registerClientGlobalApiHandlers( + connection: MessageConnection, + handlers: ClientGlobalApiHandlers, +): void { + connection.onRequest("llmInference.httpRequestStart", async (params: LlmInferenceHttpRequestStartRequest) => { + const handler = handlers.llmInference; + if (!handler) throw new Error("No llmInference client-global handler registered"); + return handler.httpRequestStart(params); + }); + connection.onRequest("llmInference.httpRequestChunk", async (params: LlmInferenceHttpRequestChunkRequest) => { + const handler = handlers.llmInference; + if (!handler) throw new Error("No llmInference client-global handler registered"); + return handler.httpRequestChunk(params); + }); +} diff --git a/nodejs/src/index.ts b/nodejs/src/index.ts index 9b266fc9c..861c27fa9 100644 --- a/nodejs/src/index.ts +++ b/nodejs/src/index.ts @@ -28,6 +28,10 @@ export { approveAll, convertMcpCallToolResult, createSessionFsAdapter, + CopilotRequestHandler, + CopilotWebSocketHandlerBase, + CopilotWebSocketCloseStatus, + CopilotWebSocketHandler, SYSTEM_MESSAGE_SECTIONS, } from "./types.js"; // Re-export the generated session-event types (every *Event interface and @@ -125,6 +129,7 @@ export type { SessionFsSqliteQueryResult, SessionFsSqliteQueryType, SessionFsSqliteProvider, + CopilotRequestContext, SystemMessageAppendConfig, SystemMessageConfig, SystemMessageCustomizeConfig, diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index f198a88b3..902ae6fcf 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -9,6 +9,7 @@ // Import and re-export generated session event types import type { Canvas } from "./canvas.js"; import type { SessionFsProvider } from "./sessionFsProvider.js"; +import type { CopilotRequestHandler } from "./copilotRequestHandler.js"; import type { ReasoningSummary, SessionEvent as GeneratedSessionEvent, @@ -33,6 +34,14 @@ export type { SessionFsFileInfo } from "./sessionFsProvider.js"; export type { SessionFsSqliteQueryResult } from "./sessionFsProvider.js"; export type { SessionFsSqliteQueryType } from "./sessionFsProvider.js"; export type { SessionFsSqliteProvider } from "./sessionFsProvider.js"; +export type { LlmInferenceHeaders } from "./generated/rpc.js"; +export type { CopilotRequestContext } from "./copilotRequestHandler.js"; +export { + CopilotRequestHandler, + CopilotWebSocketHandlerBase, + CopilotWebSocketCloseStatus, + CopilotWebSocketHandler, +} from "./copilotRequestHandler.js"; /** * Options for creating a CopilotClient @@ -305,6 +314,30 @@ export interface CopilotClientOptions { */ sessionFs?: SessionFsConfig; + /** + * Custom handler for outbound model-layer requests (experimental). + * + * When provided, the client registers as the runtime's request handler + * on connection: every outbound model-layer request the runtime would + * otherwise have issued itself — plain HTTP, streaming SSE, and + * WebSocket — is dispatched back to the handler over JSON-RPC. The + * handler returns the response verbatim, exactly as if the runtime had + * issued the request itself. + * + * Subclass {@link CopilotRequestHandler} and override the hooks you need; + * an instance that overrides nothing is a transparent pass-through. + * + * v1 notes: + * - HTTP (buffered and streaming SSE) and WebSocket transports are all + * intercepted. The handler receives a `transport` discriminator on the + * {@link CopilotRequestContext} for both. + * - The handler is set process-globally on the runtime; the same + * handler is invoked for every session created on this client. + * + * @experimental + */ + requestHandler?: CopilotRequestHandler; + /** * Server-wide idle timeout for sessions in seconds. * Sessions without activity for this duration are automatically cleaned up. diff --git a/nodejs/test/e2e/copilot_request_cancel_error.e2e.test.ts b/nodejs/test/e2e/copilot_request_cancel_error.e2e.test.ts new file mode 100644 index 000000000..5a9cad5a5 --- /dev/null +++ b/nodejs/test/e2e/copilot_request_cancel_error.e2e.test.ts @@ -0,0 +1,184 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, CopilotRequestHandler, type CopilotRequestContext } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +/** + * Cancellation and error coverage for {@link CopilotRequestHandler}. These two + * scenarios exercise the handler's terminal paths that the happy-path session-id + * and HTTP/WebSocket tests never reach: + * + * - **Error** — the handler throws from {@link CopilotRequestHandler.sendRequest} + * for an inference request. The base adapter reports a transport error back to + * the runtime (`errorResponse`) rather than hanging. + * - **Runtime cancel** — the handler blocks an inference request indefinitely; + * when the consumer aborts the turn the runtime cancels the in-flight request, + * firing `ctx.signal`. The handler observes the abort (the `cancel`-frame + * path) instead of leaking a stuck request. + * + * Non-inference model-layer requests (catalog, policy, model session) are served + * with minimal stubs so the turn reaches the inference step. The success-path + * SSE body is intentionally omitted — neither scenario completes a turn. + */ + +function isInferenceUrl(url: string): boolean { + const u = url.toLowerCase(); + return ( + u.endsWith("/chat/completions") || + u.endsWith("/responses") || + u.endsWith("/v1/messages") || + u.endsWith("/messages") + ); +} + +function json(body: string): Response { + return new Response(body, { status: 200, headers: { "content-type": "application/json" } }); +} + +/** Serve the non-inference GETs/POSTs (catalog, policy, model session). */ +function serveNonInference(url: string): Response { + const u = url.toLowerCase(); + if (u.endsWith("/models")) { + return json(MODEL_CATALOG_JSON); + } + if (u.includes("/models/session")) { + return json("{}"); + } + if (u.includes("/policy")) { + return json(JSON.stringify({ state: "enabled" })); + } + return json("{}"); +} + +const MODEL_CATALOG_JSON = JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], +}); + +async function waitFor(predicate: () => boolean, timeoutMs: number): Promise { + const deadline = Date.now() + timeoutMs; + while (!predicate()) { + if (Date.now() > deadline) { + throw new Error("waitFor timed out"); + } + await new Promise((resolve) => setTimeout(resolve, 50)); + } +} + +/** Throws from every inference request to exercise the error-reporting path. */ +class ThrowingRequestHandler extends CopilotRequestHandler { + inferenceAttempts = 0; + + protected override async sendRequest( + request: Request, + _ctx: CopilotRequestContext + ): Promise { + if (!isInferenceUrl(request.url)) { + return serveNonInference(request.url); + } + this.inferenceAttempts++; + throw new Error("synthetic-callback-transport-failure"); + } +} + +/** Blocks every inference request until the runtime cancels it. */ +class CancellingRequestHandler extends CopilotRequestHandler { + inferenceEntered = false; + sawAbort = false; + + protected override async sendRequest( + request: Request, + ctx: CopilotRequestContext + ): Promise { + if (!isInferenceUrl(request.url)) { + return serveNonInference(request.url); + } + this.inferenceEntered = true; + await new Promise((resolve) => { + if (ctx.signal.aborted) { + resolve(); + return; + } + ctx.signal.addEventListener("abort", () => resolve(), { once: true }); + }); + this.sawAbort = true; + // The runtime already dropped the request; throwing simply propagates + // the abort out of the (here, simulated) upstream call. + throw new Error("cancelled by runtime"); + } +} + +describe("CopilotRequestHandler surfaces inference errors", async () => { + const handler = new ThrowingRequestHandler(); + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { requestHandler: handler }, + }); + + it("reports a thrown callback error instead of hanging the turn", async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + try { + // The callback throws on inference; the turn surfaces an error (or + // completes without an assistant message) rather than hanging. + await session.sendAndWait({ prompt: "Say OK." }).catch(() => undefined); + } finally { + await session.disconnect(); + } + + expect( + handler.inferenceAttempts, + "expected the inference callback to be reached and raise" + ).toBeGreaterThan(0); + }, 90_000); +}); + +describe("CopilotRequestHandler observes runtime cancellation", async () => { + const handler = new CancellingRequestHandler(); + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { requestHandler: handler }, + }); + + it("fires ctx.signal when the consumer aborts an in-flight inference request", async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + try { + await session.send("Say OK."); + await waitFor(() => handler.inferenceEntered, 60_000); + await session.abort(); + await waitFor(() => handler.sawAbort, 30_000); + } finally { + await session.disconnect(); + } + + expect(handler.inferenceEntered, "expected the inference callback to be entered").toBe( + true + ); + expect(handler.sawAbort, "expected the callback to observe runtime cancellation").toBe( + true + ); + }, 90_000); +}); diff --git a/nodejs/test/e2e/copilot_request_handler.e2e.test.ts b/nodejs/test/e2e/copilot_request_handler.e2e.test.ts new file mode 100644 index 000000000..6b761984e --- /dev/null +++ b/nodejs/test/e2e/copilot_request_handler.e2e.test.ts @@ -0,0 +1,394 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { createServer, IncomingMessage, Server as HttpServer, ServerResponse } from "http"; +import { AddressInfo } from "net"; +import { afterAll, describe, expect, it } from "vitest"; +import { WebSocket as WsClient, WebSocketServer } from "ws"; +import { + approveAll, + CopilotRequestHandler, + CopilotWebSocketHandlerBase, + CopilotWebSocketCloseStatus, + type CopilotRequestContext, +} from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const HTTP_TEXT = "OK from synthetic HTTP upstream."; +const WS_TEXT = "OK from synthetic WS upstream."; + +/** + * Stand up an in-process upstream that speaks the real CAPI shapes the + * runtime needs: model catalog, policy, `/responses` SSE for HTTP + * inference, and a WebSocket endpoint at `/responses` that answers each + * inbound `response.create` with the ordered `/responses` events the + * reducer expects. + * + * Returned `url` is what the handler subclass rewrites every + * intercepted request to point at — the runtime never talks to this + * server directly; the handler does, on the runtime's behalf. + */ +async function startFakeUpstream(): Promise<{ + url: string; + server: HttpServer; + wsRequestCount: () => number; + close: () => Promise; +}> { + let wsRequests = 0; + + const httpServer = createServer((req, res) => { + const url = new URL(req.url ?? "/", `http://${req.headers.host ?? "localhost"}`); + if (url.pathname === "/models" && req.method === "GET") { + sendJson(res, 200, { + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + supported_endpoints: ["/responses", "ws:/responses"], + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { + max_context_window_tokens: 200000, + max_output_tokens: 8192, + }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], + }); + return; + } + if (url.pathname.endsWith("/models/session")) { + sendJson(res, 200, {}); + return; + } + if (url.pathname.includes("/policy")) { + sendJson(res, 200, { state: "enabled" }); + return; + } + if (url.pathname.endsWith("/responses") && req.method === "POST") { + // Single-shot HTTP inference (e.g. title generation). SSE + // events the `responses-client.ts` reducer accepts. + drainBody(req) + .then(() => { + res.writeHead(200, { + "content-type": "text/event-stream", + "cache-control": "no-cache", + }); + for (const event of buildResponsesEvents(HTTP_TEXT, "resp_stub_http")) { + res.write(`event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`); + } + res.end(); + }) + .catch(() => { + res.writeHead(500).end(); + }); + return; + } + // Anything else: not found. + res.writeHead(404, { "content-type": "application/json" }); + res.end(JSON.stringify({ error: "not_found", path: url.pathname })); + }); + + const wss = new WebSocketServer({ server: httpServer, path: "/responses" }); + wss.on("connection", (socket) => { + socket.on("message", (raw) => { + wsRequests++; + // For each `response.create` request the runtime sends, + // answer with the ordered `/responses` event objects — one + // event per outbound WS message, raw JSON (NOT SSE-framed). + for (const event of buildResponsesEvents(WS_TEXT, "resp_stub_ws")) { + socket.send(JSON.stringify(event)); + } + void raw; + }); + }); + + await new Promise((resolve) => httpServer.listen(0, "127.0.0.1", resolve)); + const port = (httpServer.address() as AddressInfo).port; + const url = `http://127.0.0.1:${port}`; + + return { + url, + server: httpServer, + wsRequestCount: () => wsRequests, + async close() { + wss.clients.forEach((c) => c.terminate()); + await new Promise((resolve) => wss.close(() => resolve())); + await new Promise((resolve) => httpServer.close(() => resolve())); + }, + }; +} + +function sendJson(res: ServerResponse, status: number, body: unknown): void { + res.writeHead(status, { "content-type": "application/json" }); + res.end(JSON.stringify(body)); +} + +async function drainBody(req: IncomingMessage): Promise { + const parts: Buffer[] = []; + for await (const chunk of req) { + parts.push(chunk as Buffer); + } + return Buffer.concat(parts); +} + +function buildResponsesEvents(text: string, id: string): Array> { + return [ + { + type: "response.created", + response: { id, object: "response", status: "in_progress", output: [] }, + }, + { + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + }, + { + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + }, + { type: "response.output_text.delta", output_index: 0, content_index: 0, delta: text }, + { type: "response.output_text.done", output_index: 0, content_index: 0, text }, + { + type: "response.completed", + response: { + id, + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + }, + ]; +} + +/** +interface Counters { + httpRequests: number; + httpResponses: number; + wsRequestMessages: number; + wsResponseMessages: number; +} + +/** + * Single handler subclass that services BOTH transports against the + * per-test fake upstream. Demonstrates mutation in each direction: + * + * - HTTP: rewrites the URL to point at the test server, adds an + * `X-Test-Mutated` header to the outbound request, and adds an + * `X-Test-Response-Mutated` header on the way back. The test server + * echoes the request header into a counter so we can assert it + * actually arrived upstream. + * - WebSocket: rewrites the WS URL similarly, opens with the `ws` + * package inside a custom per-connection handler, and observes + * message counts in both directions. + */ +class TestHandler extends CopilotRequestHandler { + constructor( + private readonly upstreamUrl: string, + private readonly counters: Counters + ) { + super(); + } + + private rewriteUrl(originalUrl: string): string { + const parsed = new URL(originalUrl); + const upstream = new URL(this.upstreamUrl); + parsed.protocol = upstream.protocol; + parsed.host = upstream.host; + return parsed.toString(); + } + + private rewriteWsUrl(originalUrl: string): string { + const parsed = new URL(originalUrl); + const upstream = new URL(this.upstreamUrl); + // The upstream URL is http(s); flip to ws(s) for the WS open. + parsed.protocol = upstream.protocol === "https:" ? "wss:" : "ws:"; + parsed.host = upstream.host; + return parsed.toString(); + } + + protected override async sendRequest( + request: Request, + _ctx: CopilotRequestContext + ): Promise { + this.counters.httpRequests++; + const rewritten = this.rewriteUrl(request.url); + const requestHeaders = new Headers(request.headers); + requestHeaders.set("x-test-mutated", "1"); + const rewrittenRequest = new Request(rewritten, { + method: request.method, + headers: requestHeaders, + body: request.body, + // @ts-expect-error duplex is required by undici when streaming a body + duplex: "half", + }); + const response = await fetch(rewrittenRequest, { signal: _ctx.signal }); + this.counters.httpResponses++; + const responseHeaders = new Headers(response.headers); + responseHeaders.set("x-test-response-mutated", "1"); + return new Response(response.body, { + status: response.status, + statusText: response.statusText, + headers: responseHeaders, + }); + } + + protected override async openWebSocket( + ctx: CopilotRequestContext + ): Promise { + return TestSocketHandler.connect(this.rewriteWsUrl(ctx.url), ctx, this.counters); + } +} + +class TestSocketHandler extends CopilotWebSocketHandlerBase { + static async connect( + url: string, + ctx: CopilotRequestContext, + counters: Counters + ): Promise { + const client = new WsClient(url); + await new Promise((resolve, reject) => { + client.once("open", () => resolve()); + client.once("error", (err) => reject(err)); + }); + return new TestSocketHandler(client, ctx, counters); + } + + private constructor( + private readonly client: WsClient, + ctx: CopilotRequestContext, + private readonly counters: Counters + ) { + super(ctx); + this.client.on("message", (data, isBinary) => { + this.counters.wsResponseMessages++; + void this.sendResponseMessage(isBinary ? (data as Buffer) : data.toString("utf-8")); + }); + this.client.once("close", () => { + void this.close(); + }); + this.client.once("error", (err) => { + void this.close(new CopilotWebSocketCloseStatus(err.message, undefined, err as Error)); + }); + const onAbort = (): void => { + try { + this.client.close(); + } catch { + /* best-effort */ + } + }; + ctx.signal.addEventListener("abort", onAbort, { once: true }); + this.client.once("close", () => ctx.signal.removeEventListener("abort", onAbort)); + } + + override sendRequestMessage(data: string | Uint8Array): void { + this.counters.wsRequestMessages++; + if (this.client.readyState !== WsClient.OPEN) { + return; + } + this.client.send(data); + } + + override async [Symbol.asyncDispose](): Promise { + try { + await super[Symbol.asyncDispose](); + } finally { + try { + this.client.close(); + } catch { + /* best-effort */ + } + } + } +} + +describe("CopilotRequestHandler — single subclass handles HTTP + WebSocket", async () => { + const upstream = await startFakeUpstream(); + const counters: Counters = { + httpRequests: 0, + httpResponses: 0, + wsRequestMessages: 0, + wsResponseMessages: 0, + }; + + const { copilotClient: client, env } = await createSdkTestContext({ + copilotClientOptions: { + requestHandler: new TestHandler(upstream.url, counters), + }, + }); + + // Enable the WebSocket Responses transport in the spawned runtime so + // the main agent turn picks the WS path; single-shot calls (title + // generation) still go over HTTP through the same subclass. + env.COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES = "true"; + + afterAll(async () => { + await upstream.close(); + }); + + it("services both an HTTP turn and a WebSocket turn end-to-end via one handler", async () => { + await client.start(); + const session = await client.createSession({ onPermissionRequest: approveAll }); + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + // The HTTP hooks fired — the runtime issued model-layer GETs + // (catalog, policy) and possibly a single-shot inference. + expect(counters.httpRequests, "expected sendRequest to fire").toBeGreaterThan(0); + expect( + counters.httpResponses, + "expected sendRequest response mutation to fire" + ).toBeGreaterThan(0); + + // The WebSocket hooks fired — the main agent turn went over + // the WS path and we observed messages in both directions. + expect( + counters.wsRequestMessages, + "expected sendRequestMessage (runtime → upstream) to fire" + ).toBeGreaterThan(0); + expect( + counters.wsResponseMessages, + "expected sendResponseMessage (upstream → runtime) to fire" + ).toBeGreaterThan(0); + expect( + upstream.wsRequestCount(), + "expected upstream WS to receive request messages" + ).toBeGreaterThan(0); + + // The synthetic content from the upstream surfaced in the + // assistant turn — proves the full chain (runtime → handler + // → upstream → handler → runtime) is intact for the + // transport the main agent turn used. + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from synthetic (HTTP|WS) upstream/); + }, 90_000); +}); diff --git a/nodejs/test/e2e/copilot_request_session_id.e2e.test.ts b/nodejs/test/e2e/copilot_request_session_id.e2e.test.ts new file mode 100644 index 000000000..3f01475aa --- /dev/null +++ b/nodejs/test/e2e/copilot_request_session_id.e2e.test.ts @@ -0,0 +1,325 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { describe, expect, it } from "vitest"; +import { approveAll, CopilotRequestHandler, type CopilotRequestContext } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +const SYNTHETIC_TEXT = "OK from the synthetic stream."; + +interface InterceptedRequest { + url: string; + sessionId?: string; +} + +function isInferenceUrl(url: string): boolean { + const u = url.toLowerCase(); + return ( + u.endsWith("/chat/completions") || + u.endsWith("/responses") || + u.endsWith("/v1/messages") || + u.endsWith("/messages") + ); +} + +/** + * A {@link CopilotRequestHandler} that records every intercepted request + * (url + threaded session id) and fully replaces the upstream call with a + * fabricated, well-formed response for every model-layer endpoint, so an + * agent turn completes entirely off-network — no upstream server and no CAPI + * proxy acting as the inference endpoint. + * + * This exercises the public extension surface end to end: a consumer + * subclasses {@link CopilotRequestHandler} and overrides {@link sendRequest} + * to short-circuit the upstream HTTP call with any {@link Response} it likes. + * The base adapter streams that response back to the runtime. + */ +class RecordingRequestHandler extends CopilotRequestHandler { + readonly records: InterceptedRequest[] = []; + + protected override async sendRequest( + request: Request, + ctx: CopilotRequestContext + ): Promise { + const url = request.url; + this.records.push({ url, sessionId: ctx.sessionId }); + const bodyText = request.body ? await request.text() : ""; + return isInferenceUrl(url) + ? buildInferenceResponse(url, bodyText) + : buildNonInferenceResponse(url); + } +} + +function json(body: string): Response { + return new Response(body, { + status: 200, + headers: { "content-type": "application/json" }, + }); +} + +function sse(body: string): Response { + return new Response(body, { + status: 200, + headers: { "content-type": "text/event-stream", "cache-control": "no-cache" }, + }); +} + +/** + * Synthesize a well-formed inference response so the agent turn completes. + * The runtime selects `/responses` for both the CAPI and BYOK sessions here; + * `/chat/completions` is handled too for robustness. + */ +function buildInferenceResponse(url: string, bodyText: string): Response { + const wantsStream = /"stream"\s*:\s*true/.test(bodyText); + const u = url.toLowerCase(); + + if (u.includes("/responses")) { + return wantsStream ? sse(RESPONSES_STREAM_EVENTS.join("")) : json(BUFFERED_RESPONSE_JSON); + } + + if (u.includes("/chat/completions") && wantsStream) { + return sse(CHAT_COMPLETION_STREAM_EVENTS.join("")); + } + + // /chat/completions non-streaming (and any other inference url) — buffered JSON. + return json(BUFFERED_CHAT_COMPLETION_JSON); +} + +/** + * Serve the non-inference model-layer GETs/POSTs the runtime issues (catalog, + * model session, policy). These flow through the same handler but carry no + * session id (they happen outside an agent turn). + */ +function buildNonInferenceResponse(url: string): Response { + const u = url.toLowerCase(); + if (u.endsWith("/models")) { + return json(MODEL_CATALOG_JSON); + } + if (u.includes("/models/session")) { + return json("{}"); + } + if (u.includes("/policy")) { + return json(JSON.stringify({ state: "enabled" })); + } + return json("{}"); +} + +const RESPONSES_STREAM_EVENTS: string[] = [ + `event: response.created\ndata: ${JSON.stringify({ + type: "response.created", + response: { id: "resp_stub_1", object: "response", status: "in_progress", output: [] }, + })}\n\n`, + `event: response.output_item.added\ndata: ${JSON.stringify({ + type: "response.output_item.added", + output_index: 0, + item: { id: "msg_1", type: "message", role: "assistant", content: [] }, + })}\n\n`, + `event: response.content_part.added\ndata: ${JSON.stringify({ + type: "response.content_part.added", + output_index: 0, + content_index: 0, + part: { type: "output_text", text: "" }, + })}\n\n`, + `event: response.output_text.delta\ndata: ${JSON.stringify({ + type: "response.output_text.delta", + output_index: 0, + content_index: 0, + delta: SYNTHETIC_TEXT, + })}\n\n`, + `event: response.output_text.done\ndata: ${JSON.stringify({ + type: "response.output_text.done", + output_index: 0, + content_index: 0, + text: SYNTHETIC_TEXT, + })}\n\n`, + `event: response.completed\ndata: ${JSON.stringify({ + type: "response.completed", + response: { + id: "resp_stub_1", + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: SYNTHETIC_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, + }, + })}\n\n`, +]; + +const CHAT_COMPLETION_STREAM_EVENTS: string[] = (() => { + const base = { + id: "chatcmpl-stub-1", + object: "chat.completion.chunk", + created: 1, + model: "claude-sonnet-4.5", + }; + return [ + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { role: "assistant", content: "" }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: { content: SYNTHETIC_TEXT }, finish_reason: null }], + })}\n\n`, + `data: ${JSON.stringify({ + ...base, + choices: [{ index: 0, delta: {}, finish_reason: "stop" }], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, + })}\n\n`, + `data: [DONE]\n\n`, + ]; +})(); + +const BUFFERED_RESPONSE_JSON = JSON.stringify({ + id: "resp_stub_1", + object: "response", + status: "completed", + output: [ + { + id: "msg_1", + type: "message", + role: "assistant", + content: [{ type: "output_text", text: SYNTHETIC_TEXT }], + }, + ], + usage: { input_tokens: 5, output_tokens: 7, total_tokens: 12 }, +}); + +const BUFFERED_CHAT_COMPLETION_JSON = JSON.stringify({ + id: "chatcmpl-stub-1", + object: "chat.completion", + created: 1, + model: "claude-sonnet-4.5", + choices: [ + { + index: 0, + message: { role: "assistant", content: SYNTHETIC_TEXT }, + finish_reason: "stop", + }, + ], + usage: { prompt_tokens: 5, completion_tokens: 7, total_tokens: 12 }, +}); + +const MODEL_CATALOG_JSON = JSON.stringify({ + data: [ + { + id: "claude-sonnet-4.5", + name: "Claude Sonnet 4.5", + object: "model", + vendor: "Anthropic", + version: "1", + preview: false, + model_picker_enabled: true, + capabilities: { + type: "chat", + family: "claude-sonnet-4.5", + tokenizer: "o200k_base", + limits: { max_context_window_tokens: 200000, max_output_tokens: 8192 }, + supports: { + streaming: true, + tool_calls: true, + parallel_tool_calls: true, + vision: true, + }, + }, + }, + ], +}); + +/** + * Asserts the runtime threads its session id into the request handler for + * BOTH a CAPI session and a BYOK session. The handler alone services every + * model-layer request — no upstream server, no CAPI proxy acting as the + * inference endpoint — so the only source of `ctx.sessionId` is the runtime's + * own per-client threading. + */ +describe("CopilotRequestHandler threads the runtime session id (CAPI + BYOK)", async () => { + const handler = new RecordingRequestHandler(); + + const { copilotClient: client } = await createSdkTestContext({ + copilotClientOptions: { + requestHandler: handler, + }, + }); + + let capiSessionId: string | undefined; + + it("threads the session id into a CAPI session's inference request", async () => { + await client.start(); + const baseline = handler.records.length; + const session = await client.createSession({ onPermissionRequest: approveAll }); + capiSessionId = session.sessionId; + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + const inference = handler.records.slice(baseline).filter((r) => isInferenceUrl(r.url)); + expect( + inference.length, + "expected at least one intercepted inference request" + ).toBeGreaterThan(0); + for (const r of inference) { + expect(r.sessionId, "CAPI inference request must carry the runtime session id").toBe( + session.sessionId + ); + } + + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from the synthetic/); + }, 90_000); + + it("threads the session id into a BYOK session's inference request", async () => { + await client.start(); + const baseline = handler.records.length; + const session = await client.createSession({ + onPermissionRequest: approveAll, + // BYOK providers require an explicit model id. + model: "claude-sonnet-4.5", + provider: { + type: "openai", + wireApi: "responses", + baseUrl: "https://byok.invalid/v1", + apiKey: "byok-secret", + modelId: "claude-sonnet-4.5", + wireModel: "claude-sonnet-4.5", + }, + }); + const byokSessionId = session.sessionId; + let resultJson = ""; + try { + const result = await session.sendAndWait({ prompt: "Say OK." }); + resultJson = JSON.stringify(result); + } finally { + await session.disconnect(); + } + + const inference = handler.records.slice(baseline).filter((r) => isInferenceUrl(r.url)); + expect( + inference.length, + "expected at least one intercepted BYOK inference request" + ).toBeGreaterThan(0); + for (const r of inference) { + expect(r.sessionId, "BYOK inference request must carry the runtime session id").toBe( + byokSessionId + ); + } + + // Session ids are per-session, so the two turns must differ — proves + // we assert against a real, request-specific id, not a constant. + expect(byokSessionId).not.toBe(capiSessionId); + + // Validate the final assistant response arrived (guards against truncated captures) + expect(resultJson).toMatch(/OK from the synthetic/); + }, 90_000); +}); diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 1bda91072..4f56b1361 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -64,6 +64,14 @@ TelemetryConfig, UriRuntimeConnection, ) +from .copilot_request_handler import ( + CopilotRequestContext, + CopilotRequestHandler, + CopilotWebSocketCloseStatus, + CopilotWebSocketHandler, + CopilotWebSocketHandlerBase, + LlmInferenceHeaders, +) from .generated.rpc import ( ModelBillingTokenPrices, ModelBillingTokenPricesLongContext, @@ -186,6 +194,10 @@ "CopilotClient", "CopilotClientMode", "CopilotSession", + "CopilotRequestContext", + "CopilotRequestHandler", + "CopilotWebSocketCloseStatus", + "CopilotWebSocketHandlerBase", "CreateSessionFsHandler", "ElicitationContext", "ElicitationHandler", @@ -198,11 +210,13 @@ "ExitPlanModeRequest", "ExitPlanModeResult", "ExtensionInfo", + "CopilotWebSocketHandler", "GetAuthStatusResponse", "GetStatusResponse", "InfiniteSessionConfig", "InputOptions", "LargeToolOutputConfig", + "LlmInferenceHeaders", "LogLevel", "MCPHTTPServerConfig", "MCPServerConfig", diff --git a/python/copilot/client.py b/python/copilot/client.py index 2c407149c..69aacc8dc 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -61,7 +61,9 @@ CanvasHandler, ExtensionInfo, ) +from .copilot_request_handler import CopilotRequestHandler, create_copilot_request_adapter from .generated.rpc import ( + ClientGlobalApiHandlers, ClientSessionApiHandlers, ModelBillingTokenPrices, ModelBillingTokenPricesLongContext, # noqa: F401 @@ -71,6 +73,7 @@ _ConnectRequest, _InternalServerRpc, from_datetime, + register_client_global_api_handlers, register_client_session_api_handlers, ) from .generated.session_events import ( @@ -352,6 +355,7 @@ class _CopilotClientOptions: use_logged_in_user: bool | None = None telemetry: TelemetryConfig | None = None session_fs: SessionFsConfig | None = None + request_handler: CopilotRequestHandler | None = None session_idle_timeout_seconds: int | None = None enable_remote_sessions: bool = False on_list_models: Callable[[], list[ModelInfo] | Awaitable[list[ModelInfo]]] | None = None @@ -1049,6 +1053,7 @@ def __init__( use_logged_in_user: bool | None = None, telemetry: TelemetryConfig | None = None, session_fs: SessionFsConfig | None = None, + request_handler: CopilotRequestHandler | None = None, session_idle_timeout_seconds: int | None = None, enable_remote_sessions: bool = False, on_list_models: Callable[[], list[ModelInfo] | Awaitable[list[ModelInfo]]] | None = None, @@ -1083,6 +1088,9 @@ def __init__( telemetry. session_fs: Connection-level session filesystem provider configuration. + request_handler: Connection-level request handler. When set, the + supplied handler services every model-layer HTTP/WebSocket + request the runtime would otherwise issue (both BYOK and CAPI). session_idle_timeout_seconds: Server-wide session idle timeout in seconds. Sessions without activity for this duration are automatically cleaned up. Set to ``None`` or ``0`` to disable. @@ -1119,6 +1127,7 @@ def __init__( use_logged_in_user=use_logged_in_user, telemetry=telemetry, session_fs=session_fs, + request_handler=request_handler, session_idle_timeout_seconds=session_idle_timeout_seconds, enable_remote_sessions=enable_remote_sessions, on_list_models=on_list_models, @@ -1209,6 +1218,7 @@ def __init__( if options.session_fs is not None: _validate_session_fs_config(options.session_fs) self._session_fs_config = options.session_fs + self._request_handler = options.request_handler @property def rpc(self) -> ServerRpc: @@ -1361,6 +1371,9 @@ async def start(self) -> None: session_fs_start, ) + if self._request_handler is not None: + await self._set_llm_inference_provider() + self._state = "connected" log_timing( logger, @@ -1445,12 +1458,10 @@ async def stop(self) -> None: StopError(message=f"Failed to disconnect session {session.session_id}: {e}") ) - runtime_shutdown_completed = False if self._rpc is not None and self._cli_process is not None and not self._is_external_server: runtime_shutdown_start = time.perf_counter() try: await self._rpc.runtime.shutdown(timeout=_RUNTIME_SHUTDOWN_TIMEOUT_SECONDS) - runtime_shutdown_completed = True log_timing( logger, logging.DEBUG, @@ -1485,62 +1496,40 @@ async def stop(self) -> None: logger.debug("Error while closing Copilot runtime transport", exc_info=True) self._process = None - # Terminate CLI process (only if we spawned it) + # Terminate CLI process (only if we spawned it). + # + # Per the runtime.shutdown contract, the runtime completes all cleanup + # *before* responding and then leaves termination to the caller ("callers + # may then terminate the owned runtime process"). It deliberately keeps + # its JSON-RPC server alive to send the response and does not self-exit, + # so there is no point waiting a grace window for a self-exit that will + # never come. Once shutdown has completed (or failed) we terminate the + # child immediately and only wait to reap it. if self._cli_process and not self._is_external_server: poll = getattr(self._cli_process, "poll", None) is_running = poll is None or poll() is None if is_running: - if runtime_shutdown_completed: - try: - await asyncio.to_thread( - self._cli_process.wait, - timeout=_RUNTIME_SHUTDOWN_TIMEOUT_SECONDS, - ) - except subprocess.TimeoutExpired: - self._cli_process.terminate() - try: - await asyncio.to_thread( - self._cli_process.wait, - timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, - ) - except subprocess.TimeoutExpired: - self._cli_process.kill() - try: - await asyncio.to_thread( - self._cli_process.wait, - timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, - ) - except subprocess.TimeoutExpired as e: - errors.append( - StopError( - message=( - "Timed out waiting for CLI process to exit after kill: " - f"{e}" - ) - ) - ) - else: - self._cli_process.terminate() + self._cli_process.terminate() + try: + await asyncio.to_thread( + self._cli_process.wait, + timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, + ) + except subprocess.TimeoutExpired: + self._cli_process.kill() try: await asyncio.to_thread( self._cli_process.wait, timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, ) - except subprocess.TimeoutExpired: - self._cli_process.kill() - try: - await asyncio.to_thread( - self._cli_process.wait, - timeout=_CLI_PROCESS_EXIT_TIMEOUT_SECONDS, - ) - except subprocess.TimeoutExpired as e: - errors.append( - StopError( - message=( - f"Timed out waiting for CLI process to exit after kill: {e}" - ) + except subprocess.TimeoutExpired as e: + errors.append( + StopError( + message=( + f"Timed out waiting for CLI process to exit after kill: {e}" ) ) + ) if self._process is self._cli_process: self._process = None self._cli_process = None @@ -3532,6 +3521,7 @@ def handle_notification(method: str, params: dict): "systemMessage.transform", self._handle_system_message_transform ) register_client_session_api_handlers(self._client, self._get_client_session_handlers) + self._register_llm_inference_handlers() # Start listening for messages loop = asyncio.get_running_loop() @@ -3651,6 +3641,7 @@ def handle_notification(method: str, params: dict): "systemMessage.transform", self._handle_system_message_transform ) register_client_session_api_handlers(self._client, self._get_client_session_handlers) + self._register_llm_inference_handlers() # Start listening for messages loop = asyncio.get_running_loop() @@ -3723,6 +3714,22 @@ async def _set_session_fs_provider(self) -> None: await self._client.request("sessionFs.setProvider", params) + def _register_llm_inference_handlers(self) -> None: + if self._request_handler is None or not self._client: + return + adapter = create_copilot_request_adapter( + self._request_handler, + lambda: self._rpc.llm_inference if self._rpc is not None else None, + ) + register_client_global_api_handlers( + self._client, ClientGlobalApiHandlers(llm_inference=adapter) + ) + + async def _set_llm_inference_provider(self) -> None: + if self._request_handler is None or self._rpc is None: + return + await self._rpc.llm_inference.set_provider() + def _get_client_session_handlers(self, session_id: str) -> ClientSessionApiHandlers: with self._sessions_lock: session = self._sessions.get(session_id) diff --git a/python/copilot/copilot_request_handler.py b/python/copilot/copilot_request_handler.py new file mode 100644 index 000000000..80cd4a90b --- /dev/null +++ b/python/copilot/copilot_request_handler.py @@ -0,0 +1,734 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""CopilotRequestHandler: observe or replace outbound model-layer HTTP/WebSocket requests. + +The SDK consumer subclasses :class:`CopilotRequestHandler` and overrides one or +both seams: + +* HTTP — override :meth:`CopilotRequestHandler.send_request` to mutate the + :class:`httpx.Request`, post-process the :class:`httpx.Response`, or replace + the call entirely. The default forwards via a shared :class:`httpx.AsyncClient`. +* WebSocket — override :meth:`CopilotRequestHandler.open_websocket` to return + a per-connection :class:`CopilotWebSocketHandlerBase`. The default opens a + transparent forwarding connection via the ``websockets`` library. + +:func:`create_copilot_request_adapter` converts a handler into the generated +:class:`~copilot.generated.rpc.LlmInferenceHandler` shape so the RPC dispatcher +can route inbound ``httpRequestStart`` / ``httpRequestChunk`` frames through it. +""" + +from __future__ import annotations + +import asyncio +import base64 +from collections.abc import AsyncIterator, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from .generated.rpc import ( + LlmInferenceHTTPRequestChunkRequest, + LlmInferenceHTTPRequestChunkResult, + LlmInferenceHTTPRequestStartRequest, + LlmInferenceHTTPRequestStartResult, + LlmInferenceHTTPResponseChunkError, + LlmInferenceHTTPResponseChunkRequest, + LlmInferenceHTTPResponseStartRequest, + ServerLlmInferenceApi, +) + +if TYPE_CHECKING: + import httpx + +# Multi-valued headers: header name → list of values. +LlmInferenceHeaders = dict[str, list[str]] + +# Hop-by-hop and length headers the transport recomputes; forwarding them +# verbatim corrupts the request. +_FORBIDDEN_REQUEST_HEADERS = frozenset( + { + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", + } +) + +_shared_http_client: httpx.AsyncClient | None = None + + +def _get_shared_http_client() -> httpx.AsyncClient: + global _shared_http_client + if _shared_http_client is None: + import httpx + + _shared_http_client = httpx.AsyncClient(timeout=None, follow_redirects=False) + return _shared_http_client + + +@dataclass +class CopilotRequestContext: + """Per-request context handed to every :class:`CopilotRequestHandler` hook.""" + + request_id: str + """Opaque runtime-minted id, stable across the request lifecycle.""" + + transport: str + """``"http"`` (plain HTTP / SSE) or ``"websocket"`` (full-duplex channel).""" + + url: str + """Absolute request URL.""" + + headers: LlmInferenceHeaders + """HTTP request headers, multi-valued.""" + + cancel_event: asyncio.Event + """Set when the runtime cancels this in-flight request. Pass it through to + your transport so the upstream call is torn down too.""" + + session_id: str | None = None + """Id of the runtime session that triggered this request, when in scope. + Absent for out-of-session requests (e.g. the startup model catalog).""" + + _bridge: _CopilotWebSocketResponseBridge | None = field(default=None, repr=False) + + +@dataclass +class CopilotWebSocketCloseStatus: + """Terminal status for a callback-owned WebSocket connection.""" + + description: str | None = None + error_code: str | None = None + error: BaseException | None = None + + @classmethod + def normal_closure(cls) -> CopilotWebSocketCloseStatus: + return cls() + + +class CopilotWebSocketHandlerBase: + """Per-connection WebSocket handler returned by + :meth:`CopilotRequestHandler.open_websocket`. + + Subclass and override :meth:`send_request_message` (runtime → upstream) to + mutate, drop, or inject messages, and :meth:`send_response_message` + (upstream → runtime) for the reverse direction. A full transport replacement + overrides :meth:`open` to stand up its own connection and receive loop. + """ + + def __init__(self, context: CopilotRequestContext) -> None: + bridge = context._bridge + if bridge is None: + raise RuntimeError("WebSocket response bridge is not attached") + self.context = context + self._response = bridge + self._completion: asyncio.Future[CopilotWebSocketCloseStatus] = ( + asyncio.get_event_loop().create_future() + ) + self._closed = False + self._suppress_close_on_dispose = False + + async def send_response_message(self, data: str | bytes) -> None: + """Forward an upstream message to the runtime response.""" + await self._response.write(data) + + async def send_request_message(self, data: str | bytes) -> None: + """Forward a runtime message to the upstream connection. Override to mutate.""" + raise NotImplementedError + + async def close(self, status: CopilotWebSocketCloseStatus | None = None) -> None: + """Initiate close: end the runtime response and resolve completion.""" + if self._closed: + return + self._closed = True + status = status or CopilotWebSocketCloseStatus.normal_closure() + if status.error is not None: + await self._response.error(status.description or str(status.error), status.error_code) + else: + await self._response.end() + if not self._completion.done(): + self._completion.set_result(status) + + async def open(self) -> None: + """Establish the connection. Default is a no-op for custom transports.""" + + async def aclose(self) -> None: + """Final resource cleanup; closes normally if not already closed.""" + if not self._suppress_close_on_dispose and not self._closed: + await self.close(CopilotWebSocketCloseStatus.normal_closure()) + + +class CopilotWebSocketHandler(CopilotWebSocketHandlerBase): + """Default pass-through WebSocket handler backed by the ``websockets`` library.""" + + def __init__(self, context: CopilotRequestContext, url: str | None = None) -> None: + super().__init__(context) + self._url = url or context.url + self._upstream: Any | None = None + self._receive_task: asyncio.Task[None] | None = None + + async def send_request_message(self, data: str | bytes) -> None: + if self._upstream is None: + return + await self._upstream.send(data) + + async def open(self) -> None: + if self._upstream is not None: + return + try: + import websockets + except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "WebSocket forwarding requires the 'websockets' package. " + "Install it or override open_websocket()." + ) from exc + + headers = [ + (name, value) + for name, values in self.context.headers.items() + if name.lower() not in _FORBIDDEN_REQUEST_HEADERS + for value in (values or []) + ] + self._upstream = await websockets.connect(self._url, additional_headers=headers) + self._receive_task = asyncio.create_task(self._receive_loop()) + + async def _receive_loop(self) -> None: + try: + async for message in self._upstream: # type: ignore[union-attr] + await self.send_response_message(message) + await self.close(CopilotWebSocketCloseStatus.normal_closure()) + except asyncio.CancelledError: + raise + except Exception as exc: + await self.close(CopilotWebSocketCloseStatus(description=str(exc), error=exc)) + + async def close(self, status: CopilotWebSocketCloseStatus | None = None) -> None: + if self._upstream is not None: + try: + await self._upstream.close() + except Exception: + # Best-effort; the socket may already be closed. + pass + await super().close(status) + + async def aclose(self) -> None: + try: + await super().aclose() + finally: + if self._receive_task is not None: + self._receive_task.cancel() + if self._upstream is not None: + try: + await self._upstream.close() + except Exception: + # Best-effort teardown: the upstream may already be closed. + pass + + +class CopilotRequestHandler: + """Base class for consumers that observe or replace LLM inference requests. + + Override :meth:`send_request` to intercept HTTP model-layer requests, or + :meth:`open_websocket` to intercept WebSocket connections. An instance + that overrides nothing is a transparent pass-through. + """ + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + """Send an HTTP request. Override to mutate request/response or replace the call.""" + return await _get_shared_http_client().send(request, stream=True) + + async def open_websocket(self, ctx: CopilotRequestContext) -> CopilotWebSocketHandlerBase: + """Open a per-connection WebSocket handler. Override to mutate or replace.""" + return CopilotWebSocketHandler(ctx) + + async def _dispatch(self, exchange: _CopilotRequestExchange) -> None: + bridge = _CopilotWebSocketResponseBridge(exchange) + ctx = CopilotRequestContext( + request_id=exchange.request_id, + session_id=exchange.session_id, + transport=exchange.transport, + url=exchange.url, + headers=exchange.headers, + cancel_event=exchange.cancel_event, + _bridge=bridge, + ) + if exchange.transport == "websocket": + await self._handle_web_socket(exchange, ctx) + else: + await self._handle_http(exchange, ctx) + + async def _handle_http( + self, exchange: _CopilotRequestExchange, ctx: CopilotRequestContext + ) -> None: + request = await _build_httpx_request(exchange) + await _run_cancellable(self._forward_http(request, exchange, ctx), exchange.cancel_event) + + async def _forward_http( + self, + request: httpx.Request, + exchange: _CopilotRequestExchange, + ctx: CopilotRequestContext, + ) -> None: + response = await self.send_request(request, ctx) + try: + await _stream_response_to_exchange(response, exchange) + finally: + await response.aclose() + + async def _handle_web_socket( + self, exchange: _CopilotRequestExchange, ctx: CopilotRequestContext + ) -> None: + handler = await self.open_websocket(ctx) + assert ctx._bridge is not None + try: + await handler.open() + # Emit the 101 upgrade head eagerly. The runtime blocks the WS + # connect until it receives this acknowledgement, and only then + # starts forwarding inbound messages as request-body chunks. + # Waiting for the first upstream message would deadlock. + await ctx._bridge.start() + + async def pump_client() -> str: + async for chunk in exchange.request_body: + await handler.send_request_message(_decode_frame(chunk)) + return "client-complete" + + client_task = asyncio.create_task(pump_client()) + completion = asyncio.ensure_future(handler._completion) + done, _ = await asyncio.wait( + {client_task, completion}, return_when=asyncio.FIRST_COMPLETED + ) + + if client_task in done and client_task.exception() is not None: + handler._suppress_close_on_dispose = True + raise client_task.exception() # type: ignore[misc] + + if client_task in done: + await handler.close(CopilotWebSocketCloseStatus.normal_closure()) + await handler._completion + return + + status = await handler._completion + if status.error is not None: + raise status.error + finally: + await handler.aclose() + + +# --------------------------------------------------------------------------- +# Internal exchange: request body feed + response emitter +# --------------------------------------------------------------------------- + + +@dataclass +class _BodyItem: + chunk: bytes | None = None + end: bool = False + cancel: bool = False + cancel_reason: str | None = None + + +class _BodyQueue: + """An async iterator of request-body byte chunks fed by the runtime.""" + + def __init__(self) -> None: + self._queue: asyncio.Queue[_BodyItem] = asyncio.Queue() + self._done = False + + def push(self, item: _BodyItem) -> None: + self._queue.put_nowait(item) + + def __aiter__(self) -> AsyncIterator[bytes]: + return self + + async def __anext__(self) -> bytes: + if self._done: + raise StopAsyncIteration + item = await self._queue.get() + if item.cancel: + self._done = True + reason = ( + f"Request cancelled by runtime: {item.cancel_reason}" + if item.cancel_reason + else "Request cancelled by runtime" + ) + raise RuntimeError(reason) + if item.end: + self._done = True + raise StopAsyncIteration + return item.chunk if item.chunk is not None else b"" + + +class _CopilotRequestExchange: + """One intercepted request in flight. + + Carries the request body stream the runtime feeds via ``httpRequestChunk`` + frames, and emits the handler's response directly to the runtime through + the generated ``llmInference`` RPC. Replaces the former provider / sink / + response-channel indirection with a single object the adapter owns. + """ + + def __init__( + self, + request_id: str, + get_server_rpc: Callable[[], ServerLlmInferenceApi | None], + ) -> None: + self.request_id = request_id + self.session_id: str | None = None + self.method: str = "GET" + self.url: str = "" + self.headers: dict[str, list[str]] = {} + self.transport: str = "http" + self._get_server_rpc = get_server_rpc + self._queue = _BodyQueue() + self.cancel_event: asyncio.Event = asyncio.Event() + self.started: bool = False + self.finished: bool = False + self.cancelled: bool = False + self.task: asyncio.Task[None] | None = None + + def set_context(self, params: LlmInferenceHTTPRequestStartRequest) -> None: + """Fill in the request context once the matching start frame arrives.""" + self.session_id = params.session_id + self.method = params.method + self.url = params.url + self.headers = params.headers + transport = params.transport + self.transport = transport.value if transport is not None else "http" + + @property + def request_body(self) -> _BodyQueue: + return self._queue + + def _require_rpc(self) -> ServerLlmInferenceApi: + rpc = self._get_server_rpc() + if rpc is None: + raise RuntimeError("Copilot request response used after RPC connection closed.") + return rpc + + async def start_response( + self, + status: int, + status_text: str | None = None, + headers: LlmInferenceHeaders | None = None, + ) -> None: + if self.started: + raise RuntimeError("Copilot request response start() called twice.") + if self.finished: + raise RuntimeError("Copilot request response already finished.") + self.started = True + await self._require_rpc().http_response_start( + LlmInferenceHTTPResponseStartRequest( + headers=headers or {}, + request_id=self.request_id, + status=status, + status_text=status_text, + ) + ) + + async def write_response(self, data: str | bytes) -> None: + if self.cancelled: + raise RuntimeError("Copilot request was cancelled by the runtime.") + if not self.started: + raise RuntimeError("Copilot request response write() called before start().") + if self.finished: + raise RuntimeError("Copilot request response write() called after end()/error().") + is_binary = isinstance(data, (bytes, bytearray)) + payload = base64.b64encode(bytes(data)).decode("ascii") if is_binary else str(data) + await self._require_rpc().http_response_chunk( + LlmInferenceHTTPResponseChunkRequest( + data=payload, + request_id=self.request_id, + binary=is_binary or None, + end=False, + ) + ) + + async def end_response(self) -> None: + if self.finished: + return + self.finished = True + await self._require_rpc().http_response_chunk( + LlmInferenceHTTPResponseChunkRequest(data="", request_id=self.request_id, end=True) + ) + + async def error_response(self, message: str, code: str | None = None) -> None: + if self.finished: + return + self.finished = True + await self._require_rpc().http_response_chunk( + LlmInferenceHTTPResponseChunkRequest( + data="", + request_id=self.request_id, + end=True, + error=LlmInferenceHTTPResponseChunkError(message=message, code=code), + ) + ) + + +# --------------------------------------------------------------------------- +# Adapter: wires the handler into the generated RPC handler shape +# --------------------------------------------------------------------------- + + +def create_copilot_request_adapter( + handler: CopilotRequestHandler, + get_server_rpc: Callable[[], ServerLlmInferenceApi | None], +) -> _CopilotRequestAdapterHandler: + """Adapt a :class:`CopilotRequestHandler` into the generated handler shape. + + Maintains a per-``request_id`` table of :class:`_CopilotRequestExchange`: + each ``httpRequestStart`` allocates one and fires the handler in the + background, returning immediately so the runtime's RPC reply is not gated + on the consumer's I/O. Subsequent ``httpRequestChunk`` frames are routed + into the matching exchange's body stream. + """ + return _CopilotRequestAdapterHandler(handler, get_server_rpc) + + +class _CopilotRequestAdapterHandler: + def __init__( + self, + handler: CopilotRequestHandler, + get_server_rpc: Callable[[], ServerLlmInferenceApi | None], + ) -> None: + self._handler = handler + self._get_server_rpc = get_server_rpc + self._pending: dict[str, _CopilotRequestExchange] = {} + + def _route_chunk( + self, + exchange: _CopilotRequestExchange, + params: LlmInferenceHTTPRequestChunkRequest, + ) -> None: + if params.cancel: + exchange.cancelled = True + exchange.cancel_event.set() + exchange._queue.push(_BodyItem(cancel=True, cancel_reason=params.cancel_reason)) + return + if params.data: + exchange._queue.push( + _BodyItem(chunk=_decode_chunk_data(params.data, bool(params.binary))) + ) + if params.end: + exchange._queue.push(_BodyItem(end=True)) + + async def _run(self, exchange: _CopilotRequestExchange) -> None: + try: + await self._handler._dispatch(exchange) + if not exchange.finished: + await _finalize( + exchange, + 502, + "Copilot request handler returned without finalising the response.", + ) + except Exception as exc: + if exchange.cancelled or exchange.cancel_event.is_set(): + await _finalize(exchange, 499, "Request cancelled by runtime", "cancelled") + return + await _finalize(exchange, 502, str(exc)) + finally: + self._pending.pop(exchange.request_id, None) + + def _get_or_create(self, request_id: str) -> _CopilotRequestExchange: + # The runtime dispatches httpRequestStart and httpRequestChunk frames + # independently. get-or-create keeps the adapter correct regardless of + # arrival order: a body chunk (including the terminal end frame) that + # races ahead of its start frame is buffered into the same exchange + # rather than dropped, which would otherwise hang the body drain. + exchange = self._pending.get(request_id) + if exchange is None: + exchange = _CopilotRequestExchange(request_id, self._get_server_rpc) + self._pending[request_id] = exchange + return exchange + + async def http_request_start( + self, params: LlmInferenceHTTPRequestStartRequest + ) -> LlmInferenceHTTPRequestStartResult: + # Adopt any exchange a racing chunk already created — with its buffered + # body — rather than dropping those frames. + exchange = self._get_or_create(params.request_id) + exchange.set_context(params) + exchange.task = asyncio.create_task(self._run(exchange)) + return LlmInferenceHTTPRequestStartResult() + + async def http_request_chunk( + self, params: LlmInferenceHTTPRequestChunkRequest + ) -> LlmInferenceHTTPRequestChunkResult: + # May arrive before the matching start frame; get-or-create so the body + # is buffered, never lost. + exchange = self._get_or_create(params.request_id) + self._route_chunk(exchange, params) + return LlmInferenceHTTPRequestChunkResult() + + +async def _finalize( + exchange: _CopilotRequestExchange, + status: int, + message: str, + code: str | None = None, +) -> None: + if exchange.finished: + return + try: + if not exchange.started: + await exchange.start_response(status) + await exchange.error_response(message, code) + except Exception: + # Best-effort — the connection may already be dead. + pass + + +# --------------------------------------------------------------------------- +# WebSocket response bridge +# --------------------------------------------------------------------------- + + +class _CopilotWebSocketResponseBridge: + """Serialises WebSocket response writes into the exchange. + + The 101 upgrade head is emitted eagerly via :meth:`start` (the runtime + gates the WS connect on it); subsequent writes and the terminal frame are + serialised via a lock so the head always precedes them. The lazy-start + path in :meth:`write` acts as a no-op backstop when ``start`` is called + first (the normal case). + """ + + def __init__(self, exchange: _CopilotRequestExchange) -> None: + self._exchange = exchange + self._started = False + self._completed = False + self._lock = asyncio.Lock() + + async def start(self) -> None: + """Emit the 101 upgrade acknowledgement now.""" + async with self._lock: + if self._started: + return + self._started = True + await self._exchange.start_response(101, headers={}) + + async def write(self, data: str | bytes) -> None: + async with self._lock: + if not self._started: + # Lazy-start backstop: emits the 101 head if a subclass calls + # write before start(). In normal usage start() is called + # eagerly in _handle_web_socket so this branch is never taken. + self._started = True + await self._exchange.start_response(101, headers={}) + if not self._completed: + await self._exchange.write_response(data) + + async def end(self) -> None: + async with self._lock: + if self._completed: + return + self._completed = True + await self._exchange.end_response() + + async def error(self, message: str, code: str | None = None) -> None: + async with self._lock: + if self._completed: + return + self._completed = True + await self._exchange.error_response(message, code) + + +# --------------------------------------------------------------------------- +# HTTP helpers +# --------------------------------------------------------------------------- + + +async def _run_cancellable(coro: Any, cancel_event: asyncio.Event) -> None: + """Run ``coro`` but abort it (and raise) when ``cancel_event`` fires.""" + task = asyncio.ensure_future(coro) + waiter = asyncio.ensure_future(cancel_event.wait()) + try: + done, _ = await asyncio.wait({task, waiter}, return_when=asyncio.FIRST_COMPLETED) + if task in done: + exc = task.exception() + if exc is not None: + raise exc + return + # Cancellation fired first. + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + # The awaited task was cancelled; its unwind exception is expected + # and irrelevant — we raise the cancellation result below. + pass + raise RuntimeError("Request cancelled by runtime") + finally: + if not waiter.done(): + waiter.cancel() + + +async def _build_httpx_request(exchange: _CopilotRequestExchange) -> httpx.Request: + import httpx + + header_pairs = [ + (name, value) + for name, values in exchange.headers.items() + if name.lower() not in _FORBIDDEN_REQUEST_HEADERS + for value in (values or []) + ] + method = exchange.method.upper() + has_body = method not in ("GET", "HEAD") + body = await _drain_async(exchange.request_body) + content = body if (has_body and body) else None + return httpx.Request(method, exchange.url, headers=header_pairs, content=content) + + +async def _drain_async(stream: AsyncIterator[bytes]) -> bytes: + parts: list[bytes] = [] + async for chunk in stream: + if chunk: + parts.append(chunk) + return b"".join(parts) + + +async def _stream_response_to_exchange( + response: httpx.Response, exchange: _CopilotRequestExchange +) -> None: + await exchange.start_response( + response.status_code, + status_text=response.reason_phrase or None, + headers=_headers_to_multi_map(response.headers), + ) + if response.is_stream_consumed: + # An in-memory response (built with ``content=``) has already buffered its + # body, so its raw stream cannot be iterated; forward the buffered bytes. + body = response.content + if body: + await exchange.write_response(body) + else: + async for chunk in response.aiter_raw(): + if chunk: + await exchange.write_response(chunk) + await exchange.end_response() + + +def _headers_to_multi_map(headers: Any) -> LlmInferenceHeaders: + out: dict[str, list[str]] = {} + for name, value in headers.multi_items(): + out.setdefault(name, []).append(value) + return out + + +def _decode_chunk_data(data: str, binary: bool) -> bytes: + if binary: + return base64.b64decode(data) + return data.encode("utf-8") + + +def _decode_frame(chunk: bytes) -> str: + return chunk.decode("utf-8", errors="replace") diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index de42808b4..cb9f2f5f9 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -24971,6 +24971,44 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: return result.value if hasattr(result, 'value') else result client.set_request_handler("canvas.action.invoke", handle_canvas_action_invoke) +# Experimental: this API group is experimental and may change or be removed. +class LlmInferenceHandler(Protocol): + async def http_request_start(self, params: LlmInferenceHTTPRequestStartRequest) -> LlmInferenceHTTPRequestStartResult: + "Announces an outbound model-layer HTTP request the runtime wants the SDK client to service. Carries the request head only; the body always follows as one or more httpRequestChunk frames keyed by the same requestId, even when the body is empty (a single chunk with end=true).\n\nArgs:\n params: The head of an outbound model-layer HTTP request.\n\nReturns:\n Acknowledgement. Returning successfully simply means the SDK accepted the start frame; it does not imply the request will succeed." + pass + async def http_request_chunk(self, params: LlmInferenceHTTPRequestChunkRequest) -> LlmInferenceHTTPRequestChunkResult: + "Delivers a body byte range (or a cancellation signal) for a request previously announced via httpRequestStart, correlated by requestId. The runtime fires at least one chunk per request — when there is no body, a single chunk with empty data and end=true. Mid-stream the runtime may send a chunk with cancel=true to abort the request; the SDK then stops issuing httpResponseChunk frames and may emit a terminal httpResponseChunk with error set.\n\nArgs:\n params: A request body chunk or cancellation signal.\n\nReturns:\n Acknowledgement. The SDK is free to ignore the ack and treat chunk delivery as fire-and-forget." + pass + +@dataclass +class ClientGlobalApiHandlers: + llm_inference: LlmInferenceHandler | None = None + +def register_client_global_api_handlers( + client: "JsonRpcClient", + handlers: ClientGlobalApiHandlers, +) -> None: + """Register client-global request handlers on a JSON-RPC connection. + + Unlike client-session handlers these methods carry no implicit + session_id dispatch key; a single set of handlers serves the entire + connection. + """ + async def handle_llm_inference_http_request_start(params: dict) -> dict | None: + request = LlmInferenceHTTPRequestStartRequest.from_dict(params) + handler = handlers.llm_inference + if handler is None: raise RuntimeError("No llm_inference client-global handler registered") + result = await handler.http_request_start(request) + return result.to_dict() + client.set_request_handler("llmInference.httpRequestStart", handle_llm_inference_http_request_start) + async def handle_llm_inference_http_request_chunk(params: dict) -> dict | None: + request = LlmInferenceHTTPRequestChunkRequest.from_dict(params) + handler = handlers.llm_inference + if handler is None: raise RuntimeError("No llm_inference client-global handler registered") + result = await handler.http_request_chunk(request) + return result.to_dict() + client.set_request_handler("llmInference.httpRequestChunk", handle_llm_inference_http_request_chunk) + __all__ = [ "APIKeyAuthInfo", "APIKeyAuthInfoType", @@ -25040,6 +25078,7 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "CanvasProviderOpenRequest", "CanvasProviderOpenResult", "CanvasSessionContext", + "ClientGlobalApiHandlers", "ClientSessionApiHandlers", "CommandList", "CommandsApi", @@ -25164,6 +25203,7 @@ async def handle_canvas_action_invoke(params: dict) -> dict | None: "LlmInferenceHTTPResponseChunkResult", "LlmInferenceHTTPResponseStartRequest", "LlmInferenceHTTPResponseStartResult", + "LlmInferenceHandler", "LlmInferenceHeaders", "LlmInferenceSetProviderResult", "LocalSessionMetadataValue", diff --git a/python/e2e/_copilot_request_helpers.py b/python/e2e/_copilot_request_helpers.py new file mode 100644 index 000000000..c3c6a06dd --- /dev/null +++ b/python/e2e/_copilot_request_helpers.py @@ -0,0 +1,291 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""Shared fixtures and response-builder helpers for the CopilotRequestHandler e2e tests. + +The ``copilot_request_*`` tests have no recorded snapshots: the registered +handler fabricates well-formed model responses and the runtime routes all of +its model-layer HTTP/WebSocket traffic through that handler instead of the +CAPI proxy. These helpers centralise the synthetic CAPI shapes (model catalog, +policy, ``/responses`` SSE, ``/chat/completions``) so each test file can focus +on the behaviour it is exercising. + +The leading underscore keeps pytest from collecting this module as a test. +""" + +from __future__ import annotations + +import json +import os +import re + +import httpx +import pytest_asyncio + +from copilot import CopilotClient, CopilotRequestHandler, RuntimeConnection +from copilot.generated.session_events import AssistantMessageData + +from .testharness import E2ETestContext + +SYNTHETIC_TEXT = "OK from the synthetic stream." + + +def sse(event: str, data: dict) -> str: + """Frame a single Server-Sent Events message: ``event:``/``data:`` + blank line.""" + return f"event: {event}\ndata: {json.dumps(data)}\n\n" + + +def is_inference_url(url: str) -> bool: + """Return True if ``url`` is a model inference endpoint. + + Strips query parameters before matching so URLs like + ``/chat/completions?api-version=2024-02`` are handled correctly. + """ + path = url.lower().split("?", 1)[0] + return ( + path.endswith("/chat/completions") + or path.endswith("/responses") + or path.endswith("/v1/messages") + or path.endswith("/messages") + ) + + +def _wants_stream(body: bytes) -> bool: + return re.search(rb'"stream"\s*:\s*true', body) is not None + + +def model_catalog(supported_endpoints: list[str] | None = None) -> dict: + """The synthetic ``/models`` catalog payload.""" + model: dict = { + "id": "claude-sonnet-4.5", + "name": "Claude Sonnet 4.5", + "object": "model", + "vendor": "Anthropic", + "version": "1", + "preview": False, + "model_picker_enabled": True, + "capabilities": { + "type": "chat", + "family": "claude-sonnet-4.5", + "tokenizer": "o200k_base", + "limits": {"max_context_window_tokens": 200000, "max_output_tokens": 8192}, + "supports": { + "streaming": True, + "tool_calls": True, + "parallel_tool_calls": True, + "vision": True, + }, + }, + } + if supported_endpoints is not None: + model["supported_endpoints"] = supported_endpoints + return {"data": [model]} + + +def responses_events(text: str, resp_id: str = "resp_stub_1") -> list[dict]: + """The ordered ``/responses`` event objects the runtime's reducer expects.""" + return [ + { + "type": "response.created", + "response": { + "id": resp_id, + "object": "response", + "status": "in_progress", + "output": [], + }, + }, + { + "type": "response.output_item.added", + "output_index": 0, + "item": {"id": "msg_1", "type": "message", "role": "assistant", "content": []}, + }, + { + "type": "response.content_part.added", + "output_index": 0, + "content_index": 0, + "part": {"type": "output_text", "text": ""}, + }, + { + "type": "response.output_text.delta", + "output_index": 0, + "content_index": 0, + "delta": text, + }, + { + "type": "response.output_text.done", + "output_index": 0, + "content_index": 0, + "text": text, + }, + { + "type": "response.completed", + "response": { + "id": resp_id, + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + ], + "usage": {"input_tokens": 5, "output_tokens": 7, "total_tokens": 12}, + }, + }, + ] + + +def build_non_inference_response( + url: str, supported_endpoints: list[str] | None = None +) -> httpx.Response: + """Build a minimal ``httpx.Response`` for non-inference model-layer requests.""" + path = url.lower().split("?", 1)[0] # strip query params before matching + if path.endswith("/models"): + return httpx.Response( + 200, + headers={"content-type": "application/json"}, + content=json.dumps(model_catalog(supported_endpoints)).encode(), + ) + if "/models/session" in path: + return httpx.Response(200, headers={"content-type": "application/json"}, content=b"{}") + if "/policy" in path: + return httpx.Response( + 200, + headers={"content-type": "application/json"}, + content=json.dumps({"state": "enabled"}).encode(), + ) + return httpx.Response(200, headers={"content-type": "application/json"}, content=b"{}") + + +def build_inference_response(request: httpx.Request, text: str = SYNTHETIC_TEXT) -> httpx.Response: + """Build a synthetic inference response for ``/responses`` or ``/chat/completions``. + + Dispatches by URL and the request body's ``stream`` flag: ``/responses`` + streams an SSE event sequence (or returns a buffered Responses object when + ``stream`` is false), ``/chat/completions`` streams chat-completion chunks + (or returns a buffered completion). + """ + body = request.content # already drained when send_request is called + wants_stream = _wants_stream(body) + url = str(request.url).lower() + + if "/responses" in url: + if not wants_stream: + return httpx.Response( + 200, + headers={"content-type": "application/json"}, + content=json.dumps(responses_events(text)[-1]["response"]).encode(), + ) + stream_body = "".join(sse(e["type"], e) for e in responses_events(text)) + return httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_body.encode(), + ) + + if "/chat/completions" in url and wants_stream: + base = { + "id": "chatcmpl-stub-1", + "object": "chat.completion.chunk", + "created": 1, + "model": "claude-sonnet-4.5", + } + chunks = [ + { + **base, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None, + } + ], + }, + { + **base, + "choices": [{"index": 0, "delta": {"content": text}, "finish_reason": None}], + }, + { + **base, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + }, + ] + stream_body = ( + "".join("data: " + json.dumps(c) + "\n\n" for c in chunks) + "data: [DONE]\n\n" + ) + return httpx.Response( + 200, + headers={"content-type": "text/event-stream"}, + content=stream_body.encode(), + ) + + return httpx.Response( + 200, + headers={"content-type": "application/json"}, + content=json.dumps( + { + "id": "chatcmpl-stub-1", + "object": "chat.completion", + "created": 1, + "model": "claude-sonnet-4.5", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, + } + ).encode(), + ) + + +def assistant_text(event) -> str: + if event is not None and isinstance(event.data, AssistantMessageData): + return event.data.content + return "" + + +def build_isolated_client( + ctx: E2ETestContext, + handler: CopilotRequestHandler, + extra_env: dict[str, str] | None = None, +) -> CopilotClient: + """Build a CopilotClient wired to ``handler`` via ``request_handler``.""" + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + env = ctx.get_env() + if extra_env: + env = {**env, **extra_env} + return CopilotClient( + connection=RuntimeConnection.for_stdio(path=ctx.cli_path), + working_directory=ctx.work_dir, + env=env, + github_token=github_token, + request_handler=handler, + ) + + +def isolated_client_fixture(make_handler, extra_env: dict[str, str] | None = None): + """Build a module-scoped pytest-asyncio fixture yielding ``(client, handler)``.""" + + @pytest_asyncio.fixture(loop_scope="module") + async def _fixture(ctx: E2ETestContext): + handler = make_handler() + client = build_isolated_client(ctx, handler, extra_env) + try: + yield client, handler + finally: + try: + await client.stop() + except Exception: + # Best-effort teardown during fixture cleanup. + pass + + return _fixture diff --git a/python/e2e/test_copilot_request_cancel_error_e2e.py b/python/e2e/test_copilot_request_cancel_error_e2e.py new file mode 100644 index 000000000..f32884a0e --- /dev/null +++ b/python/e2e/test_copilot_request_cancel_error_e2e.py @@ -0,0 +1,130 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""Cancellation and error coverage for CopilotRequestHandler. + +Mirrors ``nodejs/test/e2e/copilot_request_cancel_error.e2e.test.ts``. These +two scenarios exercise the handler's terminal paths that the happy-path +session-id and HTTP/WebSocket tests never reach: + +* **Error** — the handler throws from :meth:`CopilotRequestHandler.send_request` + for an inference request. The adapter reports a transport error back to the + runtime rather than hanging. +* **Runtime cancel** — the handler blocks an inference request indefinitely; + when the consumer aborts the turn the runtime cancels the in-flight request, + firing ``ctx.cancel_event``. The handler observes the abort (the ``cancel``-frame + path) instead of leaking a stuck request. + +Non-inference model-layer requests (catalog, policy, model session) are served +with minimal stubs so the turn reaches the inference step. +""" + +from __future__ import annotations + +import asyncio + +import httpx +import pytest + +from copilot import CopilotRequestContext, CopilotRequestHandler +from copilot.session import PermissionHandler + +from ._copilot_request_helpers import ( + is_inference_url, + isolated_client_fixture, +) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +async def _wait_for(predicate, timeout_s: float) -> None: + loop = asyncio.get_event_loop() + start = loop.time() + while not predicate(): + if loop.time() - start > timeout_s: + raise TimeoutError("wait_for timed out") + await asyncio.sleep(0.05) + + +class _ThrowingHandler(CopilotRequestHandler): + """Throws from every inference request to exercise the error-reporting path.""" + + def __init__(self) -> None: + self.inference_attempts = 0 + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + url = str(request.url) + if not is_inference_url(url): + return await super().send_request(request, ctx) + self.inference_attempts += 1 + raise RuntimeError("synthetic-callback-transport-failure") + + +class _CancellingHandler(CopilotRequestHandler): + """Blocks every inference request until the runtime cancels it.""" + + def __init__(self) -> None: + self.inference_entered = False + self.saw_abort = False + self.abort_seen = asyncio.Event() + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + url = str(request.url) + if not is_inference_url(url): + return await super().send_request(request, ctx) + self.inference_entered = True + await ctx.cancel_event.wait() + self.saw_abort = True + self.abort_seen.set() + raise RuntimeError("cancelled by runtime") + + +throwing_client = isolated_client_fixture(_ThrowingHandler) +cancelling_client = isolated_client_fixture(_CancellingHandler) + + +class TestCopilotRequestHandlerError: + async def test_reports_thrown_callback_error_instead_of_hanging(self, throwing_client): + client, handler = throwing_client + await client.start() + session = await client.create_session(on_permission_request=PermissionHandler.approve_all) + try: + # The callback throws on inference; the turn surfaces an error (or + # completes without an assistant message) rather than hanging. + await session.send_and_wait("Say OK.") + except Exception: # noqa: BLE001 + # Any turn-level error is expected here; we only assert the callback + # was reached below. + pass + finally: + await session.disconnect() + + assert handler.inference_attempts > 0, ( + "expected the inference callback to be reached and raise" + ) + + +class TestCopilotRequestHandlerCancel: + async def test_fires_cancel_event_when_consumer_aborts_in_flight_request( + self, cancelling_client + ): + client, handler = cancelling_client + await client.start() + session = await client.create_session(on_permission_request=PermissionHandler.approve_all) + try: + await session.send("Say OK.") + await _wait_for(lambda: handler.inference_entered, 60.0) + await session.abort() + await asyncio.wait_for(handler.abort_seen.wait(), timeout=30.0) + finally: + await session.disconnect() + + assert handler.inference_entered is True, "expected the inference callback to be entered" + assert handler.saw_abort is True, ( + "expected the callback to observe runtime cancellation via cancel_event" + ) diff --git a/python/e2e/test_copilot_request_handler_e2e.py b/python/e2e/test_copilot_request_handler_e2e.py new file mode 100644 index 000000000..eec0571d4 --- /dev/null +++ b/python/e2e/test_copilot_request_handler_e2e.py @@ -0,0 +1,283 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""E2E test for the idiomatic ``CopilotRequestHandler`` forwarding seams. + +Mirrors ``nodejs/test/e2e/copilot_request_handler.e2e.test.ts``. A single +handler subclass services BOTH transports against a per-test fake upstream: + +* HTTP — :meth:`send_request` rewrites the request to the local HTTP upstream, + mutates an outbound and a response header, and forwards via httpx. +* WebSocket — :meth:`open_websocket` rewrites the URL to the local WebSocket + upstream and returns a forwarding handler that counts messages in both + directions. + +Unlike the other inference tests (which fabricate responses inline), this one +exercises the default httpx / ``websockets`` forwarding machinery against a +real socket, proving the full chain runtime → handler → upstream → handler → +runtime is intact for whichever transport the agent turn selects. +""" + +from __future__ import annotations + +import json +import os +import threading +from dataclasses import dataclass, field +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +import httpx +import pytest +import pytest_asyncio +from websockets.asyncio.server import serve as ws_serve + +from copilot import ( + CopilotClient, + CopilotRequestContext, + CopilotRequestHandler, + CopilotWebSocketHandler, + RuntimeConnection, +) +from copilot.session import PermissionHandler + +from ._copilot_request_helpers import assistant_text, model_catalog, responses_events +from .testharness import E2ETestContext + +pytestmark = pytest.mark.asyncio(loop_scope="module") + +HTTP_TEXT = "OK from synthetic HTTP upstream." +WS_TEXT = "OK from synthetic WS upstream." + + +@dataclass +class _Counters: + http_requests: int = 0 + http_responses: int = 0 + ws_request_messages: int = 0 + ws_response_messages: int = 0 + + +@dataclass +class _Upstream: + http_url: str + ws_url: str + _http_server: ThreadingHTTPServer + _http_thread: threading.Thread + _ws_server: object + ws_requests: list[int] = field(default_factory=lambda: [0]) + + @property + def ws_request_count(self) -> int: + return self.ws_requests[0] + + async def close(self) -> None: + self._http_server.shutdown() + self._http_thread.join(timeout=5) + self._http_server.server_close() + self._ws_server.close() # type: ignore[attr-defined] + await self._ws_server.wait_closed() # type: ignore[attr-defined] + + +def _sse_body(text: str, resp_id: str) -> bytes: + out = "".join( + f"event: {event['type']}\ndata: {json.dumps(event)}\n\n" + for event in responses_events(text, resp_id) + ) + return out.encode("utf-8") + + +async def _start_fake_upstream() -> _Upstream: + class _Handler(BaseHTTPRequestHandler): + def log_message(self, *_args): # noqa: ANN002 - silence default logging + pass + + def _send(self, status: int, content_type: str, body: bytes) -> None: + self.send_response(status) + self.send_header("content-type", content_type) + self.send_header("content-length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def _route(self) -> None: + path = self.path.split("?", 1)[0].lower() + length = int(self.headers.get("content-length") or 0) + if length: + self.rfile.read(length) + if path.endswith("/models"): + self._send( + 200, + "application/json", + json.dumps( + model_catalog(supported_endpoints=["/responses", "ws:/responses"]) + ).encode("utf-8"), + ) + return + if path.endswith("/models/session"): + self._send(200, "application/json", b"{}") + return + if "/policy" in path: + self._send( + 200, + "application/json", + json.dumps({"state": "enabled"}).encode("utf-8"), + ) + return + if path.endswith("/responses"): + self._send(200, "text/event-stream", _sse_body(HTTP_TEXT, "resp_stub_http")) + return + self._send( + 404, + "application/json", + json.dumps({"error": "not_found", "path": path}).encode("utf-8"), + ) + + def do_GET(self): # noqa: N802 + self._route() + + def do_POST(self): # noqa: N802 + self._route() + + http_server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler) + http_port = http_server.server_address[1] + http_thread = threading.Thread(target=http_server.serve_forever, daemon=True) + http_thread.start() + + ws_requests = [0] + + async def ws_handler(connection) -> None: + async for _raw in connection: + ws_requests[0] += 1 + for event in responses_events(WS_TEXT, "resp_stub_ws"): + await connection.send(json.dumps(event)) + + ws_server = await ws_serve(ws_handler, "127.0.0.1", 0) + ws_port = ws_server.sockets[0].getsockname()[1] + + return _Upstream( + http_url=f"http://127.0.0.1:{http_port}", + ws_url=f"ws://127.0.0.1:{ws_port}", + _http_server=http_server, + _http_thread=http_thread, + _ws_server=ws_server, + ws_requests=ws_requests, + ) + + +class _CountingSocketHandler(CopilotWebSocketHandler): + """Forwarding WebSocket handler that counts messages in both directions.""" + + def __init__(self, ctx: CopilotRequestContext, url: str, counters: _Counters) -> None: + super().__init__(ctx, url=url) + self._counters = counters + + async def send_request_message(self, data: str | bytes) -> None: + self._counters.ws_request_messages += 1 + await super().send_request_message(data) + + async def send_response_message(self, data: str | bytes) -> None: + self._counters.ws_response_messages += 1 + await super().send_response_message(data) + + +class _TestHandler(CopilotRequestHandler): + def __init__(self, upstream: _Upstream, counters: _Counters) -> None: + self._upstream = upstream + self._counters = counters + self._client = httpx.AsyncClient(timeout=None, follow_redirects=False) + + def _rewrite_http(self, url: httpx.URL) -> httpx.URL: + up = httpx.URL(self._upstream.http_url) + return url.copy_with(scheme=up.scheme, host=up.host, port=up.port) + + def _rewrite_ws(self, url: str) -> str: + parsed = httpx.URL(url) + up = httpx.URL(self._upstream.ws_url) + return str(parsed.copy_with(scheme=up.scheme, host=up.host, port=up.port)) + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + self._counters.http_requests += 1 + headers = dict(request.headers) + headers["x-test-mutated"] = "1" + rewritten = httpx.Request( + request.method, + self._rewrite_http(request.url), + headers=headers, + content=request.content, + ) + response = await self._client.send(rewritten, stream=True) + self._counters.http_responses += 1 + response.headers["x-test-response-mutated"] = "1" + return response + + async def open_websocket(self, ctx: CopilotRequestContext): + return _CountingSocketHandler(ctx, self._rewrite_ws(ctx.url), self._counters) + + async def aclose(self) -> None: + await self._client.aclose() + + +@dataclass +class _HandlerFixture: + client: CopilotClient + upstream: _Upstream + counters: _Counters + + +@pytest_asyncio.fixture(loop_scope="module") +async def handler_fixture(ctx: E2ETestContext): + upstream = await _start_fake_upstream() + counters = _Counters() + handler = _TestHandler(upstream, counters) + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + env = {**ctx.get_env(), "COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES": "true"} + client = CopilotClient( + connection=RuntimeConnection.for_stdio(path=ctx.cli_path), + working_directory=ctx.work_dir, + env=env, + github_token=github_token, + request_handler=handler, + ) + try: + yield _HandlerFixture(client=client, upstream=upstream, counters=counters) + finally: + try: + await client.stop() + except Exception: + # Best-effort teardown during fixture cleanup. + pass + await handler.aclose() + await upstream.close() + + +class TestCopilotRequestHandler: + async def test_services_http_and_websocket_via_one_handler(self, handler_fixture): + fx = handler_fixture + await fx.client.start() + session = await fx.client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + # The HTTP seam fired — the runtime issued model-layer GETs (catalog, + # policy) and possibly a single-shot inference through send_request. + assert fx.counters.http_requests > 0, "expected send_request to fire" + assert fx.counters.http_responses > 0, "expected send_request response mutation to fire" + + # The WebSocket seam fired — the main agent turn went over the WS path + # and we observed messages in both directions. + assert fx.counters.ws_request_messages > 0, "expected runtime → upstream ws messages" + assert fx.counters.ws_response_messages > 0, "expected upstream → runtime ws messages" + assert fx.upstream.ws_request_count > 0, "expected upstream WS to receive request messages" + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from synthetic" in text and "upstream" in text diff --git a/python/e2e/test_copilot_request_session_id_e2e.py b/python/e2e/test_copilot_request_session_id_e2e.py new file mode 100644 index 000000000..e40af13a1 --- /dev/null +++ b/python/e2e/test_copilot_request_session_id_e2e.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# -------------------------------------------------------------------------------------------- + +"""E2E tests asserting the runtime threads its session id into the +CopilotRequestHandler for both CAPI and BYOK sessions. + +Mirrors ``nodejs/test/e2e/copilot_request_session_id.e2e.test.ts``. The handler +alone services every model-layer request (no upstream server, no CAPI proxy +acting as the inference endpoint), so the only source of ``ctx.session_id`` is +the runtime's own per-client threading. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import httpx +import pytest + +from copilot import CopilotRequestContext, CopilotRequestHandler +from copilot.session import PermissionHandler + +from ._copilot_request_helpers import ( + assistant_text, + build_inference_response, + build_non_inference_response, + is_inference_url, + isolated_client_fixture, +) + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +@dataclass +class _InterceptedRequest: + url: str + session_id: str | None + + +class _SessionIdHandler(CopilotRequestHandler): + def __init__(self) -> None: + self.records: list[_InterceptedRequest] = [] + + async def send_request( + self, request: httpx.Request, ctx: CopilotRequestContext + ) -> httpx.Response: + url = str(request.url) + self.records.append(_InterceptedRequest(url=url, session_id=ctx.session_id)) + if is_inference_url(url): + return build_inference_response(request) + # Force /responses transport so the inference URL is predictable. + return build_non_inference_response(url, supported_endpoints=["/responses"]) + + +session_id_client = isolated_client_fixture(_SessionIdHandler) + + +class TestCopilotRequestSessionId: + capi_session_id: str | None = None + + async def test_threads_session_id_into_capi_session(self, session_id_client): + client, handler = session_id_client + await client.start() + baseline = len(handler.records) + session = await client.create_session(on_permission_request=PermissionHandler.approve_all) + TestCopilotRequestSessionId.capi_session_id = session.session_id + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + inference = [r for r in handler.records[baseline:] if is_inference_url(r.url)] + assert len(inference) > 0, "expected at least one intercepted inference request" + for r in inference: + assert r.session_id == session.session_id, ( + "CAPI inference request must carry the runtime session id" + ) + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from the synthetic" in text + + async def test_threads_session_id_into_byok_session(self, session_id_client): + client, handler = session_id_client + await client.start() + baseline = len(handler.records) + session = await client.create_session( + on_permission_request=PermissionHandler.approve_all, + model="claude-sonnet-4.5", + provider={ + "type": "openai", + "wire_api": "responses", + "base_url": "https://byok.invalid/v1", + "api_key": "byok-secret", + "model_id": "claude-sonnet-4.5", + "wire_model": "claude-sonnet-4.5", + }, + ) + byok_session_id = session.session_id + text = "" + try: + result = await session.send_and_wait("Say OK.") + text = assistant_text(result) + finally: + await session.disconnect() + + inference = [r for r in handler.records[baseline:] if is_inference_url(r.url)] + assert len(inference) > 0, "expected at least one intercepted BYOK inference request" + for r in inference: + assert r.session_id == byok_session_id, ( + "BYOK inference request must carry the runtime session id" + ) + + # Session ids are per-session, so the two turns must differ. + assert byok_session_id != TestCopilotRequestSessionId.capi_session_id + + # Validate the final assistant response arrived (guards against truncated captures) + assert "OK from the synthetic" in text diff --git a/python/pyproject.toml b/python/pyproject.toml index 596e07be2..ea15b2d71 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ dependencies = [ "python-dateutil>=2.9.0.post0", "pydantic>=2.0", + "httpx>=0.24.0", ] [project.urls] @@ -41,7 +42,7 @@ dev = [ "pytest>=7.0.0", "pytest-asyncio>=0.21.0", "pytest-timeout>=2.0.0", - "httpx>=0.24.0", + "websockets>=12.0", "opentelemetry-sdk>=1.0.0", ] telemetry = [ diff --git a/python/test_client.py b/python/test_client.py index 6af4450de..20bfa1e1b 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -51,7 +51,12 @@ async def shutdown(self, *, timeout=None): await client.stop() assert calls == ["runtime.shutdown"] - process.terminate.assert_not_called() + # The runtime never self-exits after runtime.shutdown (it keeps its + # JSON-RPC server alive to send the response and leaves termination to + # the caller), so stop() terminates the owned process. The mocked + # process exits on terminate() (wait returns immediately), so we never + # escalate to kill(). + process.terminate.assert_called_once() process.kill.assert_not_called() @pytest.mark.asyncio diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 8d8efffca..fb7b66e19 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -43,6 +43,12 @@ dependencies = [ "syn", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "base64" version = "0.22.1" @@ -70,6 +76,12 @@ version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" @@ -92,6 +104,22 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.17" @@ -126,6 +154,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4ae5f15dda3c708c0ade84bfee31ccab44a3da4f88015ed22f63732abe300c8" + [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -246,12 +280,33 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "foldhash" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -261,6 +316,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-core" version = "0.3.32" @@ -278,6 +342,23 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -297,7 +378,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "slab", ] @@ -341,11 +426,16 @@ name = "github-copilot-sdk" version = "0.0.0-dev" dependencies = [ "async-trait", + "base64", + "bytes", "dirs", "flate2", + "futures-util", "getrandom 0.2.17", + "http", "parking_lot", "regex", + "reqwest", "rusqlite", "schemars", "serde", @@ -356,6 +446,7 @@ dependencies = [ "tempfile", "tokio", "tokio-stream", + "tokio-tungstenite", "tokio-util", "tracing", "ureq", @@ -363,6 +454,25 @@ dependencies = [ "zip", ] +[[package]] +name = "h2" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb093c84e8bd9b188d4c4a8cb6579fc016968d14c99882163cd3ff402a4f155" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -393,6 +503,120 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "http" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6970f50e31d6fc17d3fa27329444bfa74e196cf62e95052a3f6fee181dba6425" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "hyper" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ca68d021ef39cf6463ab54c1d0f5daf03377b70561305bb89a8f83aab66e0f" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "icu_collections" version = "2.2.0" @@ -514,12 +738,29 @@ dependencies = [ "serde_core", ] +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + [[package]] name = "itoa" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +[[package]] +name = "js-sys" +version = "0.3.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03d04c30968dffe80775bd4d7fb676131cd04a1fb46d2686dbffbaec2d9dfd31" +dependencies = [ + "cfg-if", + "futures-util", + "wasm-bindgen", +] + [[package]] name = "leb128fmt" version = "0.1.0" @@ -609,12 +850,72 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "once_cell" version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +[[package]] +name = "openssl" +version = "0.10.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77823a27f0babb03091cb9ed9ef80af3b39dbc82f97e8fa530374b7dafd87a45" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b47e7e6bb2c38cd930d25a23b40fa52e068c10e85f3e03a7f5ba5aaca5713695" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -677,6 +978,15 @@ dependencies = [ "zerovec", ] +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -711,6 +1021,36 @@ version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" +[[package]] +name = "rand" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -789,6 +1129,47 @@ version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + [[package]] name = "ring" version = "0.17.14" @@ -865,6 +1246,18 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "scc" version = "2.4.0" @@ -874,6 +1267,15 @@ dependencies = [ "sdd", ] +[[package]] +name = "schannel" +version = "0.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "schemars" version = "1.2.1" @@ -911,6 +1313,29 @@ version = "3.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.28" @@ -971,6 +1396,18 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serial_test" version = "3.4.0" @@ -997,6 +1434,17 @@ dependencies = [ "syn", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1075,6 +1523,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + [[package]] name = "synstructure" version = "0.13.2" @@ -1187,6 +1644,26 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.18" @@ -1199,6 +1676,20 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "native-tls", + "tokio", + "tokio-native-tls", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1212,6 +1703,51 @@ dependencies = [ "tokio", ] +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", + "url", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.44" @@ -1243,6 +1779,31 @@ dependencies = [ "once_cell", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "native-tls", + "rand", + "sha1", + "thiserror 1.0.69", + "utf-8", +] + [[package]] name = "typenum" version = "1.20.0" @@ -1294,6 +1855,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -1321,6 +1888,15 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -1345,6 +1921,61 @@ dependencies = [ "wit-bindgen 0.51.0", ] +[[package]] +name = "wasm-bindgen" +version = "0.2.125" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ddb3f79143bced6de84270411622a2699cee572fc0875aeaf1e7867cf9fca1a" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "503b14d284f2c8dac03b819967e155ea753f573586193b2b2c95990cb5d69280" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.125" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e21a184b13fb19e157296e2c46056aec9092264fab83e4ba59e68c61b323c3d" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.125" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fecefd9c35bd935a20fc3fc344b5f29138961e4f47fb03297d88f2587afb5ebd" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.125" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23939e44bb9a5d7576fa2b563dc2e136628f1224e88a8deed09e04858b77871f" +dependencies = [ + "unicode-ident", +] + [[package]] name = "wasm-encoder" version = "0.244.0" @@ -1367,6 +1998,19 @@ dependencies = [ "wasmparser", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wasmparser" version = "0.244.0" @@ -1379,6 +2023,16 @@ dependencies = [ "semver", ] +[[package]] +name = "web-sys" +version = "0.3.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6430a72df5eb332242960fe84b3002a241163998241eb596d4f739b9757061d" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -1684,6 +2338,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce1022995ff5ff5d841ad7d994facc23098cd40152f2c1d11cd607c6f530653f" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ae7f38b72ec2a254e2b87ef277cf2cd4fb97cbebf944faa6f33354da0867930" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.7" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d835ed276..66ef69ad2 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -52,6 +52,14 @@ parking_lot = "0.12" regex = "1" getrandom = "0.2" uuid = { version = "1", default-features = false, features = ["v4"] } +# LLM inference callback transport: idiomatic HTTP/WebSocket forwarding for the +# `CopilotRequestHandler`, plus base64/byte/stream plumbing for the chunk protocol. +base64 = "0.22" +bytes = "1" +http = "1" +futures-util = "0.3" +reqwest = { version = "0.12", default-features = false, features = ["stream", "http2", "default-tls"] } +tokio-tungstenite = { version = "0.24", default-features = false, features = ["connect", "native-tls"] } [target.'cfg(windows)'.dependencies] zip = { version = "2", default-features = false, features = ["deflate"], optional = true } diff --git a/rust/src/copilot_request_handler.rs b/rust/src/copilot_request_handler.rs new file mode 100644 index 000000000..57e6db8be --- /dev/null +++ b/rust/src/copilot_request_handler.rs @@ -0,0 +1,1208 @@ +//! Connection-level interception of the model-layer HTTP and WebSocket traffic +//! the runtime issues — for both CAPI and BYOK sessions. +//! +//! When [`ClientOptions::request_handler`](crate::ClientOptions::request_handler) +//! is set, the SDK registers itself as the runtime's request handler on +//! [`Client::start`](crate::Client::start). From then on, whenever the runtime +//! would issue a model-layer request (inference, `/models`, `/policy`, …) it +//! asks the registered [`CopilotRequestHandler`] to service it instead of making +//! the call itself. +//! +//! [`CopilotRequestHandler`] is the single seam consumers implement: one HTTP +//! send method and one WebSocket factory, each defaulting to transparent +//! pass-through to the real upstream. Override +//! [`send_request`](CopilotRequestHandler::send_request) to mutate / replace HTTP +//! requests, or [`open_websocket`](CopilotRequestHandler::open_websocket) to +//! mutate the handshake or return a custom [`CopilotWebSocketHandler`]. +//! +//! # Cancellation +//! +//! [`CopilotRequestContext::cancel`] fires when the runtime cancels the +//! in-flight request (for example because the agent turn was aborted). Forward +//! it to the upstream call so it is torn down too, and stop writing the response. + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::{Arc, LazyLock, OnceLock, Weak}; + +use async_trait::async_trait; +use base64::Engine; +use bytes::Bytes; +use futures_util::{SinkExt, Stream, StreamExt}; +use http::HeaderMap; +use http::header::{HeaderName, HeaderValue}; +use parking_lot::Mutex; +use tokio::net::TcpStream; +use tokio::sync::{Mutex as AsyncMutex, mpsc}; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async}; +use tokio_util::sync::CancellationToken; +use tracing::warn; + +use crate::generated::api_types::{ + LlmInferenceHttpRequestChunkRequest, LlmInferenceHttpRequestStartRequest, + LlmInferenceHttpRequestStartTransport, LlmInferenceHttpResponseChunkError, + LlmInferenceHttpResponseChunkRequest, LlmInferenceHttpResponseStartRequest, +}; +use crate::{ + Client, ClientInner, JsonRpcRequest, JsonRpcResponse, RequestId, SessionId, error_codes, +}; + +const METHOD_HTTP_REQUEST_START: &str = "llmInference.httpRequestStart"; +const METHOD_HTTP_REQUEST_CHUNK: &str = "llmInference.httpRequestChunk"; + +/// Transport the runtime would otherwise use for an intercepted request. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CopilotRequestTransport { + /// Plain HTTP or SSE. Each response body frame is an opaque byte range. + #[default] + Http, + /// Full-duplex WebSocket. Each request/response body frame maps to exactly + /// one WebSocket message. + Websocket, +} + +impl CopilotRequestTransport { + fn from_wire(value: Option) -> Self { + match value { + Some(LlmInferenceHttpRequestStartTransport::Websocket) => Self::Websocket, + _ => Self::Http, + } + } +} + +/// Error returned by a [`CopilotRequestHandler`] hook or the response stream. +#[derive(Debug)] +#[non_exhaustive] +pub enum CopilotRequestError { + /// The response was used after the RPC connection to the runtime closed. + ConnectionClosed, + + /// The response state machine was violated (for example `start` called + /// twice, or a write before `start`). + InvalidState(String), + + /// An upstream transport failure while forwarding the request. + Upstream(String), + + /// A failure surfaced by the consumer's own handler. + Handler(String), + + /// An RPC error talking to the runtime. + Rpc(crate::Error), +} + +impl CopilotRequestError { + /// Construct a handler-level error from a message — the idiomatic way for a + /// consumer to fail an intercepted request. + pub fn message(message: impl Into) -> Self { + Self::Handler(message.into()) + } +} + +impl std::fmt::Display for CopilotRequestError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ConnectionClosed => { + f.write_str("Copilot request response used after RPC connection closed") + } + Self::InvalidState(message) | Self::Upstream(message) | Self::Handler(message) => { + f.write_str(message) + } + Self::Rpc(err) => write!(f, "{err}"), + } + } +} + +impl std::error::Error for CopilotRequestError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Rpc(err) => Some(err), + _ => None, + } + } +} + +impl From for CopilotRequestError { + fn from(err: crate::Error) -> Self { + Self::Rpc(err) + } +} + +/// Context describing an intercepted request, shared by the HTTP and WebSocket +/// seams. +#[derive(Clone)] +#[non_exhaustive] +pub struct CopilotRequestContext { + /// Opaque runtime-minted request id, stable across the request lifecycle. + pub request_id: String, + /// Id of the runtime session that triggered this request, or `None` when it + /// was issued outside any session (for example the startup model catalog). + pub session_id: Option, + /// Transport the runtime would otherwise use. + pub transport: CopilotRequestTransport, + /// Absolute request URL. + pub url: String, + /// Request headers, multi-valued. + pub headers: HeaderMap, + /// Fires when the runtime cancels this in-flight request. + pub cancel: CancellationToken, +} + +/// Streaming response body: a sequence of byte chunks or a terminal error. +pub type CopilotHttpResponseBody = + Pin> + Send>>; + +/// A buffered HTTP request handed to [`CopilotRequestHandler::send_request`]. +#[non_exhaustive] +pub struct CopilotHttpRequest { + /// HTTP method (`GET`, `POST`, …). + pub method: String, + /// Absolute request URL. + pub url: String, + /// Request headers. + pub headers: HeaderMap, + /// Fully-buffered request body. + pub body: Vec, + /// Fires when the runtime cancels the request. + pub cancel: CancellationToken, +} + +/// A streaming HTTP response returned by [`CopilotRequestHandler::send_request`]. +#[non_exhaustive] +pub struct CopilotHttpResponse { + /// HTTP status code. + pub status: u16, + /// Optional status reason phrase. + pub status_text: Option, + /// Response headers. + pub headers: HeaderMap, + /// Streaming response body. + pub body: CopilotHttpResponseBody, +} + +impl CopilotHttpResponse { + /// Build a response with the given parts. + pub fn new( + status: u16, + status_text: Option, + headers: HeaderMap, + body: CopilotHttpResponseBody, + ) -> Self { + Self { + status, + status_text, + headers, + body, + } + } +} + +/// A single WebSocket message flowing through a [`CopilotWebSocketHandler`]. +#[derive(Clone)] +pub struct CopilotWebSocketMessage { + /// Message payload. + pub data: Vec, + /// Whether the payload is a binary frame (`true`) or a text frame (`false`). + pub binary: bool, +} + +impl CopilotWebSocketMessage { + /// A UTF-8 text message. + pub fn text(data: impl Into) -> Self { + Self { + data: data.into().into_bytes(), + binary: false, + } + } + + /// A binary message. + pub fn binary(data: Vec) -> Self { + Self { data, binary: true } + } +} + +/// The runtime-facing side of a WebSocket: a [`CopilotWebSocketHandler`] writes +/// upstream→runtime messages here. +#[derive(Clone)] +pub struct CopilotWebSocketResponse { + exchange: Arc, +} + +impl CopilotWebSocketResponse { + fn new(exchange: Arc) -> Self { + Self { exchange } + } + + /// Forward one upstream message to the runtime. + pub async fn send_message( + &self, + message: CopilotWebSocketMessage, + ) -> Result<(), CopilotRequestError> { + self.exchange.ensure_ws_started().await?; + if message.binary { + self.exchange.write_binary(&message.data).await + } else { + let text = String::from_utf8_lossy(&message.data); + self.exchange.write_text(&text).await + } + } + + /// End the runtime response stream (the upstream connection closed). + pub async fn close(&self) -> Result<(), CopilotRequestError> { + self.exchange.end_response().await + } + + async fn fail( + &self, + message: impl Into, + code: Option, + ) -> Result<(), CopilotRequestError> { + self.exchange.error_response(message, code).await + } +} + +/// A per-connection WebSocket handler. The default implementation +/// ([`ForwardingCopilotWebSocketHandler`]) bridges to the real upstream; +/// override [`CopilotRequestHandler::open_websocket`] to supply a custom one. +#[async_trait] +pub trait CopilotWebSocketHandler: Send + Sync { + /// Forward one runtime→upstream message. + async fn send_request_message( + &self, + message: CopilotWebSocketMessage, + ) -> Result<(), CopilotRequestError>; + + /// Tear down the upstream connection. + async fn close(&self) -> Result<(), CopilotRequestError>; +} + +/// The connection-level Copilot request seam. +/// +/// One implementor services both transports. Defaults forward transparently to +/// the real upstream, so overriding nothing yields a pass-through; override a +/// method to mutate or replace traffic. +#[async_trait] +pub trait CopilotRequestHandler: Send + Sync + 'static { + /// Service one intercepted HTTP request. Default: forward to the real + /// upstream via [`forward_http`]. Override to mutate the request before + /// forwarding, mutate the response after, or replace the call entirely. + async fn send_request( + &self, + request: CopilotHttpRequest, + _ctx: &CopilotRequestContext, + ) -> Result { + forward_http(request).await + } + + /// Open a per-connection WebSocket handler. Default: a + /// [`ForwardingCopilotWebSocketHandler`] wired to the real upstream. + /// Override to mutate the handshake (URL / headers via `ctx`) or return a + /// custom handler. + /// + /// Unlike the other SDKs, Rust passes `response` — the runtime-facing sink + /// for upstream→runtime messages — as a second argument here rather than + /// exposing a base-class `send_response_message` helper. A custom handler + /// must store this `CopilotWebSocketResponse` in the returned handler struct + /// and call [`CopilotWebSocketResponse::send_message`] on it to push + /// upstream messages back to the runtime. + async fn open_websocket( + &self, + ctx: &CopilotRequestContext, + response: CopilotWebSocketResponse, + ) -> Result, CopilotRequestError> { + let handler = + ForwardingCopilotWebSocketHandler::builder(ctx.url.clone(), ctx.headers.clone()) + .connect(response) + .await?; + Ok(Box::new(handler)) + } +} + +/// Forward through a shared handler, so an `Arc` can be registered while the +/// consumer retains a handle (for example to read state the handler records). +#[async_trait] +impl CopilotRequestHandler for Arc { + async fn send_request( + &self, + request: CopilotHttpRequest, + ctx: &CopilotRequestContext, + ) -> Result { + (**self).send_request(request, ctx).await + } + + async fn open_websocket( + &self, + ctx: &CopilotRequestContext, + response: CopilotWebSocketResponse, + ) -> Result, CopilotRequestError> { + (**self).open_websocket(ctx, response).await + } +} +/// fresh upstream connection. +const FORBIDDEN_HEADERS: &[&str] = &[ + "host", + "connection", + "content-length", + "transfer-encoding", + "keep-alive", + "upgrade", + "proxy-connection", + "te", + "trailer", +]; + +fn is_forbidden_header(name: &HeaderName) -> bool { + let name = name.as_str(); + FORBIDDEN_HEADERS.contains(&name) || name.starts_with("sec-websocket") +} + +/// Drop headers that belong to the inbound connection rather than the request. +fn strip_forbidden_headers(headers: &mut HeaderMap) { + let forbidden: Vec = headers + .keys() + .filter(|name| is_forbidden_header(name)) + .cloned() + .collect(); + for name in forbidden { + headers.remove(&name); + } +} + +static SHARED_HTTP_CLIENT: LazyLock = LazyLock::new(|| { + reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("default reqwest client must build") +}); + +/// Forward an HTTP request to its real upstream and stream the response back. +/// +/// This is the default behaviour of [`CopilotRequestHandler::send_request`]; +/// consumers that mutate a request can call it to forward the mutated request. +pub async fn forward_http( + request: CopilotHttpRequest, +) -> Result { + let method = reqwest::Method::from_bytes(request.method.as_bytes()) + .map_err(|e| CopilotRequestError::InvalidState(format!("invalid HTTP method: {e}")))?; + + let mut headers = request.headers; + strip_forbidden_headers(&mut headers); + + let mut builder = SHARED_HTTP_CLIENT + .request(method, &request.url) + .headers(headers); + if !request.body.is_empty() { + builder = builder.body(request.body); + } + + let response = tokio::select! { + _ = request.cancel.cancelled() => { + return Err(CopilotRequestError::message("Request cancelled by runtime")); + } + result = builder.send() => result.map_err(|e| CopilotRequestError::Upstream(e.to_string()))?, + }; + + let status = response.status().as_u16(); + let status_text = response.status().canonical_reason().map(str::to_string); + let headers = response.headers().clone(); + let body = response + .bytes_stream() + .map(|item| item.map_err(|e| CopilotRequestError::Upstream(e.to_string()))); + + Ok(CopilotHttpResponse { + status, + status_text, + headers, + body: Box::pin(body), + }) +} + +type UpstreamWrite = + futures_util::stream::SplitSink>, Message>; + +/// Transform applied to a WebSocket message; return `None` to drop it. +pub type WebSocketTransform = + Arc Option + Send + Sync>; + +/// Builder for a [`ForwardingCopilotWebSocketHandler`]. +pub struct ForwardingCopilotWebSocketHandlerBuilder { + url: String, + headers: HeaderMap, + on_send_request_message: Option, + on_send_response_message: Option, +} + +impl ForwardingCopilotWebSocketHandlerBuilder { + /// Hook runtime→upstream messages (mutate or drop before forwarding). + pub fn on_send_request_message(mut self, transform: WebSocketTransform) -> Self { + self.on_send_request_message = Some(transform); + self + } + + /// Hook upstream→runtime messages (mutate or drop before forwarding). + pub fn on_send_response_message(mut self, transform: WebSocketTransform) -> Self { + self.on_send_response_message = Some(transform); + self + } + + /// Dial the upstream WebSocket and begin pumping upstream→runtime messages + /// into `response`. + pub async fn connect( + self, + response: CopilotWebSocketResponse, + ) -> Result { + let mut request = + self.url.as_str().into_client_request().map_err(|e| { + CopilotRequestError::Upstream(format!("invalid websocket url: {e}")) + })?; + for (name, value) in &self.headers { + if is_forbidden_header(name) { + continue; + } + request.headers_mut().append(name.clone(), value.clone()); + } + + let (stream, _) = connect_async(request) + .await + .map_err(|e| CopilotRequestError::Upstream(format!("websocket connect failed: {e}")))?; + let (write, mut read) = stream.split(); + + let cancel = CancellationToken::new(); + let loop_cancel = cancel.clone(); + let on_response = self.on_send_response_message.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + _ = loop_cancel.cancelled() => break, + msg = read.next() => match msg { + Some(Ok(Message::Text(text))) => { + let message = CopilotWebSocketMessage::text(text); + if let Some(out) = apply_transform(&on_response, message) { + let _ = response.send_message(out).await; + } + } + Some(Ok(Message::Binary(data))) => { + let message = CopilotWebSocketMessage::binary(data); + if let Some(out) = apply_transform(&on_response, message) { + let _ = response.send_message(out).await; + } + } + Some(Ok(Message::Close(_))) | None => break, + Some(Ok(_)) => continue, + Some(Err(e)) => { + let _ = response.fail(e.to_string(), None).await; + return; + } + } + } + } + let _ = response.close().await; + }); + + Ok(ForwardingCopilotWebSocketHandler { + write: AsyncMutex::new(Some(write)), + on_send_request_message: self.on_send_request_message, + cancel, + }) + } +} + +/// The default WebSocket handler: forwards each runtime message to the real +/// upstream and each upstream message back to the runtime. Mutate by supplying +/// transforms on the [builder](ForwardingCopilotWebSocketHandler::builder). +pub struct ForwardingCopilotWebSocketHandler { + write: AsyncMutex>, + on_send_request_message: Option, + cancel: CancellationToken, +} + +impl ForwardingCopilotWebSocketHandler { + /// Start building a forwarding handler for `url` with the given upstream + /// handshake headers. + pub fn builder(url: String, headers: HeaderMap) -> ForwardingCopilotWebSocketHandlerBuilder { + ForwardingCopilotWebSocketHandlerBuilder { + url, + headers, + on_send_request_message: None, + on_send_response_message: None, + } + } +} + +#[async_trait] +impl CopilotWebSocketHandler for ForwardingCopilotWebSocketHandler { + async fn send_request_message( + &self, + message: CopilotWebSocketMessage, + ) -> Result<(), CopilotRequestError> { + let Some(message) = apply_transform(&self.on_send_request_message, message) else { + return Ok(()); + }; + let ws_message = if message.binary { + Message::Binary(message.data) + } else { + Message::Text(String::from_utf8_lossy(&message.data).into_owned()) + }; + let mut guard = self.write.lock().await; + if let Some(write) = guard.as_mut() { + write + .send(ws_message) + .await + .map_err(|e| CopilotRequestError::Upstream(e.to_string()))?; + } + Ok(()) + } + + async fn close(&self) -> Result<(), CopilotRequestError> { + self.cancel.cancel(); + let mut guard = self.write.lock().await; + if let Some(mut write) = guard.take() { + let _ = write.send(Message::Close(None)).await; + let _ = write.close().await; + } + Ok(()) + } +} + +fn apply_transform( + transform: &Option, + message: CopilotWebSocketMessage, +) -> Option { + match transform { + Some(f) => f(message), + None => Some(message), + } +} + +/// Mutable response state machine for a single exchange. +#[derive(Default)] +struct ResponseState { + started: bool, + finished: bool, +} + +/// One intercepted request in flight. +/// +/// Carries the request metadata plus the body byte stream the runtime feeds in +/// via `httpRequestChunk` frames, and emits the handler's response straight back +/// to the runtime through the generated `llmInference` server API — a single +/// object the dispatcher owns and the handler drives. +/// Request context populated when the matching `httpRequestStart` frame +/// arrives. Held behind a `OnceLock` so the owning [`CopilotRequestExchange`] +/// can be created bare by a body chunk that races ahead of its start frame. +#[derive(Default)] +struct RequestMeta { + session_id: Option, + method: String, + url: String, + headers: HeaderMap, + transport: CopilotRequestTransport, +} + +struct CopilotRequestExchange { + request_id: String, + meta: OnceLock, + cancel: CancellationToken, + client: Weak, + /// Sender feeding the request body stream. Dropped (set to `None`) on `end` + /// or `cancel` to close the stream. + body_tx: Mutex>>>, + body_rx: AsyncMutex>>, + state: Mutex, +} + +impl CopilotRequestExchange { + fn new(request_id: String, client: Weak) -> Self { + let (body_tx, body_rx) = mpsc::unbounded_channel(); + Self { + request_id, + meta: OnceLock::new(), + cancel: CancellationToken::new(), + client, + body_tx: Mutex::new(Some(body_tx)), + body_rx: AsyncMutex::new(body_rx), + state: Mutex::new(ResponseState::default()), + } + } + + /// Fill in the request context once the matching start frame arrives. + fn set_context(&self, params: LlmInferenceHttpRequestStartRequest) { + let _ = self.meta.set(RequestMeta { + session_id: params.session_id.map(SessionId::into_inner), + method: params.method, + url: params.url, + headers: headers_from_wire(¶ms.headers), + transport: CopilotRequestTransport::from_wire(params.transport), + }); + } + + /// Request metadata. Always populated before the handler runs; the + /// defaulted fallback only guards the (contract-impossible) case of a body + /// chunk with no preceding start frame. + fn meta(&self) -> &RequestMeta { + self.meta.get_or_init(RequestMeta::default) + } + + fn context(&self) -> CopilotRequestContext { + let meta = self.meta(); + CopilotRequestContext { + request_id: self.request_id.clone(), + session_id: meta.session_id.clone(), + transport: meta.transport, + url: meta.url.clone(), + headers: meta.headers.clone(), + cancel: self.cancel.clone(), + } + } + + fn client(&self) -> Result { + self.client + .upgrade() + .map(Client::from_inner) + .ok_or(CopilotRequestError::ConnectionClosed) + } + + fn request_id(&self) -> RequestId { + RequestId::new(self.request_id.clone()) + } + + // --- Request body feed (driven by the dispatcher as frames arrive) --- + + fn push_chunk(&self, data: Vec) { + if let Some(tx) = self.body_tx.lock().as_ref() { + let _ = tx.send(data); + } + } + + fn push_end(&self) { + *self.body_tx.lock() = None; + } + + fn push_cancel(&self) { + self.cancel.cancel(); + *self.body_tx.lock() = None; + } + + async fn recv_body(&self) -> Option> { + self.body_rx.lock().await.recv().await + } + + async fn drain_body(&self) -> Vec { + let mut buf = Vec::new(); + let mut rx = self.body_rx.lock().await; + while let Some(frame) = rx.recv().await { + buf.extend_from_slice(&frame); + } + buf + } + + // --- Response emit (driven by the handler). Strict state machine: --- + // start_response once -> 0..N write -> exactly one of + // end_response / error_response. + + fn started(&self) -> bool { + self.state.lock().started + } + + fn finished(&self) -> bool { + self.state.lock().finished + } + + async fn start_response( + &self, + status: u16, + status_text: Option, + headers: HeaderMap, + ) -> Result<(), CopilotRequestError> { + { + let mut state = self.state.lock(); + if state.started { + return Err(CopilotRequestError::InvalidState( + "response start() called twice".to_string(), + )); + } + if state.finished { + return Err(CopilotRequestError::InvalidState( + "response already finished".to_string(), + )); + } + state.started = true; + } + let request = LlmInferenceHttpResponseStartRequest { + headers: headers_to_wire(&headers), + request_id: self.request_id(), + status: i64::from(status), + status_text, + }; + self.client()? + .rpc() + .llm_inference() + .http_response_start(request) + .await?; + Ok(()) + } + + /// Start the WebSocket upgrade head (status 101) once, ignoring repeat + /// calls. The dispatcher emits it eagerly before pumping; later writes call + /// this as a harmless no-op backstop. + async fn ensure_ws_started(&self) -> Result<(), CopilotRequestError> { + if self.started() { + return Ok(()); + } + self.start_response(101, None, HeaderMap::new()).await + } + + async fn write_text(&self, text: &str) -> Result<(), CopilotRequestError> { + self.write(text.to_string(), false).await + } + + async fn write_binary(&self, data: &[u8]) -> Result<(), CopilotRequestError> { + let encoded = base64::engine::general_purpose::STANDARD.encode(data); + self.write(encoded, true).await + } + + async fn write(&self, data: String, binary: bool) -> Result<(), CopilotRequestError> { + { + let state = self.state.lock(); + if !state.started { + return Err(CopilotRequestError::InvalidState( + "response write called before start()".to_string(), + )); + } + if state.finished { + return Err(CopilotRequestError::InvalidState( + "response write called after end()/error()".to_string(), + )); + } + } + let request = LlmInferenceHttpResponseChunkRequest { + binary: binary.then_some(true), + data, + end: Some(false), + error: None, + request_id: self.request_id(), + }; + self.client()? + .rpc() + .llm_inference() + .http_response_chunk(request) + .await?; + Ok(()) + } + + async fn end_response(&self) -> Result<(), CopilotRequestError> { + { + let mut state = self.state.lock(); + if state.finished { + return Ok(()); + } + state.finished = true; + } + let request = LlmInferenceHttpResponseChunkRequest { + binary: None, + data: String::new(), + end: Some(true), + error: None, + request_id: self.request_id(), + }; + self.client()? + .rpc() + .llm_inference() + .http_response_chunk(request) + .await?; + Ok(()) + } + + async fn error_response( + &self, + message: impl Into, + code: Option, + ) -> Result<(), CopilotRequestError> { + { + let mut state = self.state.lock(); + if state.finished { + return Ok(()); + } + state.finished = true; + } + let request = LlmInferenceHttpResponseChunkRequest { + binary: None, + data: String::new(), + end: Some(true), + error: Some(LlmInferenceHttpResponseChunkError { + code, + message: message.into(), + }), + request_id: self.request_id(), + }; + self.client()? + .rpc() + .llm_inference() + .http_response_chunk(request) + .await?; + Ok(()) + } +} + +/// Drive one exchange through the registered handler, dispatching by transport. +async fn drive_exchange( + exchange: &Arc, + handler: &Arc, +) -> Result<(), CopilotRequestError> { + let ctx = exchange.context(); + let meta = exchange.meta(); + match meta.transport { + CopilotRequestTransport::Http => { + let body = exchange.drain_body().await; + let request = CopilotHttpRequest { + method: meta.method.clone(), + url: meta.url.clone(), + headers: meta.headers.clone(), + body, + cancel: ctx.cancel.clone(), + }; + let response = handler.send_request(request, &ctx).await?; + stream_http_response(response, exchange, &ctx.cancel).await + } + CopilotRequestTransport::Websocket => { + // The runtime blocks the WebSocket connect until it receives the 101 + // response head (the upgrade acknowledgement) and only then forwards + // inbound messages as request-body chunks. Emit it eagerly here — + // waiting for the first upstream message would deadlock, since the + // upstream stays silent until it receives a request message the + // runtime won't send before the upgrade completes. + exchange.ensure_ws_started().await?; + let response = CopilotWebSocketResponse::new(exchange.clone()); + let ws = handler.open_websocket(&ctx, response).await?; + let result = pump_websocket_requests(ws.as_ref(), exchange, &ctx.cancel).await; + let _ = ws.close().await; + match result { + Ok(()) => exchange.end_response().await, + Err(err) if ctx.cancel.is_cancelled() => { + exchange + .error_response( + "Request cancelled by runtime", + Some("cancelled".to_string()), + ) + .await?; + let _ = err; + Ok(()) + } + Err(err) => Err(err), + } + } + } +} + +/// Stream an HTTP response into the runtime, honouring cancellation. +async fn stream_http_response( + response: CopilotHttpResponse, + exchange: &CopilotRequestExchange, + cancel: &CancellationToken, +) -> Result<(), CopilotRequestError> { + exchange + .start_response(response.status, response.status_text, response.headers) + .await?; + + let mut body = response.body; + loop { + tokio::select! { + _ = cancel.cancelled() => { + return exchange + .error_response("Request cancelled by runtime", Some("cancelled".to_string())) + .await; + } + next = body.next() => match next { + Some(Ok(chunk)) => { + for piece in chunk.chunks(32 * 1024) { + exchange.write_binary(piece).await?; + } + } + Some(Err(e)) => { + return exchange.error_response(e.to_string(), None).await; + } + None => break, + } + } + } + exchange.end_response().await +} + +/// Forward runtime→upstream WebSocket messages until the runtime closes its side +/// or cancels. +async fn pump_websocket_requests( + handler: &dyn CopilotWebSocketHandler, + exchange: &CopilotRequestExchange, + cancel: &CancellationToken, +) -> Result<(), CopilotRequestError> { + loop { + tokio::select! { + _ = cancel.cancelled() => { + return Err(CopilotRequestError::message("Request cancelled by runtime")); + } + frame = exchange.recv_body() => match frame { + Some(data) => { + handler + .send_request_message(CopilotWebSocketMessage { data, binary: false }) + .await?; + } + None => return Ok(()), + } + } + } +} + +/// Drive the exchange's response to a terminal state once the handler returns, +/// covering handlers that error, get cancelled, or forget to finalize. +async fn finalize_exchange( + exchange: &CopilotRequestExchange, + result: Result<(), CopilotRequestError>, +) { + match result { + Ok(()) => { + if !exchange.finished() { + fail_via_response( + exchange, + 502, + "Copilot request handler returned without finalising the response".to_string(), + ) + .await; + } + } + Err(err) => { + if exchange.finished() { + return; + } + if exchange.cancel.is_cancelled() { + if !exchange.started() { + let _ = exchange.start_response(499, None, HeaderMap::new()).await; + } + let _ = exchange + .error_response( + "Request cancelled by runtime", + Some("cancelled".to_string()), + ) + .await; + } else { + fail_via_response(exchange, 502, err.to_string()).await; + } + } + } +} + +async fn fail_via_response(exchange: &CopilotRequestExchange, status: u16, message: String) { + if !exchange.started() { + let _ = exchange + .start_response(status, None, HeaderMap::new()) + .await; + } + let _ = exchange.error_response(message, None).await; +} + +/// Routes inbound `llmInference.*` requests to the registered handler, +/// reassembling each request's streaming body and acking every frame. +pub(crate) struct CopilotRequestDispatcher { + handler: Arc, + client: OnceLock>, + pending: Mutex>>, +} + +impl CopilotRequestDispatcher { + pub(crate) fn new(handler: Arc) -> Self { + Self { + handler, + client: OnceLock::new(), + pending: Mutex::new(HashMap::new()), + } + } + + pub(crate) fn set_client(&self, client: Weak) { + let _ = self.client.set(client); + } + + fn client(&self) -> Option { + self.client + .get() + .and_then(Weak::upgrade) + .map(Client::from_inner) + } + + fn client_weak(&self) -> Weak { + self.client.get().cloned().unwrap_or_else(Weak::new) + } + + pub(crate) async fn dispatch(self: &Arc, request: JsonRpcRequest) { + match request.method.as_str() { + METHOD_HTTP_REQUEST_START => self.handle_start(request).await, + METHOD_HTTP_REQUEST_CHUNK => self.handle_chunk(request).await, + other => { + warn!(method = other, "unknown llmInference request method"); + self.send_error(request.id, "unknown llmInference method") + .await; + } + } + } + + fn get_or_create_exchange(&self, request_id: String) -> Arc { + // The runtime dispatches httpRequestStart and httpRequestChunk frames + // independently. get-or-create keeps the adapter correct regardless of + // arrival order: a body chunk (including the terminal end frame) that + // races ahead of its start frame is buffered into the same exchange + // rather than dropped, which would otherwise hang the body drain. + self.pending + .lock() + .entry(request_id.clone()) + .or_insert_with(|| { + Arc::new(CopilotRequestExchange::new(request_id, self.client_weak())) + }) + .clone() + } + + async fn handle_start(self: &Arc, request: JsonRpcRequest) { + let id = request.id; + let Some(params) = parse_params::(&request) else { + self.send_error(id, "invalid llmInference.httpRequestStart params") + .await; + return; + }; + + // Adopt any exchange a racing chunk already created — with its buffered + // body — rather than dropping those frames. + let request_id = params.request_id.clone().into_inner(); + let exchange = self.get_or_create_exchange(request_id.clone()); + exchange.set_context(params); + + let handler = self.handler.clone(); + let dispatcher = Arc::clone(self); + let exchange_for_task = exchange.clone(); + tokio::spawn(async move { + let result = drive_exchange(&exchange_for_task, &handler).await; + finalize_exchange(&exchange_for_task, result).await; + dispatcher.remove_pending(&request_id); + }); + + self.ack(id).await; + } + + async fn handle_chunk(&self, request: JsonRpcRequest) { + let id = request.id; + let Some(params) = parse_params::(&request) else { + self.send_error(id, "invalid llmInference.httpRequestChunk params") + .await; + return; + }; + + // May arrive before the matching start frame; get-or-create so the body + // is buffered, never lost. + let exchange = self.get_or_create_exchange(params.request_id.to_string()); + apply_chunk(&exchange, ¶ms); + + self.ack(id).await; + } + + fn remove_pending(&self, request_id: &str) { + self.pending.lock().remove(request_id); + } + + async fn ack(&self, id: u64) { + let Some(client) = self.client() else { + return; + }; + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: Some(serde_json::json!({})), + error: None, + }) + .await; + } + + async fn send_error(&self, id: u64, message: &str) { + let Some(client) = self.client() else { + return; + }; + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INTERNAL_ERROR, + message: message.to_string(), + data: None, + }), + }) + .await; + } +} + +/// Apply one body chunk to a pending request: route data into the body stream, +/// or terminate it on `end` / `cancel`. +fn apply_chunk(exchange: &CopilotRequestExchange, params: &LlmInferenceHttpRequestChunkRequest) { + if params.cancel == Some(true) { + exchange.push_cancel(); + return; + } + + if !params.data.is_empty() { + let decoded = if params.binary == Some(true) { + match base64::engine::general_purpose::STANDARD.decode(params.data.as_bytes()) { + Ok(bytes) => bytes, + Err(e) => { + warn!(error = %e, "failed to decode base64 llmInference body chunk"); + return; + } + } + } else { + params.data.clone().into_bytes() + }; + exchange.push_chunk(decoded); + } + + if params.end == Some(true) { + exchange.push_end(); + } +} + +fn parse_params(request: &JsonRpcRequest) -> Option { + request + .params + .as_ref() + .and_then(|p| serde_json::from_value(p.clone()).ok()) +} + +/// Convert a wire header map into an [`http::HeaderMap`], skipping any entry the +/// `http` crate rejects. +fn headers_from_wire(wire: &HashMap>) -> HeaderMap { + let mut headers = HeaderMap::new(); + for (name, values) in wire { + let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else { + continue; + }; + for value in values { + let Ok(header_value) = HeaderValue::from_str(value) else { + continue; + }; + headers.append(header_name.clone(), header_value); + } + } + headers +} + +/// Convert an [`http::HeaderMap`] into the wire header map, dropping values that +/// are not valid UTF-8. +fn headers_to_wire(headers: &HeaderMap) -> HashMap> { + let mut wire: HashMap> = HashMap::new(); + for (name, value) in headers { + let Ok(value) = value.to_str() else { + continue; + }; + wire.entry(name.as_str().to_string()) + .or_default() + .push(value.to_string()); + } + wire +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c04fe19e1..a0986182f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -11,6 +11,10 @@ mod canvas_dispatch; pub(crate) mod embeddedcli; mod errors; pub use errors::*; +/// Connection-level Copilot request handler — intercept and replace the +/// model-layer HTTP and WebSocket traffic the runtime issues for both CAPI and +/// BYOK sessions. +pub mod copilot_request_handler; /// Event handler traits for session lifecycle. pub mod handler; /// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). @@ -238,6 +242,15 @@ pub struct ClientOptions { /// [`SessionFsProvider`] via /// [`SessionConfig::with_session_fs_provider`](crate::SessionConfig::with_session_fs_provider). pub session_fs: Option, + /// Connection-level Copilot request handler configuration. + /// + /// When set, the SDK registers itself as the runtime's request handler + /// during [`Client::start`], so the runtime routes its model-layer HTTP and + /// WebSocket traffic — for both CAPI and BYOK sessions — through the + /// configured + /// [`CopilotRequestHandler`] + /// instead of issuing the calls itself. + pub request_handler: Option>, /// Optional [`TraceContextProvider`] used to inject W3C Trace Context /// headers (`traceparent` / `tracestate`) on outbound `session.create`, /// `session.resume`, and `session.send` requests. @@ -313,6 +326,10 @@ impl std::fmt::Debug for ClientOptions { &self.on_list_models.as_ref().map(|_| ""), ) .field("session_fs", &self.session_fs) + .field( + "request_handler", + &self.request_handler.as_ref().map(|_| ""), + ) .field( "on_get_trace_context", &self.on_get_trace_context.as_ref().map(|_| ""), @@ -560,6 +577,7 @@ impl Default for ClientOptions { session_idle_timeout_seconds: None, on_list_models: None, session_fs: None, + request_handler: None, on_get_trace_context: None, telemetry: None, base_directory: None, @@ -692,6 +710,18 @@ impl ClientOptions { self } + /// Register a connection-level Copilot request handler. The runtime will + /// route its model-layer HTTP and WebSocket traffic through the handler + /// configured here instead of issuing the calls itself. The handler is + /// wrapped in `Arc` internally. + pub fn with_request_handler(mut self, handler: H) -> Self + where + H: crate::copilot_request_handler::CopilotRequestHandler, + { + self.request_handler = Some(Arc::new(handler)); + self + } + /// Set the [`TraceContextProvider`] used to inject W3C Trace Context /// headers on outbound `session.create` / `session.resume` / /// `session.send` requests. The provider is wrapped in `Arc` internally. @@ -814,6 +844,9 @@ struct ClientInner { models_cache: parking_lot::Mutex>>>, session_fs_configured: bool, session_fs_sqlite_declared: bool, + /// Inbound `llmInference.*` dispatcher, installed when + /// [`ClientOptions::request_handler`] is set. + llm_inference: OnceLock>, on_get_trace_context: Option>, /// Token sent in the `connect` handshake. Auto-generated when the /// SDK spawns its own CLI in TCP mode and no explicit token is set; @@ -909,6 +942,7 @@ impl Client { } => connection_token.clone(), }; let session_fs_config = options.session_fs.clone(); + let request_handler = options.request_handler.clone(); let session_fs_sqlite_declared = session_fs_config .as_ref() .and_then(|c| c.capabilities.as_ref()) @@ -1044,6 +1078,26 @@ impl Client { "Client::start session filesystem setup complete" ); } + if let Some(handler) = request_handler { + let llm_inference_start = Instant::now(); + let dispatcher = Arc::new(copilot_request_handler::CopilotRequestDispatcher::new( + handler, + )); + dispatcher.set_client(Arc::downgrade(&client.inner)); + let _ = client.inner.llm_inference.set(dispatcher.clone()); + // Start the router early (before any session is registered) so the + // startup model catalog request is dispatched to the handler. + client.inner.router.ensure_started( + &client.inner.notification_tx, + &client.inner.request_rx, + Some(dispatcher.clone()), + ); + client.rpc().llm_inference().set_provider().await?; + debug!( + elapsed_ms = llm_inference_start.elapsed().as_millis(), + "Client::start Copilot request handler registration complete" + ); + } debug!( elapsed_ms = start_time.elapsed().as_millis(), "Client::start complete" @@ -1176,6 +1230,7 @@ impl Client { models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured, session_fs_sqlite_declared, + llm_inference: OnceLock::new(), on_get_trace_context, effective_connection_token, mode, @@ -1557,6 +1612,11 @@ impl Client { self.inner.rpc.write(response).await } + /// Reconstruct a [`Client`] handle from a shared inner pointer. + pub(crate) fn from_inner(inner: Arc) -> Self { + Self { inner } + } + /// Take the receiver for incoming JSON-RPC requests from the CLI. /// /// Can only be called once — subsequent calls return `None`. @@ -1576,9 +1636,11 @@ impl Client { &self, session_id: &SessionId, ) -> crate::router::SessionChannels { - self.inner - .router - .ensure_started(&self.inner.notification_tx, &self.inner.request_rx); + self.inner.router.ensure_started( + &self.inner.notification_tx, + &self.inner.request_rx, + self.inner.llm_inference.get().cloned(), + ); self.inner.router.register(session_id) } @@ -1911,14 +1973,12 @@ impl Client { } let should_shutdown_runtime = self.inner.child.lock().is_some(); - let mut runtime_shutdown_completed = false; if should_shutdown_runtime { let runtime_shutdown_start = Instant::now(); match tokio::time::timeout(RUNTIME_SHUTDOWN_TIMEOUT, self.rpc().runtime().shutdown()) .await { Ok(Ok(())) => { - runtime_shutdown_completed = true; debug!( elapsed_ms = runtime_shutdown_start.elapsed().as_millis(), "Client::stop runtime shutdown complete" @@ -1955,17 +2015,13 @@ impl Client { match child.try_wait() { Ok(Some(_status)) => {} Ok(None) => { - if runtime_shutdown_completed { - match tokio::time::timeout(RUNTIME_SHUTDOWN_TIMEOUT, child.wait()).await { - Ok(Ok(_status)) => {} - Ok(Err(e)) => errors.push(e.into()), - Err(_) => { - if let Err(e) = child.kill().await { - errors.push(e.into()); - } - } - } - } else if let Err(e) = child.kill().await { + // The runtime completes all cleanup before responding to + // runtime.shutdown and then leaves termination to us; it + // deliberately keeps its JSON-RPC server alive to send the + // response and never self-exits. Waiting for a self-exit + // that will never come just wastes time, so terminate the + // child immediately. + if let Err(e) = child.kill().await { errors.push(e.into()); } } @@ -2669,6 +2725,7 @@ mod tests { models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())), session_fs_configured: false, session_fs_sqlite_declared: false, + llm_inference: OnceLock::new(), on_get_trace_context: None, effective_connection_token: None, mode: ClientMode::default(), diff --git a/rust/src/router.rs b/rust/src/router.rs index e14630e03..cc621c287 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -85,6 +85,7 @@ impl SessionRouter { &self, notification_tx: &broadcast::Sender, request_rx: &Mutex>>, + llm_inference: Option>, ) { let mut started = self.started.lock(); if *started { @@ -145,6 +146,20 @@ impl SessionRouter { let sessions = self.sessions.clone(); tokio::spawn(async move { while let Some(request) = rx.recv().await { + // Client-global `llmInference.*` requests carry no routable + // session and are handled by the inference dispatcher. + if request.method.starts_with("llmInference.") { + if let Some(dispatcher) = &llm_inference { + dispatcher.dispatch(request).await; + } else { + warn!( + method = %request.method, + "llmInference request with no provider registered" + ); + } + continue; + } + let session_id = request .params .as_ref() diff --git a/rust/src/types.rs b/rust/src/types.rs index 5c1c0ddf3..9728a6942 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -13,6 +13,12 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use crate::canvas::{CanvasDeclaration, CanvasHandler}; +pub use crate::copilot_request_handler::{ + CopilotHttpRequest, CopilotHttpResponse, CopilotHttpResponseBody, CopilotRequestContext, + CopilotRequestError, CopilotRequestHandler, CopilotRequestTransport, CopilotWebSocketHandler, + CopilotWebSocketMessage, CopilotWebSocketResponse, ForwardingCopilotWebSocketHandler, + ForwardingCopilotWebSocketHandlerBuilder, WebSocketTransform, forward_http, +}; use crate::generated::api_types::OpenCanvasInstance; /// Context window tier for models that support tiered context windows. pub use crate::generated::session_events::ContextTier; diff --git a/rust/tests/e2e.rs b/rust/tests/e2e.rs index 04fe0b2ee..c46630e69 100644 --- a/rust/tests/e2e.rs +++ b/rust/tests/e2e.rs @@ -21,6 +21,8 @@ mod client_options; mod commands; #[path = "e2e/compaction.rs"] mod compaction; +#[path = "e2e/copilot_request_handler.rs"] +mod copilot_request_handler; #[path = "e2e/elicitation.rs"] mod elicitation; #[path = "e2e/error_resilience.rs"] diff --git a/rust/tests/e2e/copilot_request_handler.rs b/rust/tests/e2e/copilot_request_handler.rs new file mode 100644 index 000000000..3fb7c1da0 --- /dev/null +++ b/rust/tests/e2e/copilot_request_handler.rs @@ -0,0 +1,821 @@ +//! End-to-end coverage for the Copilot request handler. +//! +//! These tests register a [`CopilotRequestHandler`] that either fabricates +//! well-formed model responses or forwards to a local upstream, then drive a +//! real agent turn and assert the runtime routed its model-layer HTTP/WebSocket +//! traffic through the handler. No recorded CAPI snapshot is used — the handler +//! replaces every outbound model call. +//! +//! Coverage mirrors the consolidated Node e2e set: +//! - `services_http_and_websocket_via_handler` — a single handler forwards both +//! HTTP and WebSocket traffic to local upstreams (streaming round-trip). +//! - `threads_session_id_into_inference` — the runtime threads its session id +//! into inference requests for both CAPI and BYOK sessions. +//! - `surfaces_handler_errors` — a handler that returns `Err` surfaces a +//! transport error rather than hanging the turn. +//! - `observes_runtime_driven_cancel` — a handler that blocks until the consumer +//! aborts observes the runtime-driven cancellation via `ctx.cancel`. + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use bytes::Bytes; +use futures_util::{SinkExt, StreamExt}; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::session_events::AssistantMessageData; +use github_copilot_sdk::{ + CopilotHttpRequest, CopilotHttpResponse, CopilotRequestContext, CopilotRequestError, + CopilotRequestHandler, CopilotWebSocketHandler, CopilotWebSocketResponse, + ForwardingCopilotWebSocketHandler, MessageOptions, ProviderConfig, SessionConfig, SessionEvent, + forward_http, +}; +use http::header::{HeaderName, HeaderValue}; +use http::{HeaderMap, Uri}; +use serde_json::{Value, json}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::tungstenite::Message; + +use super::support::with_e2e_context_no_snapshot; + +const SYNTHETIC_TEXT: &str = "OK from the synthetic stream."; +const HANDLER_HTTP_TEXT: &str = "OK from synthetic HTTP upstream."; +const HANDLER_WS_TEXT: &str = "OK from synthetic WS upstream."; +const WS_SUPPORTED_ENDPOINTS: &[&str] = &["/responses", "ws:/responses"]; + +fn say_ok() -> MessageOptions { + MessageOptions::new("Say OK.").with_wait_timeout(Duration::from_secs(120)) +} + +fn header_map(pairs: &[(&str, &str)]) -> HeaderMap { + let mut headers = HeaderMap::new(); + for (name, value) in pairs { + headers.insert( + HeaderName::from_bytes(name.as_bytes()).unwrap(), + HeaderValue::from_str(value).unwrap(), + ); + } + headers +} + +fn json_headers() -> HeaderMap { + header_map(&[("content-type", "application/json")]) +} + +fn sse_headers() -> HeaderMap { + header_map(&[("content-type", "text/event-stream")]) +} + +fn assistant_text(event: &Option) -> String { + event + .as_ref() + .and_then(|e| e.typed_data::()) + .map(|data| data.content) + .unwrap_or_default() +} + +fn is_inference_url(url: &str) -> bool { + let url = url.to_lowercase(); + url.ends_with("/chat/completions") + || url.ends_with("/responses") + || url.ends_with("/v1/messages") + || url.ends_with("/messages") +} + +/// Detect `"stream": true` in a request body without depending on exact JSON +/// whitespace. +fn stream_true(body: &[u8]) -> bool { + let text = String::from_utf8_lossy(body); + let compact: String = text.chars().filter(|c| !c.is_whitespace()).collect(); + compact.contains("\"stream\":true") +} + +fn sse(event_type: &str, data: &Value) -> String { + format!( + "event: {event_type}\ndata: {}\n\n", + serde_json::to_string(data).unwrap() + ) +} + +fn model_catalog(supported_endpoints: Option<&[&str]>) -> String { + let mut model = json!({ + "id": "claude-sonnet-4.5", + "name": "Claude Sonnet 4.5", + "object": "model", + "vendor": "Anthropic", + "version": "1", + "preview": false, + "model_picker_enabled": true, + "capabilities": { + "type": "chat", + "family": "claude-sonnet-4.5", + "tokenizer": "o200k_base", + "limits": { + "max_context_window_tokens": 200000, + "max_output_tokens": 8192, + }, + "supports": { + "streaming": true, + "tool_calls": true, + "parallel_tool_calls": true, + "vision": true, + }, + }, + }); + if let Some(endpoints) = supported_endpoints { + model["supported_endpoints"] = json!(endpoints); + } + serde_json::to_string(&json!({ "data": [model] })).unwrap() +} + +/// The ordered `/responses` event objects the runtime's reducer expects. Used +/// raw (one object == one WebSocket message) for the WS path and SSE-framed for +/// the HTTP path. +fn responses_events(text: &str, resp_id: &str) -> Vec { + vec![ + json!({ + "type": "response.created", + "response": { "id": resp_id, "object": "response", "status": "in_progress", "output": [] }, + }), + json!({ + "type": "response.output_item.added", + "output_index": 0, + "item": { "id": "msg_1", "type": "message", "role": "assistant", "content": [] }, + }), + json!({ + "type": "response.content_part.added", + "output_index": 0, + "content_index": 0, + "part": { "type": "output_text", "text": "" }, + }), + json!({ "type": "response.output_text.delta", "output_index": 0, "content_index": 0, "delta": text }), + json!({ "type": "response.output_text.done", "output_index": 0, "content_index": 0, "text": text }), + json!({ + "type": "response.completed", + "response": { + "id": resp_id, + "object": "response", + "status": "completed", + "output": [{ + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{ "type": "output_text", "text": text }], + }], + "usage": { "input_tokens": 5, "output_tokens": 7, "total_tokens": 12 }, + }, + }), + ] +} + +/// Build a streaming HTTP response from a sequence of body chunks. +fn http_response(status: u16, headers: HeaderMap, chunks: Vec>) -> CopilotHttpResponse { + let body = futures_util::stream::iter( + chunks + .into_iter() + .map(|chunk| Ok::(Bytes::from(chunk))), + ); + CopilotHttpResponse::new(status, None, headers, Box::pin(body)) +} + +/// Serve the model catalog, model session and policy endpoints with an +/// empty-JSON fallback for anything unrecognised. +fn synth_non_inference_response( + url: &str, + supported_endpoints: Option<&[&str]>, +) -> CopilotHttpResponse { + let lower = url.to_lowercase(); + if lower.ends_with("/models") { + return http_response( + 200, + json_headers(), + vec![model_catalog(supported_endpoints).into_bytes()], + ); + } + if lower.contains("/models/session") { + return http_response(200, HeaderMap::new(), vec![b"{}".to_vec()]); + } + if lower.contains("/policy") { + return http_response( + 200, + HeaderMap::new(), + vec![br#"{"state":"enabled"}"#.to_vec()], + ); + } + http_response(200, json_headers(), vec![b"{}".to_vec()]) +} + +/// Synthesize a well-formed inference response, dispatching by URL and the +/// request body's stream flag exactly as a real reverse proxy would. +fn synth_inference_response(url: &str, body: &[u8], text: &str) -> CopilotHttpResponse { + let wants_stream = stream_true(body); + let lower = url.to_lowercase(); + + if lower.contains("/responses") { + let events = responses_events(text, "resp_stub_1"); + if !wants_stream { + let last = serde_json::to_string(&events[events.len() - 1]["response"]).unwrap(); + return http_response(200, json_headers(), vec![last.into_bytes()]); + } + let chunks = events + .iter() + .map(|event| sse(event["type"].as_str().unwrap(), event).into_bytes()) + .collect(); + return http_response(200, sse_headers(), chunks); + } + + if lower.contains("/chat/completions") && wants_stream { + let base = || { + json!({ + "id": "chatcmpl-stub-1", + "object": "chat.completion.chunk", + "created": 1, + "model": "claude-sonnet-4.5", + }) + }; + let mut c1 = base(); + c1["choices"] = json!([{ "index": 0, "delta": { "role": "assistant", "content": "" }, "finish_reason": null }]); + let mut c2 = base(); + c2["choices"] = + json!([{ "index": 0, "delta": { "content": text }, "finish_reason": null }]); + let mut c3 = base(); + c3["choices"] = json!([{ "index": 0, "delta": {}, "finish_reason": "stop" }]); + c3["usage"] = json!({ "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }); + let mut chunks: Vec> = [c1, c2, c3] + .iter() + .map(|chunk| { + format!("data: {}\n\n", serde_json::to_string(chunk).unwrap()).into_bytes() + }) + .collect(); + chunks.push(b"data: [DONE]\n\n".to_vec()); + return http_response(200, sse_headers(), chunks); + } + + let buffered = json!({ + "id": "chatcmpl-stub-1", + "object": "chat.completion", + "created": 1, + "model": "claude-sonnet-4.5", + "choices": [{ + "index": 0, + "message": { "role": "assistant", "content": text }, + "finish_reason": "stop", + }], + "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }, + }); + http_response( + 200, + json_headers(), + vec![serde_json::to_string(&buffered).unwrap().into_bytes()], + ) +} + +async fn wait_for_flag(flag: &AtomicBool, what: &str) { + let deadline = Instant::now() + Duration::from_secs(60); + while !flag.load(Ordering::SeqCst) { + assert!(Instant::now() < deadline, "timed out waiting for {what}"); + tokio::time::sleep(Duration::from_millis(50)).await; + } +} + +async fn session_send(session: &github_copilot_sdk::session::Session) -> Option { + session + .send_and_wait(say_ok()) + .await + .expect("send_and_wait") +} + +// --------------------------------------------------------------------------- +// Scenario 1: handler — one handler forwards both HTTP and WebSocket traffic to +// local upstreams, mutating traffic on the way through. +// --------------------------------------------------------------------------- + +#[derive(Clone, Default)] +struct HandlerCounters { + http_requests: Arc, + http_responses: Arc, + ws_request_messages: Arc, + ws_response_messages: Arc, + upstream_ws_requests: Arc, +} + +struct ForwardingHandler { + http_authority: String, + ws_authority: String, + counters: HandlerCounters, +} + +fn rewrite_authority( + url: &str, + scheme: &str, + authority: &str, +) -> Result { + let uri: Uri = url + .parse() + .map_err(|e| CopilotRequestError::message(format!("invalid url {url}: {e}")))?; + let path_and_query = uri.path_and_query().map(|p| p.as_str()).unwrap_or("/"); + Ok(format!("{scheme}://{authority}{path_and_query}")) +} + +#[async_trait] +impl CopilotRequestHandler for ForwardingHandler { + async fn send_request( + &self, + mut request: CopilotHttpRequest, + _ctx: &CopilotRequestContext, + ) -> Result { + self.counters.http_requests.fetch_add(1, Ordering::SeqCst); + request.url = rewrite_authority(&request.url, "http", &self.http_authority)?; + request + .headers + .insert("x-test-mutated", HeaderValue::from_static("1")); + let mut response = forward_http(request).await?; + self.counters.http_responses.fetch_add(1, Ordering::SeqCst); + response + .headers + .insert("x-test-response-mutated", HeaderValue::from_static("1")); + Ok(response) + } + + async fn open_websocket( + &self, + ctx: &CopilotRequestContext, + response: CopilotWebSocketResponse, + ) -> Result, CopilotRequestError> { + let ws_url = rewrite_authority(&ctx.url, "ws", &self.ws_authority)?; + let request_counter = self.counters.ws_request_messages.clone(); + let response_counter = self.counters.ws_response_messages.clone(); + let handler = ForwardingCopilotWebSocketHandler::builder(ws_url, ctx.headers.clone()) + .on_send_request_message(Arc::new(move |message| { + request_counter.fetch_add(1, Ordering::SeqCst); + Some(message) + })) + .on_send_response_message(Arc::new(move |message| { + response_counter.fetch_add(1, Ordering::SeqCst); + Some(message) + })) + .connect(response) + .await?; + Ok(Box::new(handler)) + } +} + +fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option { + haystack + .windows(needle.len()) + .position(|window| window == needle) +} + +fn route_http_upstream(path: &str) -> (u16, &'static str, String) { + if path.ends_with("/models") { + ( + 200, + "application/json", + model_catalog(Some(WS_SUPPORTED_ENDPOINTS)), + ) + } else if path.ends_with("/models/session") { + (200, "application/json", "{}".to_string()) + } else if path.contains("/policy") { + ( + 200, + "application/json", + r#"{"state":"enabled"}"#.to_string(), + ) + } else if path.ends_with("/responses") { + let mut body = String::new(); + for event in responses_events(HANDLER_HTTP_TEXT, "resp_stub_http") { + body.push_str(&sse(event["type"].as_str().unwrap(), &event)); + } + (200, "text/event-stream", body) + } else { + ( + 404, + "application/json", + r#"{"error":"not_found"}"#.to_string(), + ) + } +} + +async fn serve_http_conn(socket: &mut TcpStream) -> std::io::Result<()> { + let mut buf = Vec::new(); + let mut tmp = [0u8; 4096]; + let header_end = loop { + let n = socket.read(&mut tmp).await?; + if n == 0 { + return Ok(()); + } + buf.extend_from_slice(&tmp[..n]); + if let Some(pos) = find_subsequence(&buf, b"\r\n\r\n") { + break pos + 4; + } + }; + let head = String::from_utf8_lossy(&buf[..header_end]).to_string(); + let content_length = head + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + if name.trim().eq_ignore_ascii_case("content-length") { + value.trim().parse::().ok() + } else { + None + } + }) + .unwrap_or(0); + let mut remaining = content_length.saturating_sub(buf.len() - header_end); + while remaining > 0 { + let n = socket.read(&mut tmp).await?; + if n == 0 { + break; + } + remaining = remaining.saturating_sub(n); + } + + let request_path = head + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .unwrap_or("/") + .split('?') + .next() + .unwrap_or("/") + .to_lowercase(); + let (status, content_type, body) = route_http_upstream(&request_path); + let reason = if status == 200 { "OK" } else { "Not Found" }; + let head = format!( + "HTTP/1.1 {status} {reason}\r\ncontent-type: {content_type}\r\ncontent-length: {}\r\nconnection: close\r\n\r\n", + body.len() + ); + socket.write_all(head.as_bytes()).await?; + socket.write_all(body.as_bytes()).await?; + socket.flush().await?; + let _ = socket.shutdown().await; + Ok(()) +} + +async fn start_http_upstream() -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let authority = listener.local_addr().unwrap().to_string(); + tokio::spawn(async move { + while let Ok((mut socket, _)) = listener.accept().await { + tokio::spawn(async move { + let _ = serve_http_conn(&mut socket).await; + }); + } + }); + authority +} + +async fn start_ws_upstream(counters: HandlerCounters) -> String { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let authority = listener.local_addr().unwrap().to_string(); + tokio::spawn(async move { + while let Ok((socket, _)) = listener.accept().await { + let counters = counters.clone(); + tokio::spawn(async move { + let ws = match tokio_tungstenite::accept_async(socket).await { + Ok(ws) => ws, + Err(_) => return, + }; + let (mut write, mut read) = ws.split(); + while let Some(Ok(message)) = read.next().await { + match message { + Message::Text(_) | Message::Binary(_) => { + counters.upstream_ws_requests.fetch_add(1, Ordering::SeqCst); + for event in responses_events(HANDLER_WS_TEXT, "resp_stub_ws") { + let raw = serde_json::to_string(&event).unwrap(); + if write.send(Message::Text(raw)).await.is_err() { + return; + } + } + } + Message::Close(_) => break, + _ => {} + } + } + }); + } + }); + authority +} + +#[tokio::test] +async fn services_http_and_websocket_via_handler() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let counters = HandlerCounters::default(); + let http_authority = start_http_upstream().await; + let ws_authority = start_ws_upstream(counters.clone()).await; + + let handler = ForwardingHandler { + http_authority, + ws_authority, + counters: counters.clone(), + }; + let client = ctx + .start_llm_client( + handler, + &[("COPILOT_EXP_COPILOT_CLI_WEBSOCKET_RESPONSES", "true")], + ) + .await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + let result = session + .send_and_wait(say_ok()) + .await + .expect("send_and_wait"); + let _ = session.disconnect().await; + + assert!( + counters.http_requests.load(Ordering::SeqCst) > 0, + "expected the HTTP forwarder to fire" + ); + assert!( + counters.http_responses.load(Ordering::SeqCst) > 0, + "expected the HTTP response mutation to fire" + ); + assert!( + counters.ws_request_messages.load(Ordering::SeqCst) > 0, + "expected runtime → upstream ws messages" + ); + assert!( + counters.ws_response_messages.load(Ordering::SeqCst) > 0, + "expected upstream → runtime ws messages" + ); + assert!( + counters.upstream_ws_requests.load(Ordering::SeqCst) > 0, + "expected the upstream WS to receive request messages" + ); + + // Validate the final assistant response arrived (guards against truncated captures) + let text = assistant_text(&result); + assert!( + text.contains("OK from synthetic") && text.contains("upstream"), + "expected synthetic upstream content in assistant reply, got {text:?}" + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +// --------------------------------------------------------------------------- +// Scenario 2: session id — the runtime threads the session id into CAPI and +// BYOK inference requests serviced entirely by the handler. +// --------------------------------------------------------------------------- + +#[derive(Default)] +struct RecordingHandler { + records: std::sync::Mutex)>>, +} + +impl RecordingHandler { + fn inference_records(&self) -> Vec<(String, Option)> { + self.records + .lock() + .unwrap() + .iter() + .filter(|(url, _)| is_inference_url(url)) + .cloned() + .collect() + } +} + +#[async_trait] +impl CopilotRequestHandler for RecordingHandler { + async fn send_request( + &self, + request: CopilotHttpRequest, + ctx: &CopilotRequestContext, + ) -> Result { + self.records + .lock() + .unwrap() + .push((request.url.clone(), ctx.session_id.clone())); + if is_inference_url(&request.url) { + Ok(synth_inference_response( + &request.url, + &request.body, + SYNTHETIC_TEXT, + )) + } else { + Ok(synth_non_inference_response(&request.url, None)) + } + } +} + +#[tokio::test] +async fn threads_session_id_into_inference() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(RecordingHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + + // CAPI session. + let capi_session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create CAPI session"); + let capi_session_id = capi_session.id().as_str().to_string(); + let result = session_send(&capi_session).await; + let _ = capi_session.disconnect().await; + + let inference = handler.inference_records(); + assert!( + !inference.is_empty(), + "expected at least one intercepted inference request" + ); + for (_, session_id) in &inference { + assert_eq!( + session_id.as_deref(), + Some(capi_session_id.as_str()), + "CAPI inference request must carry the session id" + ); + } + assert!( + assistant_text(&result).contains("OK from the synthetic"), + "expected synthetic content in CAPI reply, got {:?}", + assistant_text(&result) + ); + + // BYOK session. + let before = handler.inference_records().len(); + let byok_config = SessionConfig::default() + .with_permission_handler(Arc::new(ApproveAllHandler)) + .with_model("claude-sonnet-4.5") + .with_provider( + ProviderConfig::new("https://byok.invalid/v1") + .with_provider_type("openai") + .with_wire_api("responses") + .with_api_key("byok-secret") + .with_model_id("claude-sonnet-4.5") + .with_wire_model("claude-sonnet-4.5"), + ); + let byok_session = client + .create_session(byok_config) + .await + .expect("create BYOK session"); + let byok_session_id = byok_session.id().as_str().to_string(); + let result = session_send(&byok_session).await; + let _ = byok_session.disconnect().await; + + let inference = handler.inference_records(); + assert!( + inference.len() > before, + "expected at least one intercepted BYOK inference request" + ); + for (_, session_id) in &inference[before..] { + assert_eq!( + session_id.as_deref(), + Some(byok_session_id.as_str()), + "BYOK inference request must carry the session id" + ); + } + assert_ne!( + byok_session_id, capi_session_id, + "expected per-session ids to differ between turns" + ); + assert!( + assistant_text(&result).contains("OK from the synthetic"), + "expected synthetic content in BYOK reply, got {:?}", + assistant_text(&result) + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +// --------------------------------------------------------------------------- +// Scenario 3a: errors — a handler that returns `Err` on an inference request +// surfaces a transport error rather than hanging the turn. +// --------------------------------------------------------------------------- + +#[derive(Default)] +struct ThrowingHandler { + inference_attempts: AtomicU32, +} + +#[async_trait] +impl CopilotRequestHandler for ThrowingHandler { + async fn send_request( + &self, + request: CopilotHttpRequest, + _ctx: &CopilotRequestContext, + ) -> Result { + if !is_inference_url(&request.url) { + return Ok(synth_non_inference_response(&request.url, None)); + } + self.inference_attempts.fetch_add(1, Ordering::SeqCst); + Err(CopilotRequestError::message( + "synthetic-callback-transport-failure", + )) + } +} + +#[tokio::test] +async fn surfaces_handler_errors() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(ThrowingHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + // The handler returns Err from the inference seam; the agent layer + // surfaces it as an error rather than hanging. + let send_result = session.send_and_wait(say_ok()).await; + let _ = session.disconnect().await; + + assert!( + handler.inference_attempts.load(Ordering::SeqCst) > 0, + "expected the inference callback to be reached and raise" + ); + if let Err(err) = send_result { + assert!( + !err.to_string().is_empty(), + "expected a non-empty error string when an error surfaces" + ); + } + + client.stop().await.expect("stop client"); + }) + }) + .await; +} + +// --------------------------------------------------------------------------- +// Scenario 3b: runtime-driven cancel — the handler blocks an inference request +// until the consumer aborts the turn; the runtime cancels the in-flight request +// and the handler observes it via `ctx.cancel`. +// --------------------------------------------------------------------------- + +#[derive(Default)] +struct CancellingHandler { + inference_entered: AtomicBool, + saw_abort: AtomicBool, +} + +#[async_trait] +impl CopilotRequestHandler for CancellingHandler { + async fn send_request( + &self, + request: CopilotHttpRequest, + ctx: &CopilotRequestContext, + ) -> Result { + if !is_inference_url(&request.url) { + return Ok(synth_non_inference_response(&request.url, None)); + } + + // Inference: never produce a response. Wait for the runtime to cancel + // us, recording the abort, then propagate it as an error. + self.inference_entered.store(true, Ordering::SeqCst); + ctx.cancel.cancelled().await; + self.saw_abort.store(true, Ordering::SeqCst); + Err(CopilotRequestError::message("cancelled by runtime")) + } +} + +#[tokio::test] +async fn observes_runtime_driven_cancel() { + with_e2e_context_no_snapshot(|ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let handler = Arc::new(CancellingHandler::default()); + let client = ctx.start_llm_client(handler.clone(), &[]).await; + let session = client + .create_session(ctx.approve_all_session_config()) + .await + .expect("create session"); + + session.send(say_ok()).await.expect("send"); + wait_for_flag(&handler.inference_entered, "inference entered").await; + session.abort().await.expect("abort"); + wait_for_flag(&handler.saw_abort, "consumer observed cancellation").await; + let _ = session.disconnect().await; + + assert!( + handler.inference_entered.load(Ordering::SeqCst), + "expected the inference callback to be entered" + ); + assert!( + handler.saw_abort.load(Ordering::SeqCst), + "expected the consumer to observe the runtime-driven cancellation" + ); + + client.stop().await.expect("stop client"); + }) + }) + .await; +} diff --git a/rust/tests/e2e/support.rs b/rust/tests/e2e/support.rs index 5554c5a06..1805eb145 100644 --- a/rust/tests/e2e/support.rs +++ b/rust/tests/e2e/support.rs @@ -12,8 +12,8 @@ use github_copilot_sdk::handler::ApproveAllHandler; use github_copilot_sdk::session::Session; use github_copilot_sdk::subscription::{EventSubscription, LifecycleSubscription}; use github_copilot_sdk::{ - CliProgram, Client, ClientOptions, SessionConfig, SessionEvent, SessionId, - SessionLifecycleEvent, Transport, + CliProgram, Client, ClientOptions, CopilotRequestHandler, SessionConfig, SessionEvent, + SessionId, SessionLifecycleEvent, Transport, }; use serde_json::json; use tokio::sync::Semaphore; @@ -49,6 +49,35 @@ where ); } +/// Like [`with_e2e_context`] but starts the CapiProxy without loading a +/// recorded snapshot. Used by the LLM inference callback tests, whose +/// registered provider fabricates every model-layer response so no CAPI +/// replay is needed — only the auth/user endpoints are served by the proxy. +pub async fn with_e2e_context_no_snapshot(test: F) +where + F: for<'a> FnOnce(&'a mut E2eContext) -> TestFuture<'a>, +{ + let _permit = E2E_CONCURRENCY + .acquire() + .await + .expect("E2E concurrency semaphore should stay open"); + let mut ctx = E2eContext::new_no_snapshot() + .await + .unwrap_or_else(|err| panic!("create E2E context: {err}")); + + let timed_out = tokio::time::timeout(default_test_timeout(), test(&mut ctx)) + .await + .is_err(); + ctx.cleanup(timed_out) + .await + .unwrap_or_else(|err| panic!("clean up E2E context: {err}")); + assert!( + !timed_out, + "timed out after {:?} running no-snapshot E2E test", + default_test_timeout() + ); +} + pub struct E2eContext { repo_root: PathBuf, cli_path: PathBuf, @@ -78,6 +107,35 @@ impl E2eContext { Ok(ctx) } + async fn new_no_snapshot() -> std::io::Result { + let repo_root = repo_root(); + let cli_path = cli_path(&repo_root)?; + let home_dir = tempfile::tempdir()?; + let work_dir = tempfile::tempdir()?; + let proxy_root = repo_root.clone(); + let proxy = tokio::task::spawn_blocking(move || CapiProxy::start(&proxy_root)) + .await + .map_err(|err| std::io::Error::other(format!("proxy startup task failed: {err}")))??; + let ctx = Self { + repo_root, + cli_path, + home_dir, + work_dir, + proxy: Some(proxy), + }; + // Initialize proxy state without replaying any recorded exchanges: the + // snapshot path intentionally does not exist, so `/copilot_internal/user` + // and the default `/models` catalog are served while all model-layer + // traffic is fabricated by the registered inference callback. + let dummy_snapshot = ctx.work_dir.path().join("__no_snapshot__.yaml"); + ctx.proxy() + .configure(&dummy_snapshot, ctx.work_dir.path()) + .map_err(|err| { + std::io::Error::other(format!("configure proxy without snapshot failed: {err}")) + })?; + Ok(ctx) + } + pub fn repo_root(&self) -> &Path { &self.repo_root } @@ -117,6 +175,29 @@ impl E2eContext { .expect("start E2E client") } + /// Start a client wired to a Copilot request handler, appending `extra_env` + /// to the spawned runtime's environment (used to flip the WebSocket ExP + /// flag for the WS transport tests). + pub async fn start_llm_client(&self, handler: H, extra_env: &[(&str, &str)]) -> Client + where + H: CopilotRequestHandler, + { + let mut env = self.environment(); + env.extend( + extra_env + .iter() + .map(|(key, value)| (OsString::from(*key), OsString::from(*value))), + ); + let options = ClientOptions::new() + .with_program(CliProgram::Path(PathBuf::from(node_program()))) + .with_prefix_args([self.cli_path.as_os_str().to_owned()]) + .with_cwd(self.work_dir.path()) + .with_env(env) + .with_use_logged_in_user(false) + .with_request_handler(handler); + Client::start(options).await.expect("start E2E LLM client") + } + #[expect(dead_code, reason = "used by follow-on E2E ports")] pub async fn start_tcp_client(&self, port: u16, token: &str) -> Client { Client::start(self.client_options_with_transport(Transport::Tcp { diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index 7b0f64a77..e1ceea5b1 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -2297,6 +2297,142 @@ function emitClientSessionApiRegistration(clientSchema: Record, return lines; } +/** + * Emit C# handler interfaces + a process-wide registration for client + * *global* API groups. + * + * Unlike client-session APIs, these methods carry no implicit `sessionId` + * dispatch key. The SDK consumer registers a single process-wide handler set + * via `RegisterClientGlobalApiHandlers`; the runtime dispatcher routes each + * incoming call to the registered handler regardless of which (if any) + * runtime session triggered it. + */ +function emitClientGlobalApiRegistration(clientSchema: Record, classes: string[]): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const { methods } of groups) { + for (const method of methods) { + const resultSchema = getMethodResultSchema(method); + if (!isVoidSchema(resultSchema) && !isOpaqueJson(resultSchema)) { + emitRpcResultType(resultTypeName(method), resultSchema!, "public", classes); + } + + const effectiveParams = resolveMethodParamsSchema(method); + if (effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0) { + const paramsClass = emitRpcClass(paramsTypeName(method), effectiveParams, "public", classes); + if (paramsClass) classes.push(paramsClass); + } + } + } + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + const groupDeprecated = isNodeFullyDeprecated(groupNode); + lines.push(`/// Handles \`${groupName}\` client global API methods.`); + if (groupExperimental) { + pushExperimentalAttribute(lines); + } + if (groupDeprecated) { + pushObsoleteAttributes(lines); + } + lines.push(`public interface ${interfaceName}`); + lines.push(`{`); + for (const method of methods) { + const effectiveParams = resolveMethodParamsSchema(method); + const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; + const resultSchema = getMethodResultSchema(method); + const taskType = resultTaskType(method); + pushRpcMethodXmlDocs( + lines, + method, + " ", + [ + ...(hasParams ? [{ name: "request", description: rpcParamsDescription(method, effectiveParams) }] : []), + { name: "cancellationToken", description: CANCELLATION_TOKEN_DESCRIPTION, escapeDescription: false }, + ], + resultSchema, + `Handles "${method.rpcMethod}".` + ); + if (method.stability === "experimental" && !groupExperimental) { + pushExperimentalAttribute(lines, " "); + } + if (method.deprecated && !groupDeprecated) { + pushObsoleteAttributes(lines, " "); + } + if (hasParams) { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(${paramsTypeName(method)} request, CancellationToken cancellationToken = default);`); + } else { + lines.push(` ${taskType} ${clientHandlerMethodName(method.rpcMethod)}(CancellationToken cancellationToken = default);`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/// Provides all client global API handler groups for a connection.`); + lines.push(`public sealed class ClientGlobalApiHandlers`); + lines.push(`{`); + for (const { groupName } of groups) { + lines.push(` /// Optional handler for ${toPascalCase(groupName)} client global API methods.`); + lines.push(` public ${clientHandlerInterfaceName(groupName)}? ${toPascalCase(groupName)} { get; set; }`); + lines.push(""); + } + if (lines[lines.length - 1] === "") lines.pop(); + lines.push(`}`); + lines.push(""); + + lines.push(`/// Registers client global API handlers on a JSON-RPC connection.`); + lines.push(`internal static class ClientGlobalApiRegistration`); + lines.push(`{`); + lines.push(` /// `); + lines.push(` /// Registers handlers for server-to-client global API calls.`); + lines.push(` /// Unlike client session APIs, these methods carry no implicit`); + lines.push(` /// sessionId dispatch key — a single set of handlers serves the`); + lines.push(` /// entire connection.`); + lines.push(` /// `); + lines.push(` public static void RegisterClientGlobalApiHandlers(JsonRpc rpc, ClientGlobalApiHandlers handlers)`); + lines.push(` {`); + for (const { groupName, methods } of groups) { + for (const method of methods) { + const handlerProperty = toPascalCase(groupName); + const handlerMethod = clientHandlerMethodName(method.rpcMethod); + const effectiveParams = resolveMethodParamsSchema(method); + const hasParams = !!effectiveParams?.properties && Object.keys(effectiveParams.properties).length > 0; + const resultSchema = getMethodResultSchema(method); + const paramsClass = paramsTypeName(method); + const taskType = handlerTaskType(method); + + if (hasParams) { + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func<${paramsClass}, CancellationToken, ${taskType}>)(async (request, cancellationToken) =>`); + lines.push(` {`); + lines.push(` var handler = handlers.${handlerProperty} ?? throw new InvalidOperationException("No ${groupName} client-global handler registered");`); + if (!isVoidSchema(resultSchema)) { + lines.push(` return await handler.${handlerMethod}(request, cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(request, cancellationToken);`); + } + lines.push(` }), singleObjectParam: true);`); + } else { + lines.push(` rpc.SetLocalRpcMethod("${method.rpcMethod}", (Func)(async cancellationToken =>`); + lines.push(` {`); + lines.push(` var handler = handlers.${handlerProperty} ?? throw new InvalidOperationException("No ${groupName} client-global handler registered");`); + if (!isVoidSchema(resultSchema)) { + lines.push(` return await handler.${handlerMethod}(cancellationToken);`); + } else { + lines.push(` await handler.${handlerMethod}(cancellationToken);`); + } + lines.push(` }));`); + } + } + } + lines.push(` }`); + lines.push(`}`); + + return lines; +} + function generateRpcCode( schema: ApiSchema, externalJsonSerializableRefs: Map> = new Map(), @@ -2315,6 +2451,7 @@ function generateRpcCode( ...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {}), ...collectRpcMethods(schema.clientSession || {}), + ...collectRpcMethods(schema.clientGlobal || {}), ]; for (const name of collectRpcMethodReferencedDefinitionNames( allMethods.filter((method) => method.stability !== "experimental"), @@ -2343,6 +2480,9 @@ function generateRpcCode( let clientSessionParts: string[] = []; if (schema.clientSession) clientSessionParts = emitClientSessionApiRegistration(schema.clientSession, classes); + let clientGlobalParts: string[] = []; + if (schema.clientGlobal) clientGlobalParts = emitClientGlobalApiRegistration(schema.clientGlobal, classes); + const lines: string[] = []; lines.push(`${COPYRIGHT} @@ -2368,6 +2508,7 @@ namespace GitHub.Copilot.Rpc; for (const part of serverRpcParts) lines.push(part, ""); for (const part of sessionRpcParts) lines.push(part, ""); if (clientSessionParts.length > 0) lines.push(...clientSessionParts, ""); + if (clientGlobalParts.length > 0) lines.push(...clientGlobalParts, ""); // Add JsonSerializerContext for AOT/trimming support const typeNames = [...emittedRpcClassSchemas.keys(), ...emittedRpcEnumResultTypes].sort(); diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 9c74977fb..5403fb444 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -3952,7 +3952,7 @@ async function generateRpc(schemaPath?: string): Promise { if (generatedTypeCode.includes("time.Time")) { imports.push(`"time"`); } - if (schema.clientSession) { + if (schema.clientSession || schema.clientGlobal) { imports.push(`"errors"`, `"fmt"`); } imports.push(`"github.com/github/copilot-sdk/go/internal/jsonrpc2"`); @@ -3987,6 +3987,10 @@ async function generateRpc(schemaPath?: string): Promise { emitClientSessionApiRegistration(lines, schema.clientSession, resolveType, generatedRpcCode.discriminatedUnions); } + if (schema.clientGlobal) { + emitClientGlobalApiRegistration(lines, schema.clientGlobal, resolveType, generatedRpcCode.discriminatedUnions); + } + const outPath = await writeGeneratedFile("go/rpc/zrpc.go", wrapGeneratedGoComments(lines.join("\n"))); console.log(` ✓ ${outPath}`); @@ -4348,7 +4352,106 @@ function emitClientSessionApiRegistration(lines: string[], clientSchema: Record< lines.push(``); } -// ── Main ──────────────────────────────────────────────────────────────────── +function emitClientGlobalApiRegistration(lines: string[], clientSchema: Record, resolveType: (name: string) => string, unionInfos: Map): void { + const groups = collectClientGroups(clientSchema); + + for (const { groupName, groupNode, methods } of groups) { + const interfaceName = clientHandlerInterfaceName(groupName); + const groupExperimental = isNodeFullyExperimental(groupNode); + const groupDeprecated = isNodeFullyDeprecated(groupNode); + if (groupDeprecated) { + pushGoComment(lines, `Deprecated: ${interfaceName} contains deprecated APIs that will be removed in a future version.`); + } + if (groupExperimental) { + pushGoExperimentalApiComment(lines, interfaceName); + } + lines.push(`type ${interfaceName} interface {`); + for (const method of methods) { + const resultSchema = getMethodResultSchema(method); + pushGoRpcMethodComment( + lines, + clientHandlerMethodName(method.rpcMethod), + method, + resultSchema, + goRpcParamsDescription(method, getMethodParamsSchema(method)), + "\t", + "handles" + ); + if (method.deprecated && !groupDeprecated) { + pushGoComment(lines, `Deprecated: ${clientHandlerMethodName(method.rpcMethod)} is deprecated and will be removed in a future version.`, "\t"); + } + if (method.stability === "experimental" && !groupExperimental) { + pushGoExperimentalMethodComment(lines, clientHandlerMethodName(method.rpcMethod), "\t"); + } + const paramsType = resolveType(goParamsTypeName(method)); + const nullableInner = resultSchema ? getNullableInner(resultSchema) : undefined; + let returnType: string; + if (isOpaqueJson(resultSchema)) { + returnType = "any"; + } else { + const resultType = nullableInner + ? resolveType(goNullableResultTypeName(method, nullableInner)) + : resolveType(goResultTypeName(method)); + returnType = unionInfos.has(resultType) ? resultType : `*${resultType}`; + } + lines.push(`\t${clientHandlerMethodName(method.rpcMethod)}(request *${paramsType}) (${returnType}, error)`); + } + lines.push(`}`); + lines.push(``); + } + + lines.push(`// ClientGlobalAPIHandlers provides all client-global API handler groups.`); + lines.push(`//`); + lines.push(`// Unlike client-session handlers these carry no implicit session id dispatch`); + lines.push(`// key; a single set of handlers serves the entire connection.`); + lines.push(`type ClientGlobalAPIHandlers struct {`); + for (const { groupName } of groups) { + lines.push(`\t${toGoFieldName(groupName)} ${clientHandlerInterfaceName(groupName)}`); + } + lines.push(`}`); + lines.push(``); + + lines.push(`func clientGlobalHandlerError(err error) *jsonrpc2.Error {`); + lines.push(`\tif err == nil {`); + lines.push(`\t\treturn nil`); + lines.push(`\t}`); + lines.push(`\tvar rpcErr *jsonrpc2.Error`); + lines.push(`\tif errors.As(err, &rpcErr) {`); + lines.push(`\t\treturn rpcErr`); + lines.push(`\t}`); + lines.push(`\treturn &jsonrpc2.Error{Code: -32603, Message: err.Error()}`); + lines.push(`}`); + lines.push(``); + + lines.push(`// RegisterClientGlobalAPIHandlers registers handlers for server-to-client client-global API calls.`); + lines.push(`func RegisterClientGlobalAPIHandlers(client *jsonrpc2.Client, handlers *ClientGlobalAPIHandlers) {`); + for (const { groupName, methods } of groups) { + const handlerField = toGoFieldName(groupName); + for (const method of methods) { + const paramsType = resolveType(goParamsTypeName(method)); + lines.push(`\tclient.SetRequestHandler("${method.rpcMethod}", func(params json.RawMessage) (json.RawMessage, *jsonrpc2.Error) {`); + lines.push(`\t\tvar request ${paramsType}`); + lines.push(`\t\tif err := json.Unmarshal(params, &request); err != nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32602, Message: fmt.Sprintf("Invalid params: %v", err)}`); + lines.push(`\t\t}`); + lines.push(`\t\tif handlers == nil || handlers.${handlerField} == nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32603, Message: "No ${groupName} client-global handler registered"}`); + lines.push(`\t\t}`); + lines.push(`\t\tresult, err := handlers.${handlerField}.${clientHandlerMethodName(method.rpcMethod)}(&request)`); + lines.push(`\t\tif err != nil {`); + lines.push(`\t\t\treturn nil, clientGlobalHandlerError(err)`); + lines.push(`\t\t}`); + lines.push(`\t\traw, err := json.Marshal(result)`); + lines.push(`\t\tif err != nil {`); + lines.push(`\t\t\treturn nil, &jsonrpc2.Error{Code: -32603, Message: fmt.Sprintf("Failed to marshal response: %v", err)}`); + lines.push(`\t\t}`); + lines.push(`\t\treturn raw, nil`); + lines.push(`\t})`); + } + } + lines.push(`}`); + lines.push(``); +} async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { let apiSchemaForSharing: ApiSchema | undefined; diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index 783ac8244..e0c4e2141 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -3190,6 +3190,9 @@ def _patch_model_capabilities(data: dict) -> dict: if (schema.clientSession) { emitClientSessionApiRegistration(lines, schema.clientSession, resolveType); } + if (schema.clientGlobal) { + emitClientGlobalApiRegistration(lines, schema.clientGlobal, resolveType); + } // Patch models.list to normalize capabilities before deserialization let finalCode = lines.join("\n"); @@ -3712,7 +3715,107 @@ function emitClientSessionRegistrationMethod( lines.push(` client.set_request_handler("${method.rpcMethod}", ${handlerVariableName})`); } -// ── Main ──────────────────────────────────────────────────────────────────── +function emitClientGlobalApiRegistration( + lines: string[], + node: Record, + resolveType: (name: string) => string +): void { + const groups = Object.entries(node).filter(([, value]) => typeof value === "object" && value !== null && !isRpcMethod(value)); + + for (const [groupName, groupNode] of groups) { + const handlerName = `${toPascalCase(groupName)}Handler`; + const groupExperimental = isNodeFullyExperimental(groupNode as Record); + const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); + if (groupDeprecated) { + lines.push(`# Deprecated: this API group is deprecated and will be removed in a future version.`); + } + if (groupExperimental) { + pushPyExperimentalApiGroupComment(lines); + } + lines.push(`class ${handlerName}(Protocol):`); + const methods = collectRpcMethods(groupNode as Record); + for (const method of methods) { + // Client-global handler methods reuse the session handler shape; the + // only difference is dispatch (no implicit session_id key). + emitClientSessionHandlerMethod(lines, method, resolveType, groupExperimental, groupDeprecated); + } + lines.push(``); + } + + lines.push(`@dataclass`); + lines.push(`class ClientGlobalApiHandlers:`); + if (groups.length === 0) { + lines.push(` pass`); + } else { + for (const [groupName] of groups) { + lines.push(` ${toSnakeCase(groupName)}: ${toPascalCase(groupName)}Handler | None = None`); + } + } + lines.push(``); + + lines.push(`def register_client_global_api_handlers(`); + lines.push(` client: "JsonRpcClient",`); + lines.push(` handlers: ClientGlobalApiHandlers,`); + lines.push(`) -> None:`); + lines.push(` """Register client-global request handlers on a JSON-RPC connection.`); + lines.push(``); + lines.push(` Unlike client-session handlers these methods carry no implicit`); + lines.push(` session_id dispatch key; a single set of handlers serves the entire`); + lines.push(` connection.`); + lines.push(` """`); + if (groups.length === 0) { + lines.push(` return`); + } else { + for (const [groupName, groupNode] of groups) { + const methods = collectRpcMethods(groupNode as Record); + for (const method of methods) { + emitClientGlobalRegistrationMethod(lines, groupName, method, resolveType); + } + } + } + lines.push(``); +} + +function emitClientGlobalRegistrationMethod( + lines: string[], + groupName: string, + method: RpcMethod, + resolveType: (name: string) => string +): void { + const rpcSegments = method.rpcMethod.split("."); + const handlerVariableName = `handle_${rpcSegments.map(toSnakeCase).join("_")}`; + const paramsType = resolveType(pythonParamsTypeName(method)); + const resultSchema = getMethodResultSchema(method); + const nullableInner = resultSchema ? getNullableInner(resultSchema) : undefined; + const hasResult = !isVoidSchema(resultSchema) && !nullableInner; + const handlerField = toSnakeCase(groupName); + const handlerMethod = clientSessionHandlerMethodName(method.rpcMethod); + + lines.push(` async def ${handlerVariableName}(params: dict) -> dict | None:`); + lines.push(` request = ${paramsType}.from_dict(params)`); + lines.push(` handler = handlers.${handlerField}`); + lines.push(` if handler is None: raise RuntimeError("No ${handlerField} client-global handler registered")`); + if (hasResult) { + lines.push(` result = await handler.${handlerMethod}(request)`); + if (isObjectSchema(resultSchema)) { + lines.push(` return result.to_dict()`); + } else { + lines.push(` return result.value if hasattr(result, 'value') else result`); + } + } else if (nullableInner) { + lines.push(` result = await handler.${handlerMethod}(request)`); + const resolvedInner = resolveSchema(nullableInner, rpcDefinitions) ?? nullableInner; + if (isObjectSchema(resolvedInner) || nullableInner.$ref) { + lines.push(` return result.to_dict() if result is not None else None`); + } else { + lines.push(` return result`); + } + } else { + lines.push(` await handler.${handlerMethod}(request)`); + lines.push(` return None`); + } + lines.push(` client.set_request_handler("${method.rpcMethod}", ${handlerVariableName})`); +} async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { await generateSessionEvents(sessionSchemaPath); diff --git a/scripts/codegen/typescript.ts b/scripts/codegen/typescript.ts index bba360b47..1303a4979 100644 --- a/scripts/codegen/typescript.ts +++ b/scripts/codegen/typescript.ts @@ -516,7 +516,8 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; const allMethods = [...collectRpcMethods(schema.server || {}), ...collectRpcMethods(schema.session || {})]; const clientSessionMethods = collectRpcMethods(schema.clientSession || {}); - const rpcMethods = [...allMethods, ...clientSessionMethods]; + const clientGlobalMethods = collectRpcMethods(schema.clientGlobal || {}); + const rpcMethods = [...allMethods, ...clientSessionMethods, ...clientGlobalMethods]; const seenBlocks = new Map(); // Build a single combined schema with shared definitions and all method types. @@ -717,6 +718,13 @@ function hasInternalMethods(node: Record): boolean { lines.push(...emitClientSessionApiRegistration(schema.clientSession)); } + // Generate client *global* API handler interfaces and registration function. + // Unlike client-session APIs, these methods do not carry a `sessionId` dispatch + // key — the SDK consumer registers a single process-wide handler per group. + if (schema.clientGlobal) { + lines.push(...emitClientGlobalApiRegistration(schema.clientGlobal)); + } + const outPath = await writeGeneratedFile("nodejs/src/generated/rpc.ts", lines.join("\n")); console.log(` ✓ ${outPath}`); } @@ -926,6 +934,105 @@ function emitClientSessionApiRegistration(clientSchema: Record) return lines; } +/** + * Generate handler interfaces and a registration function for client *global* + * API groups. + * + * Unlike client-session APIs, these methods carry no implicit `sessionId` + * dispatch key. The SDK consumer registers a single process-wide handler set + * via `registerClientGlobalApiHandlers`; the runtime dispatcher routes each + * incoming call to the registered handler regardless of which (if any) + * runtime session triggered it. + */ +function emitClientGlobalApiRegistration(clientSchema: Record): string[] { + const lines: string[] = []; + const groups = collectClientGroups(clientSchema); + + for (const [groupName, methods] of groups) { + const interfaceName = toPascalCase(groupName) + "Handler"; + const groupDeprecated = isNodeFullyDeprecated(clientSchema[groupName] as Record); + const groupExperimental = isNodeFullyExperimental(clientSchema[groupName] as Record); + if (groupDeprecated) { + lines.push(`/** @deprecated Handler for \`${groupName}\` client global API methods. */`); + } else if (groupExperimental) { + lines.push(`/** Handler for \`${groupName}\` client global API methods. */`); + lines.push(TS_EXPERIMENTAL_JSDOC); + } else { + lines.push(`/** Handler for \`${groupName}\` client global API methods. */`); + } + lines.push(`export interface ${interfaceName} {`); + for (const method of methods) { + const name = handlerMethodName(method.rpcMethod); + const hasParams = hasSchemaPayload(getMethodParamsSchema(method)); + const pType = hasParams ? paramsTypeName(method) : ""; + const rType = tsResultType(method); + + pushTsRpcMethodJsDoc(lines, " ", method, { + summaryFallback: `Handles \`${method.rpcMethod}\`.`, + paramsName: hasParams ? "params" : undefined, + paramsDescription: rpcParamsDescription(method, getMethodParamsSchema(method)), + includeDeprecated: method.deprecated && !groupDeprecated, + includeExperimental: method.stability === "experimental" && !groupExperimental, + }); + if (hasParams) { + lines.push(` ${name}(params: ${pType}): Promise<${rType}>;`); + } else { + lines.push(` ${name}(): Promise<${rType}>;`); + } + } + lines.push(`}`); + lines.push(""); + } + + lines.push(`/** All client global API handler groups. */`); + lines.push(`export interface ClientGlobalApiHandlers {`); + for (const [groupName] of groups) { + const interfaceName = toPascalCase(groupName) + "Handler"; + lines.push(` ${groupName}?: ${interfaceName};`); + } + lines.push(`}`); + lines.push(""); + + lines.push(`/**`); + lines.push(` * Register client global API handlers on a JSON-RPC connection.`); + lines.push(` * The server calls these methods to delegate work to the client.`); + lines.push(` * Unlike session-scoped client APIs, these methods carry no implicit`); + lines.push(` * \`sessionId\` dispatch key — a single set of handlers serves the entire`); + lines.push(` * connection.`); + lines.push(` */`); + lines.push(`export function registerClientGlobalApiHandlers(`); + lines.push(` connection: MessageConnection,`); + lines.push(` handlers: ClientGlobalApiHandlers,`); + lines.push(`): void {`); + + for (const [groupName, methods] of groups) { + for (const method of methods) { + const name = handlerMethodName(method.rpcMethod); + const pType = paramsTypeName(method); + const hasParams = hasSchemaPayload(getMethodParamsSchema(method)); + + if (hasParams) { + lines.push(` connection.onRequest("${method.rpcMethod}", async (params: ${pType}) => {`); + lines.push(` const handler = handlers.${groupName};`); + lines.push(` if (!handler) throw new Error("No ${groupName} client-global handler registered");`); + lines.push(` return handler.${name}(params);`); + lines.push(` });`); + } else { + lines.push(` connection.onRequest("${method.rpcMethod}", async () => {`); + lines.push(` const handler = handlers.${groupName};`); + lines.push(` if (!handler) throw new Error("No ${groupName} client-global handler registered");`); + lines.push(` return handler.${name}();`); + lines.push(` });`); + } + } + } + + lines.push(`}`); + lines.push(""); + + return lines; +} + // ── Main ──────────────────────────────────────────────────────────────────── async function generate(sessionSchemaPath?: string, apiSchemaPath?: string): Promise { diff --git a/scripts/codegen/utils.ts b/scripts/codegen/utils.ts index b9bfb9730..c63f9732c 100644 --- a/scripts/codegen/utils.ts +++ b/scripts/codegen/utils.ts @@ -506,6 +506,7 @@ export interface ApiSchema { server?: Record; session?: Record; clientSession?: Record; + clientGlobal?: Record; } export function isRpcMethod(node: unknown): node is RpcMethod { @@ -555,6 +556,7 @@ export function fixNullableRequiredRefsInApiSchema(schema: ApiSchema): ApiSchema server: walkApiNode(schema.server), session: walkApiNode(schema.session), clientSession: walkApiNode(schema.clientSession), + clientGlobal: walkApiNode(schema.clientGlobal), }; }