Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 250 additions & 0 deletions python/samples/02-agents/auto_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "agent-framework",
# "tenacity",
# ]
# ///
# Run with any PEP 723 compatible runner, e.g.:
# uv run samples/02-agents/auto_retry.py

# Copyright (c) Microsoft. All rights reserved.

import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar, cast

from agent_framework import ChatContext, ChatMiddleware, SupportsChatGetResponse, chat_middleware
from agent_framework.azure import AzureOpenAIChatClient
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from openai import RateLimitError
from tenacity import (
AsyncRetrying,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

# Load environment variables from .env file
load_dotenv()

"""
Auto-Retry Rate Limiting Sample

Every model inference API enforces rate limits, so production agents need retry logic
to handle 429 responses gracefully. This sample shows two ways to add automatic retry
using the `tenacity` library, keeping your application code free of boilerplate.

Approach 1 – Class decorator
Apply a class decorator to any client type implementing
SupportsChatGetResponse. The decorator patches get_response() with retry
behavior. Non-streaming responses are retried; streaming is returned as-is
(streaming retry requires more delicate handling).

Approach 2 – Chat middleware
Register middleware on the agent that catches RateLimitError raised inside
call_next() and retries the entire request pipeline. Two styles are shown:
a) Class-based middleware (ChatMiddleware subclass)
b) Function-based middleware (@chat_middleware decorator)

Both approaches use the same tenacity primitives:
- stop_after_attempt – cap the total number of tries
- wait_exponential – exponential back-off between retries
- retry_if_exception_type(RateLimitError) – only retry on 429 errors
- before_sleep_log – log each retry attempt at WARNING level
"""

logger = logging.getLogger(__name__)

RETRY_ATTEMPTS = 3

# =============================================================================
# Approach 1: Class decorator
# =============================================================================


ChatClientT = TypeVar("ChatClientT", bound=SupportsChatGetResponse[Any])


def with_rate_limit_retry(*, retry_attempts: int = RETRY_ATTEMPTS) -> Callable[[type[ChatClientT]], type[ChatClientT]]:
"""Class decorator that adds non-streaming retry behavior to get_response()."""

def decorator(client_cls: type[ChatClientT]) -> type[ChatClientT]:
original_get_response = client_cls.get_response

def get_response_with_retry(self, *args, **kwargs): # type: ignore[no-untyped-def]
stream = kwargs.get("stream", False)

if stream:
# Streaming retry is more complex; fall back to the original behaviour.
return original_get_response(self, *args, **kwargs)

async def _with_retry():
async for attempt in AsyncRetrying(
stop=stop_after_attempt(retry_attempts),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
):
with attempt:
return await original_get_response(self, *args, **kwargs)
return None

return _with_retry()

client_cls.get_response = cast(Any, get_response_with_retry)
return client_cls

return decorator


@with_rate_limit_retry()
class RetryingAzureOpenAIChatClient(AzureOpenAIChatClient):
"""Azure OpenAI Chat client with class-decorator-based retry behavior."""


# =============================================================================
# Approach 2a: Class-based chat middleware
# =============================================================================


class RateLimitRetryMiddleware(ChatMiddleware):
"""Chat middleware that retries the full request pipeline on rate limit errors.

Register this middleware on an agent (or at the run level) to automatically
retry any call_next() invocation that raises RateLimitError.
"""

def __init__(self, *, max_attempts: int = RETRY_ATTEMPTS) -> None:
"""Initialize with the maximum number of retry attempts."""
self.max_attempts = max_attempts

async def process(
self,
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Retry call_next() on rate limit errors with exponential back-off."""
async for attempt in AsyncRetrying(
stop=stop_after_attempt(self.max_attempts),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
):
with attempt:
await call_next()


# =============================================================================
# Approach 2b: Function-based chat middleware
# =============================================================================


@chat_middleware
async def rate_limit_retry_middleware(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based chat middleware that retries on rate limit errors.

Wrap call_next() with a tenacity @retry decorator so any RateLimitError
raised during model inference triggers an automatic retry with exponential
back-off.
"""

@retry(
stop=stop_after_attempt(RETRY_ATTEMPTS),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(RateLimitError),
reraise=True,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
async def _call_next_with_retry() -> None:
await call_next()

await _call_next_with_retry()


# =============================================================================
# Demo
# =============================================================================


async def class_decorator_example() -> None:
"""Demonstrate Approach 1: class decorator on a chat client type."""
print("\n" + "=" * 60)
print("Approach 1: Class decorator (applied to client type)")
print("=" * 60)

# For authentication, run `az login` command in terminal or replace
# AzureCliCredential with your preferred authentication option.
agent = RetryingAzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
instructions="You are a helpful assistant.",
)

query = "Say hello!"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")


async def class_based_middleware_example() -> None:
"""Demonstrate Approach 2a: class-based chat middleware."""
print("\n" + "=" * 60)
print("Approach 2a: Class-based chat middleware")
print("=" * 60)

# For authentication, run `az login` command in terminal or replace
# AzureCliCredential with your preferred authentication option.
agent = AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
instructions="You are a helpful assistant.",
middleware=[RateLimitRetryMiddleware(max_attempts=3)],
)

query = "Say hello!"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")


async def function_based_middleware_example() -> None:
"""Demonstrate Approach 2b: function-based chat middleware."""
print("\n" + "=" * 60)
print("Approach 2b: Function-based chat middleware")
print("=" * 60)

# For authentication, run `az login` command in terminal or replace
# AzureCliCredential with your preferred authentication option.
agent = AzureOpenAIChatClient(credential=AzureCliCredential()).as_agent(
instructions="You are a helpful assistant.",
middleware=[rate_limit_retry_middleware],
)

query = "Say hello!"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}")


async def main() -> None:
"""Run all auto-retry examples."""
print("=== Auto-Retry Rate Limiting Sample ===")
print(
"Demonstrates two approaches for automatic retry on rate limit (429) errors.\n"
"Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME (and optionally\n"
"AZURE_OPENAI_API_KEY) before running, or populate a .env file."
)

await class_decorator_example()
await class_based_middleware_example()
await function_based_middleware_example()


if __name__ == "__main__":
asyncio.run(main())