diff --git a/README.md b/README.md index 6e42a2019..736f04e8a 100644 --- a/README.md +++ b/README.md @@ -1933,6 +1933,7 @@ To build the SDK from source for use as a dependency, the following prerequisite * [uv](https://docs.astral.sh/uv/) * [Rust](https://www.rust-lang.org/) * [Protobuf Compiler](https://protobuf.dev/) +* [Node.js](https://nodejs.org/) Use `uv` to install `poe`: @@ -2074,6 +2075,12 @@ back from this downgrade, restore both of those files and run `uv sync --all-ext run for protobuf version 3 by setting the `TEMPORAL_TEST_PROTO3` env var to `1` prior to running tests. +The local build and lint flows also regenerate Temporal system Nexus models. By default this pulls +in `nexus-rpc-gen@0.1.0-alpha.4` via `npx`. To use an existing checkout instead, set +`TEMPORAL_NEXUS_RPC_GEN_DIR` to the `nexus-rpc-gen` repo root or its `src` directory before +running `poe build-develop`, `poe lint`, or `poe gen-protos`. The local checkout override path +also requires [`pnpm`](https://pnpm.io/) to be installed. + ### Style * Mostly [Google Style Guide](https://google.github.io/styleguide/pyguide.html). Notable exceptions: diff --git a/pyproject.toml b/pyproject.toml index 4bcd3f03e..40bd7189e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,12 +79,14 @@ gen-protos = [ { cmd = "uv run scripts/gen_protos.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] gen-protos-docker = [ { cmd = "uv run scripts/gen_protos_docker.py" }, { cmd = "uv run scripts/gen_payload_visitor.py" }, { cmd = "uv run scripts/gen_bridge_client.py" }, + { ref = "gen-nexus-system-models" }, { ref = "format" }, ] lint = [ @@ -102,6 +104,7 @@ lint-types = [ { cmd = "uv run mypy --namespace-packages --check-untyped-defs ." }, { cmd = "uv run basedpyright" }, ] +gen-nexus-system-models = "uv run scripts/gen_nexus_system_models.py" run-bench = "uv run python scripts/run_bench.py" test = "uv run pytest" @@ -139,14 +142,17 @@ environment = { PATH = "$PATH:$HOME/.cargo/bin", CARGO_NET_GIT_FETCH_WITH_CLI = ignore_missing_imports = true exclude = [ # Ignore generated code + 'build', 'temporalio/api', 'temporalio/bridge/proto', + 'temporalio/nexus/system/_workflow_service_generated.py', ] [tool.pydocstyle] convention = "google" # https://github.com/PyCQA/pydocstyle/issues/363#issuecomment-625563088 -match_dir = "^(?!(docs|scripts|tests|api|proto|\\.)).*" +match_dir = "^(?!(build|docs|scripts|tests|api|proto|\\.)).*" +match = "^(?!_workflow_service_generated\\.py$).*\\.py" add_ignore = [ # We like to wrap at a certain number of chars, even long summary sentences. # https://github.com/PyCQA/pydocstyle/issues/184 diff --git a/scripts/gen_nexus_system_models.py b/scripts/gen_nexus_system_models.py new file mode 100644 index 000000000..2009a43a7 --- /dev/null +++ b/scripts/gen_nexus_system_models.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import os +import subprocess +import sys +from pathlib import Path + +NEXUS_RPC_GEN_ENV_VAR = "TEMPORAL_NEXUS_RPC_GEN_DIR" +NEXUS_RPC_GEN_VERSION = "0.1.0-alpha.4" + + +def main() -> None: + repo_root = Path(__file__).resolve().parent.parent + # TODO: Remove the local .nexusrpc.yaml shim once the upstream API repo + # checks in the Nexus definition we can consume directly. + override_root = normalize_nexus_rpc_gen_root( + Path.cwd(), env_value=NEXUS_RPC_GEN_ENV_VAR + ) + input_schema = ( + repo_root + / "temporalio" + / "nexus" + / "system" + / "_workflow_service.nexusrpc.yaml" + ) + output_file = ( + repo_root / "temporalio" / "nexus" / "system" / "_workflow_service_generated.py" + ) + + if not input_schema.is_file(): + raise RuntimeError(f"Expected Nexus schema at {input_schema}") + + run_nexus_rpc_gen( + override_root=override_root, + output_file=output_file, + input_schema=input_schema, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "check", + "--select", + "I", + "--fix", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + subprocess.run( + [ + "uv", + "run", + "ruff", + "format", + str(output_file), + ], + cwd=repo_root, + check=True, + ) + + +def run_nexus_rpc_gen( + *, override_root: Path | None, output_file: Path, input_schema: Path +) -> None: + common_args = [ + "--lang", + "py", + "--out-file", + str(output_file), + "--temporal-nexus-payload-codec-support", + str(input_schema), + ] + if override_root is None: + subprocess.run( + ["npx", "--yes", f"nexus-rpc-gen@{NEXUS_RPC_GEN_VERSION}", *common_args], + check=True, + ) + return + + subprocess.run( + [ + "node", + "packages/nexus-rpc-gen/dist/index.js", + *common_args, + ], + cwd=override_root, + check=True, + ) + + +def normalize_nexus_rpc_gen_root(base_dir: Path, env_value: str) -> Path | None: + raw_root = env_get(env_value) + if raw_root is None: + return None + candidate = Path(raw_root) + if not candidate.is_absolute(): + candidate = base_dir / candidate + candidate = candidate.resolve() + if (candidate / "package.json").is_file() and (candidate / "packages").is_dir(): + return candidate + if (candidate / "src" / "package.json").is_file(): + return candidate / "src" + raise RuntimeError( + f"{NEXUS_RPC_GEN_ENV_VAR} must point to the nexus-rpc-gen repo root or its src directory" + ) + + +def env_get(name: str) -> str | None: + return os.environ.get(name) + + +if __name__ == "__main__": + try: + main() + except Exception as err: + print(f"Failed to generate Nexus system models: {err}", file=sys.stderr) + raise diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c2e426d28..9f0a43b0e 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -20,7 +20,9 @@ import temporalio.bridge.runtime import temporalio.bridge.temporal_sdk_bridge import temporalio.converter +import temporalio.nexus.system from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.api.enums.v1.command_type_pb2 import CommandType from temporalio.bridge._visitor import VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, @@ -28,6 +30,7 @@ from temporalio.bridge.temporal_sdk_bridge import ( PollShutdownError, # type: ignore # noqa: F401 ) +from temporalio.worker import _command_aware_visitor from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor @@ -279,7 +282,10 @@ async def finalize_shutdown(self) -> None: class _Visitor(VisitorFunctions): - def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]]): + def __init__( + self, + f: Callable[[Sequence[Payload]], Awaitable[list[Payload]]], + ): self._f = f async def visit_payload(self, payload: Payload) -> None: @@ -297,6 +303,42 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: payloads.extend(new_payloads) +async def _encode_completion_payloads( + data_converter: temporalio.converter.DataConverter, + payloads: Sequence[Payload], +) -> list[Payload]: + if len(payloads) != 1: + return await data_converter._encode_payload_sequence(payloads) + + # A single payload may be the outer envelope for a system Nexus operation. + # In that case we leave the envelope itself unencoded so the server can read + # it, but still route any nested Temporal payloads through normal payload + # processing via the generated operation-specific rewriter. + payload = payloads[0] + command_info = _command_aware_visitor.current_command_info.get() + if ( + command_info is None + or command_info.command_type + != CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION + or not command_info.nexus_service + or not command_info.nexus_operation + ): + return await data_converter._encode_payload_sequence(payloads) + + rewrite = temporalio.nexus.system.get_payload_rewriter( + command_info.nexus_service, command_info.nexus_operation + ) + if rewrite is None: + return await data_converter._encode_payload_sequence(payloads) + + new_payload = await rewrite( + payload, + data_converter._encode_payload_sequence, + False, + ) + return [new_payload] + + async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, data_converter: temporalio.converter.DataConverter, @@ -316,4 +358,9 @@ async def encode_completion( """Encode all payloads in the completion.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(data_converter._encode_payload_sequence), completion) + ).visit( + _Visitor( + lambda payloads: _encode_completion_payloads(data_converter, payloads) + ), + completion, + ) diff --git a/temporalio/nexus/system/__init__.py b/temporalio/nexus/system/__init__.py new file mode 100644 index 000000000..bb1a2f37e --- /dev/null +++ b/temporalio/nexus/system/__init__.py @@ -0,0 +1,53 @@ +"""Generated system Nexus service models. + +This package contains code generated from Temporal's system Nexus schemas. +Higher-level ergonomic APIs may wrap these generated types. +""" + +from collections.abc import Awaitable, Callable, Sequence + +import temporalio.api.common.v1 +import temporalio.converter + +from . import _workflow_service_generated as generated +from ._workflow_service_generated import __temporal_nexus_payload_rewriters__ + +TemporalNexusPayloadRewriter = Callable[ + [ + temporalio.api.common.v1.Payload, + Callable[ + [Sequence[temporalio.api.common.v1.Payload]], + Awaitable[list[temporalio.api.common.v1.Payload]], + ], + bool, + ], + Awaitable[temporalio.api.common.v1.Payload], +] + +_SYSTEM_NEXUS_PAYLOAD_CONVERTER = temporalio.converter.default().payload_converter + + +def get_payload_rewriter( + service: str, + operation: str, +) -> TemporalNexusPayloadRewriter | None: + """Return the generated nested-payload rewriter for a system Nexus operation.""" + return __temporal_nexus_payload_rewriters__.get((service, operation)) + + +def is_system_operation(service: str, operation: str) -> bool: + """Return whether a Nexus operation uses the generated system envelope.""" + return get_payload_rewriter(service, operation) is not None + + +def get_payload_converter() -> temporalio.converter.PayloadConverter: + """Return the fixed payload converter for system Nexus outer envelopes.""" + return _SYSTEM_NEXUS_PAYLOAD_CONVERTER + + +__all__ = ( + "generated", + "get_payload_converter", + "get_payload_rewriter", + "is_system_operation", +) diff --git a/temporalio/nexus/system/_workflow_service.nexusrpc.yaml b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml new file mode 100644 index 000000000..edea24b4e --- /dev/null +++ b/temporalio/nexus/system/_workflow_service.nexusrpc.yaml @@ -0,0 +1,11 @@ +# TODO: Remove this local shim once the upstream API repo checks in the Nexus +# definition and the generator can consume it directly. +nexusrpc: 1.0.0 +services: + WorkflowService: + operations: + SignalWithStartWorkflowExecution: + input: + $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionRequest + output: + $ref: ../../bridge/sdk-core/crates/common/protos/api_upstream/openapi/openapiv3.yaml#/components/schemas/SignalWithStartWorkflowExecutionResponse diff --git a/temporalio/nexus/system/_workflow_service_generated.py b/temporalio/nexus/system/_workflow_service_generated.py new file mode 100644 index 000000000..4ac22c76e --- /dev/null +++ b/temporalio/nexus/system/_workflow_service_generated.py @@ -0,0 +1,843 @@ +# Generated by nexus-rpc-gen. DO NOT EDIT! +# pyright: reportDeprecated=false + +from __future__ import annotations + +import collections.abc +import json +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional + +from google.protobuf.json_format import MessageToDict, ParseDict +from nexusrpc import Operation, service + +import temporalio.api.common.v1 + + +@dataclass +class Header: + """Contains metadata that can be attached to a variety of requests, like starting a + workflow, and + can be propagated between, for example, workflows and activities. + """ + + fields: Optional[Dict[str, Any]] = None + + +@dataclass +class Input: + """Serialized arguments to the workflow. These are passed as arguments to the workflow + function. + + See `Payload` + + Serialized value(s) to provide with the signal + """ + + payloads: Optional[List[Any]] = None + + +@dataclass +class BatchJob: + """A link to a built-in batch job. + Batch jobs can be used to perform operations on a set of workflows (e.g. terminate, + signal, cancel, etc). + This link can be put on workflow history events generated by actions taken by a batch job. + """ + + jobId: Optional[str] = None + + +class EventType(Enum): + EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY = ( + "EVENT_TYPE_ACTIVITY_PROPERTIES_MODIFIED_EXTERNALLY" + ) + EVENT_TYPE_ACTIVITY_TASK_CANCELED = "EVENT_TYPE_ACTIVITY_TASK_CANCELED" + EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED = ( + "EVENT_TYPE_ACTIVITY_TASK_CANCEL_REQUESTED" + ) + EVENT_TYPE_ACTIVITY_TASK_COMPLETED = "EVENT_TYPE_ACTIVITY_TASK_COMPLETED" + EVENT_TYPE_ACTIVITY_TASK_FAILED = "EVENT_TYPE_ACTIVITY_TASK_FAILED" + EVENT_TYPE_ACTIVITY_TASK_SCHEDULED = "EVENT_TYPE_ACTIVITY_TASK_SCHEDULED" + EVENT_TYPE_ACTIVITY_TASK_STARTED = "EVENT_TYPE_ACTIVITY_TASK_STARTED" + EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT = "EVENT_TYPE_ACTIVITY_TASK_TIMED_OUT" + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_CANCELED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_COMPLETED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_STARTED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TERMINATED" + ) + EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT = ( + "EVENT_TYPE_CHILD_WORKFLOW_EXECUTION_TIMED_OUT" + ) + EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( + "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_CANCEL_REQUESTED" + ) + EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED = ( + "EVENT_TYPE_EXTERNAL_WORKFLOW_EXECUTION_SIGNALED" + ) + EVENT_TYPE_MARKER_RECORDED = "EVENT_TYPE_MARKER_RECORDED" + EVENT_TYPE_NEXUS_OPERATION_CANCELED = "EVENT_TYPE_NEXUS_OPERATION_CANCELED" + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED" + ) + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED" + ) + EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED = ( + "EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED" + ) + EVENT_TYPE_NEXUS_OPERATION_COMPLETED = "EVENT_TYPE_NEXUS_OPERATION_COMPLETED" + EVENT_TYPE_NEXUS_OPERATION_FAILED = "EVENT_TYPE_NEXUS_OPERATION_FAILED" + EVENT_TYPE_NEXUS_OPERATION_SCHEDULED = "EVENT_TYPE_NEXUS_OPERATION_SCHEDULED" + EVENT_TYPE_NEXUS_OPERATION_STARTED = "EVENT_TYPE_NEXUS_OPERATION_STARTED" + EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT = "EVENT_TYPE_NEXUS_OPERATION_TIMED_OUT" + EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_REQUEST_CANCEL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_SIGNAL_EXTERNAL_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED = ( + "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_FAILED" + ) + EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED = ( + "EVENT_TYPE_START_CHILD_WORKFLOW_EXECUTION_INITIATED" + ) + EVENT_TYPE_TIMER_CANCELED = "EVENT_TYPE_TIMER_CANCELED" + EVENT_TYPE_TIMER_FIRED = "EVENT_TYPE_TIMER_FIRED" + EVENT_TYPE_TIMER_STARTED = "EVENT_TYPE_TIMER_STARTED" + EVENT_TYPE_UNSPECIFIED = "EVENT_TYPE_UNSPECIFIED" + EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES = ( + "EVENT_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED = "EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED" + EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED = "EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED" + EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_FAILED = "EVENT_TYPE_WORKFLOW_EXECUTION_FAILED" + EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_OPTIONS_UPDATED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_PAUSED" + EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED = "EVENT_TYPE_WORKFLOW_EXECUTION_SIGNALED" + EVENT_TYPE_WORKFLOW_EXECUTION_STARTED = "EVENT_TYPE_WORKFLOW_EXECUTION_STARTED" + EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT = "EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT" + EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED = "EVENT_TYPE_WORKFLOW_EXECUTION_UNPAUSED" + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ACCEPTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_ADMITTED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_COMPLETED" + ) + EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED = ( + "EVENT_TYPE_WORKFLOW_EXECUTION_UPDATE_REJECTED" + ) + EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED = "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED" + EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY = ( + "EVENT_TYPE_WORKFLOW_PROPERTIES_MODIFIED_EXTERNALLY" + ) + EVENT_TYPE_WORKFLOW_TASK_COMPLETED = "EVENT_TYPE_WORKFLOW_TASK_COMPLETED" + EVENT_TYPE_WORKFLOW_TASK_FAILED = "EVENT_TYPE_WORKFLOW_TASK_FAILED" + EVENT_TYPE_WORKFLOW_TASK_SCHEDULED = "EVENT_TYPE_WORKFLOW_TASK_SCHEDULED" + EVENT_TYPE_WORKFLOW_TASK_STARTED = "EVENT_TYPE_WORKFLOW_TASK_STARTED" + EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT = "EVENT_TYPE_WORKFLOW_TASK_TIMED_OUT" + + +@dataclass +class EventRef: + """EventReference is a direct reference to a history event through the event ID.""" + + eventId: Optional[str] = None + eventType: Optional[EventType] = None + + +@dataclass +class RequestIDRef: + """RequestIdReference is a indirect reference to a history event through the request ID.""" + + eventType: Optional[EventType] = None + requestId: Optional[str] = None + + +@dataclass +class WorkflowEvent: + eventRef: Optional[EventRef] = None + namespace: Optional[str] = None + requestIdRef: Optional[RequestIDRef] = None + runId: Optional[str] = None + workflowId: Optional[str] = None + + +@dataclass +class Openapiv3: + """Link can be associated with history events. It might contain information about an + external entity + related to the history event. For example, workflow A makes a Nexus call that starts + workflow B: + in this case, a history event in workflow A could contain a Link to the workflow started + event in + workflow B, and vice-versa. + """ + + batchJob: Optional[BatchJob] = None + workflowEvent: Optional[WorkflowEvent] = None + + +@dataclass +class Memo: + """A user-defined set of *unindexed* fields that are exposed when listing/searching workflows""" + + fields: Optional[Dict[str, Any]] = None + + +@dataclass +class Priority: + """Priority metadata + + Priority contains metadata that controls relative ordering of task processing + when tasks are backed up in a queue. Initially, Priority will be used in + matching (workflow and activity) task queues. Later it may be used in history + task queues and in rate limiting decisions. + + Priority is attached to workflows and activities. By default, activities + inherit Priority from the workflow that created them, but may override fields + when an activity is started or modified. + + Despite being named "Priority", this message also contains fields that + control "fairness" mechanisms. + + For all fields, the field not present or equal to zero/empty string means to + inherit the value from the calling workflow, or if there is no calling + workflow, then use the default value. + + For all fields other than fairness_key, the zero value isn't meaningful so + there's no confusion between inherit/default and a meaningful value. For + fairness_key, the empty string will be interpreted as "inherit". This means + that if a workflow has a non-empty fairness key, you can't override the + fairness key of its activity to the empty string. + + The overall semantics of Priority are: + 1. First, consider "priority": higher priority (lower number) goes first. + 2. Then, consider fairness: try to dispatch tasks for different fairness keys + in proportion to their weight. + + Applications may use any subset of mechanisms that are useful to them and + leave the other fields to use default values. + + Not all queues in the system may support the "full" semantics of all priority + fields. (Currently only support in matching task queues is planned.) + """ + + fairnessKey: Optional[str] = None + """Fairness key is a short string that's used as a key for a fairness + balancing mechanism. It may correspond to a tenant id, or to a fixed + string like "high" or "low". The default is the empty string. + + The fairness mechanism attempts to dispatch tasks for a given key in + proportion to its weight. For example, using a thousand distinct tenant + ids, each with a weight of 1.0 (the default) will result in each tenant + getting a roughly equal share of task dispatch throughput. + + (Note: this does not imply equal share of worker capacity! Fairness + decisions are made based on queue statistics, not + current worker load.) + + As another example, using keys "high" and "low" with weight 9.0 and 1.0 + respectively will prefer dispatching "high" tasks over "low" tasks at a + 9:1 ratio, while allowing either key to use all worker capacity if the + other is not present. + + All fairness mechanisms, including rate limits, are best-effort and + probabilistic. The results may not match what a "perfect" algorithm with + infinite resources would produce. The more unique keys are used, the less + accurate the results will be. + + Fairness keys are limited to 64 bytes. + """ + fairnessWeight: Optional[float] = None + """Fairness weight for a task can come from multiple sources for + flexibility. From highest to lowest precedence: + 1. Weights for a small set of keys can be overridden in task queue + configuration with an API. + 2. It can be attached to the workflow/activity in this field. + 3. The default weight of 1.0 will be used. + + Weight values are clamped to the range [0.001, 1000]. + """ + priorityKey: Optional[int] = None + """Priority key is a positive integer from 1 to n, where smaller integers + correspond to higher priorities (tasks run sooner). In general, tasks in + a queue should be processed in close to priority order, although small + deviations are possible. + + The maximum priority value (minimum priority) is determined by server + configuration, and defaults to 5. + + If priority is not present (or zero), then the effective priority will be + the default priority, which is calculated by (min+max)/2. With the + default max of 5, and min of 1, that comes out to 3. + """ + + +@dataclass +class RetryPolicy: + """Retry policy for the workflow + + How retries ought to be handled, usable by both workflows and activities + """ + + backoffCoefficient: Optional[float] = None + """Coefficient used to calculate the next retry interval. + The next retry interval is previous interval multiplied by the coefficient. + Must be 1 or larger. + """ + initialInterval: Optional[str] = None + """Interval of the first retry. If retryBackoffCoefficient is 1.0 then it is used for all + retries. + """ + maximumAttempts: Optional[int] = None + """Maximum number of attempts. When exceeded the retries stop even if not expired yet. + 1 disables retries. 0 means unlimited (up to the timeouts) + """ + maximumInterval: Optional[str] = None + """Maximum interval between retries. Exponential backoff leads to interval increase. + This value is the cap of the increase. Default is 100x of the initial interval. + """ + nonRetryableErrorTypes: Optional[List[str]] = None + """Non-Retryable errors types. Will stop retrying if the error type matches this list. Note + that + this is not a substring match, the error *type* (not message) must match exactly. + """ + + +@dataclass +class SearchAttributes: + """A user-defined set of *indexed* fields that are used/exposed when listing/searching + workflows. + The payload is not serialized in a user-defined way. + """ + + indexedFields: Optional[Dict[str, Any]] = None + + +class Kind(Enum): + """Default: TASK_QUEUE_KIND_NORMAL.""" + + TASK_QUEUE_KIND_NORMAL = "TASK_QUEUE_KIND_NORMAL" + TASK_QUEUE_KIND_STICKY = "TASK_QUEUE_KIND_STICKY" + TASK_QUEUE_KIND_UNSPECIFIED = "TASK_QUEUE_KIND_UNSPECIFIED" + + +@dataclass +class TaskQueue: + """The task queue to start this workflow on, if it will be started + + See https://docs.temporal.io/docs/concepts/task-queues/ + """ + + kind: Optional[Kind] = None + """Default: TASK_QUEUE_KIND_NORMAL.""" + + name: Optional[str] = None + normalName: Optional[str] = None + """Iff kind == TASK_QUEUE_KIND_STICKY, then this field contains the name of + the normal task queue that the sticky worker is running on. + """ + + +@dataclass +class UserMetadata: + """Metadata on the workflow if it is started. This is carried over to the + WorkflowExecutionInfo + for use by user interfaces to display the fixed as-of-start summary and details of the + workflow. + + Information a user can set, often for use by user interfaces. + """ + + details: Any + """Long-form text that provides details. This payload should be a "json/plain"-encoded + payload + that is a single JSON string for use in user interfaces. User interface formatting may + apply to + this text in common use. The payload data section is limited to 20000 bytes by default. + """ + summary: Any + """Short-form text that provides a summary. This payload should be a "json/plain"-encoded + payload + that is a single JSON string for use in user interfaces. User interface formatting may + not + apply to this text when used in "title" situations. The payload data section is limited + to 400 + bytes by default. + """ + + +class VersioningOverrideBehavior(Enum): + """Required. + Deprecated. Use `override`. + """ + + VERSIONING_BEHAVIOR_AUTO_UPGRADE = "VERSIONING_BEHAVIOR_AUTO_UPGRADE" + VERSIONING_BEHAVIOR_PINNED = "VERSIONING_BEHAVIOR_PINNED" + VERSIONING_BEHAVIOR_UNSPECIFIED = "VERSIONING_BEHAVIOR_UNSPECIFIED" + + +@dataclass +class Deployment: + """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. + Identifies the worker deployment to pin the workflow to. + Deprecated. Use `override.pinned.version`. + + `Deployment` identifies a deployment of Temporal workers. The combination of deployment + series + name + build ID serves as the identifier. User can use `WorkerDeploymentOptions` in their + worker + programs to specify these values. + Deprecated. + """ + + buildId: Optional[str] = None + """Build ID changes with each version of the worker when the worker program code and/or + config + changes. + """ + seriesName: Optional[str] = None + """Different versions of the same worker service/application are related together by having + a + shared series name. + Out of all deployments of a series, one can be designated as the current deployment, + which + receives new workflow executions and new tasks of workflows with + `VERSIONING_BEHAVIOR_AUTO_UPGRADE` versioning behavior. + """ + + +class PinnedBehavior(Enum): + """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. + See `PinnedOverrideBehavior` for details. + """ + + PINNED_OVERRIDE_BEHAVIOR_PINNED = "PINNED_OVERRIDE_BEHAVIOR_PINNED" + PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED = "PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED" + + +@dataclass +class Version: + """Specifies the Worker Deployment Version to pin this workflow to. + Required if the target workflow is not already pinned to a version. + + If omitted and the target workflow is already pinned, the effective + pinned version will be the existing pinned version. + + If omitted and the target workflow is not pinned, the override request + will be rejected with a PreconditionFailed error. + + A Worker Deployment Version (Version, for short) represents a + version of workers within a Worker Deployment. (see documentation of + WorkerDeploymentVersionInfo) + Version records are created in Temporal server automatically when their + first poller arrives to the server. + Experimental. Worker Deployment Versions are experimental and might significantly change + in the future. + """ + + buildId: Optional[str] = None + """A unique identifier for this Version within the Deployment it is a part of. + Not necessarily unique within the namespace. + The combination of `deployment_name` and `build_id` uniquely identifies this + Version within the namespace, because Deployment names are unique within a namespace. + """ + deploymentName: Optional[str] = None + """Identifies the Worker Deployment this Version is part of.""" + + +@dataclass +class Pinned: + """Override the workflow to have Pinned behavior.""" + + behavior: Optional[PinnedBehavior] = None + """Defaults to PINNED_OVERRIDE_BEHAVIOR_UNSPECIFIED. + See `PinnedOverrideBehavior` for details. + """ + version: Optional[Version] = None + """Specifies the Worker Deployment Version to pin this workflow to. + Required if the target workflow is not already pinned to a version. + + If omitted and the target workflow is already pinned, the effective + pinned version will be the existing pinned version. + + If omitted and the target workflow is not pinned, the override request + will be rejected with a PreconditionFailed error. + """ + + +@dataclass +class VersioningOverride: + """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task + completion. + To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. + + Used to override the versioning behavior (and pinned deployment version, if applicable) + of a + specific workflow execution. If set, this override takes precedence over worker-sent + values. + See `WorkflowExecutionInfo.VersioningInfo` for more information. + + To remove the override, call `UpdateWorkflowExecutionOptions` with a null + `VersioningOverride`, and use the `update_mask` to indicate that it should be mutated. + + Pinned behavior overrides are automatically inherited by child workflows, workflow + retries, continue-as-new + workflows, and cron workflows. + """ + + autoUpgrade: Optional[bool] = None + """Override the workflow to have AutoUpgrade behavior.""" + + behavior: Optional[VersioningOverrideBehavior] = None + """Required. + Deprecated. Use `override`. + """ + deployment: Optional[Deployment] = None + """Required if behavior is `PINNED`. Must be null if behavior is `AUTO_UPGRADE`. + Identifies the worker deployment to pin the workflow to. + Deprecated. Use `override.pinned.version`. + """ + pinned: Optional[Pinned] = None + """Override the workflow to have Pinned behavior.""" + + pinnedVersion: Optional[str] = None + """Required if behavior is `PINNED`. Must be absent if behavior is not `PINNED`. + Identifies the worker deployment version to pin the workflow to, in the format + ".". + Deprecated. Use `override.pinned.version`. + """ + + +class WorkflowIDConflictPolicy(Enum): + """Defines how to resolve a workflow id conflict with a *running* workflow. + The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. + Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* + workflow. + """ + + WORKFLOW_ID_CONFLICT_POLICY_FAIL = "WORKFLOW_ID_CONFLICT_POLICY_FAIL" + WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING = ( + "WORKFLOW_ID_CONFLICT_POLICY_TERMINATE_EXISTING" + ) + WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED = "WORKFLOW_ID_CONFLICT_POLICY_UNSPECIFIED" + WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING = ( + "WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING" + ) + + +class WorkflowIDReusePolicy(Enum): + """Defines whether to allow re-using the workflow id from a previously *closed* workflow. + The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* + workflow. + """ + + WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE = ( + "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE" + ) + WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY = ( + "WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE_FAILED_ONLY" + ) + WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE = ( + "WORKFLOW_ID_REUSE_POLICY_REJECT_DUPLICATE" + ) + WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING = ( + "WORKFLOW_ID_REUSE_POLICY_TERMINATE_IF_RUNNING" + ) + WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED = "WORKFLOW_ID_REUSE_POLICY_UNSPECIFIED" + + +@dataclass +class WorkflowType: + """Represents the identifier used by a workflow author to define the workflow. Typically, + the + name of a function. This is sometimes referred to as the workflow's "name" + """ + + name: Optional[str] = None + + +@dataclass +class WorkflowServiceSignalWithStartWorkflowExecutionInput: + control: Optional[str] = None + """Deprecated.""" + + cronSchedule: Optional[str] = None + """See https://docs.temporal.io/docs/content/what-is-a-temporal-cron-job/""" + + header: Optional[Header] = None + identity: Optional[str] = None + """The identity of the worker/client""" + + input: Optional[Input] = None + """Serialized arguments to the workflow. These are passed as arguments to the workflow + function. + """ + links: Optional[List[Openapiv3]] = None + """Links to be associated with the WorkflowExecutionStarted and WorkflowExecutionSignaled + events. + """ + memo: Optional[Memo] = None + namespace: Optional[str] = None + priority: Optional[Priority] = None + """Priority metadata""" + + requestId: Optional[str] = None + """Used to de-dupe signal w/ start requests""" + + retryPolicy: Optional[RetryPolicy] = None + """Retry policy for the workflow""" + + searchAttributes: Optional[SearchAttributes] = None + signalInput: Optional[Input] = None + """Serialized value(s) to provide with the signal""" + + signalName: Optional[str] = None + """The workflow author-defined name of the signal to send to the workflow""" + + taskQueue: Optional[TaskQueue] = None + """The task queue to start this workflow on, if it will be started""" + + userMetadata: Optional[UserMetadata] = None + """Metadata on the workflow if it is started. This is carried over to the + WorkflowExecutionInfo + for use by user interfaces to display the fixed as-of-start summary and details of the + workflow. + """ + versioningOverride: Optional[VersioningOverride] = None + """If set, takes precedence over the Versioning Behavior sent by the SDK on Workflow Task + completion. + To unset the override after the workflow is running, use UpdateWorkflowExecutionOptions. + """ + workflowExecutionTimeout: Optional[str] = None + """Total workflow execution timeout including retries and continue as new""" + + workflowId: Optional[str] = None + workflowIdConflictPolicy: Optional[WorkflowIDConflictPolicy] = None + """Defines how to resolve a workflow id conflict with a *running* workflow. + The default policy is WORKFLOW_ID_CONFLICT_POLICY_USE_EXISTING. + Note that WORKFLOW_ID_CONFLICT_POLICY_FAIL is an invalid option. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *closed* + workflow. + """ + workflowIdReusePolicy: Optional[WorkflowIDReusePolicy] = None + """Defines whether to allow re-using the workflow id from a previously *closed* workflow. + The default policy is WORKFLOW_ID_REUSE_POLICY_ALLOW_DUPLICATE. + + See `workflow_id_reuse_policy` for handling a workflow id duplication with a *running* + workflow. + """ + workflowRunTimeout: Optional[str] = None + """Timeout of a single workflow run""" + + workflowStartDelay: Optional[str] = None + """Time to wait before dispatching the first workflow task. Cannot be used with + `cron_schedule`. + Note that the signal will be delivered with the first workflow task. If the workflow + gets + another SignalWithStartWorkflow before the delay a workflow task will be dispatched + immediately + and the rest of the delay period will be ignored, even if that request also had a delay. + Signal via SignalWorkflowExecution will not unblock the workflow. + """ + workflowTaskTimeout: Optional[str] = None + """Timeout of a single workflow task""" + + workflowType: Optional[WorkflowType] = None + + +@dataclass +class WorkflowServiceSignalWithStartWorkflowExecutionOutput: + runId: Optional[str] = None + """The run id of the workflow that was started - or just signaled, if it was already running.""" + + started: Optional[bool] = None + """If true, a new workflow was started.""" + + +@service +class WorkflowService: + signal_with_start_workflow_execution: Operation[ + WorkflowServiceSignalWithStartWorkflowExecutionInput, + WorkflowServiceSignalWithStartWorkflowExecutionOutput, + ] = Operation(name="SignalWithStartWorkflowExecution") + + +class _TemporalNexusPayloadRewriter: + def __init__( + self, + payload_visitor: collections.abc.Callable[ + [collections.abc.Sequence[temporalio.api.common.v1.Payload]], + collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], + ], + visit_search_attributes: bool = False, + ): + self._payload_visitor = payload_visitor + self._visit_search_attributes = visit_search_attributes + + async def _rewrite_payload_json(self, value: dict) -> dict: + payload = ParseDict(value, temporalio.api.common.v1.Payload()) + [rewritten_payload] = await self._payload_visitor([payload]) + return MessageToDict(rewritten_payload) + + async def _rewrite_payloads_json(self, value: dict) -> dict: + payloads = ParseDict(value, temporalio.api.common.v1.Payloads()) + rewritten_payloads = await self._payload_visitor(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(rewritten_payloads) + return MessageToDict(payloads) + + async def _rewrite_payload_map_json(self, message_type: type, value: dict) -> dict: + message = message_type() + keys = list(value.keys()) + rewritten_payloads = await self._payload_visitor( + [ParseDict(value[key], temporalio.api.common.v1.Payload()) for key in keys] + ) + for key, rewritten_payload in zip(keys, rewritten_payloads): + message.fields[key].CopyFrom(rewritten_payload) + return MessageToDict(message).get("fields", {}) + + async def _temporal_nexus_rewrite_header_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("fields") is not None: + rewritten["fields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.Header, rewritten["fields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_input_json(self, value: dict) -> dict: + return await self._rewrite_payloads_json(value) + + async def _temporal_nexus_rewrite_memo_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("fields") is not None: + rewritten["fields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.Memo, rewritten["fields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_search_attributes_json(self, value: dict) -> dict: + if not self._visit_search_attributes: + return value + rewritten = dict(value) + if rewritten.get("indexedFields") is not None: + rewritten["indexedFields"] = await self._rewrite_payload_map_json( + temporalio.api.common.v1.SearchAttributes, rewritten["indexedFields"] + ) + return rewritten + + async def _temporal_nexus_rewrite_user_metadata_json(self, value: dict) -> dict: + rewritten = dict(value) + if rewritten.get("details") is not None: + rewritten["details"] = await self._rewrite_payload_json( + rewritten["details"] + ) + if rewritten.get("summary") is not None: + rewritten["summary"] = await self._rewrite_payload_json( + rewritten["summary"] + ) + return rewritten + + async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + self, value: dict + ) -> dict: + rewritten = dict(value) + if rewritten.get("header") is not None: + rewritten["header"] = await self._temporal_nexus_rewrite_header_json( + rewritten["header"] + ) + if rewritten.get("input") is not None: + rewritten["input"] = await self._temporal_nexus_rewrite_input_json( + rewritten["input"] + ) + if rewritten.get("memo") is not None: + rewritten["memo"] = await self._temporal_nexus_rewrite_memo_json( + rewritten["memo"] + ) + if rewritten.get("searchAttributes") is not None: + rewritten[ + "searchAttributes" + ] = await self._temporal_nexus_rewrite_search_attributes_json( + rewritten["searchAttributes"] + ) + if rewritten.get("signalInput") is not None: + rewritten["signalInput"] = await self._temporal_nexus_rewrite_input_json( + rewritten["signalInput"] + ) + if rewritten.get("userMetadata") is not None: + rewritten[ + "userMetadata" + ] = await self._temporal_nexus_rewrite_user_metadata_json( + rewritten["userMetadata"] + ) + return rewritten + + +async def _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input( + payload: temporalio.api.common.v1.Payload, + payload_visitor: collections.abc.Callable[ + [collections.abc.Sequence[temporalio.api.common.v1.Payload]], + collections.abc.Awaitable[list[temporalio.api.common.v1.Payload]], + ], + visit_search_attributes: bool = False, +) -> temporalio.api.common.v1.Payload: + try: + value = json.loads(payload.data) + except json.JSONDecodeError: + return payload + if not isinstance(value, dict): + return payload + rewriter = _TemporalNexusPayloadRewriter(payload_visitor, visit_search_attributes) + rewritten = await rewriter._temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input_json( + value + ) + return temporalio.api.common.v1.Payload( + metadata=dict(payload.metadata), + data=json.dumps(rewritten, separators=(",", ":"), sort_keys=True).encode(), + ) + + +__temporal_nexus_payload_rewriters__ = { + ( + "WorkflowService", + "SignalWithStartWorkflowExecution", + ): _temporal_nexus_rewrite_workflow_service_signal_with_start_workflow_execution_input, +} diff --git a/temporalio/worker/_command_aware_visitor.py b/temporalio/worker/_command_aware_visitor.py index 2d7f3990b..85c38ff06 100644 --- a/temporalio/worker/_command_aware_visitor.py +++ b/temporalio/worker/_command_aware_visitor.py @@ -31,6 +31,8 @@ class CommandInfo: command_type: CommandType.ValueType command_seq: int + nexus_service: str | None = None + nexus_operation: str | None = None current_command_info: contextvars.ContextVar[CommandInfo | None] = ( @@ -81,7 +83,12 @@ async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( async def _visit_coresdk_workflow_commands_ScheduleNexusOperation( self, fs: VisitorFunctions, o: ScheduleNexusOperation ) -> None: - with current_command(CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, o.seq): + with current_command( + CommandType.COMMAND_TYPE_SCHEDULE_NEXUS_OPERATION, + o.seq, + nexus_service=o.service, + nexus_operation=o.operation, + ): await super()._visit_coresdk_workflow_commands_ScheduleNexusOperation(fs, o) # Workflow activation jobs with payloads @@ -150,11 +157,20 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperation( @contextmanager def current_command( - command_type: CommandType.ValueType, command_seq: int + command_type: CommandType.ValueType, + command_seq: int, + *, + nexus_service: str | None = None, + nexus_operation: str | None = None, ) -> Iterator[None]: """Context manager for setting command info.""" token = current_command_info.set( - CommandInfo(command_type=command_type, command_seq=command_seq) + CommandInfo( + command_type=command_type, + command_seq=command_seq, + nexus_service=nexus_service, + nexus_operation=nexus_operation, + ) ) try: yield diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1bfa77c3c..3454eb2ad 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.nexus.system import temporalio.workflow from temporalio.service import __version__ @@ -3345,14 +3346,18 @@ def _resolve_failure(self, err: BaseException) -> None: self._result_fut.set_result(None) def _apply_schedule_command(self) -> None: - payload = self._payload_converter.to_payload(self._input.input) command = self._instance._add_command() v = command.schedule_nexus_operation v.seq = self._seq v.endpoint = self._input.endpoint v.service = self._input.service v.operation = self._input.operation_name - v.input.CopyFrom(payload) + payload_converter = ( + temporalio.nexus.system.get_payload_converter() + if temporalio.nexus.system.is_system_operation(v.service, v.operation) + else self._payload_converter + ) + v.input.CopyFrom(payload_converter.to_payload(self._input.input)) if self._input.schedule_to_close_timeout is not None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout diff --git a/tests/nexus/test_temporal_system_nexus.py b/tests/nexus/test_temporal_system_nexus.py new file mode 100644 index 000000000..d02849bca --- /dev/null +++ b/tests/nexus/test_temporal_system_nexus.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import dataclasses +import json +import uuid +from collections.abc import Sequence +from typing import cast + +import nexusrpc.handler +import pytest +from google.protobuf.json_format import MessageToDict + +import temporalio.api.common.v1 +import temporalio.converter +from temporalio import workflow +from temporalio.client import Client +from temporalio.converter import ( + DefaultPayloadConverter, + ExternalStorage, + PayloadCodec, +) +from temporalio.nexus.system import generated +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from temporalio.worker._workflow_instance import UnsandboxedWorkflowRunner +from tests.helpers.nexus import make_nexus_endpoint_name +from tests.test_extstore import InMemoryTestDriver + + +@nexusrpc.handler.service_handler(service=generated.WorkflowService) +class WorkflowServicePayloadHandler: + @nexusrpc.handler.sync_operation + async def signal_with_start_workflow_execution( + self, + _ctx: nexusrpc.handler.StartOperationContext, + request: generated.WorkflowServiceSignalWithStartWorkflowExecutionInput, + ) -> generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput: + assert request.workflowId == "system-nexus-workflow-id" + assert request.signalName == "test-signal" + request_dict = dataclasses.asdict(request) + for field_name in ("input", "signalInput"): + payloads = request_dict[field_name]["payloads"] + assert payloads[0]["externalPayloads"] + for field_name in ("memo", "header"): + fields = request_dict[field_name]["fields"] + assert next(iter(fields.values()))["externalPayloads"] + for field_name in ("summary", "details"): + payload = request_dict["userMetadata"][field_name] + assert payload["externalPayloads"] + search_attribute_payload = request_dict["searchAttributes"]["indexedFields"][ + "custom-key" + ] + assert "externalPayloads" not in search_attribute_payload + assert "test-codec" not in search_attribute_payload["metadata"] + return generated.WorkflowServiceSignalWithStartWorkflowExecutionOutput( + runId=f"{request.workflowId}-run" + ) + + +@workflow.defn +class SystemNexusCallerWithPayloadsWorkflow: + @workflow.run + async def run(self, task_queue: str) -> str: + nexus_client = workflow.create_nexus_client( + service=generated.WorkflowService, + endpoint=make_nexus_endpoint_name(task_queue), + ) + request = generated.WorkflowServiceSignalWithStartWorkflowExecutionInput( + namespace="default", + workflowId="system-nexus-workflow-id", + signalName="test-signal", + input=generated.Input( + payloads=[ + MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"workflow-input"', + ) + ) + ] + ), + signalInput=generated.Input( + payloads=[ + MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"signal-input"', + ) + ) + ] + ), + memo=generated.Memo( + fields={ + "memo-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"memo-value"', + ) + ) + } + ), + header=generated.Header( + fields={ + "header-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"header-value"', + ) + ) + } + ), + userMetadata=generated.UserMetadata( + summary=MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"summary-value"', + ) + ), + details=MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"details-value"', + ) + ), + ), + searchAttributes=generated.SearchAttributes( + indexedFields={ + "custom-key": MessageToDict( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'"search-attribute-value"', + ) + ) + } + ), + ) + handle = await nexus_client.start_operation( + generated.WorkflowService.signal_with_start_workflow_execution, + request, + ) + result = await handle + return cast(str, result.runId) + + +class RejectOuterSystemNexusCodec(PayloadCodec): + def __init__(self) -> None: + self.encode_count = 0 + + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + try: + body = json.loads(payload.data) + except json.JSONDecodeError: + body = None + if isinstance(body, dict) and { + "namespace", + "workflowId", + "signalName", + }.issubset(body): + raise RuntimeError( + "outer system nexus envelope should not be codec encoded" + ) + self.encode_count += 1 + encoded.append( + temporalio.api.common.v1.Payload( + metadata={**payload.metadata, "test-codec": b"true"}, + data=payload.data, + ) + ) + return encoded + + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + decoded: list[temporalio.api.common.v1.Payload] = [] + for payload in payloads: + try: + body = json.loads(payload.data) + except json.JSONDecodeError: + body = None + if isinstance(body, dict) and { + "namespace", + "workflowId", + "signalName", + }.issubset(body): + raise RuntimeError( + "outer system nexus envelope should not be codec decoded" + ) + decoded.append(payload) + return decoded + + +class BadSystemNexusEnvelopePayloadConverter(DefaultPayloadConverter): + def to_payloads( + self, values: Sequence[object] + ) -> list[temporalio.api.common.v1.Payload]: + payloads: list[temporalio.api.common.v1.Payload] = [] + for value in values: + if isinstance( + value, generated.WorkflowServiceSignalWithStartWorkflowExecutionInput + ): + payloads.append( + temporalio.api.common.v1.Payload( + metadata={"encoding": b"json/plain"}, + data=b'{"workflow_id":"bad-envelope"}', + ) + ) + else: + payloads.extend(super().to_payloads([value])) + return payloads + + +async def test_workflow_service_signal_with_start_nested_payloads_use_codec_without_encoding_outer_envelope( + env: WorkflowEnvironment, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with the Java test server") + + codec = RejectOuterSystemNexusCodec() + driver = InMemoryTestDriver() + caller_config = env.client.config() + caller_config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_converter_class=BadSystemNexusEnvelopePayloadConverter, + payload_codec=codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1, + ), + ) + caller_client = Client(**caller_config) + handler_config = env.client.config() + handler_config["data_converter"] = temporalio.converter.default() + handler_client = Client(**handler_config) + caller_task_queue = str(uuid.uuid4()) + handler_task_queue = str(uuid.uuid4()) + + caller_worker = Worker( + caller_client, + task_queue=caller_task_queue, + workflows=[SystemNexusCallerWithPayloadsWorkflow], + workflow_runner=UnsandboxedWorkflowRunner(), + ) + handler_worker = Worker( + handler_client, + task_queue=handler_task_queue, + nexus_service_handlers=[WorkflowServicePayloadHandler()], + ) + + async with caller_worker, handler_worker: + endpoint_name = make_nexus_endpoint_name(handler_task_queue) + await env.create_nexus_endpoint(endpoint_name, handler_task_queue) + result = await caller_client.execute_workflow( + SystemNexusCallerWithPayloadsWorkflow.run, + handler_task_queue, + id=str(uuid.uuid4()), + task_queue=caller_task_queue, + ) + + assert result == "system-nexus-workflow-id-run" + assert codec.encode_count >= 6 + stored_payloads: list[temporalio.api.common.v1.Payload] = [] + for stored_payload_bytes in driver._storage.values(): + stored_payload = temporalio.api.common.v1.Payload() + stored_payload.ParseFromString(stored_payload_bytes) + stored_payloads.append(stored_payload) + assert stored_payload.metadata["test-codec"] == b"true" + stored_payload_data = {payload.data for payload in stored_payloads} + assert { + b'"workflow-input"', + b'"signal-input"', + b'"memo-value"', + b'"header-value"', + b'"summary-value"', + b'"details-value"', + }.issubset(stored_payload_data)