Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions packages/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ const (
operationCallGetMFASessionStatus = "CallGetMFASessionStatus"
operationCallOrgRelayHeartBeat = "CallOrgRelayHeartBeat"
operationCallInstanceRelayHeartBeat = "CallInstanceRelayHeartBeat"
operationCallRelayLogin = "CallRelayLogin"
operationCallRelayConnect = "CallRelayConnect"
operationCallRelayHeartbeatV2 = "CallRelayHeartbeatV2"
operationCallIssueCertificate = "CallIssueCertificate"
operationCallRetrieveCertificate = "CallRetrieveCertificate"
operationCallGetCertificateBundle = "CallGetCertificateBundle"
Expand Down Expand Up @@ -901,6 +904,62 @@ func CallGetRelays(httpClient *resty.Client) (GetRelaysResponse, error) {
return resBody, nil
}

func CallRelayLogin(httpClient *resty.Client, request RelayLoginRequest) (RelayLoginResponse, error) {
var resBody RelayLoginResponse
response, err := httpClient.
R().
SetResult(&resBody).
SetHeader("User-Agent", USER_AGENT).
SetBody(request).
Post(fmt.Sprintf("%v/v2/relays/login", config.INFISICAL_URL))

if err != nil {
return RelayLoginResponse{}, NewGenericRequestError(operationCallRelayLogin, err)
}

if response.IsError() {
return RelayLoginResponse{}, NewAPIErrorWithResponse(operationCallRelayLogin, response, nil)
}

return resBody, nil
}

func CallRelayConnect(httpClient *resty.Client) (RelayConnectResponse, error) {
var resBody RelayConnectResponse
response, err := httpClient.
R().
SetResult(&resBody).
SetHeader("User-Agent", USER_AGENT).
Post(fmt.Sprintf("%v/v2/relays/connect", config.INFISICAL_URL))

if err != nil {
return RelayConnectResponse{}, NewGenericRequestError(operationCallRelayConnect, err)
}

if response.IsError() {
return RelayConnectResponse{}, NewAPIErrorWithResponse(operationCallRelayConnect, response, nil)
}

return resBody, nil
}

func CallRelayHeartbeatV2(httpClient *resty.Client) error {
response, err := httpClient.
R().
SetHeader("User-Agent", USER_AGENT).
Post(fmt.Sprintf("%v/v2/relays/heartbeat", config.INFISICAL_URL))

if err != nil {
return NewGenericRequestError(operationCallRelayHeartbeatV2, err)
}

if response.IsError() {
return NewAPIErrorWithResponse(operationCallRelayHeartbeatV2, response, nil)
}

return nil
}

func CallConnectGateway(httpClient *resty.Client, request ConnectGatewayRequest) (RegisterGatewayResponse, error) {
var resBody RegisterGatewayResponse
response, err := httpClient.
Expand Down
29 changes: 29 additions & 0 deletions packages/api/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,35 @@ type RelayHeartbeatRequest struct {
Name string `json:"name"`
}

type RelayLoginRequest struct {
Method string `json:"method"`
Token string `json:"token,omitempty"`
RelayID string `json:"relayId,omitempty"`
HTTPRequestMethod string `json:"iamHttpRequestMethod,omitempty"`
IamRequestBody string `json:"iamRequestBody,omitempty"`
IamRequestHeaders string `json:"iamRequestHeaders,omitempty"`
}

type RelayLoginResponse struct {
AccessToken string `json:"accessToken"`
RelayID string `json:"relayId"`
TokenType string `json:"tokenType"`
}

type RelayConnectResponse struct {
RelayID string `json:"relayId"`
PKI struct {
ServerCertificate string `json:"serverCertificate"`
ServerPrivateKey string `json:"serverPrivateKey"`
ClientCertificateChain string `json:"clientCertificateChain"`
} `json:"pki"`
SSH struct {
ServerCertificate string `json:"serverCertificate"`
ServerPrivateKey string `json:"serverPrivateKey"`
ClientCAPublicKey string `json:"clientCAPublicKey"`
} `json:"ssh"`
}

type AltName struct {
Type string `json:"type"`
Value string `json:"value"`
Expand Down
170 changes: 158 additions & 12 deletions packages/cmd/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"syscall"
"time"

"github.com/Infisical/infisical-merge/packages/api"
"github.com/Infisical/infisical-merge/packages/config"
gatewayv2 "github.com/Infisical/infisical-merge/packages/gateway-v2"
"github.com/Infisical/infisical-merge/packages/relay"
"github.com/Infisical/infisical-merge/packages/util"
Expand Down Expand Up @@ -39,29 +41,172 @@ var relayStartCmd = &cobra.Command{
util.HandleError(err, fmt.Sprintf("unable to get name flag or %s env", gatewayv2.RELAY_NAME_ENV_NAME))
}

host, err := util.GetCmdFlagOrEnv(cmd, "host", []string{gatewayv2.RELAY_HOST_ENV_NAME})
if err != nil || host == "" {
util.HandleError(err, fmt.Sprintf("unable to get host flag or %s env", gatewayv2.RELAY_HOST_ENV_NAME))
enrollMethod, _ := cmd.Flags().GetString("enroll-method")
if enrollMethod == "" {
enrollMethod = os.Getenv("INFISICAL_RELAY_ENROLL_METHOD")
}
if enrollMethod != "" && enrollMethod != relay.EnrollMethodToken && enrollMethod != relay.EnrollMethodAws {
util.HandleError(fmt.Errorf("invalid --enroll-method %q: supported values are %q and %q",
enrollMethod, relay.EnrollMethodToken, relay.EnrollMethodAws))
}

host, _ := util.GetCmdFlagOrEnv(cmd, "host", []string{gatewayv2.RELAY_HOST_ENV_NAME})
if host == "" && enrollMethod == "" {
util.HandleError(fmt.Errorf("please provide host flag"), fmt.Sprintf("unable to get host flag or %s env", gatewayv2.RELAY_HOST_ENV_NAME))
Comment thread
saifsmailbox98 marked this conversation as resolved.
}

instanceType, err := util.GetCmdFlagOrEnvWithDefaultValue(cmd, "type", []string{gatewayv2.RELAY_TYPE_ENV_NAME}, "org")
if err != nil {
util.HandleError(err, fmt.Sprintf("unable to get type flag or %s env", gatewayv2.RELAY_TYPE_ENV_NAME))
}

var enrolledAccessToken string

// --- AWS Auth path ---
if enrollMethod == relay.EnrollMethodAws {
relayID, _ := cmd.Flags().GetString("relay-id")
if relayID == "" {
relayID = os.Getenv(relay.INFISICAL_RELAY_ID_KEY)
}
if relayID == "" {
stored, _ := relay.LoadStoredRelayID(relayName)
relayID = stored
}
if relayID == "" {
util.HandleError(errors.New("--relay-id is required when --enroll-method=aws"))
}

domain, _ := cmd.Flags().GetString("domain")
if domain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(domain)
} else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain)
}

httpClient, err := util.GetRestyClientWithCustomHeaders()
if err != nil {
util.HandleError(err, "unable to create HTTP client")
}

log.Info().Msg("Authenticating relay via AWS Auth (STS GetCallerIdentity)...")
accessTokenStr, err := relay.LoginRelayWithAws(cmd.Context(), httpClient, relayID)
if err != nil {
util.HandleError(err, "AWS Auth login failed")
}

enrolledAccessToken = accessTokenStr

if err := relay.SaveRelayID(relayName, relayID); err != nil {
util.HandleError(err, "failed to save relay id to config")
}

effectiveDomain := domain
if effectiveDomain == "" {
effectiveDomain = config.INFISICAL_URL
}
if effectiveDomain != "" {
if err := relay.SaveDomain(relayName, effectiveDomain); err != nil {
util.HandleError(err, "failed to save domain to config")
}
}

log.Info().Msgf("Relay authenticated via AWS Auth. State saved to %s", relay.GetConfPathDisplay(relayName))
log.Info().Msg("Starting relay...")
}

// --- Enrollment token path ---
if enrollMethod == relay.EnrollMethodToken {
enrollToken, _ := cmd.Flags().GetString("token")
if enrollToken == "" {
util.HandleError(errors.New("--token is required when --enroll-method=token"))
}

storedEnrollToken, _ := relay.LoadStoredEnrollmentToken(relayName)
alreadyEnrolled := storedEnrollToken != "" && storedEnrollToken == enrollToken

if alreadyEnrolled {
log.Info().Msg("Enrollment token matches stored token. Skipping enrollment.")
} else {
domain, _ := cmd.Flags().GetString("domain")
if domain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(domain)
} else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain)
}
Comment thread
saifsmailbox98 marked this conversation as resolved.

httpClient, err := util.GetRestyClientWithCustomHeaders()
if err != nil {
util.HandleError(err, "unable to create HTTP client")
}

log.Info().Msg("Enrolling relay with enrollment token...")
enrollResp, err := api.CallRelayLogin(httpClient, api.RelayLoginRequest{
Method: "token",
Token: enrollToken,
})
if err != nil {
util.HandleError(err, "enrollment failed")
}

enrolledAccessToken = enrollResp.AccessToken
if err := relay.SaveAccessToken(relayName, enrollResp.AccessToken); err != nil {
util.HandleError(err, "failed to save relay access token")
}
if err := relay.SaveEnrollmentToken(relayName, enrollToken); err != nil {
util.HandleError(err, "failed to save enrollment token to config")
}

effectiveDomain := domain
if effectiveDomain == "" {
effectiveDomain = config.INFISICAL_URL
}
if effectiveDomain != "" {
if err := relay.SaveDomain(relayName, effectiveDomain); err != nil {
util.HandleError(err, "failed to save domain to config")
}
}

log.Info().Msgf("Relay enrolled successfully. Access token saved to %s", relay.GetConfPathDisplay(relayName))
}

log.Info().Msg("Starting relay...")
}

// --- Domain resolution for resource auth / stored token ---
isResourceAuth := enrollMethod == relay.EnrollMethodToken || enrollMethod == relay.EnrollMethodAws
if isResourceAuth {
if flagDomain, _ := cmd.Flags().GetString("domain"); flagDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(flagDomain)
} else if storedDomain, _ := relay.LoadStoredDomain(relayName); storedDomain != "" {
config.INFISICAL_URL = util.AppendAPIEndpoint(storedDomain)
}
}

relayInstance, err := relay.NewRelay(&relay.RelayConfig{
RelayName: relayName,
SSHPort: "2222",
TLSPort: "8443",
Host: host,
Type: instanceType,
RelayName: relayName,
SSHPort: "2222",
TLSPort: "8443",
Host: host,
Type: instanceType,
EnrollMethod: enrollMethod,
})

if err != nil {
util.HandleError(err, "unable to create relay instance")
}

if instanceType == "instance" {
if isResourceAuth {
// Use the freshly enrolled token, or load the stored one.
if enrolledAccessToken != "" {
relayInstance.SetToken(enrolledAccessToken)
} else {
storedToken, err := relay.LoadStoredAccessToken(relayName)
if err != nil || storedToken == "" {
util.HandleError(errors.New("no stored access token found — re-run with enrollment token"))
}
relayInstance.SetToken(storedToken)
}
} else if instanceType == "instance" {
relayAuthSecret := os.Getenv(gatewayv2.RELAY_AUTH_SECRET_ENV_NAME)
if relayAuthSecret == "" {
util.HandleError(fmt.Errorf("%s is not set", gatewayv2.RELAY_AUTH_SECRET_ENV_NAME), "unable to get relay auth secret")
Expand Down Expand Up @@ -96,7 +241,6 @@ var relayStartCmd = &cobra.Command{
cancelCmd()
cancelSdk()

// Give graceful shutdown 10 seconds, then force exit on second signal
select {
case <-sigCh:
log.Warn().Msg("Second signal received, force exit triggered")
Expand All @@ -107,7 +251,6 @@ var relayStartCmd = &cobra.Command{
}
}()

// Token refresh goroutine - runs every 10 seconds
go func() {
tokenRefreshTicker := time.NewTicker(10 * time.Second)
defer tokenRefreshTicker.Stop()
Expand Down Expand Up @@ -259,8 +402,11 @@ func init() {
relayStartCmd.Flags().String("type", "", "The type of relay to run. Defaults to 'org'")
relayStartCmd.Flags().String("host", "", "The IP or hostname for the relay")
relayStartCmd.Flags().String("name", "", "The name of the relay")
relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token. if not provided, you must set the auth-method flag")
relayStartCmd.Flags().String("token", "", "connect with Infisical using machine identity access token, or a one-time enrollment token when --enroll-method=token")
relayStartCmd.Flags().String("auth-method", "", "login method [universal-auth, kubernetes, azure, gcp-id-token, gcp-iam, aws-iam, oidc-auth]. if not provided, you must set the token flag")
relayStartCmd.Flags().String("enroll-method", "", "relay auth method [token, aws]. when set to 'token', uses --token as a one-time enrollment token. when set to 'aws', authenticates via signed STS GetCallerIdentity using --relay-id")
relayStartCmd.Flags().String("relay-id", "", "relay id (required when --enroll-method=aws)")
relayStartCmd.Flags().String("domain", "", "domain of your self-hosted Infisical instance (used with --enroll-method)")
relayStartCmd.Flags().String("client-id", "", "client id for universal auth")
relayStartCmd.Flags().String("client-secret", "", "client secret for universal auth")
relayStartCmd.Flags().String("machine-identity-id", "", "machine identity id for kubernetes, azure, gcp-id-token, gcp-iam, and aws-iam auth methods")
Expand Down
Loading
Loading