diff --git a/Makefile b/Makefile index 4826559b..eac6360d 100644 --- a/Makefile +++ b/Makefile @@ -194,3 +194,13 @@ start-docs-website: generate-ca-certificates: scripts/generate-ca-certificates.sh cli/internal/client/ca-certificates.pem + + +start-proxy: install-cli + tyger-proxy start -f <($(MAKE) get-proxy-config) + +run-proxy: install-cli + tyger-proxy run -f <($(MAKE) get-proxy-config) + +kill-proxy: + killall -s SIGINT tyger-proxy diff --git a/Makefile.cloud b/Makefile.cloud index b1c8bcc5..1d8fd173 100644 --- a/Makefile.cloud +++ b/Makefile.cloud @@ -256,12 +256,6 @@ get-proxy-config: $${auth_parameters} " -start-proxy: install-cli - tyger-proxy start -f <($(MAKE) get-proxy-config) - -kill-proxy: - killall tyger-proxy - connect-db: set-context helm_values=$$(helm get values -n ${HELM_NAMESPACE} ${HELM_RELEASE} -o json || true) diff --git a/cli/cmd/tyger-proxy/main.go b/cli/cmd/tyger-proxy/main.go index 3c4b1ee2..bdae0dc1 100644 --- a/cli/cmd/tyger-proxy/main.go +++ b/cli/cmd/tyger-proxy/main.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net/url" "os" "path" "path/filepath" @@ -133,45 +134,71 @@ func readProxyOptions(optionsFilePath string, options *tygerproxy.ProxyOptions) return errors.New("serverUrl must be specified") } - if options.ManagedIdentity { - if options.ServicePrincipal != "" { - return errors.New("servicePrincipal cannot be specified when using managed identity") + parsedUrl, err := url.Parse(options.ServerUrl) + if err != nil { + return fmt.Errorf("invalid serverUrl: %v", err) + } + + if parsedUrl.Scheme == "ssh" || parsedUrl.Scheme == "http+unix" { + if options.ManagedIdentity { + return errors.New("managedIdentity cannot be specified when using SSH or Unix socket connection") } - if options.CertificatePath != "" { - return errors.New("certificatePath cannot be specified when using managed identity") + if options.GitHub { + return errors.New("github cannot be specified when using SSH or Unix socket connection") } - if options.CertificateThumbprint != "" { - return errors.New("certificateThumbprint cannot be specified when using managed identity") - } - } else if options.GitHub { if options.ServicePrincipal != "" { - return errors.New("servicePrincipal cannot be specified when using GitHub authentication") + return errors.New("servicePrincipal cannot be specified when using SSH or Unix socket connection") } if options.CertificatePath != "" { - return errors.New("certificatePath cannot be specified when using GitHub authentication") + return errors.New("certificatePath cannot be specified when using SSH or Unix socket connection") } if options.CertificateThumbprint != "" { - return errors.New("certificateThumbprint cannot be specified when using GitHub authentication") + return errors.New("certificateThumbprint cannot be specified when using SSH or Unix socket connection") } - } else { - if options.ServicePrincipal == "" { - return errors.New("if both managedIdentity and github are both not true, servicePrincipal must be specified in the options file") + if options.TargetFederatedIdentity != "" { + return errors.New("targetFederatedIdentity cannot be specified when using SSH or Unix socket connection") } - - if runtime.GOOS == "windows" { - if options.CertificatePath == "" && options.CertificateThumbprint == "" { - return errors.New("either certificatePath or certificateThumbprint must be specified in the options file") + } else { + if options.ManagedIdentity { + if options.ServicePrincipal != "" { + return errors.New("servicePrincipal cannot be specified when using managed identity") } + if options.CertificatePath != "" { + return errors.New("certificatePath cannot be specified when using managed identity") + } + if options.CertificateThumbprint != "" { + return errors.New("certificateThumbprint cannot be specified when using managed identity") + } + } else if options.GitHub { + if options.ServicePrincipal != "" { + return errors.New("servicePrincipal cannot be specified when using GitHub authentication") + } + if options.CertificatePath != "" { + return errors.New("certificatePath cannot be specified when using GitHub authentication") + } + if options.CertificateThumbprint != "" { + return errors.New("certificateThumbprint cannot be specified when using GitHub authentication") + } + } else { + if options.ServicePrincipal == "" { + return errors.New("if both managedIdentity and github are both not true, servicePrincipal must be specified in the options file") + } + + if runtime.GOOS == "windows" { + if options.CertificatePath == "" && options.CertificateThumbprint == "" { + return errors.New("either certificatePath or certificateThumbprint must be specified in the options file") + } - if options.CertificatePath != "" && options.CertificateThumbprint != "" { - return errors.New("certificatePath and certificateThumbprint cannot both be specified") + if options.CertificatePath != "" && options.CertificateThumbprint != "" { + return errors.New("certificatePath and certificateThumbprint cannot both be specified") + } + } else if options.CertificatePath == "" { + return errors.New("certificatePath must be specified in the options file") } - } else if options.CertificatePath == "" { - return errors.New("certificatePath must be specified in the options file") - } - if options.TargetFederatedIdentity != "" { - return errors.New("targetFederatedIdentity cannot be specified when using service principal authentication") + if options.TargetFederatedIdentity != "" { + return errors.New("targetFederatedIdentity cannot be specified when using service principal authentication") + } } } diff --git a/cli/cmd/tyger-proxy/proxyrun.go b/cli/cmd/tyger-proxy/proxyrun.go index faae40c3..caea60db 100644 --- a/cli/cmd/tyger-proxy/proxyrun.go +++ b/cli/cmd/tyger-proxy/proxyrun.go @@ -4,10 +4,13 @@ package main import ( + "context" "io" "os" + "os/signal" "path" "path/filepath" + "syscall" "github.com/microsoft/tyger/cli/internal/controlplane" "github.com/microsoft/tyger/cli/internal/logging" @@ -66,12 +69,19 @@ func newProxyRunCommand(optionsFilePath *string, options *tygerproxy.ProxyOption log.Info().Str("path", logFile.Name()).Msg("Logging to file") } - client, err := controlplane.Login(cmd.Context(), options.LoginConfig) + // Set up signal handling for graceful shutdown + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + client, serviceMetadata, err := controlplane.Login(ctx, options.LoginConfig) if err != nil { log.Fatal().Err(err).Msg("login failed") } - _, err = tygerproxy.RunProxy(cmd.Context(), client, options, log.Logger) + closeProxy, err := tygerproxy.RunProxy(ctx, client, options, serviceMetadata, log.Logger) if err != nil { if err == tygerproxy.ErrProxyAlreadyRunning { log.Info().Int("port", options.Port).Msg("A proxy is already running at this address.") @@ -83,8 +93,21 @@ func newProxyRunCommand(optionsFilePath *string, options *tygerproxy.ProxyOption log.Info().Int("port", options.Port).Msg(proxyIsListeningMessage) - // wait indefinitely - <-(make(chan any)) + // Wait for shutdown signal + sig := <-sigChan + log.Info().Str("signal", sig.String()).Msg("Received shutdown signal, cleaning up...") + + // Cancel context to trigger cleanup of SSH tunnels and other resources + cancel() + + // Close the proxy + if closeProxy != nil { + if err := closeProxy(); err != nil { + log.Warn().Err(err).Msg("Error closing proxy") + } + } + + log.Info().Msg("Proxy shutdown complete") }, } diff --git a/cli/integrationtest/expected_openapi_spec.yaml b/cli/integrationtest/expected_openapi_spec.yaml index 6d261a9b..dbac499f 100644 --- a/cli/integrationtest/expected_openapi_spec.yaml +++ b/cli/integrationtest/expected_openapi_spec.yaml @@ -1111,6 +1111,12 @@ components: items: type: string nullable: true + storageEndpoints: + type: array + items: + type: string + format: uri + nullable: true additionalProperties: false Socket: type: object diff --git a/cli/integrationtest/httpproxy_test.go b/cli/integrationtest/httpproxy_test.go index 37973618..4fe19fe3 100644 --- a/cli/integrationtest/httpproxy_test.go +++ b/cli/integrationtest/httpproxy_test.go @@ -9,16 +9,23 @@ import ( "fmt" "os" "path" + "path/filepath" "testing" "time" "github.com/microsoft/tyger/cli/internal/controlplane" + "github.com/microsoft/tyger/cli/internal/controlplane/model" "github.com/microsoft/tyger/cli/internal/install" "github.com/stretchr/testify/require" "sigs.k8s.io/yaml" ) -const composeFile = ` +func TestHttpProxy(t *testing.T) { + t.Parallel() + skipIfOnlyFastTests(t) + skipIfUsingUnixSocket(t) + + const composeFile = ` name: http-proxy-test services: @@ -47,12 +54,7 @@ networks: - subnet: 192.168.250.0/24 ` -func TestHttpProxy(t *testing.T) { - t.Parallel() - skipIfOnlyFastTests(t) - skipIfUsingUnixSocket(t) - - s := NewComposeSession(t) + s := NewComposeSession(t, composeFile) defer s.Cleanup() s.CommandSucceeds("create") @@ -162,14 +164,156 @@ func TestHttpProxy(t *testing.T) { s.ShellExecSucceeds("client", fmt.Sprintf("curl --fail --retry 5 http://tyger-proxy:6888/metadata && tyger login http://tyger-proxy:6888 && tyger buffer read %s > /dev/null", bufferId)) } +func TestTygerProxyOverSsh(t *testing.T) { + // Deliberately not parallel because the interactive portion of Login() does not perform retries and + // when running in GitHub actions, sshd may refuse connections if too many are opened simultaneously. + + skipIfOnlyFastTests(t) + skipIfNotUsingSSH(t) + + const composeFile = ` +name: tyger-proxy-ssh-test + +services: + tyger-proxy: + image: mcr.microsoft.com/devcontainers/base:ubuntu + command: ["tyger-proxy", "run", "-f", "/proxy-config.yml"] + network_mode: bridge + + client: + image: mcr.microsoft.com/devcontainers/base:ubuntu + command: ["sleep", "infinity"] + network_mode: bridge + +` + + s := NewComposeSession(t, composeFile) + defer s.Cleanup() + + s.CommandSucceeds("create") + + ssh_host := runCommandSucceeds(t, "docker", "inspect", "-f", "{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}", "tyger-test-ssh") + + c, _ := controlplane.GetClientFromCache() + + tempDir := t.TempDir() + + const containerKeyFile = "/root/id" + + sshUrl := *c.RawControlPlaneUrl + sshUrl.Host = ssh_host + sshUrlQuery := sshUrl.Query() + sshUrlQuery.Set("option[StrictHostKeyChecking]", "no") + sshUrlQuery.Set("option[IdentityFile]", containerKeyFile) + + sshUrl.RawQuery = sshUrlQuery.Encode() + + proxyConfig := controlplane.LoginConfig{ + ServerUrl: sshUrl.String(), + LogPath: "/logs", + } + + b, err := yaml.Marshal(proxyConfig) + require.NoError(t, err) + proxyConfigString := string(b) + proxyConfigFilePath := path.Join(tempDir, "proxy-config.yml") + require.NoError(t, os.WriteFile(proxyConfigFilePath, []byte(proxyConfigString), 0644)) + + tygerPath := runCommandSucceeds(t, "which", "tyger") + tygerProxyPath := runCommandSucceeds(t, "which", "tyger-proxy") + s.CommandSucceeds("cp", tygerPath, "tyger-proxy:/usr/local/bin/tyger") + s.CommandSucceeds("cp", tygerPath, "client:/usr/local/bin/tyger") + s.CommandSucceeds("cp", tygerProxyPath, "tyger-proxy:/usr/local/bin/tyger-proxy") + s.CommandSucceeds("cp", proxyConfigFilePath, "tyger-proxy:/proxy-config.yml") + + localKeyFile := path.Join(tempDir, "id") + runCommandSucceeds(t, "ssh-keygen", "-t", "ed25519", "-f", localKeyFile, "-N", "") + s.CommandSucceeds("cp", localKeyFile, "tyger-proxy:/root/id") + + runCommandSucceeds(t, "ssh-copy-id", "-f", "-i", localKeyFile+".pub", "tygersshhost") + + s.CommandSucceeds("start", "tyger-proxy") + defer func() { + logs := s.CommandSucceeds("logs", "tyger-proxy") + t.Log(logs) + }() + + s.CommandSucceeds("start", "client") + + runSpec := fmt.Sprintf(` +job: + codespec: + image: %s + buffers: + inputs: ["input"] + outputs: ["output"] + command: + - "sh" + - "-c" + - | + set -euo pipefail + inp=$(cat "$INPUT_PIPE") + echo "${inp}: Bonjour" > "$OUTPUT_PIPE" +timeoutSeconds: 600`, BasicImage) + + runSpecPath := filepath.Join(tempDir, "runspec.yaml") + require.NoError(t, os.WriteFile(runSpecPath, []byte(runSpec), 0644)) + + runId := runTygerSucceeds(t, "run", "create", "-f", runSpecPath) + + tygerProxyContainerId := s.CommandSucceeds("ps", "-q", "tyger-proxy") + + proxyIp := runCommandSucceeds(t, "docker", "inspect", "-f", "{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}", tygerProxyContainerId) + + // it could take some time for the proxy to be ready, so retry a few times + for attempt := 1; ; attempt++ { + _, stdErr, err := s.ShellExec("client", fmt.Sprintf("tyger login http://%s:6888", proxyIp)) + if err == nil { + break + } + if attempt >= 5 { + t.Fatalf("tyger login failed after %d attempts: %v\n%s", attempt, err, stdErr) + } + time.Sleep(time.Duration(attempt*2) * time.Second) + } + + run := model.Run{} + runYaml := s.ShellExecSucceeds("client", "tyger run show "+runId) + require.NoError(t, yaml.Unmarshal([]byte(runYaml), &run)) + + inputBufferId := run.Job.Buffers["input"] + outputBufferId := run.Job.Buffers["output"] + + s.ShellExecSucceeds("client", "echo Carl | tyger buffer write "+inputBufferId) + output := s.ShellExecSucceeds("client", "tyger buffer read "+outputBufferId) + require.Equal(t, "Carl: Bonjour", output) + + // repeat with ephemeral buffers + + runId = runTygerSucceeds(t, "run", "create", "-f", runSpecPath, "-b", "input=_", "-b", "output=_") + runYaml = s.ShellExecSucceeds("client", "tyger run show "+runId) + require.NoError(t, yaml.Unmarshal([]byte(runYaml), &run)) + + inputBufferId = run.Job.Buffers["input"] + outputBufferId = run.Job.Buffers["output"] + + _, stderr, err := s.ShellExec("client", fmt.Sprintf("echo Isabelle | tyger buffer write %s --log-level trace", inputBufferId)) + require.NoError(t, err, stderr) + t.Log(stderr) + output, stderr, err = s.ShellExec("client", fmt.Sprintf("tyger buffer read %s --log-level trace", outputBufferId)) + require.NoError(t, err, stderr) + t.Log(stderr) + require.Equal(t, "Isabelle: Bonjour", output) +} + type ComposeSession struct { t *testing.T dir string } -func NewComposeSession(t *testing.T) *ComposeSession { +func NewComposeSession(t *testing.T, composeFileContent string) *ComposeSession { s := &ComposeSession{t: t, dir: t.TempDir()} - require.NoError(t, os.WriteFile(path.Join(s.dir, "/docker-compose.yml"), []byte(composeFile), 0644)) + require.NoError(t, os.WriteFile(path.Join(s.dir, "/docker-compose.yml"), []byte(composeFileContent), 0644)) s.Cleanup() return s } diff --git a/cli/integrationtest/testutils.go b/cli/integrationtest/testutils.go index d3a28c89..fb004a4d 100644 --- a/cli/integrationtest/testutils.go +++ b/cli/integrationtest/testutils.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "github.com/microsoft/tyger/cli/internal/client" "github.com/microsoft/tyger/cli/internal/controlplane" "github.com/microsoft/tyger/cli/internal/controlplane/model" "github.com/microsoft/tyger/cli/internal/install/cloudinstall" @@ -303,6 +304,12 @@ func skipIfNotUsingUnixSocketDirectly(t *testing.T) { } } +func skipIfNotUsingSSH(t *testing.T) { + if c, _ := controlplane.GetClientFromCache(); c.ConnectionType() != client.TygerConnectionTypeSsh { + t.Skip("Skipping test because the control plane is not using SSH") + } +} + func skipIfOnlyFastTests(t *testing.T) { if *runOnlyFastTestsFlag { t.Skip("Skipping test because --fast flag is set") diff --git a/cli/integrationtest/tygerproxy_test.go b/cli/integrationtest/tygerproxy_test.go index 81779e7d..6719ab30 100644 --- a/cli/integrationtest/tygerproxy_test.go +++ b/cli/integrationtest/tygerproxy_test.go @@ -56,13 +56,15 @@ timeoutSeconds: 600`, BasicImage) runId := runTygerSucceeds(t, "run", "create", "--file", runSpecPath) tygerClient, err := controlplane.GetClientFromCache() + serviceMetadata, err := controlplane.GetServiceMetadata(t.Context(), tygerClient.ControlPlaneUrl.String(), tygerClient.ControlPlaneClient.HTTPClient) + require.NoError(err) proxyOptions := tygerproxy.ProxyOptions{} proxyLogBuffer := SyncBuffer{} logger := zerolog.New(&proxyLogBuffer) - closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, logger) + closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, serviceMetadata, logger) require.NoError(err) defer closeProxy() @@ -182,7 +184,9 @@ func TestProxiedRequestsFromAllowedCIDR(t *testing.T) { tygerClient, err := controlplane.GetClientFromCache() require.NoError(err) - closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, logger) + serviceMetadata, err := controlplane.GetServiceMetadata(t.Context(), tygerClient.ControlPlaneUrl.String(), tygerClient.ControlPlaneClient.HTTPClient) + require.NoError(err) + closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, serviceMetadata, logger) defer closeProxy() resp, err := tygerClient.ControlPlaneClient.Get(fmt.Sprintf("http://localhost:%d/metadata", proxyOptions.Port)) require.NoError(err) @@ -203,10 +207,12 @@ func TestProxiedRequestsFromDisallowedAllowedCIDR(t *testing.T) { tygerClient, err := controlplane.GetClientFromCache() require.NoError(err) + serviceMetadata, err := controlplane.GetServiceMetadata(t.Context(), tygerClient.ControlPlaneUrl.String(), tygerClient.ControlPlaneClient.HTTPClient) + require.NoError(err) + proxyLogBuffer := SyncBuffer{} logger := zerolog.New(&proxyLogBuffer) - - closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, logger) + closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, serviceMetadata, logger) defer closeProxy() resp, err := tygerClient.ControlPlaneClient.Get(fmt.Sprintf("http://localhost:%d/runs/1", proxyOptions.Port)) require.NoError(err) @@ -226,6 +232,9 @@ func TestRunningProxyOnSamePort(t *testing.T) { tygerClient, err := controlplane.GetClientFromCache() require.NoError(err) + serviceMetadata, err := controlplane.GetServiceMetadata(t.Context(), tygerClient.ControlPlaneUrl.String(), tygerClient.ControlPlaneClient.HTTPClient) + require.NoError(err) + proxyOptions := tygerproxy.ProxyOptions{ LoginConfig: controlplane.LoginConfig{ ServerUrl: tygerClient.ControlPlaneUrl.String(), @@ -234,11 +243,11 @@ func TestRunningProxyOnSamePort(t *testing.T) { proxyLogBuffer := SyncBuffer{} logger := zerolog.New(&proxyLogBuffer) - closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, logger) + closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, serviceMetadata, logger) require.NoError(err) defer closeProxy() - _, err = tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, logger) + _, err = tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, serviceMetadata, logger) require.ErrorIs(err, tygerproxy.ErrProxyAlreadyRunning) } @@ -250,6 +259,9 @@ func TestRunningProxyOnSamePortDifferentTarget(t *testing.T) { tygerClient, err := controlplane.GetClientFromCache() require.NoError(err) + serviceMetadata, err := controlplane.GetServiceMetadata(t.Context(), tygerClient.ControlPlaneUrl.String(), tygerClient.ControlPlaneClient.HTTPClient) + require.NoError(err) + proxyOptions := tygerproxy.ProxyOptions{ LoginConfig: controlplane.LoginConfig{ ServerUrl: tygerClient.ControlPlaneUrl.String(), @@ -258,14 +270,14 @@ func TestRunningProxyOnSamePortDifferentTarget(t *testing.T) { proxyLogBuffer := SyncBuffer{} logger := zerolog.New(&proxyLogBuffer) - closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, logger) + closeProxy, err := tygerproxy.RunProxy(context.Background(), tygerClient, &proxyOptions, serviceMetadata, logger) require.NoError(err) defer closeProxy() secondProxyOptions := proxyOptions secondProxyOptions.LoginConfig.ServerUrl = "http://someotherserver" - _, err = tygerproxy.RunProxy(context.Background(), tygerClient, &secondProxyOptions, logger) + _, err = tygerproxy.RunProxy(context.Background(), tygerClient, &secondProxyOptions, serviceMetadata, logger) require.ErrorIs(err, tygerproxy.ErrProxyAlreadyRunningWrongTarget) } diff --git a/cli/internal/client/sshurl.go b/cli/internal/client/sshurl.go index b9aff6eb..1b8e8566 100644 --- a/cli/internal/client/sshurl.go +++ b/cli/internal/client/sshurl.go @@ -5,6 +5,7 @@ package client import ( "fmt" + "maps" "net/url" "os" "runtime" @@ -23,6 +24,7 @@ type SshParams struct { User string SocketPath string CliPath string + ConfigPath string Options map[string]string } @@ -47,6 +49,11 @@ func ParseSshUrl(u *url.URL) (*SshParams, error) { sp.SocketPath = u.Path if queryParams := u.Query(); len(queryParams) > 0 { + if configPath, ok := queryParams["configPath"]; ok { + sp.ConfigPath = configPath[0] + queryParams.Del("configPath") + } + if cliPath, ok := queryParams["cliPath"]; ok { sp.CliPath = cliPath[0] queryParams.Del("cliPath") @@ -60,7 +67,7 @@ func ParseSshUrl(u *url.URL) (*SshParams, error) { } sp.Options[name] = v[0] } else { - return nil, errors.Errorf("unexpected query parameter: %q. Only 'cliPath' and 'option[]' are suported", k) + return nil, errors.Errorf("unexpected query parameter: %q. Only 'configPath', 'cliPath' and 'option[]' are suported", k) } } } @@ -108,25 +115,34 @@ func (sp *SshParams) FormatCmdLine(add ...string) []string { sshOptions := map[string]string{ "StrictHostKeyChecking": "yes", } - return sp.formatCmdLine(sshOptions, nil, add...) + return sp.formatCmdLine(sshOptions, nil, nil, true, add...) } -func (sp *SshParams) formatCmdLine(sshOptions map[string]string, otherSshArgs []string, cmdArgs ...string) []string { +func (sp *SshParams) formatCmdLine(defaultSshOptions map[string]string, overridingSshOptions map[string]string, otherSshArgs []string, callTyger bool, cmdArgs ...string) []string { args := []string{sp.Host} var combinedSshOptions map[string]string if sp.Options == nil { - combinedSshOptions = sshOptions - } else if sshOptions == nil { - combinedSshOptions = sp.Options + combinedSshOptions = defaultSshOptions + } else if defaultSshOptions == nil { + combinedSshOptions = make(map[string]string) + maps.Copy(combinedSshOptions, sp.Options) + } else { combinedSshOptions = make(map[string]string) - for k, v := range sshOptions { - combinedSshOptions[k] = v - } - for k, v := range sp.Options { - combinedSshOptions[k] = v + maps.Copy(combinedSshOptions, defaultSshOptions) + maps.Copy(combinedSshOptions, sp.Options) + } + + if combinedSshOptions == nil && overridingSshOptions != nil { + combinedSshOptions = make(map[string]string) + } + + if overridingSshOptions != nil { + if combinedSshOptions == nil { + combinedSshOptions = make(map[string]string) } + maps.Copy(combinedSshOptions, overridingSshOptions) } for k, v := range combinedSshOptions { @@ -139,20 +155,26 @@ func (sp *SshParams) formatCmdLine(sshOptions map[string]string, otherSshArgs [] if sp.Port != "" { args = append(args, "-p", sp.Port) } + if sp.ConfigPath != "" { + args = append(args, "-F", sp.ConfigPath) + } args = append(args, otherSshArgs...) - args = append(args, "--") + if callTyger { + args = append(args, "--") - if sp.CliPath != "" { - args = append(args, sp.CliPath) - } else { - args = append(args, "tyger") - } + if sp.CliPath != "" { + args = append(args, sp.CliPath) + } else { + args = append(args, "tyger") + } - args = append(args, "stdio-proxy") + args = append(args, "stdio-proxy") + + args = append(args, cmdArgs...) + } - args = append(args, cmdArgs...) return args } @@ -177,7 +199,25 @@ func (sp *SshParams) FormatLoginArgs(add ...string) []string { } args = append(args, add...) - return sp.formatCmdLine(sshOptions, otherSshArgs, args...) + return sp.formatCmdLine(sshOptions, nil, otherSshArgs, true, args...) +} + +func (sp *SshParams) FormatTunnelArgs(local string) []string { + overridingSshOptions := map[string]string{ + "ControlMaster": "no", + "ControlPath": "none", + "ExitOnForwardFailure": "yes", + "ServerAliveInterval": "15", + "ServerAliveCountMax": "3", + "StrictHostKeyChecking": "yes", + } + + otherSshArgs := []string{ + "-nNT", + "-L", fmt.Sprintf("%s:%s", local, sp.SocketPath), + } + + return sp.formatCmdLine(nil, overridingSshOptions, otherSshArgs, false) } func (sp *SshParams) FormatDataPlaneCmdLine(add ...string) []string { @@ -185,12 +225,13 @@ func (sp *SshParams) FormatDataPlaneCmdLine(add ...string) []string { "StrictHostKeyChecking": "yes", } + sshOverrideOptions := map[string]string{} if runtime.GOOS != "windows" { // create a dedicated control socket for this process - sshOptions["ControlMaster"] = "auto" - sshOptions["ControlPath"] = fmt.Sprintf("/tmp/%s", uuid.New().String()) - sshOptions["ControlPersist"] = "2m" + sshOverrideOptions["ControlMaster"] = "auto" + sshOverrideOptions["ControlPath"] = fmt.Sprintf("/tmp/%s", uuid.New().String()) + sshOverrideOptions["ControlPersist"] = "2m" } - return sp.formatCmdLine(sshOptions, nil, add...) + return sp.formatCmdLine(sshOptions, sshOverrideOptions, nil, true, add...) } diff --git a/cli/internal/cmd/install/accesscontrol.go b/cli/internal/cmd/install/accesscontrol.go index fb63e1c7..6bf54fe9 100644 --- a/cli/internal/cmd/install/accesscontrol.go +++ b/cli/internal/cmd/install/accesscontrol.go @@ -290,7 +290,7 @@ func newAccessControlPrettyPrintCommand() *cobra.Command { } func getAccessControlConfigFromServerUrl(ctx context.Context, serverUrl string) *cloudinstall.AccessControlConfig { - serviceMetadata, err := controlplane.GetServiceMetadata(ctx, serverUrl) + serviceMetadata, err := controlplane.GetServiceMetadata(ctx, serverUrl, nil) if err != nil { log.Fatal().Err(err).Msg("Unable to get service metadata") } diff --git a/cli/internal/cmd/login.go b/cli/internal/cmd/login.go index 298cc924..4fef48c7 100644 --- a/cli/internal/cmd/login.go +++ b/cli/internal/cmd/login.go @@ -57,7 +57,7 @@ Subsequent commands will be performed against this server.`, } options.ServerUrl = controlplane.LocalUrlSentinel - _, err = controlplane.Login(cmd.Context(), options) + _, _, err = controlplane.Login(cmd.Context(), options) return err } @@ -168,7 +168,7 @@ Subsequent commands will be performed against this server.`, options.CertificatePath = filepath.Clean(filepath.Join(filepath.Dir(optionsFilePath), options.CertificatePath)) } - _, err = controlplane.Login(cmd.Context(), options) + _, _, err = controlplane.Login(cmd.Context(), options) return err case 1: if options.ServicePrincipal != "" { @@ -240,7 +240,7 @@ Subsequent commands will be performed against this server.`, } options.ServerUrl = args[0] - _, err := controlplane.Login(cmd.Context(), options) + _, _, err := controlplane.Login(cmd.Context(), options) return err default: return errors.New("too many arguments") diff --git a/cli/internal/controlplane/login.go b/cli/internal/controlplane/login.go index 4985f0b8..d09e9a15 100644 --- a/cli/internal/controlplane/login.go +++ b/cli/internal/controlplane/login.go @@ -146,27 +146,27 @@ func (si *serviceInfo) UnmarshalJSON(data []byte) error { return nil } -func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error) { +func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, *model.ServiceMetadata, error) { if options.ServerUrl == LocalUrlSentinel { optionsClone := options optionsClone.ServerUrl = client.GetDefaultSocketUrl() - c, errUnix := Login(ctx, optionsClone) + c, sm, errUnix := Login(ctx, optionsClone) if errUnix == nil { - return c, nil + return c, sm, nil } optionsClone.ServerUrl = "docker://" - c, errDocker := Login(ctx, optionsClone) + c, sm, errDocker := Login(ctx, optionsClone) if errDocker == nil { - return c, nil + return c, sm, nil } - return nil, errUnix + return nil, sm, errUnix } normalizedServerUrl, err := NormalizeServerUrl(options.ServerUrl) if err != nil { - return nil, err + return nil, nil, err } options.ServerUrl = normalizedServerUrl.String() @@ -189,15 +189,16 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error } if err := client.SetDefaultNetworkClientSettings(&defaultClientOptions); err != nil { - return nil, err + return nil, nil, err } var tygerClient *client.TygerClient + var serviceMetadata *model.ServiceMetadata switch normalizedServerUrl.Scheme { case "docker": dockerParams, err := client.ParseDockerUrl(normalizedServerUrl) if err != nil { - return nil, fmt.Errorf("invalid Docker URL: %w", err) + return nil, nil, fmt.Errorf("invalid Docker URL: %w", err) } loginCommand := exec.CommandContext(ctx, "docker", dockerParams.FormatLoginArgs()...) @@ -205,12 +206,12 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error loginCommand.Stdout = &outb loginCommand.Stderr = &errb if err := loginCommand.Run(); err != nil { - return nil, fmt.Errorf("failed to establish a tyger connection: %w. stderr: %s", err, errb.String()) + return nil, nil, fmt.Errorf("failed to establish a tyger connection: %w. stderr: %s", err, errb.String()) } socketUrl, err := NormalizeServerUrl(outb.String()) if err != nil { - return nil, fmt.Errorf("failed to parse socket URL: %w", err) + return nil, nil, fmt.Errorf("failed to parse socket URL: %w", err) } if socketUrl.Scheme != "http+unix" { @@ -226,14 +227,19 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error controlPlaneClientOptions.CreateTransport = client.MakeCommandTransport(dockerConcurrencyLimit, "docker", dockerParams.FormatCmdLine()...) controlPlaneClient, err := client.NewControlPlaneClient(&controlPlaneClientOptions) if err != nil { - return nil, fmt.Errorf("unable to create control plane client: %w", err) + return nil, nil, fmt.Errorf("unable to create control plane client: %w", err) } dataPlaneClientOptions := controlPlaneClientOptions dataPlaneClientOptions.DisableRetries = true dataPlaneClient, err := client.NewDataPlaneClient(&dataPlaneClientOptions) if err != nil { - return nil, fmt.Errorf("unable to create data plane client: %w", err) + return nil, nil, fmt.Errorf("unable to create data plane client: %w", err) + } + + serviceMetadata, err = GetServiceMetadata(ctx, socketUrl.String(), controlPlaneClient.HTTPClient) + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch service metadata: %w", err) } tygerClient = &client.TygerClient{ @@ -248,16 +254,17 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error case "ssh": sshParams, err := client.ParseSshUrl(normalizedServerUrl) if err != nil { - return nil, fmt.Errorf("invalid ssh URL: %w", err) + return nil, nil, fmt.Errorf("invalid ssh URL: %w", err) } // Give the user a chance accept remote host key if necessary preFlightCommand := exec.CommandContext(ctx, "ssh", sshParams.FormatLoginArgs("--preflight")...) preFlightCommand.Stdin = os.Stdin preFlightCommand.Stdout = os.Stdout - preFlightCommand.Stderr = os.Stderr + var preFlightStderr bytes.Buffer + preFlightCommand.Stderr = io.MultiWriter(os.Stderr, &preFlightStderr) if err := preFlightCommand.Run(); err != nil { - return nil, fmt.Errorf("failed to establish a remote tyger connection: %w", err) + return nil, nil, fmt.Errorf("failed to establish a remote tyger connection: %w. stderr: %s", err, preFlightStderr.String()) } loginCommand := exec.CommandContext(ctx, "ssh", sshParams.FormatLoginArgs()...) @@ -265,12 +272,12 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error loginCommand.Stdout = &outb loginCommand.Stderr = &errb if err := loginCommand.Run(); err != nil { - return nil, fmt.Errorf("failed to establish a remote tyger connection: %w. stderr: %s", err, errb.String()) + return nil, nil, fmt.Errorf("failed to establish a remote tyger connection: %w. stderr: %s", err, errb.String()) } socketUrl, err := NormalizeServerUrl(outb.String()) if err != nil { - return nil, fmt.Errorf("failed to parse socket URL: %w", err) + return nil, nil, fmt.Errorf("failed to parse socket URL: %w", err) } if socketUrl.Scheme != "http+unix" { @@ -286,7 +293,7 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error controlPlaneOptions.CreateTransport = client.MakeCommandTransport(sshConcurrencyLimit, "ssh", sshParams.FormatCmdLine()...) controlPlaneClient, err := client.NewControlPlaneClient(&controlPlaneOptions) if err != nil { - return nil, fmt.Errorf("unable to create control plane client: %w", err) + return nil, nil, fmt.Errorf("unable to create control plane client: %w", err) } dataPlaneOptions := controlPlaneOptions // clone @@ -294,7 +301,12 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error dataPlaneClient, err := client.NewDataPlaneClient(&dataPlaneOptions) if err != nil { - return nil, fmt.Errorf("unable to create data plane client: %w", err) + return nil, nil, fmt.Errorf("unable to create data plane client: %w", err) + } + + serviceMetadata, err = GetServiceMetadata(ctx, socketUrl.String(), controlPlaneClient.HTTPClient) + if err != nil { + return nil, nil, fmt.Errorf("unable to fetch service metadata: %w", err) } tygerClient = &client.TygerClient{ @@ -306,14 +318,15 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error RawControlPlaneUrl: si.parsedServerUrl, RawProxy: si.parsedProxy, } + default: if err := validateServiceInfo(si); err != nil { - return nil, err + return nil, nil, err } - serviceMetadata, err := GetServiceMetadata(ctx, options.ServerUrl) + serviceMetadata, err = GetServiceMetadata(ctx, options.ServerUrl, nil) if err != nil { - return nil, err + return nil, nil, err } // augment with data received from the metadata endpoint @@ -328,7 +341,7 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error } if err := validateServiceInfo(si); err != nil { - return nil, err + return nil, nil, err } if serviceMetadata.Authority != "" { @@ -349,19 +362,19 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error // So we need to extract the client ID from the access token and use that next time. claims := jwt.MapClaims{} if _, _, err := jwt.NewParser().ParseUnverified(authResult.Token, claims); err != nil { - return nil, fmt.Errorf("unable to parse access token: %w", err) + return nil, nil, fmt.Errorf("unable to parse access token: %w", err) } else { var ok bool si.ClientId, ok = claims["appid"].(string) if !ok { - return nil, errors.New("unable to extract client ID from access token; the client is not compatible with this version of the CLI") + return nil, nil, errors.New("unable to extract client ID from access token; the client is not compatible with this version of the CLI") } } } } if err != nil { - return nil, err + return nil, nil, err } si.LastToken = authResult.Token @@ -386,12 +399,12 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error cpClient, err := client.NewControlPlaneClient(&controlPlaneOptions) if err != nil { - return nil, fmt.Errorf("unable to create control plane client: %w", err) + return nil, nil, fmt.Errorf("unable to create control plane client: %w", err) } dpClient, err := client.NewDataPlaneClient(&dataPlaneOptions) if err != nil { - return nil, fmt.Errorf("unable to create data plane client: %w", err) + return nil, nil, fmt.Errorf("unable to create data plane client: %w", err) } tygerClient = &client.TygerClient{ @@ -406,7 +419,7 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error if serviceMetadata.RbacEnabled { if roles, err := tygerClient.GetRoleAssignments(ctx); err == nil && len(roles) == 0 { - return nil, errors.New("the principal does not have any assigned roles") + return nil, nil, errors.New("the principal does not have any assigned roles") } } } @@ -414,11 +427,11 @@ func Login(ctx context.Context, options LoginConfig) (*client.TygerClient, error if options.Persisted { err = si.persist() if err != nil { - return nil, err + return nil, nil, err } } - return tygerClient, nil + return tygerClient, serviceMetadata, nil } func NormalizeServerUrl(serverUrl string) (*url.URL, error) { @@ -820,18 +833,24 @@ func readCachedServiceInfo() (*serviceInfo, error) { return &si, nil } -func GetServiceMetadata(ctx context.Context, serverUrl string) (*model.ServiceMetadata, error) { - // Not using a retryable client because when doing `tyger login --local` we first try to use the unix socket - // before trying the docker gateway and we don't want to wait for retries. +func GetServiceMetadata(ctx context.Context, serverUrl string, httpClient *http.Client) (*model.ServiceMetadata, error) { + if httpClient == nil { + // Not using a retryable client because when doing `tyger login --local` we first try to use the unix socket + // before trying the docker gateway and we don't want to wait for retries. + httpClient = client.DefaultClient.HTTPClient + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s/metadata", serverUrl), nil) if err != nil { return nil, fmt.Errorf("unable to create request: %w", err) } - resp, err := client.DefaultClient.HTTPClient.Do(req) + resp, err := httpClient.Do(req) if err != nil { return nil, err } + defer resp.Body.Close() + serviceMetadata := &model.ServiceMetadata{} if err := json.NewDecoder(resp.Body).Decode(serviceMetadata); err != nil { // Check if the server is older than the client (uses the old `/v1/` path) @@ -840,10 +859,12 @@ func GetServiceMetadata(ctx context.Context, serverUrl string) (*model.ServiceMe return nil, fmt.Errorf("unable to create request: %w", err) } - resp, err := client.DefaultClient.HTTPClient.Do(req) + resp, err := httpClient.Do(req) if err != nil { return nil, err } + defer resp.Body.Close() + respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) diff --git a/cli/internal/controlplane/model/model.go b/cli/internal/controlplane/model/model.go index 537b6f85..ee1ee071 100644 --- a/cli/internal/controlplane/model/model.go +++ b/cli/internal/controlplane/model/model.go @@ -24,6 +24,7 @@ type ServiceMetadata struct { ApiVersions []string `json:"apiVersions,omitempty"` DataPlaneProxy string `json:"dataPlaneProxy,omitempty"` // Only used by tyger-proxy TlsCaCertificates string `json:"tlsCaCertificates,omitempty"` // Only used by tyger-proxy + StorageEndpoints []string `json:"storageEndpoints,omitempty"` } type Buffer struct { diff --git a/cli/internal/dataplane/read.go b/cli/internal/dataplane/read.go index f2eafbab..40beb49a 100644 --- a/cli/internal/dataplane/read.go +++ b/cli/internal/dataplane/read.go @@ -78,7 +78,7 @@ func Read(ctx context.Context, container *Container, outputWriter io.Writer, opt readOptions.connectionType = tygerClient.ConnectionType() readOptions.httpClient = tygerClient.DataPlaneClient.Client if tygerClient.ConnectionType() == client.TygerConnectionTypeSsh && container.Scheme() == "http+unix" && !container.SupportsRelay() { - httpClient, tunnelPool, err := createSshTunnelPoolClient(ctx, tygerClient, container, readOptions.dop) + httpClient, tunnelPool, err := createSshTunnelPoolClientFromContainer(ctx, tygerClient, container, readOptions.dop) if err != nil { return err } diff --git a/cli/internal/dataplane/relayclient.go b/cli/internal/dataplane/relayclient.go index 6e0f1c2b..f8f330a2 100644 --- a/cli/internal/dataplane/relayclient.go +++ b/cli/internal/dataplane/relayclient.go @@ -142,11 +142,13 @@ func relayErrorCodeToErr(errorCode string) error { func pingRelay(ctx context.Context, containerClient *ContainerClient, connectionType client.TygerConnectionType) error { log.Ctx(ctx).Info().Msg("Attempting to connect to relay server...") - headRequest := containerClient.NewRequestWithRelativeUrl(ctx, http.MethodHead, "", nil) + headRequest := containerClient.NewNonRetryableRequestWithRelativeUrl(ctx, http.MethodHead, "", nil) + // don't use retryable client here so that we can do special error handling unknownErrCount := 0 for retryCount := 0; ; retryCount++ { - resp, err := containerClient.Do(headRequest) + containerClient.updateRequestUrl(headRequest) + resp, err := containerClient.innerClient.HTTPClient.Do(headRequest) if err == nil { io.Copy(io.Discard, resp.Body) resp.Body.Close() diff --git a/cli/internal/dataplane/tunnel.go b/cli/internal/dataplane/tunnel.go index 6c466c20..c55ab3e9 100644 --- a/cli/internal/dataplane/tunnel.go +++ b/cli/internal/dataplane/tunnel.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os/exec" + "strconv" "strings" "sync" "time" @@ -26,47 +27,59 @@ import ( // A pool of SSH tunnels for the dataplane. -func createSshTunnelPoolClient(ctx context.Context, tygerClient *client.TygerClient, container *Container, count int) (*retryablehttp.Client, *sshTunnelPool, error) { +func createSshTunnelPoolClientFromContainer(ctx context.Context, tygerClient *client.TygerClient, container *Container, count int) (*retryablehttp.Client, *SshTunnelPool, error) { + return CreateSshTunnelPoolClient(ctx, tygerClient, strings.Split(container.initialAccessUrl.Path, ":")[0], count) +} + +func CreateSshTunnelPoolClient(ctx context.Context, tygerClient *client.TygerClient, socketPath string, count int) (*retryablehttp.Client, *SshTunnelPool, error) { controlPlaneSshParams, err := client.ParseSshUrl(tygerClient.RawControlPlaneUrl) if err != nil { return nil, nil, err } dpSshParams := *controlPlaneSshParams - - dpSshParams.SocketPath = strings.Split(container.initialAccessUrl.Path, ":")[0] + dpSshParams.SocketPath = socketPath tunnelPool := NewSshTunnelPool(ctx, dpSshParams, count) httpClient := client.CloneRetryableClient(tygerClient.DataPlaneClient.Client) httpClient.RequestLogHook = func(_ retryablehttp.Logger, r *http.Request, _ int) { + if relayParam, ok := r.URL.Query()["relay"]; ok && len(relayParam) == 1 && relayParam[0] == "true" { + // don't use the tunnel pool for ephemeral buffers + return + } + r.URL = tunnelPool.GetUrl(r.URL) } return httpClient, tunnelPool, nil } -type sshTunnelPool struct { +type SshTunnelPool struct { socketPath string + sshParams client.SshParams ctx context.Context cancelCtx context.CancelFunc mutex sync.Mutex + wg sync.WaitGroup allTunnels *list.List healthyTunnels []*sshTunnel index int } -func (tp *sshTunnelPool) Close() { +func (tp *SshTunnelPool) Close() { log.Debug().Msg("Closing SSH tunnel pool") tp.cancelCtx() tp.mutex.Lock() - defer tp.mutex.Unlock() for e := tp.allTunnels.Front(); e != nil; e = e.Next() { e.Value.(*sshTunnel).Close() } + tp.mutex.Unlock() + tp.wg.Wait() + log.Debug().Msg("SSH tunnel pool closed") } -func (tp *sshTunnelPool) GetUrl(input *url.URL) *url.URL { +func (tp *SshTunnelPool) GetUrl(input *url.URL) *url.URL { tp.mutex.Lock() if len(tp.healthyTunnels) == 0 || tp.ctx.Err() != nil { tp.mutex.Unlock() @@ -98,7 +111,7 @@ func (tp *sshTunnelPool) GetUrl(input *url.URL) *url.URL { return &outputUrl } -func (tp *sshTunnelPool) watch(ctx context.Context, tunnel *sshTunnel) { +func (tp *SshTunnelPool) watch(ctx context.Context, tunnel *sshTunnel) { active := true healthCheckEndpoint := fmt.Sprintf("http://%s/healthcheck", tunnel.Host) @@ -107,58 +120,101 @@ func (tp *sshTunnelPool) watch(ctx context.Context, tunnel *sshTunnel) { panic(err) } + healthCheckTicker := time.NewTicker(5 * time.Second) + defer healthCheckTicker.Stop() + for { - if ctx.Err() != nil { + select { + case <-ctx.Done(): return - } + case exitErr := <-tunnel.exited: + log.Warn().Str("host", tunnel.Host).Err(exitErr).Msg("SSH tunnel process exited") + tp.removeTunnelFromHealthy(tunnel) + tp.wg.Go(func() { tp.recreateTunnel(ctx) }) + return + case <-healthCheckTicker.C: + _, err := http.DefaultClient.Do(req) + if err == nil { + if !active { + log.Ctx(ctx).Info().Str("host", tunnel.Host).Msg("SSH tunnel is active") + tp.mutex.Lock() + tp.healthyTunnels = append(tp.healthyTunnels, tunnel) + tp.mutex.Unlock() + active = true + } + } else { + if errors.Is(err, ctx.Err()) { + return + } - _, err := http.DefaultClient.Do(req) - if err == nil { - if !active { - log.Ctx(ctx).Info().Str("host", tunnel.Host).Msg("SSH tunnel is active") - tp.mutex.Lock() - tp.healthyTunnels = append(tp.healthyTunnels, tunnel) - tp.mutex.Unlock() - } - } else { - if errors.Is(err, ctx.Err()) { - return - } + if active { + active = false + log.Warn().Str("host", tunnel.Host).Err(err).Msg("SSH tunnel is inactive") + tp.removeTunnelFromHealthy(tunnel) - if active { - active = false - log.Warn().Str("host", tunnel.Host).Err(err).Msg("SSH tunnel is inactive") - tp.mutex.Lock() - for i, t := range tp.healthyTunnels { - if t == tunnel { - if len(tp.healthyTunnels) == 1 { - tp.healthyTunnels = tp.healthyTunnels[:0] - } else if i == len(tp.healthyTunnels)-1 { - tp.healthyTunnels = tp.healthyTunnels[:i] - } else if i == 0 { - tp.healthyTunnels = tp.healthyTunnels[1:] - } else { - tp.healthyTunnels[i] = tp.healthyTunnels[len(tp.healthyTunnels)-1] - tp.healthyTunnels = tp.healthyTunnels[:len(tp.healthyTunnels)-1] - } - break - } + // Close the dead tunnel and attempt to create a new one + tunnel.Close() + tp.wg.Go(func() { tp.recreateTunnel(ctx) }) + return } - tp.mutex.Unlock() } } + } +} - if active { - time.Sleep(1 * time.Second) +func (tp *SshTunnelPool) removeTunnelFromHealthy(tunnel *sshTunnel) { + tp.mutex.Lock() + defer tp.mutex.Unlock() + for i, t := range tp.healthyTunnels { + if t == tunnel { + tp.healthyTunnels = append(tp.healthyTunnels[:i], tp.healthyTunnels[i+1:]...) + break } } } -func NewSshTunnelPool(ctx context.Context, sshParams client.SshParams, count int) *sshTunnelPool { +func (tp *SshTunnelPool) recreateTunnel(ctx context.Context) { + for retryCount := 0; ; retryCount++ { + if ctx.Err() != nil { + return + } + + // Exponential backoff with jitter + backoff := time.Duration(rand.IntnRange(200, 1500)*(1< 30*time.Second { + backoff = 30 * time.Second + } + time.Sleep(backoff) + + tunnel, err := newSshTunnel(ctx, tp, tp.sshParams) + if err != nil { + var evt *zerolog.Event + if retryCount > 5 { + evt = log.Warn() + } else { + evt = log.Debug() + } + + evt.Err(err).Int("retryCount", retryCount).Msg("Failed to recreate tunnel") + continue + } + + log.Info().Str("host", tunnel.Host).Msg("Successfully recreated SSH tunnel") + tp.mutex.Lock() + tp.healthyTunnels = append(tp.healthyTunnels, tunnel) + tp.mutex.Unlock() + + tp.wg.Go(func() { tp.watch(ctx, tunnel) }) + return + } +} + +func NewSshTunnelPool(ctx context.Context, sshParams client.SshParams, count int) *SshTunnelPool { ctx, cancelCtx := context.WithCancel(ctx) - pool := &sshTunnelPool{ + pool := &SshTunnelPool{ socketPath: sshParams.SocketPath, + sshParams: sshParams, ctx: ctx, cancelCtx: cancelCtx, mutex: sync.Mutex{}, @@ -166,7 +222,7 @@ func NewSshTunnelPool(ctx context.Context, sshParams client.SshParams, count int } for range count { - go func() { + pool.wg.Go(func() { for retryCount := 0; ; retryCount++ { if ctx.Err() != nil { return @@ -195,10 +251,10 @@ func NewSshTunnelPool(ctx context.Context, sshParams client.SshParams, count int pool.healthyTunnels = append(pool.healthyTunnels, tunnel) pool.mutex.Unlock() - go pool.watch(ctx, tunnel) + pool.wg.Go(func() { pool.watch(ctx, tunnel) }) return } - }() + }) } return pool @@ -230,28 +286,13 @@ func (t *sshTunnel) healthCheck() error { return err } -func newSshTunnel(ctx context.Context, pool *sshTunnelPool, sshParams client.SshParams) (*sshTunnel, error) { +func newSshTunnel(ctx context.Context, pool *SshTunnelPool, sshParams client.SshParams) (*sshTunnel, error) { port, err := GetFreePort() if err != nil { return nil, err } - args := []string{ - "-nNT", - "-o", "ControlMaster=no", - "-o", "ControlPath=none", - "-o", "ExitOnForwardFailure=yes", - "-L", fmt.Sprintf("%d:%s", port, sshParams.SocketPath), - } - - if sshParams.User != "" { - args = append(args, "-l", sshParams.User) - } - if sshParams.Port != "" { - args = append(args, "-p", sshParams.Port) - } - - args = append(args, sshParams.Host) + args := sshParams.FormatTunnelArgs(strconv.Itoa(port)) cmd := exec.Command("ssh", args...) log.Debug().Int("port", port).Msg("Creating SSH tunnel...") @@ -285,7 +326,7 @@ func newSshTunnel(ctx context.Context, pool *sshTunnelPool, sshParams client.Ssh // ignore the error since we are cleaning up err = nil } - log.Debug().Int("port", port).AnErr("error", err).Bytes("stderr", stdErr.Bytes()).Msg("SSH tunnel closed") + err = fmt.Errorf("ssh tunnel closed: %w: %s", err, stdErr.String()) tunnel.exited <- err close(tunnel.exited) }() diff --git a/cli/internal/dataplane/write.go b/cli/internal/dataplane/write.go index 2eb59fe1..cb3e371d 100644 --- a/cli/internal/dataplane/write.go +++ b/cli/internal/dataplane/write.go @@ -103,7 +103,7 @@ func Write(ctx context.Context, container *Container, inputReader io.Reader, opt writeOptions.httpClient = tygerClient.DataPlaneClient.Client writeOptions.connectionType = tygerClient.ConnectionType() if tygerClient.ConnectionType() == client.TygerConnectionTypeSsh && container.Scheme() == "http+unix" && !container.SupportsRelay() { - httpClient, tunnelPool, err := createSshTunnelPoolClient(ctx, tygerClient, container, writeOptions.dop) + httpClient, tunnelPool, err := createSshTunnelPoolClientFromContainer(ctx, tygerClient, container, writeOptions.dop) if err != nil { return err } diff --git a/cli/internal/tygerproxy/tygerproxy.go b/cli/internal/tygerproxy/tygerproxy.go index 9cb0d294..aec95cf7 100644 --- a/cli/internal/tygerproxy/tygerproxy.go +++ b/cli/internal/tygerproxy/tygerproxy.go @@ -14,7 +14,9 @@ import ( "net" "net/http" "net/url" + "slices" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -22,10 +24,12 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/hashicorp/go-cleanhttp" + "github.com/hashicorp/go-retryablehttp" pool "github.com/libp2p/go-buffer-pool" "github.com/microsoft/tyger/cli/internal/client" "github.com/microsoft/tyger/cli/internal/controlplane" "github.com/microsoft/tyger/cli/internal/controlplane/model" + "github.com/microsoft/tyger/cli/internal/dataplane" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) @@ -40,6 +44,11 @@ type ProxyServiceMetadata struct { LogPath string `json:"logPath,omitempty"` } +const ( + LocalBuffersCapability = "LocalBuffers" + sshTunnelPoolSize = 8 +) + var ( ErrProxyAlreadyRunning = errors.New("the proxy is already running") ErrProxyAlreadyRunningWrongTarget = errors.New("the proxy is already running on the requested port, but targets a different server") @@ -48,17 +57,55 @@ var ( type CloseProxyFunc func() error -func RunProxy(ctx context.Context, tygerClient *client.TygerClient, options *ProxyOptions, logger zerolog.Logger) (CloseProxyFunc, error) { +func RunProxy(ctx context.Context, tygerClient *client.TygerClient, options *ProxyOptions, serviceMetadata *model.ServiceMetadata, logger zerolog.Logger) (CloseProxyFunc, error) { + // Disable retries from the proxy (the client will handle its retries) + tygerClient.ControlPlaneClient.RetryMax = 0 + tygerClient.DataPlaneClient.RetryMax = 0 + controlPlaneTargetUrl := tygerClient.ControlPlaneUrl handler := proxyHandler{ tygerClient: tygerClient, targetControlPlaneUrl: controlPlaneTargetUrl, options: options, + serviceMetadata: serviceMetadata, getCaCertsPemString: sync.OnceValue(func() string { pemBytes, _ := client.GetCaCertPemBytes(options.TlsCaCertificates) return string(pemBytes) }), - nextProxyFunc: tygerClient.DataPlaneClient.Proxy, + requiresDataPlaneConnectTunnel: !slices.Contains(serviceMetadata.Capabilities, LocalBuffersCapability), + nextProxyFunc: tygerClient.DataPlaneClient.Proxy, + } + + var sshTunnelPool *dataplane.SshTunnelPool + if !handler.requiresDataPlaneConnectTunnel { + if len(serviceMetadata.StorageEndpoints) != 1 { + return nil, fmt.Errorf("expected exactly one storage endpoint for local buffer proxying, found %d", len(serviceMetadata.StorageEndpoints)) + } + + storageEndpoint := serviceMetadata.StorageEndpoints[0] + parsedStorageEndpoint, err := url.Parse(storageEndpoint) + if err != nil { + return nil, fmt.Errorf("invalid storage endpoint URL '%s': %w", storageEndpoint, err) + } + if parsedStorageEndpoint.Scheme != "http+unix" { + return nil, fmt.Errorf("unsupported storage endpoint scheme '%s' for local buffer proxying", parsedStorageEndpoint.Scheme) + } + + storageSocketPath := strings.Split(parsedStorageEndpoint.Path, ":")[0] + + var tunnelPoolAwareRetryableHttpClient *retryablehttp.Client + + tunnelPoolAwareRetryableHttpClient, sshTunnelPool, err = dataplane.CreateSshTunnelPoolClient(ctx, tygerClient, storageSocketPath, sshTunnelPoolSize) + if err != nil { + return nil, fmt.Errorf("failed to create SSH tunnel pool for local buffer proxying: %w", err) + } + + tygerClient.DataPlaneClient.Client = tunnelPoolAwareRetryableHttpClient + + go func() { + <-ctx.Done() + sshTunnelPool.Close() + }() } r := chi.NewRouter() @@ -73,11 +120,13 @@ func RunProxy(ctx context.Context, tygerClient *client.TygerClient, options *Pro r.Group(func(r chi.Router) { r.Route("/", func(r chi.Router) { r.Route("/runs/{runId}", func(r chi.Router) { - r.Get("/", handler.forwardControlPlaneRequest) - r.Get("/logs", handler.forwardControlPlaneRequest) + r.Get("/", handler.makeForwardControlPlaneRequestFunc(copyResponse)) + r.Get("/logs", handler.makeForwardControlPlaneRequestFunc(copyResponse)) }) - r.Post("/buffers/{id}/access", handler.forwardControlPlaneRequest) + r.Post("/buffers/{id}/access", handler.makeForwardControlPlaneRequestFunc(handler.processBufferAccessResponse)) r.Get("/metadata", handler.handleMetadataRequest) + r.HandleFunc("/dataplane/*", handler.handleDataPlaneRequest) + r.HandleFunc("/dataplane", handler.handleDataPlaneRequest) }) }) @@ -126,7 +175,13 @@ func RunProxy(ctx context.Context, tygerClient *client.TygerClient, options *Pro } }() - return func() error { return server.Close() }, nil + closeProxy := func() error { + if sshTunnelPool != nil { + sshTunnelPool.Close() + } + return server.Close() + } + return closeProxy, nil } func CheckProxyAlreadyRunning(options *ProxyOptions) (*ProxyServiceMetadata, error) { @@ -162,11 +217,13 @@ func GetExistingProxyMetadata(options *ProxyOptions) *ProxyServiceMetadata { } type proxyHandler struct { - tygerClient *client.TygerClient - targetControlPlaneUrl *url.URL - options *ProxyOptions - getCaCertsPemString func() string - nextProxyFunc func(*http.Request) (*url.URL, error) + tygerClient *client.TygerClient + targetControlPlaneUrl *url.URL + options *ProxyOptions + serviceMetadata *model.ServiceMetadata + getCaCertsPemString func() string + requiresDataPlaneConnectTunnel bool + nextProxyFunc func(*http.Request) (*url.URL, error) } func (h *proxyHandler) handleMetadataRequest(w http.ResponseWriter, r *http.Request) { @@ -178,13 +235,27 @@ func (h *proxyHandler) handleMetadataRequest(w http.ResponseWriter, r *http.Requ dataPlaneProxyUrl.Scheme = "https" } + serviceMetadata := *h.serviceMetadata + if h.requiresDataPlaneConnectTunnel { + serviceMetadata.DataPlaneProxy = dataPlaneProxyUrl.String() + } + + // Remove auth-related fields, so that tyger login + // against the proxy does not require any authentication. + serviceMetadata.ApiAppId = "" + serviceMetadata.ApiAppUri = "" + serviceMetadata.CliAppId = "" + serviceMetadata.CliAppUri = "" + serviceMetadata.Authority = "" + serviceMetadata.Audience = "" + serviceMetadata.RbacEnabled = false + + serviceMetadata.TlsCaCertificates = h.getCaCertsPemString() + metadata := &ProxyServiceMetadata{ - ServiceMetadata: model.ServiceMetadata{ - DataPlaneProxy: dataPlaneProxyUrl.String(), - TlsCaCertificates: h.getCaCertsPemString(), - }, - ServerUrl: h.targetControlPlaneUrl.String(), - LogPath: h.options.LogPath, + ServiceMetadata: serviceMetadata, + ServerUrl: h.targetControlPlaneUrl.String(), + LogPath: h.options.LogPath, } w.WriteHeader(http.StatusOK) @@ -193,28 +264,88 @@ func (h *proxyHandler) handleMetadataRequest(w http.ResponseWriter, r *http.Requ } } -func (h *proxyHandler) forwardControlPlaneRequest(w http.ResponseWriter, r *http.Request) { - proxyReq := r.Clone(r.Context()) - proxyReq.RequestURI = "" // need to clear this since the instance will be used for a new request - proxyReq.URL.Scheme = h.targetControlPlaneUrl.Scheme - proxyReq.URL.Host = h.targetControlPlaneUrl.Host - proxyReq.Host = h.targetControlPlaneUrl.Host - if h.targetControlPlaneUrl.Path != "" { - p := proxyReq.URL.Path - proxyReq.URL.Path = "/" - proxyReq.URL = proxyReq.URL.JoinPath(h.targetControlPlaneUrl.Path, p) +func (h *proxyHandler) makeForwardControlPlaneRequestFunc(responseHandler func(originalRequest *http.Request, w http.ResponseWriter, resp *http.Response)) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + proxyReq := r.Clone(r.Context()) + proxyReq.RequestURI = "" // need to clear this since the instance will be used for a new request + proxyReq.URL.Scheme = h.targetControlPlaneUrl.Scheme + proxyReq.URL.Host = h.targetControlPlaneUrl.Host + proxyReq.Host = h.targetControlPlaneUrl.Host + if h.targetControlPlaneUrl.Path != "" { + p := proxyReq.URL.Path + proxyReq.URL.Path = "/" + proxyReq.URL = proxyReq.URL.JoinPath(h.targetControlPlaneUrl.Path, p) + } + + token, err := h.tygerClient.GetAccessToken(r.Context()) + + if err != nil { + log.Ctx(r.Context()).Error().Err(err).Send() + http.Error(w, "failed to get access token", http.StatusInternalServerError) + return + } + + proxyReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + resp, err := h.tygerClient.ControlPlaneClient.HTTPClient.Transport.RoundTrip(proxyReq) + if err != nil { + log.Ctx(r.Context()).Error().Err(err).Msg("Failed to forward request") + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + + defer resp.Body.Close() + + copyHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + responseHandler(r, w, resp) } +} - token, err := h.tygerClient.GetAccessToken(r.Context()) +// This is only used when we are connected to the downstream service over SSH. +func (h *proxyHandler) handleDataPlaneRequest(w http.ResponseWriter, r *http.Request) { + if h.requiresDataPlaneConnectTunnel { + h.handleUnsupportedRequest(w, r) + return + } + originalParam := r.URL.Query().Get("original") + if originalParam == "" { + http.Error(w, "Missing 'original' query parameter", http.StatusBadRequest) + return + } + originalUrl, err := url.Parse(originalParam) if err != nil { - log.Ctx(r.Context()).Error().Err(err).Send() - http.Error(w, "failed to get access token", http.StatusInternalServerError) + http.Error(w, "Invalid 'original' query parameter", http.StatusBadRequest) return } - proxyReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - resp, err := h.tygerClient.ControlPlaneClient.HTTPClient.Transport.RoundTrip(proxyReq) + // Get the subpath after /dataplane + subpath := strings.TrimPrefix(r.URL.Path, "/dataplane") + if subpath != "" { + originalUrl.Path = originalUrl.Path + subpath + } + + // Copy query parameters from the incoming request (except 'original') + query := originalUrl.Query() + for key, values := range r.URL.Query() { + if key != "original" { + query[key] = values + } + } + originalUrl.RawQuery = query.Encode() + + proxyReq := r.Clone(r.Context()) + proxyReq.RequestURI = "" // need to clear this since the instance will be used for a new request + proxyReq.URL = originalUrl + + retryableProxyRequest, err := retryablehttp.FromRequest(proxyReq) + if err != nil { + log.Ctx(r.Context()).Error().Err(err).Msg("Failed to create proxy request") + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + + resp, err := h.tygerClient.DataPlaneClient.Do(retryableProxyRequest) if err != nil { log.Ctx(r.Context()).Error().Err(err).Msg("Failed to forward request") http.Error(w, "Bad Gateway", http.StatusBadGateway) @@ -223,10 +354,7 @@ func (h *proxyHandler) forwardControlPlaneRequest(w http.ResponseWriter, r *http copyHeaders(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) - - if err := copyResponse(w, resp); err != nil { - log.Ctx(r.Context()).Error().Err(err).Msg("error copying response") - } + copyResponse(r, w, resp) } func (h *proxyHandler) handleUnsupportedRequest(w http.ResponseWriter, r *http.Request) { @@ -375,6 +503,51 @@ func (h *proxyHandler) handleTunnelRequest(w http.ResponseWriter, r *http.Reques }() } +func (h *proxyHandler) processBufferAccessResponse(originalRequest *http.Request, w http.ResponseWriter, resp *http.Response) { + if h.requiresDataPlaneConnectTunnel || resp.StatusCode != http.StatusCreated { + copyResponse(originalRequest, w, resp) + return + } + + // Local buffer proxying + defer resp.Body.Close() + var accessInfo model.BufferAccess + if err := json.NewDecoder(resp.Body).Decode(&accessInfo); err != nil { + log.Error().Err(err).Msg("Failed to decode buffer access info") + http.Error(w, "Bad Gateway", http.StatusBadGateway) + return + } + + dataPlaneUrl := url.URL{ + Host: originalRequest.Host, + Path: "/dataplane", + } + if originalRequest.TLS == nil { + dataPlaneUrl.Scheme = "http" + } else { + dataPlaneUrl.Scheme = "https" + } + query := dataPlaneUrl.Query() + query.Set("original", accessInfo.Uri) + + if parsedOriginalUrl, err := url.Parse(accessInfo.Uri); err == nil { + // copy the relay parameter so the client knows this is an ephemeral buffer + if relay, ok := parsedOriginalUrl.Query()["relay"]; ok { + query["relay"] = relay + } + } + + dataPlaneUrl.RawQuery = query.Encode() + + accessInfo.Uri = dataPlaneUrl.String() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(accessInfo); err != nil { + log.Ctx(originalRequest.Context()).Error().Err(err).Msg("Failed to write buffer access response") + } +} + // Returns a connection that is created after successfully calling CONNECT // on another proxy server func openTunnel(proxyAddress string, destination *url.URL) (net.Conn, error) { @@ -414,12 +587,15 @@ func transfer(destination io.WriteCloser, source io.ReadCloser, wg *sync.WaitGro io.Copy(destination, source) } -func copyResponse(w http.ResponseWriter, resp *http.Response) error { +func copyResponse(originalRequest *http.Request, w http.ResponseWriter, resp *http.Response) { flusher, ok := w.(http.Flusher) if !ok { // The ResponseWriter doesn't support flushing, fallback to simple copy _, err := io.Copy(w, resp.Body) - return err + if err != nil { + log.Ctx(resp.Request.Context()).Error().Err(err).Msg("Failed to copy response body") + } + return } // Copy with flushing whenever there is data so that a trickle of data does not get buffered @@ -438,9 +614,9 @@ func copyResponse(w http.ResponseWriter, resp *http.Response) error { } if err != nil { if err != io.EOF { - return err + log.Ctx(resp.Request.Context()).Error().Err(err).Msg("Failed to copy response body") } - return nil + return } } } diff --git a/scripts/run-ssh-tests.sh b/scripts/run-ssh-tests.sh index 5f898610..b9f3926d 100755 --- a/scripts/run-ssh-tests.sh +++ b/scripts/run-ssh-tests.sh @@ -3,6 +3,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# This script sets up an SSH-based test environment for testing tyger's SSH functionality. +# It is invoked by the 'variant-test' target in Makefile.docker. +# +# What it does: +# 1. Creates a Docker container running an SSH server (using panubo/sshd image) +# 2. Configures SSH authentication using either existing ssh-agent keys or generates a new key pair +# 3. Copies the tyger CLI and docker CLI into the container for testing +# 4. Adds a temporary SSH host configuration to ~/.ssh/config +# 5. Logs into tyger using the SSH protocol (ssh://) +# 6. Runs the integration tests against this SSH-based connection +# +# Options: +# --start-only Start the SSH container and configure SSH, but don't run tests +# (useful for manual testing or debugging) +# +# The script cleans up after itself by removing the container and SSH config on exit. + set -euo pipefail container_name=tyger-test-ssh @@ -13,10 +30,24 @@ ssh_host=tygersshhost start_marker="# START OF TYGER TESTING SECTION" end_marker="# END OF TYGER TESTING SECTION" +usage() { + echo "Usage: $(basename "$0") [OPTIONS]" + echo "" + echo "Sets up an SSH-based test environment for testing tyger's SSH login functionality." + echo "" + echo "Options:" + echo " --start-only Start the SSH container and configure SSH, but don't run tests" + echo " -h, --help Show this help message and exit" +} + while [[ $# -gt 0 ]]; do key="$1" case $key in + -h | --help) + usage + exit 0 + ;; --start-only) start_only=1 shift @@ -41,14 +72,16 @@ cleanup() { docker rm -f $container_name >/dev/null } -if [[ -z ${start_only:-} ]]; then - trap cleanup SIGINT SIGTERM EXIT -fi - # Copy bind mounts from tyger-local-gateway gateway_bind_mounts=$(docker inspect -f '{{range .Mounts}}{{if eq .Type "bind"}}-v {{.Source}}:{{.Destination}} {{end}}{{end}}' tyger-local-gateway) -docker rm -f $container_name &>/dev/null +# Check if container already exists before removing +container_existed=$(docker inspect -f '{{.State.Running}}' $container_name 2>/dev/null || echo "") +docker rm -f $container_name &>/dev/null || true + +if [[ -z ${start_only:-} ]] && [[ -z $container_existed ]]; then + trap cleanup SIGINT SIGTERM EXIT +fi # shellcheck disable=SC2086 docker create \ @@ -97,7 +130,8 @@ else ssh_connection_port=$ssh_port fi -host_config="$start_marker +host_config=" +$start_marker Host $ssh_host HostName $ssh_connection_host Port $ssh_connection_port @@ -142,5 +176,5 @@ tyger login "ssh://${ssh_host}${tyger_socket_path}?option[StrictHostKeyChecking] tyger login status if [[ -z ${start_only:-} ]]; then - make -f "$(dirname "$0")/../Makefile" integration-test-no-up + make -s -f "$(dirname "$0")/../Makefile" integration-test-no-up fi diff --git a/server/ControlPlane/Buffers/Buffers.cs b/server/ControlPlane/Buffers/Buffers.cs index 214a155f..99ff0c03 100644 --- a/server/ControlPlane/Buffers/Buffers.cs +++ b/server/ControlPlane/Buffers/Buffers.cs @@ -11,6 +11,7 @@ using Tyger.ControlPlane.Json; using Tyger.ControlPlane.Model; using Tyger.ControlPlane.OpenApi; +using Tyger.ControlPlane.ServiceMetadata; using Buffer = Tyger.ControlPlane.Model.Buffer; namespace Tyger.ControlPlane.Buffers; @@ -49,6 +50,7 @@ public static void AddBuffers(this IHostApplicationBuilder builder) { builder.Services.AddSingleton(); builder.Services.AddSingleton(sp => sp.GetRequiredService()); + builder.Services.AddSingleton(sp => sp.GetRequiredService()); builder.Services.AddHostedService(sp => sp.GetRequiredService()); builder.Services.AddHealthChecks().AddCheck("data plane"); } diff --git a/server/ControlPlane/Buffers/LocalStorageBufferProvider.cs b/server/ControlPlane/Buffers/LocalStorageBufferProvider.cs index 76dbd893..49d2b9fc 100644 --- a/server/ControlPlane/Buffers/LocalStorageBufferProvider.cs +++ b/server/ControlPlane/Buffers/LocalStorageBufferProvider.cs @@ -8,11 +8,12 @@ using Tyger.Common.Buffers; using Tyger.ControlPlane.Database; using Tyger.ControlPlane.Model; +using Tyger.ControlPlane.ServiceMetadata; using Buffer = Tyger.ControlPlane.Model.Buffer; namespace Tyger.ControlPlane.Buffers; -public sealed class LocalStorageBufferProvider : IBufferProvider, IHostedService, IHealthCheck, IDisposable +public sealed class LocalStorageBufferProvider : IBufferProvider, IHostedService, IHealthCheck, ICapabilitiesContributor, IDisposable { public const string AccountName = "local"; public const string AccountLocation = "local"; @@ -238,4 +239,6 @@ public async Task TryMarkBufferAsFailed(string id, CancellationToken cancellatio _logger.FailedToMarkBufferAsFailed(e); } } + + public Capabilities GetCapabilities() => Capabilities.LocalBuffers; } diff --git a/server/ControlPlane/Compute/Docker/DockerEphemeralBufferProvider.cs b/server/ControlPlane/Compute/Docker/DockerEphemeralBufferProvider.cs index 2770c54d..7a050edf 100644 --- a/server/ControlPlane/Compute/Docker/DockerEphemeralBufferProvider.cs +++ b/server/ControlPlane/Compute/Docker/DockerEphemeralBufferProvider.cs @@ -77,6 +77,15 @@ public DockerEphemeralBufferProvider(DockerClient client, IOptions? Capabilities { get; init; } public IEnumerable? ApiVersions { get; init; } + public IEnumerable? StorageEndpoints { get; init; } } public record StorageAccount(string Name, string Location, string Endpoint) : ModelBase; diff --git a/server/ControlPlane/ServiceMetadata/ServiceMetadata.cs b/server/ControlPlane/ServiceMetadata/ServiceMetadata.cs index a7e914bc..fc81fc04 100644 --- a/server/ControlPlane/ServiceMetadata/ServiceMetadata.cs +++ b/server/ControlPlane/ServiceMetadata/ServiceMetadata.cs @@ -3,6 +3,7 @@ using Microsoft.Extensions.Options; using Tyger.ControlPlane.AccessControl; +using Tyger.ControlPlane.Buffers; using Tyger.ControlPlane.Versioning; namespace Tyger.ControlPlane.ServiceMetadata; @@ -14,7 +15,7 @@ public static void MapServiceMetadata(this WebApplication app) Model.ServiceMetadata? serviceMetadata = null; app.MapGet( "/metadata", - (IEnumerable contributor, IOptions accessControl) => + (IEnumerable contributor, IOptions accessControl, IBufferProvider bufferProvider) => { if (serviceMetadata is null) { @@ -25,7 +26,8 @@ public static void MapServiceMetadata(this WebApplication app) serviceMetadata = new Model.ServiceMetadata { Capabilities = capabilityStrings, - ApiVersions = apiVersionsSupported + ApiVersions = apiVersionsSupported, + StorageEndpoints = bufferProvider.GetStorageAccounts().Select(sa => new Uri(sa.Endpoint)), }; if (accessControl.Value.Enabled) @@ -58,6 +60,7 @@ public enum Capabilities EphemeralBuffers = 1 << 3, Docker = 1 << 4, Kubernetes = 1 << 5, + LocalBuffers = 1 << 6, } public interface ICapabilitiesContributor