From b5d368d299ff18c05a9591eb14491f0918f15968 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Wed, 8 Apr 2026 12:36:31 -0400 Subject: [PATCH 01/14] Add TeeEnabled to template generator, and add a new requirements callback --- go.mod | 2 +- go.sum | 4 +- .../v2/protoc/pkg/template_generator.go | 11 ++- pkg/workflows/wasm/host/execution.go | 26 +++++++ .../host/internal/rawsdk/helpers_wasip1.go | 4 + pkg/workflows/wasm/host/module.go | 24 ++++-- .../wasm/host/requirements_gen/main.go | 68 ++++++++++++++++ .../requirements_helper.go.tmpl | 29 +++++++ .../wasm/host/requirements_helper_gen.go | 29 +++++++ .../wasm/host/requirements_helper_gen_test.go | 47 +++++++++++ pkg/workflows/wasm/host/requirements_rerun.go | 16 ++++ pkg/workflows/wasm/host/standard_test.go | 48 +++++++++++- .../standard_tests/tee_runtime/main_wasip1.go | 38 +++++++++ pkg/workflows/wasm/host/tee_provider.go | 20 +++++ pkg/workflows/wasm/host/tee_provider_test.go | 40 ++++++++++ .../invalid_memory/main_wasip1.go | 13 ++++ .../requirements/invalid_proto/main_wasip1.go | 10 +++ pkg/workflows/wasm/host/wasm.go | 2 + pkg/workflows/wasm/host/wasm_nodag_test.go | 78 +++++++++++++++++-- 19 files changed, 491 insertions(+), 18 deletions(-) create mode 100644 pkg/workflows/wasm/host/requirements_gen/main.go create mode 100644 pkg/workflows/wasm/host/requirements_gen/requirements_helper.go.tmpl create mode 100644 pkg/workflows/wasm/host/requirements_helper_gen.go create mode 100644 pkg/workflows/wasm/host/requirements_helper_gen_test.go create mode 100644 pkg/workflows/wasm/host/requirements_rerun.go create mode 100644 pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go create mode 100644 pkg/workflows/wasm/host/tee_provider.go create mode 100644 pkg/workflows/wasm/host/tee_provider_test.go create mode 100644 pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go create mode 100644 pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go diff --git a/go.mod b/go.mod index 863b34085c..e84f119dbf 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260326111235-8c09d1a4491f + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260413164538-e2dad579edbc github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 38a8d0b61e..1971727715 100644 --- a/go.sum +++ b/go.sum @@ -340,8 +340,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260326111235-8c09d1a4491f h1:8p3vE987AHM3Of1JvnNJXNE/AtWtfNvJhk3TeeAG3Qw= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260326111235-8c09d1a4491f/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260413164538-e2dad579edbc h1:R0yNtbt6I1DquC5rERDQEw81hxgcEh3Z+VjY95y5DRI= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260413164538-e2dad579edbc/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/protoc/pkg/template_generator.go b/pkg/capabilities/v2/protoc/pkg/template_generator.go index da6c40e99a..9db41c14d3 100644 --- a/pkg/capabilities/v2/protoc/pkg/template_generator.go +++ b/pkg/capabilities/v2/protoc/pkg/template_generator.go @@ -139,9 +139,9 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial if md == nil { return false, nil - } else { - return md.MapToUntypedApi, nil } + + return md.MapToUntypedApi, nil }, "addImport": func(importPath protogen.GoImportPath, ignore string) string { importName := importPath.String() @@ -259,6 +259,13 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial return line } }, + "TeeEnabled": func(s *protogen.Service) (bool, error) { + md, err := getCapabilityMetadata(s) + if err != nil { + return false, err + } + return md.TeeEnabled, nil + }, }).Funcs(t.ExtraFns) // Register partials diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index ec9fd1bbfd..d538013755 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -32,6 +32,7 @@ type execution[T any] struct { nodeSeed int64 donLogCount uint32 nodeLogCount uint32 + requirementsRerunErr error } // callCapAsync async calls a capability by placing execution results onto a @@ -381,3 +382,28 @@ func (e *execution[T]) pollOneoff(caller *wasmtime.Caller, subscriptionptr int32 return ErrnoSuccess } + +// A trap return will cause the execution to halt. +// This function fails safe and prefers to kill the program than to return an error to the user. +// It does this because a failure here could lead to code running in an environment it's not allowed in +// Although the runtime could protect from this instead, it's safer to fail as early as possible +func (e *execution[T]) requirements(caller *wasmtime.Caller, ptr int32, ptrlen int32) *wasmtime.Trap { + requirements := &sdkpb.Requirements{} + payload, err := wasmRead(caller, ptr, ptrlen) + if err != nil { + e.requirementsRerunErr = fmt.Errorf("error reading requirements: %s", err) + return wasmtime.NewTrap(e.requirementsRerunErr.Error()) + } + + if err = proto.Unmarshal(payload, requirements); err != nil { + e.requirementsRerunErr = fmt.Errorf("error unmarshalling requirements: %s", err) + return wasmtime.NewTrap(e.requirementsRerunErr.Error()) + } + + if !CheckRequirements(e.module.cfg.RequirementsHandler, requirements) { + e.requirementsRerunErr = (*RequirementsRerun)(requirements) + return wasmtime.NewTrap(e.requirementsRerunErr.Error()) + } + + return nil +} diff --git a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go index dfdad8114c..3ad5843589 100644 --- a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go +++ b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go @@ -58,6 +58,7 @@ func SendError(err error) { func SendSubscription(subscriptions *sdk.TriggerSubscriptionRequest) { execResult := &sdk.ExecutionResult{Result: &sdk.ExecutionResult_TriggerSubscriptions{TriggerSubscriptions: subscriptions}} sendResponse(BufferToPointerLen(Must(proto.Marshal(execResult)))) + os.Exit(0) } func Now() time.Time { @@ -251,3 +252,6 @@ func getSecrets(req unsafe.Pointer, reqLen int32, responseBuffer unsafe.Pointer, //go:wasmimport env await_secrets func awaitSecrets(req unsafe.Pointer, reqLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 + +//go:wasmimport env requirements +func Requirements(req unsafe.Pointer, reqLen int32) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index e63e49fd9d..c512359c54 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -50,10 +50,10 @@ var ( defaultMaxLogCountNodeMode = 10_000 ResponseBufferTooSmall = "response buffer too small" - defaultMaxUserMetricPayloadBytes = uint32(4096) // 4 KB - defaultMaxUserMetricNameLength = uint32(128) - defaultMaxUserMetricLabelsPerMetric = uint32(10) - defaultMaxUserMetricLabelValueLength = uint32(256) + defaultMaxUserMetricPayloadBytes = uint32(4096) // 4 KB + defaultMaxUserMetricNameLength = uint32(128) + defaultMaxUserMetricLabelsPerMetric = uint32(10) + defaultMaxUserMetricLabelValueLength = uint32(256) ) type DeterminismConfig struct { @@ -82,7 +82,7 @@ type ModuleConfig struct { MaxLogCountDONMode uint32 MaxLogCountNodeMode uint32 - EnableUserMetricsLimiter limits.GateLimiter + EnableUserMetricsLimiter limits.GateLimiter MaxUserMetricPayloadBytes uint32 MaxUserMetricPayloadLimiter limits.BoundLimiter[config.Size] // supersedes MaxUserMetricPayloadBytes if set MaxUserMetricNameLength uint32 @@ -101,7 +101,8 @@ type ModuleConfig struct { // If Determinism is set, the module will override the random_get function in the WASI API with // the provided seed to ensure deterministic behavior. - Determinism *DeterminismConfig + Determinism *DeterminismConfig + RequirementsHandler RequirementsHandler } type ModuleBase interface { @@ -490,6 +491,13 @@ func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*sdkpb.Executio return nil, fmt.Errorf("error wrapping get_time func: %w", err) } + if err = linker.FuncWrap( + "env", + "requirements", + exec.requirements); err != nil { + return nil, fmt.Errorf("error wrapping requirements func: %w", err) + } + return linker.Instantiate(store, m.module) } @@ -727,6 +735,10 @@ func runWasm[I, O proto.Message]( return o, fmt.Errorf("invariant violation: host errored during sendResponse") } + if exec.requirementsRerunErr != nil { + return o, exec.requirementsRerunErr + } + // If an error has occurred and the deadline has been reached or exceeded, return a deadline exceeded error. // Note - there is no other reliable signal on the error that can be used to infer it is due to epoch deadline // being reached, so if an error is returned after the deadline it is assumed it is due to that and return diff --git a/pkg/workflows/wasm/host/requirements_gen/main.go b/pkg/workflows/wasm/host/requirements_gen/main.go new file mode 100644 index 0000000000..b4b4bd3e6f --- /dev/null +++ b/pkg/workflows/wasm/host/requirements_gen/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "bytes" + _ "embed" + "log" + "os" + "reflect" + "text/template" + + "github.com/smartcontractkit/chainlink-common/pkg/utils/codegen" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +//go:embed requirements_helper.go.tmpl +var tmplSrc string + +type fieldInfo struct { + Name string + Type string +} + +type templateData struct { + Fields []fieldInfo +} + +func main() { + requirementsType := reflect.TypeOf(sdk.Requirements{}) + + var fields []fieldInfo + for i := 0; i < requirementsType.NumField(); i++ { + f := requirementsType.Field(i) + if !f.IsExported() { + continue + } + fields = append(fields, fieldInfo{ + Name: f.Name, + Type: f.Type.String(), + }) + } + + tmpl, err := template.New("requirements_helper").Parse(tmplSrc) + if err != nil { + log.Fatalf("failed to parse template: %v", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, templateData{Fields: fields}); err != nil { + log.Fatalf("failed to execute template: %v", err) + } + + const outFile = "requirements_helper_gen.go" + settings := codegen.PrettySettings{ + Tool: "requirements_gen", + GoPrettySettings: codegen.GoPrettySettings{ + LocalPrefix: "github.com/smartcontractkit/chainlink-common", + }, + } + + content, err := codegen.PrettyFile(outFile, buf.String(), settings) + if err != nil { + log.Fatalf("failed to format generated code: %v\n%s", err, buf.String()) + } + + if err := os.WriteFile(outFile, []byte(content), 0644); err != nil { + log.Fatalf("failed to write output: %v", err) + } +} diff --git a/pkg/workflows/wasm/host/requirements_gen/requirements_helper.go.tmpl b/pkg/workflows/wasm/host/requirements_gen/requirements_helper.go.tmpl new file mode 100644 index 0000000000..2d8f102714 --- /dev/null +++ b/pkg/workflows/wasm/host/requirements_gen/requirements_helper.go.tmpl @@ -0,0 +1,29 @@ +package host + +// RequirementsHandler contains a callback for each public field in sdk.Requirements. +// Each callback receives the field value and returns a list of strings or an error. +type RequirementsHandler struct { +{{- range .Fields}} + {{.Name}} func({{.Type}}) bool +{{- end}} +} + +// CheckRequirements calls each non-nil callback in handler for the corresponding +// non-nil field in req, returning false if any are false, or if the handler is nil. +// Unknown fields on the proto also result in a false return value. +func CheckRequirements(handler RequirementsHandler, req *sdk.Requirements) bool { + if len(req.ProtoReflect().GetUnknown()) != 0 { + return false + } + +{{range .Fields}} + if req.{{.Name}} != nil { + if handler.{{.Name}} == nil || !handler.{{.Name}}(req.{{.Name}}) { + return false + } + + } +{{end}} + + return true +} diff --git a/pkg/workflows/wasm/host/requirements_helper_gen.go b/pkg/workflows/wasm/host/requirements_helper_gen.go new file mode 100644 index 0000000000..875da649fa --- /dev/null +++ b/pkg/workflows/wasm/host/requirements_helper_gen.go @@ -0,0 +1,29 @@ +// Code generated by requirements_gen, DO NOT EDIT. + +package host + +import "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + +// RequirementsHandler contains a callback for each public field in sdk.Requirements. +// Each callback receives the field value and returns a list of strings or an error. +type RequirementsHandler struct { + Tee func(*sdk.Tee) bool +} + +// CheckRequirements calls each non-nil callback in handler for the corresponding +// non-nil field in req, returning false if any are false, or if the handler is nil. +// Unknown fields on the proto also result in a false return value. +func CheckRequirements(handler RequirementsHandler, req *sdk.Requirements) bool { + if len(req.ProtoReflect().GetUnknown()) != 0 { + return false + } + + if req.Tee != nil { + if handler.Tee == nil || !handler.Tee(req.Tee) { + return false + } + + } + + return true +} diff --git a/pkg/workflows/wasm/host/requirements_helper_gen_test.go b/pkg/workflows/wasm/host/requirements_helper_gen_test.go new file mode 100644 index 0000000000..7c1b9049ff --- /dev/null +++ b/pkg/workflows/wasm/host/requirements_helper_gen_test.go @@ -0,0 +1,47 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func Test_CheckRequirements(t *testing.T) { + t.Parallel() + t.Run("unknown proto fields", func(t *testing.T) { + // Encode a field number (99) unknown to Requirements so proto.Unmarshal + // preserves it as unknown bytes. + b := protowire.AppendTag(nil, 99, protowire.VarintType) + b = protowire.AppendVarint(b, 1) + req := &sdk.Requirements{} + require.NoError(t, proto.Unmarshal(b, req)) + + assert.False(t, CheckRequirements(RequirementsHandler{}, req)) + }) + + t.Run("no fields always passes", func(t *testing.T) { + assert.True(t, CheckRequirements(RequirementsHandler{}, &sdk.Requirements{})) + }) + + t.Run("handler not set returns false", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + assert.False(t, CheckRequirements(RequirementsHandler{}, req)) + }) + + t.Run("handler returns false causes false return value", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + handler := RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }} + assert.False(t, CheckRequirements(handler, req)) + }) + + t.Run("handler returns true causes true return value", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + handler := RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }} + assert.True(t, CheckRequirements(handler, req)) + }) +} diff --git a/pkg/workflows/wasm/host/requirements_rerun.go b/pkg/workflows/wasm/host/requirements_rerun.go new file mode 100644 index 0000000000..9b64f4e928 --- /dev/null +++ b/pkg/workflows/wasm/host/requirements_rerun.go @@ -0,0 +1,16 @@ +package host + +import ( + "google.golang.org/protobuf/encoding/protojson" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type RequirementsRerun sdk.Requirements + +func (r *RequirementsRerun) Error() string { + str, _ := protojson.Marshal((*sdk.Requirements)(r)) + return string(str) +} + +var _ error = (*RequirementsRerun)(nil) diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index 76b03c471f..e06c5be31f 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -523,6 +523,47 @@ func TestStandardTimeInterpretation(t *testing.T) { require.Equal(t, "2020-01-02T03:04:05Z", result) } +func TestStandardTeeRuntime(t *testing.T) { + t.Parallel() + + trigger := &basictrigger.Outputs{CoolOutput: anyTestTriggerValue} + + cfg := defaultNoDAGModCfg(t) + var seenTeeRequirement *sdk.Tee + cfg.RequirementsHandler.Tee = func(tee *sdk.Tee) bool { + seenTeeRequirement = tee + return true + } + + m := makeTestModuleWithConfig(t, cfg) + + for _, test := range []struct { + name string + req *sdk.ExecuteRequest + }{ + { + name: "subscribe", + req: &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}}, + }, + { + name: "execute", + req: triggerExecuteRequest(t, 0, trigger), + }, + } { + t.Run(test.name, func(t *testing.T) { + mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() + + _, err := m.Execute(t.Context(), test.req, mockExecutionHelper) + require.NoError(t, err) + require.True(t, proto.Equal(seenTeeRequirement, &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []sdk.TeeType{sdk.TeeType_TEE_TYPE_AWS_NITRO}}}})) + }) + } +} + func triggerExecuteRequest(t *testing.T, id uint64, trigger proto.Message) *sdk.ExecuteRequest { wrappedTrigger, err := anypb.New(trigger) require.NoError(t, err) @@ -549,8 +590,12 @@ func runWithBasicTrigger(t *testing.T, executor ExecutionHelper) *sdk.ExecutionR // To re-use a binary, an outer test can create the module and use t.Run to run subtests using that module. // When subtests have their own binaries, those binaries are expected to be nested in a subfolder. func makeTestModule(t *testing.T) *module { + return makeTestModuleWithConfig(t, nil) +} + +func makeTestModuleWithConfig(t *testing.T, cfg *ModuleConfig) *module { testName := strcase.ToSnake(t.Name()[len("TestStandard"):]) - return makeTestModuleByName(t, testName, nil) + return makeTestModuleByName(t, testName, cfg) } func makeTestModuleByName(t *testing.T, testName string, cfg *ModuleConfig) *module { @@ -559,6 +604,7 @@ func makeTestModuleByName(t *testing.T, testName string, cfg *ModuleConfig) *mod absPath, err := filepath.Abs(testPath) require.NoError(t, err, "Failed to get absolute path for test directory") cmd.Dir = absPath + fmt.Printf("Compiling test module from %s with command %s\n:", cmd.Dir, cmd.String()) output, err := cmd.CombinedOutput() require.NoError(t, err, string(output)) diff --git a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go new file mode 100644 index 0000000000..2d3631d7f5 --- /dev/null +++ b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go @@ -0,0 +1,38 @@ +package main + +import ( + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func main() { + req := rawsdk.GetRequest() + requirements := &sdk.Requirements{Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []sdk.TeeType{sdk.TeeType_TEE_TYPE_AWS_NITRO}}}}} + bytes, err := proto.Marshal(requirements) + if err != nil { + rawsdk.SendError(err) + } + rawsdk.Requirements(rawsdk.BufferToPointerLen(bytes)) + subscription := &sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + })), + Method: "Trigger", + }, + }, + } + switch req.GetRequest().(type) { + case *sdk.ExecuteRequest_Subscribe: + rawsdk.SendSubscription(subscription) + } + + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/tee_provider.go b/pkg/workflows/wasm/host/tee_provider.go new file mode 100644 index 0000000000..64a5c8e61e --- /dev/null +++ b/pkg/workflows/wasm/host/tee_provider.go @@ -0,0 +1,20 @@ +package host + +import sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + +type TeeProvider sdkpb.TeeType + +func (t TeeProvider) Provides(tee *sdkpb.Tee) bool { + switch teet := tee.Type.(type) { + case *sdkpb.Tee_Any: + return true + case *sdkpb.Tee_TypeSelection: + for _, selection := range teet.TypeSelection.Types { + if selection == sdkpb.TeeType(t) { + return true + } + } + } + + return false +} diff --git a/pkg/workflows/wasm/host/tee_provider_test.go b/pkg/workflows/wasm/host/tee_provider_test.go new file mode 100644 index 0000000000..c79858338f --- /dev/null +++ b/pkg/workflows/wasm/host/tee_provider_test.go @@ -0,0 +1,40 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/emptypb" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func TestTeeProvider(t *testing.T) { + t.Parallel() + t.Run("matches any", func(t *testing.T) { + p := TeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO) + tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} + assert.True(t, p.Provides(tee)) + }) + + t.Run("matches type selection", func(t *testing.T) { + p := TeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO) + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []sdkpb.TeeType{sdkpb.TeeType(99), sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("does not match any type", func(t *testing.T) { + // Use a cast to an unknown value so we don't need a second enum variant. + p := TeeProvider(sdkpb.TeeType(99)) + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []sdkpb.TeeType{sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, + }} + assert.False(t, p.Provides(tee)) + }) +} diff --git a/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go b/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go new file mode 100644 index 0000000000..bd31aa32c8 --- /dev/null +++ b/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go @@ -0,0 +1,13 @@ +package main + +import ( + "unsafe" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" +) + +func main() { + buf := make([]byte, 4) + rawsdk.Requirements(unsafe.Pointer(&buf[0]), 100) + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go b/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go new file mode 100644 index 0000000000..1df7e1da8f --- /dev/null +++ b/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" +) + +func main() { + rawsdk.Requirements(rawsdk.BufferToPointerLen([]byte{0x3E, 0x80, 0xFF, 0x0A, 0xFF, 0x01, 0x0C, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01})) + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/wasm.go b/pkg/workflows/wasm/host/wasm.go index d8c4bae7f1..2ec826d4db 100644 --- a/pkg/workflows/wasm/host/wasm.go +++ b/pkg/workflows/wasm/host/wasm.go @@ -1,5 +1,7 @@ package host +//go:generate go run ./requirements_gen + import ( "context" "errors" diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 9f77ba1ebb..2e488734c4 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" @@ -20,12 +22,18 @@ import ( ) const ( - nodagRandomBinaryCmd = "standard_tests/multiple_triggers" - nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" - loggingLimitsBinaryCmd = "test/logging_limits/cmd" - loggingLimitsBinaryLocation = loggingLimitsBinaryCmd + "/testmodule.wasm" - metricLimitsBinaryCmd = "test/metric_limits/cmd" - metricLimitsBinaryLocation = metricLimitsBinaryCmd + "/testmodule.wasm" + nodagRandomBinaryCmd = "standard_tests/multiple_triggers" + nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" + loggingLimitsBinaryCmd = "test/logging_limits/cmd" + loggingLimitsBinaryLocation = loggingLimitsBinaryCmd + "/testmodule.wasm" + metricLimitsBinaryCmd = "test/metric_limits/cmd" + metricLimitsBinaryLocation = metricLimitsBinaryCmd + "/testmodule.wasm" + standardTeeRuntimeBinaryCmd = "standard_tests/tee_runtime" + standardTeeRuntimeBinaryLocation = standardTeeRuntimeBinaryCmd + "/testmodule.wasm" + invalidMemoryForRequirementsCmd = "test/requirements/invalid_memory" + invalidMemoryForRequirementsBinaryLocation = invalidMemoryForRequirementsCmd + "/testmodule.wasm" + invalidProtoForRequirementsCmd = "test/requirements/invalid_proto" + invalidProtoForRequirementsBinaryLocation = invalidProtoForRequirementsCmd + "/testmodule.wasm" ) func Test_Sleep_Timeout(t *testing.T) { @@ -213,6 +221,64 @@ func Test_NoDAG_EmitMetricDisabled(t *testing.T) { // EmitUserMetric should never be called when disabled - no mock expectation set } +func Test_NoDAG_UnparseableRequirements(t *testing.T) { + t.Parallel() + binary := createTestBinary(invalidProtoForRequirementsCmd, invalidProtoForRequirementsBinaryLocation, true, t) + + err := runTeeFailureTest(t, sdk.TeeType_TEE_TYPE_AWS_NITRO, binary) + + assert.Error(t, err) + rerunErr := &RequirementsRerun{} + assert.False(t, errors.As(err, &rerunErr)) +} + +func Test_NoDAG_InvalidMemoryAddressForRequirements(t *testing.T) { + t.Parallel() + binary := createTestBinary(invalidMemoryForRequirementsCmd, invalidMemoryForRequirementsBinaryLocation, true, t) + + err := runTeeFailureTest(t, sdk.TeeType_TEE_TYPE_AWS_NITRO, binary) + + assert.Error(t, err) + rerunErr := &RequirementsRerun{} + assert.False(t, errors.As(err, &rerunErr)) +} + +func Test_NoDAG_RequirementsNotMet(t *testing.T) { + t.Parallel() + + binary := createTestBinary(standardTeeRuntimeBinaryCmd, standardTeeRuntimeBinaryLocation, true, t) + + // Different (non-existent) TEE + err := runTeeFailureTest(t, 999, binary) + + rerunErr := &RequirementsRerun{} + require.True(t, errors.As(err, &rerunErr)) + + expected := &sdk.Requirements{ + Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{ + TypeSelection: &sdk.TeeTypeSelection{Types: []sdk.TeeType{sdk.TeeType_TEE_TYPE_AWS_NITRO}}}, + }, + } + assert.True(t, proto.Equal(expected, (*sdk.Requirements)(rerunErr))) +} + +func runTeeFailureTest(t *testing.T, teeType sdk.TeeType, binary []byte) error { + cfg := defaultNoDAGModCfg(t) + cfg.RequirementsHandler.Tee = TeeProvider(teeType).Provides + m, err := NewModule(t.Context(), cfg, binary) + require.NoError(t, err) + + mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() + subscribe := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}} + + _, err = m.Execute(t.Context(), subscribe, mockExecutionHelper) + return err +} + func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { return &ModuleConfig{ Logger: logger.Test(t), From 660e58a08c23698a5115b5c95ae6f64f97957932 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Wed, 22 Apr 2026 09:28:26 -0400 Subject: [PATCH 02/14] Add region --- go.mod | 2 +- go.sum | 4 +- pkg/workflows/wasm/host/standard_test.go | 2 +- .../standard_tests/tee_runtime/main_wasip1.go | 2 +- pkg/workflows/wasm/host/tee_provider.go | 27 +++- pkg/workflows/wasm/host/tee_provider_test.go | 118 ++++++++++++++++-- pkg/workflows/wasm/host/wasm_nodag_test.go | 6 +- 7 files changed, 140 insertions(+), 21 deletions(-) diff --git a/go.mod b/go.mod index e84f119dbf..fad24559b3 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260413164538-e2dad579edbc + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260421194300-2c8da85a337a github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 1971727715..31e6c3af3c 100644 --- a/go.sum +++ b/go.sum @@ -340,8 +340,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260413164538-e2dad579edbc h1:R0yNtbt6I1DquC5rERDQEw81hxgcEh3Z+VjY95y5DRI= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260413164538-e2dad579edbc/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260421194300-2c8da85a337a h1:2mwWuRputcmFMzehSUlk95q9NQp9cspupb6FZxgCh7w= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260421194300-2c8da85a337a/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index e06c5be31f..0d96d4f195 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -559,7 +559,7 @@ func TestStandardTeeRuntime(t *testing.T) { _, err := m.Execute(t.Context(), test.req, mockExecutionHelper) require.NoError(t, err) - require.True(t, proto.Equal(seenTeeRequirement, &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []sdk.TeeType{sdk.TeeType_TEE_TYPE_AWS_NITRO}}}})) + require.True(t, proto.Equal(seenTeeRequirement, &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}})) }) } } diff --git a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go index 2d3631d7f5..d8455ad3a1 100644 --- a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go +++ b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go @@ -11,7 +11,7 @@ import ( func main() { req := rawsdk.GetRequest() - requirements := &sdk.Requirements{Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []sdk.TeeType{sdk.TeeType_TEE_TYPE_AWS_NITRO}}}}} + requirements := &sdk.Requirements{Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}}} bytes, err := proto.Marshal(requirements) if err != nil { rawsdk.SendError(err) diff --git a/pkg/workflows/wasm/host/tee_provider.go b/pkg/workflows/wasm/host/tee_provider.go index 64a5c8e61e..5905bfdbd2 100644 --- a/pkg/workflows/wasm/host/tee_provider.go +++ b/pkg/workflows/wasm/host/tee_provider.go @@ -2,16 +2,35 @@ package host import sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" -type TeeProvider sdkpb.TeeType +type teeProvider struct { + sdkpb.TeeType + regions map[string]bool +} -func (t TeeProvider) Provides(tee *sdkpb.Tee) bool { +func NewTeeProvider(tpe sdkpb.TeeType, regions []string) func(tee *sdkpb.Tee) bool { + supportedRegions := map[string]bool{} + for _, region := range regions { + supportedRegions[region] = true + } + return (&teeProvider{TeeType: tpe, regions: supportedRegions}).Provides +} + +func (t *teeProvider) Provides(tee *sdkpb.Tee) bool { switch teet := tee.Type.(type) { case *sdkpb.Tee_Any: return true case *sdkpb.Tee_TypeSelection: for _, selection := range teet.TypeSelection.Types { - if selection == sdkpb.TeeType(t) { - return true + if selection.Type == t.TeeType { + if len(selection.Regions) == 0 { + return true + } + + for _, region := range selection.Regions { + if t.regions[region] { + return true + } + } } } } diff --git a/pkg/workflows/wasm/host/tee_provider_test.go b/pkg/workflows/wasm/host/tee_provider_test.go index c79858338f..5ba04bfb81 100644 --- a/pkg/workflows/wasm/host/tee_provider_test.go +++ b/pkg/workflows/wasm/host/tee_provider_test.go @@ -9,32 +9,132 @@ import ( sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) -func TestTeeProvider(t *testing.T) { +func TestNewTeeProvider(t *testing.T) { t.Parallel() t.Run("matches any", func(t *testing.T) { - p := TeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO) + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} assert.True(t, p.Provides(tee)) }) - t.Run("matches type selection", func(t *testing.T) { - p := TeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO) + t.Run("matches type selection with no region constraint", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []sdkpb.TeeType{sdkpb.TeeType(99), sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(99)}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, }, }} assert.True(t, p.Provides(tee)) }) - t.Run("does not match any type", func(t *testing.T) { - // Use a cast to an unknown value so we don't need a second enum variant. - p := TeeProvider(sdkpb.TeeType(99)) + t.Run("does not match different types", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType(99)} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []sdkpb.TeeType{sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, }, }} assert.False(t, p.Provides(tee)) }) + + t.Run("matches type and region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("matches type but not region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches one of multiple requested regions", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"eu-west-1": true}} + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2", "eu-west-1"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("provider has multiple regions and one matches", func(t *testing.T) { + p := teeProvider{ + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + } + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-east-1"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("no matching region across multiple provider regions", func(t *testing.T) { + p := teeProvider{ + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + } + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"ap-southeast-1"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("type mismatch ignores region match", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType(99), regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches any tee", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} + assert.True(t, provides(tee)) + }) + + t.Run("returns a function that checks regions", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ + TypeSelection: &sdkpb.TeeTypeSelection{ + Types: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }, + }, + }} + assert.False(t, provides(tee)) + }) } diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 2e488734c4..02a4d48510 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -256,15 +256,15 @@ func Test_NoDAG_RequirementsNotMet(t *testing.T) { expected := &sdk.Requirements{ Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{ - TypeSelection: &sdk.TeeTypeSelection{Types: []sdk.TeeType{sdk.TeeType_TEE_TYPE_AWS_NITRO}}}, - }, + TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}}, + }}, } assert.True(t, proto.Equal(expected, (*sdk.Requirements)(rerunErr))) } func runTeeFailureTest(t *testing.T, teeType sdk.TeeType, binary []byte) error { cfg := defaultNoDAGModCfg(t) - cfg.RequirementsHandler.Tee = TeeProvider(teeType).Provides + cfg.RequirementsHandler.Tee = NewTeeProvider(teeType, nil) m, err := NewModule(t.Context(), cfg, binary) require.NoError(t, err) From 42946187b814025a4720864c1e7eac02c8c8aff8 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Wed, 22 Apr 2026 15:51:38 -0400 Subject: [PATCH 03/14] Requirements selecting runner --- .../wasm/host/requirement_selecting_module.go | 84 ++++++ .../host/requirement_selecting_module_test.go | 253 ++++++++++++++++++ 2 files changed, 337 insertions(+) create mode 100644 pkg/workflows/wasm/host/requirement_selecting_module.go create mode 100644 pkg/workflows/wasm/host/requirement_selecting_module_test.go diff --git a/pkg/workflows/wasm/host/requirement_selecting_module.go b/pkg/workflows/wasm/host/requirement_selecting_module.go new file mode 100644 index 0000000000..4a6988ef09 --- /dev/null +++ b/pkg/workflows/wasm/host/requirement_selecting_module.go @@ -0,0 +1,84 @@ +package host + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type ModuleAndHandler struct { + ModuleV2 + RequirementsHandler +} + +func NewRequirementSelectingModule(moduleAndHandlers []ModuleAndHandler) ModuleV2 { + return &requirementSelectingModule{ + moduleAndHandler: moduleAndHandlers, + runOn: -1, + } +} + +type requirementSelectingModule struct { + moduleAndHandler []ModuleAndHandler + runOn int + started atomic.Bool + findMutex sync.Mutex +} + +func (r *requirementSelectingModule) Start() { + r.started.Store(true) + r.moduleAndHandler[0].Start() +} + +func (r *requirementSelectingModule) Close() { + r.findMutex.Lock() + defer r.findMutex.Unlock() + if r.runOn == -1 { + r.moduleAndHandler[0].Close() + } else { + r.moduleAndHandler[r.runOn].Close() + } +} + +func (r *requirementSelectingModule) IsLegacyDAG() bool { + return r.moduleAndHandler[0].IsLegacyDAG() +} + +func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + if r.runOn >= 0 { + return r.moduleAndHandler[r.runOn].Execute(ctx, request, handler) + } + + r.findMutex.Lock() + defer r.findMutex.Unlock() + result, err := r.moduleAndHandler[0].Execute(ctx, request, handler) + if err == nil { + r.runOn = 0 + return result, nil + } + + rerun := &RequirementsRerun{} + if !errors.As(err, &rerun) { + return nil, err + } + + numHandlers := len(r.moduleAndHandler) + for i := 1; i < numHandlers; i++ { + item := r.moduleAndHandler[i] + if CheckRequirements(item.RequirementsHandler, (*sdk.Requirements)(rerun)) { + r.runOn = i + if r.started.Load() { + item.Start() + } + return item.Execute(ctx, request, handler) + } + } + + return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements %+v\n", rerun) +} + +var _ ModuleV2 = &requirementSelectingModule{} diff --git a/pkg/workflows/wasm/host/requirement_selecting_module_test.go b/pkg/workflows/wasm/host/requirement_selecting_module_test.go new file mode 100644 index 0000000000..fdaba31e8a --- /dev/null +++ b/pkg/workflows/wasm/host/requirement_selecting_module_test.go @@ -0,0 +1,253 @@ +package host + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type stubModuleV2 struct { + startFn func() + closeFn func() + legacyFn func() bool + executeFn func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) +} + +func (s *stubModuleV2) Start() { s.startFn() } +func (s *stubModuleV2) Close() { s.closeFn() } +func (s *stubModuleV2) IsLegacyDAG() bool { return s.legacyFn() } +func (s *stubModuleV2) Execute(ctx context.Context, req *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + return s.executeFn(ctx, req, h) +} + +func TestRequirementSelectingModule_Start(t *testing.T) { + var started bool + m0 := &stubModuleV2{startFn: func() { started = true }} + m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) + m.Start() + assert.True(t, started) +} + +func TestRequirementSelectingModule_Close(t *testing.T) { + t.Run("before execute closes first module", func(t *testing.T) { + var closedIdx int + m0 := &stubModuleV2{closeFn: func() { closedIdx = 0 }} + m1 := &stubModuleV2{closeFn: func() { closedIdx = 1 }} + m := NewRequirementSelectingModule([]ModuleAndHandler{ + {ModuleV2: m0}, + {ModuleV2: m1}, + }) + closedIdx = -1 + m.Close() + assert.Equal(t, 0, closedIdx) + }) + + t.Run("after execute closes selected module", func(t *testing.T) { + wantResult := &sdk.ExecutionResult{} + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + var closedIdx int + + m0 := &stubModuleV2{ + startFn: func() {}, + closeFn: func() { closedIdx = 0 }, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + } + m1 := &stubModuleV2{ + closeFn: func() { closedIdx = 1 }, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return wantResult, nil + }, + } + + m := NewRequirementSelectingModule([]ModuleAndHandler{ + {ModuleV2: m0}, + {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, + }) + + _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + + closedIdx = -1 + m.Close() + assert.Equal(t, 1, closedIdx) + }) +} + +func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { + t.Run("delegates", func(t *testing.T) { + m0 := &stubModuleV2{legacyFn: func() bool { return true }} + m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) + assert.True(t, m.IsLegacyDAG()) + }) +} + +func TestRequirementSelectingModule_Execute(t *testing.T) { + t.Run("delegates when runOn already set", func(t *testing.T) { + calls := 0 + wantResult := &sdk.ExecutionResult{} + m0 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + calls++ + return wantResult, nil + }, + } + + m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) + + _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + assert.Equal(t, 1, calls) + + got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + assert.Equal(t, wantResult, got) + assert.Equal(t, 2, calls) + }) + + t.Run("first module succeeds sets runOn to zero", func(t *testing.T) { + wantResult := &sdk.ExecutionResult{} + numCalls := 0 + m0 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + numCalls++ + return wantResult, nil + }, + } + + m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) + + got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + assert.Equal(t, 1, numCalls) + assert.Equal(t, wantResult, got) + }) + + t.Run("non-RequirementsRerun error is propagated without additional executions", func(t *testing.T) { + m0 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, assert.AnError + }, + } + + m1 := &stubModuleV2{executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + assert.Fail(t, "second module should not be executed") + return nil, nil + }} + + m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}, {ModuleV2: m1}}) + + _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + assert.ErrorIs(t, err, assert.AnError) + }) + + t.Run("RequirementsRerun with matching handler not started", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + wantResult := &sdk.ExecutionResult{} + var m1Started bool + + m0 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + } + m1 := &stubModuleV2{ + startFn: func() { m1Started = true }, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return wantResult, nil + }, + } + + m := NewRequirementSelectingModule([]ModuleAndHandler{ + {ModuleV2: m0}, + {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, + }) + + got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + assert.Equal(t, wantResult, got) + assert.False(t, m1Started) + }) + + t.Run("RequirementsRerun with matching handler already started", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + wantResult := &sdk.ExecutionResult{} + var m1Started bool + + m0 := &stubModuleV2{ + startFn: func() {}, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + } + m1 := &stubModuleV2{ + startFn: func() { m1Started = true }, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return wantResult, nil + }, + } + + m := NewRequirementSelectingModule([]ModuleAndHandler{ + {ModuleV2: m0}, + {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, + }) + + m.Start() + + got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + assert.Equal(t, wantResult, got) + assert.True(t, m1Started) + }) + + t.Run("RequirementsRerun with no matching handler returns error", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + m0 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + } + m1 := &stubModuleV2{} + + m := NewRequirementSelectingModule([]ModuleAndHandler{ + {ModuleV2: m0}, + {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}}, + }) + + _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner that can satisfy the requirements") + }) + + t.Run("RequirementsRerun skips non-matching selects later match", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + wantResult := &sdk.ExecutionResult{} + + m0 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + } + m1 := &stubModuleV2{} + m2 := &stubModuleV2{ + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return wantResult, nil + }, + } + + m := NewRequirementSelectingModule([]ModuleAndHandler{ + {ModuleV2: m0}, + {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}}, + {ModuleV2: m2, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, + }) + + got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) + require.NoError(t, err) + assert.Equal(t, wantResult, got) + }) +} From a8698611e9fff65bdc83a7de83bc9071352fc89b Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Fri, 24 Apr 2026 15:31:22 -0400 Subject: [PATCH 04/14] update proto and fix requirement not met test --- go.mod | 2 +- go.sum | 4 ++-- pkg/capabilities/v2/protoc/pkg/template_generator.go | 7 ++++++- pkg/workflows/wasm/host/wasm_nodag_test.go | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 268a966728..ce593a074f 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260421194300-2c8da85a337a + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260424192350-c2ff1c3f6163 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 13b2033a6a..ed7fd44978 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260421194300-2c8da85a337a h1:2mwWuRputcmFMzehSUlk95q9NQp9cspupb6FZxgCh7w= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260421194300-2c8da85a337a/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260424192350-c2ff1c3f6163 h1:MfHAshLU/p25XvIafw6sPrBaBKwpeTNVANADiMLzeak= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260424192350-c2ff1c3f6163/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/protoc/pkg/template_generator.go b/pkg/capabilities/v2/protoc/pkg/template_generator.go index 9db41c14d3..90e84da804 100644 --- a/pkg/capabilities/v2/protoc/pkg/template_generator.go +++ b/pkg/capabilities/v2/protoc/pkg/template_generator.go @@ -264,7 +264,12 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial if err != nil { return false, err } - return md.TeeEnabled, nil + for _, env := range md.AdditionalEnvironments { + if env == generator.AdditionalEnironments_ADDITIONAL_ENVIRONMENTS_TEE { + return true, nil + } + } + return false, nil }, }).Funcs(t.ExtraFns) diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 02a4d48510..fd27a99798 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -256,7 +256,7 @@ func Test_NoDAG_RequirementsNotMet(t *testing.T) { expected := &sdk.Requirements{ Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{ - TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}}, + TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}, }}, } assert.True(t, proto.Equal(expected, (*sdk.Requirements)(rerunErr))) From 16dba123303f1cb85e3f2514f937015369008d17 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Mon, 27 Apr 2026 09:51:55 -0400 Subject: [PATCH 05/14] Update proto, allow individual triggers to choose where to run --- .mockery.yaml | 11 +- .../v2/actions/confidentialhttp/client.pb.go | 4 +- pkg/capabilities/v2/actions/http/client.pb.go | 4 +- .../host/mock_execution_helper_test.go | 0 pkg/workflows/host/mocks/execution_helper.go | 396 +++++++++++++++ pkg/workflows/host/mocks/module.go | 207 ++++++++ pkg/workflows/host/module.go | 41 ++ .../host/requirement_selecting_module.go | 105 ++++ .../host/requirement_selecting_module_test.go | 462 ++++++++++++++++++ .../{wasm => }/host/requirements_gen/main.go | 0 .../requirements_helper.go.tmpl | 0 .../host/requirements_helper_gen.go | 0 .../host/requirements_helper_gen_test.go | 0 .../{wasm => }/host/requirements_rerun.go | 0 pkg/workflows/{wasm => }/host/tee_provider.go | 0 .../{wasm => }/host/tee_provider_test.go | 0 pkg/workflows/wasm/host/execution.go | 6 +- pkg/workflows/wasm/host/mocks/module_v2.go | 203 +------- pkg/workflows/wasm/host/module.go | 35 +- pkg/workflows/wasm/host/module_test.go | 4 +- .../wasm/host/requirement_selecting_module.go | 84 ---- .../host/requirement_selecting_module_test.go | 253 ---------- pkg/workflows/wasm/host/standard_test.go | 40 +- pkg/workflows/wasm/host/time_test.go | 11 +- pkg/workflows/wasm/host/wasm.go | 2 - pkg/workflows/wasm/host/wasm_nodag_test.go | 22 +- 26 files changed, 1280 insertions(+), 610 deletions(-) rename pkg/workflows/{wasm => }/host/mock_execution_helper_test.go (100%) create mode 100644 pkg/workflows/host/mocks/execution_helper.go create mode 100644 pkg/workflows/host/mocks/module.go create mode 100644 pkg/workflows/host/module.go create mode 100644 pkg/workflows/host/requirement_selecting_module.go create mode 100644 pkg/workflows/host/requirement_selecting_module_test.go rename pkg/workflows/{wasm => }/host/requirements_gen/main.go (100%) rename pkg/workflows/{wasm => }/host/requirements_gen/requirements_helper.go.tmpl (100%) rename pkg/workflows/{wasm => }/host/requirements_helper_gen.go (100%) rename pkg/workflows/{wasm => }/host/requirements_helper_gen_test.go (100%) rename pkg/workflows/{wasm => }/host/requirements_rerun.go (100%) rename pkg/workflows/{wasm => }/host/tee_provider.go (100%) rename pkg/workflows/{wasm => }/host/tee_provider_test.go (100%) delete mode 100644 pkg/workflows/wasm/host/requirement_selecting_module.go delete mode 100644 pkg/workflows/wasm/host/requirement_selecting_module_test.go diff --git a/.mockery.yaml b/.mockery.yaml index b07b652c93..8771512ed2 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -37,13 +37,12 @@ packages: github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host: interfaces: ModuleV1: {} - ModuleV2: {} + github.com/smartcontractkit/chainlink-common/pkg/workflows/host: + interfaces: + Module: {} ExecutionHelper: config: - inpackage: true - filename: "mock_{{.InterfaceName | snakecase}}_test.go" - mockname: "Mock{{.InterfaceName}}" - dir: "{{.InterfaceDir}}" + mockname: "Mock{{.InterfaceName}}" github.com/smartcontractkit/chainlink-common/pkg/custmsg: interfaces: MessageEmitter: @@ -63,4 +62,4 @@ packages: dir: "{{.InterfaceDir}}/limits" outpkg: limits interfaces: - Getter: \ No newline at end of file + Getter: diff --git a/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go b/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go index 4497ce5c8b..8b1b74b65b 100644 --- a/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go @@ -441,9 +441,9 @@ const file_capabilities_networking_confidentialhttp_v1alpha_client_proto_rawDesc "\x05value\x18\x02 \x01(\v2>.capabilities.networking.confidentialhttp.v1alpha.HeaderValuesR\x05value:\x028\x01\"\xe2\x01\n" + "\x17ConfidentialHTTPRequest\x12n\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2B.capabilities.networking.confidentialhttp.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12W\n" + - "\arequest\x18\x02 \x01(\v2=.capabilities.networking.confidentialhttp.v1alpha.HTTPRequestR\arequest2\xca\x01\n" + + "\arequest\x18\x02 \x01(\v2=.capabilities.networking.confidentialhttp.v1alpha.HTTPRequestR\arequest2\xcd\x01\n" + "\x06Client\x12\x98\x01\n" + - "\vSendRequest\x12I.capabilities.networking.confidentialhttp.v1alpha.ConfidentialHTTPRequest\x1a>.capabilities.networking.confidentialhttp.v1alpha.HTTPResponse\x1a%\x82\xb5\x18!\b\x01\x12\x1dconfidential-http@1.0.0-alphab\x06proto3" + "\vSendRequest\x12I.capabilities.networking.confidentialhttp.v1alpha.ConfidentialHTTPRequest\x1a>.capabilities.networking.confidentialhttp.v1alpha.HTTPResponse\x1a(\x82\xb5\x18$\b\x01\x12\x1dconfidential-http@1.0.0-alpha\"\x01\x01b\x06proto3" var ( file_capabilities_networking_confidentialhttp_v1alpha_client_proto_rawDescOnce sync.Once diff --git a/pkg/capabilities/v2/actions/http/client.pb.go b/pkg/capabilities/v2/actions/http/client.pb.go index 6a93e88293..f9a767596b 100644 --- a/pkg/capabilities/v2/actions/http/client.pb.go +++ b/pkg/capabilities/v2/actions/http/client.pb.go @@ -320,9 +320,9 @@ const file_capabilities_networking_http_v1alpha_client_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\x1as\n" + "\x11MultiHeadersEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12H\n" + - "\x05value\x18\x02 \x01(\v22.capabilities.networking.http.v1alpha.HeaderValuesR\x05value:\x028\x012\x98\x01\n" + + "\x05value\x18\x02 \x01(\v22.capabilities.networking.http.v1alpha.HeaderValuesR\x05value:\x028\x012\x9b\x01\n" + "\x06Client\x12l\n" + - "\vSendRequest\x12-.capabilities.networking.http.v1alpha.Request\x1a..capabilities.networking.http.v1alpha.Response\x1a \x82\xb5\x18\x1c\b\x02\x12\x18http-actions@1.0.0-alphab\x06proto3" + "\vSendRequest\x12-.capabilities.networking.http.v1alpha.Request\x1a..capabilities.networking.http.v1alpha.Response\x1a#\x82\xb5\x18\x1f\b\x02\x12\x18http-actions@1.0.0-alpha\"\x01\x01b\x06proto3" var ( file_capabilities_networking_http_v1alpha_client_proto_rawDescOnce sync.Once diff --git a/pkg/workflows/wasm/host/mock_execution_helper_test.go b/pkg/workflows/host/mock_execution_helper_test.go similarity index 100% rename from pkg/workflows/wasm/host/mock_execution_helper_test.go rename to pkg/workflows/host/mock_execution_helper_test.go diff --git a/pkg/workflows/host/mocks/execution_helper.go b/pkg/workflows/host/mocks/execution_helper.go new file mode 100644 index 0000000000..455fde3041 --- /dev/null +++ b/pkg/workflows/host/mocks/execution_helper.go @@ -0,0 +1,396 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + + time "time" + + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + v2 "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" + mock "github.com/stretchr/testify/mock" +) + +// MockExecutionHelper is an autogenerated mock type for the ExecutionHelper type +type MockExecutionHelper struct { + mock.Mock +} + +type MockExecutionHelper_Expecter struct { + mock *mock.Mock +} + +func (_m *MockExecutionHelper) EXPECT() *MockExecutionHelper_Expecter { + return &MockExecutionHelper_Expecter{mock: &_m.Mock} +} + +// CallCapability provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelper) CallCapability(ctx context.Context, request *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CallCapability") + } + + var r0 *sdk.CapabilityResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.CapabilityRequest) *sdk.CapabilityResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sdk.CapabilityResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.CapabilityRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' +type MockExecutionHelper_CallCapability_Call struct { + *mock.Call +} + +// CallCapability is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.CapabilityRequest +func (_e *MockExecutionHelper_Expecter) CallCapability(ctx interface{}, request interface{}) *MockExecutionHelper_CallCapability_Call { + return &MockExecutionHelper_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} +} + +func (_c *MockExecutionHelper_CallCapability_Call) Run(run func(ctx context.Context, request *sdk.CapabilityRequest)) *MockExecutionHelper_CallCapability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.CapabilityRequest)) + }) + return _c +} + +func (_c *MockExecutionHelper_CallCapability_Call) Return(_a0 *sdk.CapabilityResponse, _a1 error) *MockExecutionHelper_CallCapability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_CallCapability_Call) RunAndReturn(run func(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error)) *MockExecutionHelper_CallCapability_Call { + _c.Call.Return(run) + return _c +} + +// EmitUserLog provides a mock function with given fields: log +func (_m *MockExecutionHelper) EmitUserLog(log string) error { + ret := _m.Called(log) + + if len(ret) == 0 { + panic("no return value specified for EmitUserLog") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(log) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockExecutionHelper_EmitUserLog_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmitUserLog' +type MockExecutionHelper_EmitUserLog_Call struct { + *mock.Call +} + +// EmitUserLog is a helper method to define mock.On call +// - log string +func (_e *MockExecutionHelper_Expecter) EmitUserLog(log interface{}) *MockExecutionHelper_EmitUserLog_Call { + return &MockExecutionHelper_EmitUserLog_Call{Call: _e.mock.On("EmitUserLog", log)} +} + +func (_c *MockExecutionHelper_EmitUserLog_Call) Run(run func(log string)) *MockExecutionHelper_EmitUserLog_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockExecutionHelper_EmitUserLog_Call) Return(_a0 error) *MockExecutionHelper_EmitUserLog_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_EmitUserLog_Call) RunAndReturn(run func(string) error) *MockExecutionHelper_EmitUserLog_Call { + _c.Call.Return(run) + return _c +} + +// EmitUserMetric provides a mock function with given fields: ctx, metric +func (_m *MockExecutionHelper) EmitUserMetric(ctx context.Context, metric *v2.WorkflowUserMetric) error { + ret := _m.Called(ctx, metric) + + if len(ret) == 0 { + panic("no return value specified for EmitUserMetric") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.WorkflowUserMetric) error); ok { + r0 = rf(ctx, metric) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockExecutionHelper_EmitUserMetric_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmitUserMetric' +type MockExecutionHelper_EmitUserMetric_Call struct { + *mock.Call +} + +// EmitUserMetric is a helper method to define mock.On call +// - ctx context.Context +// - metric *v2.WorkflowUserMetric +func (_e *MockExecutionHelper_Expecter) EmitUserMetric(ctx interface{}, metric interface{}) *MockExecutionHelper_EmitUserMetric_Call { + return &MockExecutionHelper_EmitUserMetric_Call{Call: _e.mock.On("EmitUserMetric", ctx, metric)} +} + +func (_c *MockExecutionHelper_EmitUserMetric_Call) Run(run func(ctx context.Context, metric *v2.WorkflowUserMetric)) *MockExecutionHelper_EmitUserMetric_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*v2.WorkflowUserMetric)) + }) + return _c +} + +func (_c *MockExecutionHelper_EmitUserMetric_Call) Return(_a0 error) *MockExecutionHelper_EmitUserMetric_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_EmitUserMetric_Call) RunAndReturn(run func(context.Context, *v2.WorkflowUserMetric) error) *MockExecutionHelper_EmitUserMetric_Call { + _c.Call.Return(run) + return _c +} + +// GetDONTime provides a mock function with no fields +func (_m *MockExecutionHelper) GetDONTime() (time.Time, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetDONTime") + } + + var r0 time.Time + var r1 error + if rf, ok := ret.Get(0).(func() (time.Time, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_GetDONTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDONTime' +type MockExecutionHelper_GetDONTime_Call struct { + *mock.Call +} + +// GetDONTime is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetDONTime() *MockExecutionHelper_GetDONTime_Call { + return &MockExecutionHelper_GetDONTime_Call{Call: _e.mock.On("GetDONTime")} +} + +func (_c *MockExecutionHelper_GetDONTime_Call) Run(run func()) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetDONTime_Call) Return(_a0 time.Time, _a1 error) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_GetDONTime_Call) RunAndReturn(run func() (time.Time, error)) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Return(run) + return _c +} + +// GetNodeTime provides a mock function with no fields +func (_m *MockExecutionHelper) GetNodeTime() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetNodeTime") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +// MockExecutionHelper_GetNodeTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeTime' +type MockExecutionHelper_GetNodeTime_Call struct { + *mock.Call +} + +// GetNodeTime is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetNodeTime() *MockExecutionHelper_GetNodeTime_Call { + return &MockExecutionHelper_GetNodeTime_Call{Call: _e.mock.On("GetNodeTime")} +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) Run(run func()) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) Return(_a0 time.Time) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Return(run) + return _c +} + +// GetSecrets provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelper) GetSecrets(ctx context.Context, request *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for GetSecrets") + } + + var r0 []*sdk.SecretResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest) []*sdk.SecretResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*sdk.SecretResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.GetSecretsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_GetSecrets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSecrets' +type MockExecutionHelper_GetSecrets_Call struct { + *mock.Call +} + +// GetSecrets is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.GetSecretsRequest +func (_e *MockExecutionHelper_Expecter) GetSecrets(ctx interface{}, request interface{}) *MockExecutionHelper_GetSecrets_Call { + return &MockExecutionHelper_GetSecrets_Call{Call: _e.mock.On("GetSecrets", ctx, request)} +} + +func (_c *MockExecutionHelper_GetSecrets_Call) Run(run func(ctx context.Context, request *sdk.GetSecretsRequest)) *MockExecutionHelper_GetSecrets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.GetSecretsRequest)) + }) + return _c +} + +func (_c *MockExecutionHelper_GetSecrets_Call) Return(_a0 []*sdk.SecretResponse, _a1 error) *MockExecutionHelper_GetSecrets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_GetSecrets_Call) RunAndReturn(run func(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error)) *MockExecutionHelper_GetSecrets_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowExecutionID provides a mock function with no fields +func (_m *MockExecutionHelper) GetWorkflowExecutionID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowExecutionID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockExecutionHelper_GetWorkflowExecutionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowExecutionID' +type MockExecutionHelper_GetWorkflowExecutionID_Call struct { + *mock.Call +} + +// GetWorkflowExecutionID is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetWorkflowExecutionID() *MockExecutionHelper_GetWorkflowExecutionID_Call { + return &MockExecutionHelper_GetWorkflowExecutionID_Call{Call: _e.mock.On("GetWorkflowExecutionID")} +} + +func (_c *MockExecutionHelper_GetWorkflowExecutionID_Call) Run(run func()) *MockExecutionHelper_GetWorkflowExecutionID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetWorkflowExecutionID_Call) Return(_a0 string) *MockExecutionHelper_GetWorkflowExecutionID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetWorkflowExecutionID_Call) RunAndReturn(run func() string) *MockExecutionHelper_GetWorkflowExecutionID_Call { + _c.Call.Return(run) + return _c +} + +// NewMockExecutionHelper creates a new instance of MockExecutionHelper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockExecutionHelper(t interface { + mock.TestingT + Cleanup(func()) +}) *MockExecutionHelper { + mock := &MockExecutionHelper{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/host/mocks/module.go b/pkg/workflows/host/mocks/module.go new file mode 100644 index 0000000000..8576d1238e --- /dev/null +++ b/pkg/workflows/host/mocks/module.go @@ -0,0 +1,207 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + + host "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + mock "github.com/stretchr/testify/mock" +) + +// Module is an autogenerated mock type for the Module type +type Module struct { + mock.Mock +} + +type Module_Expecter struct { + mock *mock.Mock +} + +func (_m *Module) EXPECT() *Module_Expecter { + return &Module_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with no fields +func (_m *Module) Close() { + _m.Called() +} + +// Module_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type Module_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *Module_Expecter) Close() *Module_Close_Call { + return &Module_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *Module_Close_Call) Run(run func()) *Module_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_Close_Call) Return() *Module_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *Module_Close_Call) RunAndReturn(run func()) *Module_Close_Call { + _c.Run(run) + return _c +} + +// Execute provides a mock function with given fields: ctx, request, handler +func (_m *Module) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper) (*sdk.ExecutionResult, error) { + ret := _m.Called(ctx, request, handler) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 *sdk.ExecutionResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)); ok { + return rf(ctx, request, handler) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) *sdk.ExecutionResult); ok { + r0 = rf(ctx, request, handler) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sdk.ExecutionResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) error); ok { + r1 = rf(ctx, request, handler) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Module_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type Module_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.ExecuteRequest +// - handler host.ExecutionHelper +func (_e *Module_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *Module_Execute_Call { + return &Module_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} +} + +func (_c *Module_Execute_Call) Run(run func(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper)) *Module_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.ExecuteRequest), args[2].(host.ExecutionHelper)) + }) + return _c +} + +func (_c *Module_Execute_Call) Return(_a0 *sdk.ExecutionResult, _a1 error) *Module_Execute_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Module_Execute_Call) RunAndReturn(run func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)) *Module_Execute_Call { + _c.Call.Return(run) + return _c +} + +// IsLegacyDAG provides a mock function with no fields +func (_m *Module) IsLegacyDAG() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsLegacyDAG") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Module_IsLegacyDAG_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsLegacyDAG' +type Module_IsLegacyDAG_Call struct { + *mock.Call +} + +// IsLegacyDAG is a helper method to define mock.On call +func (_e *Module_Expecter) IsLegacyDAG() *Module_IsLegacyDAG_Call { + return &Module_IsLegacyDAG_Call{Call: _e.mock.On("IsLegacyDAG")} +} + +func (_c *Module_IsLegacyDAG_Call) Run(run func()) *Module_IsLegacyDAG_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_IsLegacyDAG_Call) Return(_a0 bool) *Module_IsLegacyDAG_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Module_IsLegacyDAG_Call) RunAndReturn(run func() bool) *Module_IsLegacyDAG_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with no fields +func (_m *Module) Start() { + _m.Called() +} + +// Module_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type Module_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +func (_e *Module_Expecter) Start() *Module_Start_Call { + return &Module_Start_Call{Call: _e.mock.On("Start")} +} + +func (_c *Module_Start_Call) Run(run func()) *Module_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_Start_Call) Return() *Module_Start_Call { + _c.Call.Return() + return _c +} + +func (_c *Module_Start_Call) RunAndReturn(run func()) *Module_Start_Call { + _c.Run(run) + return _c +} + +// NewModule creates a new instance of Module. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewModule(t interface { + mock.TestingT + Cleanup(func()) +}) *Module { + mock := &Module{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/host/module.go b/pkg/workflows/host/module.go new file mode 100644 index 0000000000..1db4c086ce --- /dev/null +++ b/pkg/workflows/host/module.go @@ -0,0 +1,41 @@ +//go:generate go run ./requirements_gen + +package host + +import ( + "context" + "time" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" +) + +type ModuleBase interface { + Start() + Close() + IsLegacyDAG() bool +} + +type Module interface { + ModuleBase + + // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution + Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) +} + +// ExecutionHelper Implemented by those running the host, for example the Workflow Engine +type ExecutionHelper interface { + // CallCapability blocking call to the Workflow Engine + CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) + GetSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) + + GetWorkflowExecutionID() string + + GetNodeTime() time.Time + + GetDONTime() (time.Time, error) + + EmitUserLog(log string) error + + EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error +} diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go new file mode 100644 index 0000000000..0cfc7232b9 --- /dev/null +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -0,0 +1,105 @@ +package host + +import ( + "context" + "errors" + "fmt" + "sync" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type ModuleAndHandler struct { + Module + RequirementsHandler +} + +// lazyModule wraps a ModuleAndHandler so that Start is called at most once. +type lazyModule struct { + ModuleAndHandler + startOnce sync.Once + started bool +} + +func (l *lazyModule) ensureStarted() { + l.startOnce.Do(func() { + l.Module.Start() + l.started = true + }) +} + +func NewRequirementSelectingModule(main ModuleAndHandler, additional []ModuleAndHandler) Module { + wrapped := make([]*lazyModule, len(additional)) + for i := range additional { + wrapped[i] = &lazyModule{ModuleAndHandler: additional[i]} + } + return &requirementSelectingModule{ + main: main, + additional: wrapped, + } +} + +type requirementSelectingModule struct { + main ModuleAndHandler + additional []*lazyModule + // triggerID → index into additional + cache sync.Map +} + +func (r *requirementSelectingModule) Start() { + r.main.Start() +} + +func (r *requirementSelectingModule) Close() { + r.main.Close() + for _, m := range r.additional { + if m.started { + m.Close() + } + } +} + +func (r *requirementSelectingModule) IsLegacyDAG() bool { + return r.main.IsLegacyDAG() +} + +func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + triggerID, hasTrigger := extractTriggerID(request) + + if hasTrigger { + if idx, ok := r.cache.Load(triggerID); ok { + return r.additional[idx.(int)].Execute(ctx, request, handler) + } + } + + result, err := r.main.Execute(ctx, request, handler) + if err == nil { + return result, nil + } + + rerun := &RequirementsRerun{} + if !errors.As(err, &rerun) { + return nil, err + } + + for i, m := range r.additional { + if CheckRequirements(m.RequirementsHandler, (*sdk.Requirements)(rerun)) { + m.ensureStarted() + if hasTrigger { + r.cache.Store(triggerID, i) + } + return m.Execute(ctx, request, handler) + } + } + + return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements %+v", rerun) +} + +func extractTriggerID(req *sdk.ExecuteRequest) (uint64, bool) { + if t := req.GetTrigger(); t != nil { + return t.Id, true + } + return 0, false +} + +var _ Module = &requirementSelectingModule{} diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go new file mode 100644 index 0000000000..8641089953 --- /dev/null +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -0,0 +1,462 @@ +package host + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type stubModule struct { + startFn func() + closeFn func() + legacyFn func() bool + executeFn func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) +} + +func (s *stubModule) Start() { s.startFn() } +func (s *stubModule) Close() { s.closeFn() } +func (s *stubModule) IsLegacyDAG() bool { return s.legacyFn() } +func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + return s.executeFn(ctx, req, h) +} + +func noop() {} +func noopClose() {} + +func triggerRequest(id uint64) *sdk.ExecuteRequest { + return &sdk.ExecuteRequest{ + Request: &sdk.ExecuteRequest_Trigger{ + Trigger: &sdk.Trigger{Id: id}, + }, + } +} + +func subscribeRequest() *sdk.ExecuteRequest { + return &sdk.ExecuteRequest{ + Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}, + } +} + +func TestRequirementSelectingModule_Start(t *testing.T) { + t.Run("starts only main module", func(t *testing.T) { + var mainStarted, additionalStarted bool + main := ModuleAndHandler{Module: &stubModule{startFn: func() { mainStarted = true }}} + add := ModuleAndHandler{Module: &stubModule{startFn: func() { additionalStarted = true }}} + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + assert.True(t, mainStarted) + assert.False(t, additionalStarted) + }) +} + +func TestRequirementSelectingModule_Close(t *testing.T) { + t.Run("closes main and no additional when none started", func(t *testing.T) { + var mainClosed, addClosed bool + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, closeFn: func() { mainClosed = true }, + }} + add := ModuleAndHandler{Module: &stubModule{ + startFn: noop, closeFn: func() { addClosed = true }, + }} + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + m.Close() + + assert.True(t, mainClosed) + assert.False(t, addClosed) + }) + + t.Run("closes main and all started additional modules", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + + var mainClosed, add0Closed, add1Closed bool + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + closeFn: func() { mainClosed = true }, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + }} + add0 := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: func() { add0Closed = true }, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + add1 := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: func() { add1Closed = true }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) + m.Start() + + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + + m.Close() + + assert.True(t, mainClosed, "main should be closed") + assert.True(t, add0Closed, "started additional should be closed") + assert.False(t, add1Closed, "never-started additional should not be closed") + }) +} + +func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { + main := ModuleAndHandler{Module: &stubModule{legacyFn: func() bool { return true }}} + m := NewRequirementSelectingModule(main, nil) + assert.True(t, m.IsLegacyDAG()) +} + +func TestRequirementSelectingModule_Execute(t *testing.T) { + t.Run("main succeeds — returns result directly", func(t *testing.T) { + want := &sdk.ExecutionResult{} + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + got, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("main non-RequirementsRerun error propagates", func(t *testing.T) { + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, assert.AnError + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + assert.ErrorIs(t, err, assert.AnError) + }) + + t.Run("RequirementsRerun routes to matching additional", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + got, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("RequirementsRerun with no matching additional returns error", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{startFn: noop}, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner that can satisfy the requirements") + }) + + t.Run("RequirementsRerun skips non-matching and selects later match", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + }} + add0 := ModuleAndHandler{ + Module: &stubModule{startFn: noop}, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + } + add1 := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) + m.Start() + + got, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("additional module started lazily", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + var addStartCount int32 + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: func() { atomic.AddInt32(&addStartCount, 1) }, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + // First execution starts the additional module. + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + + // Second execution with a different trigger still goes through main, + // but the additional module is not started again (sync.Once). + _, err = m.Execute(t.Context(), triggerRequest(2), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + }) + + t.Run("subscribe request goes through main directly", func(t *testing.T) { + want := &sdk.ExecutionResult{} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + got, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) +} + +func TestRequirementSelectingModule_TriggerCache(t *testing.T) { + t.Run("cached trigger skips main on subsequent calls", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + var mainCalls int32 + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + atomic.AddInt32(&mainCalls, 1) + return nil, rerunErr + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + // First call triggers main. + _, err := m.Execute(t.Context(), triggerRequest(42), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainCalls)) + + // Second call with same trigger ID skips main. + _, err = m.Execute(t.Context(), triggerRequest(42), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainCalls)) + }) + + t.Run("different trigger IDs are cached independently", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + var mainCalls int32 + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + atomic.AddInt32(&mainCalls, 1) + return nil, rerunErr + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + // Trigger 1 goes through main. + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainCalls)) + + // Trigger 2 also goes through main (different ID, not cached). + _, err = m.Execute(t.Context(), triggerRequest(2), nil) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&mainCalls)) + + // Both are now cached — neither goes through main. + _, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + _, err = m.Execute(t.Context(), triggerRequest(2), nil) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&mainCalls)) + }) + + t.Run("different triggers can route to different additional modules", func(t *testing.T) { + teeRerun := &RequirementsRerun{Tee: &sdk.Tee{ + Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{ + Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}, + }}, + }} + noReqRerun := &RequirementsRerun{} + + callCount := 0 + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + callCount++ + if req.GetTrigger().Id == 1 { + return nil, teeRerun + } + return nil, noReqRerun + }, + }} + + var addNitroResult, addDefaultResult sdk.ExecutionResult + addNitro := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &addNitroResult, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(tee *sdk.Tee) bool { + return tee.GetTypeSelection() != nil + }}, + } + addDefault := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &addDefaultResult, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{addNitro, addDefault}) + m.Start() + + got, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, &addNitroResult, got) + + // Trigger 1 is now cached to addNitro; verify second call skips main. + got, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, &addNitroResult, got) + assert.Equal(t, 1, callCount, "main should only be called once for trigger 1") + }) + + t.Run("RequirementsRerun with no additional modules returns error", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, rerunErr + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner") + }) +} diff --git a/pkg/workflows/wasm/host/requirements_gen/main.go b/pkg/workflows/host/requirements_gen/main.go similarity index 100% rename from pkg/workflows/wasm/host/requirements_gen/main.go rename to pkg/workflows/host/requirements_gen/main.go diff --git a/pkg/workflows/wasm/host/requirements_gen/requirements_helper.go.tmpl b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl similarity index 100% rename from pkg/workflows/wasm/host/requirements_gen/requirements_helper.go.tmpl rename to pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl diff --git a/pkg/workflows/wasm/host/requirements_helper_gen.go b/pkg/workflows/host/requirements_helper_gen.go similarity index 100% rename from pkg/workflows/wasm/host/requirements_helper_gen.go rename to pkg/workflows/host/requirements_helper_gen.go diff --git a/pkg/workflows/wasm/host/requirements_helper_gen_test.go b/pkg/workflows/host/requirements_helper_gen_test.go similarity index 100% rename from pkg/workflows/wasm/host/requirements_helper_gen_test.go rename to pkg/workflows/host/requirements_helper_gen_test.go diff --git a/pkg/workflows/wasm/host/requirements_rerun.go b/pkg/workflows/host/requirements_rerun.go similarity index 100% rename from pkg/workflows/wasm/host/requirements_rerun.go rename to pkg/workflows/host/requirements_rerun.go diff --git a/pkg/workflows/wasm/host/tee_provider.go b/pkg/workflows/host/tee_provider.go similarity index 100% rename from pkg/workflows/wasm/host/tee_provider.go rename to pkg/workflows/host/tee_provider.go diff --git a/pkg/workflows/wasm/host/tee_provider_test.go b/pkg/workflows/host/tee_provider_test.go similarity index 100% rename from pkg/workflows/wasm/host/tee_provider_test.go rename to pkg/workflows/host/tee_provider_test.go diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index d538013755..97bd3147df 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -10,6 +10,8 @@ import ( "github.com/bytecodealliance/wasmtime-go/v28" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/config" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -400,8 +402,8 @@ func (e *execution[T]) requirements(caller *wasmtime.Caller, ptr int32, ptrlen i return wasmtime.NewTrap(e.requirementsRerunErr.Error()) } - if !CheckRequirements(e.module.cfg.RequirementsHandler, requirements) { - e.requirementsRerunErr = (*RequirementsRerun)(requirements) + if !host.CheckRequirements(e.module.cfg.RequirementsHandler, requirements) { + e.requirementsRerunErr = (*host.RequirementsRerun)(requirements) return wasmtime.NewTrap(e.requirementsRerunErr.Error()) } diff --git a/pkg/workflows/wasm/host/mocks/module_v2.go b/pkg/workflows/wasm/host/mocks/module_v2.go index 4c84a3b4ae..dbcb7aa0c1 100644 --- a/pkg/workflows/wasm/host/mocks/module_v2.go +++ b/pkg/workflows/wasm/host/mocks/module_v2.go @@ -1,207 +1,20 @@ -// Code generated by mockery v2.53.3. DO NOT EDIT. - package mocks import ( - context "context" + "github.com/stretchr/testify/mock" - host "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" - sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" - mock "github.com/stretchr/testify/mock" + hostmocks "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" ) -// ModuleV2 is an autogenerated mock type for the ModuleV2 type -type ModuleV2 struct { - mock.Mock -} - -type ModuleV2_Expecter struct { - mock *mock.Mock -} - -func (_m *ModuleV2) EXPECT() *ModuleV2_Expecter { - return &ModuleV2_Expecter{mock: &_m.Mock} -} - -// Close provides a mock function with no fields -func (_m *ModuleV2) Close() { - _m.Called() -} - -// ModuleV2_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' -type ModuleV2_Close_Call struct { - *mock.Call -} - -// Close is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) Close() *ModuleV2_Close_Call { - return &ModuleV2_Close_Call{Call: _e.mock.On("Close")} -} - -func (_c *ModuleV2_Close_Call) Run(run func()) *ModuleV2_Close_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_Close_Call) Return() *ModuleV2_Close_Call { - _c.Call.Return() - return _c -} - -func (_c *ModuleV2_Close_Call) RunAndReturn(run func()) *ModuleV2_Close_Call { - _c.Run(run) - return _c -} - -// Execute provides a mock function with given fields: ctx, request, handler -func (_m *ModuleV2) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper) (*sdk.ExecutionResult, error) { - ret := _m.Called(ctx, request, handler) - - if len(ret) == 0 { - panic("no return value specified for Execute") - } - - var r0 *sdk.ExecutionResult - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)); ok { - return rf(ctx, request, handler) - } - if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) *sdk.ExecutionResult); ok { - r0 = rf(ctx, request, handler) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*sdk.ExecutionResult) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) error); ok { - r1 = rf(ctx, request, handler) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ModuleV2_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' -type ModuleV2_Execute_Call struct { - *mock.Call -} - -// Execute is a helper method to define mock.On call -// - ctx context.Context -// - request *sdk.ExecuteRequest -// - handler host.ExecutionHelper -func (_e *ModuleV2_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *ModuleV2_Execute_Call { - return &ModuleV2_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} -} - -func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper)) *ModuleV2_Execute_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*sdk.ExecuteRequest), args[2].(host.ExecutionHelper)) - }) - return _c -} - -func (_c *ModuleV2_Execute_Call) Return(_a0 *sdk.ExecutionResult, _a1 error) *ModuleV2_Execute_Call { - _c.Call.Return(_a0, _a1) - return _c -} +// ModuleV2 is a backward-compatible alias for hostmocks.Module. +// The ModuleV2 interface now lives in pkg/workflows/host as Module; +// this alias keeps existing consumers compiling without changes. +type ModuleV2 = hostmocks.Module -func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)) *ModuleV2_Execute_Call { - _c.Call.Return(run) - return _c -} - -// IsLegacyDAG provides a mock function with no fields -func (_m *ModuleV2) IsLegacyDAG() bool { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for IsLegacyDAG") - } - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -// ModuleV2_IsLegacyDAG_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsLegacyDAG' -type ModuleV2_IsLegacyDAG_Call struct { - *mock.Call -} - -// IsLegacyDAG is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) IsLegacyDAG() *ModuleV2_IsLegacyDAG_Call { - return &ModuleV2_IsLegacyDAG_Call{Call: _e.mock.On("IsLegacyDAG")} -} - -func (_c *ModuleV2_IsLegacyDAG_Call) Run(run func()) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_IsLegacyDAG_Call) Return(_a0 bool) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *ModuleV2_IsLegacyDAG_Call) RunAndReturn(run func() bool) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Return(run) - return _c -} - -// Start provides a mock function with no fields -func (_m *ModuleV2) Start() { - _m.Called() -} - -// ModuleV2_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type ModuleV2_Start_Call struct { - *mock.Call -} - -// Start is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) Start() *ModuleV2_Start_Call { - return &ModuleV2_Start_Call{Call: _e.mock.On("Start")} -} - -func (_c *ModuleV2_Start_Call) Run(run func()) *ModuleV2_Start_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_Start_Call) Return() *ModuleV2_Start_Call { - _c.Call.Return() - return _c -} - -func (_c *ModuleV2_Start_Call) RunAndReturn(run func()) *ModuleV2_Start_Call { - _c.Run(run) - return _c -} - -// NewModuleV2 creates a new instance of ModuleV2. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. +// NewModuleV2 creates a new instance of ModuleV2 (alias for hostmocks.NewModule). func NewModuleV2(t interface { mock.TestingT Cleanup(func()) }) *ModuleV2 { - mock := &ModuleV2{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock + return hostmocks.NewModule(t) } diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index c512359c54..9f29c2cf38 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -21,6 +21,8 @@ import ( "github.com/bytecodealliance/wasmtime-go/v28" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -31,7 +33,6 @@ import ( wasmdagpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" "github.com/smartcontractkit/chainlink-protos/cre/go/values" - wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) const v2ImportPrefix = "version_v2" @@ -102,14 +103,10 @@ type ModuleConfig struct { // If Determinism is set, the module will override the random_get function in the WASI API with // the provided seed to ensure deterministic behavior. Determinism *DeterminismConfig - RequirementsHandler RequirementsHandler + RequirementsHandler host.RequirementsHandler } -type ModuleBase interface { - Start() - Close() - IsLegacyDAG() bool -} +type ModuleBase = host.ModuleBase type ModuleV1 interface { ModuleBase @@ -118,29 +115,9 @@ type ModuleV1 interface { Run(ctx context.Context, request *wasmdagpb.Request) (*wasmdagpb.Response, error) } -type ModuleV2 interface { - ModuleBase - - // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution - Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) -} - -// ExecutionHelper Implemented by those running the host, for example the Workflow Engine -type ExecutionHelper interface { - // CallCapability blocking call to the Workflow Engine - CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) - GetSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) +type ModuleV2 = host.Module - GetWorkflowExecutionID() string - - GetNodeTime() time.Time - - GetDONTime() (time.Time, error) - - EmitUserLog(log string) error - - EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error -} +type ExecutionHelper = host.ExecutionHelper type module struct { engine *wasmtime.Engine diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index af5d4fc12c..ed53ab45ce 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/utils/matches" @@ -620,7 +622,7 @@ func Test_SdkLabeler(t *testing.T) { // CallAwaitRace validates that every call can be awaited. func Test_CallAwaitRace(t *testing.T) { ctx := t.Context() - mockExecHelper := NewMockExecutionHelper(t) + mockExecHelper := mocks.NewMockExecutionHelper(t) mockExecHelper.EXPECT(). CallCapability(matches.AnyContext, mock.Anything). Return(&sdkpb.CapabilityResponse{}, nil) diff --git a/pkg/workflows/wasm/host/requirement_selecting_module.go b/pkg/workflows/wasm/host/requirement_selecting_module.go deleted file mode 100644 index 4a6988ef09..0000000000 --- a/pkg/workflows/wasm/host/requirement_selecting_module.go +++ /dev/null @@ -1,84 +0,0 @@ -package host - -import ( - "context" - "errors" - "fmt" - "sync" - "sync/atomic" - - "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" -) - -type ModuleAndHandler struct { - ModuleV2 - RequirementsHandler -} - -func NewRequirementSelectingModule(moduleAndHandlers []ModuleAndHandler) ModuleV2 { - return &requirementSelectingModule{ - moduleAndHandler: moduleAndHandlers, - runOn: -1, - } -} - -type requirementSelectingModule struct { - moduleAndHandler []ModuleAndHandler - runOn int - started atomic.Bool - findMutex sync.Mutex -} - -func (r *requirementSelectingModule) Start() { - r.started.Store(true) - r.moduleAndHandler[0].Start() -} - -func (r *requirementSelectingModule) Close() { - r.findMutex.Lock() - defer r.findMutex.Unlock() - if r.runOn == -1 { - r.moduleAndHandler[0].Close() - } else { - r.moduleAndHandler[r.runOn].Close() - } -} - -func (r *requirementSelectingModule) IsLegacyDAG() bool { - return r.moduleAndHandler[0].IsLegacyDAG() -} - -func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { - if r.runOn >= 0 { - return r.moduleAndHandler[r.runOn].Execute(ctx, request, handler) - } - - r.findMutex.Lock() - defer r.findMutex.Unlock() - result, err := r.moduleAndHandler[0].Execute(ctx, request, handler) - if err == nil { - r.runOn = 0 - return result, nil - } - - rerun := &RequirementsRerun{} - if !errors.As(err, &rerun) { - return nil, err - } - - numHandlers := len(r.moduleAndHandler) - for i := 1; i < numHandlers; i++ { - item := r.moduleAndHandler[i] - if CheckRequirements(item.RequirementsHandler, (*sdk.Requirements)(rerun)) { - r.runOn = i - if r.started.Load() { - item.Start() - } - return item.Execute(ctx, request, handler) - } - } - - return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements %+v\n", rerun) -} - -var _ ModuleV2 = &requirementSelectingModule{} diff --git a/pkg/workflows/wasm/host/requirement_selecting_module_test.go b/pkg/workflows/wasm/host/requirement_selecting_module_test.go deleted file mode 100644 index fdaba31e8a..0000000000 --- a/pkg/workflows/wasm/host/requirement_selecting_module_test.go +++ /dev/null @@ -1,253 +0,0 @@ -package host - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" -) - -type stubModuleV2 struct { - startFn func() - closeFn func() - legacyFn func() bool - executeFn func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) -} - -func (s *stubModuleV2) Start() { s.startFn() } -func (s *stubModuleV2) Close() { s.closeFn() } -func (s *stubModuleV2) IsLegacyDAG() bool { return s.legacyFn() } -func (s *stubModuleV2) Execute(ctx context.Context, req *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { - return s.executeFn(ctx, req, h) -} - -func TestRequirementSelectingModule_Start(t *testing.T) { - var started bool - m0 := &stubModuleV2{startFn: func() { started = true }} - m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) - m.Start() - assert.True(t, started) -} - -func TestRequirementSelectingModule_Close(t *testing.T) { - t.Run("before execute closes first module", func(t *testing.T) { - var closedIdx int - m0 := &stubModuleV2{closeFn: func() { closedIdx = 0 }} - m1 := &stubModuleV2{closeFn: func() { closedIdx = 1 }} - m := NewRequirementSelectingModule([]ModuleAndHandler{ - {ModuleV2: m0}, - {ModuleV2: m1}, - }) - closedIdx = -1 - m.Close() - assert.Equal(t, 0, closedIdx) - }) - - t.Run("after execute closes selected module", func(t *testing.T) { - wantResult := &sdk.ExecutionResult{} - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - var closedIdx int - - m0 := &stubModuleV2{ - startFn: func() {}, - closeFn: func() { closedIdx = 0 }, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr - }, - } - m1 := &stubModuleV2{ - closeFn: func() { closedIdx = 1 }, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return wantResult, nil - }, - } - - m := NewRequirementSelectingModule([]ModuleAndHandler{ - {ModuleV2: m0}, - {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, - }) - - _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - - closedIdx = -1 - m.Close() - assert.Equal(t, 1, closedIdx) - }) -} - -func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { - t.Run("delegates", func(t *testing.T) { - m0 := &stubModuleV2{legacyFn: func() bool { return true }} - m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) - assert.True(t, m.IsLegacyDAG()) - }) -} - -func TestRequirementSelectingModule_Execute(t *testing.T) { - t.Run("delegates when runOn already set", func(t *testing.T) { - calls := 0 - wantResult := &sdk.ExecutionResult{} - m0 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - calls++ - return wantResult, nil - }, - } - - m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) - - _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - assert.Equal(t, 1, calls) - - got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - assert.Equal(t, wantResult, got) - assert.Equal(t, 2, calls) - }) - - t.Run("first module succeeds sets runOn to zero", func(t *testing.T) { - wantResult := &sdk.ExecutionResult{} - numCalls := 0 - m0 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - numCalls++ - return wantResult, nil - }, - } - - m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}}) - - got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - assert.Equal(t, 1, numCalls) - assert.Equal(t, wantResult, got) - }) - - t.Run("non-RequirementsRerun error is propagated without additional executions", func(t *testing.T) { - m0 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, assert.AnError - }, - } - - m1 := &stubModuleV2{executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - assert.Fail(t, "second module should not be executed") - return nil, nil - }} - - m := NewRequirementSelectingModule([]ModuleAndHandler{{ModuleV2: m0}, {ModuleV2: m1}}) - - _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - assert.ErrorIs(t, err, assert.AnError) - }) - - t.Run("RequirementsRerun with matching handler not started", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - wantResult := &sdk.ExecutionResult{} - var m1Started bool - - m0 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr - }, - } - m1 := &stubModuleV2{ - startFn: func() { m1Started = true }, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return wantResult, nil - }, - } - - m := NewRequirementSelectingModule([]ModuleAndHandler{ - {ModuleV2: m0}, - {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, - }) - - got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - assert.Equal(t, wantResult, got) - assert.False(t, m1Started) - }) - - t.Run("RequirementsRerun with matching handler already started", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - wantResult := &sdk.ExecutionResult{} - var m1Started bool - - m0 := &stubModuleV2{ - startFn: func() {}, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr - }, - } - m1 := &stubModuleV2{ - startFn: func() { m1Started = true }, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return wantResult, nil - }, - } - - m := NewRequirementSelectingModule([]ModuleAndHandler{ - {ModuleV2: m0}, - {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, - }) - - m.Start() - - got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - assert.Equal(t, wantResult, got) - assert.True(t, m1Started) - }) - - t.Run("RequirementsRerun with no matching handler returns error", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - m0 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr - }, - } - m1 := &stubModuleV2{} - - m := NewRequirementSelectingModule([]ModuleAndHandler{ - {ModuleV2: m0}, - {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}}, - }) - - _, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "cannot find a runner that can satisfy the requirements") - }) - - t.Run("RequirementsRerun skips non-matching selects later match", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - wantResult := &sdk.ExecutionResult{} - - m0 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr - }, - } - m1 := &stubModuleV2{} - m2 := &stubModuleV2{ - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return wantResult, nil - }, - } - - m := NewRequirementSelectingModule([]ModuleAndHandler{ - {ModuleV2: m0}, - {ModuleV2: m1, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}}, - {ModuleV2: m2, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}}, - }) - - got, err := m.Execute(t.Context(), &sdk.ExecuteRequest{}, nil) - require.NoError(t, err) - assert.Equal(t, wantResult, got) - }) -} diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index 0d96d4f195..d04053f9c9 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -25,6 +25,8 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" @@ -47,7 +49,7 @@ func init() { func TestStandardConfig(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Some languages call time during initiation of the executable before the main is called. // This would be in unknown mode, which would call Node mode by default. @@ -63,7 +65,7 @@ func TestStandardConfig(t *testing.T) { func TestStandardErrors(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -77,7 +79,7 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { // To ensure the calls are actually async, the mock will block the first call until the second call is made. // The first call sets InputThing to true, the second to false. t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -120,7 +122,7 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() }).Maybe() @@ -152,7 +154,7 @@ func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { func TestStandardModeSwitch(t *testing.T) { t.Parallel() t.Run("successful mode switch", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Node calls may occur on initialization depending on the language. var donCall bool @@ -192,7 +194,7 @@ func TestStandardModeSwitch(t *testing.T) { }) t.Run("node runtime in don mode", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -217,7 +219,7 @@ func TestStandardModeSwitch(t *testing.T) { }) t.Run("don runtime in node mode", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -252,7 +254,7 @@ func TestStandardModeSwitch(t *testing.T) { func TestStandardLogging(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -272,7 +274,7 @@ func TestStandardMultipleTriggers(t *testing.T) { t.Parallel() m := makeTestModule(t) t.Run("test registration", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -324,7 +326,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("first callback", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -338,7 +340,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("same trigger as first one but different registration", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -351,7 +353,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("different capability callback", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -370,7 +372,7 @@ func TestStandardRandom(t *testing.T) { // Test binary executes node mode code conditionally based on the value >= 100 anyId := "Id" - gte100Exec := NewMockExecutionHelper(t) + gte100Exec := mocks.NewMockExecutionHelper(t) gte100Exec.EXPECT().GetWorkflowExecutionID().Return(anyId) gte100Exec.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -400,7 +402,7 @@ func TestStandardRandom(t *testing.T) { value1 := executeWithResult[any](t, m, anyRequest, gte100Exec) t.Run("Same execution id gives the same randoms even if random is called in node mode", func(t *testing.T) { - lt100Exec := NewMockExecutionHelper(t) + lt100Exec := mocks.NewMockExecutionHelper(t) lt100Exec.EXPECT().GetWorkflowExecutionID().Return(anyId) lt100Exec.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -424,7 +426,7 @@ func TestStandardRandom(t *testing.T) { t.Run("Different execution id give different randoms", func(t *testing.T) { require.NoError(t, err) - gte100Exec2 := NewMockExecutionHelper(t) + gte100Exec2 := mocks.NewMockExecutionHelper(t) gte100Exec2.EXPECT().GetWorkflowExecutionID().Return("differentId") gte100Exec2.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -470,7 +472,7 @@ func TestStandardSecrets(t *testing.T) { } func TestStandardSecretsFailInNodeMode(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -504,7 +506,7 @@ func TestStandardSecretsFailInNodeMode(t *testing.T) { func TestStandardTimeInterpretation(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Inject fixed timestamp: 1577934245000 milliseconds = 2020-01-02T03:04:05Z fixedTime := time.UnixMilli(1577934245000) @@ -551,7 +553,7 @@ func TestStandardTeeRuntime(t *testing.T) { }, } { t.Run(test.name, func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -695,7 +697,7 @@ func assertProto[T proto.Message](t *testing.T, expected, actual T) { } func runSecretTest(t *testing.T, m *module, secretResponse *sdk.SecretResponse) *sdk.ExecutionResult { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() diff --git a/pkg/workflows/wasm/host/time_test.go b/pkg/workflows/wasm/host/time_test.go index b792588c99..7f7139b230 100644 --- a/pkg/workflows/wasm/host/time_test.go +++ b/pkg/workflows/wasm/host/time_test.go @@ -8,13 +8,14 @@ import ( "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) func TestTimeFetcher_GetTime_NODE(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) expected := time.Now() mockExec.EXPECT().GetNodeTime().Return(expected) @@ -29,7 +30,7 @@ func TestTimeFetcher_GetTime_NODE(t *testing.T) { func TestTimeFetcher_GetTime_DON(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) expected := time.Now() mockExec.EXPECT().GetDONTime().Return(expected, nil) @@ -44,7 +45,7 @@ func TestTimeFetcher_GetTime_DON(t *testing.T) { func TestTimeFetcher_GetTime_DON_Error(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Return(time.Time{}, errors.New("don error")) tf := newTimeFetcher(ctx, mockExec) @@ -58,7 +59,7 @@ func TestTimeFetcher_ContextCancelledBeforeRequest(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) cancel() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Return(time.Time{}, context.Canceled).Maybe() tf := newTimeFetcher(ctx, mockExec) @@ -81,7 +82,7 @@ func TestTimeFetcher_ContextCancelledDuringResponse(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Run(func() { time.Sleep(20 * time.Millisecond) // force timeout }).Return(time.Time{}, nil) diff --git a/pkg/workflows/wasm/host/wasm.go b/pkg/workflows/wasm/host/wasm.go index 2ec826d4db..d8c4bae7f1 100644 --- a/pkg/workflows/wasm/host/wasm.go +++ b/pkg/workflows/wasm/host/wasm.go @@ -1,7 +1,5 @@ package host -//go:generate go run ./requirements_gen - import ( "context" "errors" diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index fd27a99798..5b3a82a2ed 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -14,6 +14,8 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" + generichost "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -51,7 +53,7 @@ func Test_Sleep_Timeout(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -107,7 +109,7 @@ func Test_NoDag_Run(t *testing.T) { func Test_NoDAG_LoggingWithLimits(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -149,7 +151,7 @@ func Test_NoDAG_LoggingWithLimits(t *testing.T) { func Test_NoDAG_EmitMetricWithLimits(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -195,7 +197,7 @@ func Test_NoDAG_EmitMetricWithLimits(t *testing.T) { func Test_NoDAG_EmitMetricDisabled(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -228,7 +230,7 @@ func Test_NoDAG_UnparseableRequirements(t *testing.T) { err := runTeeFailureTest(t, sdk.TeeType_TEE_TYPE_AWS_NITRO, binary) assert.Error(t, err) - rerunErr := &RequirementsRerun{} + rerunErr := &generichost.RequirementsRerun{} assert.False(t, errors.As(err, &rerunErr)) } @@ -239,7 +241,7 @@ func Test_NoDAG_InvalidMemoryAddressForRequirements(t *testing.T) { err := runTeeFailureTest(t, sdk.TeeType_TEE_TYPE_AWS_NITRO, binary) assert.Error(t, err) - rerunErr := &RequirementsRerun{} + rerunErr := &generichost.RequirementsRerun{} assert.False(t, errors.As(err, &rerunErr)) } @@ -251,7 +253,7 @@ func Test_NoDAG_RequirementsNotMet(t *testing.T) { // Different (non-existent) TEE err := runTeeFailureTest(t, 999, binary) - rerunErr := &RequirementsRerun{} + rerunErr := &generichost.RequirementsRerun{} require.True(t, errors.As(err, &rerunErr)) expected := &sdk.Requirements{ @@ -264,11 +266,11 @@ func Test_NoDAG_RequirementsNotMet(t *testing.T) { func runTeeFailureTest(t *testing.T, teeType sdk.TeeType, binary []byte) error { cfg := defaultNoDAGModCfg(t) - cfg.RequirementsHandler.Tee = NewTeeProvider(teeType, nil) + cfg.RequirementsHandler.Tee = generichost.NewTeeProvider(teeType, nil) m, err := NewModule(t.Context(), cfg, binary) require.NoError(t, err) - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -287,7 +289,7 @@ func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { } func getTriggersSpec(t *testing.T, m ModuleV2, config []byte) (*sdk.TriggerSubscriptionRequest, error) { - helper := NewMockExecutionHelper(t) + helper := mocks.NewMockExecutionHelper(t) helper.EXPECT().GetWorkflowExecutionID().Return("Id") helper.EXPECT().GetNodeTime().Return(time.Now()).Maybe() execResult, err := m.Execute(t.Context(), &sdk.ExecuteRequest{ From cf01251557979e268f45c9e9f01d5012abf2ef3a Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Mon, 27 Apr 2026 10:27:03 -0400 Subject: [PATCH 06/14] Add max time for requirements selection --- .../host/requirement_selecting_module.go | 6 ++++ .../host/requirement_selecting_module_test.go | 30 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go index 0cfc7232b9..91ac5fbc7a 100644 --- a/pkg/workflows/host/requirement_selecting_module.go +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "sync" + "time" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) @@ -72,6 +73,7 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E } } + start := time.Now() result, err := r.main.Execute(ctx, request, handler) if err == nil { return result, nil @@ -82,6 +84,10 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E return nil, err } + if time.Now().Sub(start) > time.Second*10 { + return nil, errors.New("rerun requirement specified too late") + } + for i, m := range r.additional { if CheckRequirements(m.RequirementsHandler, (*sdk.Requirements)(rerun)) { m.ensureStarted() diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index 8641089953..7f8d8434aa 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -4,6 +4,7 @@ import ( "context" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -442,6 +443,35 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { assert.Equal(t, 1, callCount, "main should only be called once for trigger 1") }) + t.Run("RequirementsRerun returned too late is rejected", func(t *testing.T) { + rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + time.Sleep(11 * time.Second) + return nil, rerunErr + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "rerun requirement specified too late") + }) + t.Run("RequirementsRerun with no additional modules returns error", func(t *testing.T) { rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} From b8745afcfc871d3c3e311b37bf085fe07d057153 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Mon, 27 Apr 2026 14:21:29 -0400 Subject: [PATCH 07/14] Use subscription for the trigger requirement --- go.mod | 2 +- go.sum | 4 +- .../host/requirement_selecting_module.go | 44 ++-- .../host/requirement_selecting_module_test.go | 247 ++++++++---------- pkg/workflows/host/requirements_rerun.go | 16 -- pkg/workflows/wasm/host/execution.go | 28 -- .../host/internal/rawsdk/helpers_wasip1.go | 3 - pkg/workflows/wasm/host/module.go | 11 - pkg/workflows/wasm/host/standard_test.go | 78 +++--- .../standard_tests/tee_runtime/main_wasip1.go | 22 +- pkg/workflows/wasm/host/wasm_nodag_test.go | 79 +----- 11 files changed, 198 insertions(+), 336 deletions(-) delete mode 100644 pkg/workflows/host/requirements_rerun.go diff --git a/go.mod b/go.mod index ce593a074f..7ea74b93c3 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260424192350-c2ff1c3f6163 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260427170224-3b3204904066 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index ed7fd44978..7c7d9bb6ff 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260424192350-c2ff1c3f6163 h1:MfHAshLU/p25XvIafw6sPrBaBKwpeTNVANADiMLzeak= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260424192350-c2ff1c3f6163/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260427170224-3b3204904066 h1:XgmfrVnD6Z2yf6f+4qcGZlqvdJlffRippMmvqE8Yl3c= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260427170224-3b3204904066/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go index 91ac5fbc7a..d6ac6ad746 100644 --- a/pkg/workflows/host/requirement_selecting_module.go +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -2,10 +2,8 @@ package host import ( "context" - "errors" "fmt" "sync" - "time" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) @@ -65,40 +63,38 @@ func (r *requirementSelectingModule) IsLegacyDAG() bool { } func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { - triggerID, hasTrigger := extractTriggerID(request) - - if hasTrigger { - if idx, ok := r.cache.Load(triggerID); ok { + if triggerID, ok := extractTriggerID(request); ok { + if idx, cached := r.cache.Load(triggerID); cached { return r.additional[idx.(int)].Execute(ctx, request, handler) } + return r.main.Execute(ctx, request, handler) } - start := time.Now() + // Subscribe: run main, then build triggerID→module cache from subscription requirements result, err := r.main.Execute(ctx, request, handler) - if err == nil { - return result, nil - } - - rerun := &RequirementsRerun{} - if !errors.As(err, &rerun) { + if err != nil { return nil, err } - if time.Now().Sub(start) > time.Second*10 { - return nil, errors.New("rerun requirement specified too late") - } - - for i, m := range r.additional { - if CheckRequirements(m.RequirementsHandler, (*sdk.Requirements)(rerun)) { - m.ensureStarted() - if hasTrigger { - r.cache.Store(triggerID, i) + for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() { + if sub.Requirements == nil { + continue + } + matched := false + for j, m := range r.additional { + if CheckRequirements(m.RequirementsHandler, sub.Requirements) { + m.ensureStarted() + r.cache.Store(uint64(i), j) + matched = true + break } - return m.Execute(ctx, request, handler) + } + if !matched { + return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements for trigger %d", i) } } - return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements %+v", rerun) + return result, nil } func extractTriggerID(req *sdk.ExecuteRequest) (uint64, bool) { diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index 7f8d8434aa..4c70d2196e 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -4,10 +4,10 @@ import ( "context" "sync/atomic" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" @@ -44,6 +44,20 @@ func subscribeRequest() *sdk.ExecuteRequest { } } +func subscribeResult(subs ...*sdk.TriggerSubscription) *sdk.ExecutionResult { + return &sdk.ExecutionResult{ + Result: &sdk.ExecutionResult_TriggerSubscriptions{ + TriggerSubscriptions: &sdk.TriggerSubscriptionRequest{ + Subscriptions: subs, + }, + }, + } +} + +func subWithReqs(reqs *sdk.Requirements) *sdk.TriggerSubscription { + return &sdk.TriggerSubscription{Requirements: reqs} +} + func TestRequirementSelectingModule_Start(t *testing.T) { t.Run("starts only main module", func(t *testing.T) { var mainStarted, additionalStarted bool @@ -77,23 +91,20 @@ func TestRequirementSelectingModule_Close(t *testing.T) { }) t.Run("closes main and all started additional modules", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} var mainClosed, add0Closed, add1Closed bool main := ModuleAndHandler{Module: &stubModule{ startFn: noop, closeFn: func() { mainClosed = true }, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil }, }} add0 := ModuleAndHandler{ Module: &stubModule{ startFn: noop, closeFn: func() { add0Closed = true }, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return &sdk.ExecutionResult{}, nil - }, }, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, } @@ -108,7 +119,7 @@ func TestRequirementSelectingModule_Close(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) m.Start() - _, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) m.Close() @@ -126,12 +137,15 @@ func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { } func TestRequirementSelectingModule_Execute(t *testing.T) { - t.Run("main succeeds — returns result directly", func(t *testing.T) { + t.Run("trigger with no cached entry goes to main", func(t *testing.T) { want := &sdk.ExecutionResult{} main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return want, nil + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + return want, nil + } + return subscribeResult(), nil }, }} @@ -143,7 +157,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { assert.Equal(t, want, got) }) - t.Run("main non-RequirementsRerun error propagates", func(t *testing.T) { + t.Run("main error on subscribe propagates", func(t *testing.T) { main := ModuleAndHandler{Module: &stubModule{ startFn: noop, executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { @@ -164,18 +178,18 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - _, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) assert.ErrorIs(t, err, assert.AnError) }) - t.Run("RequirementsRerun routes to matching additional", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + t.Run("subscribe with requirements routes trigger to additional", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} want := &sdk.ExecutionResult{} main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil }, }} add := ModuleAndHandler{ @@ -192,18 +206,21 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - got, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) require.NoError(t, err) assert.Equal(t, want, got) }) - t.Run("RequirementsRerun with no matching additional returns error", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + t.Run("subscribe with unmatched requirements returns error", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil }, }} add := ModuleAndHandler{ @@ -214,19 +231,19 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - _, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.Error(t, err) assert.Contains(t, err.Error(), "cannot find a runner that can satisfy the requirements") }) - t.Run("RequirementsRerun skips non-matching and selects later match", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + t.Run("subscribe skips non-matching and selects later additional", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} want := &sdk.ExecutionResult{} main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil }, }} add0 := ModuleAndHandler{ @@ -247,49 +264,48 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) m.Start() - got, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) require.NoError(t, err) assert.Equal(t, want, got) }) - t.Run("additional module started lazily", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + t.Run("additional module started lazily during subscribe", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} var addStartCount int32 main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil }, }} add := ModuleAndHandler{ Module: &stubModule{ startFn: func() { atomic.AddInt32(&addStartCount, 1) }, closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return &sdk.ExecutionResult{}, nil - }, }, RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() + assert.Equal(t, int32(0), atomic.LoadInt32(&addStartCount)) - // First execution starts the additional module. - _, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) - // Second execution with a different trigger still goes through main, - // but the additional module is not started again (sync.Once). - _, err = m.Execute(t.Context(), triggerRequest(2), nil) + // Second subscribe does not start additional again (sync.Once). + _, err = m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) }) - t.Run("subscribe request goes through main directly", func(t *testing.T) { - want := &sdk.ExecutionResult{} + t.Run("subscribe with no requirements returns main result", func(t *testing.T) { + want := subscribeResult() main := ModuleAndHandler{Module: &stubModule{ startFn: noop, @@ -309,14 +325,16 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { func TestRequirementSelectingModule_TriggerCache(t *testing.T) { t.Run("cached trigger skips main on subsequent calls", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - var mainCalls int32 + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var mainTriggerCalls int32 main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - atomic.AddInt32(&mainCalls, 1) - return nil, rerunErr + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + } + return subscribeResult(subWithReqs(teeReqs)), nil }, }} add := ModuleAndHandler{ @@ -333,26 +351,31 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - // First call triggers main. - _, err := m.Execute(t.Context(), triggerRequest(42), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) require.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&mainCalls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls), "cached trigger should skip main") - // Second call with same trigger ID skips main. - _, err = m.Execute(t.Context(), triggerRequest(42), nil) + _, err = m.Execute(t.Context(), triggerRequest(0), nil) require.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&mainCalls)) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls), "cached trigger should skip main on repeat") }) - t.Run("different trigger IDs are cached independently", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} - var mainCalls int32 + t.Run("trigger not in cache goes to main", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var mainTriggerCalls int32 main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - atomic.AddInt32(&mainCalls, 1) - return nil, rerunErr + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return &sdk.ExecutionResult{}, nil + } + // subscription 0 has requirements; subscription 1 does not + return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil }, }} add := ModuleAndHandler{ @@ -369,123 +392,77 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - // Trigger 1 goes through main. - _, err := m.Execute(t.Context(), triggerRequest(1), nil) - require.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&mainCalls)) - - // Trigger 2 also goes through main (different ID, not cached). - _, err = m.Execute(t.Context(), triggerRequest(2), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) - assert.Equal(t, int32(2), atomic.LoadInt32(&mainCalls)) - // Both are now cached — neither goes through main. + // trigger 1 has no requirements → goes to main _, err = m.Execute(t.Context(), triggerRequest(1), nil) require.NoError(t, err) - _, err = m.Execute(t.Context(), triggerRequest(2), nil) - require.NoError(t, err) - assert.Equal(t, int32(2), atomic.LoadInt32(&mainCalls)) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls)) }) - t.Run("different triggers can route to different additional modules", func(t *testing.T) { - teeRerun := &RequirementsRerun{Tee: &sdk.Tee{ + t.Run("different triggers route to different modules", func(t *testing.T) { + // subscription 0: TEE required → additional; subscription 1: no requirements → main + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{ Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{ Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}, }}, }} - noReqRerun := &RequirementsRerun{} + var mainTriggerCalls int32 + wantAdditional := &sdk.ExecutionResult{} - callCount := 0 main := ModuleAndHandler{Module: &stubModule{ startFn: noop, executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { - callCount++ - if req.GetTrigger().Id == 1 { - return nil, teeRerun + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return &sdk.ExecutionResult{}, nil } - return nil, noReqRerun + return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil }, }} - - var addNitroResult, addDefaultResult sdk.ExecutionResult - addNitro := ModuleAndHandler{ - Module: &stubModule{ - startFn: noop, closeFn: noopClose, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return &addNitroResult, nil - }, - }, - RequirementsHandler: RequirementsHandler{Tee: func(tee *sdk.Tee) bool { - return tee.GetTypeSelection() != nil - }}, - } - addDefault := ModuleAndHandler{ + add := ModuleAndHandler{ Module: &stubModule{ startFn: noop, closeFn: noopClose, executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return &addDefaultResult, nil + return wantAdditional, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, } - m := NewRequirementSelectingModule(main, []ModuleAndHandler{addNitro, addDefault}) + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) m.Start() - got, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.NoError(t, err) - assert.Equal(t, &addNitroResult, got) - // Trigger 1 is now cached to addNitro; verify second call skips main. - got, err = m.Execute(t.Context(), triggerRequest(1), nil) + // trigger 0 has TEE requirements → additional + got, err := m.Execute(t.Context(), triggerRequest(0), nil) require.NoError(t, err) - assert.Equal(t, &addNitroResult, got) - assert.Equal(t, 1, callCount, "main should only be called once for trigger 1") - }) - - t.Run("RequirementsRerun returned too late is rejected", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + assert.Equal(t, wantAdditional, got) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls)) - main := ModuleAndHandler{Module: &stubModule{ - startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - time.Sleep(11 * time.Second) - return nil, rerunErr - }, - }} - add := ModuleAndHandler{ - Module: &stubModule{ - startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - t.Fatal("additional module should not be called") - return nil, nil - }, - }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, - } - - m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) - m.Start() - - _, err := m.Execute(t.Context(), triggerRequest(1), nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "rerun requirement specified too late") + // trigger 1 has no requirements → main + _, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls)) }) - t.Run("RequirementsRerun with no additional modules returns error", func(t *testing.T) { - rerunErr := &RequirementsRerun{Tee: &sdk.Tee{}} + t.Run("no additional modules when subscribe has requirements returns error", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} main := ModuleAndHandler{Module: &stubModule{ startFn: noop, - executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { - return nil, rerunErr + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil }, }} m := NewRequirementSelectingModule(main, nil) m.Start() - _, err := m.Execute(t.Context(), triggerRequest(1), nil) + _, err := m.Execute(t.Context(), subscribeRequest(), nil) require.Error(t, err) assert.Contains(t, err.Error(), "cannot find a runner") }) diff --git a/pkg/workflows/host/requirements_rerun.go b/pkg/workflows/host/requirements_rerun.go deleted file mode 100644 index 9b64f4e928..0000000000 --- a/pkg/workflows/host/requirements_rerun.go +++ /dev/null @@ -1,16 +0,0 @@ -package host - -import ( - "google.golang.org/protobuf/encoding/protojson" - - "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" -) - -type RequirementsRerun sdk.Requirements - -func (r *RequirementsRerun) Error() string { - str, _ := protojson.Marshal((*sdk.Requirements)(r)) - return string(str) -} - -var _ error = (*RequirementsRerun)(nil) diff --git a/pkg/workflows/wasm/host/execution.go b/pkg/workflows/wasm/host/execution.go index 97bd3147df..ec9fd1bbfd 100644 --- a/pkg/workflows/wasm/host/execution.go +++ b/pkg/workflows/wasm/host/execution.go @@ -10,8 +10,6 @@ import ( "github.com/bytecodealliance/wasmtime-go/v28" "google.golang.org/protobuf/proto" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" - "github.com/smartcontractkit/chainlink-common/pkg/config" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -34,7 +32,6 @@ type execution[T any] struct { nodeSeed int64 donLogCount uint32 nodeLogCount uint32 - requirementsRerunErr error } // callCapAsync async calls a capability by placing execution results onto a @@ -384,28 +381,3 @@ func (e *execution[T]) pollOneoff(caller *wasmtime.Caller, subscriptionptr int32 return ErrnoSuccess } - -// A trap return will cause the execution to halt. -// This function fails safe and prefers to kill the program than to return an error to the user. -// It does this because a failure here could lead to code running in an environment it's not allowed in -// Although the runtime could protect from this instead, it's safer to fail as early as possible -func (e *execution[T]) requirements(caller *wasmtime.Caller, ptr int32, ptrlen int32) *wasmtime.Trap { - requirements := &sdkpb.Requirements{} - payload, err := wasmRead(caller, ptr, ptrlen) - if err != nil { - e.requirementsRerunErr = fmt.Errorf("error reading requirements: %s", err) - return wasmtime.NewTrap(e.requirementsRerunErr.Error()) - } - - if err = proto.Unmarshal(payload, requirements); err != nil { - e.requirementsRerunErr = fmt.Errorf("error unmarshalling requirements: %s", err) - return wasmtime.NewTrap(e.requirementsRerunErr.Error()) - } - - if !host.CheckRequirements(e.module.cfg.RequirementsHandler, requirements) { - e.requirementsRerunErr = (*host.RequirementsRerun)(requirements) - return wasmtime.NewTrap(e.requirementsRerunErr.Error()) - } - - return nil -} diff --git a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go index 3ad5843589..e18a5fcd59 100644 --- a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go +++ b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go @@ -252,6 +252,3 @@ func getSecrets(req unsafe.Pointer, reqLen int32, responseBuffer unsafe.Pointer, //go:wasmimport env await_secrets func awaitSecrets(req unsafe.Pointer, reqLen int32, responseBuffer unsafe.Pointer, maxResponseLen int32) int64 - -//go:wasmimport env requirements -func Requirements(req unsafe.Pointer, reqLen int32) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 9f29c2cf38..8c53fbb0f8 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -468,13 +468,6 @@ func linkNoDAG(m *module, store *wasmtime.Store, exec *execution[*sdkpb.Executio return nil, fmt.Errorf("error wrapping get_time func: %w", err) } - if err = linker.FuncWrap( - "env", - "requirements", - exec.requirements); err != nil { - return nil, fmt.Errorf("error wrapping requirements func: %w", err) - } - return linker.Instantiate(store, m.module) } @@ -712,10 +705,6 @@ func runWasm[I, O proto.Message]( return o, fmt.Errorf("invariant violation: host errored during sendResponse") } - if exec.requirementsRerunErr != nil { - return o, exec.requirementsRerunErr - } - // If an error has occurred and the deadline has been reached or exceeded, return a deadline exceeded error. // Note - there is no other reliable signal on the error that can be used to infer it is due to epoch deadline // being reached, so if an error is returned after the deadline it is assumed it is due to that and return diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index d04053f9c9..d4e8ccf660 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -528,42 +528,57 @@ func TestStandardTimeInterpretation(t *testing.T) { func TestStandardTeeRuntime(t *testing.T) { t.Parallel() - trigger := &basictrigger.Outputs{CoolOutput: anyTestTriggerValue} - cfg := defaultNoDAGModCfg(t) - var seenTeeRequirement *sdk.Tee - cfg.RequirementsHandler.Tee = func(tee *sdk.Tee) bool { - seenTeeRequirement = tee - return true - } - m := makeTestModuleWithConfig(t, cfg) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() - for _, test := range []struct { - name string - req *sdk.ExecuteRequest - }{ - { - name: "subscribe", - req: &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}}, - }, - { - name: "execute", - req: triggerExecuteRequest(t, 0, trigger), + subscribe := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}} + actual, err := m.Execute(t.Context(), subscribe, mockExecutionHelper) + require.NoError(t, err) + + payload0, err := anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + }) + require.NoError(t, err) + + payload1, err := anypb.New(&basictrigger.Config{ + Name: "second-trigger", + Number: 200, + }) + require.NoError(t, err) + + expected := &sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: payload0, + Method: "Trigger", + Requirements: &sdk.Requirements{ + Tee: &sdk.Tee{ + Type: &sdk.Tee_TypeSelection{ + TypeSelection: &sdk.TeeTypeSelection{ + Types: []*sdk.TeeTypeAndRegions{ + {Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }, + }, + }, + }, + { + Id: "basic-test-trigger@1.0.0", + Payload: payload1, + Method: "Trigger", + }, }, - } { - t.Run(test.name, func(t *testing.T) { - mockExecutionHelper := mocks.NewMockExecutionHelper(t) - mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") - mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { - return time.Now() - }).Maybe() - - _, err := m.Execute(t.Context(), test.req, mockExecutionHelper) - require.NoError(t, err) - require.True(t, proto.Equal(seenTeeRequirement, &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}})) - }) } + + assertProto(t, expected, actual.GetTriggerSubscriptions()) } func triggerExecuteRequest(t *testing.T, id uint64, trigger proto.Message) *sdk.ExecuteRequest { @@ -685,6 +700,7 @@ func wrapValue(t *testing.T, nodeResponse *nodeaction.NodeOutputs) *valuespb.Val func assertProto[T proto.Message](t *testing.T, expected, actual T) { t.Helper() + require.NotNil(t, actual) diff := cmp.Diff(expected, actual, protocmp.Transform()) var sb strings.Builder diff --git a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go index d8455ad3a1..a17cb37de8 100644 --- a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go +++ b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go @@ -1,7 +1,6 @@ package main import ( - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" @@ -10,13 +9,7 @@ import ( ) func main() { - req := rawsdk.GetRequest() requirements := &sdk.Requirements{Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}}} - bytes, err := proto.Marshal(requirements) - if err != nil { - rawsdk.SendError(err) - } - rawsdk.Requirements(rawsdk.BufferToPointerLen(bytes)) subscription := &sdk.TriggerSubscriptionRequest{ Subscriptions: []*sdk.TriggerSubscription{ { @@ -25,14 +18,19 @@ func main() { Name: "first-trigger", Number: 100, })), + Method: "Trigger", + Requirements: requirements, + }, + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "second-trigger", + Number: 200, + })), Method: "Trigger", }, }, } - switch req.GetRequest().(type) { - case *sdk.ExecuteRequest_Subscribe: - rawsdk.SendSubscription(subscription) - } - rawsdk.SendResponse(0) + rawsdk.SendSubscription(subscription) } diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 5b3a82a2ed..82d57c9667 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -7,14 +7,11 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/emptypb" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" - generichost "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -24,18 +21,12 @@ import ( ) const ( - nodagRandomBinaryCmd = "standard_tests/multiple_triggers" - nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" - loggingLimitsBinaryCmd = "test/logging_limits/cmd" - loggingLimitsBinaryLocation = loggingLimitsBinaryCmd + "/testmodule.wasm" - metricLimitsBinaryCmd = "test/metric_limits/cmd" - metricLimitsBinaryLocation = metricLimitsBinaryCmd + "/testmodule.wasm" - standardTeeRuntimeBinaryCmd = "standard_tests/tee_runtime" - standardTeeRuntimeBinaryLocation = standardTeeRuntimeBinaryCmd + "/testmodule.wasm" - invalidMemoryForRequirementsCmd = "test/requirements/invalid_memory" - invalidMemoryForRequirementsBinaryLocation = invalidMemoryForRequirementsCmd + "/testmodule.wasm" - invalidProtoForRequirementsCmd = "test/requirements/invalid_proto" - invalidProtoForRequirementsBinaryLocation = invalidProtoForRequirementsCmd + "/testmodule.wasm" + nodagRandomBinaryCmd = "standard_tests/multiple_triggers" + nodagRandomBinaryLocation = nodagRandomBinaryCmd + "/testmodule.wasm" + loggingLimitsBinaryCmd = "test/logging_limits/cmd" + loggingLimitsBinaryLocation = loggingLimitsBinaryCmd + "/testmodule.wasm" + metricLimitsBinaryCmd = "test/metric_limits/cmd" + metricLimitsBinaryLocation = metricLimitsBinaryCmd + "/testmodule.wasm" ) func Test_Sleep_Timeout(t *testing.T) { @@ -223,64 +214,6 @@ func Test_NoDAG_EmitMetricDisabled(t *testing.T) { // EmitUserMetric should never be called when disabled - no mock expectation set } -func Test_NoDAG_UnparseableRequirements(t *testing.T) { - t.Parallel() - binary := createTestBinary(invalidProtoForRequirementsCmd, invalidProtoForRequirementsBinaryLocation, true, t) - - err := runTeeFailureTest(t, sdk.TeeType_TEE_TYPE_AWS_NITRO, binary) - - assert.Error(t, err) - rerunErr := &generichost.RequirementsRerun{} - assert.False(t, errors.As(err, &rerunErr)) -} - -func Test_NoDAG_InvalidMemoryAddressForRequirements(t *testing.T) { - t.Parallel() - binary := createTestBinary(invalidMemoryForRequirementsCmd, invalidMemoryForRequirementsBinaryLocation, true, t) - - err := runTeeFailureTest(t, sdk.TeeType_TEE_TYPE_AWS_NITRO, binary) - - assert.Error(t, err) - rerunErr := &generichost.RequirementsRerun{} - assert.False(t, errors.As(err, &rerunErr)) -} - -func Test_NoDAG_RequirementsNotMet(t *testing.T) { - t.Parallel() - - binary := createTestBinary(standardTeeRuntimeBinaryCmd, standardTeeRuntimeBinaryLocation, true, t) - - // Different (non-existent) TEE - err := runTeeFailureTest(t, 999, binary) - - rerunErr := &generichost.RequirementsRerun{} - require.True(t, errors.As(err, &rerunErr)) - - expected := &sdk.Requirements{ - Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{ - TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}, - }}, - } - assert.True(t, proto.Equal(expected, (*sdk.Requirements)(rerunErr))) -} - -func runTeeFailureTest(t *testing.T, teeType sdk.TeeType, binary []byte) error { - cfg := defaultNoDAGModCfg(t) - cfg.RequirementsHandler.Tee = generichost.NewTeeProvider(teeType, nil) - m, err := NewModule(t.Context(), cfg, binary) - require.NoError(t, err) - - mockExecutionHelper := mocks.NewMockExecutionHelper(t) - mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") - mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { - return time.Now() - }).Maybe() - subscribe := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}} - - _, err = m.Execute(t.Context(), subscribe, mockExecutionHelper) - return err -} - func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { return &ModuleConfig{ Logger: logger.Test(t), From 2b8b244c61c0c0588f693d3524ec3a88b904712d Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Tue, 28 Apr 2026 12:48:46 -0400 Subject: [PATCH 08/14] Main module satisfies requirements --- .../host/requirement_selecting_module.go | 2 +- .../host/requirement_selecting_module_test.go | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go index d6ac6ad746..8e650ff591 100644 --- a/pkg/workflows/host/requirement_selecting_module.go +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -77,7 +77,7 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E } for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() { - if sub.Requirements == nil { + if sub.Requirements == nil || CheckRequirements(r.main.RequirementsHandler, sub.Requirements) { continue } matched := false diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index 4c70d2196e..b1f8854552 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -321,6 +321,47 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { require.NoError(t, err) assert.Equal(t, want, got) }) + + t.Run("main module satisfying requirements keeps trigger on main", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + var mainTriggerCalls int32 + main := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return want, nil + } + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called when main satisfies requirements") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls), "trigger should run on main") + }) } func TestRequirementSelectingModule_TriggerCache(t *testing.T) { From 58ec02d5cc0e6e6821ba726a5c133438b6902ac1 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Wed, 29 Apr 2026 10:19:25 -0400 Subject: [PATCH 09/14] remove unused hook from module, the selecting one is now responsible for the decision --- pkg/workflows/wasm/host/module.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 8c53fbb0f8..2c05aec069 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -102,8 +102,7 @@ type ModuleConfig struct { // If Determinism is set, the module will override the random_get function in the WASI API with // the provided seed to ensure deterministic behavior. - Determinism *DeterminismConfig - RequirementsHandler host.RequirementsHandler + Determinism *DeterminismConfig } type ModuleBase = host.ModuleBase From 937bc208136b8ddd5eba4fb5a97c02e745572161 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Thu, 30 Apr 2026 10:33:02 -0400 Subject: [PATCH 10/14] Hook to ask implementation to determine if it supports regions --- go.mod | 2 +- go.sum | 4 +- .../actions/confidentialworkflow/client.pb.go | 87 ++++++++++++++++--- .../server/client_server_gen.go | 16 ++++ .../host/requirement_selecting_module.go | 4 +- .../host/requirement_selecting_module_test.go | 26 +++--- .../requirements_helper.go.tmpl | 10 ++- pkg/workflows/host/requirements_helper_gen.go | 14 +-- .../host/requirements_helper_gen_test.go | 15 ++-- pkg/workflows/host/tee_provider.go | 32 +++++-- pkg/workflows/host/tee_provider_test.go | 51 ++++++----- 11 files changed, 183 insertions(+), 78 deletions(-) diff --git a/go.mod b/go.mod index 7ea74b93c3..94b1873933 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260427170224-3b3204904066 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430150654-a1ec54d2121f github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 7c7d9bb6ff..45b51dc27c 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260427170224-3b3204904066 h1:XgmfrVnD6Z2yf6f+4qcGZlqvdJlffRippMmvqE8Yl3c= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260427170224-3b3204904066/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430150654-a1ec54d2121f h1:Bxqdqu/me/tAMOQIwxY54MAb4AI1UMbqEjC93ws0krw= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430150654-a1ec54d2121f/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go index c9f311fdb1..95fab33b55 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 +// protoc-gen-go v1.36.6 // protoc v5.29.3 // source: capabilities/compute/confidentialworkflow/v1alpha/client.proto @@ -10,6 +10,7 @@ import ( _ "github.com/smartcontractkit/chainlink-protos/cre/go/tools/generator" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" sync "sync" unsafe "unsafe" @@ -96,7 +97,9 @@ type WorkflowExecution struct { ExecutionId string `protobuf:"bytes,6,opt,name=execution_id,json=executionId,proto3" json:"execution_id,omitempty"` // org_id is the organization identifier for the workflow owner. // Used by the enclave when fetching secrets from VaultDON with org-based ownership. - OrgId string `protobuf:"bytes,7,opt,name=org_id,json=orgId,proto3" json:"org_id,omitempty"` + OrgId string `protobuf:"bytes,7,opt,name=org_id,json=orgId,proto3" json:"org_id,omitempty"` + // regions that the workflow is allowed to run in. + Regions []string `protobuf:"bytes,8,rep,name=regions,proto3" json:"regions,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -180,6 +183,13 @@ func (x *WorkflowExecution) GetOrgId() string { return "" } +func (x *WorkflowExecution) GetRegions() []string { + if x != nil { + return x.Regions + } + return nil +} + // ConfidentialWorkflowRequest is the input provided to the confidential workflows capability. // It combines a WorkflowExecution with secrets from VaultDON. type ConfidentialWorkflowRequest struct { @@ -280,16 +290,60 @@ func (x *ConfidentialWorkflowResponse) GetExecutionResult() []byte { return nil } +type GetRegionsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Regions []string `protobuf:"bytes,1,rep,name=regions,proto3" json:"regions,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetRegionsResponse) Reset() { + *x = GetRegionsResponse{} + mi := &file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetRegionsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetRegionsResponse) ProtoMessage() {} + +func (x *GetRegionsResponse) ProtoReflect() protoreflect.Message { + mi := &file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetRegionsResponse.ProtoReflect.Descriptor instead. +func (*GetRegionsResponse) Descriptor() ([]byte, []int) { + return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescGZIP(), []int{4} +} + +func (x *GetRegionsResponse) GetRegions() []string { + if x != nil { + return x.Regions + } + return nil +} + var File_capabilities_compute_confidentialworkflow_v1alpha_client_proto protoreflect.FileDescriptor const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc = "" + "\n" + - ">capabilities/compute/confidentialworkflow/v1alpha/client.proto\x121capabilities.compute.confidentialworkflow.v1alpha\x1a*tools/generator/v1alpha/cre_metadata.proto\"U\n" + + ">capabilities/compute/confidentialworkflow/v1alpha/client.proto\x121capabilities.compute.confidentialworkflow.v1alpha\x1a*tools/generator/v1alpha/cre_metadata.proto\x1a\x1bgoogle/protobuf/empty.proto\"U\n" + "\x10SecretIdentifier\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12!\n" + "\tnamespace\x18\x02 \x01(\tH\x00R\tnamespace\x88\x01\x01B\f\n" + "\n" + - "_namespace\"\xed\x01\n" + + "_namespace\"\x87\x02\n" + "\x11WorkflowExecution\x12\x1f\n" + "\vworkflow_id\x18\x01 \x01(\tR\n" + "workflowId\x12\x1d\n" + @@ -300,14 +354,19 @@ const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDes "\x0fexecute_request\x18\x04 \x01(\fR\x0eexecuteRequest\x12\x14\n" + "\x05owner\x18\x05 \x01(\tR\x05owner\x12!\n" + "\fexecution_id\x18\x06 \x01(\tR\vexecutionId\x12\x15\n" + - "\x06org_id\x18\a \x01(\tR\x05orgId\"\xf2\x01\n" + + "\x06org_id\x18\a \x01(\tR\x05orgId\x12\x18\n" + + "\aregions\x18\b \x03(\tR\aregions\"\xf2\x01\n" + "\x1bConfidentialWorkflowRequest\x12o\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2C.capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12b\n" + "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"I\n" + "\x1cConfidentialWorkflowResponse\x12)\n" + - "\x10execution_result\x18\x01 \x01(\fR\x0fexecutionResult2\xe1\x01\n" + + "\x10execution_result\x18\x01 \x01(\fR\x0fexecutionResult\".\n" + + "\x12GetRegionsResponse\x12\x18\n" + + "\aregions\x18\x01 \x03(\tR\aregions2\xce\x02\n" + "\x06Client\x12\xaa\x01\n" + - "\aExecute\x12N.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest\x1aO.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse\x1a*\x82\xb5\x18&\b\x01\x12\"confidential-workflows@1.0.0-alphab\x06proto3" + "\aExecute\x12N.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest\x1aO.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse\x12k\n" + + "\n" + + "GetRegions\x12\x16.google.protobuf.Empty\x1aE.capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse\x1a*\x82\xb5\x18&\b\x01\x12\"confidential-workflows@1.0.0-alphab\x06proto3" var ( file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescOnce sync.Once @@ -321,20 +380,24 @@ func file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescData } -var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_goTypes = []any{ (*SecretIdentifier)(nil), // 0: capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier (*WorkflowExecution)(nil), // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution (*ConfidentialWorkflowRequest)(nil), // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest (*ConfidentialWorkflowResponse)(nil), // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + (*GetRegionsResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse + (*emptypb.Empty)(nil), // 5: google.protobuf.Empty } var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs = []int32{ 0, // 0: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier 1, // 1: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution 2, // 2: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest - 3, // 3: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse - 3, // [3:4] is the sub-list for method output_type - 2, // [2:3] is the sub-list for method input_type + 5, // 3: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:input_type -> google.protobuf.Empty + 3, // 4: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + 4, // 5: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:output_type -> capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse + 4, // [4:6] is the sub-list for method output_type + 2, // [2:4] is the sub-list for method input_type 2, // [2:2] is the sub-list for extension type_name 2, // [2:2] is the sub-list for extension extendee 0, // [0:2] is the sub-list for field type_name @@ -352,7 +415,7 @@ func file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_init() GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc), len(file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc)), NumEnums: 0, - NumMessages: 4, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go b/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go index 222bc4fa9f..3d30de9c87 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go @@ -21,6 +21,8 @@ var _ = emptypb.Empty{} type ClientCapability interface { Execute(ctx context.Context, metadata capabilities.RequestMetadata, input *confidentialworkflow.ConfidentialWorkflowRequest) (*capabilities.ResponseAndMetadata[*confidentialworkflow.ConfidentialWorkflowResponse], caperrors.Error) + GetRegions(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty) (*capabilities.ResponseAndMetadata[*confidentialworkflow.GetRegionsResponse], caperrors.Error) + Start(ctx context.Context) error Close() error HealthReport() map[string]error @@ -137,6 +139,20 @@ func (c *clientCapability) Execute(ctx context.Context, request capabilities.Cap return output.Response, output.ResponseMetadata, output.OCRAttestation, err } return capabilities.Execute(ctx, request, input, config, wrapped) + case "GetRegions": + input := &emptypb.Empty{} + config := &emptypb.Empty{} + wrapped := func(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty, _ *emptypb.Empty) (*confidentialworkflow.GetRegionsResponse, capabilities.ResponseMetadata, *capabilities.OCRAttestation, error) { + output, err := c.ClientCapability.GetRegions(ctx, metadata, input) + if err != nil { + return nil, capabilities.ResponseMetadata{}, nil, err + } + if output == nil { + return nil, capabilities.ResponseMetadata{}, nil, fmt.Errorf("output and error is nil for method GetRegions(..) (if output is nil error must be present)") + } + return output.Response, output.ResponseMetadata, output.OCRAttestation, err + } + return capabilities.Execute(ctx, request, input, config, wrapped) default: return response, fmt.Errorf("method %s not found", request.Method) } diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go index 8e650ff591..edb45d0e5f 100644 --- a/pkg/workflows/host/requirement_selecting_module.go +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -77,12 +77,12 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E } for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() { - if sub.Requirements == nil || CheckRequirements(r.main.RequirementsHandler, sub.Requirements) { + if sub.Requirements == nil || CheckRequirements(ctx, r.main.RequirementsHandler, sub.Requirements) { continue } matched := false for j, m := range r.additional { - if CheckRequirements(m.RequirementsHandler, sub.Requirements) { + if CheckRequirements(ctx, m.RequirementsHandler, sub.Requirements) { m.ensureStarted() r.cache.Store(uint64(i), j) matched = true diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index b1f8854552..e51496a5c6 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -106,14 +106,14 @@ func TestRequirementSelectingModule_Close(t *testing.T) { startFn: noop, closeFn: func() { add0Closed = true }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } add1 := ModuleAndHandler{ Module: &stubModule{ startFn: noop, closeFn: func() { add1Closed = true }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) @@ -172,7 +172,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { return nil, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -200,7 +200,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { return want, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -225,7 +225,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { }} add := ModuleAndHandler{ Module: &stubModule{startFn: noop}, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -248,7 +248,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { }} add0 := ModuleAndHandler{ Module: &stubModule{startFn: noop}, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, } add1 := ModuleAndHandler{ Module: &stubModule{ @@ -258,7 +258,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { return want, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) @@ -287,7 +287,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { startFn: func() { atomic.AddInt32(&addStartCount, 1) }, closeFn: noopClose, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -338,7 +338,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { return subscribeResult(subWithReqs(teeReqs)), nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } add := ModuleAndHandler{ Module: &stubModule{ @@ -348,7 +348,7 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { return nil, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -386,7 +386,7 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { return &sdk.ExecutionResult{}, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -427,7 +427,7 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { return &sdk.ExecutionResult{}, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) @@ -469,7 +469,7 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { return wantAdditional, nil }, }, - RequirementsHandler: RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, } m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) diff --git a/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl index 2d8f102714..b47f7b3098 100644 --- a/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl +++ b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl @@ -1,24 +1,26 @@ package host +import "ctx" + // RequirementsHandler contains a callback for each public field in sdk.Requirements. // Each callback receives the field value and returns a list of strings or an error. type RequirementsHandler struct { {{- range .Fields}} - {{.Name}} func({{.Type}}) bool + {{.Name}} func(context.Context, {{.Type}}) bool {{- end}} } -// CheckRequirements calls each non-nil callback in handler for the corresponding +// CheckRequirements calls each non-nil callback in the handler for the corresponding // non-nil field in req, returning false if any are false, or if the handler is nil. // Unknown fields on the proto also result in a false return value. -func CheckRequirements(handler RequirementsHandler, req *sdk.Requirements) bool { +func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { if len(req.ProtoReflect().GetUnknown()) != 0 { return false } {{range .Fields}} if req.{{.Name}} != nil { - if handler.{{.Name}} == nil || !handler.{{.Name}}(req.{{.Name}}) { + if handler.{{.Name}} == nil || !handler.{{.Name}}(ctx, req.{{.Name}}) { return false } diff --git a/pkg/workflows/host/requirements_helper_gen.go b/pkg/workflows/host/requirements_helper_gen.go index 875da649fa..f8f2ea1486 100644 --- a/pkg/workflows/host/requirements_helper_gen.go +++ b/pkg/workflows/host/requirements_helper_gen.go @@ -2,24 +2,28 @@ package host -import "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +import ( + "context" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) // RequirementsHandler contains a callback for each public field in sdk.Requirements. // Each callback receives the field value and returns a list of strings or an error. type RequirementsHandler struct { - Tee func(*sdk.Tee) bool + Tee func(context.Context, *sdk.Tee) bool } -// CheckRequirements calls each non-nil callback in handler for the corresponding +// CheckRequirements calls each non-nil callback in the handler for the corresponding // non-nil field in req, returning false if any are false, or if the handler is nil. // Unknown fields on the proto also result in a false return value. -func CheckRequirements(handler RequirementsHandler, req *sdk.Requirements) bool { +func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { if len(req.ProtoReflect().GetUnknown()) != 0 { return false } if req.Tee != nil { - if handler.Tee == nil || !handler.Tee(req.Tee) { + if handler.Tee == nil || !handler.Tee(ctx, req.Tee) { return false } diff --git a/pkg/workflows/host/requirements_helper_gen_test.go b/pkg/workflows/host/requirements_helper_gen_test.go index 7c1b9049ff..68bb23b3d7 100644 --- a/pkg/workflows/host/requirements_helper_gen_test.go +++ b/pkg/workflows/host/requirements_helper_gen_test.go @@ -1,6 +1,7 @@ package host import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -21,27 +22,27 @@ func Test_CheckRequirements(t *testing.T) { req := &sdk.Requirements{} require.NoError(t, proto.Unmarshal(b, req)) - assert.False(t, CheckRequirements(RequirementsHandler{}, req)) + assert.False(t, CheckRequirements(context.Background(), RequirementsHandler{}, req)) }) t.Run("no fields always passes", func(t *testing.T) { - assert.True(t, CheckRequirements(RequirementsHandler{}, &sdk.Requirements{})) + assert.True(t, CheckRequirements(context.Background(), RequirementsHandler{}, &sdk.Requirements{})) }) t.Run("handler not set returns false", func(t *testing.T) { req := &sdk.Requirements{Tee: &sdk.Tee{}} - assert.False(t, CheckRequirements(RequirementsHandler{}, req)) + assert.False(t, CheckRequirements(context.Background(), RequirementsHandler{}, req)) }) t.Run("handler returns false causes false return value", func(t *testing.T) { req := &sdk.Requirements{Tee: &sdk.Tee{}} - handler := RequirementsHandler{Tee: func(*sdk.Tee) bool { return false }} - assert.False(t, CheckRequirements(handler, req)) + handler := RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }} + assert.False(t, CheckRequirements(context.Background(), handler, req)) }) t.Run("handler returns true causes true return value", func(t *testing.T) { req := &sdk.Requirements{Tee: &sdk.Tee{}} - handler := RequirementsHandler{Tee: func(*sdk.Tee) bool { return true }} - assert.True(t, CheckRequirements(handler, req)) + handler := RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }} + assert.True(t, CheckRequirements(context.Background(), handler, req)) }) } diff --git a/pkg/workflows/host/tee_provider.go b/pkg/workflows/host/tee_provider.go index 5905bfdbd2..8bdb9d258a 100644 --- a/pkg/workflows/host/tee_provider.go +++ b/pkg/workflows/host/tee_provider.go @@ -1,21 +1,34 @@ package host -import sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +import ( + "context" + "sync" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) type teeProvider struct { sdkpb.TeeType - regions map[string]bool + regionsFn func(ctx context.Context) map[string]bool + once sync.Once } -func NewTeeProvider(tpe sdkpb.TeeType, regions []string) func(tee *sdkpb.Tee) bool { - supportedRegions := map[string]bool{} - for _, region := range regions { - supportedRegions[region] = true +func NewTeeProvider(tpe sdkpb.TeeType, regionsFn func(ctx context.Context) []string) func(context.Context, *sdkpb.Tee) bool { + p := &teeProvider{ + TeeType: tpe, + regionsFn: func(ctx context.Context) map[string]bool { + regions := regionsFn(ctx) + rMap := make(map[string]bool, len(regions)) + for _, region := range regions { + rMap[region] = true + } + return rMap + }, } - return (&teeProvider{TeeType: tpe, regions: supportedRegions}).Provides + return p.Provides } -func (t *teeProvider) Provides(tee *sdkpb.Tee) bool { +func (t *teeProvider) Provides(ctx context.Context, tee *sdkpb.Tee) bool { switch teet := tee.Type.(type) { case *sdkpb.Tee_Any: return true @@ -26,8 +39,9 @@ func (t *teeProvider) Provides(tee *sdkpb.Tee) bool { return true } + regions := t.regionsFn(ctx) for _, region := range selection.Regions { - if t.regions[region] { + if regions[region] { return true } } diff --git a/pkg/workflows/host/tee_provider_test.go b/pkg/workflows/host/tee_provider_test.go index 5ba04bfb81..d23791c18c 100644 --- a/pkg/workflows/host/tee_provider_test.go +++ b/pkg/workflows/host/tee_provider_test.go @@ -1,6 +1,7 @@ package host import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -9,16 +10,20 @@ import ( sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) +func regionsFn(regions map[string]bool) func(context.Context) map[string]bool { + return func(context.Context) map[string]bool { return regions } +} + func TestNewTeeProvider(t *testing.T) { t.Parallel() t.Run("matches any", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} - assert.True(t, p.Provides(tee)) + assert.True(t, p.Provides(context.Background(), tee)) }) t.Run("matches type selection with no region constraint", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ Types: []*sdkpb.TeeTypeAndRegions{ @@ -27,7 +32,7 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.True(t, p.Provides(tee)) + assert.True(t, p.Provides(context.Background(), tee)) }) t.Run("does not match different types", func(t *testing.T) { @@ -39,11 +44,11 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.False(t, p.Provides(tee)) + assert.False(t, p.Provides(context.Background(), tee)) }) t.Run("matches type and region", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ Types: []*sdkpb.TeeTypeAndRegions{ @@ -51,11 +56,11 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.True(t, p.Provides(tee)) + assert.True(t, p.Provides(context.Background(), tee)) }) t.Run("matches type but not region", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ Types: []*sdkpb.TeeTypeAndRegions{ @@ -63,11 +68,11 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.False(t, p.Provides(tee)) + assert.False(t, p.Provides(context.Background(), tee)) }) t.Run("matches one of multiple requested regions", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"eu-west-1": true}} + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"eu-west-1": true})} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ Types: []*sdkpb.TeeTypeAndRegions{ @@ -75,13 +80,13 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.True(t, p.Provides(tee)) + assert.True(t, p.Provides(context.Background(), tee)) }) t.Run("provider has multiple regions and one matches", func(t *testing.T) { p := teeProvider{ - TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, - regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regionsFn: regionsFn(map[string]bool{"us-west-2": true, "us-east-1": true}), } tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ @@ -90,13 +95,13 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.True(t, p.Provides(tee)) + assert.True(t, p.Provides(context.Background(), tee)) }) t.Run("no matching region across multiple provider regions", func(t *testing.T) { p := teeProvider{ - TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, - regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regionsFn: regionsFn(map[string]bool{"us-west-2": true, "us-east-1": true}), } tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ @@ -105,11 +110,11 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.False(t, p.Provides(tee)) + assert.False(t, p.Provides(context.Background(), tee)) }) t.Run("type mismatch ignores region match", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType(99), regions: map[string]bool{"us-west-2": true}} + p := teeProvider{TeeType: sdkpb.TeeType(99), regionsFn: regionsFn(map[string]bool{"us-west-2": true})} tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ Types: []*sdkpb.TeeTypeAndRegions{ @@ -117,17 +122,17 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.False(t, p.Provides(tee)) + assert.False(t, p.Provides(context.Background(), tee)) }) t.Run("matches any tee", func(t *testing.T) { - provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, func(context.Context) []string { return []string{"us-west-2"} }) tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} - assert.True(t, provides(tee)) + assert.True(t, provides(context.Background(), tee)) }) t.Run("returns a function that checks regions", func(t *testing.T) { - provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, func(context.Context) []string { return []string{"us-west-2"} }) tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ TypeSelection: &sdkpb.TeeTypeSelection{ Types: []*sdkpb.TeeTypeAndRegions{ @@ -135,6 +140,6 @@ func TestNewTeeProvider(t *testing.T) { }, }, }} - assert.False(t, provides(tee)) + assert.False(t, provides(context.Background(), tee)) }) } From e578b91863399f99c1e8e9849a34e8da245dcd6b Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Thu, 30 Apr 2026 13:57:00 -0400 Subject: [PATCH 11/14] Update protos to use full TEE requirement and not use bytes for the embedded request --- go.mod | 2 +- go.sum | 4 +- .../actions/confidentialworkflow/client.pb.go | 64 ++++++++++--------- 3 files changed, 38 insertions(+), 32 deletions(-) diff --git a/go.mod b/go.mod index 94b1873933..0855b4ec32 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430150654-a1ec54d2121f + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430182156-06c6d2165948 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 45b51dc27c..359c168eae 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430150654-a1ec54d2121f h1:Bxqdqu/me/tAMOQIwxY54MAb4AI1UMbqEjC93ws0krw= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430150654-a1ec54d2121f/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430182156-06c6d2165948 h1:qm8OTi3ypzPuQpWmA1gfWcPNCprUkrlba+g0o+C+uT4= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430182156-06c6d2165948/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go index 95fab33b55..d14aadb6ce 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go @@ -7,6 +7,7 @@ package confidentialworkflow import ( + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" _ "github.com/smartcontractkit/chainlink-protos/cre/go/tools/generator" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" @@ -86,9 +87,8 @@ type WorkflowExecution struct { BinaryUrl string `protobuf:"bytes,2,opt,name=binary_url,json=binaryUrl,proto3" json:"binary_url,omitempty"` // binary_hash is the expected SHA-256 hash of the WASM binary, for integrity verification. BinaryHash []byte `protobuf:"bytes,3,opt,name=binary_hash,json=binaryHash,proto3" json:"binary_hash,omitempty"` - // execute_request is a serialized sdk.v1alpha.ExecuteRequest proto. // Contains either a subscribe request or a trigger execution request. - ExecuteRequest []byte `protobuf:"bytes,4,opt,name=execute_request,json=executeRequest,proto3" json:"execute_request,omitempty"` + ExecuteRequest *sdk.ExecuteRequest `protobuf:"bytes,4,opt,name=execute_request,json=executeRequest,proto3" json:"execute_request,omitempty"` // owner is the on-chain owner address of the workflow (hex, 0x-prefixed). // Used by the enclave for runtime secret fetching from VaultDON. Owner string `protobuf:"bytes,5,opt,name=owner,proto3" json:"owner,omitempty"` @@ -99,7 +99,7 @@ type WorkflowExecution struct { // Used by the enclave when fetching secrets from VaultDON with org-based ownership. OrgId string `protobuf:"bytes,7,opt,name=org_id,json=orgId,proto3" json:"org_id,omitempty"` // regions that the workflow is allowed to run in. - Regions []string `protobuf:"bytes,8,rep,name=regions,proto3" json:"regions,omitempty"` + Tee *sdk.Tee `protobuf:"bytes,8,opt,name=tee,proto3" json:"tee,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -155,7 +155,7 @@ func (x *WorkflowExecution) GetBinaryHash() []byte { return nil } -func (x *WorkflowExecution) GetExecuteRequest() []byte { +func (x *WorkflowExecution) GetExecuteRequest() *sdk.ExecuteRequest { if x != nil { return x.ExecuteRequest } @@ -183,9 +183,9 @@ func (x *WorkflowExecution) GetOrgId() string { return "" } -func (x *WorkflowExecution) GetRegions() []string { +func (x *WorkflowExecution) GetTee() *sdk.Tee { if x != nil { - return x.Regions + return x.Tee } return nil } @@ -248,7 +248,7 @@ func (x *ConfidentialWorkflowRequest) GetExecution() *WorkflowExecution { type ConfidentialWorkflowResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // execution_result is a serialized sdk.v1alpha.ExecutionResult proto. - ExecutionResult []byte `protobuf:"bytes,1,opt,name=execution_result,json=executionResult,proto3" json:"execution_result,omitempty"` + ExecutionResult *sdk.ExecutionResult `protobuf:"bytes,1,opt,name=execution_result,json=executionResult,proto3" json:"execution_result,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -283,7 +283,7 @@ func (*ConfidentialWorkflowResponse) Descriptor() ([]byte, []int) { return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescGZIP(), []int{3} } -func (x *ConfidentialWorkflowResponse) GetExecutionResult() []byte { +func (x *ConfidentialWorkflowResponse) GetExecutionResult() *sdk.ExecutionResult { if x != nil { return x.ExecutionResult } @@ -338,29 +338,29 @@ var File_capabilities_compute_confidentialworkflow_v1alpha_client_proto protoref const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc = "" + "\n" + - ">capabilities/compute/confidentialworkflow/v1alpha/client.proto\x121capabilities.compute.confidentialworkflow.v1alpha\x1a*tools/generator/v1alpha/cre_metadata.proto\x1a\x1bgoogle/protobuf/empty.proto\"U\n" + + ">capabilities/compute/confidentialworkflow/v1alpha/client.proto\x121capabilities.compute.confidentialworkflow.v1alpha\x1a*tools/generator/v1alpha/cre_metadata.proto\x1a\x15sdk/v1alpha/sdk.proto\x1a\x1bgoogle/protobuf/empty.proto\"U\n" + "\x10SecretIdentifier\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12!\n" + "\tnamespace\x18\x02 \x01(\tH\x00R\tnamespace\x88\x01\x01B\f\n" + "\n" + - "_namespace\"\x87\x02\n" + + "_namespace\"\xae\x02\n" + "\x11WorkflowExecution\x12\x1f\n" + "\vworkflow_id\x18\x01 \x01(\tR\n" + "workflowId\x12\x1d\n" + "\n" + "binary_url\x18\x02 \x01(\tR\tbinaryUrl\x12\x1f\n" + "\vbinary_hash\x18\x03 \x01(\fR\n" + - "binaryHash\x12'\n" + - "\x0fexecute_request\x18\x04 \x01(\fR\x0eexecuteRequest\x12\x14\n" + + "binaryHash\x12D\n" + + "\x0fexecute_request\x18\x04 \x01(\v2\x1b.sdk.v1alpha.ExecuteRequestR\x0eexecuteRequest\x12\x14\n" + "\x05owner\x18\x05 \x01(\tR\x05owner\x12!\n" + "\fexecution_id\x18\x06 \x01(\tR\vexecutionId\x12\x15\n" + - "\x06org_id\x18\a \x01(\tR\x05orgId\x12\x18\n" + - "\aregions\x18\b \x03(\tR\aregions\"\xf2\x01\n" + + "\x06org_id\x18\a \x01(\tR\x05orgId\x12\"\n" + + "\x03tee\x18\b \x01(\v2\x10.sdk.v1alpha.TeeR\x03tee\"\xf2\x01\n" + "\x1bConfidentialWorkflowRequest\x12o\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2C.capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12b\n" + - "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"I\n" + - "\x1cConfidentialWorkflowResponse\x12)\n" + - "\x10execution_result\x18\x01 \x01(\fR\x0fexecutionResult\".\n" + + "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"g\n" + + "\x1cConfidentialWorkflowResponse\x12G\n" + + "\x10execution_result\x18\x01 \x01(\v2\x1c.sdk.v1alpha.ExecutionResultR\x0fexecutionResult\".\n" + "\x12GetRegionsResponse\x12\x18\n" + "\aregions\x18\x01 \x03(\tR\aregions2\xce\x02\n" + "\x06Client\x12\xaa\x01\n" + @@ -387,20 +387,26 @@ var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_goTypes (*ConfidentialWorkflowRequest)(nil), // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest (*ConfidentialWorkflowResponse)(nil), // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse (*GetRegionsResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse - (*emptypb.Empty)(nil), // 5: google.protobuf.Empty + (*sdk.ExecuteRequest)(nil), // 5: sdk.v1alpha.ExecuteRequest + (*sdk.Tee)(nil), // 6: sdk.v1alpha.Tee + (*sdk.ExecutionResult)(nil), // 7: sdk.v1alpha.ExecutionResult + (*emptypb.Empty)(nil), // 8: google.protobuf.Empty } var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs = []int32{ - 0, // 0: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier - 1, // 1: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution - 2, // 2: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest - 5, // 3: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:input_type -> google.protobuf.Empty - 3, // 4: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse - 4, // 5: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:output_type -> capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse - 4, // [4:6] is the sub-list for method output_type - 2, // [2:4] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 5, // 0: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.execute_request:type_name -> sdk.v1alpha.ExecuteRequest + 6, // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.tee:type_name -> sdk.v1alpha.Tee + 0, // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier + 1, // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution + 7, // 4: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse.execution_result:type_name -> sdk.v1alpha.ExecutionResult + 2, // 5: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest + 8, // 6: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:input_type -> google.protobuf.Empty + 3, // 7: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + 4, // 8: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:output_type -> capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse + 7, // [7:9] is the sub-list for method output_type + 5, // [5:7] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name } func init() { file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_init() } From 98d72c594db363ccc23d4f3dce1baf534682477c Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Thu, 30 Apr 2026 15:30:59 -0400 Subject: [PATCH 12/14] Send full requirements to modules that accept them --- go.mod | 2 +- go.sum | 4 +- .../actions/confidentialworkflow/client.pb.go | 18 ++--- pkg/workflows/host/module.go | 7 ++ .../host/requirement_selecting_module.go | 67 ++++++++++--------- .../host/requirement_selecting_module_test.go | 65 ++++++++++++++++++ .../requirements_helper.go.tmpl | 4 ++ pkg/workflows/host/requirements_helper_gen.go | 4 ++ 8 files changed, 129 insertions(+), 42 deletions(-) diff --git a/go.mod b/go.mod index 0855b4ec32..f4d69ec69d 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430182156-06c6d2165948 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501133512-474ca53a440a github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 359c168eae..614611f7cb 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430182156-06c6d2165948 h1:qm8OTi3ypzPuQpWmA1gfWcPNCprUkrlba+g0o+C+uT4= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260430182156-06c6d2165948/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501133512-474ca53a440a h1:5evxdEH/QG8ti8Cc6+iQtwJheF5U7WRetsIgZW38HMw= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501133512-474ca53a440a/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go index d14aadb6ce..d7f8c08944 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go @@ -98,8 +98,8 @@ type WorkflowExecution struct { // org_id is the organization identifier for the workflow owner. // Used by the enclave when fetching secrets from VaultDON with org-based ownership. OrgId string `protobuf:"bytes,7,opt,name=org_id,json=orgId,proto3" json:"org_id,omitempty"` - // regions that the workflow is allowed to run in. - Tee *sdk.Tee `protobuf:"bytes,8,opt,name=tee,proto3" json:"tee,omitempty"` + // requirements to run this workflow + Requirements *sdk.Requirements `protobuf:"bytes,8,opt,name=requirements,proto3" json:"requirements,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -183,9 +183,9 @@ func (x *WorkflowExecution) GetOrgId() string { return "" } -func (x *WorkflowExecution) GetTee() *sdk.Tee { +func (x *WorkflowExecution) GetRequirements() *sdk.Requirements { if x != nil { - return x.Tee + return x.Requirements } return nil } @@ -343,7 +343,7 @@ const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDes "\x03key\x18\x01 \x01(\tR\x03key\x12!\n" + "\tnamespace\x18\x02 \x01(\tH\x00R\tnamespace\x88\x01\x01B\f\n" + "\n" + - "_namespace\"\xae\x02\n" + + "_namespace\"\xc9\x02\n" + "\x11WorkflowExecution\x12\x1f\n" + "\vworkflow_id\x18\x01 \x01(\tR\n" + "workflowId\x12\x1d\n" + @@ -354,8 +354,8 @@ const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDes "\x0fexecute_request\x18\x04 \x01(\v2\x1b.sdk.v1alpha.ExecuteRequestR\x0eexecuteRequest\x12\x14\n" + "\x05owner\x18\x05 \x01(\tR\x05owner\x12!\n" + "\fexecution_id\x18\x06 \x01(\tR\vexecutionId\x12\x15\n" + - "\x06org_id\x18\a \x01(\tR\x05orgId\x12\"\n" + - "\x03tee\x18\b \x01(\v2\x10.sdk.v1alpha.TeeR\x03tee\"\xf2\x01\n" + + "\x06org_id\x18\a \x01(\tR\x05orgId\x12=\n" + + "\frequirements\x18\b \x01(\v2\x19.sdk.v1alpha.RequirementsR\frequirements\"\xf2\x01\n" + "\x1bConfidentialWorkflowRequest\x12o\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2C.capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12b\n" + "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"g\n" + @@ -388,13 +388,13 @@ var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_goTypes (*ConfidentialWorkflowResponse)(nil), // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse (*GetRegionsResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse (*sdk.ExecuteRequest)(nil), // 5: sdk.v1alpha.ExecuteRequest - (*sdk.Tee)(nil), // 6: sdk.v1alpha.Tee + (*sdk.Requirements)(nil), // 6: sdk.v1alpha.Requirements (*sdk.ExecutionResult)(nil), // 7: sdk.v1alpha.ExecutionResult (*emptypb.Empty)(nil), // 8: google.protobuf.Empty } var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs = []int32{ 5, // 0: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.execute_request:type_name -> sdk.v1alpha.ExecuteRequest - 6, // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.tee:type_name -> sdk.v1alpha.Tee + 6, // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.requirements:type_name -> sdk.v1alpha.Requirements 0, // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier 1, // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution 7, // 4: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse.execution_result:type_name -> sdk.v1alpha.ExecutionResult diff --git a/pkg/workflows/host/module.go b/pkg/workflows/host/module.go index 1db4c086ce..f4debbb922 100644 --- a/pkg/workflows/host/module.go +++ b/pkg/workflows/host/module.go @@ -23,6 +23,13 @@ type Module interface { Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) } +type RequirementEnforcingModule interface { + Module + + // SetRequirements must respect the requirements for the execution until it completes + SetRequirements(executionId string, requirements *sdkpb.Requirements) +} + // ExecutionHelper Implemented by those running the host, for example the Workflow Engine type ExecutionHelper interface { // CallCapability blocking call to the Workflow Engine diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go index edb45d0e5f..7f83aedf94 100644 --- a/pkg/workflows/host/requirement_selecting_module.go +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -27,31 +27,35 @@ func (l *lazyModule) ensureStarted() { }) } +// NewRequirementSelectingModule creates a module that routes trigger executions +// based on subscription requirements. main is prepended as modules[0]; additional +// modules follow. Subscribe always runs on modules[0]. func NewRequirementSelectingModule(main ModuleAndHandler, additional []ModuleAndHandler) Module { - wrapped := make([]*lazyModule, len(additional)) - for i := range additional { - wrapped[i] = &lazyModule{ModuleAndHandler: additional[i]} - } - return &requirementSelectingModule{ - main: main, - additional: wrapped, + modules := make([]*lazyModule, 1+len(additional)) + modules[0] = &lazyModule{ModuleAndHandler: main} + for i, a := range additional { + modules[1+i] = &lazyModule{ModuleAndHandler: a} } + return &requirementSelectingModule{modules: modules} +} + +type triggerInfo struct { + moduleIdx int + requirements *sdk.Requirements } type requirementSelectingModule struct { - main ModuleAndHandler - additional []*lazyModule - // triggerID → index into additional + modules []*lazyModule + // triggerID → triggerInfo cache sync.Map } func (r *requirementSelectingModule) Start() { - r.main.Start() + r.modules[0].ensureStarted() } func (r *requirementSelectingModule) Close() { - r.main.Close() - for _, m := range r.additional { + for _, m := range r.modules { if m.started { m.Close() } @@ -59,32 +63,28 @@ func (r *requirementSelectingModule) Close() { } func (r *requirementSelectingModule) IsLegacyDAG() bool { - return r.main.IsLegacyDAG() + return r.modules[0].IsLegacyDAG() } func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { - if triggerID, ok := extractTriggerID(request); ok { - if idx, cached := r.cache.Load(triggerID); cached { - return r.additional[idx.(int)].Execute(ctx, request, handler) - } - return r.main.Execute(ctx, request, handler) + if request.GetTrigger() == nil { + return r.subscribe(ctx, request, handler) } + return r.trigger(ctx, request, handler) +} - // Subscribe: run main, then build triggerID→module cache from subscription requirements - result, err := r.main.Execute(ctx, request, handler) +func (r *requirementSelectingModule) subscribe(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + result, err := r.modules[0].Execute(ctx, request, handler) if err != nil { return nil, err } for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() { - if sub.Requirements == nil || CheckRequirements(ctx, r.main.RequirementsHandler, sub.Requirements) { - continue - } matched := false - for j, m := range r.additional { + for j, m := range r.modules { if CheckRequirements(ctx, m.RequirementsHandler, sub.Requirements) { m.ensureStarted() - r.cache.Store(uint64(i), j) + r.cache.Store(uint64(i), triggerInfo{moduleIdx: j, requirements: sub.Requirements}) matched = true break } @@ -97,11 +97,18 @@ func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.E return result, nil } -func extractTriggerID(req *sdk.ExecuteRequest) (uint64, bool) { - if t := req.GetTrigger(); t != nil { - return t.Id, true +func (r *requirementSelectingModule) trigger(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + trigger := request.GetTrigger() + if val, cached := r.cache.Load(trigger.Id); cached { + info := val.(triggerInfo) + m := r.modules[info.moduleIdx] + if rem, ok := m.Module.(RequirementEnforcingModule); ok && info.requirements != nil { + rem.SetRequirements(handler.GetWorkflowExecutionID(), info.requirements) + } + + return m.Execute(ctx, request, handler) } - return 0, false + return r.modules[0].Execute(ctx, request, handler) } var _ Module = &requirementSelectingModule{} diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index e51496a5c6..d3417653bd 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -27,6 +27,15 @@ func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h Exe return s.executeFn(ctx, req, h) } +type requirementEnforcingStub struct { + *stubModule + setRequirementsFn func(string, *sdk.Requirements) +} + +func (s *requirementEnforcingStub) SetRequirements(executionID string, requirements *sdk.Requirements) { + s.setRequirementsFn(executionID, requirements) +} + func noop() {} func noopClose() {} @@ -362,6 +371,62 @@ func TestRequirementSelectingModule_Execute(t *testing.T) { assert.Equal(t, want, got) assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls), "trigger should run on main") }) + + t.Run("cached trigger sets requirements before execute", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + executionID := "wf-exec-1" + + main := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + var calls []string + var gotReqs *sdk.Requirements + var gotExecutionID string + enforcingAdd := &requirementEnforcingStub{ + stubModule: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + calls = append(calls, "execute") + return want, nil + }, + }, + setRequirementsFn: func(id string, requirements *sdk.Requirements) { + calls = append(calls, "set") + gotExecutionID = id + gotReqs = requirements + }, + } + add := ModuleAndHandler{ + Module: enforcingAdd, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + helper := &MockExecutionHelper{} + helper.On("GetWorkflowExecutionID").Return(executionID).Once() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), helper) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, []string{"set", "execute"}, calls) + assert.Equal(t, executionID, gotExecutionID) + assert.Same(t, teeReqs, gotReqs) + helper.AssertExpectations(t) + }) } func TestRequirementSelectingModule_TriggerCache(t *testing.T) { diff --git a/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl index b47f7b3098..5ec3a5d9cb 100644 --- a/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl +++ b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl @@ -14,6 +14,10 @@ type RequirementsHandler struct { // non-nil field in req, returning false if any are false, or if the handler is nil. // Unknown fields on the proto also result in a false return value. func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { + if req == nil { + return true + } + if len(req.ProtoReflect().GetUnknown()) != 0 { return false } diff --git a/pkg/workflows/host/requirements_helper_gen.go b/pkg/workflows/host/requirements_helper_gen.go index f8f2ea1486..85fb0e8b7c 100644 --- a/pkg/workflows/host/requirements_helper_gen.go +++ b/pkg/workflows/host/requirements_helper_gen.go @@ -18,6 +18,10 @@ type RequirementsHandler struct { // non-nil field in req, returning false if any are false, or if the handler is nil. // Unknown fields on the proto also result in a false return value. func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { + if req == nil { + return true + } + if len(req.ProtoReflect().GetUnknown()) != 0 { return false } From 8e8598d367460f359697cc000eb2a7ac700397a6 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Fri, 1 May 2026 13:37:53 -0400 Subject: [PATCH 13/14] Regions with any is allowed too, and let the TEE say everything it is, not just region. --- go.mod | 2 +- go.sum | 4 +- .../actions/confidentialworkflow/client.pb.go | 59 ++--- .../server/client_server_gen.go | 10 +- .../host/requirement_selecting_module_test.go | 4 +- pkg/workflows/host/tee_provider.go | 74 +++--- pkg/workflows/host/tee_provider_test.go | 148 ++++++----- pkg/workflows/host/tee_selection_provider.go | 49 ++++ .../host/tee_selection_provider_test.go | 231 ++++++++++++++++++ pkg/workflows/wasm/host/standard_test.go | 6 +- .../standard_tests/tee_runtime/main_wasip1.go | 2 +- 11 files changed, 448 insertions(+), 141 deletions(-) create mode 100644 pkg/workflows/host/tee_selection_provider.go create mode 100644 pkg/workflows/host/tee_selection_provider_test.go diff --git a/go.mod b/go.mod index f4d69ec69d..bb20444699 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501133512-474ca53a440a + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501160256-5806971948f2 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 614611f7cb..464e2ea510 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501133512-474ca53a440a h1:5evxdEH/QG8ti8Cc6+iQtwJheF5U7WRetsIgZW38HMw= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501133512-474ca53a440a/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501160256-5806971948f2 h1:RKmSjhsAHuN4A62fJn/wGj/dXCBfrRTojo5ZQZfL/y8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501160256-5806971948f2/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go index d7f8c08944..28154e1741 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go @@ -290,27 +290,27 @@ func (x *ConfidentialWorkflowResponse) GetExecutionResult() *sdk.ExecutionResult return nil } -type GetRegionsResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - Regions []string `protobuf:"bytes,1,rep,name=regions,proto3" json:"regions,omitempty"` +type ProvidedTeesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Tee []*sdk.TeeTypeAndRegions `protobuf:"bytes,1,rep,name=tee,proto3" json:"tee,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *GetRegionsResponse) Reset() { - *x = GetRegionsResponse{} +func (x *ProvidedTeesResponse) Reset() { + *x = ProvidedTeesResponse{} mi := &file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *GetRegionsResponse) String() string { +func (x *ProvidedTeesResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GetRegionsResponse) ProtoMessage() {} +func (*ProvidedTeesResponse) ProtoMessage() {} -func (x *GetRegionsResponse) ProtoReflect() protoreflect.Message { +func (x *ProvidedTeesResponse) ProtoReflect() protoreflect.Message { mi := &file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -322,14 +322,14 @@ func (x *GetRegionsResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GetRegionsResponse.ProtoReflect.Descriptor instead. -func (*GetRegionsResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use ProvidedTeesResponse.ProtoReflect.Descriptor instead. +func (*ProvidedTeesResponse) Descriptor() ([]byte, []int) { return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescGZIP(), []int{4} } -func (x *GetRegionsResponse) GetRegions() []string { +func (x *ProvidedTeesResponse) GetTee() []*sdk.TeeTypeAndRegions { if x != nil { - return x.Regions + return x.Tee } return nil } @@ -360,13 +360,12 @@ const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDes "\x11vault_don_secrets\x18\x01 \x03(\v2C.capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12b\n" + "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"g\n" + "\x1cConfidentialWorkflowResponse\x12G\n" + - "\x10execution_result\x18\x01 \x01(\v2\x1c.sdk.v1alpha.ExecutionResultR\x0fexecutionResult\".\n" + - "\x12GetRegionsResponse\x12\x18\n" + - "\aregions\x18\x01 \x03(\tR\aregions2\xce\x02\n" + + "\x10execution_result\x18\x01 \x01(\v2\x1c.sdk.v1alpha.ExecutionResultR\x0fexecutionResult\"H\n" + + "\x14ProvidedTeesResponse\x120\n" + + "\x03tee\x18\x01 \x03(\v2\x1e.sdk.v1alpha.TeeTypeAndRegionsR\x03tee2\xd2\x02\n" + "\x06Client\x12\xaa\x01\n" + - "\aExecute\x12N.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest\x1aO.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse\x12k\n" + - "\n" + - "GetRegions\x12\x16.google.protobuf.Empty\x1aE.capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse\x1a*\x82\xb5\x18&\b\x01\x12\"confidential-workflows@1.0.0-alphab\x06proto3" + "\aExecute\x12N.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest\x1aO.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse\x12o\n" + + "\fProvidedTees\x12\x16.google.protobuf.Empty\x1aG.capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse\x1a*\x82\xb5\x18&\b\x01\x12\"confidential-workflows@1.0.0-alphab\x06proto3" var ( file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescOnce sync.Once @@ -386,11 +385,12 @@ var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_goTypes (*WorkflowExecution)(nil), // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution (*ConfidentialWorkflowRequest)(nil), // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest (*ConfidentialWorkflowResponse)(nil), // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse - (*GetRegionsResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse + (*ProvidedTeesResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse (*sdk.ExecuteRequest)(nil), // 5: sdk.v1alpha.ExecuteRequest (*sdk.Requirements)(nil), // 6: sdk.v1alpha.Requirements (*sdk.ExecutionResult)(nil), // 7: sdk.v1alpha.ExecutionResult - (*emptypb.Empty)(nil), // 8: google.protobuf.Empty + (*sdk.TeeTypeAndRegions)(nil), // 8: sdk.v1alpha.TeeTypeAndRegions + (*emptypb.Empty)(nil), // 9: google.protobuf.Empty } var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs = []int32{ 5, // 0: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.execute_request:type_name -> sdk.v1alpha.ExecuteRequest @@ -398,15 +398,16 @@ var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs 0, // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier 1, // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution 7, // 4: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse.execution_result:type_name -> sdk.v1alpha.ExecutionResult - 2, // 5: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest - 8, // 6: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:input_type -> google.protobuf.Empty - 3, // 7: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse - 4, // 8: capabilities.compute.confidentialworkflow.v1alpha.Client.GetRegions:output_type -> capabilities.compute.confidentialworkflow.v1alpha.GetRegionsResponse - 7, // [7:9] is the sub-list for method output_type - 5, // [5:7] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name + 8, // 5: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse.tee:type_name -> sdk.v1alpha.TeeTypeAndRegions + 2, // 6: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest + 9, // 7: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:input_type -> google.protobuf.Empty + 3, // 8: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + 4, // 9: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse + 8, // [8:10] is the sub-list for method output_type + 6, // [6:8] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name } func init() { file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_init() } diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go b/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go index 3d30de9c87..cbebc884e8 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go @@ -21,7 +21,7 @@ var _ = emptypb.Empty{} type ClientCapability interface { Execute(ctx context.Context, metadata capabilities.RequestMetadata, input *confidentialworkflow.ConfidentialWorkflowRequest) (*capabilities.ResponseAndMetadata[*confidentialworkflow.ConfidentialWorkflowResponse], caperrors.Error) - GetRegions(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty) (*capabilities.ResponseAndMetadata[*confidentialworkflow.GetRegionsResponse], caperrors.Error) + ProvidedTees(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty) (*capabilities.ResponseAndMetadata[*confidentialworkflow.ProvidedTeesResponse], caperrors.Error) Start(ctx context.Context) error Close() error @@ -139,16 +139,16 @@ func (c *clientCapability) Execute(ctx context.Context, request capabilities.Cap return output.Response, output.ResponseMetadata, output.OCRAttestation, err } return capabilities.Execute(ctx, request, input, config, wrapped) - case "GetRegions": + case "ProvidedTees": input := &emptypb.Empty{} config := &emptypb.Empty{} - wrapped := func(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty, _ *emptypb.Empty) (*confidentialworkflow.GetRegionsResponse, capabilities.ResponseMetadata, *capabilities.OCRAttestation, error) { - output, err := c.ClientCapability.GetRegions(ctx, metadata, input) + wrapped := func(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty, _ *emptypb.Empty) (*confidentialworkflow.ProvidedTeesResponse, capabilities.ResponseMetadata, *capabilities.OCRAttestation, error) { + output, err := c.ClientCapability.ProvidedTees(ctx, metadata, input) if err != nil { return nil, capabilities.ResponseMetadata{}, nil, err } if output == nil { - return nil, capabilities.ResponseMetadata{}, nil, fmt.Errorf("output and error is nil for method GetRegions(..) (if output is nil error must be present)") + return nil, capabilities.ResponseMetadata{}, nil, fmt.Errorf("output and error is nil for method ProvidedTees(..) (if output is nil error must be present)") } return output.Response, output.ResponseMetadata, output.OCRAttestation, err } diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go index d3417653bd..1ffd5d94f7 100644 --- a/pkg/workflows/host/requirement_selecting_module_test.go +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -510,8 +510,8 @@ func TestRequirementSelectingModule_TriggerCache(t *testing.T) { t.Run("different triggers route to different modules", func(t *testing.T) { // subscription 0: TEE required → additional; subscription 1: no requirements → main teeReqs := &sdk.Requirements{Tee: &sdk.Tee{ - Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{ - Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}, + Item: &sdk.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdk.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}, }}, }} var mainTriggerCalls int32 diff --git a/pkg/workflows/host/tee_provider.go b/pkg/workflows/host/tee_provider.go index 8bdb9d258a..4b9344be8d 100644 --- a/pkg/workflows/host/tee_provider.go +++ b/pkg/workflows/host/tee_provider.go @@ -1,52 +1,52 @@ package host -import ( - "context" - "sync" - - sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" -) +import sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" type teeProvider struct { sdkpb.TeeType - regionsFn func(ctx context.Context) map[string]bool - once sync.Once + regions map[string]bool } -func NewTeeProvider(tpe sdkpb.TeeType, regionsFn func(ctx context.Context) []string) func(context.Context, *sdkpb.Tee) bool { - p := &teeProvider{ - TeeType: tpe, - regionsFn: func(ctx context.Context) map[string]bool { - regions := regionsFn(ctx) - rMap := make(map[string]bool, len(regions)) - for _, region := range regions { - rMap[region] = true - } - return rMap - }, +func NewTeeProvider(tpe sdkpb.TeeType, regions []string) func(tee *sdkpb.Tee) bool { + supportedRegions := map[string]bool{} + for _, region := range regions { + supportedRegions[region] = true } - return p.Provides + return (&teeProvider{TeeType: tpe, regions: supportedRegions}).Provides } -func (t *teeProvider) Provides(ctx context.Context, tee *sdkpb.Tee) bool { - switch teet := tee.Type.(type) { - case *sdkpb.Tee_Any: - return true - case *sdkpb.Tee_TypeSelection: - for _, selection := range teet.TypeSelection.Types { - if selection.Type == t.TeeType { - if len(selection.Regions) == 0 { - return true - } - - regions := t.regionsFn(ctx) - for _, region := range selection.Regions { - if regions[region] { - return true - } - } +func (t *teeProvider) Provides(tee *sdkpb.Tee) bool { + var regions []string + switch teet := tee.Item.(type) { + case *sdkpb.Tee_AnyRegions: + regions = teet.AnyRegions.Regions + case *sdkpb.Tee_TeeTypesAndRegions: + if teet.TeeTypesAndRegions == nil { + return false + } + + found := false + for _, tr := range teet.TeeTypesAndRegions.TeeTypeAndRegions { + if tr.Type == t.TeeType { + found = true + regions = tr.Regions + break } } + + if !found { + return false + } + } + + if len(regions) == 0 { + return true + } + + for _, region := range regions { + if t.regions[region] { + return true + } } return false diff --git a/pkg/workflows/host/tee_provider_test.go b/pkg/workflows/host/tee_provider_test.go index d23791c18c..0ca39602ea 100644 --- a/pkg/workflows/host/tee_provider_test.go +++ b/pkg/workflows/host/tee_provider_test.go @@ -1,145 +1,171 @@ package host import ( - "context" "testing" "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/emptypb" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) -func regionsFn(regions map[string]bool) func(context.Context) map[string]bool { - return func(context.Context) map[string]bool { return regions } -} - func TestNewTeeProvider(t *testing.T) { t.Parallel() t.Run("matches any", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} - assert.True(t, p.Provides(context.Background(), tee)) + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, p.Provides(tee)) }) - t.Run("matches type selection with no region constraint", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + t.Run("matches type selection with matching region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType(99)}, - {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, }, }, }} - assert.True(t, p.Provides(context.Background(), tee)) + assert.True(t, p.Provides(tee)) }) t.Run("does not match different types", func(t *testing.T) { p := teeProvider{TeeType: sdkpb.TeeType(99)} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ - {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, }, }, }} - assert.False(t, p.Provides(context.Background(), tee)) + assert.False(t, p.Provides(tee)) }) t.Run("matches type and region", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, }, }, }} - assert.True(t, p.Provides(context.Background(), tee)) + assert.True(t, p.Provides(tee)) }) t.Run("matches type but not region", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"us-west-2": true})} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, }, }, }} - assert.False(t, p.Provides(context.Background(), tee)) + assert.False(t, p.Provides(tee)) }) t.Run("matches one of multiple requested regions", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regionsFn: regionsFn(map[string]bool{"eu-west-1": true})} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"eu-west-1": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2", "eu-west-1"}}, }, }, }} - assert.True(t, p.Provides(context.Background(), tee)) + assert.True(t, p.Provides(tee)) }) t.Run("provider has multiple regions and one matches", func(t *testing.T) { p := teeProvider{ - TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, - regionsFn: regionsFn(map[string]bool{"us-west-2": true, "us-east-1": true}), + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, } - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-east-1"}}, }, }, }} - assert.True(t, p.Provides(context.Background(), tee)) + assert.True(t, p.Provides(tee)) }) t.Run("no matching region across multiple provider regions", func(t *testing.T) { p := teeProvider{ - TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, - regionsFn: regionsFn(map[string]bool{"us-west-2": true, "us-east-1": true}), + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, } - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"ap-southeast-1"}}, }, }, }} - assert.False(t, p.Provides(context.Background(), tee)) + assert.False(t, p.Provides(tee)) }) t.Run("type mismatch ignores region match", func(t *testing.T) { - p := teeProvider{TeeType: sdkpb.TeeType(99), regionsFn: regionsFn(map[string]bool{"us-west-2": true})} - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + p := teeProvider{TeeType: sdkpb.TeeType(99), regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, }, }, }} - assert.False(t, p.Provides(context.Background(), tee)) + assert.False(t, p.Provides(tee)) }) t.Run("matches any tee", func(t *testing.T) { - provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, func(context.Context) []string { return []string{"us-west-2"} }) - tee := &sdkpb.Tee{Type: &sdkpb.Tee_Any{Any: &emptypb.Empty{}}} - assert.True(t, provides(context.Background(), tee)) + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, provides(tee)) }) t.Run("returns a function that checks regions", func(t *testing.T) { - provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, func(context.Context) []string { return []string{"us-west-2"} }) - tee := &sdkpb.Tee{Type: &sdkpb.Tee_TypeSelection{ - TypeSelection: &sdkpb.TeeTypeSelection{ - Types: []*sdkpb.TeeTypeAndRegions{ + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, }, }, }} - assert.False(t, provides(context.Background(), tee)) + assert.False(t, provides(tee)) + }) + + t.Run("returns false when tee item is nil", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{} + assert.True(t, provides(tee)) + }) + + t.Run("AnyRegions with empty region list returns false", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{}}} + assert.True(t, provides(tee)) + }) + + t.Run("TeeTypesAndRegions with empty region list returns true", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, + }}} + assert.True(t, provides(tee)) + }) + + t.Run("TeeTypesAndRegions with nil regions returns false", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: nil}, + }, + }}} + assert.True(t, provides(tee)) }) } diff --git a/pkg/workflows/host/tee_selection_provider.go b/pkg/workflows/host/tee_selection_provider.go new file mode 100644 index 0000000000..c31938e4db --- /dev/null +++ b/pkg/workflows/host/tee_selection_provider.go @@ -0,0 +1,49 @@ +package host + +import ( + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func NewProviderFromSelection(types []*sdkpb.TeeTypeAndRegions) func(tee *sdkpb.Tee) bool { + if len(types) == 1 { + return NewTeeProvider(types[0].Type, types[0].Regions) + } + + supplies := make(map[sdkpb.TeeType][]string) + for _, t := range types { + supplies[t.Type] = append(supplies[t.Type], t.Regions...) + } + + providers := make(map[sdkpb.TeeType]func(tee *sdkpb.Tee) bool) + for k, v := range supplies { + providers[k] = NewTeeProvider(k, v) + } + + return func(tee *sdkpb.Tee) bool { + switch teet := tee.Item.(type) { + case *sdkpb.Tee_AnyRegions: + for _, provider := range providers { + if provider(tee) { + return true + } + } + + return false + case *sdkpb.Tee_TeeTypesAndRegions: + for _, requestedType := range teet.TeeTypesAndRegions.TeeTypeAndRegions { + provider, ok := providers[requestedType.Type] + if !ok { + continue + } + + if provider(tee) { + return true + } + } + + return false + default: + return false + } + } +} diff --git a/pkg/workflows/host/tee_selection_provider_test.go b/pkg/workflows/host/tee_selection_provider_test.go new file mode 100644 index 0000000000..4cb10d1dbe --- /dev/null +++ b/pkg/workflows/host/tee_selection_provider_test.go @@ -0,0 +1,231 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func TestNewProviderFromSelection(t *testing.T) { + t.Parallel() + + t.Run("returns false for nil selection", func(t *testing.T) { + provider := NewProviderFromSelection(nil) + assert.False(t, provider(&sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}})) + }) + + t.Run("single type selection delegates to tee provider", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types support any tee", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types merges regions for same type", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }) + + regions := []string{"eu-west-1"} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: regions, + }}}}} + assert.True(t, provider(tee)) + regions[0] = "us-west-2" + assert.True(t, provider(tee)) + }) + + t.Run("multiple types returns false when requested type is not supplied", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType(999), + Regions: []string{"us-west-2"}, + }}}}} + assert.False(t, provider(tee)) + }) + + t.Run("returns false for unsupported tee shape", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + }}) + assert.False(t, provider(&sdkpb.Tee{})) + }) + + t.Run("multi-type AnyRegions returns false when no provider matches", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + // AnyRegions with a region that doesn't match any provider's regions + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"ap-southeast-1"}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with nil TeeTypesAndRegions returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: nil}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions returns false when all requested types not in providers", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + // Request types that don't exist in providers + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(999), Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(888), Regions: []string{"eu-west-1"}}, + }, + }}} + assert.False(t, provider(tee)) + }) + + t.Run("single type TeeTypesAndRegions with non-matching region returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"eu-west-1"}, + }}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions partial match skips non-providers", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + // Request both AWS_NITRO and an unknown type; AWS_NITRO should match + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("single type returns directly without closure for AnyRegions", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("single type returns false for non-matching AnyRegions", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with empty TeeTypeAndRegions array returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{}, + }}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type AnyRegions with empty regions list returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{}}} + assert.False(t, provider(tee)) + }) + + t.Run("multiple types with no regions in first type", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with first type not matching then match on second", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(888), Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("unsupported item type in multi-type scenario", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions all types not in providers with continue path", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(555), Regions: []string{"us-west-2"}}, + }) + + // Request a type that is never in providers - forces continue on every iteration + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(777), Regions: []string{"us-west-2"}}, + }, + }}} + assert.False(t, provider(tee)) + }) +} diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index d4e8ccf660..7d2b7f5d88 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -560,9 +560,9 @@ func TestStandardTeeRuntime(t *testing.T) { Method: "Trigger", Requirements: &sdk.Requirements{ Tee: &sdk.Tee{ - Type: &sdk.Tee_TypeSelection{ - TypeSelection: &sdk.TeeTypeSelection{ - Types: []*sdk.TeeTypeAndRegions{ + Item: &sdk.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdk.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{ {Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, }, }, diff --git a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go index a17cb37de8..a6fdfec748 100644 --- a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go +++ b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go @@ -9,7 +9,7 @@ import ( ) func main() { - requirements := &sdk.Requirements{Tee: &sdk.Tee{Type: &sdk.Tee_TypeSelection{TypeSelection: &sdk.TeeTypeSelection{Types: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}}} + requirements := &sdk.Requirements{Tee: &sdk.Tee{Item: &sdk.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdk.TeeTypesAndRegions{TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}}} subscription := &sdk.TriggerSubscriptionRequest{ Subscriptions: []*sdk.TriggerSubscription{ { From f0dfd6b71e78f7e21b85d55d7a858b02b7459024 Mon Sep 17 00:00:00 2001 From: Ryan Tinianov Date: Mon, 4 May 2026 12:14:25 -0400 Subject: [PATCH 14/14] Update protos for unknown TEE Type --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index bb20444699..967940241c 100644 --- a/go.mod +++ b/go.mod @@ -43,7 +43,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501160256-5806971948f2 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260504161322-7061fbfd5189 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 464e2ea510..94226b2d2c 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501160256-5806971948f2 h1:RKmSjhsAHuN4A62fJn/wGj/dXCBfrRTojo5ZQZfL/y8= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260501160256-5806971948f2/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260504161322-7061fbfd5189 h1:Fe3Njnug3v3lXGTctzrHUbQSrEoXocdr9bAsakk5RB4= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260504161322-7061fbfd5189/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8=