diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index bfa684b469..0cdc70b42b 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -154,6 +154,14 @@ handler, ) from ._workflows._function_executor import FunctionExecutor, executor +from ._workflows._functional import ( + FunctionalWorkflow, + FunctionalWorkflowAgent, + RunContext, + StepWrapper, + step, + workflow, +) from ._workflows._request_info_mixin import response_handler from ._workflows._runner import Runner from ._workflows._runner_context import ( @@ -244,6 +252,8 @@ "FunctionMiddleware", "FunctionMiddlewareTypes", "FunctionTool", + "FunctionalWorkflow", + "FunctionalWorkflowAgent", "GeneratedEmbeddings", "GraphConnectivityError", "InMemoryCheckpointStorage", @@ -263,11 +273,13 @@ "ResponseStream", "Role", "RoleLiteral", + "RunContext", "Runner", "RunnerContext", "SecretString", "SessionContext", "SingleEdgeGroup", + "StepWrapper", "SubWorkflowRequestMessage", "SubWorkflowResponseMessage", "SupportsAgentRun", @@ -327,9 +339,11 @@ "register_state_type", "resolve_agent_id", "response_handler", + "step", "tool", "validate_chat_options", "validate_tool_mode", "validate_tools", "validate_workflow_graph", + "workflow", ] diff --git a/python/packages/core/agent_framework/_workflows/_functional.py b/python/packages/core/agent_framework/_workflows/_functional.py new file mode 100644 index 0000000000..80bb9d8cb9 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_functional.py @@ -0,0 +1,1102 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Functional workflow API for writing workflows as plain async functions. + +This module provides the ``@workflow`` and ``@step`` decorators that let users +define workflows using native Python control flow (if/else, loops, +``asyncio.gather``) instead of a graph-based topology. + +A ``@workflow``-decorated async function receives its input as the first +positional argument. If the function needs HITL (``request_info``), custom +events, or key/value state, add a :class:`RunContext` parameter — otherwise it +can be omitted. Inside the function, plain ``async`` calls run normally. +Optionally, ``@step``-decorated functions gain caching, per-step checkpointing, +and event emission. + +Key public symbols: + +* :func:`workflow` / :class:`FunctionalWorkflow` — decorator and runtime. +* :func:`step` / :class:`StepWrapper` — optional step decorator. +* :class:`RunContext` — execution context injected into workflow functions. +* :class:`FunctionalWorkflowAgent` — agent adapter returned by + :meth:`FunctionalWorkflow.as_agent`. +""" + +from __future__ import annotations + +# pyright: reportPrivateUsage=false +# Classes in this module (RunContext, StepWrapper, FunctionalWorkflow) form a +# cohesive unit and intentionally access each other's underscore-prefixed members. +import functools +import hashlib +import inspect +import logging +import typing +import uuid +from collections.abc import AsyncIterable, Awaitable, Callable, Sequence +from contextvars import ContextVar +from copy import deepcopy +from typing import Any, Generic, Literal, TypeVar, overload + +from .._types import AgentResponse, AgentResponseUpdate, ResponseStream +from ..observability import OtelAttr, capture_exception, create_workflow_span +from ._checkpoint import CheckpointStorage, WorkflowCheckpoint +from ._events import ( + WorkflowErrorDetails, + WorkflowEvent, + WorkflowRunState, + _framework_event_origin, # type: ignore[reportPrivateUsage] +) +from ._workflow import WorkflowRunResult + +logger = logging.getLogger(__name__) + +R = TypeVar("R") + +# ContextVar holding the active RunContext during workflow execution. +# ContextVar is per-asyncio-Task, so concurrent workflows each get their own context. +_active_run_ctx: ContextVar[RunContext | None] = ContextVar("_active_run_ctx", default=None) + + +# --------------------------------------------------------------------------- +# Internal exception for HITL interruption +# --------------------------------------------------------------------------- + + +class WorkflowInterrupted(BaseException): + """Internal: raised when request_info() is called during initial execution. + + Inherits from ``BaseException`` (not ``Exception``) so that user code + with ``except Exception:`` handlers inside a ``@workflow`` function does + not accidentally intercept the HITL interruption signal. + """ + + def __init__(self, request_id: str, request_data: Any, response_type: type) -> None: + self.request_id = request_id + self.request_data = request_data + self.response_type = response_type + super().__init__(f"Workflow interrupted by request_info (request_id={request_id})") + + +# --------------------------------------------------------------------------- +# RunContext +# --------------------------------------------------------------------------- + + +class RunContext: + """Execution context injected into ``@workflow`` functions. + + Every ``@workflow`` invocation receives a ``RunContext`` instance that + provides human-in-the-loop (HITL) requests, custom event emission, + key/value state, and event collection. The context is available to the + workflow function via a parameter annotated as ``RunContext``. + + The workflow's return value is automatically emitted as the output. + Use :meth:`add_event` to emit custom events during execution. + + Args: + workflow_name: Identifier for the enclosing workflow, used when + generating events and checkpoint metadata. + streaming: Whether the current run was started with ``stream=True``. + run_kwargs: Extra keyword arguments forwarded from + :meth:`FunctionalWorkflow.run`. + + Examples: + + .. code-block:: python + + @workflow + async def my_pipeline(data: str) -> str: + return await some_step(data) + + + # Add ctx: RunContext only when you need HITL, state, or custom events: + @workflow + async def hitl_pipeline(data: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": data}, response_type=str) + return feedback + """ + + def __init__( + self, + workflow_name: str, + *, + streaming: bool = False, + run_kwargs: dict[str, Any] | None = None, + ) -> None: + self._workflow_name = workflow_name + self._streaming = streaming + self._run_kwargs = run_kwargs or {} + + # Event accumulator + self._events: list[WorkflowEvent[Any]] = [] + + # Step result cache: (step_name, call_index) -> result + self._step_cache: dict[tuple[str, int], Any] = {} + # Per-step call counters for deterministic cache keys + self._step_call_counters: dict[str, int] = {} + + # HITL responses (set via _set_responses before replay) + self._responses: dict[str, Any] = {} + # Pending request_info events (for checkpointing) + self._pending_requests: dict[str, WorkflowEvent[Any]] = {} + + # User state (simple dict) + self._state: dict[str, Any] = {} + + # Callback invoked after each step completes (set by FunctionalWorkflow) + self._on_step_completed: Callable[[], Awaitable[None]] | None = None + + # ------------------------------------------------------------------ + # Public API (for @workflow functions) + # ------------------------------------------------------------------ + + async def request_info( + self, + request_data: Any, + response_type: type, + *, + request_id: str | None = None, + ) -> Any: + """Request external information (human-in-the-loop). + + On first execution this suspends the workflow by raising an internal + ``WorkflowInterrupted`` signal (caught by the framework, never exposed + to user code). The caller receives a ``WorkflowRunResult`` whose + :meth:`~WorkflowRunResult.get_request_info_events` contains the pending + request. When the workflow is resumed with + ``run(responses={request_id: value})``, the same function re-executes + and ``request_info`` returns the provided *value* directly. + + Args: + request_data: Arbitrary payload describing what information is + needed (e.g. a Pydantic model, dict, or string prompt). + response_type: The expected Python type of the response value. + request_id: Optional stable identifier for this request. If + omitted a random UUID is generated. + + Returns: + The response value supplied during replay. + + Raises: + WorkflowInterrupted: Raised internally on initial execution + (not visible to workflow authors). + """ + rid = request_id or str(uuid.uuid4()) + + # Check if we already have a response for this request + found, value = self._get_response(rid) + if found: + return value + + # No response — emit event and interrupt + event = WorkflowEvent.request_info( + request_id=rid, + source_executor_id=self._workflow_name, + request_data=request_data, + response_type=response_type, + ) + self._add_event(event) + self._pending_requests[rid] = event + raise WorkflowInterrupted(rid, request_data, response_type) + + async def add_event(self, event: WorkflowEvent[Any]) -> None: + """Add a custom event to the workflow event stream. + + Use this to inject application-specific events alongside the + framework-generated lifecycle events. + + Args: + event: The workflow event to append. + """ + self._add_event(event) + + def get_state(self, key: str, default: Any = None) -> Any: + """Retrieve a value from the workflow's key/value state. + + State values are persisted across HITL interruptions and are included + in checkpoints when checkpoint storage is configured. + + Args: + key: The state key to look up. + default: Value returned when *key* is absent. + + Returns: + The stored value, or *default* if the key does not exist. + """ + return self._state.get(key, default) + + def set_state(self, key: str, value: Any) -> None: + """Store a value in the workflow's key/value state. + + Args: + key: The state key. + value: The value to store. Must be JSON-serializable if + checkpoint storage is used. + """ + self._state[key] = value + + def is_streaming(self) -> bool: + """Return whether the current run was started with ``stream=True``. + + Returns: + ``True`` if the workflow is running in streaming mode. + """ + return self._streaming + + # ------------------------------------------------------------------ + # Internal API (for StepWrapper and FunctionalWorkflow) + # ------------------------------------------------------------------ + + def _add_event(self, event: WorkflowEvent[Any]) -> None: + self._events.append(event) + + def _get_events(self) -> list[WorkflowEvent[Any]]: + return list(self._events) + + def _get_step_cache_key(self, step_name: str) -> tuple[str, int]: + idx = self._step_call_counters.get(step_name, 0) + self._step_call_counters[step_name] = idx + 1 + return (step_name, idx) + + def _get_cached_result(self, key: tuple[str, int]) -> tuple[bool, Any]: + if key in self._step_cache: + return True, self._step_cache[key] + return False, None + + def _set_cached_result(self, key: tuple[str, int], value: Any) -> None: + self._step_cache[key] = value + + def _set_responses(self, responses: dict[str, Any]) -> None: + self._responses = dict(responses) + + def _get_response(self, request_id: str) -> tuple[bool, Any]: + if request_id in self._responses: + return True, self._responses[request_id] + return False, None + + def _export_step_cache(self) -> dict[str, Any]: + """Serialize the step cache for checkpointing. + + Converts tuple keys to strings for JSON compatibility. + """ + return {f"{name}::{idx}": val for (name, idx), val in self._step_cache.items()} + + def _import_step_cache(self, data: dict[str, Any]) -> None: + """Restore step cache from checkpoint data.""" + self._step_cache = {} + for k, v in data.items(): + try: + name, idx_str = k.rsplit("::", 1) + self._step_cache[name, int(idx_str)] = v + except (ValueError, TypeError) as exc: + raise ValueError( + f"Corrupted step cache entry in checkpoint: key={k!r}. " + f"The checkpoint may be from an incompatible version or corrupted. " + f"Original error: {exc}" + ) from exc + + +# --------------------------------------------------------------------------- +# StepWrapper +# --------------------------------------------------------------------------- + + +class StepWrapper(Generic[R]): + """Wrapper returned by the ``@step`` decorator. + + When called inside a running ``@workflow`` function, the wrapper + intercepts execution to provide: + + * **Caching** — results are cached by ``(step_name, call_index)`` so + that HITL replay and checkpoint restore skip already-completed work. + * **Event emission** — ``executor_invoked`` / ``executor_completed`` / + ``executor_failed`` events are emitted for observability. + * **Per-step checkpointing** — a checkpoint is saved after each live + execution when checkpoint storage is configured. + + Outside a workflow the wrapper is transparent: it delegates directly to + the original function, making decorated functions fully testable in + isolation. + + Args: + func: The async function to wrap. + name: Optional display name. Defaults to ``func.__name__``. + + Raises: + TypeError: If *func* is not an async (coroutine) function. + """ + + def __init__(self, func: Callable[..., Awaitable[R]], *, name: str | None = None) -> None: + if not inspect.iscoroutinefunction(func): + raise TypeError( + f"@step can only decorate async functions, but '{func.__name__}' is not a coroutine function." + ) + self._func = func + self.name: str = name or func.__name__ + functools.update_wrapper(self, func) + + async def __call__(self, *args: Any, **kwargs: Any) -> R: + ctx = _active_run_ctx.get() + if ctx is None: + # Outside a workflow — pass through directly + return await self._func(*args, **kwargs) + + cache_key = ctx._get_step_cache_key(self.name) + found, cached = ctx._get_cached_result(cache_key) + invocation_data = deepcopy({"args": args, "kwargs": kwargs}) if args or kwargs else None + if found: + # Replay path: emit events and return cached result + ctx._add_event(WorkflowEvent.executor_invoked(self.name, invocation_data)) + ctx._add_event(WorkflowEvent.executor_completed(self.name, cached)) + return cached # type: ignore[return-value, no-any-return] + + # Live execution path + ctx._add_event(WorkflowEvent.executor_invoked(self.name, invocation_data)) + try: + result = await self._func(*args, **kwargs) + except Exception as exc: + ctx._add_event(WorkflowEvent.executor_failed(self.name, WorkflowErrorDetails.from_exception(exc))) + raise + ctx._set_cached_result(cache_key, result) + ctx._add_event(WorkflowEvent.executor_completed(self.name, result)) + if ctx._on_step_completed is not None: + await ctx._on_step_completed() + return result + + +# --------------------------------------------------------------------------- +# @step decorator +# --------------------------------------------------------------------------- + + +@overload +def step(func: Callable[..., Awaitable[R]]) -> StepWrapper[R]: ... + + +@overload +def step(*, name: str | None = None) -> Callable[[Callable[..., Awaitable[R]]], StepWrapper[R]]: ... + + +def step( + func: Callable[..., Awaitable[Any]] | None = None, + *, + name: str | None = None, +) -> StepWrapper[Any] | Callable[[Callable[..., Awaitable[Any]]], StepWrapper[Any]]: + """Decorator that marks an async function as a tracked workflow step. + + Supports both bare ``@step`` and parameterized ``@step(name="custom")`` + forms. Inside a running ``@workflow`` function, calls to a step are + intercepted for result caching, event emission, and per-step + checkpointing. Outside a workflow the decorated function behaves + identically to the original, making it fully testable in isolation. + + The ``@step`` decorator is **optional**. Plain async functions work + inside ``@workflow`` without it; use ``@step`` only when you need + caching, checkpointing, or observability for a particular call. + + Args: + func: The async function to decorate (when using the bare + ``@step`` form). + name: Optional display name for the step. Defaults to the + function's ``__name__``. + + Returns: + A :class:`StepWrapper` (bare form) or a decorator that produces + one (parameterized form). + + Raises: + TypeError: If the decorated function is not async. + + Examples: + + .. code-block:: python + + @step + async def fetch_data(url: str) -> dict: + return await http_get(url) + + + @step(name="transform") + async def transform_data(raw: dict) -> str: + return json.dumps(raw) + """ + if func is not None: + return StepWrapper(func, name=name) + + def _decorator(fn: Callable[..., Awaitable[Any]]) -> StepWrapper[Any]: + return StepWrapper(fn, name=name) + + return _decorator + + +# --------------------------------------------------------------------------- +# FunctionalWorkflow +# --------------------------------------------------------------------------- + + +class FunctionalWorkflow: + """A workflow backed by a user-defined async function. + + Created by the :func:`workflow` decorator. Exposes the same ``run()`` + interface as graph-based :class:`Workflow` objects, returning a + :class:`WorkflowRunResult` (or a :class:`ResponseStream` in streaming + mode). + + The underlying function is executed directly — no graph compilation or + edge wiring is involved. Native Python control flow (``if``/``else``, + ``for``, ``asyncio.gather``) is used for branching and parallelism. + + Args: + func: The async function that implements the workflow logic. + name: Display name for the workflow. Defaults to ``func.__name__``. + description: Optional human-readable description. + checkpoint_storage: Default :class:`CheckpointStorage` used for + persisting step results and state between runs. Can be + overridden per-run via the *checkpoint_storage* parameter of + :meth:`run`. + + Examples: + + .. code-block:: python + + @workflow + async def my_pipeline(data: str) -> str: + return await to_upper(data) + + + result = await my_pipeline.run("hello") + print(result.get_outputs()) # ['HELLO'] + """ + + def __init__( + self, + func: Callable[..., Awaitable[Any]], + *, + name: str | None = None, + description: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + ) -> None: + self._func = func + self.name = name or func.__name__ + self.description = description + self._checkpoint_storage = checkpoint_storage + self._is_running = False + # Last message used to invoke the workflow (for replay on resume) + self._last_message: Any = None + # Step cache from the last run (for response-only replay without checkpoint) + self._last_step_cache: dict[tuple[str, int], Any] = {} + + # Discover step names referenced in the function for signature hash + self._step_names = self._discover_step_names(func) + + # Compute a stable signature hash + self.graph_signature_hash = self._compute_signature_hash() + + functools.update_wrapper(self, func) # type: ignore[arg-type] + + # ------------------------------------------------------------------ + # run() — same overloaded interface as graph Workflow + # ------------------------------------------------------------------ + + @overload + def run( + self, + message: Any | None = None, + *, + stream: Literal[True], + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + **kwargs: Any, + ) -> ResponseStream[WorkflowEvent[Any], WorkflowRunResult]: ... + + @overload + def run( + self, + message: Any | None = None, + *, + stream: Literal[False] = ..., + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + include_status_events: bool = False, + **kwargs: Any, + ) -> Awaitable[WorkflowRunResult]: ... + + def run( + self, + message: Any | None = None, + *, + stream: bool = False, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + include_status_events: bool = False, + **kwargs: Any, + ) -> ResponseStream[WorkflowEvent[Any], WorkflowRunResult] | Awaitable[WorkflowRunResult]: + """Run the functional workflow. + + Exactly one of *message*, *responses*, or *checkpoint_id* must be + provided. Use *message* for a fresh run, *responses* to resume + after a HITL interruption, or *checkpoint_id* to restore from a + previously saved checkpoint. + + Args: + message: Input data passed as the first positional argument to + the workflow function. + stream: If ``True``, return a :class:`ResponseStream` that + yields :class:`WorkflowEvent` instances as they are produced. + responses: HITL responses keyed by ``request_id``, used to + resume a workflow that was suspended by + :meth:`RunContext.request_info`. + checkpoint_id: Identifier of a checkpoint to restore from. + Requires *checkpoint_storage* to be set (here or on the + decorator). + checkpoint_storage: Override the default checkpoint storage + for this run. + include_status_events: When ``True`` (non-streaming only), + include status-change events in the result. + + Keyword Args: + **kwargs: Extra keyword arguments stored on + :attr:`RunContext._run_kwargs` and accessible to step + functions. + + Returns: + A :class:`WorkflowRunResult` (non-streaming) or a + :class:`ResponseStream` (streaming). + + Raises: + ValueError: If the combination of *message*, *responses*, and + *checkpoint_id* is invalid. + RuntimeError: If the workflow is already running (concurrent + execution is not allowed). + """ + self._validate_run_params(message, responses, checkpoint_id) + self._ensure_not_running() + + response_stream: ResponseStream[WorkflowEvent[Any], WorkflowRunResult] = ResponseStream( + self._run_core( + message=message, + responses=responses, + checkpoint_id=checkpoint_id, + checkpoint_storage=checkpoint_storage, + streaming=stream, + **kwargs, + ), + finalizer=functools.partial(self._finalize_events, include_status_events=include_status_events), + cleanup_hooks=[self._run_cleanup], + ) + + if stream: + return response_stream + return response_stream.get_final_response() + + # ------------------------------------------------------------------ + # As agent + # ------------------------------------------------------------------ + + def as_agent(self, name: str | None = None) -> FunctionalWorkflowAgent: + """Wrap this workflow as an agent-compatible object. + + The returned :class:`FunctionalWorkflowAgent` exposes a ``run()`` + method that delegates to the workflow and converts the first output + into an :class:`AgentResponse`. + + Args: + name: Display name for the agent. Defaults to the workflow name. + + Returns: + A :class:`FunctionalWorkflowAgent` wrapping this workflow. + """ + return FunctionalWorkflowAgent(workflow=self, name=name) + + # ------------------------------------------------------------------ + # Internal execution + # ------------------------------------------------------------------ + + async def _run_core( + self, + message: Any | None = None, + *, + responses: dict[str, Any] | None = None, + checkpoint_id: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, + streaming: bool = False, + **kwargs: Any, + ) -> AsyncIterable[WorkflowEvent[Any]]: + storage = checkpoint_storage or self._checkpoint_storage + + # Build context + ctx = RunContext(self.name, streaming=streaming, run_kwargs=kwargs if kwargs else None) + + # Restore from checkpoint if requested + prev_checkpoint_id: str | None = None + if checkpoint_id is not None: + if storage is None: + raise ValueError( + "Cannot restore from checkpoint without checkpoint_storage. " + "Provide checkpoint_storage parameter or set it on the @workflow decorator." + ) + checkpoint = await storage.load(checkpoint_id) + if checkpoint.graph_signature_hash != self.graph_signature_hash: + raise ValueError( + f"Checkpoint '{checkpoint_id}' was created by a different version of workflow " + f"'{checkpoint.workflow_name}' and is not compatible with the current version. " + f"The workflow's step structure may have changed since this checkpoint was saved." + ) + prev_checkpoint_id = checkpoint_id + # Restore step cache + step_cache_data = checkpoint.state.get("_step_cache", {}) + ctx._import_step_cache(step_cache_data) + # Restore user state + ctx._state = {k: v for k, v in checkpoint.state.items() if not k.startswith("_")} + # Restore pending request info events + ctx._pending_requests = dict(checkpoint.pending_request_info_events) + # Restore original message for replay + if message is None: + message = checkpoint.state.get("_original_message") + + # For response-only replay (no checkpoint), restore cached state + if checkpoint_id is None and responses: + if message is None: + message = self._last_message + ctx._step_cache = dict(self._last_step_cache) + + # Store message for future replays + if message is not None: + self._last_message = message + + # Set responses for replay + if responses: + ctx._set_responses(responses) + + # Wire up per-step checkpointing + # Use a mutable list so the closure can update prev_checkpoint_id + ckpt_chain: list[str | None] = [prev_checkpoint_id] + if storage is not None: + + async def _on_step_completed() -> None: + ckpt_chain[0] = await self._save_checkpoint(ctx, storage, ckpt_chain[0]) + + ctx._on_step_completed = _on_step_completed + + # Tracing + attributes: dict[str, Any] = {OtelAttr.WORKFLOW_NAME: self.name} + if self.description: + attributes[OtelAttr.WORKFLOW_DESCRIPTION] = self.description + + with create_workflow_span(OtelAttr.WORKFLOW_RUN_SPAN, attributes) as span: + saw_request = False + try: + span.add_event(OtelAttr.WORKFLOW_STARTED) + + with _framework_event_origin(): + yield WorkflowEvent.started() + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IN_PROGRESS) + + # Execute the user function + return_value = await self._execute(ctx, message) + + # Emit the return value as the workflow output. + if return_value is not None: + ctx._add_event(WorkflowEvent.output(self.name, return_value)) + + # Persist step cache for response-only replay + self._last_step_cache = dict(ctx._step_cache) + + # Yield collected events + for event in ctx._get_events(): + if event.type == "request_info": + saw_request = True + yield event + if event.type == "request_info": + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS) + + # Save final checkpoint if storage is available + if storage is not None: + await self._save_checkpoint(ctx, storage, ckpt_chain[0]) + + # Final status + if saw_request: + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) + else: + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IDLE) + + span.add_event(OtelAttr.WORKFLOW_COMPLETED) + + except WorkflowInterrupted: + # Persist step cache for response-only replay + self._last_step_cache = dict(ctx._step_cache) + + # HITL interruption — yield events collected so far + for event in ctx._get_events(): + if event.type == "request_info": + saw_request = True + yield event + if event.type == "request_info": + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS) + + # Save checkpoint + if storage is not None: + await self._save_checkpoint(ctx, storage, ckpt_chain[0]) + + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS) + + span.add_event(OtelAttr.WORKFLOW_COMPLETED) + + except Exception as exc: + # Yield any events collected before the failure + for event in ctx._get_events(): + yield event + + details = WorkflowErrorDetails.from_exception(exc) + with _framework_event_origin(): + yield WorkflowEvent.failed(details) + with _framework_event_origin(): + yield WorkflowEvent.status(WorkflowRunState.FAILED) + + span.add_event( + name=OtelAttr.WORKFLOW_ERROR, + attributes={ + "error.message": str(exc), + "error.type": type(exc).__name__, + }, + ) + capture_exception(span, exception=exc) + raise + + async def _execute(self, ctx: RunContext, message: Any) -> Any: + """Run the user's async function with the active context.""" + token = _active_run_ctx.set(ctx) + try: + sig = inspect.signature(self._func) + params = list(sig.parameters.values()) + + # Resolve string annotations to actual types + try: + hints = typing.get_type_hints(self._func) + except Exception as exc: + logger.warning( + "Failed to resolve type hints for workflow function '%s': %s. " + "RunContext injection may not work if annotations are forward references.", + self._func.__name__, + exc, + ) + hints = {} + + # Build call arguments: inject RunContext and pass `message`. + # RunContext is detected by type annotation first, then by + # parameter name "ctx" — so both of these work: + # async def my_workflow(data: str, ctx: RunContext) -> str: + # async def my_workflow(data: str, ctx) -> str: + call_args: list[Any] = [] + message_injected = False + + for param in params: + resolved = hints.get(param.name, param.annotation) + if resolved is RunContext or param.name == "ctx": + call_args.append(ctx) + elif not message_injected: + # First non-ctx param gets the message + call_args.append(message) + message_injected = True + + return await self._func(*call_args) + finally: + _active_run_ctx.reset(token) + + # ------------------------------------------------------------------ + # Checkpoint helpers + # ------------------------------------------------------------------ + + async def _save_checkpoint( + self, + ctx: RunContext, + storage: CheckpointStorage, + previous_checkpoint_id: str | None = None, + ) -> str: + state = dict(ctx._state) + state["_step_cache"] = ctx._export_step_cache() + state["_original_message"] = self._last_message + + checkpoint = WorkflowCheckpoint( + workflow_name=self.name, + graph_signature_hash=self.graph_signature_hash, + previous_checkpoint_id=previous_checkpoint_id, + state=state, + pending_request_info_events=dict(ctx._pending_requests), + ) + return await storage.save(checkpoint) + + def _compute_signature_hash(self) -> str: + """Compute a stable hash from the workflow function name and step names.""" + sig_data = { + "workflow": self.name, + "steps": sorted(self._step_names), + } + import json + + canonical = json.dumps(sig_data, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest() + + @staticmethod + def _discover_step_names(func: Callable[..., Any]) -> list[str]: + """Extract step names referenced by the workflow function. + + Inspects the function's ``__code__.co_names`` and global scope for + ``StepWrapper`` instances. + """ + names: list[str] = [] + globs = getattr(func, "__globals__", {}) + code_names = getattr(getattr(func, "__code__", None), "co_names", ()) + for n in code_names: + obj = globs.get(n) + if isinstance(obj, StepWrapper): + names.append(obj.name) + return names + + # ------------------------------------------------------------------ + # Finalize / cleanup / validation (mirrors Workflow) + # ------------------------------------------------------------------ + + @staticmethod + def _finalize_events( + events: Sequence[WorkflowEvent[Any]], + *, + include_status_events: bool = False, + ) -> WorkflowRunResult: + filtered: list[WorkflowEvent[Any]] = [] + status_events: list[WorkflowEvent[Any]] = [] + + for ev in events: + if ev.type == "started": + continue + if ev.type == "status": + status_events.append(ev) + if include_status_events: + filtered.append(ev) + continue + filtered.append(ev) + + return WorkflowRunResult(filtered, status_events) + + @staticmethod + def _validate_run_params( + message: Any | None, + responses: dict[str, Any] | None, + checkpoint_id: str | None, + ) -> None: + if message is not None and responses is not None: + raise ValueError("Cannot provide both 'message' and 'responses'. Use one or the other.") + + if message is not None and checkpoint_id is not None: + raise ValueError("Cannot provide both 'message' and 'checkpoint_id'. Use one or the other.") + + if message is None and responses is None and checkpoint_id is None: + raise ValueError( + "Must provide at least one of: 'message' (new run), 'responses' (send responses), " + "or 'checkpoint_id' (resume from checkpoint)." + ) + + def _ensure_not_running(self) -> None: + if self._is_running: + raise RuntimeError("Workflow is already running. Concurrent executions are not allowed.") + self._is_running = True + + async def _run_cleanup(self) -> None: + self._is_running = False + + +# --------------------------------------------------------------------------- +# @workflow decorator +# --------------------------------------------------------------------------- + + +@overload +def workflow(func: Callable[..., Awaitable[Any]]) -> FunctionalWorkflow: ... + + +@overload +def workflow( + *, + name: str | None = None, + description: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, +) -> Callable[[Callable[..., Awaitable[Any]]], FunctionalWorkflow]: ... + + +def workflow( + func: Callable[..., Awaitable[Any]] | None = None, + *, + name: str | None = None, + description: str | None = None, + checkpoint_storage: CheckpointStorage | None = None, +) -> FunctionalWorkflow | Callable[[Callable[..., Awaitable[Any]]], FunctionalWorkflow]: + """Decorator that converts an async function into a :class:`FunctionalWorkflow`. + + Supports both bare ``@workflow`` and parameterized + ``@workflow(name="my_wf")`` forms. + + The decorated function receives its input as the first positional argument + and a :class:`RunContext` instance wherever a parameter is annotated with + that type. The resulting :class:`FunctionalWorkflow` object exposes the + same ``run()`` interface as graph-based workflows. + + Args: + func: The async function to decorate (when using the bare + ``@workflow`` form). + name: Display name for the workflow. Defaults to ``func.__name__``. + description: Optional human-readable description. + checkpoint_storage: Default :class:`CheckpointStorage` for + persisting step results and workflow state. + + Returns: + A :class:`FunctionalWorkflow` (bare form) or a decorator that + produces one (parameterized form). + + Examples: + + .. code-block:: python + + # Bare form + @workflow + async def pipeline(data: str) -> str: + return await process(data) + + + # Parameterized form + @workflow(name="my_pipeline", checkpoint_storage=storage) + async def pipeline(data: str) -> str: ... + """ + if func is not None: + return FunctionalWorkflow(func, name=name, description=description, checkpoint_storage=checkpoint_storage) + + def _decorator(fn: Callable[..., Awaitable[Any]]) -> FunctionalWorkflow: + return FunctionalWorkflow(fn, name=name, description=description, checkpoint_storage=checkpoint_storage) + + return _decorator + + +# --------------------------------------------------------------------------- +# FunctionalWorkflowAgent +# --------------------------------------------------------------------------- + + +class FunctionalWorkflowAgent: + """Agent adapter for a :class:`FunctionalWorkflow`. + + Provides a ``run()`` method with the same overloaded signature as + :class:`BaseAgent` — returning an :class:`AgentResponse` (non-streaming) + or a :class:`ResponseStream[AgentResponseUpdate, AgentResponse]` + (streaming), making functional workflows usable anywhere an + agent-compatible object is expected. + + Args: + workflow: The :class:`FunctionalWorkflow` to wrap. + name: Display name for the agent. Defaults to the workflow name. + """ + + def __init__(self, workflow: FunctionalWorkflow, *, name: str | None = None) -> None: + self._workflow = workflow + self.name = name or workflow.name + self.id = f"FunctionalWorkflowAgent_{self.name}" + self.description: str | None = workflow.description + + @overload + def run( + self, + messages: Any | None = None, + *, + stream: Literal[True], + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... + + @overload + def run( + self, + messages: Any | None = None, + *, + stream: Literal[False] = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse]: ... + + def run( + self, + messages: Any | None = None, + *, + stream: bool = False, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse] | Awaitable[AgentResponse]: + """Run the underlying workflow and return the result as an agent response. + + Args: + messages: Input data forwarded to :meth:`FunctionalWorkflow.run`. + + Keyword Args: + stream: If ``True``, return a :class:`ResponseStream` of + :class:`AgentResponseUpdate` items. + **kwargs: Extra keyword arguments forwarded to the workflow run. + + Returns: + An :class:`AgentResponse` (non-streaming) or a + :class:`ResponseStream` (streaming). + """ + if stream: + return self._run_streaming(messages, **kwargs) + return self._run_non_streaming(messages, **kwargs) + + async def _run_non_streaming(self, messages: Any | None, **kwargs: Any) -> AgentResponse: + result = await self._workflow.run(messages, **kwargs) + return self._result_to_agent_response(result) + + def _run_streaming(self, messages: Any | None, **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + from .._types import Content + + agent_name = self.name + workflow_stream = self._workflow.run(messages, stream=True, **kwargs) + + async def _generate_updates() -> AsyncIterable[AgentResponseUpdate]: + async for event in workflow_stream: + if event.type != "output": + continue + data = event.data + if isinstance(data, str): + contents = [Content.from_text(text=data)] + elif isinstance(data, Content): + contents = [data] + else: + contents = [Content.from_text(text=str(data))] + yield AgentResponseUpdate( + contents=contents, + role="assistant", + author_name=agent_name, + ) + + return ResponseStream( + _generate_updates(), + finalizer=AgentResponse.from_updates, + ) + + @staticmethod + def _result_to_agent_response(result: WorkflowRunResult) -> AgentResponse: + from .._types import Content + from .._types import Message as Msg + + messages: list[Msg] = [] + for output in result.get_outputs(): + if isinstance(output, str): + contents = [Content.from_text(text=output)] + elif isinstance(output, Content): + contents = [output] + else: + contents = [Content.from_text(text=str(output))] + messages.append(Msg("assistant", contents)) + return AgentResponse(messages=messages) diff --git a/python/packages/core/tests/workflow/test_functional_workflow.py b/python/packages/core/tests/workflow/test_functional_workflow.py new file mode 100644 index 0000000000..7325fcb745 --- /dev/null +++ b/python/packages/core/tests/workflow/test_functional_workflow.py @@ -0,0 +1,1045 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for the functional workflow API (@workflow, @step, RunContext).""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass + +import pytest + +from agent_framework import ( + FunctionalWorkflow, + InMemoryCheckpointStorage, + RunContext, + StepWrapper, + WorkflowRunResult, + WorkflowRunState, + step, + workflow, +) +from agent_framework._workflows._functional import ( + RunContext as _RunContext, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@step +async def add_one(x: int) -> int: + return x + 1 + + +@step +async def double(x: int) -> int: + return x * 2 + + +@step +async def to_upper(s: str) -> str: + return s.upper() + + +@step(name="custom_name") +async def named_step(x: int) -> int: + return x + 10 + + +@step +async def failing_step(x: int) -> int: + raise ValueError(f"step failed with {x}") + + +# --------------------------------------------------------------------------- +# Basic execution +# --------------------------------------------------------------------------- + + +class TestBasicExecution: + async def test_simple_sequential_pipeline(self): + @workflow + async def pipeline(x: int) -> int: + a = await add_one(x) + return await double(a) + + result = await pipeline.run(5) + assert isinstance(result, WorkflowRunResult) + outputs = result.get_outputs() + assert outputs == [12] # (5+1)*2 + + async def test_workflow_with_string_data(self): + @workflow + async def upper_pipeline(text: str) -> str: + return await to_upper(text) + + result = await upper_pipeline.run("hello") + assert result.get_outputs() == ["HELLO"] + + async def test_workflow_returns_result(self): + @workflow + async def simple(x: int) -> int: + return await add_one(x) + + result = await simple.run(10) + assert result.get_outputs() == [11] + + async def test_workflow_name_defaults_to_function_name(self): + @workflow + async def my_pipeline(x: int) -> int: + return x + + assert my_pipeline.name == "my_pipeline" + + async def test_workflow_custom_name(self): + @workflow(name="custom_wf", description="A test workflow") + async def wf(x: int) -> int: + return x + + assert wf.name == "custom_wf" + assert wf.description == "A test workflow" + + +# --------------------------------------------------------------------------- +# Event emission +# --------------------------------------------------------------------------- + + +class TestEventEmission: + async def test_step_events_emitted(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + result = await pipeline.run(5) + event_types = [e.type for e in result] + assert "executor_invoked" in event_types + assert "executor_completed" in event_types + assert "output" in event_types + + async def test_step_events_carry_executor_id(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + result = await pipeline.run(5) + invoked_events = [e for e in result if e.type == "executor_invoked"] + assert len(invoked_events) == 1 + assert invoked_events[0].executor_id == "add_one" + + completed_events = [e for e in result if e.type == "executor_completed"] + assert len(completed_events) == 1 + assert completed_events[0].executor_id == "add_one" + assert completed_events[0].data == 6 + + async def test_status_events_in_timeline(self): + @workflow + async def pipeline(x: int) -> int: + return x + + result = await pipeline.run(1) + states = [e.state for e in result.status_timeline()] + assert WorkflowRunState.IN_PROGRESS in states + assert WorkflowRunState.IDLE in states + + async def test_final_state_is_idle(self): + @workflow + async def pipeline(x: int) -> int: + return x + + result = await pipeline.run(1) + assert result.get_final_state() == WorkflowRunState.IDLE + + async def test_custom_event(self): + from agent_framework import WorkflowEvent + + @workflow + async def pipeline(x: int, ctx: RunContext) -> int: + await ctx.add_event(WorkflowEvent.emit("pipeline", "custom_data")) + return x + + result = await pipeline.run(1) + data_events = [e for e in result if e.type == "data"] + assert len(data_events) == 1 + assert data_events[0].data == "custom_data" + + +# --------------------------------------------------------------------------- +# Parallel execution +# --------------------------------------------------------------------------- + + +class TestParallelExecution: + async def test_parallel_tasks_with_gather(self): + @step + async def slow_add(x: int) -> int: + await asyncio.sleep(0.01) + return x + 1 + + @step + async def slow_double(x: int) -> int: + await asyncio.sleep(0.01) + return x * 2 + + @workflow + async def parallel_wf(x: int) -> list[int]: + a, b = await asyncio.gather(slow_add(x), slow_double(x)) + return [a, b] + + result = await parallel_wf.run(5) + outputs = result.get_outputs() + assert outputs == [[6, 10]] + + async def test_parallel_events_all_emitted(self): + @step + async def task_a(x: int) -> int: + return x + 1 + + @step + async def task_b(x: int) -> int: + return x * 2 + + @workflow + async def par_wf(x: int) -> tuple[int, int]: + a, b = await asyncio.gather(task_a(x), task_b(x)) + return (a, b) + + result = await par_wf.run(3) + invoked = [e for e in result if e.type == "executor_invoked"] + completed = [e for e in result if e.type == "executor_completed"] + assert len(invoked) == 2 + assert len(completed) == 2 + + +# --------------------------------------------------------------------------- +# HITL (request_info / resume) +# --------------------------------------------------------------------------- + + +class TestHITL: + async def test_request_info_interrupts(self): + @workflow + async def review_wf(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Final: {feedback}" + + # Phase 1: should interrupt with pending request + result = await review_wf.run("my doc") + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + request_events = result.get_request_info_events() + assert len(request_events) == 1 + assert request_events[0].request_id == "req1" + + async def test_request_info_resume(self): + @workflow + async def review_wf(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Final: {feedback}" + + # Phase 1 + result1 = await review_wf.run("my doc") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Phase 2: resume with response + result2 = await review_wf.run(responses={"req1": "Looks great!"}) + outputs = result2.get_outputs() + assert outputs == ["Final: Looks great!"] + assert result2.get_final_state() == WorkflowRunState.IDLE + + async def test_untyped_ctx_parameter(self): + """ctx is injected by parameter name even without a RunContext annotation.""" + + @workflow + async def review_wf(doc: str, ctx) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Final: {feedback}" + + result1 = await review_wf.run("my doc") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + result2 = await review_wf.run(responses={"req1": "LGTM"}) + assert result2.get_outputs() == ["Final: LGTM"] + + async def test_multiple_sequential_interrupts(self): + @workflow + async def multi_hitl(data: str, ctx: RunContext) -> str: + r1 = await ctx.request_info("step1", response_type=str, request_id="r1") + r2 = await ctx.request_info("step2", response_type=str, request_id="r2") + return f"{r1}+{r2}" + + # Phase 1: first interrupt + result1 = await multi_hitl.run("start") + assert len(result1.get_request_info_events()) == 1 + assert result1.get_request_info_events()[0].request_id == "r1" + + # Phase 2: respond to first, hits second + result2 = await multi_hitl.run(responses={"r1": "A"}) + assert len(result2.get_request_info_events()) == 1 + assert result2.get_request_info_events()[0].request_id == "r2" + + # Phase 3: respond to second + result3 = await multi_hitl.run(responses={"r1": "A", "r2": "B"}) + assert result3.get_outputs() == ["A+B"] + + async def test_request_info_auto_generates_id(self): + @workflow + async def auto_id_wf(x: int, ctx: RunContext) -> None: + await ctx.request_info("need data", response_type=str) + + result = await auto_id_wf.run(1) + events = result.get_request_info_events() + assert len(events) == 1 + assert events[0].request_id # should be a non-empty uuid string + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + async def test_step_failure_propagates(self): + @workflow + async def failing_wf(x: int) -> None: + await failing_step(x) + + with pytest.raises(ValueError, match="step failed with 42"): + await failing_wf.run(42) + + async def test_step_failure_emits_executor_failed(self): + @workflow + async def failing_wf(x: int) -> None: + await failing_step(x) + + # Use stream to collect events before the raise + stream = failing_wf.run(42, stream=True) + events: list = [] + with pytest.raises(ValueError): + async for event in stream: + events.append(event) + + failed_events = [e for e in events if e.type == "executor_failed"] + assert len(failed_events) == 1 + assert failed_events[0].executor_id == "failing_step" + + async def test_workflow_failure_emits_failed_status(self): + @workflow + async def bad_wf(x: int) -> None: + raise RuntimeError("workflow broke") + + stream = bad_wf.run(42, stream=True) + events: list = [] + with pytest.raises(RuntimeError, match="workflow broke"): + async for event in stream: + events.append(event) + + failed_events = [e for e in events if e.type == "failed"] + assert len(failed_events) == 1 + status_events = [e for e in events if e.type == "status"] + assert any(e.state == WorkflowRunState.FAILED for e in status_events) + + async def test_invalid_params_message_and_responses(self): + @workflow + async def wf(x: int) -> None: + pass + + with pytest.raises(ValueError, match="Cannot provide both"): + await wf.run("hello", responses={"r1": "val"}) + + async def test_invalid_params_message_and_checkpoint(self): + @workflow + async def wf(x: int) -> None: + pass + + with pytest.raises(ValueError, match="Cannot provide both"): + await wf.run("hello", checkpoint_id="abc") + + async def test_invalid_params_nothing(self): + @workflow + async def wf(x: int) -> None: + pass + + with pytest.raises(ValueError, match="Must provide at least one"): + await wf.run() + + +# --------------------------------------------------------------------------- +# Streaming +# --------------------------------------------------------------------------- + + +class TestStreaming: + async def test_streaming_yields_events(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + stream = pipeline.run(5, stream=True) + events = [] + async for event in stream: + events.append(event) + + event_types = [e.type for e in events] + assert "started" in event_types + assert "executor_invoked" in event_types + assert "executor_completed" in event_types + assert "output" in event_types + + async def test_streaming_final_response(self): + @workflow + async def pipeline(x: int) -> int: + return await add_one(x) + + stream = pipeline.run(5, stream=True) + final = await stream.get_final_response() + assert isinstance(final, WorkflowRunResult) + assert final.get_outputs() == [6] + + async def test_streaming_context_reports_streaming(self): + streaming_flag = None + + @workflow + async def wf(x: int, ctx: RunContext) -> int: + nonlocal streaming_flag + streaming_flag = ctx.is_streaming() + return x + + stream = wf.run(1, stream=True) + await stream.get_final_response() + assert streaming_flag is True + + streaming_flag = None + await wf.run(1) + assert streaming_flag is False + + +# --------------------------------------------------------------------------- +# Step passthrough outside workflow +# --------------------------------------------------------------------------- + + +class TestStepPassthrough: + async def test_step_works_outside_workflow(self): + result = await add_one(10) + assert result == 11 + + async def test_named_step_outside_workflow(self): + result = await named_step(5) + assert result == 15 + + def test_step_wrapper_name(self): + assert add_one.name == "add_one" + assert named_step.name == "custom_name" + + def test_step_wrapper_is_step_wrapper(self): + assert isinstance(add_one, StepWrapper) + assert isinstance(named_step, StepWrapper) + + +# --------------------------------------------------------------------------- +# State management +# --------------------------------------------------------------------------- + + +class TestStateManagement: + async def test_get_set_state(self): + @workflow + async def stateful_wf(x: int, ctx: RunContext) -> int: + ctx.set_state("counter", x) + return ctx.get_state("counter") + + result = await stateful_wf.run(42) + assert result.get_outputs() == [42] + + async def test_get_state_default(self): + @workflow + async def wf(x: int, ctx: RunContext) -> str: + return ctx.get_state("missing", "default_val") + + result = await wf.run(1) + assert result.get_outputs() == ["default_val"] + + +# --------------------------------------------------------------------------- +# Checkpointing +# --------------------------------------------------------------------------- + + +class TestCheckpointing: + async def test_checkpoint_save_and_restore(self): + storage = InMemoryCheckpointStorage() + + @step + async def expensive(x: int) -> int: + return x * 100 + + @workflow(checkpoint_storage=storage) + async def ckpt_wf(x: int) -> int: + return await expensive(x) + + result = await ckpt_wf.run(5) + assert result.get_outputs() == [500] + + # Verify checkpoints were saved: 1 per-step + 1 final + checkpoints = await storage.list_checkpoints(workflow_name="ckpt_wf") + assert len(checkpoints) == 2 + + async def test_checkpoint_runtime_storage_override(self): + storage = InMemoryCheckpointStorage() + + @step + async def compute(x: int) -> int: + return x + 1 + + @workflow + async def wf(x: int) -> int: + return await compute(x) + + result = await wf.run(10, checkpoint_storage=storage) + assert result.get_outputs() == [11] + # 1 per-step checkpoint + 1 final checkpoint + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert len(checkpoints) == 2 + + async def test_checkpoint_restore_replays_cached_tasks(self): + storage = InMemoryCheckpointStorage() + call_count = 0 + + @step + async def counting_task(x: int) -> int: + nonlocal call_count + call_count += 1 + return x + 1 + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> int: + return await counting_task(x) + + # First run + result1 = await wf.run(5) + assert result1.get_outputs() == [6] + assert call_count == 1 + + # Get checkpoint ID + checkpoints = await storage.list_checkpoints(workflow_name="wf") + ckpt_id = checkpoints[0].checkpoint_id + + # Restore — step should replay from cache + result2 = await wf.run(checkpoint_id=ckpt_id) + assert result2.get_outputs() == [6] + assert call_count == 1 # not called again + + async def test_checkpoint_hitl_resume(self): + storage = InMemoryCheckpointStorage() + + @workflow(checkpoint_storage=storage) + async def hitl_wf(doc: str, ctx: RunContext) -> str: + feedback = await ctx.request_info({"draft": doc}, response_type=str, request_id="req1") + return f"Done: {feedback}" + + # Phase 1: interrupt + result1 = await hitl_wf.run("draft text") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Get checkpoint + checkpoints = await storage.list_checkpoints(workflow_name="hitl_wf") + ckpt_id = checkpoints[0].checkpoint_id + + # Phase 2: restore and respond + result2 = await hitl_wf.run(checkpoint_id=ckpt_id, responses={"req1": "Approved!"}) + assert result2.get_outputs() == ["Done: Approved!"] + + async def test_checkpoint_without_storage_raises(self): + @workflow + async def wf(x: int) -> int: + return x + + with pytest.raises(ValueError, match="checkpoint_storage"): + await wf.run(checkpoint_id="nonexistent") + + async def test_checkpoint_preserves_state(self): + storage = InMemoryCheckpointStorage() + + @workflow(checkpoint_storage=storage) + async def stateful_wf(x: int, ctx: RunContext) -> str: + ctx.set_state("key", "value") + feedback = await ctx.request_info("need info", response_type=str, request_id="r1") + val = ctx.get_state("key") + return f"{val}:{feedback}" + + # Phase 1 + result1 = await stateful_wf.run(1) + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # Phase 2: restore and respond + checkpoints = await storage.list_checkpoints(workflow_name="stateful_wf") + ckpt_id = checkpoints[0].checkpoint_id + + result2 = await stateful_wf.run(checkpoint_id=ckpt_id, responses={"r1": "hello"}) + assert result2.get_outputs() == ["value:hello"] + + async def test_per_step_checkpoint_enables_crash_recovery(self): + """Simulates crash recovery: step 1 completes and is checkpointed, + then the workflow crashes in step 2. Restoring from the per-step + checkpoint should replay step 1 from cache without re-executing it.""" + storage = InMemoryCheckpointStorage() + step1_calls = 0 + step2_calls = 0 + + @step + async def slow_step1(x: int) -> int: + nonlocal step1_calls + step1_calls += 1 + return x + 10 + + @step + async def crashing_step2(x: int) -> int: + nonlocal step2_calls + step2_calls += 1 + if step2_calls == 1: + raise RuntimeError("simulated crash") + return x * 2 + + @workflow(checkpoint_storage=storage) + async def crash_wf(x: int) -> int: + a = await slow_step1(x) + return await crashing_step2(a) + + # First run: step1 succeeds and checkpoints, step2 crashes + with pytest.raises(RuntimeError, match="simulated crash"): + await crash_wf.run(5) + + assert step1_calls == 1 + assert step2_calls == 1 + + # A per-step checkpoint was saved after step1 completed + checkpoints = await storage.list_checkpoints(workflow_name="crash_wf") + assert len(checkpoints) >= 1 + ckpt_id = checkpoints[0].checkpoint_id + + # Restore from checkpoint: step1 replays from cache, step2 runs fresh + result = await crash_wf.run(checkpoint_id=ckpt_id) + assert result.get_outputs() == [30] # (5+10)*2 + assert step1_calls == 1 # NOT called again — replayed from cache + assert step2_calls == 2 # called again, succeeds this time + + async def test_per_step_checkpoint_chain(self): + """Each step creates a new checkpoint chained to the previous one.""" + storage = InMemoryCheckpointStorage() + + @step + async def s1(x: int) -> int: + return x + 1 + + @step + async def s2(x: int) -> int: + return x + 2 + + @step + async def s3(x: int) -> int: + return x + 3 + + @workflow(checkpoint_storage=storage) + async def multi_step_wf(x: int) -> int: + a = await s1(x) + b = await s2(a) + return await s3(b) + + result = await multi_step_wf.run(0) + assert result.get_outputs() == [6] # 0+1+2+3 + + # 3 per-step checkpoints + 1 final = 4 + checkpoints = await storage.list_checkpoints(workflow_name="multi_step_wf") + assert len(checkpoints) == 4 + + async def test_no_checkpoint_on_cache_hit(self): + """During replay, cached steps should NOT create additional checkpoints.""" + storage = InMemoryCheckpointStorage() + + @step + async def compute(x: int) -> int: + return x + 1 + + @workflow(checkpoint_storage=storage) + async def wf(x: int) -> int: + return await compute(x) + + # First run: 1 per-step + 1 final = 2 checkpoints + await wf.run(5) + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert len(checkpoints) == 2 + ckpt_id = checkpoints[0].checkpoint_id + + # Restore: step replays from cache (no new per-step checkpoint), + # but final checkpoint still saved = 1 new checkpoint + await wf.run(checkpoint_id=ckpt_id) + checkpoints = await storage.list_checkpoints(workflow_name="wf") + assert len(checkpoints) == 3 # 2 from first run + 1 final from restore + + +# --------------------------------------------------------------------------- +# Branching / control flow +# --------------------------------------------------------------------------- + + +class TestControlFlow: + async def test_if_else_branching(self): + @dataclass + class Classification: + is_spam: bool + + @step + async def classify(text: str) -> Classification: + return Classification(is_spam="spam" in text.lower()) + + @step + async def process_normal(text: str) -> str: + return f"processed: {text}" + + @step + async def quarantine(text: str) -> str: + return f"quarantined: {text}" + + @workflow + async def email_pipeline(email: str) -> str: + cl = await classify(email) + if cl.is_spam: + result = await quarantine(email) + else: + result = await process_normal(email) + return result + + result_spam = await email_pipeline.run("Buy spam now!") + assert result_spam.get_outputs() == ["quarantined: Buy spam now!"] + + result_normal = await email_pipeline.run("Hello friend") + assert result_normal.get_outputs() == ["processed: Hello friend"] + + +# --------------------------------------------------------------------------- +# Nested workflow calls +# --------------------------------------------------------------------------- + + +class TestNestedWorkflows: + async def test_nested_workflow_as_task(self): + @step + async def step_a(x: int) -> int: + return x + 1 + + @workflow + async def inner_wf(x: int) -> int: + return await step_a(x) + + @step + async def call_inner(x: int) -> int: + result = await inner_wf.run(x) + return result.get_outputs()[0] + + @workflow + async def outer_wf(x: int) -> int: + return await call_inner(x) + + result = await outer_wf.run(5) + assert result.get_outputs() == [6] + + +# --------------------------------------------------------------------------- +# as_agent() +# --------------------------------------------------------------------------- + + +class TestAsAgent: + async def test_as_agent_returns_agent(self): + @workflow + async def wf(x: int) -> str: + return f"result: {x}" + + agent = wf.as_agent() + assert agent.name == "wf" + + async def test_as_agent_custom_name(self): + @workflow + async def wf(x: int) -> int: + return x + + agent = wf.as_agent(name="my_agent") + assert agent.name == "my_agent" + + async def test_as_agent_run(self): + @workflow + async def wf(x: int) -> int: + return await add_one(x) + + agent = wf.as_agent() + response = await agent.run(10) + assert response.text == "11" + + async def test_as_agent_run_streaming(self): + @workflow + async def wf(x: int) -> str: + return f"result: {x}" + + agent = wf.as_agent() + stream = agent.run(10, stream=True) + updates = [] + async for update in stream: + updates.append(update) + assert len(updates) == 1 + assert updates[0].text == "result: 10" + + response = await stream.get_final_response() + assert len(response.messages) >= 1 + + async def test_as_agent_has_id_and_description(self): + @workflow(description="A test workflow") + async def wf(x: int) -> int: + return x + + agent = wf.as_agent(name="my_agent") + assert agent.id == "FunctionalWorkflowAgent_my_agent" + assert agent.description == "A test workflow" + + +# --------------------------------------------------------------------------- +# Concurrent execution guard +# --------------------------------------------------------------------------- + + +class TestConcurrencyGuard: + async def test_concurrent_run_raises(self): + @workflow + async def slow_wf(x: int) -> int: + await asyncio.sleep(0.1) + return x + + # Start first run + stream = slow_wf.run(1, stream=True) + + # Try to start second run while first is active + with pytest.raises(RuntimeError, match="already running"): + slow_wf.run(2, stream=True) + + # Consume the stream to clean up + await stream.get_final_response() + + async def test_run_after_completion(self): + @workflow + async def wf(x: int) -> int: + return x + + result1 = await wf.run(1) + assert result1.get_outputs() == [1] + + # Should be able to run again after first completes + result2 = await wf.run(2) + assert result2.get_outputs() == [2] + + +# --------------------------------------------------------------------------- +# Decorator forms +# --------------------------------------------------------------------------- + + +class TestDecoratorForms: + def test_step_bare_decorator(self): + @step + async def my_step(x: int) -> int: + return x + + assert isinstance(my_step, StepWrapper) + assert my_step.name == "my_step" + + def test_step_with_name(self): + @step(name="renamed") + async def my_step(x: int) -> int: + return x + + assert isinstance(my_step, StepWrapper) + assert my_step.name == "renamed" + + def test_workflow_bare_decorator(self): + @workflow + async def my_wf(x: int) -> None: + pass + + assert isinstance(my_wf, FunctionalWorkflow) + assert my_wf.name == "my_wf" + + def test_workflow_with_params(self): + @workflow(name="custom", description="desc") + async def my_wf(x: int) -> None: + pass + + assert isinstance(my_wf, FunctionalWorkflow) + assert my_wf.name == "custom" + assert my_wf.description == "desc" + + +# --------------------------------------------------------------------------- +# include_status_events +# --------------------------------------------------------------------------- + + +class TestIncludeStatusEvents: + async def test_status_events_excluded_by_default(self): + @workflow + async def wf(x: int) -> int: + return x + + result = await wf.run(1) + status_in_list = [e for e in result if e.type == "status"] + assert len(status_in_list) == 0 + + async def test_status_events_included_when_requested(self): + @workflow + async def wf(x: int) -> int: + return x + + result = await wf.run(1, include_status_events=True) + status_in_list = [e for e in result if e.type == "status"] + assert len(status_in_list) > 0 + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + async def test_workflow_with_no_tasks(self): + @workflow + async def no_tasks(x: int) -> int: + return x * 2 + + result = await no_tasks.run(5) + assert result.get_outputs() == [10] + + async def test_workflow_with_no_output(self): + @workflow + async def silent_wf(x: int) -> None: + pass # returns None — no output emitted + + result = await silent_wf.run(5) + assert result.get_outputs() == [] + + async def test_return_value_auto_yields_output(self): + """Returning a non-None value automatically emits it as an output.""" + + @workflow + async def wf(x: int) -> int: + return x * 3 + + result = await wf.run(5) + assert result.get_outputs() == [15] + + async def test_step_called_multiple_times(self): + @workflow + async def wf(x: int) -> int: + a = await add_one(x) + b = await add_one(a) + return await add_one(b) + + result = await wf.run(0) + assert result.get_outputs() == [3] # 0+1+1+1 + + # Should have 3 invoked and 3 completed events for add_one + invoked = [e for e in result if e.type == "executor_invoked"] + completed = [e for e in result if e.type == "executor_completed"] + assert len(invoked) == 3 + assert len(completed) == 3 + + +# --------------------------------------------------------------------------- +# Recovery after errors +# --------------------------------------------------------------------------- + + +class TestRecoveryAfterErrors: + async def test_run_after_failure_is_allowed(self): + @workflow + async def wf(x: int) -> int: + if x == 1: + raise RuntimeError("boom") + return x + + with pytest.raises(RuntimeError, match="boom"): + await wf.run(1) + + # Must be able to run again after the failure + result = await wf.run(2) + assert result.get_outputs() == [2] + + async def test_step_sync_function_raises(self): + with pytest.raises(TypeError, match="async functions"): + + @step + def not_async(x: int) -> int: + return x + + +# --------------------------------------------------------------------------- +# WorkflowInterrupted is BaseException +# --------------------------------------------------------------------------- + + +class TestWorkflowInterruptedIsBaseException: + async def test_except_exception_does_not_catch_interrupt(self): + """User code with ``except Exception`` should not catch WorkflowInterrupted.""" + caught = False + + @workflow + async def wf(x: int, ctx: RunContext) -> str: + nonlocal caught + try: + return await ctx.request_info("need review", response_type=str, request_id="r1") + except Exception: + # This should NOT catch WorkflowInterrupted + caught = True + return "caught!" + + result = await wf.run("data") + # Should have a pending request, NOT "caught!" + assert result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + assert result.get_outputs() == [] + assert caught is False + + +# --------------------------------------------------------------------------- +# Checkpoint validation +# --------------------------------------------------------------------------- + + +class TestCheckpointValidation: + async def test_checkpoint_signature_mismatch_raises(self): + from agent_framework import WorkflowCheckpoint + + storage = InMemoryCheckpointStorage() + + @workflow(name="my_wf", checkpoint_storage=storage) + async def wf(x: int) -> int: + return x + + # Manually create a checkpoint with a different signature hash + bad_checkpoint = WorkflowCheckpoint( + workflow_name="my_wf", + graph_signature_hash="totally_different_hash", + state={"_step_cache": {}, "_original_message": 1}, + ) + ckpt_id = await storage.save(bad_checkpoint) + + # Should fail due to hash mismatch + with pytest.raises(ValueError, match="not compatible"): + await wf.run(checkpoint_id=ckpt_id) + + async def test_import_step_cache_malformed_key(self): + ctx = _RunContext("test") + with pytest.raises(ValueError, match="Corrupted step cache"): + ctx._import_step_cache({"invalid_key_no_separator": 42}) + + async def test_import_step_cache_non_integer_index(self): + ctx = _RunContext("test") + with pytest.raises(ValueError, match="Corrupted step cache"): + ctx._import_step_cache({"step_name::abc": 42}) diff --git a/python/samples/01-get-started/05_first_functional_workflow.py b/python/samples/01-get-started/05_first_functional_workflow.py new file mode 100644 index 0000000000..6fcd9dc0fe --- /dev/null +++ b/python/samples/01-get-started/05_first_functional_workflow.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +First Functional Workflow — Orchestrate async functions with @workflow + +The functional API lets you write workflows as plain Python async functions. +No graph concepts, no edges, no executor classes — just call functions +and use native control flow (if/else, loops, asyncio.gather). + +This sample builds a minimal pipeline with two steps: +1. Convert text to uppercase +2. Reverse the text + +No external services are required. +""" + +import asyncio + +from agent_framework import workflow + + +# Plain async functions — no decorators needed +async def to_upper_case(text: str) -> str: + """Convert input to uppercase.""" + return text.upper() + + +async def reverse_text(text: str) -> str: + """Reverse the string.""" + return text[::-1] + + +# +@workflow +async def text_pipeline(text: str) -> str: + """Uppercase the text, then reverse it.""" + upper = await to_upper_case(text) + return await reverse_text(upper) +# + + +async def main() -> None: + # + result = await text_pipeline.run("hello world") + print(f"Output: {result.get_outputs()}") + print(f"Final state: {result.get_final_state()}") + # + + """ + Expected output: + Output: ['DLROW OLLEH'] + Final state: WorkflowRunState.IDLE + """ + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/01-get-started/06_functional_workflow_with_agents.py b/python/samples/01-get-started/06_functional_workflow_with_agents.py new file mode 100644 index 0000000000..646689fde2 --- /dev/null +++ b/python/samples/01-get-started/06_functional_workflow_with_agents.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Functional Workflow with Agents — Call agents inside @workflow + +This sample shows how to call agents inside a functional workflow. +Agent calls are just regular async function calls — no special wrappers needed. + +Environment variables: + AZURE_OPENAI_ENDPOINT — Your Azure OpenAI endpoint + AZURE_OPENAI_API_VERSION — API version (e.g. 2025-04-01-preview) + AZURE_OPENAI_CHAT_DEPLOYMENT_NAME — Model deployment name (e.g. gpt-4o) +""" + +import asyncio + +from agent_framework import Agent, workflow +from agent_framework.azure import AzureOpenAIChatClient +from dotenv import load_dotenv + +load_dotenv() + +writer = Agent( + name="WriterAgent", + instructions="Write a short poem (4 lines max) about the given topic.", + client=AzureOpenAIChatClient(), +) + +reviewer = Agent( + name="ReviewerAgent", + instructions="Review the given poem in one sentence. Is it good?", + client=AzureOpenAIChatClient(), +) + + +@workflow +async def poem_pipeline(topic: str) -> str: + """Write a poem, then review it.""" + poem = (await writer.run(f"Write a poem about: {topic}")).text + review = (await reviewer.run(f"Review this poem: {poem}")).text + return f"Poem:\n{poem}\n\nReview: {review}" + + +async def main() -> None: + result = await poem_pipeline.run("a cat learning to code") + print(result.get_outputs()[0]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/01-get-started/05_first_workflow.py b/python/samples/01-get-started/07_first_graph_workflow.py similarity index 87% rename from python/samples/01-get-started/05_first_workflow.py rename to python/samples/01-get-started/07_first_graph_workflow.py index 89b4f608b2..8eb44e37b0 100644 --- a/python/samples/01-get-started/05_first_workflow.py +++ b/python/samples/01-get-started/07_first_graph_workflow.py @@ -12,9 +12,12 @@ from typing_extensions import Never """ -First Workflow — Chain executors with edges +First Graph Workflow — Chain executors with edges -This sample builds a minimal workflow with two steps: +The graph API gives you full control over execution topology: edges, +fan-out/fan-in, switch/case, and superstep-based checkpointing. + +This sample builds a minimal graph workflow with two steps: 1. Convert text to uppercase (class-based executor) 2. Reverse the text (function-based executor) diff --git a/python/samples/01-get-started/06_host_your_agent.py b/python/samples/01-get-started/08_host_your_agent.py similarity index 100% rename from python/samples/01-get-started/06_host_your_agent.py rename to python/samples/01-get-started/08_host_your_agent.py diff --git a/python/samples/01-get-started/README.md b/python/samples/01-get-started/README.md index 5ba119e016..1d0b41d9c7 100644 --- a/python/samples/01-get-started/README.md +++ b/python/samples/01-get-started/README.md @@ -24,8 +24,10 @@ export AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME="gpt-4o" # optional, defaults to | 2 | [02_add_tools.py](02_add_tools.py) | Define a function tool with `@tool` and attach it to an agent. | | 3 | [03_multi_turn.py](03_multi_turn.py) | Keep conversation history across turns with `AgentThread`. | | 4 | [04_memory.py](04_memory.py) | Add dynamic context with a custom `ContextProvider`. | -| 5 | [05_first_workflow.py](05_first_workflow.py) | Chain executors into a workflow with edges. | -| 6 | [06_host_your_agent.py](06_host_your_agent.py) | Host a single agent with Azure Functions. | +| 5 | [05_first_functional_workflow.py](05_first_functional_workflow.py) | Write a workflow as a plain async function. | +| 6 | [06_functional_workflow_with_agents.py](06_functional_workflow_with_agents.py) | Call agents inside a functional workflow. | +| 7 | [07_first_graph_workflow.py](07_first_graph_workflow.py) | Chain executors into a graph workflow with edges. | +| 8 | [08_host_your_agent.py](08_host_your_agent.py) | Host a single agent with Azure Functions. | Run any sample with: diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index c5abe202ac..8d7874f78d 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -30,6 +30,19 @@ Once comfortable with these, explore the rest of the samples below. ## Samples Overview (by directory) +### functional + +Write workflows as plain Python async functions — no graph concepts, no executor classes, no edges. Use native control flow (`if`/`else`, loops, `asyncio.gather`) for branching and parallelism. + +| Sample | File | Concepts | +|---|---|---| +| Basic Pipeline | [functional/basic_pipeline.py](./functional/basic_pipeline.py) | Sequential steps as plain async functions | +| Basic Streaming Pipeline | [functional/basic_streaming_pipeline.py](./functional/basic_streaming_pipeline.py) | Stream workflow events in real time with `run(stream=True)` | +| Parallel Pipeline | [functional/parallel_pipeline.py](./functional/parallel_pipeline.py) | Fan-out/fan-in with `asyncio.gather` | +| Steps and Checkpointing | [functional/steps_and_checkpointing.py](./functional/steps_and_checkpointing.py) | `@step` decorator for per-step checkpointing and observability | +| Human-in-the-Loop Review | [functional/hitl_review.py](./functional/hitl_review.py) | HITL with `ctx.request_info()` and replay | +| Agent Integration | [functional/agent_integration.py](./functional/agent_integration.py) | Calling agents inside workflow steps | + ### agents | Sample | File | Concepts | diff --git a/python/samples/03-workflows/functional/agent_integration.py b/python/samples/03-workflows/functional/agent_integration.py new file mode 100644 index 0000000000..317c332219 --- /dev/null +++ b/python/samples/03-workflows/functional/agent_integration.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Calling agents inside functional workflows. + +Agent calls work inside @workflow as plain function calls — no decorator needed. +Just call the agent and use the result. + +If you want per-step caching (so agent calls don't re-execute on HITL resume +or crash recovery), add @step. Since each agent call hits an LLM API (time + +money), @step is often worth it. But it's always opt-in. + +This sample also demonstrates .as_agent() to wrap a workflow as an agent. + +Environment variables: + AZURE_OPENAI_ENDPOINT — Your Azure OpenAI endpoint + AZURE_OPENAI_API_VERSION — API version (e.g. 2025-04-01-preview) + AZURE_OPENAI_CHAT_DEPLOYMENT_NAME — Model deployment name (e.g. gpt-4o) +""" + +import asyncio + +from agent_framework import Agent, step, workflow +from agent_framework.azure import AzureOpenAIChatClient +from dotenv import load_dotenv + +load_dotenv() + +# --------------------------------------------------------------------------- +# Create agents +# --------------------------------------------------------------------------- + +classifier_agent = Agent( + name="ClassifierAgent", + instructions=( + "Classify documents into one category: Technical, Legal, Marketing, or Scientific. " + "Reply with only the category name." + ), + client=AzureOpenAIChatClient(), +) + +writer_agent = Agent( + name="WriterAgent", + instructions="Summarize the given content in one sentence.", + client=AzureOpenAIChatClient(), +) + +reviewer_agent = Agent( + name="ReviewerAgent", + instructions="Review the given summary in one sentence. Is it accurate and complete?", + client=AzureOpenAIChatClient(), +) + +# --------------------------------------------------------------------------- +# Simplest approach: call agents directly inside the workflow. +# No @step, no wrappers — just plain function calls. +# --------------------------------------------------------------------------- + + +@workflow +async def simple_pipeline(document: str) -> str: + """Process a document — agents called inline, no @step.""" + classification = (await classifier_agent.run(f"Classify this document: {document}")).text + summary = (await writer_agent.run(f"Summarize: {document}")).text + review = (await reviewer_agent.run(f"Review this summary: {summary}")).text + + return f"Classification: {classification}\nSummary: {summary}\nReview: {review}" + + +# --------------------------------------------------------------------------- +# With @step: agent results are cached. On HITL resume or checkpoint +# recovery, completed steps return their saved result instead of calling +# the LLM again. Worth it for expensive operations. +# --------------------------------------------------------------------------- + + +@step +async def classify_document(doc: str) -> str: + return (await classifier_agent.run(f"Classify this document: {doc}")).text + + +@step +async def generate_summary(doc: str) -> str: + return (await writer_agent.run(f"Summarize: {doc}")).text + + +@step +async def review_summary(summary: str) -> str: + return (await reviewer_agent.run(f"Review this summary: {summary}")).text + + +@workflow +async def cached_pipeline(document: str) -> str: + """Same pipeline, but @step caches each agent call.""" + classification = await classify_document(document) + summary = await generate_summary(document) + review = await review_summary(summary) + + return f"Classification: {classification}\nSummary: {summary}\nReview: {review}" + + +async def main(): + # Simple version — agents called inline + result = await simple_pipeline.run("This is a technical document about machine learning...") + print(result.get_outputs()[0]) + + # .as_agent() wraps the workflow so it can be used anywhere an agent + # is expected — for example, as a node in a graph workflow. + agent = cached_pipeline.as_agent(name="doc_processor") + response = await agent.run("A short story about a robot learning to paint.") + print(f"\nAs agent: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/basic_pipeline.py b/python/samples/03-workflows/functional/basic_pipeline.py new file mode 100644 index 0000000000..81514da53a --- /dev/null +++ b/python/samples/03-workflows/functional/basic_pipeline.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Basic sequential pipeline using the functional workflow API. + +The simplest possible workflow: plain async functions orchestrated by @workflow. +No @step decorator needed — just write Python. +""" + +import asyncio + +from agent_framework import workflow + + +# These are plain async functions — no decorators needed. +# They run normally inside the workflow, just like any other Python function. +async def fetch_data(url: str) -> dict[str, str | int]: + """Simulate fetching data from a URL.""" + return {"url": url, "content": f"Data from {url}", "status": 200} + + +async def transform_data(data: dict[str, str | int]) -> str: + """Transform raw data into a summary string.""" + return f"[{data['status']}] {data['content']}" + + +# @workflow turns this async function into a FunctionalWorkflow object. +# Without it, this is just a normal async function. With it, you get: +# - .run() that returns a WorkflowRunResult with events and outputs +# - .run(stream=True) for streaming events in real time +# - .as_agent() to use this workflow anywhere an agent is expected +# +# The function's first parameter receives the input from .run("..."). +# Add a `ctx: RunContext` parameter only if you need HITL, state, or custom events. +@workflow +async def data_pipeline(url: str) -> str: + """A simple sequential data pipeline.""" + raw = await fetch_data(url) + summary = await transform_data(raw) + + # This is just a function — plain Python works between calls. + # No need to wrap every operation in a separate async function. + is_valid = len(summary) > 0 and "[200]" in summary + tag = "VALID" if is_valid else "INVALID" + + # Returning a value automatically emits it as an output. + # Callers retrieve it via result.get_outputs(). + return f"[{tag}] {summary}" + + +async def main(): + # .run() is provided by @workflow — a plain async function wouldn't have it + result = await data_pipeline.run("https://example.com/api/data") + print("Output:", result.get_outputs()[0]) + print("State:", result.get_final_state()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/basic_streaming_pipeline.py b/python/samples/03-workflows/functional/basic_streaming_pipeline.py new file mode 100644 index 0000000000..91fed0ff4c --- /dev/null +++ b/python/samples/03-workflows/functional/basic_streaming_pipeline.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Basic streaming pipeline using the functional workflow API. + +Stream workflow events in real time with run(stream=True). +""" + +import asyncio + +from agent_framework import workflow + + +# Plain async functions — no decorators needed for simple helpers. +async def fetch_data(url: str) -> dict[str, str | int]: + """Simulate fetching data from a URL.""" + return {"url": url, "content": f"Data from {url}", "status": 200} + + +async def transform_data(data: dict[str, str | int]) -> str: + """Transform raw data into a summary string.""" + return f"[{data['status']}] {data['content']}" + + +async def validate_result(summary: str) -> bool: + """Validate the transformed result.""" + return len(summary) > 0 and "[200]" in summary + + +# @workflow enables .run(stream=True), which returns a ResponseStream +# you can iterate over with `async for`. Without @workflow, you'd just +# have a normal async function with no streaming capability. +@workflow +async def data_pipeline(url: str) -> str: + """A simple sequential data pipeline.""" + raw = await fetch_data(url) + summary = await transform_data(raw) + is_valid = await validate_result(summary) + + return f"{summary} (valid={is_valid})" + + +async def main(): + # run(stream=True) returns a ResponseStream that yields events as they + # are produced. The raw stream includes lifecycle events (started, status) + # alongside application events — filter by event.type to find what you need. + stream = data_pipeline.run("https://example.com/api/data", stream=True) + async for event in stream: + if event.type == "output": + print(f"Output: {event.data}") + + # After iteration, get_final_response() returns the WorkflowRunResult + result = await stream.get_final_response() + print(f"Final state: {result.get_final_state()}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/hitl_review.py b/python/samples/03-workflows/functional/hitl_review.py new file mode 100644 index 0000000000..e0340f2b03 --- /dev/null +++ b/python/samples/03-workflows/functional/hitl_review.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Human-in-the-loop review pipeline using functional workflows. + +Demonstrates ctx.request_info() for pausing the workflow to wait for +external input and resuming with run(responses={...}). + +HITL works with or without @step. The difference is what happens on resume: +- Without @step: every function re-executes from the top (fine for cheap calls). +- With @step: completed functions return their saved result instantly. + +This sample uses @step on write_draft() because it simulates an expensive +operation that shouldn't re-run just because the workflow was paused. +""" + +import asyncio + +from agent_framework import RunContext, WorkflowRunState, step, workflow + + +# @step saves the result. When the workflow resumes after the HITL pause, +# this returns its saved result instead of running the expensive operation again. +@step +async def write_draft(topic: str) -> str: + """Simulate writing a draft — expensive, shouldn't re-run on resume.""" + print(f" write_draft executing for '{topic}'") + return f"Draft document about '{topic}': Lorem ipsum dolor sit amet..." + + +@step +async def revise_draft(draft: str, feedback: str) -> str: + """Revise the draft based on feedback.""" + return f"Revised: {draft[:50]}... [Applied feedback: {feedback}]" + + +@workflow +async def review_pipeline(topic: str, ctx: RunContext) -> str: + """Write a draft, get human review, then revise.""" + draft = await write_draft(topic) + + # ctx.request_info() suspends the workflow here. The caller gets back + # a WorkflowRunResult with state IDLE_WITH_PENDING_REQUESTS and can + # inspect the pending request via result.get_request_info_events(). + feedback = await ctx.request_info( + {"draft": draft, "instructions": "Please review this draft"}, + response_type=str, + request_id="review_request", + ) + + # This only executes after the caller resumes with run(responses={...}). + # write_draft above returns its saved result (thanks to @step), + # request_info returns the provided response, and we continue here. + return await revise_draft(draft, feedback) + + +async def main(): + # Phase 1: Run until the workflow pauses for human input + print("=== Phase 1: Initial run ===") + result1 = await review_pipeline.run("AI Safety") + + # If request_info() was reached, the state is IDLE_WITH_PENDING_REQUESTS. + # If the workflow completed without hitting request_info(), it would be IDLE. + print(f"State: {result1.get_final_state()}") + assert result1.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + requests = result1.get_request_info_events() + print(f"Pending request: {requests[0].request_id}") + + # Phase 2: Resume with the human's response + print("\n=== Phase 2: Resume with feedback ===") + print("(write_draft should NOT execute again — saved by @step)") + result2 = await review_pipeline.run(responses={"review_request": "Add more details about alignment research"}) + + print(f"State: {result2.get_final_state()}") + print(f"Output: {result2.get_outputs()[0]}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/parallel_pipeline.py b/python/samples/03-workflows/functional/parallel_pipeline.py new file mode 100644 index 0000000000..e88657b2f0 --- /dev/null +++ b/python/samples/03-workflows/functional/parallel_pipeline.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Parallel pipeline using asyncio.gather with functional workflows. + +Fan-out/fan-in uses native Python concurrency via asyncio.gather. +No @step needed — still just plain async functions. +""" + +import asyncio + +from agent_framework import workflow + + +# Plain async functions — asyncio.gather handles the concurrency, +# no framework primitives needed for parallelism. +async def research_web(topic: str) -> str: + """Simulate web research.""" + await asyncio.sleep(0.05) + return f"Web results for '{topic}': 10 articles found" + + +async def research_papers(topic: str) -> str: + """Simulate academic paper search.""" + await asyncio.sleep(0.05) + return f"Papers on '{topic}': 3 relevant papers" + + +async def research_news(topic: str) -> str: + """Simulate news search.""" + await asyncio.sleep(0.05) + return f"News about '{topic}': 5 recent articles" + + +async def synthesize(sources: list[str]) -> str: + """Combine research results into a summary.""" + return "Research Summary:\n" + "\n".join(f" - {s}" for s in sources) + + +# @workflow wraps the orchestration logic so you get .run(), streaming, +# and events. The functions it calls are plain Python — no decorators +# needed just because they're inside a workflow. +@workflow +async def research_pipeline(topic: str) -> str: + """Fan-out to three research tasks, then synthesize results.""" + # asyncio.gather runs all three concurrently — this is standard Python, + # not a framework concept. Use it the same way you would anywhere else. + web, papers, news = await asyncio.gather( + research_web(topic), + research_papers(topic), + research_news(topic), + ) + + return await synthesize([web, papers, news]) + + +async def main(): + result = await research_pipeline.run("AI agents") + print(result.get_outputs()[0]) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/03-workflows/functional/steps_and_checkpointing.py b/python/samples/03-workflows/functional/steps_and_checkpointing.py new file mode 100644 index 0000000000..93a9423df0 --- /dev/null +++ b/python/samples/03-workflows/functional/steps_and_checkpointing.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Introducing @step: per-step checkpointing and observability. + +The previous samples used plain functions — and that works. Workflows support +HITL (ctx.request_info) and checkpointing regardless of whether you use @step. + +The difference: without @step, a resumed workflow re-executes every function +call from the top. That's fine for cheap functions. But for expensive operations +(API calls, agent runs, etc.) you don't want to pay that cost again. + +@step saves each function's result so it skips re-execution on resume: +- On HITL resume, completed steps return their saved result instantly. +- On crash recovery from a checkpoint, earlier step results are restored. +- Each step emits executor_invoked/executor_completed events for observability. + +@step is opt-in. Plain functions still work alongside @step in the same workflow. +""" + +import asyncio + +from agent_framework import InMemoryCheckpointStorage, step, workflow + +# Track call counts to show which functions actually execute on resume +fetch_calls = 0 +transform_calls = 0 + + +# @step saves this function's result. On resume, it returns the saved +# result instead of re-executing — useful because this is expensive. +@step +async def fetch_data(url: str) -> dict[str, str | int]: + """Expensive operation — @step prevents re-execution on resume.""" + global fetch_calls + fetch_calls += 1 + print(f" fetch_data called (call #{fetch_calls})") + return {"url": url, "content": f"Data from {url}", "status": 200} + + +@step +async def transform_data(data: dict[str, str | int]) -> str: + """Another expensive operation — @step saves the result.""" + global transform_calls + transform_calls += 1 + print(f" transform_data called (call #{transform_calls})") + return f"[{data['status']}] {data['content']}" + + +# No @step — this is cheap, so it just re-runs on resume. That's fine. +async def validate_result(summary: str) -> bool: + """Cheap validation — no @step needed.""" + return len(summary) > 0 and "[200]" in summary + + +storage = InMemoryCheckpointStorage() + + +# checkpoint_storage tells @workflow where to persist step results. +# Each @step saves a checkpoint after it completes. +@workflow(checkpoint_storage=storage) +async def data_pipeline(url: str) -> str: + """Mix of @step functions and plain functions.""" + raw = await fetch_data(url) + summary = await transform_data(raw) + is_valid = await validate_result(summary) + + return f"{summary} (valid={is_valid})" + + +async def main(): + # --- Run 1: Everything executes normally --- + print("=== Run 1: Fresh execution ===") + result = await data_pipeline.run("https://example.com/api/data") + print(f"Output: {result.get_outputs()[0]}") + print(f"fetch_calls={fetch_calls}, transform_calls={transform_calls}") + + # @step functions emit executor events; plain functions don't. + print("\nEvents:") + for event in result: + if event.type in ("executor_invoked", "executor_completed"): + print(f" {event.type}: {event.executor_id}") + + # --- Run 2: Restore from checkpoint --- + # The workflow re-executes, but @step functions return saved results. + # Only validate_result() (no @step) actually runs again. + print("\n=== Run 2: Restored from checkpoint ===") + latest = await storage.get_latest(workflow_name="data_pipeline") + assert latest is not None + + result2 = await data_pipeline.run(checkpoint_id=latest.checkpoint_id) + print(f"Output: {result2.get_outputs()[0]}") + print(f"fetch_calls={fetch_calls}, transform_calls={transform_calls}") + print("(call counts unchanged — @step results were restored from checkpoint)") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/README.md b/python/samples/README.md index 148ace320d..51df93212b 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -20,8 +20,10 @@ Start with `01-get-started/` and work through the numbered files: 2. **[02_add_tools.py](./01-get-started/02_add_tools.py)** — Add function tools with `@tool` 3. **[03_multi_turn.py](./01-get-started/03_multi_turn.py)** — Multi-turn conversations with `AgentThread` 4. **[04_memory.py](./01-get-started/04_memory.py)** — Agent memory with `ContextProvider` -5. **[05_first_workflow.py](./01-get-started/05_first_workflow.py)** — Build a workflow with executors and edges -6. **[06_host_your_agent.py](./01-get-started/06_host_your_agent.py)** — Host your agent via A2A +5. **[05_first_functional_workflow.py](./01-get-started/05_first_functional_workflow.py)** — Write a workflow as a plain async function +6. **[06_functional_workflow_with_agents.py](./01-get-started/06_functional_workflow_with_agents.py)** — Call agents inside a functional workflow +7. **[07_first_graph_workflow.py](./01-get-started/07_first_graph_workflow.py)** — Build a workflow with executors and edges +8. **[08_host_your_agent.py](./01-get-started/08_host_your_agent.py)** — Host your agent via A2A ## Prerequisites