diff --git a/cmd/workon.go b/cmd/workon.go index ef07346..639aec2 100644 --- a/cmd/workon.go +++ b/cmd/workon.go @@ -92,6 +92,8 @@ func runWorkOn(cmd *cobra.Command, target string, step string, execCmd []string) credEnvVars := creds.EnvVars() + ptdRoot := helpers.GetTargetsConfigPath() + // Start proxy if needed (non-fatal) proxyFile := path.Join(internal.DataDir(), "proxy.json") stopProxy, err := kube.StartProxy(cmd.Context(), t, proxyFile) @@ -173,6 +175,7 @@ func runWorkOn(cmd *cobra.Command, target string, step string, execCmd []string) for k, v := range credEnvVars { shellCommand.Env = append(shellCommand.Env, k+"="+v) } + shellCommand.Env = append(shellCommand.Env, "PTD_ROOT="+ptdRoot) if kubeconfigPath != "" { shellCommand.Env = append(shellCommand.Env, "KUBECONFIG="+kubeconfigPath) } @@ -210,6 +213,7 @@ func runWorkOn(cmd *cobra.Command, target string, step string, execCmd []string) for k, v := range credEnvVars { shellCommand.Env = append(shellCommand.Env, k+"="+v) } + shellCommand.Env = append(shellCommand.Env, "PTD_ROOT="+ptdRoot) if kubeconfigPath != "" { shellCommand.Env = append(shellCommand.Env, "KUBECONFIG="+kubeconfigPath) } diff --git a/lib/azure/proxy.go b/lib/azure/proxy.go index a39ba22..09bc01e 100644 --- a/lib/azure/proxy.go +++ b/lib/azure/proxy.go @@ -7,7 +7,6 @@ import ( "os" "os/exec" "os/signal" - "regexp" "time" "github.com/posit-dev/ptd/lib/helpers" @@ -22,6 +21,7 @@ type ProxySession struct { tunnelCommand *exec.Cmd socksCommand *exec.Cmd localPort string + sshKeyPath string // temp file for bastion SSH key, cleaned up on Stop runningProxy *proxy.RunningProxy isReused bool // indicates if the session is reused from an existing running proxy @@ -88,16 +88,10 @@ func (p *ProxySession) Start(ctx context.Context) error { return err } - bastionName, err := p.target.BastionName(ctx) - - if err != nil { - slog.Error("Error getting bastion name", "error", err) - } - - jumpBoxId, err := p.target.JumpBoxId(ctx) - + bastionInfo, err := p.target.BastionInfo(ctx) if err != nil { - slog.Error("Error getting jump box ID", "error", err) + slog.Error("Error getting bastion info", "error", err) + return fmt.Errorf("failed to get bastion info: %w", err) } // Determine which resource group to use for the bastion tunnel @@ -114,25 +108,32 @@ func (p *ProxySession) Start(ctx context.Context) error { return fmt.Errorf("Resource Group name is empty, cannot continue.") } - // HACK: at the moment, the ssh key is written to a path and named based on the bastion name. - // This is a temporary workaround to remove the "-host" suffix from the bastion name, since that isn't in the key name - r := regexp.MustCompile(`-host.*`) - bastionSshKeyName := r.ReplaceAllString(bastionName, "") + // Write the SSH private key from Pulumi state to a temp file (cleaned up in Stop) + sshKeyFile, err := os.CreateTemp("", "ptd-bastion-ssh-*") + if err != nil { + return fmt.Errorf("failed to create temp file for SSH key: %w", err) + } + if _, err := sshKeyFile.WriteString(bastionInfo.SSHPrivateKey); err != nil { + sshKeyFile.Close() + os.Remove(sshKeyFile.Name()) + return fmt.Errorf("failed to write SSH key: %w", err) + } + sshKeyFile.Close() + p.sshKeyPath = sshKeyFile.Name() // build the command to start the bastion tunnel, this will connect jumpbox:22 to localhost:22001 (enabling SSH connection via separate command) p.tunnelCommand = exec.CommandContext( ctx, p.azCliPath, "network", "bastion", "tunnel", - "--name", bastionName, + "--name", bastionInfo.Name, "--resource-group", resourceGroupName, - "--target-resource-id", jumpBoxId, + "--target-resource-id", bastionInfo.JumpBoxID, "--resource-port", "22", "--port", "22001", ) // build the command to start the SOCKS proxy via SSH, using the jumpbox tunnel from above - // ssh -ND 1080 ptd-admin@localhost -p 22001 -i ~/.ssh/bas-ptd-madrigal01-production-bastion p.socksCommand = exec.CommandContext( ctx, "ssh", @@ -141,7 +142,7 @@ func (p *ProxySession) Start(ctx context.Context) error { "-p", "22001", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", - "-i", fmt.Sprintf("%s/.ssh/%s", os.Getenv("HOME"), bastionSshKeyName)) + "-i", p.sshKeyPath) // set the environment variables for the command // add each az env var to command @@ -150,7 +151,7 @@ func (p *ProxySession) Start(ctx context.Context) error { p.socksCommand.Env = append(p.socksCommand.Env, fmt.Sprintf("%s=%s", k, v)) } - slog.Debug("Starting Azure bastion tunnel", "bastion_name", bastionName, "resource_group", resourceGroupName, "tunnel_port", "22001", "target_port", "22") + slog.Debug("Starting Azure bastion tunnel", "bastion_name", bastionInfo.Name, "resource_group", resourceGroupName, "tunnel_port", "22001", "target_port", "22") if ctx.Value("verbose") != nil && ctx.Value("verbose").(bool) { slog.Debug("Verbose turned on, attaching command output to stdout and stderr") p.tunnelCommand.Stdout = os.Stdout @@ -201,6 +202,10 @@ func (p *ProxySession) Start(ctx context.Context) error { } func (p *ProxySession) Stop() error { + if p.sshKeyPath != "" { + os.Remove(p.sshKeyPath) + } + if p.isReused { slog.Debug("Proxy session was reused, not stopping", "target", p.target.Name(), "local_port", p.localPort) return nil diff --git a/lib/azure/target.go b/lib/azure/target.go index 9be7159..a89c72e 100644 --- a/lib/azure/target.go +++ b/lib/azure/target.go @@ -166,10 +166,18 @@ func (t Target) fullPulumiEnvVars(ctx context.Context) (map[string]string, error return creds.EnvVars(), nil } -func (t Target) BastionName(ctx context.Context) (string, error) { +// BastionInfo holds the bastion connection details from the persistent stack. +type BastionInfo struct { + Name string + JumpBoxID string + SSHPrivateKey string +} + +// BastionInfo retrieves bastion connection details from the persistent stack outputs. +func (t Target) BastionInfo(ctx context.Context) (*BastionInfo, error) { envVars, err := t.fullPulumiEnvVars(ctx) if err != nil { - return "", err + return nil, err } persistentStack, err := pulumi.NewPythonPulumiStack( @@ -185,57 +193,35 @@ func (t Target) BastionName(ctx context.Context) (string, error) { false, ) if err != nil { - return "", err + return nil, err } - persistentOutputs, err := persistentStack.Outputs(ctx) + outputs, err := persistentStack.Outputs(ctx) if err != nil { - return "", err - } - - if _, ok := persistentOutputs["bastion_name"]; !ok { - return "", fmt.Errorf("bastion_name output not found in persistent stack outputs") + return nil, err } - bastionName := persistentOutputs["bastion_name"].Value.(string) + info := &BastionInfo{} - return bastionName, nil -} - -func (t Target) JumpBoxId(ctx context.Context) (string, error) { - envVars, err := t.fullPulumiEnvVars(ctx) - if err != nil { - return "", err - } - - persistentStack, err := pulumi.NewPythonPulumiStack( - ctx, - "azure", - "workload", - "persistent", - t.Name(), - t.Region(), - t.PulumiBackendUrl(), - t.PulumiSecretsProviderKey(), - envVars, - false, - ) - if err != nil { - return "", err + if v, ok := outputs["bastion_name"]; ok { + info.Name = v.Value.(string) + } else { + return nil, fmt.Errorf("bastion_name output not found in persistent stack outputs") } - persistentOutputs, err := persistentStack.Outputs(ctx) - if err != nil { - return "", err + if v, ok := outputs["bastion_jumpbox_id"]; ok { + info.JumpBoxID = v.Value.(string) + } else { + return nil, fmt.Errorf("bastion_jumpbox_id output not found in persistent stack outputs") } - if _, ok := persistentOutputs["bastion_jumpbox_id"]; !ok { - return "", fmt.Errorf("bastion_jumpbox_id output not found in persistent stack outputs") + if v, ok := outputs["bastion_ssh_private_key"]; ok { + info.SSHPrivateKey = v.Value.(string) + } else { + return nil, fmt.Errorf("bastion_ssh_private_key output not found in persistent stack outputs") } - jumpBoxId := persistentOutputs["bastion_jumpbox_id"].Value.(string) - - return jumpBoxId, nil + return info, nil } // HashName returns an obfuscated name for the target that can be used as a unique identifier. diff --git a/python-pulumi/src/ptd/pulumi_resources/azure_bastion.py b/python-pulumi/src/ptd/pulumi_resources/azure_bastion.py index 2894ad1..e1cdc6f 100644 --- a/python-pulumi/src/ptd/pulumi_resources/azure_bastion.py +++ b/python-pulumi/src/ptd/pulumi_resources/azure_bastion.py @@ -1,7 +1,6 @@ import pulumi import pulumi_tls as tls from pulumi_azure_native import compute, network -from pulumi_command import local class AzureBastion(pulumi.ComponentResource): @@ -50,19 +49,6 @@ def __init__( algorithm="ED25519", ) - # write the private key to a file on the local machine - # this needs to be repeated by any engineer who wants to access the jumpbox - local.run_output( - command=pulumi.Output.format( - "FILE=~/.ssh/{1}; " - 'if [ ! -f "$FILE" ]; then ' - 'echo \'{0}\' > "$FILE" && chmod 600 "$FILE"; ' - 'else echo "File $FILE already exists, skipping."; fi', - self.jumpbox_ssh_key.private_key_openssh, - name, - ), - ) - # Create a Public IP for Bastion self.public_ip = network.PublicIPAddress( f"{name}-pip", diff --git a/python-pulumi/src/ptd/pulumi_resources/azure_workload_persistent.py b/python-pulumi/src/ptd/pulumi_resources/azure_workload_persistent.py index da80cd8..bfa2ccb 100644 --- a/python-pulumi/src/ptd/pulumi_resources/azure_workload_persistent.py +++ b/python-pulumi/src/ptd/pulumi_resources/azure_workload_persistent.py @@ -96,6 +96,7 @@ def __init__( "app_gateway_subnet_id": self.app_gateway_subnet.id, "bastion_name": self.bastion.bastion_host.name, "bastion_jumpbox_id": self.bastion.jumpbox_host.id, + "bastion_ssh_private_key": self.bastion.jumpbox_ssh_key.private_key_openssh, "mimir_password": self.mimir_password.result, "private_subnet_name": self.private_subnet.name, "private_subnet_cidr": self.private_subnet.address_prefix,