Skip to content

Commit 76fe731

Browse files
authored
.NET: feat: Refactor Handoff Orchestration and add HITL support (#5174)
* feat: Refactor Handoff Orchestration and add HITL support * Change HandoffAgentExecutor to use factory-based instantiation * Extract shared request collection logic in AIAgentUnservicedRequestsCollector * Refactor HandoffAgentExecutor to use the "ContinueTurn" pattern as in AIAgentHostExecutor * fix: Remove '$' from exception strings
1 parent 39b560f commit 76fe731

10 files changed

Lines changed: 749 additions & 192 deletions

File tree

dotnet/src/Microsoft.Agents.AI.Workflows/HandoffWorkflowBuilder.cs

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
using Microsoft.Extensions.AI;
99
using Microsoft.Shared.Diagnostics;
1010

11+
using ExecutorFactoryFunc = System.Func<Microsoft.Agents.AI.Workflows.ExecutorConfig<Microsoft.Agents.AI.Workflows.ExecutorOptions>,
12+
string,
13+
System.Threading.Tasks.ValueTask<Microsoft.Agents.AI.Workflows.Specialized.HandoffAgentExecutor>>;
14+
1115
namespace Microsoft.Agents.AI.Workflows;
1216

1317
internal static class DiagnosticConstants
@@ -233,24 +237,70 @@ public TBuilder WithHandoff(AIAgent from, AIAgent to, string? handoffReason = nu
233237
return (TBuilder)this;
234238
}
235239

240+
private Dictionary<string, ExecutorBinding> CreateExecutorBindings(WorkflowBuilder builder)
241+
{
242+
HandoffAgentExecutorOptions options = new(this.HandoffInstructions,
243+
this._emitAgentResponseEvents,
244+
this._emitAgentResponseUpdateEvents,
245+
this._toolCallFilteringBehavior);
246+
247+
// There are two types of ids being used in this method, and it is critical that we are clear about
248+
// which one we are using, and where.
249+
// AgentId...: comes from AIAgent.Id, is often an unreadable machine identifier (e.g. a Guid), and is used to address
250+
// the handoffs
251+
// ExecutorId: uses AIAgent.GetDescriptiveId() to use a friendlier name in telemetry, and is used for ExecutorBinding,
252+
// which are subsequently used in building the workflow
253+
254+
// The outgoing dictionary maps from AgentId => ExecutorBinding
255+
return this._allAgents.ToDictionary(keySelector: a => a.Id, elementSelector: CreateFactoryBinding);
256+
257+
ExecutorBinding CreateFactoryBinding(AIAgent agent)
258+
{
259+
if (!this._targets.TryGetValue(agent, out HashSet<HandoffTarget>? handoffs))
260+
{
261+
handoffs = new();
262+
}
263+
264+
// Use the ExecutorId as the placeholder id for a (possibly) future-bound factory
265+
builder.AddSwitch(HandoffAgentExecutor.IdFor(agent), (SwitchBuilder sb) =>
266+
{
267+
foreach (HandoffTarget handoff in handoffs)
268+
{
269+
sb.AddCase<HandoffState>(state => state?.RequestedHandoffTargetAgentId == handoff.Target.Id, // Use AgentId for target matching
270+
HandoffAgentExecutor.IdFor(handoff.Target)); // Use ExecutorId in for routing at the workflow level
271+
}
272+
273+
sb.WithDefault(HandoffEndExecutor.ExecutorId);
274+
});
275+
276+
ExecutorFactoryFunc factory =
277+
(config, sessionId) => new(
278+
new HandoffAgentExecutor(agent,
279+
handoffs,
280+
options));
281+
282+
// Make sure to use ExecutorId when binding the executor, not AgentId
283+
ExecutorBinding binding = factory.BindExecutor(HandoffAgentExecutor.IdFor(agent));
284+
285+
builder.BindExecutor(binding);
286+
287+
return binding;
288+
}
289+
}
290+
236291
/// <summary>
237292
/// Builds a <see cref="Workflow"/> composed of agents that operate via handoffs, with the next
238293
/// agent to process messages selected by the current agent.
239294
/// </summary>
240295
/// <returns>The workflow built based on the handoffs in the builder.</returns>
241296
public Workflow Build()
242297
{
243-
HandoffsStartExecutor start = new(this._returnToPrevious);
244-
HandoffsEndExecutor end = new(this._returnToPrevious);
298+
HandoffStartExecutor start = new(this._returnToPrevious);
299+
HandoffEndExecutor end = new(this._returnToPrevious);
245300
WorkflowBuilder builder = new(start);
246301

247-
HandoffAgentExecutorOptions options = new(this.HandoffInstructions,
248-
this._emitAgentResponseEvents,
249-
this._emitAgentResponseUpdateEvents,
250-
this._toolCallFilteringBehavior);
251-
252-
// Create an AgentExecutor for each agent.
253-
Dictionary<string, HandoffAgentExecutor> executors = this._allAgents.ToDictionary(a => a.Id, a => new HandoffAgentExecutor(a, options));
302+
// Create an factory-based ExecutorBinding for each agent.
303+
Dictionary<string, ExecutorBinding> executors = this.CreateExecutorBindings(builder);
254304

255305
// Connect the start executor to the initial agent (or use dynamic routing when ReturnToPrevious is enabled).
256306
if (this._returnToPrevious)
@@ -263,7 +313,7 @@ public Workflow Build()
263313
if (agent.Id != initialAgentId)
264314
{
265315
string agentId = agent.Id;
266-
sb.AddCase<HandoffState>(state => state?.CurrentAgentId == agentId, executors[agentId]);
316+
sb.AddCase<HandoffState>(state => state?.PreviousAgentId == agentId, executors[agentId]);
267317
}
268318
}
269319

@@ -275,13 +325,6 @@ public Workflow Build()
275325
builder.AddEdge(start, executors[this._initialAgent.Id]);
276326
}
277327

278-
// Initialize each executor with its handoff targets to the other executors.
279-
foreach (var agent in this._allAgents)
280-
{
281-
executors[agent.Id].Initialize(builder, end, executors,
282-
this._targets.TryGetValue(agent, out HashSet<HandoffTarget>? targets) ? targets : []);
283-
}
284-
285328
// Build the workflow.
286329
return builder.WithOutputFrom(end).Build();
287330
}

dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/AIAgentHostExecutor.cs

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ public static bool ShouldEmitStreamingEvents(this TurnToken token, bool? agentSe
1919

2020
public static bool ShouldEmitStreamingEvents(bool? turnTokenSetting, bool? agentSetting)
2121
=> turnTokenSetting ?? agentSetting ?? false;
22+
23+
public static bool ShouldEmitStreamingEvents(this HandoffState handoffState, bool? agentSetting)
24+
=> handoffState.TurnToken.ShouldEmitStreamingEvents(agentSetting);
2225
}
2326

2427
internal sealed class AIAgentHostExecutor : ChatProtocolExecutor
@@ -81,7 +84,11 @@ private ValueTask HandleUserInputResponseAsync(
8184
// resumes can be processed in one invocation.
8285
return this.ProcessTurnMessagesAsync(async (pendingMessages, ctx, ct) =>
8386
{
84-
pendingMessages.Add(new ChatMessage(ChatRole.User, [response]));
87+
pendingMessages.Add(new ChatMessage(ChatRole.User, [response])
88+
{
89+
CreatedAt = DateTimeOffset.UtcNow,
90+
MessageId = Guid.NewGuid().ToString("N"),
91+
});
8592

8693
await this.ContinueTurnAsync(pendingMessages, ctx, this._currentTurnEmitEvents ?? false, ct).ConfigureAwait(false);
8794

@@ -104,7 +111,12 @@ private ValueTask HandleFunctionResultAsync(
104111
// resumes can be processed in one invocation.
105112
return this.ProcessTurnMessagesAsync(async (pendingMessages, ctx, ct) =>
106113
{
107-
pendingMessages.Add(new ChatMessage(ChatRole.Tool, [result]));
114+
pendingMessages.Add(new ChatMessage(ChatRole.Tool, [result])
115+
{
116+
AuthorName = this._agent.Name ?? this._agent.Id,
117+
CreatedAt = DateTimeOffset.UtcNow,
118+
MessageId = Guid.NewGuid().ToString("N"),
119+
});
108120

109121
await this.ContinueTurnAsync(pendingMessages, ctx, this._currentTurnEmitEvents ?? false, ct).ConfigureAwait(false);
110122

@@ -186,16 +198,13 @@ protected override ValueTask TakeTurnAsync(List<ChatMessage> messages, IWorkflow
186198
TurnExtensions.ShouldEmitStreamingEvents(turnTokenSetting: emitEvents, this._options.EmitAgentUpdateEvents),
187199
cancellationToken);
188200

189-
private async ValueTask<AgentResponse> InvokeAgentAsync(IEnumerable<ChatMessage> messages, IWorkflowContext context, bool emitEvents, CancellationToken cancellationToken = default)
201+
private async ValueTask<AgentResponse> InvokeAgentAsync(IEnumerable<ChatMessage> messages, IWorkflowContext context, bool emitUpdateEvents, CancellationToken cancellationToken = default)
190202
{
191-
#pragma warning disable MEAI001
192-
Dictionary<string, ToolApprovalRequestContent> userInputRequests = new();
193-
Dictionary<string, FunctionCallContent> functionCalls = new();
194203
AgentResponse response;
204+
AIAgentUnservicedRequestsCollector collector = new(this._userInputHandler, this._functionCallHandler);
195205

196-
if (emitEvents)
206+
if (emitUpdateEvents)
197207
{
198-
#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
199208
// Run the agent in streaming mode only when agent run update events are to be emitted.
200209
IAsyncEnumerable<AgentResponseUpdate> agentStream = this._agent.RunStreamingAsync(
201210
messages,
@@ -206,7 +215,7 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
206215
await foreach (AgentResponseUpdate update in agentStream.ConfigureAwait(false))
207216
{
208217
await context.YieldOutputAsync(update, cancellationToken).ConfigureAwait(false);
209-
ExtractUnservicedRequests(update.Contents);
218+
collector.ProcessAgentResponseUpdate(update);
210219
updates.Add(update);
211220
}
212221

@@ -220,53 +229,16 @@ await this.EnsureSessionAsync(context, cancellationToken).ConfigureAwait(false),
220229
cancellationToken: cancellationToken)
221230
.ConfigureAwait(false);
222231

223-
ExtractUnservicedRequests(response.Messages.SelectMany(message => message.Contents));
232+
collector.ProcessAgentResponse(response);
224233
}
225234

226235
if (this._options.EmitAgentResponseEvents)
227236
{
228237
await context.YieldOutputAsync(response, cancellationToken).ConfigureAwait(false);
229238
}
230239

231-
if (userInputRequests.Count > 0 || functionCalls.Count > 0)
232-
{
233-
Task userInputTask = this._userInputHandler?.ProcessRequestContentsAsync(userInputRequests, context, cancellationToken) ?? Task.CompletedTask;
234-
Task functionCallTask = this._functionCallHandler?.ProcessRequestContentsAsync(functionCalls, context, cancellationToken) ?? Task.CompletedTask;
235-
236-
await Task.WhenAll(userInputTask, functionCallTask)
237-
.ConfigureAwait(false);
238-
}
240+
await collector.SubmitAsync(context, cancellationToken).ConfigureAwait(false);
239241

240242
return response;
241-
242-
void ExtractUnservicedRequests(IEnumerable<AIContent> contents)
243-
{
244-
foreach (AIContent content in contents)
245-
{
246-
if (content is ToolApprovalRequestContent userInputRequest)
247-
{
248-
// It is an error to simultaneously have multiple outstanding user input requests with the same ID.
249-
userInputRequests.Add(userInputRequest.RequestId, userInputRequest);
250-
}
251-
else if (content is ToolApprovalResponseContent userInputResponse)
252-
{
253-
// If the set of messages somehow already has a corresponding user input response, remove it.
254-
_ = userInputRequests.Remove(userInputResponse.RequestId);
255-
}
256-
else if (content is FunctionCallContent functionCall)
257-
{
258-
// For function calls, we emit an event to notify the workflow.
259-
//
260-
// possibility 1: this will be handled inline by the agent abstraction
261-
// possibility 2: this will not be handled inline by the agent abstraction
262-
functionCalls.Add(functionCall.CallId, functionCall);
263-
}
264-
else if (content is FunctionResultContent functionResult)
265-
{
266-
_ = functionCalls.Remove(functionResult.CallId);
267-
}
268-
}
269-
}
270-
#pragma warning restore MEAI001
271243
}
272244
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
using Microsoft.Extensions.AI;
9+
10+
namespace Microsoft.Agents.AI.Workflows.Specialized;
11+
12+
internal sealed class AIAgentUnservicedRequestsCollector(AIContentExternalHandler<ToolApprovalRequestContent, ToolApprovalResponseContent>? userInputHandler,
13+
AIContentExternalHandler<FunctionCallContent, FunctionResultContent>? functionCallHandler)
14+
{
15+
private readonly Dictionary<string, ToolApprovalRequestContent> _userInputRequests = [];
16+
private readonly Dictionary<string, FunctionCallContent> _functionCalls = [];
17+
18+
public Task SubmitAsync(IWorkflowContext context, CancellationToken cancellationToken)
19+
{
20+
Task userInputTask = userInputHandler != null && this._userInputRequests.Count > 0
21+
? userInputHandler.ProcessRequestContentsAsync(this._userInputRequests, context, cancellationToken)
22+
: Task.CompletedTask;
23+
24+
Task functionCallTask = functionCallHandler != null && this._functionCalls.Count > 0
25+
? functionCallHandler.ProcessRequestContentsAsync(this._functionCalls, context, cancellationToken)
26+
: Task.CompletedTask;
27+
28+
return Task.WhenAll(userInputTask, functionCallTask);
29+
}
30+
31+
public void ProcessAgentResponseUpdate(AgentResponseUpdate update, Func<FunctionCallContent, bool>? functionCallFilter = null)
32+
=> this.ProcessAIContents(update.Contents, functionCallFilter);
33+
34+
public void ProcessAgentResponse(AgentResponse response)
35+
=> this.ProcessAIContents(response.Messages.SelectMany(message => message.Contents));
36+
37+
public void ProcessAIContents(IEnumerable<AIContent> contents, Func<FunctionCallContent, bool>? functionCallFilter = null)
38+
{
39+
foreach (AIContent content in contents)
40+
{
41+
if (content is ToolApprovalRequestContent userInputRequest)
42+
{
43+
if (this._userInputRequests.ContainsKey(userInputRequest.RequestId))
44+
{
45+
throw new InvalidOperationException($"ToolApprovalRequestContent with duplicate RequestId: {userInputRequest.RequestId}");
46+
}
47+
48+
// It is an error to simultaneously have multiple outstanding user input requests with the same ID.
49+
this._userInputRequests.Add(userInputRequest.RequestId, userInputRequest);
50+
}
51+
else if (content is ToolApprovalResponseContent userInputResponse)
52+
{
53+
// If the set of messages somehow already has a corresponding user input response, remove it.
54+
_ = this._userInputRequests.Remove(userInputResponse.RequestId);
55+
}
56+
else if (content is FunctionCallContent functionCall)
57+
{
58+
// For function calls, we emit an event to notify the workflow.
59+
//
60+
// possibility 1: this will be handled inline by the agent abstraction
61+
// possibility 2: this will not be handled inline by the agent abstraction
62+
if (functionCallFilter == null || functionCallFilter(functionCall))
63+
{
64+
if (this._functionCalls.ContainsKey(functionCall.CallId))
65+
{
66+
throw new InvalidOperationException($"FunctionCallContent with duplicate CallId: {functionCall.CallId}");
67+
}
68+
69+
this._functionCalls.Add(functionCall.CallId, functionCall);
70+
}
71+
}
72+
else if (content is FunctionResultContent functionResult)
73+
{
74+
_ = this._functionCalls.Remove(functionResult.CallId);
75+
}
76+
}
77+
}
78+
}

0 commit comments

Comments
 (0)