forked from microsoft/agent-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmiddleware_termination.py
More file actions
177 lines (144 loc) · 6.38 KB
/
middleware_termination.py
File metadata and controls
177 lines (144 loc) · 6.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentMiddleware,
AgentRunContext,
AgentRunResponse,
ChatMessage,
Role,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Middleware Termination Example
This sample demonstrates how middleware can terminate execution using the `context.terminate` flag.
The example includes:
- PreTerminationMiddleware: Terminates execution before calling next() to prevent agent processing
- PostTerminationMiddleware: Allows processing to complete but terminates further execution
This is useful for implementing security checks, rate limiting, or early exit conditions.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class PreTerminationMiddleware(AgentMiddleware):
"""Middleware that terminates execution before calling the agent."""
def __init__(self, blocked_words: list[str]):
self.blocked_words = [word.lower() for word in blocked_words]
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
# Check if the user message contains any blocked words
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text.lower()
for blocked_word in self.blocked_words:
if blocked_word in query:
print(f"[PreTerminationMiddleware] Blocked word '{blocked_word}' detected. Terminating request.")
# Set a custom response
context.result = AgentRunResponse(
messages=[
ChatMessage(
role=Role.ASSISTANT,
text=(
f"Sorry, I cannot process requests containing '{blocked_word}'. "
"Please rephrase your question."
),
)
]
)
# Set terminate flag to prevent further processing
context.terminate = True
break
await next(context)
class PostTerminationMiddleware(AgentMiddleware):
"""Middleware that allows processing but terminates after reaching max responses across multiple runs."""
def __init__(self, max_responses: int = 1):
self.max_responses = max_responses
self.response_count = 0
async def process(
self,
context: AgentRunContext,
next: Callable[[AgentRunContext], Awaitable[None]],
) -> None:
print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})")
# Check if we should terminate before processing
if self.response_count >= self.max_responses:
print(
f"[PostTerminationMiddleware] Maximum responses ({self.max_responses}) reached. "
"Terminating further processing."
)
context.terminate = True
# Allow the agent to process normally
await next(context)
# Increment response count after processing
self.response_count += 1
async def pre_termination_middleware() -> None:
"""Demonstrate pre-termination middleware that blocks requests with certain words."""
print("\n--- Example 1: Pre-termination Middleware ---")
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=PreTerminationMiddleware(blocked_words=["bad", "inappropriate"]),
) as agent,
):
# Test with normal query
print("\n1. Normal query:")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
# Test with blocked word
print("\n2. Query with blocked word:")
query = "What's the bad weather in New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
async def post_termination_middleware() -> None:
"""Demonstrate post-termination middleware that limits responses across multiple runs."""
print("\n--- Example 2: Post-termination Middleware ---")
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(async_credential=credential).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=PostTerminationMiddleware(max_responses=1),
) as agent,
):
# First run (should work)
print("\n1. First run:")
query = "What's the weather in Paris?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")
# Second run (should be terminated by middleware)
print("\n2. Second run (should be terminated):")
query = "What about the weather in London?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
# Third run (should also be terminated)
print("\n3. Third run (should also be terminated):")
query = "And New York?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text if result.text else 'No response (terminated)'}")
async def main() -> None:
"""Example demonstrating middleware termination functionality."""
print("=== Middleware Termination Example ===")
await pre_termination_middleware()
await post_termination_middleware()
if __name__ == "__main__":
asyncio.run(main())