Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/agentex/lib/core/clients/temporal/temporal_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from temporalio.client import Client, WorkflowExecutionStatus
from temporalio.common import RetryPolicy as TemporalRetryPolicy, WorkflowIDReusePolicy
from temporalio.service import RPCError, RPCStatusCode
from temporalio.converter import PayloadCodec
from temporalio.converter import PayloadCodec, DataConverter

from agentex.lib.utils.logging import make_logger
from agentex.lib.utils.model_utils import BaseModel
Expand Down Expand Up @@ -78,11 +78,16 @@

class TemporalClient:
def __init__(
self, temporal_client: Client | None = None, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None
self,
temporal_client: Client | None = None,
plugins: list[Any] = [],
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
):
self._client: Client | None = temporal_client
self._plugins = plugins
self._payload_codec = payload_codec
self._data_converter = data_converter

@property
def client(self) -> Client:
Expand All @@ -92,7 +97,13 @@ def client(self) -> Client:
return self._client

@classmethod
async def create(cls, temporal_address: str, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None):
async def create(
cls,
temporal_address: str,
plugins: list[Any] = [],
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
):
if temporal_address in [
"false",
"False",
Expand All @@ -105,8 +116,13 @@ async def create(cls, temporal_address: str, plugins: list[Any] = [], payload_co
]:
_client = None
else:
_client = await get_temporal_client(temporal_address, plugins=plugins, payload_codec=payload_codec)
return cls(_client, plugins, payload_codec)
_client = await get_temporal_client(
temporal_address,
plugins=plugins,
payload_codec=payload_codec,
data_converter=data_converter,
)
return cls(_client, plugins, payload_codec, data_converter)

async def setup(self, temporal_address: str):
self._client = await self._get_temporal_client(temporal_address=temporal_address)
Expand All @@ -124,7 +140,12 @@ async def _get_temporal_client(self, temporal_address: str) -> Client | None:
]:
return None
else:
return await get_temporal_client(temporal_address, plugins=self._plugins, payload_codec=self._payload_codec)
return await get_temporal_client(
temporal_address,
plugins=self._plugins,
payload_codec=self._payload_codec,
data_converter=self._data_converter,
)

async def start_workflow(
self,
Expand Down
43 changes: 31 additions & 12 deletions src/agentex/lib/core/clients/temporal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from temporalio.client import Client, Plugin as ClientPlugin
from temporalio.worker import Interceptor
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
from temporalio.converter import PayloadCodec
from temporalio.converter import PayloadCodec, DataConverter
from temporalio.contrib.pydantic import pydantic_data_converter

# class DateTimeJSONEncoder(AdvancedJSONEncoder):
Expand Down Expand Up @@ -86,6 +86,7 @@ async def get_temporal_client(
metrics_url: str | None = None,
plugins: list[Any] = [],
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
) -> Client:
"""
Create a Temporal client with plugin integration.
Expand All @@ -94,7 +95,14 @@ async def get_temporal_client(
temporal_address: Temporal server address
metrics_url: Optional metrics endpoint URL
plugins: List of Temporal plugins to include
payload_codec: Optional payload codec for encoding/decoding payloads (e.g. encryption, compression)
payload_codec: Optional payload codec for encoding/decoding payloads
(e.g. encryption, compression). Cannot be combined with the
OpenAIAgentsPlugin via this kwarg — see ``data_converter``.
data_converter: Optional pre-built ``DataConverter``. Use this when
composing the OpenAIAgentsPlugin with a payload codec: build a
``DataConverter(payload_converter_class=OpenAIPayloadConverter,
payload_codec=...)`` and pass it here. Mutually exclusive with
``payload_codec``.

Returns:
Configured Temporal client
Expand All @@ -103,29 +111,40 @@ async def get_temporal_client(
if plugins:
validate_client_plugins(plugins)

# Check if OpenAI plugin is present - it needs to configure its own data converter
if payload_codec is not None and data_converter is not None:
raise ValueError(
"Pass payload_codec inside `data_converter` "
"(DataConverter(..., payload_codec=...)) instead of as a separate "
"kwarg. Specifying both is ambiguous."
)

# Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin

has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []))

if has_openai_plugin and payload_codec is not None:
if has_openai_plugin and payload_codec is not None and data_converter is None:
raise ValueError(
"payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin "
"installs its own data converter and the codec would be silently ignored, "
"leaving payloads unencoded. Remove one or the other."
"payload_codec passed as a kwarg alongside OpenAIAgentsPlugin would "
"be silently dropped by the plugin's data-converter transformer. "
"Build a DataConverter explicitly with "
"`payload_converter_class=OpenAIPayloadConverter` (or a subclass) "
"and `payload_codec=...`, then pass it via the `data_converter` "
"kwarg instead."
)

connect_kwargs = {
connect_kwargs: dict[str, Any] = {
"target_host": temporal_address,
"plugins": plugins,
}

if not has_openai_plugin:
data_converter = pydantic_data_converter
if payload_codec:
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
if data_converter is not None:
connect_kwargs["data_converter"] = data_converter
elif not has_openai_plugin:
dc = pydantic_data_converter
if payload_codec:
dc = dataclasses.replace(dc, payload_codec=payload_codec)
connect_kwargs["data_converter"] = dc
Comment thread
greptile-apps[bot] marked this conversation as resolved.

if not metrics_url:
client = await Client.connect(**connect_kwargs)
Expand Down
37 changes: 25 additions & 12 deletions src/agentex/lib/core/temporal/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,35 +95,45 @@ async def get_temporal_client(
metrics_url: str | None = None,
plugins: list = [],
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
) -> Client:
if plugins != []: # We don't need to validate the plugins if they are empty
_validate_plugins(plugins)

# Check if OpenAI plugin is present - it needs to configure its own data converter
if payload_codec is not None and data_converter is not None:
raise ValueError(
"Pass payload_codec inside `data_converter` "
"(DataConverter(..., payload_codec=...)) instead of as a separate "
"kwarg. Specifying both is ambiguous."
)

# Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin

has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []))

if has_openai_plugin and payload_codec is not None:
if has_openai_plugin and payload_codec is not None and data_converter is None:
raise ValueError(
"payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin "
"installs its own data converter and the codec would be silently ignored, "
"leaving payloads unencoded. Remove one or the other."
"payload_codec passed as a kwarg alongside OpenAIAgentsPlugin would "
"be silently dropped by the plugin's data-converter transformer. "
"Build a DataConverter explicitly with "
"`payload_converter_class=OpenAIPayloadConverter` (or a subclass) "
"and `payload_codec=...`, then pass it via the `data_converter` "
"kwarg instead."
)

# Build connection kwargs
connect_kwargs = {
connect_kwargs: dict[str, Any] = {
"target_host": temporal_address,
"plugins": plugins,
}

# Only set data_converter if OpenAI plugin is not present
if not has_openai_plugin:
data_converter = custom_data_converter
if payload_codec:
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
if data_converter is not None:
connect_kwargs["data_converter"] = data_converter
elif not has_openai_plugin:
dc = custom_data_converter
if payload_codec:
dc = dataclasses.replace(dc, payload_codec=payload_codec)
connect_kwargs["data_converter"] = dc

if not metrics_url:
client = await Client.connect(**connect_kwargs)
Expand All @@ -145,6 +155,7 @@ def __init__(
interceptors: list = [],
metrics_url: str | None = None,
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
):
self.task_queue = task_queue
self.activity_handles = []
Expand All @@ -159,6 +170,7 @@ def __init__(
self.interceptors = interceptors
self.metrics_url = metrics_url
self.payload_codec = payload_codec
self.data_converter = data_converter

@overload
async def run(
Expand Down Expand Up @@ -195,6 +207,7 @@ async def run(
plugins=self.plugins,
metrics_url=self.metrics_url,
payload_codec=self.payload_codec,
data_converter=self.data_converter,
)

# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)
Expand Down
2 changes: 2 additions & 0 deletions src/agentex/lib/sdk/fastacp/fastacp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def create_async_acp(config: AsyncACPConfig, **kwargs) -> BaseACPServer:
temporal_config["interceptors"] = config.interceptors # type: ignore[attr-defined]
if hasattr(config, "payload_codec"):
temporal_config["payload_codec"] = config.payload_codec # type: ignore[attr-defined]
if hasattr(config, "data_converter"):
temporal_config["data_converter"] = config.data_converter # type: ignore[attr-defined]
return implementation_class.create(**temporal_config)
else:
return implementation_class.create(**kwargs)
Expand Down
16 changes: 13 additions & 3 deletions src/agentex/lib/sdk/fastacp/impl/temporal_acp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import asynccontextmanager

from fastapi import FastAPI
from temporalio.converter import PayloadCodec
from temporalio.converter import PayloadCodec, DataConverter

from agentex.protocol.acp import (
SendEventParams,
Expand Down Expand Up @@ -33,13 +33,15 @@ def __init__(
plugins: list[Any] | None = None,
interceptors: list[Any] | None = None,
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
):
super().__init__()
self._temporal_task_service = temporal_task_service
self._temporal_address = temporal_address
self._plugins = plugins or []
self._interceptors = interceptors or []
self._payload_codec = payload_codec
self._data_converter = data_converter

@classmethod
@override
Expand All @@ -49,12 +51,17 @@ def create(
plugins: list[Any] | None = None,
interceptors: list[Any] | None = None,
payload_codec: PayloadCodec | None = None,
data_converter: DataConverter | None = None,
) -> "TemporalACP":
logger.info("Initializing TemporalACP instance")

# Create instance without temporal client initially
temporal_acp = cls(
temporal_address=temporal_address, plugins=plugins, interceptors=interceptors, payload_codec=payload_codec
temporal_address=temporal_address,
plugins=plugins,
interceptors=interceptors,
payload_codec=payload_codec,
data_converter=data_converter,
)
temporal_acp._setup_handlers()
logger.info("TemporalACP instance initialized now")
Expand All @@ -71,7 +78,10 @@ async def lifespan(app: FastAPI):
if self._temporal_task_service is None:
env_vars = EnvironmentVariables.refresh()
temporal_client = await TemporalClient.create(
temporal_address=self._temporal_address, plugins=self._plugins, payload_codec=self._payload_codec
temporal_address=self._temporal_address,
plugins=self._plugins,
payload_codec=self._payload_codec,
data_converter=self._data_converter,
)
self._temporal_task_service = TemporalTaskService(
temporal_client=temporal_client,
Expand Down
24 changes: 22 additions & 2 deletions src/agentex/lib/types/fastacp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Literal

from pydantic import Field, BaseModel, field_validator
from pydantic import Field, BaseModel, field_validator, model_validator

from agentex.lib.core.clients.temporal.utils import validate_client_plugins, validate_worker_interceptors

Expand Down Expand Up @@ -56,14 +56,24 @@ class TemporalACPConfig(AsyncACPConfig):
encoding/decoding payloads (e.g. encryption, compression). NOTE:
this only configures the ACP (client) side. The worker side must
be configured separately via ``AgentexWorker(payload_codec=...)``
with the SAME codec, or decode will fail at runtime.
with the SAME codec, or decode will fail at runtime. Cannot be
combined with ``OpenAIAgentsPlugin``; use ``data_converter``
instead in that case.
data_converter: Optional pre-built ``temporalio.converter.DataConverter``.
Use this when composing the ``OpenAIAgentsPlugin`` with a payload
codec: build a ``DataConverter(payload_converter_class=
OpenAIPayloadConverter, payload_codec=...)`` and pass it here.
Mutually exclusive with ``payload_codec``. The worker side must
be configured separately via ``AgentexWorker(data_converter=...)``
with the SAME converter, or decode will fail at runtime.
"""

type: Literal["temporal"] = Field(default="temporal", frozen=True)
temporal_address: str = Field(default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True)
plugins: list[Any] = Field(default=[], frozen=True)
interceptors: list[Any] = Field(default=[], frozen=True)
payload_codec: Any = Field(default=None, frozen=True)
data_converter: Any = Field(default=None, frozen=True)
Comment thread
greptile-apps[bot] marked this conversation as resolved.

@field_validator("plugins")
@classmethod
Expand All @@ -79,6 +89,16 @@ def validate_interceptors(cls, v: list[Any]) -> list[Any]:
validate_worker_interceptors(v)
return v

@model_validator(mode="after")
def _validate_codec_and_data_converter_mutually_exclusive(self) -> "TemporalACPConfig":
if self.payload_codec is not None and self.data_converter is not None:
raise ValueError(
"Pass payload_codec inside `data_converter` "
"(DataConverter(..., payload_codec=...)) instead of as a separate "
"field. Specifying both is ambiguous."
)
return self


class AsyncBaseACPConfig(AsyncACPConfig):
"""Configuration for AsyncBaseACP implementation
Expand Down
Loading
Loading