diff --git a/packages/api/api.go b/packages/api/api.go index 60630df9..3a6a9d6b 100644 --- a/packages/api/api.go +++ b/packages/api/api.go @@ -63,6 +63,9 @@ const ( operationCallGetMFASessionStatus = "CallGetMFASessionStatus" operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat" operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat" + operationCallRelayLogin = "CallRelayLogin" + operationCallRelayConnect = "CallRelayConnect" + operationCallRelayHeartbeatV2 = "CallRelayHeartbeatV2" operationCallIssueCertificate = "CallIssueCertificate" operationCallRetrieveCertificate = "CallRetrieveCertificate" operationCallGetCertificateBundle = "CallGetCertificateBundle" @@ -901,6 +904,62 @@ func CallGetRelays(httpClient *resty.Client) (GetRelaysResponse, error) { return resBody, nil } +func CallRelayLogin(httpClient *resty.Client, request RelayLoginRequest) (RelayLoginResponse, error) { + var resBody RelayLoginResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v2/relays/login", config.INFISICAL_URL)) + + if err != nil { + return RelayLoginResponse{}, NewGenericRequestError(operationCallRelayLogin, err) + } + + if response.IsError() { + return RelayLoginResponse{}, NewAPIErrorWithResponse(operationCallRelayLogin, response, nil) + } + + return resBody, nil +} + +func CallRelayConnect(httpClient *resty.Client) (RelayConnectResponse, error) { + var resBody RelayConnectResponse + response, err := httpClient. + R(). + SetResult(&resBody). + SetHeader("User-Agent", USER_AGENT). + Post(fmt.Sprintf("%v/v2/relays/connect", config.INFISICAL_URL)) + + if err != nil { + return RelayConnectResponse{}, NewGenericRequestError(operationCallRelayConnect, err) + } + + if response.IsError() { + return RelayConnectResponse{}, NewAPIErrorWithResponse(operationCallRelayConnect, response, nil) + } + + return resBody, nil +} + +func CallRelayHeartbeatV2(httpClient *resty.Client) error { + response, err := httpClient. + R(). + SetHeader("User-Agent", USER_AGENT). + Post(fmt.Sprintf("%v/v2/relays/heartbeat", config.INFISICAL_URL)) + + if err != nil { + return NewGenericRequestError(operationCallRelayHeartbeatV2, err) + } + + if response.IsError() { + return NewAPIErrorWithResponse(operationCallRelayHeartbeatV2, response, nil) + } + + return nil +} + func CallConnectGateway(httpClient *resty.Client, request ConnectGatewayRequest) (RegisterGatewayResponse, error) { var resBody RegisterGatewayResponse response, err := httpClient. diff --git a/packages/api/model.go b/packages/api/model.go index 5dd3fe95..abbb0229 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -986,6 +986,35 @@ type RelayHeartbeatRequest struct { Name string `json:"name"` } +type RelayLoginRequest struct { + Method string `json:"method"` + Token string `json:"token,omitempty"` + RelayID string `json:"relayId,omitempty"` + HTTPRequestMethod string `json:"iamHttpRequestMethod,omitempty"` + IamRequestBody string `json:"iamRequestBody,omitempty"` + IamRequestHeaders string `json:"iamRequestHeaders,omitempty"` +} + +type RelayLoginResponse struct { + AccessToken string `json:"accessToken"` + RelayID string `json:"relayId"` + TokenType string `json:"tokenType"` +} + +type RelayConnectResponse struct { + RelayID string `json:"relayId"` + PKI struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCertificateChain string `json:"clientCertificateChain"` + } `json:"pki"` + SSH struct { + ServerCertificate string `json:"serverCertificate"` + ServerPrivateKey string `json:"serverPrivateKey"` + ClientCAPublicKey string `json:"clientCAPublicKey"` + } `json:"ssh"` +} + type AltName struct { Type string `json:"type"` Value string `json:"value"` diff --git a/packages/cmd/relay.go b/packages/cmd/relay.go index d38ff31c..0ad1e815 100644 --- a/packages/cmd/relay.go +++ b/packages/cmd/relay.go @@ -12,6 +12,8 @@ import ( "syscall" "time" + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/config" gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2" "github.com/Infisical/infisical-merge/packages/relay" "github.com/Infisical/infisical-merge/packages/util" @@ -39,9 +41,18 @@ var relayStartCmd = &cobra.Command{ util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME)) } - host, err := util.GetCmdFlagOrEnv(cmd, "host", []string{gatewayv2.RELAY_HOST_ENV_NAME}) - if err != nil || host == "" { - util.HandleError(err, fmt.Sprintf("unable to get host flag or %s env", gatewayv2.RELAY_HOST_ENV_NAME)) + enrollMethod, _ := cmd.Flags().GetString("enroll-method") + if enrollMethod == "" { + enrollMethod = os.Getenv("INFISICAL_RELAY_ENROLL_METHOD") + } + if enrollMethod != "" && enrollMethod != relay.EnrollMethodToken && enrollMethod != relay.EnrollMethodAws { + util.HandleError(fmt.Errorf("invalid --enroll-method %q: supported values are %q and %q", + enrollMethod, relay.EnrollMethodToken, relay.EnrollMethodAws)) + } + + host, _ := util.GetCmdFlagOrEnv(cmd, "host", []string{gatewayv2.RELAY_HOST_ENV_NAME}) + if host == "" && enrollMethod == "" { + util.HandleError(fmt.Errorf("please provide host flag"), fmt.Sprintf("unable to get host flag or %s env", gatewayv2.RELAY_HOST_ENV_NAME)) } instanceType, err := util.GetCmdFlagOrEnvWithDefaultValue(cmd, "type", []string{gatewayv2.RELAY_TYPE_ENV_NAME}, "org") @@ -49,19 +60,153 @@ var relayStartCmd = &cobra.Command{ util.HandleError(err, fmt.Sprintf("unable to get type flag or %s env", gatewayv2.RELAY_TYPE_ENV_NAME)) } + var enrolledAccessToken string + + // --- AWS Auth path --- + if enrollMethod == relay.EnrollMethodAws { + relayID, _ := cmd.Flags().GetString("relay-id") + if relayID == "" { + relayID = os.Getenv(relay.INFISICAL_RELAY_ID_KEY) + } + if relayID == "" { + stored, _ := relay.LoadStoredRelayID(relayName) + relayID = stored + } + if relayID == "" { + util.HandleError(errors.New("--relay-id is required when --enroll-method=aws")) + } + + domain, _ := cmd.Flags().GetString("domain") + if domain != "" { + config.INFISICAL_URL = util.AppendAPIEndpoint(domain) + } else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" { + config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain) + } + + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + util.HandleError(err, "unable to create HTTP client") + } + + log.Info().Msg("Authenticating relay via AWS Auth (STS GetCallerIdentity)...") + accessTokenStr, err := relay.LoginRelayWithAws(cmd.Context(), httpClient, relayID) + if err != nil { + util.HandleError(err, "AWS Auth login failed") + } + + enrolledAccessToken = accessTokenStr + + if err := relay.SaveRelayID(relayName, relayID); err != nil { + util.HandleError(err, "failed to save relay id to config") + } + + effectiveDomain := domain + if effectiveDomain == "" { + effectiveDomain = config.INFISICAL_URL + } + if effectiveDomain != "" { + if err := relay.SaveDomain(relayName, effectiveDomain); err != nil { + util.HandleError(err, "failed to save domain to config") + } + } + + log.Info().Msgf("Relay authenticated via AWS Auth. State saved to %s", relay.GetConfPathDisplay(relayName)) + log.Info().Msg("Starting relay...") + } + + // --- Enrollment token path --- + if enrollMethod == relay.EnrollMethodToken { + enrollToken, _ := cmd.Flags().GetString("token") + if enrollToken == "" { + util.HandleError(errors.New("--token is required when --enroll-method=token")) + } + + storedEnrollToken, _ := relay.LoadStoredEnrollmentToken(relayName) + alreadyEnrolled := storedEnrollToken != "" && storedEnrollToken == enrollToken + + if alreadyEnrolled { + log.Info().Msg("Enrollment token matches stored token. Skipping enrollment.") + } else { + domain, _ := cmd.Flags().GetString("domain") + if domain != "" { + config.INFISICAL_URL = util.AppendAPIEndpoint(domain) + } else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" { + config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain) + } + + httpClient, err := util.GetRestyClientWithCustomHeaders() + if err != nil { + util.HandleError(err, "unable to create HTTP client") + } + + log.Info().Msg("Enrolling relay with enrollment token...") + enrollResp, err := api.CallRelayLogin(httpClient, api.RelayLoginRequest{ + Method: "token", + Token: enrollToken, + }) + if err != nil { + util.HandleError(err, "enrollment failed") + } + + enrolledAccessToken = enrollResp.AccessToken + if err := relay.SaveAccessToken(relayName, enrollResp.AccessToken); err != nil { + util.HandleError(err, "failed to save relay access token") + } + if err := relay.SaveEnrollmentToken(relayName, enrollToken); err != nil { + util.HandleError(err, "failed to save enrollment token to config") + } + + effectiveDomain := domain + if effectiveDomain == "" { + effectiveDomain = config.INFISICAL_URL + } + if effectiveDomain != "" { + if err := relay.SaveDomain(relayName, effectiveDomain); err != nil { + util.HandleError(err, "failed to save domain to config") + } + } + + log.Info().Msgf("Relay enrolled successfully. Access token saved to %s", relay.GetConfPathDisplay(relayName)) + } + + log.Info().Msg("Starting relay...") + } + + // --- Domain resolution for resource auth / stored token --- + isResourceAuth := enrollMethod == relay.EnrollMethodToken || enrollMethod == relay.EnrollMethodAws + if isResourceAuth { + if flagDomain, _ := cmd.Flags().GetString("domain"); flagDomain != "" { + config.INFISICAL_URL = util.AppendAPIEndpoint(flagDomain) + } else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" { + config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain) + } + } + relayInstance, err := relay.NewRelay(&relay.RelayConfig{ - RelayName: relayName, - SSHPort: "2222", - TLSPort: "8443", - Host: host, - Type: instanceType, + RelayName: relayName, + SSHPort: "2222", + TLSPort: "8443", + Host: host, + Type: instanceType, + EnrollMethod: enrollMethod, }) if err != nil { util.HandleError(err, "unable to create relay instance") } - if instanceType == "instance" { + if isResourceAuth { + // Use the freshly enrolled token, or load the stored one. + if enrolledAccessToken != "" { + relayInstance.SetToken(enrolledAccessToken) + } else { + storedToken, err := relay.LoadStoredAccessToken(relayName) + if err != nil || storedToken == "" { + util.HandleError(errors.New("no stored access token found — re-run with enrollment token")) + } + relayInstance.SetToken(storedToken) + } + } else if instanceType == "instance" { relayAuthSecret := os.Getenv(gatewayv2.RELAY_AUTH_SECRET_ENV_NAME) if relayAuthSecret == "" { util.HandleError(fmt.Errorf("%s is not set", gatewayv2.RELAY_AUTH_SECRET_ENV_NAME), "unable to get relay auth secret") @@ -96,7 +241,6 @@ var relayStartCmd = &cobra.Command{ cancelCmd() cancelSdk() - // Give graceful shutdown 10 seconds, then force exit on second signal select { case <-sigCh: log.Warn().Msg("Second signal received, force exit triggered") @@ -107,7 +251,6 @@ var relayStartCmd = &cobra.Command{ } }() - // Token refresh goroutine - runs every 10 seconds go func() { tokenRefreshTicker := time.NewTicker(10 * time.Second) defer tokenRefreshTicker.Stop() @@ -259,8 +402,11 @@ func init() { relayStartCmd.Flags().String("type", "", "The type of relay to run. Defaults to 'org'") relayStartCmd.Flags().String("host", "", "The IP or hostname for the relay") relayStartCmd.Flags().String("name", "", "The name of the relay") - relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag") + relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token, or a one-time enrollment token when --enroll-method=token") relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag") + relayStartCmd.Flags().String("enroll-method", "", "relay auth method [token, aws]. when set to 'token', uses --token as a one-time enrollment token. when set to 'aws', authenticates via signed STS GetCallerIdentity using --relay-id") + relayStartCmd.Flags().String("relay-id", "", "relay id (required when --enroll-method=aws)") + relayStartCmd.Flags().String("domain", "", "domain of your self-hosted Infisical instance (used with --enroll-method)") relayStartCmd.Flags().String("client-id", "", "client id for universal auth") relayStartCmd.Flags().String("client-secret", "", "client secret for universal auth") relayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods") diff --git a/packages/relay/aws_auth.go b/packages/relay/aws_auth.go new file mode 100644 index 00000000..73059df6 --- /dev/null +++ b/packages/relay/aws_auth.go @@ -0,0 +1,75 @@ +package relay + +import ( + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/go-resty/resty/v2" + infisicalSdkUtil "github.com/infisical/go-sdk/packages/util" +) + +func LoginRelayWithAws(ctx context.Context, httpClient *resty.Client, relayID string) (string, error) { + if relayID == "" { + return "", errors.New("--relay-id is required when --enroll-method=aws") + } + + awsCredentials, awsRegion, err := infisicalSdkUtil.RetrieveAwsCredentials() + if err != nil { + return "", fmt.Errorf("unable to retrieve AWS credentials: %w", err) + } + + iamRequestURL := fmt.Sprintf("https://sts.%s.amazonaws.com/", awsRegion) + iamRequestBody := "Action=GetCallerIdentity&Version=2011-06-15" + + req, err := http.NewRequest(http.MethodPost, iamRequestURL, strings.NewReader(iamRequestBody)) + if err != nil { + return "", fmt.Errorf("error building STS request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") + + hash := sha256.New() + hash.Write([]byte(iamRequestBody)) + payloadHash := fmt.Sprintf("%x", hash.Sum(nil)) + + signer := v4.NewSigner() + if err := signer.SignHTTP(ctx, awsCredentials, req, payloadHash, "sts", awsRegion, time.Now()); err != nil { + return "", fmt.Errorf("error signing STS request: %w", err) + } + + headers := make(map[string]string) + for name, values := range req.Header { + if strings.ToLower(name) == "content-length" { + continue + } + headers[name] = values[0] + } + headers["Host"] = fmt.Sprintf("sts.%s.amazonaws.com", awsRegion) + + headersJSON, err := json.Marshal(headers) + if err != nil { + return "", fmt.Errorf("error marshalling headers: %w", err) + } + + resp, err := api.CallRelayLogin(httpClient, api.RelayLoginRequest{ + Method: EnrollMethodAws, + RelayID: relayID, + HTTPRequestMethod: req.Method, + IamRequestBody: base64.StdEncoding.EncodeToString([]byte(iamRequestBody)), + IamRequestHeaders: base64.StdEncoding.EncodeToString(headersJSON), + }) + if err != nil { + return "", err + } + + return resp.AccessToken, nil +} diff --git a/packages/relay/enroll.go b/packages/relay/enroll.go new file mode 100644 index 00000000..e37898bb --- /dev/null +++ b/packages/relay/enroll.go @@ -0,0 +1,142 @@ +package relay + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" +) + +const ( + EnrollMethodToken = "token" + EnrollMethodAws = "aws" + + INFISICAL_RELAY_ACCESS_TOKEN_KEY = "INFISICAL_RELAY_ACCESS_TOKEN" + INFISICAL_RELAY_DOMAIN_KEY = "INFISICAL_RELAY_DOMAIN" + INFISICAL_RELAY_ENROLLMENT_TOKEN_KEY = "INFISICAL_RELAY_ENROLLMENT_TOKEN" + INFISICAL_RELAY_ID_KEY = "INFISICAL_RELAY_ID" +) + +func relayConfPath(name string) (string, error) { + if os.Geteuid() == 0 { + return filepath.Join("/etc/infisical/relays", name+".conf"), nil + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("unable to determine home directory: %w", err) + } + + return filepath.Join(homeDir, ".infisical", "relays", name+".conf"), nil +} + +func loadConfKey(name, key string) (string, error) { + confPath, err := relayConfPath(name) + if err != nil { + return "", err + } + + data, err := os.ReadFile(confPath) + if os.IsNotExist(err) { + return "", nil + } + if err != nil { + return "", fmt.Errorf("failed to read relay config: %w", err) + } + + prefix := key + "=" + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, prefix) { + return strings.TrimPrefix(line, prefix), nil + } + } + + return "", nil +} + +func saveConfKey(name, key, value string) error { + confPath, err := relayConfPath(name) + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(confPath), 0700); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + var existingLines []string + data, err := os.ReadFile(confPath) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to read existing config: %w", err) + } + if err == nil { + prefix := key + "=" + for _, line := range strings.Split(string(data), "\n") { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, prefix) { + continue + } + existingLines = append(existingLines, line) + } + } + + existingLines = append(existingLines, fmt.Sprintf("%s=%s", key, value)) + content := strings.Join(existingLines, "\n") + "\n" + + if err := os.WriteFile(confPath, []byte(content), 0600); err != nil { + return fmt.Errorf("failed to write relay config: %w", err) + } + + return nil +} + +func LoadStoredAccessToken(name string) (string, error) { + if envToken := os.Getenv(INFISICAL_RELAY_ACCESS_TOKEN_KEY); envToken != "" { + return envToken, nil + } + return loadConfKey(name, INFISICAL_RELAY_ACCESS_TOKEN_KEY) +} + +func SaveAccessToken(name, token string) error { + return saveConfKey(name, INFISICAL_RELAY_ACCESS_TOKEN_KEY, token) +} + +func LoadStoredDomain(name string) (string, error) { + return loadConfKey(name, INFISICAL_RELAY_DOMAIN_KEY) +} + +func SaveDomain(name, domain string) error { + return saveConfKey(name, INFISICAL_RELAY_DOMAIN_KEY, domain) +} + +func LoadStoredEnrollmentToken(name string) (string, error) { + return loadConfKey(name, INFISICAL_RELAY_ENROLLMENT_TOKEN_KEY) +} + +func SaveEnrollmentToken(name, token string) error { + return saveConfKey(name, INFISICAL_RELAY_ENROLLMENT_TOKEN_KEY, token) +} + +func LoadStoredRelayID(name string) (string, error) { + if envID := os.Getenv(INFISICAL_RELAY_ID_KEY); envID != "" { + return envID, nil + } + return loadConfKey(name, INFISICAL_RELAY_ID_KEY) +} + +func SaveRelayID(name, relayID string) error { + return saveConfKey(name, INFISICAL_RELAY_ID_KEY, relayID) +} + +func GetConfPathDisplay(name string) string { + path, err := relayConfPath(name) + if err != nil { + if runtime.GOOS == "linux" { + return "/etc/infisical/relays/" + name + ".conf" + } + return "~/.infisical/relays/" + name + ".conf" + } + return path +} diff --git a/packages/relay/relay.go b/packages/relay/relay.go index 91ceb98f..60ee8328 100644 --- a/packages/relay/relay.go +++ b/packages/relay/relay.go @@ -42,6 +42,8 @@ type RelayConfig struct { // Network Configuration Host string + + EnrollMethod string } type Relay struct { @@ -91,11 +93,15 @@ func (r *Relay) SetToken(token string) { func (r *Relay) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() error { var err error - heartbeatBody := api.RelayHeartbeatRequest{Name: r.config.RelayName} - if r.config.Type == "instance" { - err = api.CallInstanceRelayHeartBeat(r.httpClient, heartbeatBody) + if r.config.EnrollMethod != "" { + err = api.CallRelayHeartbeatV2(r.httpClient) } else { - err = api.CallOrgRelayHeartBeat(r.httpClient, heartbeatBody) + heartbeatBody := api.RelayHeartbeatRequest{Name: r.config.RelayName} + if r.config.Type == "instance" { + err = api.CallInstanceRelayHeartBeat(r.httpClient, heartbeatBody) + } else { + err = api.CallOrgRelayHeartBeat(r.httpClient, heartbeatBody) + } } if err != nil { @@ -201,6 +207,19 @@ func (r *Relay) Start(ctx context.Context) error { } func (r *Relay) registerRelay() error { + if r.config.EnrollMethod != "" { + certResp, err := api.CallRelayConnect(r.httpClient) + if err != nil { + return fmt.Errorf("failed to connect relay via v2: %v", err) + } + r.certificates = &api.RegisterRelayResponse{ + PKI: certResp.PKI, + SSH: certResp.SSH, + } + log.Info().Msg("Successfully connected relay and received certificates via v2 API") + return nil + } + body := api.RegisterRelayRequest{ Host: r.config.Host, Name: r.config.RelayName,