From ad3ee78084de709e68598c90978976aebf169867 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 12:16:23 -0700 Subject: [PATCH 1/7] Add Nexus system payload codec rewrite support --- pyproject.toml | 16 +- scripts/gen_nexus_system_test_models.py | 75 ++ temporalio/bridge/worker.py | 46 +- temporalio/nexus/system/__init__.py | 41 + .../system/_workflow_service_generated.py | 821 ++++++++++++++++++ temporalio/worker/_command_aware_visitor.py | 22 +- tests/nexus/test_temporal_system_nexus.py | 178 ++++ 7 files changed, 1192 insertions(+), 7 deletions(-) create mode 100644 scripts/gen_nexus_system_test_models.py create mode 100644 temporalio/nexus/system/__init__.py create mode 100644 temporalio/nexus/system/_workflow_service_generated.py create mode 100644 tests/nexus/test_temporal_system_nexus.py diff --git a/pyproject.toml b/pyproject.toml index 4bcd3f03e..3c810ccc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,8 +67,14 @@ dev = [ ] [tool.poe.tasks] -build-develop = "uv run maturin develop --uv" -build-develop-with-release = { cmd = "uv run maturin develop --release --uv" } +build-develop = [ + { ref = "gen-nexus-system-test-models" }, + { cmd = "uv run maturin develop --uv" }, +] +build-develop-with-release = [ + { ref = "gen-nexus-system-test-models" }, + { cmd = "uv run maturin develop --release --uv" }, +] format = [ { cmd = "uv run ruff check --select I --fix" }, { cmd = "uv run ruff format" }, @@ -79,6 +85,7 @@ gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-test-models" }, { ref = "format" }, ] gen-protos-docker = [ @@ -98,10 +105,12 @@ bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" } # https://github.com/PyCQA/pydocstyle/pull/511? lint-docs = "uv run pydocstyle --ignore-decorators=overload" lint-types = [ + { ref = "gen-nexus-system-test-models" }, { cmd = "uv run pyright" }, { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, ] +gen-nexus-system-test-models = "uv run scripts/gen_nexus_system_test_models.py" run-bench = "uv run python scripts/run_bench.py" test = "uv run pytest" @@ -141,12 +150,14 @@ exclude = [ # Ignore generated code 'temporalio/api', 'temporalio/bridge/proto', + 'temporalio/nexus/system/_workflow_service_generated.py', ] [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match = "^(?!_workflow_service_generated\\.py$).*\\.py" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. # https://github.com/PyCQA/pydocstyle/issues/184 @@ -212,6 +223,7 @@ exclude = [ "temporalio/api", "temporalio/bridge/proto", "temporalio/bridge/_visitor.py", + "temporalio/nexus/system/_workflow_service_generated.py", "tests/worker/workflow_sandbox/testmodules/proto", ] diff --git a/scripts/gen_nexus_system_test_models.py b/scripts/gen_nexus_system_test_models.py new file mode 100644 index 000000000..d707727a5 --- /dev/null +++ b/scripts/gen_nexus_system_test_models.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + + +def main() -> None: + repo_root = Path(__file__).resolve().parent.parent + workspace_root = repo_root.parent + nexus_rpc_gen_root = workspace_root / "nexus-rpc-gen" / "src" + input_schema = ( + workspace_root + / "temporal-api" + / "nexus" + / "temporal-json-schema-models-nexusrpc.yaml" + ) + output_file = ( + repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + + if not nexus_rpc_gen_root.is_dir(): + raise RuntimeError(f"Expected nexus-rpc-gen checkout at {nexus_rpc_gen_root}") + if not input_schema.is_file(): + raise RuntimeError(f"Expected Temporal Nexus schema at {input_schema}") + + subprocess.run( + [ + "npm", + "run", + "cli", + "--", + "--lang", + "py", + "--out-file", + str(output_file), + "--temporal-nexus-payload-codec-support", + str(input_schema), + ], + cwd=nexus_rpc_gen_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + + +if __name__ == "__main__": + try: + main() + except Exception as err: + print(f"Failed to generate Nexus system test models: {err}", file=sys.stderr) + raise diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c2e426d28..52a54ce9c 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -20,7 +20,9 @@ import temporalio.bridge.runtime import temporalio.bridge.temporal_sdk_bridge import temporalio.converter +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.api.enums.v1.command_type_pb2 import CommandType from temporalio.bridge._visitor import VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, @@ -28,6 +30,7 @@ from temporalio.bridge.temporal_sdk_bridge import ( PollShutdownError, # type: ignore # noqa: F401 ) +from temporalio.worker import _command_aware_visitor from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor @@ -279,14 +282,50 @@ async def finalize_shutdown(self) -> None: class _Visitor(VisitorFunctions): - def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]]): + def __init__( + self, + f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + payload_codec: temporalio.converter.PayloadCodec | None = None, + ): self._f = f + self._payload_codec = payload_codec async def visit_payload(self, payload: Payload) -> None: + if self._payload_codec: + rewritten_payload = await self._maybe_rewrite_nexus_payload(payload) + if rewritten_payload is not None: + if rewritten_payload is not payload: + payload.CopyFrom(rewritten_payload) + return new_payload = (await self._f([payload]))[0] if new_payload is not payload: payload.CopyFrom(new_payload) + async def _maybe_rewrite_nexus_payload(self, payload: Payload) -> Payload | None: + command_info = _command_aware_visitor.current_command_info.get() + if ( + command_info is None + or command_info.command_type + != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION + or not command_info.nexus_service + or not command_info.nexus_operation + ): + return None + + rewrite = temporalio.nexus.system.get_payload_codec_rewriter( + command_info.nexus_service, + command_info.nexus_operation, + ) + if rewrite is None: + return None + + rewritten_payload = await rewrite(payload, self._payload_codec) + if not isinstance(rewritten_payload, Payload): + raise TypeError( + "temporal nexus payload codec rewriter must return a Payload" + ) + return rewritten_payload + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: return @@ -316,4 +355,7 @@ async def encode_completion( """Encode all payloads in the completion.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + ).visit( + _Visitor(data_converter._encode_payload_sequence, data_converter.payload_codec), + completion, + ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py new file mode 100644 index 000000000..7eff3a93d --- /dev/null +++ b/temporalio/nexus/system/__init__.py @@ -0,0 +1,41 @@ +"""Generated system Nexus service models. + +This package contains code generated from Temporal's system Nexus schemas. +Higher-level ergonomic APIs may wrap these generated types. +""" + +from collections.abc import Awaitable, Callable + +import temporalio.api.common.v1 +import temporalio.converter + +from ._workflow_service_generated import ( + WorkflowService, + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + __temporal_nexus_payload_codec_rewriters__, +) + +TemporalNexusPayloadCodecRewriter = Callable[ + [ + temporalio.api.common.v1.Payload, + temporalio.converter.PayloadCodec | None, + ], + Awaitable[temporalio.api.common.v1.Payload], +] + + +def get_payload_codec_rewriter( + service: str, + operation: str, +) -> TemporalNexusPayloadCodecRewriter | None: + """Return the generated payload codec rewriter for a system Nexus operation.""" + return __temporal_nexus_payload_codec_rewriters__.get((service, operation)) + + +__all__ = ( + "WorkflowService", + "WorkflowServiceSignalWithStartWorkflowExecutionInput", + "WorkflowServiceSignalWithStartWorkflowExecutionOutput", + "get_payload_codec_rewriter", +) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py new file mode 100644 index 000000000..1723631a1 --- /dev/null +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -0,0 +1,821 @@ +# Generated by nexus-rpc-gen. DO NOT EDIT! + +from __future__ import annotations + +import json +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from google.protobuf.json_format import MessageToDict, ParseDict +from nexusrpc import Operation, service +from pydantic import BaseModel, Field + +import temporalio.api.common.v1 +import temporalio.converter + + +async def _temporal_nexus_encode_payload_json( + value: dict, payload_codec: temporalio.converter.PayloadCodec +) -> dict: + payload = ParseDict(value, temporalio.api.common.v1.Payload()) + [encoded_payload] = await payload_codec.encode([payload]) + return MessageToDict(encoded_payload) + + +async def _temporal_nexus_encode_payloads_json( + value: dict, payload_codec: temporalio.converter.PayloadCodec +) -> dict: + payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) + encoded_payloads = await payload_codec.encode(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(encoded_payloads) + return MessageToDict(payloads) + + +async def _temporal_nexus_encode_payload_map_json( + message_type: type, value: dict, payload_codec: temporalio.converter.PayloadCodec +) -> dict: + message = ParseDict(value, message_type()) + keys = list(message.fields.keys()) + encoded_payloads = await payload_codec.encode([message.fields[key] for key in keys]) + for key, encoded_payload in zip(keys, encoded_payloads): + message.fields[key].CopyFrom(encoded_payload) + return MessageToDict(message) + + +async def _temporal_nexus_encode_json_value( + value: object, payload_codec: temporalio.converter.PayloadCodec +) -> object: + if isinstance(value, list): + return [ + await _temporal_nexus_encode_json_value(item, payload_codec) + for item in value + ] + if not isinstance(value, dict): + return value + if "indexedFields" in value: + return value + if "payloads" in value and isinstance(value["payloads"], list): + return await _temporal_nexus_encode_payloads_json(value, payload_codec) + if "fields" in value and isinstance(value["fields"], dict): + return await _temporal_nexus_encode_payload_map_json( + temporalio.api.common.v1.Header, value, payload_codec + ) + if "data" in value and "metadata" in value: + return await _temporal_nexus_encode_payload_json(value, payload_codec) + rewritten: dict[str, object] = {} + for key, item in value.items(): + rewritten[key] = ( + item + if key == "indexedFields" + else await _temporal_nexus_encode_json_value(item, payload_codec) + ) + return rewritten + + +class Header(BaseModel): + """Contains metadata that can be attached to a variety of requests, like starting a + workflow, and + can be propagated between, for example, workflows and activities. + """ + + fields: Optional[Dict[str, Any]] = None + + +class Input(BaseModel): + """Serialized arguments to the workflow. These are passed as arguments to the workflow + function. + + See `Payload` + + Serialized value(s) to provide with the signal + """ + + payloads: Optional[List[Any]] = None + + +class BatchJob(BaseModel): + """A link to a built-in batch job. + Batch jobs can be used to perform operations on a set of workflows (e.g. terminate, + signal, cancel, etc). + This link can be put on workflow history events generated by actions taken by a batch job. + """ + + job_id: Optional[str] = Field(None, alias="jobId") + + +class EventType(Enum): + EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY = ( + "EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY" + ) + EVENT_TYPE_ACTIVITY_TASK_CANCELED = "EVENT_TYPE_ACTIVITY_TASK_CANCELED" + EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED = ( + "EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED" + ) + EVENT_TYPE_ACTIVITY_TASK_COMPLETED = "EVENT_TYPE_ACTIVITY_TASK_COMPLETED" + EVENT_TYPE_ACTIVITY_TASK_FAILED = "EVENT_TYPE_ACTIVITY_TASK_FAILED" + EVENT_TYPE_ACTIVITY_TASK_SCHEDULED = "EVENT_TYPE_ACTIVITY_TASK_SCHEDULED" + EVENT_TYPE_ACTIVITY_TASK_STARTED = "EVENT_TYPE_ACTIVITY_TASK_STARTED" + EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT = "EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT" + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT" + ) + EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( + "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED" + ) + EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED = ( + "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED" + ) + EVENT_TYPE_MARKER_RECORDED = "EVENT_TYPE_MARKER_RECORDED" + EVENT_TYPE_NEXUS_OPERATION_CANCELED = "EVENT_TYPE_NEXUS_OPERATION_CANCELED" + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED" + ) + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED" + ) + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED" + ) + EVENT_TYPE_NEXUS_OPERATION_COMPLETED = "EVENT_TYPE_NEXUS_OPERATION_COMPLETED" + EVENT_TYPE_NEXUS_OPERATION_FAILED = "EVENT_TYPE_NEXUS_OPERATION_FAILED" + EVENT_TYPE_NEXUS_OPERATION_SCHEDULED = "EVENT_TYPE_NEXUS_OPERATION_SCHEDULED" + EVENT_TYPE_NEXUS_OPERATION_STARTED = "EVENT_TYPE_NEXUS_OPERATION_STARTED" + EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT = "EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT" + EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_TIMER_CANCELED = "EVENT_TYPE_TIMER_CANCELED" + EVENT_TYPE_TIMER_FIRED = "EVENT_TYPE_TIMER_FIRED" + EVENT_TYPE_TIMER_STARTED = "EVENT_TYPE_TIMER_STARTED" + EVENT_TYPE_UNSPECIFIED = "EVENT_TYPE_UNSPECIFIED" + EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES = ( + "EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED = "EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED" + EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED = "EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED" + EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_FAILED = "EVENT_TYPE_WORKFLOW_EXECUTION_FAILED" + EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED" + EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED = "EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED" + EVENT_TYPE_WORKFLOW_EXECUTION_STARTED = "EVENT_TYPE_WORKFLOW_EXECUTION_STARTED" + EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" + EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED" + ) + EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED = "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED" + EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY = ( + "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY" + ) + EVENT_TYPE_WORKFLOW_TASK_COMPLETED = "EVENT_TYPE_WORKFLOW_TASK_COMPLETED" + EVENT_TYPE_WORKFLOW_TASK_FAILED = "EVENT_TYPE_WORKFLOW_TASK_FAILED" + EVENT_TYPE_WORKFLOW_TASK_SCHEDULED = "EVENT_TYPE_WORKFLOW_TASK_SCHEDULED" + EVENT_TYPE_WORKFLOW_TASK_STARTED = "EVENT_TYPE_WORKFLOW_TASK_STARTED" + EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT = "EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT" + + +class EventRef(BaseModel): + """EventReference is a direct reference to a history event through the event ID.""" + + event_id: Optional[str] = Field(None, alias="eventId") + event_type: Optional[EventType] = Field(None, alias="eventType") + + +class RequestIDRef(BaseModel): + """RequestIdReference is a indirect reference to a history event through the request ID.""" + + event_type: Optional[EventType] = Field(None, alias="eventType") + request_id: Optional[str] = Field(None, alias="requestId") + + +class WorkflowEvent(BaseModel): + event_ref: Optional[EventRef] = Field(None, alias="eventRef") + namespace: Optional[str] = None + request_id_ref: Optional[RequestIDRef] = Field(None, alias="requestIdRef") + run_id: Optional[str] = Field(None, alias="runId") + workflow_id: Optional[str] = Field(None, alias="workflowId") + + +class Openapiv3(BaseModel): + """Link can be associated with history events. It might contain information about an + external entity + related to the history event. For example, workflow A makes a Nexus call that starts + workflow B: + in this case, a history event in workflow A could contain a Link to the workflow started + event in + workflow B, and vice-versa. + """ + + batch_job: Optional[BatchJob] = Field(None, alias="batchJob") + workflow_event: Optional[WorkflowEvent] = Field(None, alias="workflowEvent") + + +class Memo(BaseModel): + """A user-defined set of *unindexed* fields that are exposed when listing/searching workflows""" + + fields: Optional[Dict[str, Any]] = None + + +class Priority(BaseModel): + """Priority metadata + + Priority contains metadata that controls relative ordering of task processing + when tasks are backed up in a queue. Initially, Priority will be used in + matching (workflow and activity) task queues. Later it may be used in history + task queues and in rate limiting decisions. + + Priority is attached to workflows and activities. By default, activities + inherit Priority from the workflow that created them, but may override fields + when an activity is started or modified. + + Despite being named "Priority", this message also contains fields that + control "fairness" mechanisms. + + For all fields, the field not present or equal to zero/empty string means to + inherit the value from the calling workflow, or if there is no calling + workflow, then use the default value. + + For all fields other than fairness_key, the zero value isn't meaningful so + there's no confusion between inherit/default and a meaningful value. For + fairness_key, the empty string will be interpreted as "inherit". This means + that if a workflow has a non-empty fairness key, you can't override the + fairness key of its activity to the empty string. + + The overall semantics of Priority are: + 1. First, consider "priority": higher priority (lower number) goes first. + 2. Then, consider fairness: try to dispatch tasks for different fairness keys + in proportion to their weight. + + Applications may use any subset of mechanisms that are useful to them and + leave the other fields to use default values. + + Not all queues in the system may support the "full" semantics of all priority + fields. (Currently only support in matching task queues is planned.) + """ + + fairness_key: Optional[str] = Field(None, alias="fairnessKey") + """Fairness key is a short string that's used as a key for a fairness + balancing mechanism. It may correspond to a tenant id, or to a fixed + string like "high" or "low". The default is the empty string. + + The fairness mechanism attempts to dispatch tasks for a given key in + proportion to its weight. For example, using a thousand distinct tenant + ids, each with a weight of 1.0 (the default) will result in each tenant + getting a roughly equal share of task dispatch throughput. + + (Note: this does not imply equal share of worker capacity! Fairness + decisions are made based on queue statistics, not + current worker load.) + + As another example, using keys "high" and "low" with weight 9.0 and 1.0 + respectively will prefer dispatching "high" tasks over "low" tasks at a + 9:1 ratio, while allowing either key to use all worker capacity if the + other is not present. + + All fairness mechanisms, including rate limits, are best-effort and + probabilistic. The results may not match what a "perfect" algorithm with + infinite resources would produce. The more unique keys are used, the less + accurate the results will be. + + Fairness keys are limited to 64 bytes. + """ + fairness_weight: Optional[float] = Field(None, alias="fairnessWeight") + """Fairness weight for a task can come from multiple sources for + flexibility. From highest to lowest precedence: + 1. Weights for a small set of keys can be overridden in task queue + configuration with an API. + 2. It can be attached to the workflow/activity in this field. + 3. The default weight of 1.0 will be used. + + Weight values are clamped to the range [0.001, 1000]. + """ + priority_key: Optional[int] = Field(None, alias="priorityKey") + """Priority key is a positive integer from 1 to n, where smaller integers + correspond to higher priorities (tasks run sooner). In general, tasks in + a queue should be processed in close to priority order, although small + deviations are possible. + + The maximum priority value (minimum priority) is determined by server + configuration, and defaults to 5. + + If priority is not present (or zero), then the effective priority will be + the default priority, which is calculated by (min+max)/2. With the + default max of 5, and min of 1, that comes out to 3. + """ + + +class RetryPolicy(BaseModel): + """Retry policy for the workflow + + How retries ought to be handled, usable by both workflows and activities + """ + + backoff_coefficient: Optional[float] = Field(None, alias="backoffCoefficient") + """Coefficient used to calculate the next retry interval. + The next retry interval is previous interval multiplied by the coefficient. + Must be 1 or larger. + """ + initial_interval: Optional[str] = Field(None, alias="initialInterval") + """Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is used for all + retries. + """ + maximum_attempts: Optional[int] = Field(None, alias="maximumAttempts") + """Maximum number of attempts. When exceeded the retries stop even if not expired yet. + 1 disables retries. 0 means unlimited (up to the timeouts) + """ + maximum_interval: Optional[str] = Field(None, alias="maximumInterval") + """Maximum interval between retries. Exponential backoff leads to interval increase. + This value is the cap of the increase. Default is 100x of the initial interval. + """ + non_retryable_error_types: Optional[List[str]] = Field( + None, alias="nonRetryableErrorTypes" + ) + """Non-Retryable errors types. Will stop retrying if the error type matches this list. Note + that + this is not a substring match, the error *type* (not message) must match exactly. + """ + + +class SearchAttributes(BaseModel): + """A user-defined set of *indexed* fields that are used/exposed when listing/searching + workflows. + The payload is not serialized in a user-defined way. + """ + + indexed_fields: Optional[Dict[str, Any]] = Field(None, alias="indexedFields") + + +class Kind(Enum): + """Default: TASK_QUEUE_KIND_NORMAL.""" + + TASK_QUEUE_KIND_NORMAL = "TASK_QUEUE_KIND_NORMAL" + TASK_QUEUE_KIND_STICKY = "TASK_QUEUE_KIND_STICKY" + TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" + + +class TaskQueue(BaseModel): + """The task queue to start this workflow on, if it will be started + + See https://docs.temporal.io/docs/concepts/task-queues/ + """ + + kind: Optional[Kind] = None + """Default: TASK_QUEUE_KIND_NORMAL.""" + + name: Optional[str] = None + normal_name: Optional[str] = Field(None, alias="normalName") + """Iff kind == TASK_QUEUE_KIND_STICKY, then this field contains the name of + the normal task queue that the sticky worker is running on. + """ + + +class TimeSkippingConfig(BaseModel): + """Time-skipping configuration. If not set, time skipping is disabled. + + Configuration for time skipping during a workflow execution. + When enabled, virtual time advances automatically whenever there is no in-flight work. + In-flight work includes activities, child workflows, Nexus operations, signal/cancel + external workflow operations, + and possibly other features added in the future. + User timers are not classified as in-flight work and will be skipped over. + When time advances, it skips to the earlier of the next user timer or the configured + bound, if either exists. + """ + + disable_propagation: Optional[bool] = Field(None, alias="disablePropagation") + """If set, the enabled field is not propagated to transitively related workflows.""" + + enabled: Optional[bool] = None + """Enables or disables time skipping for this workflow execution. + By default, this field is propagated to transitively related workflows (child + workflows/start-as-new/reset) + at the time they are started. + Changes made after a transitively related workflow has started are not propagated. + """ + max_elapsed_duration: Optional[str] = Field(None, alias="maxElapsedDuration") + """Maximum elapsed time since time skipping was enabled. + This includes both skipped time and real time elapsing. + """ + max_skipped_duration: Optional[str] = Field(None, alias="maxSkippedDuration") + """Maximum total virtual time that can be skipped.""" + + max_target_time: Optional[datetime] = Field(None, alias="maxTargetTime") + """Absolute virtual timestamp at which time skipping is disabled. + Time skipping will not advance beyond this point. + """ + + +class UserMetadata(BaseModel): + """Metadata on the workflow if it is started. This is carried over to the + WorkflowExecutionInfo + for use by user interfaces to display the fixed as-of-start summary and details of the + workflow. + + Information a user can set, often for use by user interfaces. + """ + + details: Any + """Long-form text that provides details. This payload should be a "json/plain"-encoded + payload + that is a single JSON string for use in user interfaces. User interface formatting may + apply to + this text in common use. The payload data section is limited to 20000 bytes by default. + """ + summary: Any + """Short-form text that provides a summary. This payload should be a "json/plain"-encoded + payload + that is a single JSON string for use in user interfaces. User interface formatting may + not + apply to this text when used in "title" situations. The payload data section is limited + to 400 + bytes by default. + """ + + +class VersioningOverrideBehavior(Enum): + """Required. + Deprecated. Use `override`. + """ + + VERSIONING_BEHAVIOR_AUTO_UPGRADE = "VERSIONING_BEHAVIOR_AUTO_UPGRADE" + VERSIONING_BEHAVIOR_PINNED = "VERSIONING_BEHAVIOR_PINNED" + VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" + + +class Deployment(BaseModel): + """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. + Identifies the worker deployment to pin the workflow to. + Deprecated. Use `override.pinned.version`. + + `Deployment` identifies a deployment of Temporal workers. The combination of deployment + series + name + build ID serves as the identifier. User can use `WorkerDeploymentOptions` in their + worker + programs to specify these values. + Deprecated. + """ + + build_id: Optional[str] = Field(None, alias="buildId") + """Build ID changes with each version of the worker when the worker program code and/or + config + changes. + """ + series_name: Optional[str] = Field(None, alias="seriesName") + """Different versions of the same worker service/application are related together by having + a + shared series name. + Out of all deployments of a series, one can be designated as the current deployment, + which + receives new workflow executions and new tasks of workflows with + `VERSIONING_BEHAVIOR_AUTO_UPGRADE` versioning behavior. + """ + + +class PinnedBehavior(Enum): + """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. + See `PinnedOverrideBehavior` for details. + """ + + PINNED_OVERRIDE_BEHAVIOR_PINNED = "PINNED_OVERRIDE_BEHAVIOR_PINNED" + PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" + + +class Version(BaseModel): + """Specifies the Worker Deployment Version to pin this workflow to. + Required if the target workflow is not already pinned to a version. + + If omitted and the target workflow is already pinned, the effective + pinned version will be the existing pinned version. + + If omitted and the target workflow is not pinned, the override request + will be rejected with a PreconditionFailed error. + + A Worker Deployment Version (Version, for short) represents a + version of workers within a Worker Deployment. (see documentation of + WorkerDeploymentVersionInfo) + Version records are created in Temporal server automatically when their + first poller arrives to the server. + Experimental. Worker Deployment Versions are experimental and might significantly change + in the future. + """ + + build_id: Optional[str] = Field(None, alias="buildId") + """A unique identifier for this Version within the Deployment it is a part of. + Not necessarily unique within the namespace. + The combination of `deployment_name` and `build_id` uniquely identifies this + Version within the namespace, because Deployment names are unique within a namespace. + """ + deployment_name: Optional[str] = Field(None, alias="deploymentName") + """Identifies the Worker Deployment this Version is part of.""" + + +class Pinned(BaseModel): + """Override the workflow to have Pinned behavior.""" + + behavior: Optional[PinnedBehavior] = None + """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. + See `PinnedOverrideBehavior` for details. + """ + version: Optional[Version] = None + """Specifies the Worker Deployment Version to pin this workflow to. + Required if the target workflow is not already pinned to a version. + + If omitted and the target workflow is already pinned, the effective + pinned version will be the existing pinned version. + + If omitted and the target workflow is not pinned, the override request + will be rejected with a PreconditionFailed error. + """ + + +class VersioningOverride(BaseModel): + """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task + completion. + To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. + + Used to override the versioning behavior (and pinned deployment version, if applicable) + of a + specific workflow execution. If set, this override takes precedence over worker-sent + values. + See `WorkflowExecutionInfo.VersioningInfo` for more information. + + To remove the override, call `UpdateWorkflowExecutionOptions` with a null + `VersioningOverride`, and use the `update_mask` to indicate that it should be mutated. + + Pinned behavior overrides are automatically inherited by child workflows, workflow + retries, continue-as-new + workflows, and cron workflows. + """ + + auto_upgrade: Optional[bool] = Field(None, alias="autoUpgrade") + """Override the workflow to have AutoUpgrade behavior.""" + + behavior: Optional[VersioningOverrideBehavior] = None + """Required. + Deprecated. Use `override`. + """ + deployment: Optional[Deployment] = None + """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. + Identifies the worker deployment to pin the workflow to. + Deprecated. Use `override.pinned.version`. + """ + pinned: Optional[Pinned] = None + """Override the workflow to have Pinned behavior.""" + + pinned_version: Optional[str] = Field(None, alias="pinnedVersion") + """Required if behavior is `PINNED`. Must be absent if behavior is not `PINNED`. + Identifies the worker deployment version to pin the workflow to, in the format + ".". + Deprecated. Use `override.pinned.version`. + """ + + +class WorkflowIDConflictPolicy(Enum): + """Defines how to resolve a workflow id conflict with a *running* workflow. + The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. + Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* + workflow. + """ + + WORKFLOW_ID_CONFLICT_POLICY_FAIL = "WORKFLOW_ID_CONFLICT_POLICY_FAIL" + WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING = ( + "WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING" + ) + WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED = "WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED" + WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING = ( + "WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING" + ) + + +class WorkflowIDReusePolicy(Enum): + """Defines whether to allow re-using the workflow id from a previously *closed* workflow. + The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* + workflow. + """ + + WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE = ( + "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE" + ) + WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY = ( + "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY" + ) + WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE = ( + "WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE" + ) + WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING = ( + "WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING" + ) + WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED = "WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED" + + +class WorkflowType(BaseModel): + """Represents the identifier used by a workflow author to define the workflow. Typically, + the + name of a function. This is sometimes referred to as the workflow's "name" + """ + + name: Optional[str] = None + + +class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): + control: Optional[str] = None + """Deprecated.""" + + cron_schedule: Optional[str] = Field(None, alias="cronSchedule") + """See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/""" + + header: Optional[Header] = None + identity: Optional[str] = None + """The identity of the worker/client""" + + input: Optional[Input] = None + """Serialized arguments to the workflow. These are passed as arguments to the workflow + function. + """ + links: Optional[List[Openapiv3]] = None + """Links to be associated with the WorkflowExecutionStarted and WorkflowExecutionSignaled + events. + """ + memo: Optional[Memo] = None + namespace: Optional[str] = None + priority: Optional[Priority] = None + """Priority metadata""" + + request_id: Optional[str] = Field(None, alias="requestId") + """Used to de-dupe signal w/ start requests""" + + retry_policy: Optional[RetryPolicy] = Field(None, alias="retryPolicy") + """Retry policy for the workflow""" + + search_attributes: Optional[SearchAttributes] = Field( + None, alias="searchAttributes" + ) + signal_input: Optional[Input] = Field(None, alias="signalInput") + """Serialized value(s) to provide with the signal""" + + signal_name: Optional[str] = Field(None, alias="signalName") + """The workflow author-defined name of the signal to send to the workflow""" + + task_queue: Optional[TaskQueue] = Field(None, alias="taskQueue") + """The task queue to start this workflow on, if it will be started""" + + time_skipping_config: Optional[TimeSkippingConfig] = Field( + None, alias="timeSkippingConfig" + ) + """Time-skipping configuration. If not set, time skipping is disabled.""" + + user_metadata: Optional[UserMetadata] = Field(None, alias="userMetadata") + """Metadata on the workflow if it is started. This is carried over to the + WorkflowExecutionInfo + for use by user interfaces to display the fixed as-of-start summary and details of the + workflow. + """ + versioning_override: Optional[VersioningOverride] = Field( + None, alias="versioningOverride" + ) + """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task + completion. + To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. + """ + workflow_execution_timeout: Optional[str] = Field( + None, alias="workflowExecutionTimeout" + ) + """Total workflow execution timeout including retries and continue as new""" + + workflow_id: Optional[str] = Field(None, alias="workflowId") + workflow_id_conflict_policy: Optional[WorkflowIDConflictPolicy] = Field( + None, alias="workflowIdConflictPolicy" + ) + """Defines how to resolve a workflow id conflict with a *running* workflow. + The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. + Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* + workflow. + """ + workflow_id_reuse_policy: Optional[WorkflowIDReusePolicy] = Field( + None, alias="workflowIdReusePolicy" + ) + """Defines whether to allow re-using the workflow id from a previously *closed* workflow. + The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* + workflow. + """ + workflow_run_timeout: Optional[str] = Field(None, alias="workflowRunTimeout") + """Timeout of a single workflow run""" + + workflow_start_delay: Optional[str] = Field(None, alias="workflowStartDelay") + """Time to wait before dispatching the first workflow task. Cannot be used with + `cron_schedule`. + Note that the signal will be delivered with the first workflow task. If the workflow + gets + another SignalWithStartWorkflow before the delay a workflow task will be dispatched + immediately + and the rest of the delay period will be ignored, even if that request also had a delay. + Signal via SignalWorkflowExecution will not unblock the workflow. + """ + workflow_task_timeout: Optional[str] = Field(None, alias="workflowTaskTimeout") + """Timeout of a single workflow task""" + + workflow_type: Optional[WorkflowType] = Field(None, alias="workflowType") + + +class WorkflowServiceSignalWithStartWorkflowExecutionOutput(BaseModel): + run_id: Optional[str] = Field(None, alias="runId") + """The run id of the workflow that was started - or just signaled, if it was already running.""" + + started: Optional[bool] = None + """If true, a new workflow was started.""" + + +async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input( + payload: temporalio.api.common.v1.Payload, + payload_codec: temporalio.converter.PayloadCodec | None, +) -> temporalio.api.common.v1.Payload: + if payload_codec is None: + return payload + try: + value = json.loads(payload.data) + except json.JSONDecodeError: + return payload + rewritten = await _temporal_nexus_encode_json_value(value, payload_codec) + return temporalio.api.common.v1.Payload( + metadata=dict(payload.metadata), + data=json.dumps(rewritten, separators=(",", ":"), sort_keys=True).encode(), + ) + + +__temporal_nexus_payload_codec_rewriters__ = { + ( + "WorkflowService", + "SignalWithStartWorkflowExecution", + ): _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input, +} + + +@service +class WorkflowService: + signal_with_start_workflow_execution: Operation[ + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + ] = Operation(name="SignalWithStartWorkflowExecution") diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 2d7f3990b..85c38ff06 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -31,6 +31,8 @@ class CommandInfo: command_type: CommandType.ValueType command_seq: int + nexus_service: str | None = None + nexus_operation: str | None = None current_command_info: contextvars.ContextVar[CommandInfo | None] = ( @@ -81,7 +83,12 @@ async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( self, fs: VisitorFunctions, o: ScheduleNexusOperation ) -> None: - with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + with current_command( + CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, + o.seq, + nexus_service=o.service, + nexus_operation=o.operation, + ): await super()._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o) # Workflow activation jobs with payloads @@ -150,11 +157,20 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation( @contextmanager def current_command( - command_type: CommandType.ValueType, command_seq: int + command_type: CommandType.ValueType, + command_seq: int, + *, + nexus_service: str | None = None, + nexus_operation: str | None = None, ) -> Iterator[None]: """Context manager for setting command info.""" token = current_command_info.set( - CommandInfo(command_type=command_type, command_seq=command_seq) + CommandInfo( + command_type=command_type, + command_seq=command_seq, + nexus_service=nexus_service, + nexus_operation=nexus_operation, + ) ) try: yield diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py new file mode 100644 index 000000000..f9dd6b7a9 --- /dev/null +++ b/tests/nexus/test_temporal_system_nexus.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import dataclasses +import json +import uuid +from collections.abc import Sequence +from typing import cast + +import nexusrpc.handler +import pytest +from google.protobuf.json_format import MessageToDict + +import temporalio.api.common.v1 +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.converter import PayloadCodec +from temporalio.nexus.system import ( + WorkflowService, + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, +) +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.helpers.nexus import make_nexus_endpoint_name + + +@nexusrpc.handler.service_handler(service=WorkflowService) +class WorkflowServicePayloadHandler: + @nexusrpc.handler.sync_operation + async def signal_with_start_workflow_execution( + self, + _ctx: nexusrpc.handler.StartOperationContext, + request: WorkflowServiceSignalWithStartWorkflowExecutionInput, + ) -> WorkflowServiceSignalWithStartWorkflowExecutionOutput: + for field_name in ("input", "signalInput"): + payloads = request.model_dump(by_alias=True)[field_name]["payloads"] + assert "test-codec" in payloads[0]["metadata"] + for field_name in ("memo", "header"): + fields = request.model_dump(by_alias=True)[field_name]["fields"] + assert "test-codec" in next(iter(fields.values()))["metadata"] + return WorkflowServiceSignalWithStartWorkflowExecutionOutput( + runId=f"{request.workflow_id}-run" + ) + + +@workflow.defn +class SystemNexusCallerWithPayloadsWorkflow: + @workflow.run + async def run(self, task_queue: str) -> str: + nexus_client = workflow.create_nexus_client( + service=WorkflowService, + endpoint=make_nexus_endpoint_name(task_queue), + ) + request = WorkflowServiceSignalWithStartWorkflowExecutionInput.model_validate( + { + "namespace": "default", + "workflowId": "system-nexus-workflow-id", + "signalName": "test-signal", + "input": MessageToDict( + temporalio.api.common.v1.Payloads( + payloads=[ + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"workflow-input"', + ) + ] + ) + ), + "signalInput": MessageToDict( + temporalio.api.common.v1.Payloads( + payloads=[ + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"signal-input"', + ) + ] + ) + ), + "memo": MessageToDict( + temporalio.api.common.v1.Memo( + fields={ + "memo-key": temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"memo-value"', + ) + } + ) + ), + "header": MessageToDict( + temporalio.api.common.v1.Header( + fields={ + "header-key": temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"header-value"', + ) + } + ) + ), + } + ) + handle = await nexus_client.start_operation( + WorkflowService.signal_with_start_workflow_execution, + request, + ) + result = await handle + return cast(str, result.run_id) + + +class RejectOuterSystemNexusCodec(PayloadCodec): + def __init__(self) -> None: + self.encode_count = 0 + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + try: + body = json.loads(payload.data) + except json.JSONDecodeError: + body = None + if isinstance(body, dict) and { + "namespace", + "workflowId", + "signalName", + }.issubset(body): + raise RuntimeError( + "outer system nexus envelope should not be codec encoded" + ) + self.encode_count += 1 + encoded.append( + temporalio.api.common.v1.Payload( + metadata={**payload.metadata, "test-codec": b"true"}, + data=payload.data, + ) + ) + return encoded + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + return list(payloads) + + +async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( + env: WorkflowEnvironment, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + config = env.client.config() + config["data_converter"] = dataclasses.replace( + pydantic_data_converter, + payload_codec=codec, + ) + client = Client(**config) + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[SystemNexusCallerWithPayloadsWorkflow], + nexus_service_handlers=[WorkflowServicePayloadHandler()], + workflow_runner=UnsandboxedWorkflowRunner(), + ) as worker: + endpoint_name = make_nexus_endpoint_name(worker.task_queue) + await env.create_nexus_endpoint(endpoint_name, worker.task_queue) + result = await client.execute_workflow( + SystemNexusCallerWithPayloadsWorkflow.run, + worker.task_queue, + id=str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + + assert result == "system-nexus-workflow-id-run" + assert codec.encode_count >= 4 From f934a822d0f111333db7cf97c2c7ea6c591a092e Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 12:31:09 -0700 Subject: [PATCH 2/7] Make Nexus system model generation repo-local --- README.md | 7 ++ pyproject.toml | 10 +- scripts/gen_nexus_system_models.py | 118 ++++++++++++++++++ scripts/gen_nexus_system_test_models.py | 75 ----------- .../system/_workflow_service.nexusrpc.yaml | 9 ++ .../system/_workflow_service_generated.py | 45 ------- 6 files changed, 139 insertions(+), 125 deletions(-) create mode 100644 scripts/gen_nexus_system_models.py delete mode 100644 scripts/gen_nexus_system_test_models.py create mode 100644 temporalio/nexus/system/_workflow_service.nexusrpc.yaml diff --git a/README.md b/README.md index 6e42a2019..d1e162ab9 100644 --- a/README.md +++ b/README.md @@ -1933,6 +1933,8 @@ To build the SDK from source for use as a dependency, the following prerequisite * [uv](https://docs.astral.sh/uv/) * [Rust](https://www.rust-lang.org/) * [Protobuf Compiler](https://protobuf.dev/) +* [Node.js](https://nodejs.org/) +* [`pnpm`](https://pnpm.io/) Use `uv` to install `poe`: @@ -2074,6 +2076,11 @@ back from this downgrade, restore both of those files and run `uv sync --all-ext run for protobuf version 3 by setting the `TEMPORAL_TEST_PROTO3` env var to `1` prior to running tests. +The local build and lint flows also regenerate Temporal system Nexus models. By default this pulls +in `nexus-rpc-gen@0.1.0-alpha.4` via `npx`. To use an existing checkout instead, set +`TEMPORAL_NEXUS_RPC_GEN_DIR` to the `nexus-rpc-gen` repo root or its `src` directory before +running `poe build-develop`, `poe lint`, or `poe gen-protos`. + ### Style * Mostly [Google Style Guide](https://google.github.io/styleguide/pyguide.html). Notable exceptions: diff --git a/pyproject.toml b/pyproject.toml index 3c810ccc9..440ed4195 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,11 +68,11 @@ dev = [ [tool.poe.tasks] build-develop = [ - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { cmd = "uv run maturin develop --uv" }, ] build-develop-with-release = [ - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { cmd = "uv run maturin develop --release --uv" }, ] format = [ @@ -85,7 +85,7 @@ gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] gen-protos-docker = [ @@ -105,12 +105,12 @@ bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" } # https://github.com/PyCQA/pydocstyle/pull/511? lint-docs = "uv run pydocstyle --ignore-decorators=overload" lint-types = [ - { ref = "gen-nexus-system-test-models" }, + { ref = "gen-nexus-system-models" }, { cmd = "uv run pyright" }, { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, ] -gen-nexus-system-test-models = "uv run scripts/gen_nexus_system_test_models.py" +gen-nexus-system-models = "uv run scripts/gen_nexus_system_models.py" run-bench = "uv run python scripts/run_bench.py" test = "uv run pytest" diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py new file mode 100644 index 000000000..d90ebdf8f --- /dev/null +++ b/scripts/gen_nexus_system_models.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +NEXUS_RPC_GEN_ENV_VAR = "TEMPORAL_NEXUS_RPC_GEN_DIR" +NEXUS_RPC_GEN_VERSION = "0.1.0-alpha.4" + + +def main() -> None: + repo_root = Path(__file__).resolve().parent.parent + override_root = normalize_nexus_rpc_gen_root( + Path.cwd(), env_value=NEXUS_RPC_GEN_ENV_VAR + ) + input_schema = ( + repo_root + / "temporalio" + / "nexus" + / "system" + / "_workflow_service.nexusrpc.yaml" + ) + output_file = ( + repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + + if not input_schema.is_file(): + raise RuntimeError(f"Expected Nexus schema at {input_schema}") + + run_nexus_rpc_gen( + override_root=override_root, + output_file=output_file, + input_schema=input_schema, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + + +def run_nexus_rpc_gen( + *, override_root: Path | None, output_file: Path, input_schema: Path +) -> None: + common_args = [ + "--lang", + "py", + "--out-file", + str(output_file), + "--temporal-nexus-payload-codec-support", + str(input_schema), + ] + if override_root is None: + subprocess.run( + ["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args], + check=True, + ) + return + + subprocess.run( + [ + "node", + "packages/nexus-rpc-gen/dist/index.js", + *common_args, + ], + cwd=override_root, + check=True, + ) + + +def normalize_nexus_rpc_gen_root(base_dir: Path, env_value: str) -> Path | None: + raw_root = env_get(env_value) + if raw_root is None: + return None + candidate = Path(raw_root) + if not candidate.is_absolute(): + candidate = base_dir / candidate + candidate = candidate.resolve() + if (candidate / "package.json").is_file() and (candidate / "packages").is_dir(): + return candidate + if (candidate / "src" / "package.json").is_file(): + return candidate / "src" + raise RuntimeError( + f"{NEXUS_RPC_GEN_ENV_VAR} must point to the nexus-rpc-gen repo root or its src directory" + ) + + +def env_get(name: str) -> str | None: + return os.environ.get(name) + + +if __name__ == "__main__": + try: + main() + except Exception as err: + print(f"Failed to generate Nexus system models: {err}", file=sys.stderr) + raise diff --git a/scripts/gen_nexus_system_test_models.py b/scripts/gen_nexus_system_test_models.py deleted file mode 100644 index d707727a5..000000000 --- a/scripts/gen_nexus_system_test_models.py +++ /dev/null @@ -1,75 +0,0 @@ -from __future__ import annotations - -import subprocess -import sys -from pathlib import Path - - -def main() -> None: - repo_root = Path(__file__).resolve().parent.parent - workspace_root = repo_root.parent - nexus_rpc_gen_root = workspace_root / "nexus-rpc-gen" / "src" - input_schema = ( - workspace_root - / "temporal-api" - / "nexus" - / "temporal-json-schema-models-nexusrpc.yaml" - ) - output_file = ( - repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" - ) - - if not nexus_rpc_gen_root.is_dir(): - raise RuntimeError(f"Expected nexus-rpc-gen checkout at {nexus_rpc_gen_root}") - if not input_schema.is_file(): - raise RuntimeError(f"Expected Temporal Nexus schema at {input_schema}") - - subprocess.run( - [ - "npm", - "run", - "cli", - "--", - "--lang", - "py", - "--out-file", - str(output_file), - "--temporal-nexus-payload-codec-support", - str(input_schema), - ], - cwd=nexus_rpc_gen_root, - check=True, - ) - subprocess.run( - [ - "uv", - "run", - "ruff", - "check", - "--select", - "I", - "--fix", - str(output_file), - ], - cwd=repo_root, - check=True, - ) - subprocess.run( - [ - "uv", - "run", - "ruff", - "format", - str(output_file), - ], - cwd=repo_root, - check=True, - ) - - -if __name__ == "__main__": - try: - main() - except Exception as err: - print(f"Failed to generate Nexus system test models: {err}", file=sys.stderr) - raise diff --git a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml new file mode 100644 index 000000000..c5b8cc671 --- /dev/null +++ b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml @@ -0,0 +1,9 @@ +nexusrpc: 1.0.0 +services: + WorkflowService: + operations: + SignalWithStartWorkflowExecution: + input: + $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionRequest + output: + $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionResponse diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 1723631a1..085dad9b1 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional @@ -202,9 +201,6 @@ class EventType(Enum): "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" ) EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" - EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED = ( - "EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED" - ) EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" @@ -424,42 +420,6 @@ class TaskQueue(BaseModel): """ -class TimeSkippingConfig(BaseModel): - """Time-skipping configuration. If not set, time skipping is disabled. - - Configuration for time skipping during a workflow execution. - When enabled, virtual time advances automatically whenever there is no in-flight work. - In-flight work includes activities, child workflows, Nexus operations, signal/cancel - external workflow operations, - and possibly other features added in the future. - User timers are not classified as in-flight work and will be skipped over. - When time advances, it skips to the earlier of the next user timer or the configured - bound, if either exists. - """ - - disable_propagation: Optional[bool] = Field(None, alias="disablePropagation") - """If set, the enabled field is not propagated to transitively related workflows.""" - - enabled: Optional[bool] = None - """Enables or disables time skipping for this workflow execution. - By default, this field is propagated to transitively related workflows (child - workflows/start-as-new/reset) - at the time they are started. - Changes made after a transitively related workflow has started are not propagated. - """ - max_elapsed_duration: Optional[str] = Field(None, alias="maxElapsedDuration") - """Maximum elapsed time since time skipping was enabled. - This includes both skipped time and real time elapsing. - """ - max_skipped_duration: Optional[str] = Field(None, alias="maxSkippedDuration") - """Maximum total virtual time that can be skipped.""" - - max_target_time: Optional[datetime] = Field(None, alias="maxTargetTime") - """Absolute virtual timestamp at which time skipping is disabled. - Time skipping will not advance beyond this point. - """ - - class UserMetadata(BaseModel): """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo @@ -718,11 +678,6 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): task_queue: Optional[TaskQueue] = Field(None, alias="taskQueue") """The task queue to start this workflow on, if it will be started""" - time_skipping_config: Optional[TimeSkippingConfig] = Field( - None, alias="timeSkippingConfig" - ) - """Time-skipping configuration. If not set, time skipping is disabled.""" - user_metadata: Optional[UserMetadata] = Field(None, alias="userMetadata") """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo From 9d4ba7831199963913fd851004e4ae637d278418 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 12:33:03 -0700 Subject: [PATCH 3/7] Clarify Nexus generator fallback requirements --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d1e162ab9..736f04e8a 100644 --- a/README.md +++ b/README.md @@ -1934,7 +1934,6 @@ To build the SDK from source for use as a dependency, the following prerequisite * [Rust](https://www.rust-lang.org/) * [Protobuf Compiler](https://protobuf.dev/) * [Node.js](https://nodejs.org/) -* [`pnpm`](https://pnpm.io/) Use `uv` to install `poe`: @@ -2079,7 +2078,8 @@ tests. The local build and lint flows also regenerate Temporal system Nexus models. By default this pulls in `nexus-rpc-gen@0.1.0-alpha.4` via `npx`. To use an existing checkout instead, set `TEMPORAL_NEXUS_RPC_GEN_DIR` to the `nexus-rpc-gen` repo root or its `src` directory before -running `poe build-develop`, `poe lint`, or `poe gen-protos`. +running `poe build-develop`, `poe lint`, or `poe gen-protos`. The local checkout override path +also requires [`pnpm`](https://pnpm.io/) to be installed. ### Style From 7960be50e4daf0b2e3313af67b32dd052ed38a31 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 8 Apr 2026 13:54:13 -0700 Subject: [PATCH 4/7] Refine Nexus system payload rewriting --- pyproject.toml | 3 +- scripts/gen_nexus_system_models.py | 2 + temporalio/bridge/worker.py | 78 +++---- temporalio/nexus/system/__init__.py | 23 +- .../system/_workflow_service.nexusrpc.yaml | 2 + .../system/_workflow_service_generated.py | 204 +++++++++++------- tests/nexus/test_temporal_system_nexus.py | 66 +++++- 7 files changed, 251 insertions(+), 127 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 440ed4195..a1f5a6661 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ environment = { PATH = "$PATH:$HOME/.cargo/bin", CARGO_NET_GIT_FETCH_WITH_CLI = ignore_missing_imports = true exclude = [ # Ignore generated code + 'build', 'temporalio/api', 'temporalio/bridge/proto', 'temporalio/nexus/system/_workflow_service_generated.py', @@ -156,7 +157,7 @@ exclude = [ [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 -match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match_dir = "^(?!(build|docs|scripts|tests|api|proto|\\.)).*" match = "^(?!_workflow_service_generated\\.py$).*\\.py" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py index d90ebdf8f..2009a43a7 100644 --- a/scripts/gen_nexus_system_models.py +++ b/scripts/gen_nexus_system_models.py @@ -11,6 +11,8 @@ def main() -> None: repo_root = Path(__file__).resolve().parent.parent + # TODO: Remove the local .nexusrpc.yaml shim once the upstream API repo + # checks in the Nexus definition we can consume directly. override_root = normalize_nexus_rpc_gen_root( Path.cwd(), env_value=NEXUS_RPC_GEN_ENV_VAR ) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 52a54ce9c..587bbaa12 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -285,46 +285,13 @@ class _Visitor(VisitorFunctions): def __init__( self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], - payload_codec: temporalio.converter.PayloadCodec | None = None, ): self._f = f - self._payload_codec = payload_codec async def visit_payload(self, payload: Payload) -> None: - if self._payload_codec: - rewritten_payload = await self._maybe_rewrite_nexus_payload(payload) - if rewritten_payload is not None: - if rewritten_payload is not payload: - payload.CopyFrom(rewritten_payload) - return - new_payload = (await self._f([payload]))[0] - if new_payload is not payload: - payload.CopyFrom(new_payload) - - async def _maybe_rewrite_nexus_payload(self, payload: Payload) -> Payload | None: - command_info = _command_aware_visitor.current_command_info.get() - if ( - command_info is None - or command_info.command_type - != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION - or not command_info.nexus_service - or not command_info.nexus_operation - ): - return None - - rewrite = temporalio.nexus.system.get_payload_codec_rewriter( - command_info.nexus_service, - command_info.nexus_operation, - ) - if rewrite is None: - return None - - rewritten_payload = await rewrite(payload, self._payload_codec) - if not isinstance(rewritten_payload, Payload): - raise TypeError( - "temporal nexus payload codec rewriter must return a Payload" - ) - return rewritten_payload + rewritten_payload = (await self._f([payload]))[0] + if rewritten_payload is not payload: + payload.CopyFrom(rewritten_payload) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: @@ -336,6 +303,43 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: payloads.extend(new_payloads) +async def _encode_completion_payloads( + data_converter: temporalio.converter.DataConverter, + payloads: Sequence[Payload], +) -> list[Payload]: + if len(payloads) != 1: + return await data_converter._encode_payload_sequence(payloads) + + # A single payload may be the outer envelope for a system Nexus operation. + # In that case we leave the envelope itself unencoded so the server can read + # it, but still route any nested Temporal payloads through normal payload + # processing via the generated operation-specific rewriter. + payload = payloads[0] + command_info = _command_aware_visitor.current_command_info.get() + if ( + command_info is None + or command_info.command_type + != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION + or not command_info.nexus_service + or not command_info.nexus_operation + ): + return await data_converter._encode_payload_sequence(payloads) + + rewrite = temporalio.nexus.system.get_payload_rewriter( + command_info.nexus_service, + command_info.nexus_operation, + ) + if rewrite is None: + return await data_converter._encode_payload_sequence(payloads) + + rewritten_payload = await rewrite( + payload, + data_converter._encode_payload_sequence, + False, + ) + return [rewritten_payload] + + async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, data_converter: temporalio.converter.DataConverter, @@ -356,6 +360,6 @@ async def encode_completion( await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers ).visit( - _Visitor(data_converter._encode_payload_sequence, data_converter.payload_codec), + _Visitor(lambda payloads: _encode_completion_payloads(data_converter, payloads)), completion, ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 7eff3a93d..52a80cda1 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -4,38 +4,41 @@ Higher-level ergonomic APIs may wrap these generated types. """ -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence import temporalio.api.common.v1 -import temporalio.converter from ._workflow_service_generated import ( WorkflowService, WorkflowServiceSignalWithStartWorkflowExecutionInput, WorkflowServiceSignalWithStartWorkflowExecutionOutput, - __temporal_nexus_payload_codec_rewriters__, + __temporal_nexus_payload_rewriters__, ) -TemporalNexusPayloadCodecRewriter = Callable[ +TemporalNexusPayloadRewriter = Callable[ [ temporalio.api.common.v1.Payload, - temporalio.converter.PayloadCodec | None, + Callable[ + [Sequence[temporalio.api.common.v1.Payload]], + Awaitable[list[temporalio.api.common.v1.Payload]], + ], + bool, ], Awaitable[temporalio.api.common.v1.Payload], ] -def get_payload_codec_rewriter( +def get_payload_rewriter( service: str, operation: str, -) -> TemporalNexusPayloadCodecRewriter | None: - """Return the generated payload codec rewriter for a system Nexus operation.""" - return __temporal_nexus_payload_codec_rewriters__.get((service, operation)) +) -> TemporalNexusPayloadRewriter | None: + """Return the generated nested-payload rewriter for a system Nexus operation.""" + return __temporal_nexus_payload_rewriters__.get((service, operation)) __all__ = ( "WorkflowService", "WorkflowServiceSignalWithStartWorkflowExecutionInput", "WorkflowServiceSignalWithStartWorkflowExecutionOutput", - "get_payload_codec_rewriter", + "get_payload_rewriter", ) diff --git a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml index c5b8cc671..edea24b4e 100644 --- a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml +++ b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml @@ -1,3 +1,5 @@ +# TODO: Remove this local shim once the upstream API repo checks in the Nexus +# definition and the generator can consume it directly. nexusrpc: 1.0.0 services: WorkflowService: diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index 085dad9b1..df330d8de 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -2,6 +2,7 @@ from __future__ import annotations +import collections.abc import json from enum import Enum from typing import Any, Dict, List, Optional @@ -11,66 +12,6 @@ from pydantic import BaseModel, Field import temporalio.api.common.v1 -import temporalio.converter - - -async def _temporal_nexus_encode_payload_json( - value: dict, payload_codec: temporalio.converter.PayloadCodec -) -> dict: - payload = ParseDict(value, temporalio.api.common.v1.Payload()) - [encoded_payload] = await payload_codec.encode([payload]) - return MessageToDict(encoded_payload) - - -async def _temporal_nexus_encode_payloads_json( - value: dict, payload_codec: temporalio.converter.PayloadCodec -) -> dict: - payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) - encoded_payloads = await payload_codec.encode(payloads.payloads) - del payloads.payloads[:] - payloads.payloads.extend(encoded_payloads) - return MessageToDict(payloads) - - -async def _temporal_nexus_encode_payload_map_json( - message_type: type, value: dict, payload_codec: temporalio.converter.PayloadCodec -) -> dict: - message = ParseDict(value, message_type()) - keys = list(message.fields.keys()) - encoded_payloads = await payload_codec.encode([message.fields[key] for key in keys]) - for key, encoded_payload in zip(keys, encoded_payloads): - message.fields[key].CopyFrom(encoded_payload) - return MessageToDict(message) - - -async def _temporal_nexus_encode_json_value( - value: object, payload_codec: temporalio.converter.PayloadCodec -) -> object: - if isinstance(value, list): - return [ - await _temporal_nexus_encode_json_value(item, payload_codec) - for item in value - ] - if not isinstance(value, dict): - return value - if "indexedFields" in value: - return value - if "payloads" in value and isinstance(value["payloads"], list): - return await _temporal_nexus_encode_payloads_json(value, payload_codec) - if "fields" in value and isinstance(value["fields"], dict): - return await _temporal_nexus_encode_payload_map_json( - temporalio.api.common.v1.Header, value, payload_codec - ) - if "data" in value and "metadata" in value: - return await _temporal_nexus_encode_payload_json(value, payload_codec) - rewritten: dict[str, object] = {} - for key, item in value.items(): - rewritten[key] = ( - item - if key == "indexedFields" - else await _temporal_nexus_encode_json_value(item, payload_codec) - ) - return rewritten class Header(BaseModel): @@ -743,34 +684,151 @@ class WorkflowServiceSignalWithStartWorkflowExecutionOutput(BaseModel): """If true, a new workflow was started.""" +@service +class WorkflowService: + signal_with_start_workflow_execution: Operation[ + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + ] = Operation(name="SignalWithStartWorkflowExecution") + + +class _TemporalNexusPayloadRewriter: + def __init__( + self, + payload_visitor: collections.abc.Callable[ + [collections.abc.Sequence[temporalio.api.common.v1.Payload]], + collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], + ], + visit_search_attributes: bool = False, + ): + self._payload_visitor = payload_visitor + self._visit_search_attributes = visit_search_attributes + + async def _rewrite_payload_json(self, value: dict) -> dict: + payload = ParseDict(value, temporalio.api.common.v1.Payload()) + [rewritten_payload] = await self._payload_visitor([payload]) + return MessageToDict(rewritten_payload) + + async def _rewrite_payloads_json(self, value: dict) -> dict: + payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) + rewritten_payloads = await self._payload_visitor(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(rewritten_payloads) + return MessageToDict(payloads) + + async def _rewrite_payload_map_json(self, message_type: type, value: dict) -> dict: + message = message_type() + keys = list(value.keys()) + rewritten_payloads = await self._payload_visitor( + [ParseDict(value[key], temporalio.api.common.v1.Payload()) for key in keys] + ) + for key, rewritten_payload in zip(keys, rewritten_payloads): + message.fields[key].CopyFrom(rewritten_payload) + return MessageToDict(message).get("fields", {}) + + async def _temporal_nexus_rewrite_header_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("fields") is not None: + rewritten["fields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.Header, rewritten["fields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_input_json(self, value: dict) -> dict: + return await self._rewrite_payloads_json(value) + + async def _temporal_nexus_rewrite_memo_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("fields") is not None: + rewritten["fields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.Memo, rewritten["fields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_search_attributes_json(self, value: dict) -> dict: + if not self._visit_search_attributes: + return value + rewritten = dict(value) + if rewritten.get("indexedFields") is not None: + rewritten["indexedFields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.SearchAttributes, rewritten["indexedFields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_user_metadata_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("details") is not None: + rewritten["details"] = await self._rewrite_payload_json( + rewritten["details"] + ) + if rewritten.get("summary") is not None: + rewritten["summary"] = await self._rewrite_payload_json( + rewritten["summary"] + ) + return rewritten + + async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + self, value: dict + ) -> dict: + rewritten = dict(value) + if rewritten.get("header") is not None: + rewritten["header"] = await self._temporal_nexus_rewrite_header_json( + rewritten["header"] + ) + if rewritten.get("input") is not None: + rewritten["input"] = await self._temporal_nexus_rewrite_input_json( + rewritten["input"] + ) + if rewritten.get("memo") is not None: + rewritten["memo"] = await self._temporal_nexus_rewrite_memo_json( + rewritten["memo"] + ) + if rewritten.get("searchAttributes") is not None: + rewritten[ + "searchAttributes" + ] = await self._temporal_nexus_rewrite_search_attributes_json( + rewritten["searchAttributes"] + ) + if rewritten.get("signalInput") is not None: + rewritten["signalInput"] = await self._temporal_nexus_rewrite_input_json( + rewritten["signalInput"] + ) + if rewritten.get("userMetadata") is not None: + rewritten[ + "userMetadata" + ] = await self._temporal_nexus_rewrite_user_metadata_json( + rewritten["userMetadata"] + ) + return rewritten + + async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input( payload: temporalio.api.common.v1.Payload, - payload_codec: temporalio.converter.PayloadCodec | None, + payload_visitor: collections.abc.Callable[ + [collections.abc.Sequence[temporalio.api.common.v1.Payload]], + collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], + ], + visit_search_attributes: bool = False, ) -> temporalio.api.common.v1.Payload: - if payload_codec is None: - return payload try: value = json.loads(payload.data) except json.JSONDecodeError: return payload - rewritten = await _temporal_nexus_encode_json_value(value, payload_codec) + if not isinstance(value, dict): + return payload + rewriter = _TemporalNexusPayloadRewriter(payload_visitor, visit_search_attributes) + rewritten = await rewriter._temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + value + ) return temporalio.api.common.v1.Payload( metadata=dict(payload.metadata), data=json.dumps(rewritten, separators=(",", ":"), sort_keys=True).encode(), ) -__temporal_nexus_payload_codec_rewriters__ = { +__temporal_nexus_payload_rewriters__ = { ( "WorkflowService", "SignalWithStartWorkflowExecution", ): _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input, } - - -@service -class WorkflowService: - signal_with_start_workflow_execution: Operation[ - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, - ] = Operation(name="SignalWithStartWorkflowExecution") diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index f9dd6b7a9..ecaf300da 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -14,7 +14,7 @@ from temporalio import workflow from temporalio.client import Client from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.converter import PayloadCodec +from temporalio.converter import ExternalStorage, PayloadCodec from temporalio.nexus.system import ( WorkflowService, WorkflowServiceSignalWithStartWorkflowExecutionInput, @@ -24,6 +24,7 @@ from temporalio.worker import Worker from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner from tests.helpers.nexus import make_nexus_endpoint_name +from tests.test_extstore import InMemoryTestDriver @nexusrpc.handler.service_handler(service=WorkflowService) @@ -34,12 +35,21 @@ async def signal_with_start_workflow_execution( _ctx: nexusrpc.handler.StartOperationContext, request: WorkflowServiceSignalWithStartWorkflowExecutionInput, ) -> WorkflowServiceSignalWithStartWorkflowExecutionOutput: + request_dict = request.model_dump(by_alias=True) for field_name in ("input", "signalInput"): - payloads = request.model_dump(by_alias=True)[field_name]["payloads"] - assert "test-codec" in payloads[0]["metadata"] + payloads = request_dict[field_name]["payloads"] + assert payloads[0]["externalPayloads"] for field_name in ("memo", "header"): - fields = request.model_dump(by_alias=True)[field_name]["fields"] - assert "test-codec" in next(iter(fields.values()))["metadata"] + fields = request_dict[field_name]["fields"] + assert next(iter(fields.values()))["externalPayloads"] + for field_name in ("summary", "details"): + payload = request_dict["userMetadata"][field_name] + assert payload["externalPayloads"] + search_attribute_payload = request_dict["searchAttributes"]["indexedFields"][ + "custom-key" + ] + assert "externalPayloads" not in search_attribute_payload + assert "test-codec" not in search_attribute_payload["metadata"] return WorkflowServiceSignalWithStartWorkflowExecutionOutput( runId=f"{request.workflow_id}-run" ) @@ -98,6 +108,30 @@ async def run(self, task_queue: str) -> str: } ) ), + "userMetadata": { + "summary": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"summary-value"', + ) + ), + "details": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"details-value"', + ) + ), + }, + "searchAttributes": { + "indexedFields": { + "custom-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"search-attribute-value"', + ) + ) + } + }, } ) handle = await nexus_client.start_operation( @@ -151,10 +185,15 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with pytest.skip("Nexus tests don't work with the Java test server") codec = RejectOuterSystemNexusCodec() + driver = InMemoryTestDriver() config = env.client.config() config["data_converter"] = dataclasses.replace( pydantic_data_converter, payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), ) client = Client(**config) @@ -175,4 +214,19 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with ) assert result == "system-nexus-workflow-id-run" - assert codec.encode_count >= 4 + assert codec.encode_count >= 6 + stored_payloads: list[temporalio.api.common.v1.Payload] = [] + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + stored_payloads.append(stored_payload) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data = {payload.data for payload in stored_payloads} + assert { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"header-value"', + b'"summary-value"', + b'"details-value"', + }.issubset(stored_payload_data) From e0523210ab594767cacd0181a12ab65e05f17cba Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 9 Apr 2026 09:37:12 -0700 Subject: [PATCH 5/7] Use generated system nexus module directly --- pyproject.toml | 12 +- temporalio/bridge/worker.py | 7 +- temporalio/nexus/system/__init__.py | 27 ++- .../system/_workflow_service_generated.py | 164 +++++++------ temporalio/worker/_workflow_instance.py | 15 +- tests/nexus/test_temporal_system_nexus.py | 229 +++++++++++------- 6 files changed, 262 insertions(+), 192 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a1f5a6661..e65b0ed54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,14 +67,8 @@ dev = [ ] [tool.poe.tasks] -build-develop = [ - { ref = "gen-nexus-system-models" }, - { cmd = "uv run maturin develop --uv" }, -] -build-develop-with-release = [ - { ref = "gen-nexus-system-models" }, - { cmd = "uv run maturin develop --release --uv" }, -] +build-develop = "uv run maturin develop --uv" +build-develop-with-release = { cmd = "uv run maturin develop --release --uv" } format = [ { cmd = "uv run ruff check --select I --fix" }, { cmd = "uv run ruff format" }, @@ -92,6 +86,7 @@ gen-protos-docker = [ { cmd = "uv run scripts/gen_protos_docker.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] lint = [ @@ -105,7 +100,6 @@ bridge-lint = { cmd = "cargo clippy -- -D warnings", cwd = "temporalio/bridge" } # https://github.com/PyCQA/pydocstyle/pull/511? lint-docs = "uv run pydocstyle --ignore-decorators=overload" lint-types = [ - { ref = "gen-nexus-system-models" }, { cmd = "uv run pyright" }, { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 587bbaa12..8539d695d 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -326,8 +326,7 @@ async def _encode_completion_payloads( return await data_converter._encode_payload_sequence(payloads) rewrite = temporalio.nexus.system.get_payload_rewriter( - command_info.nexus_service, - command_info.nexus_operation, + command_info.nexus_service, command_info.nexus_operation ) if rewrite is None: return await data_converter._encode_payload_sequence(payloads) @@ -360,6 +359,8 @@ async def encode_completion( await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers ).visit( - _Visitor(lambda payloads: _encode_completion_payloads(data_converter, payloads)), + _Visitor( + lambda payloads: _encode_completion_payloads(data_converter, payloads) + ), completion, ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 52a80cda1..2f4dc7926 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -7,13 +7,10 @@ from collections.abc import Awaitable, Callable, Sequence import temporalio.api.common.v1 +import temporalio.converter -from ._workflow_service_generated import ( - WorkflowService, - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, - __temporal_nexus_payload_rewriters__, -) +from . import _workflow_service_generated as generated +from ._workflow_service_generated import __temporal_nexus_payload_rewriters__ TemporalNexusPayloadRewriter = Callable[ [ @@ -27,6 +24,8 @@ Awaitable[temporalio.api.common.v1.Payload], ] +_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.JSONPlainPayloadConverter() + def get_payload_rewriter( service: str, @@ -36,9 +35,19 @@ def get_payload_rewriter( return __temporal_nexus_payload_rewriters__.get((service, operation)) +def is_system_operation(service: str, operation: str) -> bool: + """Return whether a Nexus operation uses the generated system envelope.""" + return get_payload_rewriter(service, operation) is not None + + +def get_payload_converter() -> temporalio.converter.EncodingPayloadConverter: + """Return the fixed payload converter for system Nexus outer envelopes.""" + return _SYSTEM_NEXUS_PAYLOAD_CONVERTER + + __all__ = ( - "WorkflowService", - "WorkflowServiceSignalWithStartWorkflowExecutionInput", - "WorkflowServiceSignalWithStartWorkflowExecutionOutput", + "generated", + "get_payload_converter", "get_payload_rewriter", + "is_system_operation", ) diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index df330d8de..e2273b871 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -4,17 +4,18 @@ import collections.abc import json +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional from google.protobuf.json_format import MessageToDict, ParseDict from nexusrpc import Operation, service -from pydantic import BaseModel, Field import temporalio.api.common.v1 -class Header(BaseModel): +@dataclass +class Header: """Contains metadata that can be attached to a variety of requests, like starting a workflow, and can be propagated between, for example, workflows and activities. @@ -23,7 +24,8 @@ class Header(BaseModel): fields: Optional[Dict[str, Any]] = None -class Input(BaseModel): +@dataclass +class Input: """Serialized arguments to the workflow. These are passed as arguments to the workflow function. @@ -35,14 +37,15 @@ class Input(BaseModel): payloads: Optional[List[Any]] = None -class BatchJob(BaseModel): +@dataclass +class BatchJob: """A link to a built-in batch job. Batch jobs can be used to perform operations on a set of workflows (e.g. terminate, signal, cancel, etc). This link can be put on workflow history events generated by actions taken by a batch job. """ - job_id: Optional[str] = Field(None, alias="jobId") + jobId: Optional[str] = None class EventType(Enum): @@ -166,29 +169,33 @@ class EventType(Enum): EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT = "EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT" -class EventRef(BaseModel): +@dataclass +class EventRef: """EventReference is a direct reference to a history event through the event ID.""" - event_id: Optional[str] = Field(None, alias="eventId") - event_type: Optional[EventType] = Field(None, alias="eventType") + eventId: Optional[str] = None + eventType: Optional[EventType] = None -class RequestIDRef(BaseModel): +@dataclass +class RequestIDRef: """RequestIdReference is a indirect reference to a history event through the request ID.""" - event_type: Optional[EventType] = Field(None, alias="eventType") - request_id: Optional[str] = Field(None, alias="requestId") + eventType: Optional[EventType] = None + requestId: Optional[str] = None -class WorkflowEvent(BaseModel): - event_ref: Optional[EventRef] = Field(None, alias="eventRef") +@dataclass +class WorkflowEvent: + eventRef: Optional[EventRef] = None namespace: Optional[str] = None - request_id_ref: Optional[RequestIDRef] = Field(None, alias="requestIdRef") - run_id: Optional[str] = Field(None, alias="runId") - workflow_id: Optional[str] = Field(None, alias="workflowId") + requestIdRef: Optional[RequestIDRef] = None + runId: Optional[str] = None + workflowId: Optional[str] = None -class Openapiv3(BaseModel): +@dataclass +class Openapiv3: """Link can be associated with history events. It might contain information about an external entity related to the history event. For example, workflow A makes a Nexus call that starts @@ -198,17 +205,19 @@ class Openapiv3(BaseModel): workflow B, and vice-versa. """ - batch_job: Optional[BatchJob] = Field(None, alias="batchJob") - workflow_event: Optional[WorkflowEvent] = Field(None, alias="workflowEvent") + batchJob: Optional[BatchJob] = None + workflowEvent: Optional[WorkflowEvent] = None -class Memo(BaseModel): +@dataclass +class Memo: """A user-defined set of *unindexed* fields that are exposed when listing/searching workflows""" fields: Optional[Dict[str, Any]] = None -class Priority(BaseModel): +@dataclass +class Priority: """Priority metadata Priority contains metadata that controls relative ordering of task processing @@ -245,7 +254,7 @@ class Priority(BaseModel): fields. (Currently only support in matching task queues is planned.) """ - fairness_key: Optional[str] = Field(None, alias="fairnessKey") + fairnessKey: Optional[str] = None """Fairness key is a short string that's used as a key for a fairness balancing mechanism. It may correspond to a tenant id, or to a fixed string like "high" or "low". The default is the empty string. @@ -271,7 +280,7 @@ class Priority(BaseModel): Fairness keys are limited to 64 bytes. """ - fairness_weight: Optional[float] = Field(None, alias="fairnessWeight") + fairnessWeight: Optional[float] = None """Fairness weight for a task can come from multiple sources for flexibility. From highest to lowest precedence: 1. Weights for a small set of keys can be overridden in task queue @@ -281,7 +290,7 @@ class Priority(BaseModel): Weight values are clamped to the range [0.001, 1000]. """ - priority_key: Optional[int] = Field(None, alias="priorityKey") + priorityKey: Optional[int] = None """Priority key is a positive integer from 1 to n, where smaller integers correspond to higher priorities (tasks run sooner). In general, tasks in a queue should be processed in close to priority order, although small @@ -296,45 +305,45 @@ class Priority(BaseModel): """ -class RetryPolicy(BaseModel): +@dataclass +class RetryPolicy: """Retry policy for the workflow How retries ought to be handled, usable by both workflows and activities """ - backoff_coefficient: Optional[float] = Field(None, alias="backoffCoefficient") + backoffCoefficient: Optional[float] = None """Coefficient used to calculate the next retry interval. The next retry interval is previous interval multiplied by the coefficient. Must be 1 or larger. """ - initial_interval: Optional[str] = Field(None, alias="initialInterval") + initialInterval: Optional[str] = None """Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is used for all retries. """ - maximum_attempts: Optional[int] = Field(None, alias="maximumAttempts") + maximumAttempts: Optional[int] = None """Maximum number of attempts. When exceeded the retries stop even if not expired yet. 1 disables retries. 0 means unlimited (up to the timeouts) """ - maximum_interval: Optional[str] = Field(None, alias="maximumInterval") + maximumInterval: Optional[str] = None """Maximum interval between retries. Exponential backoff leads to interval increase. This value is the cap of the increase. Default is 100x of the initial interval. """ - non_retryable_error_types: Optional[List[str]] = Field( - None, alias="nonRetryableErrorTypes" - ) + nonRetryableErrorTypes: Optional[List[str]] = None """Non-Retryable errors types. Will stop retrying if the error type matches this list. Note that this is not a substring match, the error *type* (not message) must match exactly. """ -class SearchAttributes(BaseModel): +@dataclass +class SearchAttributes: """A user-defined set of *indexed* fields that are used/exposed when listing/searching workflows. The payload is not serialized in a user-defined way. """ - indexed_fields: Optional[Dict[str, Any]] = Field(None, alias="indexedFields") + indexedFields: Optional[Dict[str, Any]] = None class Kind(Enum): @@ -345,7 +354,8 @@ class Kind(Enum): TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" -class TaskQueue(BaseModel): +@dataclass +class TaskQueue: """The task queue to start this workflow on, if it will be started See https://docs.temporal.io/docs/concepts/task-queues/ @@ -355,13 +365,14 @@ class TaskQueue(BaseModel): """Default: TASK_QUEUE_KIND_NORMAL.""" name: Optional[str] = None - normal_name: Optional[str] = Field(None, alias="normalName") + normalName: Optional[str] = None """Iff kind == TASK_QUEUE_KIND_STICKY, then this field contains the name of the normal task queue that the sticky worker is running on. """ -class UserMetadata(BaseModel): +@dataclass +class UserMetadata: """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo for use by user interfaces to display the fixed as-of-start summary and details of the @@ -398,7 +409,8 @@ class VersioningOverrideBehavior(Enum): VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" -class Deployment(BaseModel): +@dataclass +class Deployment: """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. Identifies the worker deployment to pin the workflow to. Deprecated. Use `override.pinned.version`. @@ -411,12 +423,12 @@ class Deployment(BaseModel): Deprecated. """ - build_id: Optional[str] = Field(None, alias="buildId") + buildId: Optional[str] = None """Build ID changes with each version of the worker when the worker program code and/or config changes. """ - series_name: Optional[str] = Field(None, alias="seriesName") + seriesName: Optional[str] = None """Different versions of the same worker service/application are related together by having a shared series name. @@ -436,7 +448,8 @@ class PinnedBehavior(Enum): PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" -class Version(BaseModel): +@dataclass +class Version: """Specifies the Worker Deployment Version to pin this workflow to. Required if the target workflow is not already pinned to a version. @@ -455,17 +468,18 @@ class Version(BaseModel): in the future. """ - build_id: Optional[str] = Field(None, alias="buildId") + buildId: Optional[str] = None """A unique identifier for this Version within the Deployment it is a part of. Not necessarily unique within the namespace. The combination of `deployment_name` and `build_id` uniquely identifies this Version within the namespace, because Deployment names are unique within a namespace. """ - deployment_name: Optional[str] = Field(None, alias="deploymentName") + deploymentName: Optional[str] = None """Identifies the Worker Deployment this Version is part of.""" -class Pinned(BaseModel): +@dataclass +class Pinned: """Override the workflow to have Pinned behavior.""" behavior: Optional[PinnedBehavior] = None @@ -484,7 +498,8 @@ class Pinned(BaseModel): """ -class VersioningOverride(BaseModel): +@dataclass +class VersioningOverride: """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task completion. To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. @@ -503,7 +518,7 @@ class VersioningOverride(BaseModel): workflows, and cron workflows. """ - auto_upgrade: Optional[bool] = Field(None, alias="autoUpgrade") + autoUpgrade: Optional[bool] = None """Override the workflow to have AutoUpgrade behavior.""" behavior: Optional[VersioningOverrideBehavior] = None @@ -518,7 +533,7 @@ class VersioningOverride(BaseModel): pinned: Optional[Pinned] = None """Override the workflow to have Pinned behavior.""" - pinned_version: Optional[str] = Field(None, alias="pinnedVersion") + pinnedVersion: Optional[str] = None """Required if behavior is `PINNED`. Must be absent if behavior is not `PINNED`. Identifies the worker deployment version to pin the workflow to, in the format ".". @@ -568,7 +583,8 @@ class WorkflowIDReusePolicy(Enum): WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED = "WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED" -class WorkflowType(BaseModel): +@dataclass +class WorkflowType: """Represents the identifier used by a workflow author to define the workflow. Typically, the name of a function. This is sometimes referred to as the workflow's "name" @@ -577,11 +593,12 @@ class WorkflowType(BaseModel): name: Optional[str] = None -class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): +@dataclass +class WorkflowServiceSignalWithStartWorkflowExecutionInput: control: Optional[str] = None """Deprecated.""" - cron_schedule: Optional[str] = Field(None, alias="cronSchedule") + cronSchedule: Optional[str] = None """See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/""" header: Optional[Header] = None @@ -601,46 +618,38 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): priority: Optional[Priority] = None """Priority metadata""" - request_id: Optional[str] = Field(None, alias="requestId") + requestId: Optional[str] = None """Used to de-dupe signal w/ start requests""" - retry_policy: Optional[RetryPolicy] = Field(None, alias="retryPolicy") + retryPolicy: Optional[RetryPolicy] = None """Retry policy for the workflow""" - search_attributes: Optional[SearchAttributes] = Field( - None, alias="searchAttributes" - ) - signal_input: Optional[Input] = Field(None, alias="signalInput") + searchAttributes: Optional[SearchAttributes] = None + signalInput: Optional[Input] = None """Serialized value(s) to provide with the signal""" - signal_name: Optional[str] = Field(None, alias="signalName") + signalName: Optional[str] = None """The workflow author-defined name of the signal to send to the workflow""" - task_queue: Optional[TaskQueue] = Field(None, alias="taskQueue") + taskQueue: Optional[TaskQueue] = None """The task queue to start this workflow on, if it will be started""" - user_metadata: Optional[UserMetadata] = Field(None, alias="userMetadata") + userMetadata: Optional[UserMetadata] = None """Metadata on the workflow if it is started. This is carried over to the WorkflowExecutionInfo for use by user interfaces to display the fixed as-of-start summary and details of the workflow. """ - versioning_override: Optional[VersioningOverride] = Field( - None, alias="versioningOverride" - ) + versioningOverride: Optional[VersioningOverride] = None """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task completion. To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. """ - workflow_execution_timeout: Optional[str] = Field( - None, alias="workflowExecutionTimeout" - ) + workflowExecutionTimeout: Optional[str] = None """Total workflow execution timeout including retries and continue as new""" - workflow_id: Optional[str] = Field(None, alias="workflowId") - workflow_id_conflict_policy: Optional[WorkflowIDConflictPolicy] = Field( - None, alias="workflowIdConflictPolicy" - ) + workflowId: Optional[str] = None + workflowIdConflictPolicy: Optional[WorkflowIDConflictPolicy] = None """Defines how to resolve a workflow id conflict with a *running* workflow. The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. @@ -648,19 +657,17 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* workflow. """ - workflow_id_reuse_policy: Optional[WorkflowIDReusePolicy] = Field( - None, alias="workflowIdReusePolicy" - ) + workflowIdReusePolicy: Optional[WorkflowIDReusePolicy] = None """Defines whether to allow re-using the workflow id from a previously *closed* workflow. The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* workflow. """ - workflow_run_timeout: Optional[str] = Field(None, alias="workflowRunTimeout") + workflowRunTimeout: Optional[str] = None """Timeout of a single workflow run""" - workflow_start_delay: Optional[str] = Field(None, alias="workflowStartDelay") + workflowStartDelay: Optional[str] = None """Time to wait before dispatching the first workflow task. Cannot be used with `cron_schedule`. Note that the signal will be delivered with the first workflow task. If the workflow @@ -670,14 +677,15 @@ class WorkflowServiceSignalWithStartWorkflowExecutionInput(BaseModel): and the rest of the delay period will be ignored, even if that request also had a delay. Signal via SignalWorkflowExecution will not unblock the workflow. """ - workflow_task_timeout: Optional[str] = Field(None, alias="workflowTaskTimeout") + workflowTaskTimeout: Optional[str] = None """Timeout of a single workflow task""" - workflow_type: Optional[WorkflowType] = Field(None, alias="workflowType") + workflowType: Optional[WorkflowType] = None -class WorkflowServiceSignalWithStartWorkflowExecutionOutput(BaseModel): - run_id: Optional[str] = Field(None, alias="runId") +@dataclass +class WorkflowServiceSignalWithStartWorkflowExecutionOutput: + runId: Optional[str] = None """The run id of the workflow that was started - or just signaled, if it was already running.""" started: Optional[bool] = None diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1bfa77c3c..8fb60667c 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.system import temporalio.workflow from temporalio.service import __version__ @@ -3345,14 +3346,24 @@ def _resolve_failure(self, err: BaseException) -> None: self._result_fut.set_result(None) def _apply_schedule_command(self) -> None: - payload = self._payload_converter.to_payload(self._input.input) command = self._instance._add_command() v = command.schedule_nexus_operation v.seq = self._seq v.endpoint = self._input.endpoint v.service = self._input.service v.operation = self._input.operation_name - v.input.CopyFrom(payload) + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if temporalio.nexus.system.is_system_operation(v.service, v.operation) + else self._payload_converter + ) + payload = payload_converter.to_payload(self._input.input) + if payload is None: + raise RuntimeError( + "Nexus operation input could not be converted to a payload" + ) + payload_message: temporalio.api.common.v1.Payload = payload + v.input.CopyFrom(payload_message) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py index ecaf300da..d02849bca 100644 --- a/tests/nexus/test_temporal_system_nexus.py +++ b/tests/nexus/test_temporal_system_nexus.py @@ -11,15 +11,15 @@ from google.protobuf.json_format import MessageToDict import temporalio.api.common.v1 +import temporalio.converter from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.converter import ExternalStorage, PayloadCodec -from temporalio.nexus.system import ( - WorkflowService, - WorkflowServiceSignalWithStartWorkflowExecutionInput, - WorkflowServiceSignalWithStartWorkflowExecutionOutput, +from temporalio.converter import ( + DefaultPayloadConverter, + ExternalStorage, + PayloadCodec, ) +from temporalio.nexus.system import generated from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner @@ -27,15 +27,17 @@ from tests.test_extstore import InMemoryTestDriver -@nexusrpc.handler.service_handler(service=WorkflowService) +@nexusrpc.handler.service_handler(service=generated.WorkflowService) class WorkflowServicePayloadHandler: @nexusrpc.handler.sync_operation async def signal_with_start_workflow_execution( self, _ctx: nexusrpc.handler.StartOperationContext, - request: WorkflowServiceSignalWithStartWorkflowExecutionInput, - ) -> WorkflowServiceSignalWithStartWorkflowExecutionOutput: - request_dict = request.model_dump(by_alias=True) + request: generated.WorkflowServiceSignalWithStartWorkflowExecutionInput, + ) -> generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput: + assert request.workflowId == "system-nexus-workflow-id" + assert request.signalName == "test-signal" + request_dict = dataclasses.asdict(request) for field_name in ("input", "signalInput"): payloads = request_dict[field_name]["payloads"] assert payloads[0]["externalPayloads"] @@ -50,8 +52,8 @@ async def signal_with_start_workflow_execution( ] assert "externalPayloads" not in search_attribute_payload assert "test-codec" not in search_attribute_payload["metadata"] - return WorkflowServiceSignalWithStartWorkflowExecutionOutput( - runId=f"{request.workflow_id}-run" + return generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput( + runId=f"{request.workflowId}-run" ) @@ -60,86 +62,84 @@ class SystemNexusCallerWithPayloadsWorkflow: @workflow.run async def run(self, task_queue: str) -> str: nexus_client = workflow.create_nexus_client( - service=WorkflowService, + service=generated.WorkflowService, endpoint=make_nexus_endpoint_name(task_queue), ) - request = WorkflowServiceSignalWithStartWorkflowExecutionInput.model_validate( - { - "namespace": "default", - "workflowId": "system-nexus-workflow-id", - "signalName": "test-signal", - "input": MessageToDict( - temporalio.api.common.v1.Payloads( - payloads=[ - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"workflow-input"', - ) - ] - ) - ), - "signalInput": MessageToDict( - temporalio.api.common.v1.Payloads( - payloads=[ - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"signal-input"', - ) - ] - ) - ), - "memo": MessageToDict( - temporalio.api.common.v1.Memo( - fields={ - "memo-key": temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"memo-value"', - ) - } + request = generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + namespace="default", + workflowId="system-nexus-workflow-id", + signalName="test-signal", + input=generated.Input( + payloads=[ + MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"workflow-input"', + ) ) - ), - "header": MessageToDict( - temporalio.api.common.v1.Header( - fields={ - "header-key": temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"header-value"', - ) - } + ] + ), + signalInput=generated.Input( + payloads=[ + MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"signal-input"', + ) ) - ), - "userMetadata": { - "summary": MessageToDict( + ] + ), + memo=generated.Memo( + fields={ + "memo-key": MessageToDict( temporalio.api.common.v1.Payload( metadata={"encoding": b"json/plain"}, - data=b'"summary-value"', + data=b'"memo-value"', ) - ), - "details": MessageToDict( + ) + } + ), + header=generated.Header( + fields={ + "header-key": MessageToDict( temporalio.api.common.v1.Payload( metadata={"encoding": b"json/plain"}, - data=b'"details-value"', + data=b'"header-value"', ) - ), - }, - "searchAttributes": { - "indexedFields": { - "custom-key": MessageToDict( - temporalio.api.common.v1.Payload( - metadata={"encoding": b"json/plain"}, - data=b'"search-attribute-value"', - ) + ) + } + ), + userMetadata=generated.UserMetadata( + summary=MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"summary-value"', + ) + ), + details=MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"details-value"', + ) + ), + ), + searchAttributes=generated.SearchAttributes( + indexedFields={ + "custom-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"search-attribute-value"', ) - } - }, - } + ) + } + ), ) handle = await nexus_client.start_operation( - WorkflowService.signal_with_start_workflow_execution, + generated.WorkflowService.signal_with_start_workflow_execution, request, ) result = await handle - return cast(str, result.run_id) + return cast(str, result.runId) class RejectOuterSystemNexusCodec(PayloadCodec): @@ -175,7 +175,42 @@ async def encode( async def decode( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - return list(payloads) + decoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + try: + body = json.loads(payload.data) + except json.JSONDecodeError: + body = None + if isinstance(body, dict) and { + "namespace", + "workflowId", + "signalName", + }.issubset(body): + raise RuntimeError( + "outer system nexus envelope should not be codec decoded" + ) + decoded.append(payload) + return decoded + + +class BadSystemNexusEnvelopePayloadConverter(DefaultPayloadConverter): + def to_payloads( + self, values: Sequence[object] + ) -> list[temporalio.api.common.v1.Payload]: + payloads: list[temporalio.api.common.v1.Payload] = [] + for value in values: + if isinstance( + value, generated.WorkflowServiceSignalWithStartWorkflowExecutionInput + ): + payloads.append( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'{"workflow_id":"bad-envelope"}', + ) + ) + else: + payloads.extend(super().to_payloads([value])) + return payloads async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( @@ -186,31 +221,43 @@ async def test_workflow_service_signal_with_start_nested_payloads_use_codec_with codec = RejectOuterSystemNexusCodec() driver = InMemoryTestDriver() - config = env.client.config() - config["data_converter"] = dataclasses.replace( - pydantic_data_converter, + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_converter_class=BadSystemNexusEnvelopePayloadConverter, payload_codec=codec, external_storage=ExternalStorage( drivers=[driver], payload_size_threshold=1, ), ) - client = Client(**config) + caller_client = Client(**caller_config) + handler_config = env.client.config() + handler_config["data_converter"] = temporalio.converter.default() + handler_client = Client(**handler_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) - async with Worker( - client, - task_queue=str(uuid.uuid4()), + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, workflows=[SystemNexusCallerWithPayloadsWorkflow], - nexus_service_handlers=[WorkflowServicePayloadHandler()], workflow_runner=UnsandboxedWorkflowRunner(), - ) as worker: - endpoint_name = make_nexus_endpoint_name(worker.task_queue) - await env.create_nexus_endpoint(endpoint_name, worker.task_queue) - result = await client.execute_workflow( + ) + handler_worker = Worker( + handler_client, + task_queue=handler_task_queue, + nexus_service_handlers=[WorkflowServicePayloadHandler()], + ) + + async with caller_worker, handler_worker: + endpoint_name = make_nexus_endpoint_name(handler_task_queue) + await env.create_nexus_endpoint(endpoint_name, handler_task_queue) + result = await caller_client.execute_workflow( SystemNexusCallerWithPayloadsWorkflow.run, - worker.task_queue, + handler_task_queue, id=str(uuid.uuid4()), - task_queue=worker.task_queue, + task_queue=caller_task_queue, ) assert result == "system-nexus-workflow-id-run" From ccf4ef395e654e72f3da8ce22a89b6aa9a8c7e43 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 9 Apr 2026 10:04:24 -0700 Subject: [PATCH 6/7] Typecheck generated nexus system models --- pyproject.toml | 1 - temporalio/nexus/system/__init__.py | 4 ++-- temporalio/nexus/system/_workflow_service_generated.py | 1 + temporalio/worker/_workflow_instance.py | 8 +------- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e65b0ed54..40bd7189e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -218,7 +218,6 @@ exclude = [ "temporalio/api", "temporalio/bridge/proto", "temporalio/bridge/_visitor.py", - "temporalio/nexus/system/_workflow_service_generated.py", "tests/worker/workflow_sandbox/testmodules/proto", ] diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py index 2f4dc7926..bb1a2f37e 100644 --- a/temporalio/nexus/system/__init__.py +++ b/temporalio/nexus/system/__init__.py @@ -24,7 +24,7 @@ Awaitable[temporalio.api.common.v1.Payload], ] -_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.JSONPlainPayloadConverter() +_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter def get_payload_rewriter( @@ -40,7 +40,7 @@ def is_system_operation(service: str, operation: str) -> bool: return get_payload_rewriter(service, operation) is not None -def get_payload_converter() -> temporalio.converter.EncodingPayloadConverter: +def get_payload_converter() -> temporalio.converter.PayloadConverter: """Return the fixed payload converter for system Nexus outer envelopes.""" return _SYSTEM_NEXUS_PAYLOAD_CONVERTER diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py index e2273b871..4ac22c76e 100644 --- a/temporalio/nexus/system/_workflow_service_generated.py +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -1,4 +1,5 @@ # Generated by nexus-rpc-gen. DO NOT EDIT! +# pyright: reportDeprecated=false from __future__ import annotations diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 8fb60667c..3454eb2ad 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -3357,13 +3357,7 @@ def _apply_schedule_command(self) -> None: if temporalio.nexus.system.is_system_operation(v.service, v.operation) else self._payload_converter ) - payload = payload_converter.to_payload(self._input.input) - if payload is None: - raise RuntimeError( - "Nexus operation input could not be converted to a payload" - ) - payload_message: temporalio.api.common.v1.Payload = payload - v.input.CopyFrom(payload_message) + v.input.CopyFrom(payload_converter.to_payload(self._input.input)) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout From 629567c34d2974e16132c50840648b1f732e19ac Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 9 Apr 2026 10:07:04 -0700 Subject: [PATCH 7/7] Use neutral payload variable names --- temporalio/bridge/worker.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 8539d695d..9f0a43b0e 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -289,9 +289,9 @@ def __init__( self._f = f async def visit_payload(self, payload: Payload) -> None: - rewritten_payload = (await self._f([payload]))[0] - if rewritten_payload is not payload: - payload.CopyFrom(rewritten_payload) + new_payload = (await self._f([payload]))[0] + if new_payload is not payload: + payload.CopyFrom(new_payload) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: @@ -331,12 +331,12 @@ async def _encode_completion_payloads( if rewrite is None: return await data_converter._encode_payload_sequence(payloads) - rewritten_payload = await rewrite( + new_payload = await rewrite( payload, data_converter._encode_payload_sequence, False, ) - return [rewritten_payload] + return [new_payload] async def decode_activation(