diff --git a/dotnet/src/Agents/Orchestration/AgentActor.cs b/dotnet/src/Agents/Orchestration/AgentActor.cs index 962dc4327241..06a4d950e2a0 100644 --- a/dotnet/src/Agents/Orchestration/AgentActor.cs +++ b/dotnet/src/Agents/Orchestration/AgentActor.cs @@ -92,23 +92,31 @@ protected ValueTask InvokeAsync(ChatMessageContent input, Ca /// A task that returns the response . protected async ValueTask InvokeAsync(IList input, CancellationToken cancellationToken) { - this.Context.Cancellation.ThrowIfCancellationRequested(); + try + { + this.Context.Cancellation.ThrowIfCancellationRequested(); - this._lastResponse = null; + this._lastResponse = null; - AgentInvokeOptions options = this.GetInvokeOptions(HandleMessageAsync); - if (this.Context.StreamingResponseCallback == null) - { - // No need to utilize streaming if no callback is provided - await this.InvokeAsync(input, options, cancellationToken).ConfigureAwait(false); + AgentInvokeOptions options = this.GetInvokeOptions(HandleMessageAsync); + if (this.Context.StreamingResponseCallback == null) + { + // No need to utilize streaming if no callback is provided + await this.InvokeAsync(input, options, cancellationToken).ConfigureAwait(false); + } + else + { + await this.InvokeStreamingAsync(input, options, cancellationToken).ConfigureAwait(false); + } + + return this._lastResponse ?? new ChatMessageContent(AuthorRole.Assistant, string.Empty); } - else + catch (Exception exception) { - await this.InvokeStreamingAsync(input, options, cancellationToken).ConfigureAwait(false); + this.Context.FailureCallback.Invoke(exception); + throw; } - return this._lastResponse ?? new ChatMessageContent(AuthorRole.Assistant, string.Empty); - async Task HandleMessageAsync(ChatMessageContent message) { this._lastResponse = message; // Keep track of most recent response for both invocation modes diff --git a/dotnet/src/Agents/Orchestration/AgentOrchestration.cs b/dotnet/src/Agents/Orchestration/AgentOrchestration.cs index 63267fd3ec5d..70eb4d7ef0a7 100644 --- a/dotnet/src/Agents/Orchestration/AgentOrchestration.cs +++ b/dotnet/src/Agents/Orchestration/AgentOrchestration.cs @@ -114,18 +114,19 @@ public async ValueTask> InvokeAsync( CancellationTokenSource orchestrationCancelSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + TaskCompletionSource completion = new(); + OrchestrationContext context = new(this.OrchestrationLabel, topic, this.ResponseCallback, this.StreamingResponseCallback, + exception => completion.SetException(exception), this.LoggerFactory, cancellationToken); ILogger logger = this.LoggerFactory.CreateLogger(this.GetType()); - TaskCompletionSource completion = new(); - AgentType orchestrationType = await this.RegisterAsync(runtime, context, completion, handoff: null).ConfigureAwait(false); cancellationToken.ThrowIfCancellationRequested(); diff --git a/dotnet/src/Agents/Orchestration/OrchestrationContext.cs b/dotnet/src/Agents/Orchestration/OrchestrationContext.cs index 7b798f9f2659..1946fbba0273 100644 --- a/dotnet/src/Agents/Orchestration/OrchestrationContext.cs +++ b/dotnet/src/Agents/Orchestration/OrchestrationContext.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Threading; using Microsoft.Extensions.Logging; using Microsoft.SemanticKernel.Agents.Runtime; @@ -16,11 +17,13 @@ internal OrchestrationContext( TopicId topic, OrchestrationResponseCallback? responseCallback, OrchestrationStreamingCallback? streamingCallback, + Action failureCallback, ILoggerFactory loggerFactory, CancellationToken cancellation) { this.Orchestration = orchestration; this.Topic = topic; + this.FailureCallback = failureCallback; this.ResponseCallback = responseCallback; this.StreamingResponseCallback = streamingCallback; this.LoggerFactory = loggerFactory; @@ -59,4 +62,9 @@ internal OrchestrationContext( /// Optional callback that is invoked for every agent response. /// public OrchestrationStreamingCallback? StreamingResponseCallback { get; } + + /// + /// Gets the callback that is invoked when an operation fails due to an exception. + /// + public Action FailureCallback { get; } } diff --git a/dotnet/src/Agents/UnitTests/Orchestration/OrchestrationResultTests.cs b/dotnet/src/Agents/UnitTests/Orchestration/OrchestrationResultTests.cs index d65a2c5e8b8c..908304814071 100644 --- a/dotnet/src/Agents/UnitTests/Orchestration/OrchestrationResultTests.cs +++ b/dotnet/src/Agents/UnitTests/Orchestration/OrchestrationResultTests.cs @@ -16,7 +16,8 @@ public class OrchestrationResultTests public void Constructor_InitializesPropertiesCorrectly() { // Arrange - OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None); + Exception? captureException = null; + OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None); TaskCompletionSource tcs = new(); // Act @@ -24,6 +25,7 @@ public void Constructor_InitializesPropertiesCorrectly() using OrchestrationResult result = new(context, tcs, cancelSource, NullLogger.Instance); // Assert + Assert.Null(captureException); Assert.Equal("TestOrchestration", result.Orchestration); Assert.Equal(new TopicId("testTopic"), result.Topic); } @@ -32,7 +34,8 @@ public void Constructor_InitializesPropertiesCorrectly() public async Task GetValueAsync_ReturnsCompletedValue_WhenTaskIsCompletedAsync() { // Arrange - OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None); + Exception? captureException = null; + OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None); TaskCompletionSource tcs = new(); using CancellationTokenSource cancelSource = new(); using OrchestrationResult result = new(context, tcs, cancelSource, NullLogger.Instance); @@ -43,6 +46,7 @@ public async Task GetValueAsync_ReturnsCompletedValue_WhenTaskIsCompletedAsync() string actualValue = await result.GetValueAsync(); // Assert + Assert.Null(captureException); Assert.Equal(expectedValue, actualValue); } @@ -50,7 +54,8 @@ public async Task GetValueAsync_ReturnsCompletedValue_WhenTaskIsCompletedAsync() public async Task GetValueAsync_WithTimeout_ReturnsCompletedValue_WhenTaskCompletesWithinTimeoutAsync() { // Arrange - OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None); + Exception? captureException = null; + OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None); TaskCompletionSource tcs = new(); using CancellationTokenSource cancelSource = new(); using OrchestrationResult result = new(context, tcs, cancelSource, NullLogger.Instance); @@ -62,6 +67,7 @@ public async Task GetValueAsync_WithTimeout_ReturnsCompletedValue_WhenTaskComple string actualValue = await result.GetValueAsync(timeout); // Assert + Assert.Null(captureException); Assert.Equal(expectedValue, actualValue); } @@ -69,7 +75,8 @@ public async Task GetValueAsync_WithTimeout_ReturnsCompletedValue_WhenTaskComple public async Task GetValueAsync_WithTimeout_ThrowsTimeoutException_WhenTaskDoesNotCompleteWithinTimeoutAsync() { // Arrange - OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None); + Exception? captureException = null; + OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None); TaskCompletionSource tcs = new(); using CancellationTokenSource cancelSource = new(); using OrchestrationResult result = new(context, tcs, cancelSource, NullLogger.Instance); @@ -77,13 +84,15 @@ public async Task GetValueAsync_WithTimeout_ThrowsTimeoutException_WhenTaskDoesN // Act & Assert TimeoutException exception = await Assert.ThrowsAsync(() => result.GetValueAsync(timeout).AsTask()); + Assert.Null(captureException); } [Fact] public async Task GetValueAsync_ReturnsCompletedValue_WhenCompletionIsDelayedAsync() { // Arrange - OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, NullLoggerFactory.Instance, CancellationToken.None); + Exception? captureException = null; + OrchestrationContext context = new("TestOrchestration", new TopicId("testTopic"), null, null, exception => captureException = exception, NullLoggerFactory.Instance, CancellationToken.None); TaskCompletionSource tcs = new(); using CancellationTokenSource cancelSource = new(); using OrchestrationResult result = new(context, tcs, cancelSource, NullLogger.Instance); @@ -100,6 +109,7 @@ public async Task GetValueAsync_ReturnsCompletedValue_WhenCompletionIsDelayedAsy int actualValue = await result.GetValueAsync(); // Assert + Assert.Null(captureException); Assert.Equal(expectedValue, actualValue); } }