diff --git a/.mockery.yaml b/.mockery.yaml index 91e44a850e..16ee6356dd 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -38,13 +38,12 @@ packages: github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host: interfaces: ModuleV1: {} - ModuleV2: {} + github.com/smartcontractkit/chainlink-common/pkg/workflows/host: + interfaces: + Module: {} ExecutionHelper: config: - inpackage: true - filename: "mock_{{.InterfaceName | snakecase}}_test.go" - mockname: "Mock{{.InterfaceName}}" - dir: "{{.InterfaceDir}}" + mockname: "Mock{{.InterfaceName}}" github.com/smartcontractkit/chainlink-common/pkg/custmsg: interfaces: MessageEmitter: @@ -64,4 +63,4 @@ packages: dir: "{{.InterfaceDir}}/limits" outpkg: limits interfaces: - Getter: \ No newline at end of file + Getter: diff --git a/go.mod b/go.mod index aec001395b..58ccde6a5e 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,7 @@ require ( github.com/smartcontractkit/chain-selectors v1.0.89 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 - github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260420204255-a3f3bdd56877 + github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260504161322-7061fbfd5189 github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b github.com/smartcontractkit/chainlink-protos/storage-service v0.3.0 diff --git a/go.sum b/go.sum index 0014393c5c..b347f60c82 100644 --- a/go.sum +++ b/go.sum @@ -262,8 +262,8 @@ github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10 h1:FJAFgXS9 github.com/smartcontractkit/chainlink-common/pkg/chipingress v0.0.10/go.mod h1:oiDa54M0FwxevWwyAX773lwdWvFYYlYHHQV1LQ5HpWY= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4 h1:GCzrxDWn3b7jFfEA+WiYRi8CKoegsayiDoJBCjYkneE= github.com/smartcontractkit/chainlink-protos/billing/go v0.0.0-20251024234028-0988426d98f4/go.mod h1:HHGeDUpAsPa0pmOx7wrByCitjQ0mbUxf0R9v+g67uCA= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260420204255-a3f3bdd56877 h1:6UueUIbck1Ogarm9rm/9TS6b09mKgMmx+YE8XFg63AQ= -github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260420204255-a3f3bdd56877/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260504161322-7061fbfd5189 h1:Fe3Njnug3v3lXGTctzrHUbQSrEoXocdr9bAsakk5RB4= +github.com/smartcontractkit/chainlink-protos/cre/go v0.0.0-20260504161322-7061fbfd5189/go.mod h1:Jqt53s27Tr0jDl8mdBXg1xhu6F8Fci8JOuq43tgHOM8= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b h1:QuI6SmQFK/zyUlVWEf0GMkiUYBPY4lssn26nKSd/bOM= github.com/smartcontractkit/chainlink-protos/linking-service/go v0.0.0-20251002192024-d2ad9222409b/go.mod h1:qSTSwX3cBP3FKQwQacdjArqv0g6QnukjV4XuzO6UyoY= github.com/smartcontractkit/chainlink-protos/node-platform v0.0.0-20260205130626-db2a2aab956b h1:36knUpKHHAZ86K4FGWXtx8i/EQftGdk2bqCoEu/Cha8= diff --git a/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go b/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go index 4497ce5c8b..8b1b74b65b 100644 --- a/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialhttp/client.pb.go @@ -441,9 +441,9 @@ const file_capabilities_networking_confidentialhttp_v1alpha_client_proto_rawDesc "\x05value\x18\x02 \x01(\v2>.capabilities.networking.confidentialhttp.v1alpha.HeaderValuesR\x05value:\x028\x01\"\xe2\x01\n" + "\x17ConfidentialHTTPRequest\x12n\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2B.capabilities.networking.confidentialhttp.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12W\n" + - "\arequest\x18\x02 \x01(\v2=.capabilities.networking.confidentialhttp.v1alpha.HTTPRequestR\arequest2\xca\x01\n" + + "\arequest\x18\x02 \x01(\v2=.capabilities.networking.confidentialhttp.v1alpha.HTTPRequestR\arequest2\xcd\x01\n" + "\x06Client\x12\x98\x01\n" + - "\vSendRequest\x12I.capabilities.networking.confidentialhttp.v1alpha.ConfidentialHTTPRequest\x1a>.capabilities.networking.confidentialhttp.v1alpha.HTTPResponse\x1a%\x82\xb5\x18!\b\x01\x12\x1dconfidential-http@1.0.0-alphab\x06proto3" + "\vSendRequest\x12I.capabilities.networking.confidentialhttp.v1alpha.ConfidentialHTTPRequest\x1a>.capabilities.networking.confidentialhttp.v1alpha.HTTPResponse\x1a(\x82\xb5\x18$\b\x01\x12\x1dconfidential-http@1.0.0-alpha\"\x01\x01b\x06proto3" var ( file_capabilities_networking_confidentialhttp_v1alpha_client_proto_rawDescOnce sync.Once diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go index c9f311fdb1..28154e1741 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/client.pb.go @@ -1,15 +1,17 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.11 +// protoc-gen-go v1.36.6 // protoc v5.29.3 // source: capabilities/compute/confidentialworkflow/v1alpha/client.proto package confidentialworkflow import ( + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" _ "github.com/smartcontractkit/chainlink-protos/cre/go/tools/generator" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" sync "sync" unsafe "unsafe" @@ -85,9 +87,8 @@ type WorkflowExecution struct { BinaryUrl string `protobuf:"bytes,2,opt,name=binary_url,json=binaryUrl,proto3" json:"binary_url,omitempty"` // binary_hash is the expected SHA-256 hash of the WASM binary, for integrity verification. BinaryHash []byte `protobuf:"bytes,3,opt,name=binary_hash,json=binaryHash,proto3" json:"binary_hash,omitempty"` - // execute_request is a serialized sdk.v1alpha.ExecuteRequest proto. // Contains either a subscribe request or a trigger execution request. - ExecuteRequest []byte `protobuf:"bytes,4,opt,name=execute_request,json=executeRequest,proto3" json:"execute_request,omitempty"` + ExecuteRequest *sdk.ExecuteRequest `protobuf:"bytes,4,opt,name=execute_request,json=executeRequest,proto3" json:"execute_request,omitempty"` // owner is the on-chain owner address of the workflow (hex, 0x-prefixed). // Used by the enclave for runtime secret fetching from VaultDON. Owner string `protobuf:"bytes,5,opt,name=owner,proto3" json:"owner,omitempty"` @@ -96,7 +97,9 @@ type WorkflowExecution struct { ExecutionId string `protobuf:"bytes,6,opt,name=execution_id,json=executionId,proto3" json:"execution_id,omitempty"` // org_id is the organization identifier for the workflow owner. // Used by the enclave when fetching secrets from VaultDON with org-based ownership. - OrgId string `protobuf:"bytes,7,opt,name=org_id,json=orgId,proto3" json:"org_id,omitempty"` + OrgId string `protobuf:"bytes,7,opt,name=org_id,json=orgId,proto3" json:"org_id,omitempty"` + // requirements to run this workflow + Requirements *sdk.Requirements `protobuf:"bytes,8,opt,name=requirements,proto3" json:"requirements,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -152,7 +155,7 @@ func (x *WorkflowExecution) GetBinaryHash() []byte { return nil } -func (x *WorkflowExecution) GetExecuteRequest() []byte { +func (x *WorkflowExecution) GetExecuteRequest() *sdk.ExecuteRequest { if x != nil { return x.ExecuteRequest } @@ -180,6 +183,13 @@ func (x *WorkflowExecution) GetOrgId() string { return "" } +func (x *WorkflowExecution) GetRequirements() *sdk.Requirements { + if x != nil { + return x.Requirements + } + return nil +} + // ConfidentialWorkflowRequest is the input provided to the confidential workflows capability. // It combines a WorkflowExecution with secrets from VaultDON. type ConfidentialWorkflowRequest struct { @@ -238,7 +248,7 @@ func (x *ConfidentialWorkflowRequest) GetExecution() *WorkflowExecution { type ConfidentialWorkflowResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // execution_result is a serialized sdk.v1alpha.ExecutionResult proto. - ExecutionResult []byte `protobuf:"bytes,1,opt,name=execution_result,json=executionResult,proto3" json:"execution_result,omitempty"` + ExecutionResult *sdk.ExecutionResult `protobuf:"bytes,1,opt,name=execution_result,json=executionResult,proto3" json:"execution_result,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -273,41 +283,89 @@ func (*ConfidentialWorkflowResponse) Descriptor() ([]byte, []int) { return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescGZIP(), []int{3} } -func (x *ConfidentialWorkflowResponse) GetExecutionResult() []byte { +func (x *ConfidentialWorkflowResponse) GetExecutionResult() *sdk.ExecutionResult { if x != nil { return x.ExecutionResult } return nil } +type ProvidedTeesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Tee []*sdk.TeeTypeAndRegions `protobuf:"bytes,1,rep,name=tee,proto3" json:"tee,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ProvidedTeesResponse) Reset() { + *x = ProvidedTeesResponse{} + mi := &file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ProvidedTeesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ProvidedTeesResponse) ProtoMessage() {} + +func (x *ProvidedTeesResponse) ProtoReflect() protoreflect.Message { + mi := &file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ProvidedTeesResponse.ProtoReflect.Descriptor instead. +func (*ProvidedTeesResponse) Descriptor() ([]byte, []int) { + return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescGZIP(), []int{4} +} + +func (x *ProvidedTeesResponse) GetTee() []*sdk.TeeTypeAndRegions { + if x != nil { + return x.Tee + } + return nil +} + var File_capabilities_compute_confidentialworkflow_v1alpha_client_proto protoreflect.FileDescriptor const file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc = "" + "\n" + - ">capabilities/compute/confidentialworkflow/v1alpha/client.proto\x121capabilities.compute.confidentialworkflow.v1alpha\x1a*tools/generator/v1alpha/cre_metadata.proto\"U\n" + + ">capabilities/compute/confidentialworkflow/v1alpha/client.proto\x121capabilities.compute.confidentialworkflow.v1alpha\x1a*tools/generator/v1alpha/cre_metadata.proto\x1a\x15sdk/v1alpha/sdk.proto\x1a\x1bgoogle/protobuf/empty.proto\"U\n" + "\x10SecretIdentifier\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12!\n" + "\tnamespace\x18\x02 \x01(\tH\x00R\tnamespace\x88\x01\x01B\f\n" + "\n" + - "_namespace\"\xed\x01\n" + + "_namespace\"\xc9\x02\n" + "\x11WorkflowExecution\x12\x1f\n" + "\vworkflow_id\x18\x01 \x01(\tR\n" + "workflowId\x12\x1d\n" + "\n" + "binary_url\x18\x02 \x01(\tR\tbinaryUrl\x12\x1f\n" + "\vbinary_hash\x18\x03 \x01(\fR\n" + - "binaryHash\x12'\n" + - "\x0fexecute_request\x18\x04 \x01(\fR\x0eexecuteRequest\x12\x14\n" + + "binaryHash\x12D\n" + + "\x0fexecute_request\x18\x04 \x01(\v2\x1b.sdk.v1alpha.ExecuteRequestR\x0eexecuteRequest\x12\x14\n" + "\x05owner\x18\x05 \x01(\tR\x05owner\x12!\n" + "\fexecution_id\x18\x06 \x01(\tR\vexecutionId\x12\x15\n" + - "\x06org_id\x18\a \x01(\tR\x05orgId\"\xf2\x01\n" + + "\x06org_id\x18\a \x01(\tR\x05orgId\x12=\n" + + "\frequirements\x18\b \x01(\v2\x19.sdk.v1alpha.RequirementsR\frequirements\"\xf2\x01\n" + "\x1bConfidentialWorkflowRequest\x12o\n" + "\x11vault_don_secrets\x18\x01 \x03(\v2C.capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifierR\x0fvaultDonSecrets\x12b\n" + - "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"I\n" + - "\x1cConfidentialWorkflowResponse\x12)\n" + - "\x10execution_result\x18\x01 \x01(\fR\x0fexecutionResult2\xe1\x01\n" + + "\texecution\x18\x02 \x01(\v2D.capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecutionR\texecution\"g\n" + + "\x1cConfidentialWorkflowResponse\x12G\n" + + "\x10execution_result\x18\x01 \x01(\v2\x1c.sdk.v1alpha.ExecutionResultR\x0fexecutionResult\"H\n" + + "\x14ProvidedTeesResponse\x120\n" + + "\x03tee\x18\x01 \x03(\v2\x1e.sdk.v1alpha.TeeTypeAndRegionsR\x03tee2\xd2\x02\n" + "\x06Client\x12\xaa\x01\n" + - "\aExecute\x12N.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest\x1aO.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse\x1a*\x82\xb5\x18&\b\x01\x12\"confidential-workflows@1.0.0-alphab\x06proto3" + "\aExecute\x12N.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest\x1aO.capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse\x12o\n" + + "\fProvidedTees\x12\x16.google.protobuf.Empty\x1aG.capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse\x1a*\x82\xb5\x18&\b\x01\x12\"confidential-workflows@1.0.0-alphab\x06proto3" var ( file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescOnce sync.Once @@ -321,23 +379,35 @@ func file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc return file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDescData } -var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_goTypes = []any{ (*SecretIdentifier)(nil), // 0: capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier (*WorkflowExecution)(nil), // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution (*ConfidentialWorkflowRequest)(nil), // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest (*ConfidentialWorkflowResponse)(nil), // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + (*ProvidedTeesResponse)(nil), // 4: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse + (*sdk.ExecuteRequest)(nil), // 5: sdk.v1alpha.ExecuteRequest + (*sdk.Requirements)(nil), // 6: sdk.v1alpha.Requirements + (*sdk.ExecutionResult)(nil), // 7: sdk.v1alpha.ExecutionResult + (*sdk.TeeTypeAndRegions)(nil), // 8: sdk.v1alpha.TeeTypeAndRegions + (*emptypb.Empty)(nil), // 9: google.protobuf.Empty } var file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_depIdxs = []int32{ - 0, // 0: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier - 1, // 1: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution - 2, // 2: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest - 3, // 3: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse - 3, // [3:4] is the sub-list for method output_type - 2, // [2:3] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 5, // 0: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.execute_request:type_name -> sdk.v1alpha.ExecuteRequest + 6, // 1: capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution.requirements:type_name -> sdk.v1alpha.Requirements + 0, // 2: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.vault_don_secrets:type_name -> capabilities.compute.confidentialworkflow.v1alpha.SecretIdentifier + 1, // 3: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest.execution:type_name -> capabilities.compute.confidentialworkflow.v1alpha.WorkflowExecution + 7, // 4: capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse.execution_result:type_name -> sdk.v1alpha.ExecutionResult + 8, // 5: capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse.tee:type_name -> sdk.v1alpha.TeeTypeAndRegions + 2, // 6: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:input_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowRequest + 9, // 7: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:input_type -> google.protobuf.Empty + 3, // 8: capabilities.compute.confidentialworkflow.v1alpha.Client.Execute:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ConfidentialWorkflowResponse + 4, // 9: capabilities.compute.confidentialworkflow.v1alpha.Client.ProvidedTees:output_type -> capabilities.compute.confidentialworkflow.v1alpha.ProvidedTeesResponse + 8, // [8:10] is the sub-list for method output_type + 6, // [6:8] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name } func init() { file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_init() } @@ -352,7 +422,7 @@ func file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_init() GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc), len(file_capabilities_compute_confidentialworkflow_v1alpha_client_proto_rawDesc)), NumEnums: 0, - NumMessages: 4, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go b/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go index 222bc4fa9f..cbebc884e8 100644 --- a/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go +++ b/pkg/capabilities/v2/actions/confidentialworkflow/server/client_server_gen.go @@ -21,6 +21,8 @@ var _ = emptypb.Empty{} type ClientCapability interface { Execute(ctx context.Context, metadata capabilities.RequestMetadata, input *confidentialworkflow.ConfidentialWorkflowRequest) (*capabilities.ResponseAndMetadata[*confidentialworkflow.ConfidentialWorkflowResponse], caperrors.Error) + ProvidedTees(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty) (*capabilities.ResponseAndMetadata[*confidentialworkflow.ProvidedTeesResponse], caperrors.Error) + Start(ctx context.Context) error Close() error HealthReport() map[string]error @@ -137,6 +139,20 @@ func (c *clientCapability) Execute(ctx context.Context, request capabilities.Cap return output.Response, output.ResponseMetadata, output.OCRAttestation, err } return capabilities.Execute(ctx, request, input, config, wrapped) + case "ProvidedTees": + input := &emptypb.Empty{} + config := &emptypb.Empty{} + wrapped := func(ctx context.Context, metadata capabilities.RequestMetadata, input *emptypb.Empty, _ *emptypb.Empty) (*confidentialworkflow.ProvidedTeesResponse, capabilities.ResponseMetadata, *capabilities.OCRAttestation, error) { + output, err := c.ClientCapability.ProvidedTees(ctx, metadata, input) + if err != nil { + return nil, capabilities.ResponseMetadata{}, nil, err + } + if output == nil { + return nil, capabilities.ResponseMetadata{}, nil, fmt.Errorf("output and error is nil for method ProvidedTees(..) (if output is nil error must be present)") + } + return output.Response, output.ResponseMetadata, output.OCRAttestation, err + } + return capabilities.Execute(ctx, request, input, config, wrapped) default: return response, fmt.Errorf("method %s not found", request.Method) } diff --git a/pkg/capabilities/v2/actions/http/client.pb.go b/pkg/capabilities/v2/actions/http/client.pb.go index 6a93e88293..f9a767596b 100644 --- a/pkg/capabilities/v2/actions/http/client.pb.go +++ b/pkg/capabilities/v2/actions/http/client.pb.go @@ -320,9 +320,9 @@ const file_capabilities_networking_http_v1alpha_client_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\x1as\n" + "\x11MultiHeadersEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12H\n" + - "\x05value\x18\x02 \x01(\v22.capabilities.networking.http.v1alpha.HeaderValuesR\x05value:\x028\x012\x98\x01\n" + + "\x05value\x18\x02 \x01(\v22.capabilities.networking.http.v1alpha.HeaderValuesR\x05value:\x028\x012\x9b\x01\n" + "\x06Client\x12l\n" + - "\vSendRequest\x12-.capabilities.networking.http.v1alpha.Request\x1a..capabilities.networking.http.v1alpha.Response\x1a \x82\xb5\x18\x1c\b\x02\x12\x18http-actions@1.0.0-alphab\x06proto3" + "\vSendRequest\x12-.capabilities.networking.http.v1alpha.Request\x1a..capabilities.networking.http.v1alpha.Response\x1a#\x82\xb5\x18\x1f\b\x02\x12\x18http-actions@1.0.0-alpha\"\x01\x01b\x06proto3" var ( file_capabilities_networking_http_v1alpha_client_proto_rawDescOnce sync.Once diff --git a/pkg/capabilities/v2/protoc/pkg/template_generator.go b/pkg/capabilities/v2/protoc/pkg/template_generator.go index eeb674c7f7..1314ad81d7 100644 --- a/pkg/capabilities/v2/protoc/pkg/template_generator.go +++ b/pkg/capabilities/v2/protoc/pkg/template_generator.go @@ -139,9 +139,9 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial if md == nil { return false, nil - } else { - return md.MapToUntypedApi, nil } + + return md.MapToUntypedApi, nil }, "addImport": func(importPath protogen.GoImportPath, ignore string) string { importName := importPath.String() @@ -259,6 +259,18 @@ func (t *TemplateGenerator) runTemplate(name, tmplText string, args any, partial return line } }, + "TeeEnabled": func(s *protogen.Service) (bool, error) { + md, err := getCapabilityMetadata(s) + if err != nil { + return false, err + } + for _, env := range md.AdditionalEnvironments { + if env == generator.AdditionalEnironments_ADDITIONAL_ENVIRONMENTS_TEE { + return true, nil + } + } + return false, nil + }, }).Funcs(t.ExtraFns) // Register partials diff --git a/pkg/workflows/wasm/host/mock_execution_helper_test.go b/pkg/workflows/host/mock_execution_helper_test.go similarity index 100% rename from pkg/workflows/wasm/host/mock_execution_helper_test.go rename to pkg/workflows/host/mock_execution_helper_test.go diff --git a/pkg/workflows/host/mocks/execution_helper.go b/pkg/workflows/host/mocks/execution_helper.go new file mode 100644 index 0000000000..455fde3041 --- /dev/null +++ b/pkg/workflows/host/mocks/execution_helper.go @@ -0,0 +1,396 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + + time "time" + + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + v2 "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" + mock "github.com/stretchr/testify/mock" +) + +// MockExecutionHelper is an autogenerated mock type for the ExecutionHelper type +type MockExecutionHelper struct { + mock.Mock +} + +type MockExecutionHelper_Expecter struct { + mock *mock.Mock +} + +func (_m *MockExecutionHelper) EXPECT() *MockExecutionHelper_Expecter { + return &MockExecutionHelper_Expecter{mock: &_m.Mock} +} + +// CallCapability provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelper) CallCapability(ctx context.Context, request *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for CallCapability") + } + + var r0 *sdk.CapabilityResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.CapabilityRequest) *sdk.CapabilityResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sdk.CapabilityResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.CapabilityRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_CallCapability_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CallCapability' +type MockExecutionHelper_CallCapability_Call struct { + *mock.Call +} + +// CallCapability is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.CapabilityRequest +func (_e *MockExecutionHelper_Expecter) CallCapability(ctx interface{}, request interface{}) *MockExecutionHelper_CallCapability_Call { + return &MockExecutionHelper_CallCapability_Call{Call: _e.mock.On("CallCapability", ctx, request)} +} + +func (_c *MockExecutionHelper_CallCapability_Call) Run(run func(ctx context.Context, request *sdk.CapabilityRequest)) *MockExecutionHelper_CallCapability_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.CapabilityRequest)) + }) + return _c +} + +func (_c *MockExecutionHelper_CallCapability_Call) Return(_a0 *sdk.CapabilityResponse, _a1 error) *MockExecutionHelper_CallCapability_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_CallCapability_Call) RunAndReturn(run func(context.Context, *sdk.CapabilityRequest) (*sdk.CapabilityResponse, error)) *MockExecutionHelper_CallCapability_Call { + _c.Call.Return(run) + return _c +} + +// EmitUserLog provides a mock function with given fields: log +func (_m *MockExecutionHelper) EmitUserLog(log string) error { + ret := _m.Called(log) + + if len(ret) == 0 { + panic("no return value specified for EmitUserLog") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(log) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockExecutionHelper_EmitUserLog_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmitUserLog' +type MockExecutionHelper_EmitUserLog_Call struct { + *mock.Call +} + +// EmitUserLog is a helper method to define mock.On call +// - log string +func (_e *MockExecutionHelper_Expecter) EmitUserLog(log interface{}) *MockExecutionHelper_EmitUserLog_Call { + return &MockExecutionHelper_EmitUserLog_Call{Call: _e.mock.On("EmitUserLog", log)} +} + +func (_c *MockExecutionHelper_EmitUserLog_Call) Run(run func(log string)) *MockExecutionHelper_EmitUserLog_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockExecutionHelper_EmitUserLog_Call) Return(_a0 error) *MockExecutionHelper_EmitUserLog_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_EmitUserLog_Call) RunAndReturn(run func(string) error) *MockExecutionHelper_EmitUserLog_Call { + _c.Call.Return(run) + return _c +} + +// EmitUserMetric provides a mock function with given fields: ctx, metric +func (_m *MockExecutionHelper) EmitUserMetric(ctx context.Context, metric *v2.WorkflowUserMetric) error { + ret := _m.Called(ctx, metric) + + if len(ret) == 0 { + panic("no return value specified for EmitUserMetric") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *v2.WorkflowUserMetric) error); ok { + r0 = rf(ctx, metric) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockExecutionHelper_EmitUserMetric_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EmitUserMetric' +type MockExecutionHelper_EmitUserMetric_Call struct { + *mock.Call +} + +// EmitUserMetric is a helper method to define mock.On call +// - ctx context.Context +// - metric *v2.WorkflowUserMetric +func (_e *MockExecutionHelper_Expecter) EmitUserMetric(ctx interface{}, metric interface{}) *MockExecutionHelper_EmitUserMetric_Call { + return &MockExecutionHelper_EmitUserMetric_Call{Call: _e.mock.On("EmitUserMetric", ctx, metric)} +} + +func (_c *MockExecutionHelper_EmitUserMetric_Call) Run(run func(ctx context.Context, metric *v2.WorkflowUserMetric)) *MockExecutionHelper_EmitUserMetric_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*v2.WorkflowUserMetric)) + }) + return _c +} + +func (_c *MockExecutionHelper_EmitUserMetric_Call) Return(_a0 error) *MockExecutionHelper_EmitUserMetric_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_EmitUserMetric_Call) RunAndReturn(run func(context.Context, *v2.WorkflowUserMetric) error) *MockExecutionHelper_EmitUserMetric_Call { + _c.Call.Return(run) + return _c +} + +// GetDONTime provides a mock function with no fields +func (_m *MockExecutionHelper) GetDONTime() (time.Time, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetDONTime") + } + + var r0 time.Time + var r1 error + if rf, ok := ret.Get(0).(func() (time.Time, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_GetDONTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetDONTime' +type MockExecutionHelper_GetDONTime_Call struct { + *mock.Call +} + +// GetDONTime is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetDONTime() *MockExecutionHelper_GetDONTime_Call { + return &MockExecutionHelper_GetDONTime_Call{Call: _e.mock.On("GetDONTime")} +} + +func (_c *MockExecutionHelper_GetDONTime_Call) Run(run func()) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetDONTime_Call) Return(_a0 time.Time, _a1 error) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_GetDONTime_Call) RunAndReturn(run func() (time.Time, error)) *MockExecutionHelper_GetDONTime_Call { + _c.Call.Return(run) + return _c +} + +// GetNodeTime provides a mock function with no fields +func (_m *MockExecutionHelper) GetNodeTime() time.Time { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetNodeTime") + } + + var r0 time.Time + if rf, ok := ret.Get(0).(func() time.Time); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(time.Time) + } + + return r0 +} + +// MockExecutionHelper_GetNodeTime_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetNodeTime' +type MockExecutionHelper_GetNodeTime_Call struct { + *mock.Call +} + +// GetNodeTime is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetNodeTime() *MockExecutionHelper_GetNodeTime_Call { + return &MockExecutionHelper_GetNodeTime_Call{Call: _e.mock.On("GetNodeTime")} +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) Run(run func()) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) Return(_a0 time.Time) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetNodeTime_Call) RunAndReturn(run func() time.Time) *MockExecutionHelper_GetNodeTime_Call { + _c.Call.Return(run) + return _c +} + +// GetSecrets provides a mock function with given fields: ctx, request +func (_m *MockExecutionHelper) GetSecrets(ctx context.Context, request *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for GetSecrets") + } + + var r0 []*sdk.SecretResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.GetSecretsRequest) []*sdk.SecretResponse); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*sdk.SecretResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.GetSecretsRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockExecutionHelper_GetSecrets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSecrets' +type MockExecutionHelper_GetSecrets_Call struct { + *mock.Call +} + +// GetSecrets is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.GetSecretsRequest +func (_e *MockExecutionHelper_Expecter) GetSecrets(ctx interface{}, request interface{}) *MockExecutionHelper_GetSecrets_Call { + return &MockExecutionHelper_GetSecrets_Call{Call: _e.mock.On("GetSecrets", ctx, request)} +} + +func (_c *MockExecutionHelper_GetSecrets_Call) Run(run func(ctx context.Context, request *sdk.GetSecretsRequest)) *MockExecutionHelper_GetSecrets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.GetSecretsRequest)) + }) + return _c +} + +func (_c *MockExecutionHelper_GetSecrets_Call) Return(_a0 []*sdk.SecretResponse, _a1 error) *MockExecutionHelper_GetSecrets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockExecutionHelper_GetSecrets_Call) RunAndReturn(run func(context.Context, *sdk.GetSecretsRequest) ([]*sdk.SecretResponse, error)) *MockExecutionHelper_GetSecrets_Call { + _c.Call.Return(run) + return _c +} + +// GetWorkflowExecutionID provides a mock function with no fields +func (_m *MockExecutionHelper) GetWorkflowExecutionID() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetWorkflowExecutionID") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockExecutionHelper_GetWorkflowExecutionID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetWorkflowExecutionID' +type MockExecutionHelper_GetWorkflowExecutionID_Call struct { + *mock.Call +} + +// GetWorkflowExecutionID is a helper method to define mock.On call +func (_e *MockExecutionHelper_Expecter) GetWorkflowExecutionID() *MockExecutionHelper_GetWorkflowExecutionID_Call { + return &MockExecutionHelper_GetWorkflowExecutionID_Call{Call: _e.mock.On("GetWorkflowExecutionID")} +} + +func (_c *MockExecutionHelper_GetWorkflowExecutionID_Call) Run(run func()) *MockExecutionHelper_GetWorkflowExecutionID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockExecutionHelper_GetWorkflowExecutionID_Call) Return(_a0 string) *MockExecutionHelper_GetWorkflowExecutionID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockExecutionHelper_GetWorkflowExecutionID_Call) RunAndReturn(run func() string) *MockExecutionHelper_GetWorkflowExecutionID_Call { + _c.Call.Return(run) + return _c +} + +// NewMockExecutionHelper creates a new instance of MockExecutionHelper. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockExecutionHelper(t interface { + mock.TestingT + Cleanup(func()) +}) *MockExecutionHelper { + mock := &MockExecutionHelper{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/host/mocks/module.go b/pkg/workflows/host/mocks/module.go new file mode 100644 index 0000000000..8576d1238e --- /dev/null +++ b/pkg/workflows/host/mocks/module.go @@ -0,0 +1,207 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + + host "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + mock "github.com/stretchr/testify/mock" +) + +// Module is an autogenerated mock type for the Module type +type Module struct { + mock.Mock +} + +type Module_Expecter struct { + mock *mock.Mock +} + +func (_m *Module) EXPECT() *Module_Expecter { + return &Module_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with no fields +func (_m *Module) Close() { + _m.Called() +} + +// Module_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type Module_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *Module_Expecter) Close() *Module_Close_Call { + return &Module_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *Module_Close_Call) Run(run func()) *Module_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_Close_Call) Return() *Module_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *Module_Close_Call) RunAndReturn(run func()) *Module_Close_Call { + _c.Run(run) + return _c +} + +// Execute provides a mock function with given fields: ctx, request, handler +func (_m *Module) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper) (*sdk.ExecutionResult, error) { + ret := _m.Called(ctx, request, handler) + + if len(ret) == 0 { + panic("no return value specified for Execute") + } + + var r0 *sdk.ExecutionResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)); ok { + return rf(ctx, request, handler) + } + if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) *sdk.ExecutionResult); ok { + r0 = rf(ctx, request, handler) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sdk.ExecutionResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) error); ok { + r1 = rf(ctx, request, handler) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Module_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' +type Module_Execute_Call struct { + *mock.Call +} + +// Execute is a helper method to define mock.On call +// - ctx context.Context +// - request *sdk.ExecuteRequest +// - handler host.ExecutionHelper +func (_e *Module_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *Module_Execute_Call { + return &Module_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} +} + +func (_c *Module_Execute_Call) Run(run func(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper)) *Module_Execute_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*sdk.ExecuteRequest), args[2].(host.ExecutionHelper)) + }) + return _c +} + +func (_c *Module_Execute_Call) Return(_a0 *sdk.ExecutionResult, _a1 error) *Module_Execute_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Module_Execute_Call) RunAndReturn(run func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)) *Module_Execute_Call { + _c.Call.Return(run) + return _c +} + +// IsLegacyDAG provides a mock function with no fields +func (_m *Module) IsLegacyDAG() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsLegacyDAG") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Module_IsLegacyDAG_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsLegacyDAG' +type Module_IsLegacyDAG_Call struct { + *mock.Call +} + +// IsLegacyDAG is a helper method to define mock.On call +func (_e *Module_Expecter) IsLegacyDAG() *Module_IsLegacyDAG_Call { + return &Module_IsLegacyDAG_Call{Call: _e.mock.On("IsLegacyDAG")} +} + +func (_c *Module_IsLegacyDAG_Call) Run(run func()) *Module_IsLegacyDAG_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_IsLegacyDAG_Call) Return(_a0 bool) *Module_IsLegacyDAG_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Module_IsLegacyDAG_Call) RunAndReturn(run func() bool) *Module_IsLegacyDAG_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with no fields +func (_m *Module) Start() { + _m.Called() +} + +// Module_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type Module_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +func (_e *Module_Expecter) Start() *Module_Start_Call { + return &Module_Start_Call{Call: _e.mock.On("Start")} +} + +func (_c *Module_Start_Call) Run(run func()) *Module_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Module_Start_Call) Return() *Module_Start_Call { + _c.Call.Return() + return _c +} + +func (_c *Module_Start_Call) RunAndReturn(run func()) *Module_Start_Call { + _c.Run(run) + return _c +} + +// NewModule creates a new instance of Module. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewModule(t interface { + mock.TestingT + Cleanup(func()) +}) *Module { + mock := &Module{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/workflows/host/module.go b/pkg/workflows/host/module.go new file mode 100644 index 0000000000..f4debbb922 --- /dev/null +++ b/pkg/workflows/host/module.go @@ -0,0 +1,48 @@ +//go:generate go run ./requirements_gen + +package host + +import ( + "context" + "time" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" +) + +type ModuleBase interface { + Start() + Close() + IsLegacyDAG() bool +} + +type Module interface { + ModuleBase + + // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution + Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) +} + +type RequirementEnforcingModule interface { + Module + + // SetRequirements must respect the requirements for the execution until it completes + SetRequirements(executionId string, requirements *sdkpb.Requirements) +} + +// ExecutionHelper Implemented by those running the host, for example the Workflow Engine +type ExecutionHelper interface { + // CallCapability blocking call to the Workflow Engine + CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) + GetSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) + + GetWorkflowExecutionID() string + + GetNodeTime() time.Time + + GetDONTime() (time.Time, error) + + EmitUserLog(log string) error + + EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error +} diff --git a/pkg/workflows/host/requirement_selecting_module.go b/pkg/workflows/host/requirement_selecting_module.go new file mode 100644 index 0000000000..7f83aedf94 --- /dev/null +++ b/pkg/workflows/host/requirement_selecting_module.go @@ -0,0 +1,114 @@ +package host + +import ( + "context" + "fmt" + "sync" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type ModuleAndHandler struct { + Module + RequirementsHandler +} + +// lazyModule wraps a ModuleAndHandler so that Start is called at most once. +type lazyModule struct { + ModuleAndHandler + startOnce sync.Once + started bool +} + +func (l *lazyModule) ensureStarted() { + l.startOnce.Do(func() { + l.Module.Start() + l.started = true + }) +} + +// NewRequirementSelectingModule creates a module that routes trigger executions +// based on subscription requirements. main is prepended as modules[0]; additional +// modules follow. Subscribe always runs on modules[0]. +func NewRequirementSelectingModule(main ModuleAndHandler, additional []ModuleAndHandler) Module { + modules := make([]*lazyModule, 1+len(additional)) + modules[0] = &lazyModule{ModuleAndHandler: main} + for i, a := range additional { + modules[1+i] = &lazyModule{ModuleAndHandler: a} + } + return &requirementSelectingModule{modules: modules} +} + +type triggerInfo struct { + moduleIdx int + requirements *sdk.Requirements +} + +type requirementSelectingModule struct { + modules []*lazyModule + // triggerID → triggerInfo + cache sync.Map +} + +func (r *requirementSelectingModule) Start() { + r.modules[0].ensureStarted() +} + +func (r *requirementSelectingModule) Close() { + for _, m := range r.modules { + if m.started { + m.Close() + } + } +} + +func (r *requirementSelectingModule) IsLegacyDAG() bool { + return r.modules[0].IsLegacyDAG() +} + +func (r *requirementSelectingModule) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + if request.GetTrigger() == nil { + return r.subscribe(ctx, request, handler) + } + return r.trigger(ctx, request, handler) +} + +func (r *requirementSelectingModule) subscribe(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + result, err := r.modules[0].Execute(ctx, request, handler) + if err != nil { + return nil, err + } + + for i, sub := range result.GetTriggerSubscriptions().GetSubscriptions() { + matched := false + for j, m := range r.modules { + if CheckRequirements(ctx, m.RequirementsHandler, sub.Requirements) { + m.ensureStarted() + r.cache.Store(uint64(i), triggerInfo{moduleIdx: j, requirements: sub.Requirements}) + matched = true + break + } + } + if !matched { + return nil, fmt.Errorf("cannot find a runner that can satisfy the requirements for trigger %d", i) + } + } + + return result, nil +} + +func (r *requirementSelectingModule) trigger(ctx context.Context, request *sdk.ExecuteRequest, handler ExecutionHelper) (*sdk.ExecutionResult, error) { + trigger := request.GetTrigger() + if val, cached := r.cache.Load(trigger.Id); cached { + info := val.(triggerInfo) + m := r.modules[info.moduleIdx] + if rem, ok := m.Module.(RequirementEnforcingModule); ok && info.requirements != nil { + rem.SetRequirements(handler.GetWorkflowExecutionID(), info.requirements) + } + + return m.Execute(ctx, request, handler) + } + return r.modules[0].Execute(ctx, request, handler) +} + +var _ Module = &requirementSelectingModule{} diff --git a/pkg/workflows/host/requirement_selecting_module_test.go b/pkg/workflows/host/requirement_selecting_module_test.go new file mode 100644 index 0000000000..1ffd5d94f7 --- /dev/null +++ b/pkg/workflows/host/requirement_selecting_module_test.go @@ -0,0 +1,575 @@ +package host + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +type stubModule struct { + startFn func() + closeFn func() + legacyFn func() bool + executeFn func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) +} + +func (s *stubModule) Start() { s.startFn() } +func (s *stubModule) Close() { s.closeFn() } +func (s *stubModule) IsLegacyDAG() bool { return s.legacyFn() } +func (s *stubModule) Execute(ctx context.Context, req *sdk.ExecuteRequest, h ExecutionHelper) (*sdk.ExecutionResult, error) { + return s.executeFn(ctx, req, h) +} + +type requirementEnforcingStub struct { + *stubModule + setRequirementsFn func(string, *sdk.Requirements) +} + +func (s *requirementEnforcingStub) SetRequirements(executionID string, requirements *sdk.Requirements) { + s.setRequirementsFn(executionID, requirements) +} + +func noop() {} +func noopClose() {} + +func triggerRequest(id uint64) *sdk.ExecuteRequest { + return &sdk.ExecuteRequest{ + Request: &sdk.ExecuteRequest_Trigger{ + Trigger: &sdk.Trigger{Id: id}, + }, + } +} + +func subscribeRequest() *sdk.ExecuteRequest { + return &sdk.ExecuteRequest{ + Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}, + } +} + +func subscribeResult(subs ...*sdk.TriggerSubscription) *sdk.ExecutionResult { + return &sdk.ExecutionResult{ + Result: &sdk.ExecutionResult_TriggerSubscriptions{ + TriggerSubscriptions: &sdk.TriggerSubscriptionRequest{ + Subscriptions: subs, + }, + }, + } +} + +func subWithReqs(reqs *sdk.Requirements) *sdk.TriggerSubscription { + return &sdk.TriggerSubscription{Requirements: reqs} +} + +func TestRequirementSelectingModule_Start(t *testing.T) { + t.Run("starts only main module", func(t *testing.T) { + var mainStarted, additionalStarted bool + main := ModuleAndHandler{Module: &stubModule{startFn: func() { mainStarted = true }}} + add := ModuleAndHandler{Module: &stubModule{startFn: func() { additionalStarted = true }}} + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + assert.True(t, mainStarted) + assert.False(t, additionalStarted) + }) +} + +func TestRequirementSelectingModule_Close(t *testing.T) { + t.Run("closes main and no additional when none started", func(t *testing.T) { + var mainClosed, addClosed bool + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, closeFn: func() { mainClosed = true }, + }} + add := ModuleAndHandler{Module: &stubModule{ + startFn: noop, closeFn: func() { addClosed = true }, + }} + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + m.Close() + + assert.True(t, mainClosed) + assert.False(t, addClosed) + }) + + t.Run("closes main and all started additional modules", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + var mainClosed, add0Closed, add1Closed bool + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + closeFn: func() { mainClosed = true }, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add0 := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: func() { add0Closed = true }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + add1 := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: func() { add1Closed = true }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + m.Close() + + assert.True(t, mainClosed, "main should be closed") + assert.True(t, add0Closed, "started additional should be closed") + assert.False(t, add1Closed, "never-started additional should not be closed") + }) +} + +func TestRequirementSelectingModule_IsLegacyDAG(t *testing.T) { + main := ModuleAndHandler{Module: &stubModule{legacyFn: func() bool { return true }}} + m := NewRequirementSelectingModule(main, nil) + assert.True(t, m.IsLegacyDAG()) +} + +func TestRequirementSelectingModule_Execute(t *testing.T) { + t.Run("trigger with no cached entry goes to main", func(t *testing.T) { + want := &sdk.ExecutionResult{} + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + return want, nil + } + return subscribeResult(), nil + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + got, err := m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("main error on subscribe propagates", func(t *testing.T) { + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return nil, assert.AnError + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + assert.ErrorIs(t, err, assert.AnError) + }) + + t.Run("subscribe with requirements routes trigger to additional", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("subscribe with unmatched requirements returns error", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{startFn: noop}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner that can satisfy the requirements") + }) + + t.Run("subscribe skips non-matching and selects later additional", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add0 := ModuleAndHandler{ + Module: &stubModule{startFn: noop}, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + add1 := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add0, add1}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("additional module started lazily during subscribe", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var addStartCount int32 + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: func() { atomic.AddInt32(&addStartCount, 1) }, + closeFn: noopClose, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + assert.Equal(t, int32(0), atomic.LoadInt32(&addStartCount)) + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + + // Second subscribe does not start additional again (sync.Once). + _, err = m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&addStartCount)) + }) + + t.Run("subscribe with no requirements returns main result", func(t *testing.T) { + want := subscribeResult() + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return want, nil + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + got, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + }) + + t.Run("main module satisfying requirements keeps trigger on main", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + + var mainTriggerCalls int32 + main := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return want, nil + } + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + t.Fatal("additional module should not be called when main satisfies requirements") + return nil, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls), "trigger should run on main") + }) + + t.Run("cached trigger sets requirements before execute", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + want := &sdk.ExecutionResult{} + executionID := "wf-exec-1" + + main := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }}, + } + + var calls []string + var gotReqs *sdk.Requirements + var gotExecutionID string + enforcingAdd := &requirementEnforcingStub{ + stubModule: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + calls = append(calls, "execute") + return want, nil + }, + }, + setRequirementsFn: func(id string, requirements *sdk.Requirements) { + calls = append(calls, "set") + gotExecutionID = id + gotReqs = requirements + }, + } + add := ModuleAndHandler{ + Module: enforcingAdd, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + helper := &MockExecutionHelper{} + helper.On("GetWorkflowExecutionID").Return(executionID).Once() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + got, err := m.Execute(t.Context(), triggerRequest(0), helper) + require.NoError(t, err) + assert.Equal(t, want, got) + assert.Equal(t, []string{"set", "execute"}, calls) + assert.Equal(t, executionID, gotExecutionID) + assert.Same(t, teeReqs, gotReqs) + helper.AssertExpectations(t) + }) +} + +func TestRequirementSelectingModule_TriggerCache(t *testing.T) { + t.Run("cached trigger skips main on subsequent calls", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var mainTriggerCalls int32 + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + } + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls), "cached trigger should skip main") + + _, err = m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls), "cached trigger should skip main on repeat") + }) + + t.Run("trigger not in cache goes to main", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + var mainTriggerCalls int32 + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return &sdk.ExecutionResult{}, nil + } + // subscription 0 has requirements; subscription 1 does not + return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, + closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return &sdk.ExecutionResult{}, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + // trigger 1 has no requirements → goes to main + _, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls)) + }) + + t.Run("different triggers route to different modules", func(t *testing.T) { + // subscription 0: TEE required → additional; subscription 1: no requirements → main + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{ + Item: &sdk.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdk.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO}}, + }}, + }} + var mainTriggerCalls int32 + wantAdditional := &sdk.ExecutionResult{} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, req *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + if req.GetTrigger() != nil { + atomic.AddInt32(&mainTriggerCalls, 1) + return &sdk.ExecutionResult{}, nil + } + return subscribeResult(subWithReqs(teeReqs), subWithReqs(nil)), nil + }, + }} + add := ModuleAndHandler{ + Module: &stubModule{ + startFn: noop, closeFn: noopClose, + executeFn: func(context.Context, *sdk.ExecuteRequest, ExecutionHelper) (*sdk.ExecutionResult, error) { + return wantAdditional, nil + }, + }, + RequirementsHandler: RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }}, + } + + m := NewRequirementSelectingModule(main, []ModuleAndHandler{add}) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.NoError(t, err) + + // trigger 0 has TEE requirements → additional + got, err := m.Execute(t.Context(), triggerRequest(0), nil) + require.NoError(t, err) + assert.Equal(t, wantAdditional, got) + assert.Equal(t, int32(0), atomic.LoadInt32(&mainTriggerCalls)) + + // trigger 1 has no requirements → main + _, err = m.Execute(t.Context(), triggerRequest(1), nil) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&mainTriggerCalls)) + }) + + t.Run("no additional modules when subscribe has requirements returns error", func(t *testing.T) { + teeReqs := &sdk.Requirements{Tee: &sdk.Tee{}} + + main := ModuleAndHandler{Module: &stubModule{ + startFn: noop, + executeFn: func(_ context.Context, _ *sdk.ExecuteRequest, _ ExecutionHelper) (*sdk.ExecutionResult, error) { + return subscribeResult(subWithReqs(teeReqs)), nil + }, + }} + + m := NewRequirementSelectingModule(main, nil) + m.Start() + + _, err := m.Execute(t.Context(), subscribeRequest(), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "cannot find a runner") + }) +} diff --git a/pkg/workflows/host/requirements_gen/main.go b/pkg/workflows/host/requirements_gen/main.go new file mode 100644 index 0000000000..b4b4bd3e6f --- /dev/null +++ b/pkg/workflows/host/requirements_gen/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "bytes" + _ "embed" + "log" + "os" + "reflect" + "text/template" + + "github.com/smartcontractkit/chainlink-common/pkg/utils/codegen" + sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +//go:embed requirements_helper.go.tmpl +var tmplSrc string + +type fieldInfo struct { + Name string + Type string +} + +type templateData struct { + Fields []fieldInfo +} + +func main() { + requirementsType := reflect.TypeOf(sdk.Requirements{}) + + var fields []fieldInfo + for i := 0; i < requirementsType.NumField(); i++ { + f := requirementsType.Field(i) + if !f.IsExported() { + continue + } + fields = append(fields, fieldInfo{ + Name: f.Name, + Type: f.Type.String(), + }) + } + + tmpl, err := template.New("requirements_helper").Parse(tmplSrc) + if err != nil { + log.Fatalf("failed to parse template: %v", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, templateData{Fields: fields}); err != nil { + log.Fatalf("failed to execute template: %v", err) + } + + const outFile = "requirements_helper_gen.go" + settings := codegen.PrettySettings{ + Tool: "requirements_gen", + GoPrettySettings: codegen.GoPrettySettings{ + LocalPrefix: "github.com/smartcontractkit/chainlink-common", + }, + } + + content, err := codegen.PrettyFile(outFile, buf.String(), settings) + if err != nil { + log.Fatalf("failed to format generated code: %v\n%s", err, buf.String()) + } + + if err := os.WriteFile(outFile, []byte(content), 0644); err != nil { + log.Fatalf("failed to write output: %v", err) + } +} diff --git a/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl new file mode 100644 index 0000000000..5ec3a5d9cb --- /dev/null +++ b/pkg/workflows/host/requirements_gen/requirements_helper.go.tmpl @@ -0,0 +1,35 @@ +package host + +import "ctx" + +// RequirementsHandler contains a callback for each public field in sdk.Requirements. +// Each callback receives the field value and returns a list of strings or an error. +type RequirementsHandler struct { +{{- range .Fields}} + {{.Name}} func(context.Context, {{.Type}}) bool +{{- end}} +} + +// CheckRequirements calls each non-nil callback in the handler for the corresponding +// non-nil field in req, returning false if any are false, or if the handler is nil. +// Unknown fields on the proto also result in a false return value. +func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { + if req == nil { + return true + } + + if len(req.ProtoReflect().GetUnknown()) != 0 { + return false + } + +{{range .Fields}} + if req.{{.Name}} != nil { + if handler.{{.Name}} == nil || !handler.{{.Name}}(ctx, req.{{.Name}}) { + return false + } + + } +{{end}} + + return true +} diff --git a/pkg/workflows/host/requirements_helper_gen.go b/pkg/workflows/host/requirements_helper_gen.go new file mode 100644 index 0000000000..85fb0e8b7c --- /dev/null +++ b/pkg/workflows/host/requirements_helper_gen.go @@ -0,0 +1,37 @@ +// Code generated by requirements_gen, DO NOT EDIT. + +package host + +import ( + "context" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +// RequirementsHandler contains a callback for each public field in sdk.Requirements. +// Each callback receives the field value and returns a list of strings or an error. +type RequirementsHandler struct { + Tee func(context.Context, *sdk.Tee) bool +} + +// CheckRequirements calls each non-nil callback in the handler for the corresponding +// non-nil field in req, returning false if any are false, or if the handler is nil. +// Unknown fields on the proto also result in a false return value. +func CheckRequirements(ctx context.Context, handler RequirementsHandler, req *sdk.Requirements) bool { + if req == nil { + return true + } + + if len(req.ProtoReflect().GetUnknown()) != 0 { + return false + } + + if req.Tee != nil { + if handler.Tee == nil || !handler.Tee(ctx, req.Tee) { + return false + } + + } + + return true +} diff --git a/pkg/workflows/host/requirements_helper_gen_test.go b/pkg/workflows/host/requirements_helper_gen_test.go new file mode 100644 index 0000000000..68bb23b3d7 --- /dev/null +++ b/pkg/workflows/host/requirements_helper_gen_test.go @@ -0,0 +1,48 @@ +package host + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func Test_CheckRequirements(t *testing.T) { + t.Parallel() + t.Run("unknown proto fields", func(t *testing.T) { + // Encode a field number (99) unknown to Requirements so proto.Unmarshal + // preserves it as unknown bytes. + b := protowire.AppendTag(nil, 99, protowire.VarintType) + b = protowire.AppendVarint(b, 1) + req := &sdk.Requirements{} + require.NoError(t, proto.Unmarshal(b, req)) + + assert.False(t, CheckRequirements(context.Background(), RequirementsHandler{}, req)) + }) + + t.Run("no fields always passes", func(t *testing.T) { + assert.True(t, CheckRequirements(context.Background(), RequirementsHandler{}, &sdk.Requirements{})) + }) + + t.Run("handler not set returns false", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + assert.False(t, CheckRequirements(context.Background(), RequirementsHandler{}, req)) + }) + + t.Run("handler returns false causes false return value", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + handler := RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return false }} + assert.False(t, CheckRequirements(context.Background(), handler, req)) + }) + + t.Run("handler returns true causes true return value", func(t *testing.T) { + req := &sdk.Requirements{Tee: &sdk.Tee{}} + handler := RequirementsHandler{Tee: func(context.Context, *sdk.Tee) bool { return true }} + assert.True(t, CheckRequirements(context.Background(), handler, req)) + }) +} diff --git a/pkg/workflows/host/tee_provider.go b/pkg/workflows/host/tee_provider.go new file mode 100644 index 0000000000..4b9344be8d --- /dev/null +++ b/pkg/workflows/host/tee_provider.go @@ -0,0 +1,53 @@ +package host + +import sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" + +type teeProvider struct { + sdkpb.TeeType + regions map[string]bool +} + +func NewTeeProvider(tpe sdkpb.TeeType, regions []string) func(tee *sdkpb.Tee) bool { + supportedRegions := map[string]bool{} + for _, region := range regions { + supportedRegions[region] = true + } + return (&teeProvider{TeeType: tpe, regions: supportedRegions}).Provides +} + +func (t *teeProvider) Provides(tee *sdkpb.Tee) bool { + var regions []string + switch teet := tee.Item.(type) { + case *sdkpb.Tee_AnyRegions: + regions = teet.AnyRegions.Regions + case *sdkpb.Tee_TeeTypesAndRegions: + if teet.TeeTypesAndRegions == nil { + return false + } + + found := false + for _, tr := range teet.TeeTypesAndRegions.TeeTypeAndRegions { + if tr.Type == t.TeeType { + found = true + regions = tr.Regions + break + } + } + + if !found { + return false + } + } + + if len(regions) == 0 { + return true + } + + for _, region := range regions { + if t.regions[region] { + return true + } + } + + return false +} diff --git a/pkg/workflows/host/tee_provider_test.go b/pkg/workflows/host/tee_provider_test.go new file mode 100644 index 0000000000..0ca39602ea --- /dev/null +++ b/pkg/workflows/host/tee_provider_test.go @@ -0,0 +1,171 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func TestNewTeeProvider(t *testing.T) { + t.Parallel() + t.Run("matches any", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, p.Provides(tee)) + }) + + t.Run("matches type selection with matching region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(99)}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("does not match different types", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType(99)} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches type and region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("matches type but not region", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches one of multiple requested regions", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, regions: map[string]bool{"eu-west-1": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2", "eu-west-1"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("provider has multiple regions and one matches", func(t *testing.T) { + p := teeProvider{ + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + } + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-east-1"}}, + }, + }, + }} + assert.True(t, p.Provides(tee)) + }) + + t.Run("no matching region across multiple provider regions", func(t *testing.T) { + p := teeProvider{ + TeeType: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + regions: map[string]bool{"us-west-2": true, "us-east-1": true}, + } + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"ap-southeast-1"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("type mismatch ignores region match", func(t *testing.T) { + p := teeProvider{TeeType: sdkpb.TeeType(99), regions: map[string]bool{"us-west-2": true}} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }} + assert.False(t, p.Provides(tee)) + }) + + t.Run("matches any tee", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, provides(tee)) + }) + + t.Run("returns a function that checks regions", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }, + }, + }} + assert.False(t, provides(tee)) + }) + + t.Run("returns false when tee item is nil", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{} + assert.True(t, provides(tee)) + }) + + t.Run("AnyRegions with empty region list returns false", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{}}} + assert.True(t, provides(tee)) + }) + + t.Run("TeeTypesAndRegions with empty region list returns true", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + }, + }}} + assert.True(t, provides(tee)) + }) + + t.Run("TeeTypesAndRegions with nil regions returns false", func(t *testing.T) { + provides := NewTeeProvider(sdkpb.TeeType_TEE_TYPE_AWS_NITRO, []string{"us-west-2"}) + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: nil}, + }, + }}} + assert.True(t, provides(tee)) + }) +} diff --git a/pkg/workflows/host/tee_selection_provider.go b/pkg/workflows/host/tee_selection_provider.go new file mode 100644 index 0000000000..c31938e4db --- /dev/null +++ b/pkg/workflows/host/tee_selection_provider.go @@ -0,0 +1,49 @@ +package host + +import ( + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func NewProviderFromSelection(types []*sdkpb.TeeTypeAndRegions) func(tee *sdkpb.Tee) bool { + if len(types) == 1 { + return NewTeeProvider(types[0].Type, types[0].Regions) + } + + supplies := make(map[sdkpb.TeeType][]string) + for _, t := range types { + supplies[t.Type] = append(supplies[t.Type], t.Regions...) + } + + providers := make(map[sdkpb.TeeType]func(tee *sdkpb.Tee) bool) + for k, v := range supplies { + providers[k] = NewTeeProvider(k, v) + } + + return func(tee *sdkpb.Tee) bool { + switch teet := tee.Item.(type) { + case *sdkpb.Tee_AnyRegions: + for _, provider := range providers { + if provider(tee) { + return true + } + } + + return false + case *sdkpb.Tee_TeeTypesAndRegions: + for _, requestedType := range teet.TeeTypesAndRegions.TeeTypeAndRegions { + provider, ok := providers[requestedType.Type] + if !ok { + continue + } + + if provider(tee) { + return true + } + } + + return false + default: + return false + } + } +} diff --git a/pkg/workflows/host/tee_selection_provider_test.go b/pkg/workflows/host/tee_selection_provider_test.go new file mode 100644 index 0000000000..4cb10d1dbe --- /dev/null +++ b/pkg/workflows/host/tee_selection_provider_test.go @@ -0,0 +1,231 @@ +package host + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func TestNewProviderFromSelection(t *testing.T) { + t.Parallel() + + t.Run("returns false for nil selection", func(t *testing.T) { + provider := NewProviderFromSelection(nil) + assert.False(t, provider(&sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}})) + }) + + t.Run("single type selection delegates to tee provider", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types support any tee", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multiple types merges regions for same type", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"eu-west-1"}}, + }) + + regions := []string{"eu-west-1"} + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: regions, + }}}}} + assert.True(t, provider(tee)) + regions[0] = "us-west-2" + assert.True(t, provider(tee)) + }) + + t.Run("multiple types returns false when requested type is not supplied", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType(999), + Regions: []string{"us-west-2"}, + }}}}} + assert.False(t, provider(tee)) + }) + + t.Run("returns false for unsupported tee shape", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + }}) + assert.False(t, provider(&sdkpb.Tee{})) + }) + + t.Run("multi-type AnyRegions returns false when no provider matches", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + // AnyRegions with a region that doesn't match any provider's regions + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"ap-southeast-1"}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with nil TeeTypesAndRegions returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: nil}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions returns false when all requested types not in providers", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + // Request types that don't exist in providers + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(999), Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(888), Regions: []string{"eu-west-1"}}, + }, + }}} + assert.False(t, provider(tee)) + }) + + t.Run("single type TeeTypesAndRegions with non-matching region returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"eu-west-1"}, + }}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions partial match skips non-providers", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + // Request both AWS_NITRO and an unknown type; AWS_NITRO should match + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("single type returns directly without closure for AnyRegions", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"us-west-2"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("single type returns false for non-matching AnyRegions", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{{ + Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, + Regions: []string{"us-west-2"}, + }}) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with empty TeeTypeAndRegions array returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{}, + }}} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type AnyRegions with empty regions list returns false", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{}}} + assert.False(t, provider(tee)) + }) + + t.Run("multiple types with no regions in first type", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_AnyRegions{AnyRegions: &sdkpb.Regions{Regions: []string{"eu-west-1"}}}} + assert.True(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions with first type not matching then match on second", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(888), Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }, + }}} + assert.True(t, provider(tee)) + }) + + t.Run("unsupported item type in multi-type scenario", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(999), Regions: []string{"eu-west-1"}}, + }) + + tee := &sdkpb.Tee{} + assert.False(t, provider(tee)) + }) + + t.Run("multi-type TeeTypesAndRegions all types not in providers with continue path", func(t *testing.T) { + provider := NewProviderFromSelection([]*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + {Type: sdkpb.TeeType(555), Regions: []string{"us-west-2"}}, + }) + + // Request a type that is never in providers - forces continue on every iteration + tee := &sdkpb.Tee{Item: &sdkpb.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdkpb.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdkpb.TeeTypeAndRegions{ + {Type: sdkpb.TeeType(777), Regions: []string{"us-west-2"}}, + }, + }}} + assert.False(t, provider(tee)) + }) +} diff --git a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go index dfdad8114c..e18a5fcd59 100644 --- a/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go +++ b/pkg/workflows/wasm/host/internal/rawsdk/helpers_wasip1.go @@ -58,6 +58,7 @@ func SendError(err error) { func SendSubscription(subscriptions *sdk.TriggerSubscriptionRequest) { execResult := &sdk.ExecutionResult{Result: &sdk.ExecutionResult_TriggerSubscriptions{TriggerSubscriptions: subscriptions}} sendResponse(BufferToPointerLen(Must(proto.Marshal(execResult)))) + os.Exit(0) } func Now() time.Time { diff --git a/pkg/workflows/wasm/host/mocks/module_v2.go b/pkg/workflows/wasm/host/mocks/module_v2.go index 4c84a3b4ae..dbcb7aa0c1 100644 --- a/pkg/workflows/wasm/host/mocks/module_v2.go +++ b/pkg/workflows/wasm/host/mocks/module_v2.go @@ -1,207 +1,20 @@ -// Code generated by mockery v2.53.3. DO NOT EDIT. - package mocks import ( - context "context" + "github.com/stretchr/testify/mock" - host "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host" - sdk "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" - mock "github.com/stretchr/testify/mock" + hostmocks "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" ) -// ModuleV2 is an autogenerated mock type for the ModuleV2 type -type ModuleV2 struct { - mock.Mock -} - -type ModuleV2_Expecter struct { - mock *mock.Mock -} - -func (_m *ModuleV2) EXPECT() *ModuleV2_Expecter { - return &ModuleV2_Expecter{mock: &_m.Mock} -} - -// Close provides a mock function with no fields -func (_m *ModuleV2) Close() { - _m.Called() -} - -// ModuleV2_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' -type ModuleV2_Close_Call struct { - *mock.Call -} - -// Close is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) Close() *ModuleV2_Close_Call { - return &ModuleV2_Close_Call{Call: _e.mock.On("Close")} -} - -func (_c *ModuleV2_Close_Call) Run(run func()) *ModuleV2_Close_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_Close_Call) Return() *ModuleV2_Close_Call { - _c.Call.Return() - return _c -} - -func (_c *ModuleV2_Close_Call) RunAndReturn(run func()) *ModuleV2_Close_Call { - _c.Run(run) - return _c -} - -// Execute provides a mock function with given fields: ctx, request, handler -func (_m *ModuleV2) Execute(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper) (*sdk.ExecutionResult, error) { - ret := _m.Called(ctx, request, handler) - - if len(ret) == 0 { - panic("no return value specified for Execute") - } - - var r0 *sdk.ExecutionResult - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)); ok { - return rf(ctx, request, handler) - } - if rf, ok := ret.Get(0).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) *sdk.ExecutionResult); ok { - r0 = rf(ctx, request, handler) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*sdk.ExecutionResult) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) error); ok { - r1 = rf(ctx, request, handler) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ModuleV2_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' -type ModuleV2_Execute_Call struct { - *mock.Call -} - -// Execute is a helper method to define mock.On call -// - ctx context.Context -// - request *sdk.ExecuteRequest -// - handler host.ExecutionHelper -func (_e *ModuleV2_Expecter) Execute(ctx interface{}, request interface{}, handler interface{}) *ModuleV2_Execute_Call { - return &ModuleV2_Execute_Call{Call: _e.mock.On("Execute", ctx, request, handler)} -} - -func (_c *ModuleV2_Execute_Call) Run(run func(ctx context.Context, request *sdk.ExecuteRequest, handler host.ExecutionHelper)) *ModuleV2_Execute_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*sdk.ExecuteRequest), args[2].(host.ExecutionHelper)) - }) - return _c -} - -func (_c *ModuleV2_Execute_Call) Return(_a0 *sdk.ExecutionResult, _a1 error) *ModuleV2_Execute_Call { - _c.Call.Return(_a0, _a1) - return _c -} +// ModuleV2 is a backward-compatible alias for hostmocks.Module. +// The ModuleV2 interface now lives in pkg/workflows/host as Module; +// this alias keeps existing consumers compiling without changes. +type ModuleV2 = hostmocks.Module -func (_c *ModuleV2_Execute_Call) RunAndReturn(run func(context.Context, *sdk.ExecuteRequest, host.ExecutionHelper) (*sdk.ExecutionResult, error)) *ModuleV2_Execute_Call { - _c.Call.Return(run) - return _c -} - -// IsLegacyDAG provides a mock function with no fields -func (_m *ModuleV2) IsLegacyDAG() bool { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for IsLegacyDAG") - } - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -// ModuleV2_IsLegacyDAG_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsLegacyDAG' -type ModuleV2_IsLegacyDAG_Call struct { - *mock.Call -} - -// IsLegacyDAG is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) IsLegacyDAG() *ModuleV2_IsLegacyDAG_Call { - return &ModuleV2_IsLegacyDAG_Call{Call: _e.mock.On("IsLegacyDAG")} -} - -func (_c *ModuleV2_IsLegacyDAG_Call) Run(run func()) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_IsLegacyDAG_Call) Return(_a0 bool) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *ModuleV2_IsLegacyDAG_Call) RunAndReturn(run func() bool) *ModuleV2_IsLegacyDAG_Call { - _c.Call.Return(run) - return _c -} - -// Start provides a mock function with no fields -func (_m *ModuleV2) Start() { - _m.Called() -} - -// ModuleV2_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' -type ModuleV2_Start_Call struct { - *mock.Call -} - -// Start is a helper method to define mock.On call -func (_e *ModuleV2_Expecter) Start() *ModuleV2_Start_Call { - return &ModuleV2_Start_Call{Call: _e.mock.On("Start")} -} - -func (_c *ModuleV2_Start_Call) Run(run func()) *ModuleV2_Start_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ModuleV2_Start_Call) Return() *ModuleV2_Start_Call { - _c.Call.Return() - return _c -} - -func (_c *ModuleV2_Start_Call) RunAndReturn(run func()) *ModuleV2_Start_Call { - _c.Run(run) - return _c -} - -// NewModuleV2 creates a new instance of ModuleV2. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. +// NewModuleV2 creates a new instance of ModuleV2 (alias for hostmocks.NewModule). func NewModuleV2(t interface { mock.TestingT Cleanup(func()) }) *ModuleV2 { - mock := &ModuleV2{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock + return hostmocks.NewModule(t) } diff --git a/pkg/workflows/wasm/host/module.go b/pkg/workflows/wasm/host/module.go index 07b42a00eb..b1ad3731a7 100644 --- a/pkg/workflows/wasm/host/module.go +++ b/pkg/workflows/wasm/host/module.go @@ -21,6 +21,8 @@ import ( "github.com/bytecodealliance/wasmtime-go/v28" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host" + "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -31,7 +33,6 @@ import ( wasmdagpb "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/pb" sdkpb "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" "github.com/smartcontractkit/chainlink-protos/cre/go/values" - wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" ) const v2ImportPrefix = "version_v2" @@ -104,11 +105,7 @@ type ModuleConfig struct { Determinism *DeterminismConfig } -type ModuleBase interface { - Start() - Close() - IsLegacyDAG() bool -} +type ModuleBase = host.ModuleBase type ModuleV1 interface { ModuleBase @@ -117,29 +114,9 @@ type ModuleV1 interface { Run(ctx context.Context, request *wasmdagpb.Request) (*wasmdagpb.Response, error) } -type ModuleV2 interface { - ModuleBase - - // V2/"NoDAG" API - request either the list of Trigger Subscriptions or launch workflow execution - Execute(ctx context.Context, request *sdkpb.ExecuteRequest, handler ExecutionHelper) (*sdkpb.ExecutionResult, error) -} - -// ExecutionHelper Implemented by those running the host, for example the Workflow Engine -type ExecutionHelper interface { - // CallCapability blocking call to the Workflow Engine - CallCapability(ctx context.Context, request *sdkpb.CapabilityRequest) (*sdkpb.CapabilityResponse, error) - GetSecrets(ctx context.Context, request *sdkpb.GetSecretsRequest) ([]*sdkpb.SecretResponse, error) +type ModuleV2 = host.Module - GetWorkflowExecutionID() string - - GetNodeTime() time.Time - - GetDONTime() (time.Time, error) - - EmitUserLog(log string) error - - EmitUserMetric(ctx context.Context, metric *wfpb.WorkflowUserMetric) error -} +type ExecutionHelper = host.ExecutionHelper type module struct { engine *wasmtime.Engine diff --git a/pkg/workflows/wasm/host/module_test.go b/pkg/workflows/wasm/host/module_test.go index ff7307ca79..0d0e55ee72 100644 --- a/pkg/workflows/wasm/host/module_test.go +++ b/pkg/workflows/wasm/host/module_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/custmsg" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/utils/matches" @@ -620,7 +622,7 @@ func Test_SdkLabeler(t *testing.T) { // CallAwaitRace validates that every call can be awaited. func Test_CallAwaitRace(t *testing.T) { ctx := t.Context() - mockExecHelper := NewMockExecutionHelper(t) + mockExecHelper := mocks.NewMockExecutionHelper(t) mockExecHelper.EXPECT(). CallCapability(matches.AnyContext, mock.Anything). Return(&sdkpb.CapabilityResponse{}, nil) diff --git a/pkg/workflows/wasm/host/standard_test.go b/pkg/workflows/wasm/host/standard_test.go index 426d343cdd..525d0a1c80 100644 --- a/pkg/workflows/wasm/host/standard_test.go +++ b/pkg/workflows/wasm/host/standard_test.go @@ -25,6 +25,8 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/emptypb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/actionandtrigger" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basicaction" "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" @@ -47,7 +49,7 @@ func init() { func TestStandardConfig(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Some languages call time during initiation of the executable before the main is called. // This would be in unknown mode, which would call Node mode by default. @@ -63,7 +65,7 @@ func TestStandardConfig(t *testing.T) { func TestStandardErrors(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -77,7 +79,7 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { // To ensure the calls are actually async, the mock will block the first call until the second call is made. // The first call sets InputThing to true, the second to false. t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -120,7 +122,7 @@ func TestStandardCapabilityCallsAreAsync(t *testing.T) { func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() }).Maybe() @@ -152,7 +154,7 @@ func TestStandardHostWasmWriteErrorsAreRespected(t *testing.T) { func TestStandardModeSwitch(t *testing.T) { t.Parallel() t.Run("successful mode switch", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Node calls may occur on initialization depending on the language. var donCall bool @@ -192,7 +194,7 @@ func TestStandardModeSwitch(t *testing.T) { }) t.Run("node runtime in don mode", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -217,7 +219,7 @@ func TestStandardModeSwitch(t *testing.T) { }) t.Run("don runtime in node mode", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -252,7 +254,7 @@ func TestStandardModeSwitch(t *testing.T) { func TestStandardLogging(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -272,7 +274,7 @@ func TestStandardMultipleTriggers(t *testing.T) { t.Parallel() m := makeTestModule(t) t.Run("test registration", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -324,7 +326,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("first callback", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -338,7 +340,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("same trigger as first one but different registration", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -351,7 +353,7 @@ func TestStandardMultipleTriggers(t *testing.T) { }) t.Run("different capability callback", func(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -370,7 +372,7 @@ func TestStandardRandom(t *testing.T) { // Test binary executes node mode code conditionally based on the value >= 100 anyId := "Id" - gte100Exec := NewMockExecutionHelper(t) + gte100Exec := mocks.NewMockExecutionHelper(t) gte100Exec.EXPECT().GetWorkflowExecutionID().Return(anyId) gte100Exec.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -400,7 +402,7 @@ func TestStandardRandom(t *testing.T) { value1 := executeWithResult[any](t, m, anyRequest, gte100Exec) t.Run("Same execution id gives the same randoms even if random is called in node mode", func(t *testing.T) { - lt100Exec := NewMockExecutionHelper(t) + lt100Exec := mocks.NewMockExecutionHelper(t) lt100Exec.EXPECT().GetWorkflowExecutionID().Return(anyId) lt100Exec.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -424,7 +426,7 @@ func TestStandardRandom(t *testing.T) { t.Run("Different execution id give different randoms", func(t *testing.T) { require.NoError(t, err) - gte100Exec2 := NewMockExecutionHelper(t) + gte100Exec2 := mocks.NewMockExecutionHelper(t) gte100Exec2.EXPECT().GetWorkflowExecutionID().Return("differentId") gte100Exec2.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -470,7 +472,7 @@ func TestStandardSecrets(t *testing.T) { } func TestStandardSecretsFailInNodeMode(t *testing.T) { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -504,7 +506,7 @@ func TestStandardSecretsFailInNodeMode(t *testing.T) { func TestStandardTimeInterpretation(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") // Inject fixed timestamp: 1577934245000 milliseconds = 2020-01-02T03:04:05Z fixedTime := time.UnixMilli(1577934245000) @@ -523,6 +525,62 @@ func TestStandardTimeInterpretation(t *testing.T) { require.Equal(t, "2020-01-02T03:04:05Z", result) } +func TestStandardTeeRuntime(t *testing.T) { + t.Parallel() + + cfg := defaultNoDAGModCfg(t) + m := makeTestModuleWithConfig(t, cfg) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) + mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") + mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { + return time.Now() + }).Maybe() + + subscribe := &sdk.ExecuteRequest{Request: &sdk.ExecuteRequest_Subscribe{Subscribe: &emptypb.Empty{}}} + actual, err := m.Execute(t.Context(), subscribe, mockExecutionHelper) + require.NoError(t, err) + + payload0, err := anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + }) + require.NoError(t, err) + + payload1, err := anypb.New(&basictrigger.Config{ + Name: "second-trigger", + Number: 200, + }) + require.NoError(t, err) + + expected := &sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: payload0, + Method: "Trigger", + Requirements: &sdk.Requirements{ + Tee: &sdk.Tee{ + Item: &sdk.Tee_TeeTypesAndRegions{ + TeeTypesAndRegions: &sdk.TeeTypesAndRegions{ + TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{ + {Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}, + }, + }, + }, + }, + }, + }, + { + Id: "basic-test-trigger@1.0.0", + Payload: payload1, + Method: "Trigger", + }, + }, + } + + assertProto(t, expected, actual.GetTriggerSubscriptions()) +} + func triggerExecuteRequest(t *testing.T, id uint64, trigger proto.Message) *sdk.ExecuteRequest { wrappedTrigger, err := anypb.New(trigger) require.NoError(t, err) @@ -549,8 +607,12 @@ func runWithBasicTrigger(t *testing.T, executor ExecutionHelper) *sdk.ExecutionR // To re-use a binary, an outer test can create the module and use t.Run to run subtests using that module. // When subtests have their own binaries, those binaries are expected to be nested in a subfolder. func makeTestModule(t *testing.T) *module { + return makeTestModuleWithConfig(t, nil) +} + +func makeTestModuleWithConfig(t *testing.T, cfg *ModuleConfig) *module { testName := strcase.ToSnake(t.Name()[len("TestStandard"):]) - return makeTestModuleByName(t, testName, nil) + return makeTestModuleByName(t, testName, cfg) } func makeTestModuleByName(t *testing.T, testName string, cfg *ModuleConfig) *module { @@ -559,6 +621,7 @@ func makeTestModuleByName(t *testing.T, testName string, cfg *ModuleConfig) *mod absPath, err := filepath.Abs(testPath) require.NoError(t, err, "Failed to get absolute path for test directory") cmd.Dir = absPath + fmt.Printf("Compiling test module from %s with command %s\n:", cmd.Dir, cmd.String()) output, err := cmd.CombinedOutput() require.NoError(t, err, string(output)) @@ -637,6 +700,7 @@ func wrapValue(t *testing.T, nodeResponse *nodeaction.NodeOutputs) *valuespb.Val func assertProto[T proto.Message](t *testing.T, expected, actual T) { t.Helper() + require.NotNil(t, actual) diff := cmp.Diff(expected, actual, protocmp.Transform()) var sb strings.Builder @@ -649,7 +713,7 @@ func assertProto[T proto.Message](t *testing.T, expected, actual T) { } func runSecretTest(t *testing.T, m *module, secretResponse *sdk.SecretResponse) *sdk.ExecutionResult { - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("Id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() diff --git a/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go new file mode 100644 index 0000000000..a6fdfec748 --- /dev/null +++ b/pkg/workflows/wasm/host/standard_tests/tee_runtime/main_wasip1.go @@ -0,0 +1,36 @@ +package main + +import ( + "google.golang.org/protobuf/types/known/anypb" + + "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" + "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" +) + +func main() { + requirements := &sdk.Requirements{Tee: &sdk.Tee{Item: &sdk.Tee_TeeTypesAndRegions{TeeTypesAndRegions: &sdk.TeeTypesAndRegions{TeeTypeAndRegions: []*sdk.TeeTypeAndRegions{{Type: sdk.TeeType_TEE_TYPE_AWS_NITRO, Regions: []string{"us-west-2"}}}}}}} + subscription := &sdk.TriggerSubscriptionRequest{ + Subscriptions: []*sdk.TriggerSubscription{ + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "first-trigger", + Number: 100, + })), + Method: "Trigger", + Requirements: requirements, + }, + { + Id: "basic-test-trigger@1.0.0", + Payload: rawsdk.Must(anypb.New(&basictrigger.Config{ + Name: "second-trigger", + Number: 200, + })), + Method: "Trigger", + }, + }, + } + + rawsdk.SendSubscription(subscription) +} diff --git a/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go b/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go new file mode 100644 index 0000000000..bd31aa32c8 --- /dev/null +++ b/pkg/workflows/wasm/host/test/requirements/invalid_memory/main_wasip1.go @@ -0,0 +1,13 @@ +package main + +import ( + "unsafe" + + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" +) + +func main() { + buf := make([]byte, 4) + rawsdk.Requirements(unsafe.Pointer(&buf[0]), 100) + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go b/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go new file mode 100644 index 0000000000..1df7e1da8f --- /dev/null +++ b/pkg/workflows/wasm/host/test/requirements/invalid_proto/main_wasip1.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/smartcontractkit/chainlink-common/pkg/workflows/wasm/host/internal/rawsdk" +) + +func main() { + rawsdk.Requirements(rawsdk.BufferToPointerLen([]byte{0x3E, 0x80, 0xFF, 0x0A, 0xFF, 0x01, 0x0C, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01})) + rawsdk.SendResponse(0) +} diff --git a/pkg/workflows/wasm/host/time_test.go b/pkg/workflows/wasm/host/time_test.go index b792588c99..7f7139b230 100644 --- a/pkg/workflows/wasm/host/time_test.go +++ b/pkg/workflows/wasm/host/time_test.go @@ -8,13 +8,14 @@ import ( "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" ) func TestTimeFetcher_GetTime_NODE(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) expected := time.Now() mockExec.EXPECT().GetNodeTime().Return(expected) @@ -29,7 +30,7 @@ func TestTimeFetcher_GetTime_NODE(t *testing.T) { func TestTimeFetcher_GetTime_DON(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) expected := time.Now() mockExec.EXPECT().GetDONTime().Return(expected, nil) @@ -44,7 +45,7 @@ func TestTimeFetcher_GetTime_DON(t *testing.T) { func TestTimeFetcher_GetTime_DON_Error(t *testing.T) { ctx := t.Context() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Return(time.Time{}, errors.New("don error")) tf := newTimeFetcher(ctx, mockExec) @@ -58,7 +59,7 @@ func TestTimeFetcher_ContextCancelledBeforeRequest(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) cancel() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Return(time.Time{}, context.Canceled).Maybe() tf := newTimeFetcher(ctx, mockExec) @@ -81,7 +82,7 @@ func TestTimeFetcher_ContextCancelledDuringResponse(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - mockExec := NewMockExecutionHelper(t) + mockExec := mocks.NewMockExecutionHelper(t) mockExec.EXPECT().GetDONTime().Run(func() { time.Sleep(20 * time.Millisecond) // force timeout }).Return(time.Time{}, nil) diff --git a/pkg/workflows/wasm/host/wasm_nodag_test.go b/pkg/workflows/wasm/host/wasm_nodag_test.go index 917692b523..b68fd1b730 100644 --- a/pkg/workflows/wasm/host/wasm_nodag_test.go +++ b/pkg/workflows/wasm/host/wasm_nodag_test.go @@ -12,6 +12,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/capabilities/v2/protoc/pkg/test_capabilities/basictrigger" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/settings/limits" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/host/mocks" "github.com/smartcontractkit/chainlink-protos/cre/go/sdk" wfpb "github.com/smartcontractkit/chainlink-protos/workflows/go/v2" @@ -43,7 +44,7 @@ func Test_Sleep_Timeout(t *testing.T) { m.Start() defer m.Close() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -197,7 +198,7 @@ func Test_NoDag_Run(t *testing.T) { func Test_NoDAG_LoggingWithLimits(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -239,7 +240,7 @@ func Test_NoDAG_LoggingWithLimits(t *testing.T) { func Test_NoDAG_EmitMetricWithLimits(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -285,7 +286,7 @@ func Test_NoDAG_EmitMetricWithLimits(t *testing.T) { func Test_NoDAG_EmitMetricDisabled(t *testing.T) { t.Parallel() - mockExecutionHelper := NewMockExecutionHelper(t) + mockExecutionHelper := mocks.NewMockExecutionHelper(t) mockExecutionHelper.EXPECT().GetWorkflowExecutionID().Return("id") mockExecutionHelper.EXPECT().GetNodeTime().RunAndReturn(func() time.Time { return time.Now() @@ -319,7 +320,7 @@ func defaultNoDAGModCfg(t testing.TB) *ModuleConfig { } func getTriggersSpec(t *testing.T, m ModuleV2, config []byte) (*sdk.TriggerSubscriptionRequest, error) { - helper := NewMockExecutionHelper(t) + helper := mocks.NewMockExecutionHelper(t) helper.EXPECT().GetWorkflowExecutionID().Return("Id") helper.EXPECT().GetNodeTime().Return(time.Now()).Maybe() execResult, err := m.Execute(t.Context(), &sdk.ExecuteRequest{