From 71aba63cc37b147ec20f562b40bf89b2b0c96def Mon Sep 17 00:00:00 2001 From: x032205 Date: Fri, 8 May 2026 04:04:28 -0400 Subject: [PATCH 01/10] fix: detect and clean up dead gateway connections to prevent OOM --- packages/gateway-v2/gateway.go | 84 ++++++++++++++++++++++++++++-- packages/pam/handlers/ssh/proxy.go | 29 +++++++++++ packages/pam/pam-proxy.go | 64 ++++++++++++++++++----- packages/util/ssh.go | 23 ++++++++ 4 files changed, 181 insertions(+), 19 deletions(-) create mode 100644 packages/util/ssh.go diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 07c32120..4d05872a 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -14,6 +14,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/Infisical/infisical-merge/packages/api" @@ -86,8 +87,9 @@ type GatewayConfig struct { } type pamSessionEntry struct { - cancel context.CancelFunc - conn *tls.Conn + cancel context.CancelFunc + conn *tls.Conn + lastActivity atomic.Int64 } type Gateway struct { @@ -166,10 +168,18 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { } // RegisterPAMSession registers an active PAM proxy connection for cancellation support -func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) { +// Returns a function that handlers should call when data flows through the connection +func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) func() { + entry := &pamSessionEntry{cancel: cancel, conn: conn} + entry.lastActivity.Store(time.Now().Unix()) + g.pamSessionsMu.Lock() defer g.pamSessionsMu.Unlock() - g.pamSessions[sessionID] = append(g.pamSessions[sessionID], &pamSessionEntry{cancel: cancel, conn: conn}) + g.pamSessions[sessionID] = append(g.pamSessions[sessionID], entry) + + return func() { + entry.lastActivity.Store(time.Now().Unix()) + } } // DeregisterPAMSession removes a specific connection from the session registry. @@ -264,6 +274,48 @@ func (g *Gateway) closeMongoProxy(sessionID string) { } } +const pamIdleTimeout = 30 * time.Minute + +// startIdleReaper periodically scans the PAM session registry and cancels +// sessions whose connections have had no data flow for pamIdleTimeout +func (g *Gateway) startIdleReaper(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + g.reapIdleSessions() + } + } +} + +func (g *Gateway) reapIdleSessions() { + cutoff := time.Now().Add(-pamIdleTimeout).Unix() + + g.pamSessionsMu.Lock() + var stale []string + for sessionID, entries := range g.pamSessions { + allIdle := true + for _, e := range entries { + if e.lastActivity.Load() > cutoff { + allIdle = false + break + } + } + if allIdle { + stale = append(stale, sessionID) + } + } + g.pamSessionsMu.Unlock() + + for _, sessionID := range stale { + log.Info().Str("sessionId", sessionID).Dur("idleTimeout", pamIdleTimeout).Msg("Reaping idle PAM session") + g.CancelPAMSession(sessionID) + } +} + func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() error { if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { @@ -329,6 +381,8 @@ func (g *Gateway) Start(ctx context.Context) error { // Start session uploader goroutine for PAM g.pamSessionUploader.Start() + go g.startIdleReaper(ctx) + go func() { for { select { @@ -489,6 +543,25 @@ func (g *Gateway) handleConnection(client *ssh.Client) error { client.Close() }() + // Keepalive on the relay SSH connection. If the relay drops silently, + // this closes the client so the reconnect loop in connectWithRetry kicks in + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := util.SSHKeepalive(client, 15*time.Second); err != nil { + log.Warn().Err(err).Msg("Relay SSH keepalive failed, closing connection") + client.Close() + return + } + case <-g.ctx.Done(): + return + } + } + }() + // Process incoming channels with context cancellation support for { select { @@ -751,8 +824,9 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } else if forwardConfig.Mode == ForwardModePAM { sessionCtx, sessionCancel := context.WithCancel(g.ctx) - g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) + touchSession := g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) defer g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn) + forwardConfig.PAMConfig.OnActivity = touchSession if err := pam.HandlePAMProxy(sessionCtx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil { if err.Error() == "unexpected EOF" { log.Debug().Err(err).Msg("PAM proxy handler ended with unexpected connection termination") diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go index 0d2657af..2897b083 100644 --- a/packages/pam/handlers/ssh/proxy.go +++ b/packages/pam/handlers/ssh/proxy.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/Infisical/infisical-merge/packages/util" "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) @@ -26,6 +27,7 @@ type SSHProxyConfig struct { SessionID string SessionLogger session.SessionLogger BlockedCommandPatterns []*regexp.Regexp // Regex patterns for command blocking (nil = no blocking) + OnActivity func() // Called when channel data flows } // SSHProxy handles proxying SSH connections with credential injection @@ -123,6 +125,29 @@ func (p *SSHProxy) HandleConnection(ctx context.Context, clientConn net.Conn) er // Discard global requests (not needed for basic remote access) go ssh.DiscardRequests(clientRequests) + // SSH keepalive: detect dead connections where TCP goes silent. Probes both sides every 30s + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := util.SSHKeepalive(clientSSHConn, 15*time.Second); err != nil { + log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to client failed, tearing down connection") + clientConn.Close() + return + } + if err := util.SSHKeepalive(serverSSHConn, 15*time.Second); err != nil { + log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to target failed, tearing down connection") + clientConn.Close() + return + } + case <-ctx.Done(): + return + } + } + }() + // Handle channels from client (this is where actual SSH sessions happen) for newChannel := range clientChannels { go p.handleChannel(ctx, newChannel, serverSSHConn, sessionID) @@ -500,6 +525,10 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses for { n, err := src.Read(buf) if n > 0 { + if p.config.OnActivity != nil { + p.config.OnActivity() + } + // Check if this channel is a binary session (SFTP/SCP) chState.mutex.Lock() isBinary := chState.isBinarySession diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 3c44db0d..3feb9c60 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "regexp" + "sync" "time" "github.com/Infisical/infisical-merge/packages/api" @@ -37,6 +38,7 @@ type GatewayPAMConfig struct { CredentialsManager *session.CredentialsManager SessionUploader *session.SessionUploader GetMongoProxy MongoProxyGetter // Session-level MongoDB proxy sharing + OnActivity func() // Called on data flow } type PAMCapabilitiesResponse struct { @@ -109,6 +111,28 @@ func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *Gatew return nil } +// activityConn wraps a net.Conn and calls onActivity on every successful read or write +type activityConn struct { + net.Conn + onActivity func() +} + +func (c *activityConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + if n > 0 { + c.onActivity() + } + return n, err +} + +func (c *activityConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + if n > 0 { + c.onActivity() + } + return n, err +} + // compilePolicyPatterns compiles regex pattern strings, logging warnings for any that fail. func compilePolicyPatterns(config *api.PAMPolicyRuleConfig, sessionID string, ruleType string) []*regexp.Regexp { if config == nil || len(config.Patterns) == 0 { @@ -138,6 +162,17 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo return fmt.Errorf("failed to retrieve PAM session credentials: %w", err) } + // Cleanup must run exactly once regardless of how the session ends + var cleanupOnce sync.Once + cleanupSession := func(reason string) { + cleanupOnce.Do(func() { + if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, reason); err != nil { + log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Str("reason", reason).Msg("Failed to cleanup PAM session") + } + }) + } + defer cleanupSession("connection_closed") + // Start a goroutine to monitor session expiry and close connection when exceeded go func() { timeUntilExpiry := time.Until(pamConfig.ExpiryTime) @@ -153,10 +188,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session expired, closing connection") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "expiry"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session on expiry") - } - + cleanupSession("expiry") conn.Close() case <-ctx.Done(): // Context cancelled, exit gracefully @@ -169,10 +201,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session already expired, closing connection immediately") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "already_expired"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup already expired PAM session") - } - + cleanupSession("already_expired") conn.Close() } }() @@ -234,6 +263,12 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo } } + // Wrap the connection so every read/write resets the idle reaper timer + var handlerConn net.Conn = conn + if pamConfig.OnActivity != nil { + handlerConn = &activityConn{Conn: conn, onActivity: pamConfig.OnActivity} + } + switch pamConfig.ResourceType { case session.ResourceTypePostgres: proxyConfig := handlers.PostgresProxyConfig{ @@ -252,7 +287,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", proxyConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting PostgreSQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMysql: mysqlConfig := mysql.MysqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -271,7 +306,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", mysqlConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MySQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMssql: mssqlConfig := mssql.MssqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -290,7 +325,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", mssqlConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MSSQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeRedis: redisConfig := redis.RedisProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -308,7 +343,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", redisConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting Redis PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeSSH: // Compile command blocking patterns from policy rules var blockedCommandPatterns []*regexp.Regexp @@ -326,6 +361,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo SessionID: pamConfig.SessionId, SessionLogger: sessionLogger, BlockedCommandPatterns: blockedCommandPatterns, + OnActivity: pamConfig.OnActivity, } proxy := ssh.NewSSHProxy(sshConfig) log.Info(). @@ -379,7 +415,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", kubernetesConfig.TargetApiServer). Str("authMethod", credentials.AuthMethod). Msg("Starting Kubernetes PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMongodb: mongoConfig := mongodb.MongoDBProxyConfig{ Host: credentials.ConnectionString, @@ -404,7 +440,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo return fmt.Errorf("MongoDB proxy init: %w", err) } - return proxy.HandleConnection(ctx, conn, sessionLogger) + return proxy.HandleConnection(ctx, handlerConn, sessionLogger) default: return fmt.Errorf("unsupported resource type: %s", pamConfig.ResourceType) } diff --git a/packages/util/ssh.go b/packages/util/ssh.go new file mode 100644 index 00000000..d207a75d --- /dev/null +++ b/packages/util/ssh.go @@ -0,0 +1,23 @@ +package util + +import ( + "fmt" + "time" + + "golang.org/x/crypto/ssh" +) + +// SSHKeepalive sends an SSH keepalive request and waits up to timeout for a response +func SSHKeepalive(conn ssh.Conn, timeout time.Duration) error { + errCh := make(chan error, 1) + go func() { + _, _, err := conn.SendRequest("keepalive@openssh.com", true, nil) + errCh <- err + }() + select { + case err := <-errCh: + return err + case <-time.After(timeout): + return fmt.Errorf("no keepalive response within %v", timeout) + } +} From 786fcd72eb6aa1f17d19250a208c60f193f9f035 Mon Sep 17 00:00:00 2001 From: x032205 Date: Fri, 8 May 2026 13:35:41 -0400 Subject: [PATCH 02/10] codex review fix --- packages/gateway-v2/gateway.go | 22 +++++++++++++++++++--- packages/pam/pam-proxy.go | 14 -------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 4d05872a..5d2a6a9c 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -183,13 +183,18 @@ func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc } // DeregisterPAMSession removes a specific connection from the session registry. +// Returns true if this was the last connection for the session. // The MongoDB proxy (if any) is NOT closed here — it persists across connections // so that subsequent client connections (e.g. mongosh retries) find a warm topology. // The proxy is cleaned up on session cancellation or gateway shutdown. -func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) { +func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) bool { g.pamSessionsMu.Lock() defer g.pamSessionsMu.Unlock() - entries := g.pamSessions[sessionID] + + entries, exists := g.pamSessions[sessionID] + if !exists { + return false + } for i, e := range entries { if e.conn == conn { g.pamSessions[sessionID] = append(entries[:i], entries[i+1:]...) @@ -198,7 +203,9 @@ func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) { } if len(g.pamSessions[sessionID]) == 0 { delete(g.pamSessions, sessionID) + return true } + return false } // CancelPAMSession kills all active connections for a PAM session @@ -313,6 +320,9 @@ func (g *Gateway) reapIdleSessions() { for _, sessionID := range stale { log.Info().Str("sessionId", sessionID).Dur("idleTimeout", pamIdleTimeout).Msg("Reaping idle PAM session") g.CancelPAMSession(sessionID) + if err := g.pamSessionUploader.CleanupPAMSession(sessionID, "idle_timeout"); err != nil { + log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to cleanup reaped PAM session") + } } } @@ -825,7 +835,6 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { } else if forwardConfig.Mode == ForwardModePAM { sessionCtx, sessionCancel := context.WithCancel(g.ctx) touchSession := g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) - defer g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn) forwardConfig.PAMConfig.OnActivity = touchSession if err := pam.HandlePAMProxy(sessionCtx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil { if err.Error() == "unexpected EOF" { @@ -834,6 +843,13 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Error().Err(err).Msg("PAM proxy handler ended with error") } } + if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn { + if err := forwardConfig.PAMConfig.SessionUploader.CleanupPAMSession( + forwardConfig.PAMConfig.SessionId, "connection_closed", + ); err != nil { + log.Error().Err(err).Str("sessionId", forwardConfig.PAMConfig.SessionId).Msg("Failed to cleanup PAM session") + } + } return } else if forwardConfig.Mode == ForwardModePAMCancellation { if err := pam.HandlePAMCancellation(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient, g.CancelPAMSession); err != nil { diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 3feb9c60..ee7fb2be 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -10,7 +10,6 @@ import ( "net/url" "os" "regexp" - "sync" "time" "github.com/Infisical/infisical-merge/packages/api" @@ -162,17 +161,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo return fmt.Errorf("failed to retrieve PAM session credentials: %w", err) } - // Cleanup must run exactly once regardless of how the session ends - var cleanupOnce sync.Once - cleanupSession := func(reason string) { - cleanupOnce.Do(func() { - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, reason); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Str("reason", reason).Msg("Failed to cleanup PAM session") - } - }) - } - defer cleanupSession("connection_closed") - // Start a goroutine to monitor session expiry and close connection when exceeded go func() { timeUntilExpiry := time.Until(pamConfig.ExpiryTime) @@ -188,7 +176,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session expired, closing connection") - cleanupSession("expiry") conn.Close() case <-ctx.Done(): // Context cancelled, exit gracefully @@ -201,7 +188,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session already expired, closing connection immediately") - cleanupSession("already_expired") conn.Close() } }() From f9dd5b63916332fe87aa984d0f76a464196a4836 Mon Sep 17 00:00:00 2001 From: x032205 Date: Fri, 8 May 2026 13:39:36 -0400 Subject: [PATCH 03/10] address claude review --- packages/pam/handlers/ssh/proxy.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go index 2897b083..8d9a6195 100644 --- a/packages/pam/handlers/ssh/proxy.go +++ b/packages/pam/handlers/ssh/proxy.go @@ -773,6 +773,10 @@ func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer, for { n, err := src.Read(buf) if n > 0 { + if p.config.OnActivity != nil { + p.config.OnActivity() + } + chState.mutex.Lock() isBinary := chState.isBinarySession sftpParser := chState.sftpParser From 19a086fe86640c7fcd26bb5871a0a62c6514e969 Mon Sep 17 00:00:00 2001 From: x032205 Date: Fri, 8 May 2026 14:12:10 -0400 Subject: [PATCH 04/10] move session cleanup to only run if a session was found and cancelled --- packages/pam/pam-proxy.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 614f42f7..3240e682 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -105,14 +105,13 @@ func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *Gatew // Kill the active proxy connection if it exists in the registry if cancelled := cancelSession(pamConfig.SessionId); cancelled { log.Info().Str("sessionId", pamConfig.SessionId).Msg("Active proxy session cancelled via registry") + if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { + log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") + } } else { log.Info().Str("sessionId", pamConfig.SessionId).Msg("No active proxy session found in registry (may have already ended)") } - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") - } - conn.Close() return nil From 268a8e65f13e61cb487d5633a7247c41e4aa9b5b Mon Sep 17 00:00:00 2001 From: x032205 Date: Fri, 8 May 2026 14:24:40 -0400 Subject: [PATCH 05/10] tiny noise fix --- packages/gateway-v2/gateway.go | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 5d2a6a9c..b8f4e127 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -843,6 +843,7 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Error().Err(err).Msg("PAM proxy handler ended with error") } } + sessionCancel() if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn { if err := forwardConfig.PAMConfig.SessionUploader.CleanupPAMSession( forwardConfig.PAMConfig.SessionId, "connection_closed", From bfab68580449afa18390aedc08a17117f51b340d Mon Sep 17 00:00:00 2001 From: bernie-g Date: Fri, 8 May 2026 18:11:10 -0400 Subject: [PATCH 06/10] fix(release): remove create-release-draft and fix --skip=build goreleaser v1.x's mode: append cannot find draft releases (GitHub's "get release by tag" API excludes drafts), so the pre-created draft was ignored and goreleaser created a second, separate release. Fix: remove the create-release-draft job and let the goreleaser jobs handle release creation themselves via mode: append. The first job to finish publishing creates the release, the others append to it. Also replaces the broken --skip=build (v2-only flag) with separate dry-run/release goreleaser steps. --- .../workflows/release_build_infisical_cli.yml | 49 +++---------------- .goreleaser-darwin.yaml | 2 - .goreleaser-windows.yaml | 2 - .goreleaser.yaml | 4 -- 4 files changed, 6 insertions(+), 51 deletions(-) diff --git a/.github/workflows/release_build_infisical_cli.yml b/.github/workflows/release_build_infisical_cli.yml index d388356a..bb5a74e9 100644 --- a/.github/workflows/release_build_infisical_cli.yml +++ b/.github/workflows/release_build_infisical_cli.yml @@ -31,37 +31,6 @@ jobs: build-rdp-bridge: uses: ./.github/workflows/build-rdp-bridge.yml - # Create the GitHub release draft up front so both goreleaser - # (ubuntu) and goreleaser-darwin (macos) can append to it in - # parallel instead of serializing on ubuntu creating the draft. - # Skipped on dry-run since --snapshot doesn't touch GitHub at all. - create-release-draft: - if: | - always() && - (needs.validate-tag-branch.result == 'success' || needs.validate-tag-branch.result == 'skipped') && - needs.cli-tests.result == 'success' && - (github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run)) - needs: - - validate-tag-branch - - cli-tests - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Create GitHub release draft (idempotent) - run: | - if gh release view "${{ github.ref_name }}" >/dev/null 2>&1; then - echo "Release for ${{ github.ref_name }} already exists, skipping creation" - else - gh release create "${{ github.ref_name }}" \ - --draft \ - --title "${{ github.ref_name }}" \ - --generate-notes - fi - env: - GH_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} - # cli-integration-tests: # name: Run tests before deployment # uses: ./.github/workflows/run-cli-tests.yml @@ -136,12 +105,10 @@ jobs: - validate-tag-branch - cli-tests - build-rdp-bridge - - create-release-draft if: | always() && needs.cli-tests.result == 'success' && needs.build-rdp-bridge.result == 'success' && - (needs.create-release-draft.result == 'success' || needs.create-release-draft.result == 'skipped') && (github.event_name == 'workflow_dispatch' || needs.validate-tag-branch.result == 'success') steps: - uses: actions/checkout@v3 @@ -218,14 +185,13 @@ jobs: mkdir -p "$target_dir" cp "/tmp/rdp-bridge-artifacts/rdp-bridge-$triple/libinfisical_rdp_bridge.a" "$target_dir/" done - - name: GoReleaser (build, no publish) + - name: GoReleaser (dry-run snapshot) + if: github.event_name == 'workflow_dispatch' && inputs.dry_run uses: goreleaser/goreleaser-action@v4 with: distribution: goreleaser-pro version: v1.26.2-pro - args: >- - release --clean --skip=publish,announce - ${{ (github.event_name == 'workflow_dispatch' && inputs.dry_run) && '--snapshot' || '' }} + args: release --clean --snapshot --skip=publish env: GITHUB_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} POSTHOG_API_KEY_FOR_CLI: ${{ secrets.POSTHOG_API_KEY_FOR_CLI }} @@ -240,6 +206,7 @@ jobs: path: dist/ retention-days: 7 - name: Smoke test linux binary across supported distros + if: github.event_name == 'workflow_dispatch' && inputs.dry_run run: | set -uo pipefail fail=0 @@ -282,13 +249,13 @@ jobs: echo "::endgroup::" [ "$fail" -eq 0 ] || exit 1 - - name: GoReleaser (publish) + - name: GoReleaser (release) if: github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run) uses: goreleaser/goreleaser-action@v4 with: distribution: goreleaser-pro version: v1.26.2-pro - args: release --skip=build,validate,before + args: release --clean env: GITHUB_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} POSTHOG_API_KEY_FOR_CLI: ${{ secrets.POSTHOG_API_KEY_FOR_CLI }} @@ -344,12 +311,10 @@ jobs: - validate-tag-branch - cli-tests - build-rdp-bridge - - create-release-draft if: | always() && needs.cli-tests.result == 'success' && needs.build-rdp-bridge.result == 'success' && - (needs.create-release-draft.result == 'success' || needs.create-release-draft.result == 'skipped') && (github.event_name == 'workflow_dispatch' || needs.validate-tag-branch.result == 'success') steps: - uses: actions/checkout@v4 @@ -426,11 +391,9 @@ jobs: needs: - validate-tag-branch - cli-tests - - create-release-draft if: | always() && needs.cli-tests.result == 'success' && - (needs.create-release-draft.result == 'success' || needs.create-release-draft.result == 'skipped') && (github.event_name == 'workflow_dispatch' || needs.validate-tag-branch.result == 'success') steps: - uses: actions/checkout@v3 diff --git a/.goreleaser-darwin.yaml b/.goreleaser-darwin.yaml index adfe00b1..2a6a8e96 100644 --- a/.goreleaser-darwin.yaml +++ b/.goreleaser-darwin.yaml @@ -45,9 +45,7 @@ archives: - manpages/* - completions/* -# Append to the release draft created by the ubuntu goreleaser job. release: - replace_existing_draft: false mode: append checksum: diff --git a/.goreleaser-windows.yaml b/.goreleaser-windows.yaml index 73752b9d..b945f8ab 100644 --- a/.goreleaser-windows.yaml +++ b/.goreleaser-windows.yaml @@ -1,6 +1,4 @@ -# Append to the release draft created by the ubuntu goreleaser job. release: - replace_existing_draft: false mode: append builds: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index b39dfde9..1b0a5362 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -161,10 +161,6 @@ archives: - completions/* release: - # The draft is created up front by the create-release-draft workflow - # job, so both this config and .goreleaser-darwin.yaml use append mode - # to add their artifacts in parallel. - replace_existing_draft: false mode: append checksum: From d422e12dd348b06d61cfb9ab8bfa1d08d8743c1f Mon Sep 17 00:00:00 2001 From: Saif Ur Rahman Date: Sat, 9 May 2026 09:04:16 +0530 Subject: [PATCH 07/10] feat(pam): add Oracle DB access support (#192) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(pam): Oracle database support via proxied-auth gateway Adds Oracle DB as the 8th PAM handler. Gateway accepts client connections with a placeholder password, proxies pre-auth bytes verbatim to upstream, intercepts at O5Logon to swap client-supplied placeholder-keyed material for real-password-keyed material, and byte-relays post-auth. Credential injection works end-to-end for JDBC thin clients (sqlcl, SQL Developer, DBeaver) and go-ora; user never sees real Oracle credentials. Handler lives in packages/pam/handlers/oracle/. Ports crypto primitives (PBKDF2+SHA512, AES-CBC session-key encryption, PKCS5 padding) and TTC codec (compressed ints, CLR byte strings, KVP encoding) from MIT-licensed sijms/go-ora. See ATTRIBUTION.md and ORACLE_PAM_NOTES.md for architecture notes and handoff. Known dead code remains in this commit from the earlier full- impersonation attempt (ano.go, nego.go, nego_templates.go, parts of o5logon*.go, upstream.go, handshake_test.go). Kept intact to preserve history of the approach; cleanup follows as a separate commit. * refactor(pam-oracle): remove impersonation-era dead code The initial attempt used server-side Oracle impersonation (see prior commit). That design worked through authentication but hit a state-mismatch problem post-auth: upstream (via go-ora) and client negotiated different TTC capabilities, so relayed queries were rejected as protocol violations. The replacement — proxied-auth — is already in proxy_auth.go and is the flow wired to HandleConnection. This commit removes the dead files and vestigial symbols that supported the old path: Removed entirely: - ano.go, nego.go, nego_templates.go (pre-auth TNS/TTC negotiation handlers; pre-auth is now forwarded verbatim to upstream) - upstream.go (go-ora-based upstream dial + KVP extraction; replaced by dialUpstreamRaw in proxy_auth.go) - handshake_test.go (tested the impersonation path, orphaned) Pruned: - proxy.go: handleConnectionLegacy (~200 LOC) - o5logon.go: O5LogonServerState, NewO5LogonServerState, VerifyClientPassword, deriveKey11g, md5Hash, parseIntVal - o5logon_server.go: AuthPhaseOne, ParseAuthPhaseOne, BuildAuthPhaseOneResponse, BuildAuthPhaseTwoResponse, BuildAuthPhaseTwoResponseFromUpstream, RunServerO5Logon, dumpBytes, readUint32 - tns.go: AcceptPacket, AcceptFromConnect, ConnectPacket, ParseConnectPacket, MarkerPacketBytes (we forward raw packet bytes rather than parse/build CONNECT or ACCEPT) Kept: crypto primitives, DATA packet codec, TTC reader/builder, query logger, prependedConn, error helpers — all live in the proxied-auth flow. Drops github.com/sijms/go-ora/v2 from go.mod — no longer imported. Adds .idea and .vscode to .gitignore. Updates ORACLE_PAM_NOTES.md with a current-state header; historical sections below retained for context. Net: ~1,600 LOC removed. The handler directory goes from 12 files to 8. Build, fmt, and the Oracle SQL test matrix (SELECT, INSERT, DDL, PL/SQL, bind vars, NLS queries) still pass against sqlcl → gateway → AWS RDS Oracle 19c. * fix(pam-oracle): capture SQL in session recordings The TTC query extractor was logging empty strings for OALL8 payloads because tryExtractSQL's "skip 6 compressed ints then read a CLR" heuristic didn't land on the SQL text — the OALL8 wire format has variable-length headers that differ by client driver and bind pattern. As a result, session recordings contained only session headers (124 bytes) with no actual query content. Replace structured parsing with a simple scan for the longest printable ASCII run in the payload. In practice the SQL text is always the longest such run. Verified with sqlcl: .enc file grows from 124 bytes (empty) to ~880 bytes (with captured queries) for a SELECT + COMMIT session. This only affects content — the tap, packet demultiplexing, and encrypted file I/O were all working correctly. Fix is localised to tryExtractSQL. * fix(pam-oracle): capture queries bundled behind piggyback OCLOSE sqlcl (and other JDBC thin clients) frequently bundle a piggybacked OCLOSE for the previous cursor with the next OALL8 query in a single TTC packet. The previous parser checked byte 0 for 0x03 (function call) and bailed when it saw 0x11 (piggyback marker), missing the OALL8 underneath. Scan the payload for the function-call + OALL8 byte pair instead, so the parser finds the query regardless of any preceding piggyback prefix. Same treatment for COMMIT and ROLLBACK, which also get piggybacked. * feat(pam-oracle): TLS (TCPS) support for upstream Matches the other SQL handlers' pattern: the client speaks plain TCP to our local listener, we do TLS to the upstream database and translate in the middle. No change to the client-facing UX — the same JDBC URL that works against a plain-TCP Oracle resource now also works against a TCPS-enabled one. Implementing this correctly required discovering Oracle TCPS's two-handshake flow (by reading go-ora's network/session.go readPacket RESEND branch — credit there): 1. Dial TCP, do an initial TLS handshake. Forward the client's CONNECT through this first TLS session. 2. When the upstream RESEND packet's byte-5 flag has 0x08 set, Oracle expects the client to abandon the first TLS session and run a FRESH TLS handshake on the bare TCP socket. Server drops its first-round TLS state in lockstep. We mirror this via upgradeToTLS on the same rawConn, then continue the Oracle handshake through the new session. 3. Mask byte 5 on packets going downstream so thin clients (JDBC thin, python-oracledb thin) don't see the 0x08 signal and try to cast their local TcpNTAdapter to TcpsNTAdapter — the cast would fail because the client-to-proxy socket is plain TCP. 4. Accept TLS 1.0 as the floor for the upstream dial: Oracle 19c's second-round handshake negotiates down to 1.0 in some configurations (AWS RDS's SSL option being one of them). The outer ALPN mTLS tunnel remains TLS 1.2+. Also removes now-dead SSLRejectUnauthorized / SSLCertificate fields from OracleProxyConfig — the shared TLSConfig built in pam-proxy.go carries that information already. Verified end-to-end against AWS RDS Oracle 19c (SSL option, port 2484) with sqlcl: authentication, DDL, DML+COMMIT, PL/SQL, DBMS_OUTPUT, bind variables, and session recording all work. Plain-TCP Oracle path is unchanged. * chore(pam-oracle): tighten TCPS TLS config comment Pure comment cleanup on buildOracleTLSConfig — no behavior change. * chore(pam-oracle): update stale comments + attribution - ATTRIBUTION.md: drop reference to nego.go (deleted in the impersonation-era cleanup) and add the upstream TCPS two-handshake flow adaptation from go-ora's session.readPacket RESEND branch. - o5logon_server.go: the file header still described the impersonation-era architecture where the gateway acted as an Oracle server and drove the O5Logon exchange. Rewrite it to describe the current proxied-auth role — packet-layer helpers used by the byte- level O5Logon translation in proxy_auth.go — and drop the reference to upstream.go, which was removed in 3ff9cff. * chore(pam-oracle): fix go vet warnings in proxy_auth.go Two warnings, both from the original feat commit (4b25ec2): - Dead self-assignment state.ServerSessKey = state.ServerSessKey on the phase-2 request translation path. The inline comment "no-op; kept for clarity" already admitted it was pointless. Delete. - Unreachable _ = prefix after a return statement inside replaceKvpValueKeepingSize. The prefix variable was left over from a refactor — the new code uses oldStart / oldEnd instead — and the _ = prefix trick to silence "declared but not used" landed on the wrong side of the return. Delete both the dead prefix declaration and the unreachable suppression. Also drops valStart (only used by the dead prefix computation). No behavior change; go vet is now clean on the oracle handler. * chore(pam-oracle): drop TNS_ADMIN dance, shorten placeholder password Two UX simplifications to the Oracle PAM access command: 1. Stop creating a per-session TNS_ADMIN directory and printing an `export TNS_ADMIN=...` line in the connect instructions. The directory only existed to hold a sqlnet.ora that set DISABLE_OOB=TRUE — a defence against sqlcl's out-of-band Ctrl-C signalling that we never actually observed breaking. If a real interrupt problem surfaces we can revisit with a proper fix instead of a per-session file dance. 2. Change ProxyPasswordPlaceholder from "infisical-pam-proxy" to "password". The string value is cryptographically arbitrary — Oracle's O5Logon needs the client and gateway to agree on SOME string; any works. Shorter is easier to copy-paste. The accompanying "not a real credential" note in the CLI output stays. * chore(pam-oracle): drop unrelated changes from branch Three items that snuck in but aren't part of Oracle PAM: 1. Untrack ORACLE_PAM_NOTES.md. It's a development notes/research log that belongs with scratch work, not in the branch. The file stays on disk (uncommitted) for local reference. Its content is also stale since the placeholder password changed from "infisical-pam-proxy" to "password" in 3ee7951 without matching updates to the notes. 2. Revert the github.com/emirpasic/gods indirect bump from v1.12.0 to v1.18.1 in go.mod/go.sum. This was residue from when go-ora/v2 was temporarily added as a direct dep in the initial feat commit — go-ora needed the newer gods, and Go module MVS held the bumped version even after go-ora was removed in 3ff9cff. Running `go mod tidy` against current HEAD (with main's go.mod/go.sum as the baseline) produces no further changes, confirming nothing we ship actually needs v1.18.1. .gitignore additions (.vscode, .idea) stay — reasonable hygiene that doesn't hurt the PR and will prevent future editor-file noise. * fix(pam-oracle): AI-review findings + username parity with other handlers - Password plaintext no longer embedded in "password mismatch" error (was bubbling to gateway logs via zerolog's .Err chain). - Long Oracle passwords (≥ 96 chars) now encode correctly: replaceKVPValue routes AUTH_PASSWORD through TTCBuilder.PutClr so values above the short-form threshold emit the 0xFE chunked form instead of a truncated single-byte length. - Client-supplied username rewritten to InjectUsername in phase-1 and phase-2 auth requests. Matches the effect of how the postgres/mysql/mssql handlers overwrite the startup-packet user — the client's choice becomes inert; upstream always looks up the configured account's verifier. - Dead / misleading code removed: local PacketTypeResendMarker constant that duplicated tns.go's PacketTypeResend, package-level min() shadowing the Go 1.21+ builtin, the if !use32Bit branch in extractDataPayload where both arms assigned the same value, and the now-unused encodeCompressedInt helper. * chore(pam-oracle): remove dead code + tighten attribution Static-analysis sweep (staticcheck + manual cross-file grep) across the oracle handler package. All removals are symbols that were defined but never referenced anywhere: - tns.go: markerTypeReset, markerTypeInterrupt constants. - o5logon_server.go: TTCMsgAuthResponse, TTCMsgBreak, LogonModeUserAndPass, LogonModeNoNewPass constants; AuthPhaseTwo fields ESpeedyKey / AlterSession / ClientInfo / LogonMode (parser wrote them, no reader downstream); ParseAuthPhaseTwo's KVP switch trimmed to the two keys actually consumed (AUTH_SESSKEY, AUTH_PASSWORD). - ttc.go: TTCReader.GetNullTermString(), TTCReader.SetUseBigClrChunks() — both uncalled. Also added a proper attribution header to o5logon_server.go (it was adapting go-ora's phase-2 layout + summary-object format but lacked the same kind of header the other ported files have), and expanded ATTRIBUTION.md to cover it plus the specific primitives borrowed by o5logon.go. No behavior change. go build / go vet / staticcheck clean on the package. * chore(pam-oracle): remove dead verifier-type and error-code constants Exhaustive dead-export sweep via `go doc -all` (a more reliable check than my earlier regex-based grep, which missed untyped constants). Removed: - VerifierType10g / VerifierType11g / VerifierType12c — defined but no callers; the only verifier type our code implements (18453, 12c+ PBKDF2+SHA512) is hardcoded, the three named constants were never referenced. The 18453-specific comments in the code retain the documentation. - ORA12660EncryptionRequired — defined but no callers. Only ORA1017InvalidCredentials is actually emitted. Post-sweep: every exported symbol in the oracle handler package has a call site. `go build` / `go vet` / `staticcheck` clean. * chore(pam-oracle): remove dead enum-block const members staticcheck's U1000 has an exemption for const blocks: if any member is used, sibling members aren't flagged as unused even when they are. An exhaustive sweep that enumerates every top-level identifier via manual grep (so const-block membership is irrelevant) caught the ones staticcheck skipped. Removed: - tns.go: PacketTypeAbort, PacketTypeAck, PacketTypeAttn, PacketTypeCtrl, PacketTypeNull. Only the 7 PacketType values we actually dispatch on remain. - query_logger.go: ttcFuncOFETCH, ttcFuncOCLOSE, ttcFuncOSTMT, ttcFuncOLOGOFF, ttcMsgPiggyback. Only the 4 TTC opcodes the query tap actually looks for remain (OALL8, OCOMMIT, ORLLBK, the outer msgFunction). Post-sweep: exhaustive enumeration finds zero unused symbols across the package (189 candidates checked). go build / go vet / staticcheck clean. * fix(pam-oracle): show Infisical account name (not real DB user) in the banner The banner was printing the real upstream DB username (pamResponse.Metadata["username"]) in the connection URL, even though the preceding "Account:" label already shows the Infisical account name. Since the gateway now rewrites the client-supplied username in the O5Logon exchange to the configured real user, the client can (and should) connect using the account name — and the banner makes that explicit. Before: Resource: aws-oracledb Account: admin2 oracle://admin:password@localhost:53521/DATABASE ← confusing After: Resource: aws-oracledb Account: admin2 oracle://admin2:password@localhost:53521/DATABASE ← matches the label Scope: Oracle only for now. The postgres/mysql/mssql handlers also overwrite the client username on the wire, so the same banner change would work there, but that needs a separate verification pass per dialect before we extend it. * fix(pam-oracle): inject SERVICE_NAME from config into CONNECT packet The client's CONNECT description string was forwarded unchanged, requiring users to know the real Oracle service name. Now we rewrite SERVICE_NAME to match InjectDatabase from the vault config, consistent with how username and password are already injected. * fix(pam-oracle): remove placeholder password verification in phase-2 The gateway no longer checks whether the client sent the placeholder password "password". It unconditionally encrypts the real password from the vault, regardless of what the client typed. Auth will succeed either way since we inject the real credentials. The placeholder password is still shown to the user in the CLI banner and still used for key derivation in phase-1 — only the verification check is removed. * fix(pam-oracle): forward phase-2 response directly, remove SVR_RESPONSE regen AUTH_SVR_RESPONSE is encrypted with encKey, which is derived from session keys + CSK salt — not the password. The client and Oracle derive the same encKey, so Oracle's original proof is already valid for the client. No need to regenerate it. Removes translatePhase2Response, BuildSvrResponse, and the placeholderEncKey field from ProxyAuthState. * fix(pam-oracle): replace 3s timeout peek with deterministic supplement drain The post-ACCEPT supplement peek used a 3-second read deadline to detect whether a go-ora client sent connect-data as a separate packet. This added a fixed 3s delay for every non-go-ora client (sqlcl, JDBC thin). Instead, check the CONNECT packet structure: if connect-data-length + connect-data-offset exceeds the packet size, the data wasn't inline and a supplement will follow. Track whether the RESEND handler already consumed it; if not, do a blocking read (the supplement is guaranteed to be in the TCP buffer since the client sent it before waiting for a response). Removes prependedConn, detectConnectDataSupplement, and the timeout. * chore(pam-oracle): strip comments, enforce InjectDatabase, fix ora() formatting - Remove verbose comments across all files (~20% → ~2% comment rate) - Remove per-file go-ora attribution (ATTRIBUTION.md carries the license) - Trim ATTRIBUTION.md to just the copyright notice + MIT text - Make InjectDatabase mandatory (error if empty, always overwrite client's SERVICE_NAME) - Format unknown Oracle error codes as ORA-XXXXX instead of bare "ERROR" - Clarify ProxyPasswordPlaceholder as a decoy * chore(pam-oracle): remove redundant banner comment and auth note * fix(pam-oracle): forward connect-data supplement before handshake loop JDBC thin sends the CONNECT description as a separate follow-up packet when it doesn't fit inline (74-byte header-only CONNECT). Previously we drained this supplement only after ACCEPT, relying on the server to RESEND. AWS RDS sends RESEND, but OCI Autonomous DB does not — it waits for the connect-data and eventually drops the connection. Fix: check the CONNECT packet for non-inline data (connect-data-offset + connect-data-length > packet size) and forward the supplement immediately. Also forward it again after a RESEND+TLS-restart, since the client re-sends CONNECT + supplement over the new TLS session. * fix(pam-oracle): restore placeholder password check for clearer errors Wrong password fails cryptographically anyway (ORA-17452 from garbage decryption), but the error is confusing. Restore the explicit check so the client gets a clean ORA-01017 instead. * fix(pam-oracle): show real DB username in banner, matching other handlers * chore(pam-oracle): update ProxyPasswordPlaceholder comment * chore(pam-oracle): simplify ProxyPasswordPlaceholder comment * chore(pam-oracle): rename ResourceTypeOracle to ResourceTypeOracledb * fix(pam-oracle): cap handshake loop, zero keys after auth, remove dead comment - Add maxHandshakeAttempts (10) to prevent infinite RESEND loops - Zero RealKey, PlaceholderKey, ServerSessKey after auth completes - Remove leftover debug comment from proxy_auth.go * chore(pam-oracle): simplify key cleanup to state = nil * fix(pam-oracle): show Easy Connect and JDBC connection strings in banner --------- Co-authored-by: saif <11242541+saifsmailbox98@users.noreply.github.com> --- .gitignore | 2 +- packages/pam/handlers/oracle/ATTRIBUTION.md | 23 + packages/pam/handlers/oracle/constants.go | 4 + packages/pam/handlers/oracle/o5logon.go | 108 +++ .../pam/handlers/oracle/o5logon_server.go | 163 ++++ packages/pam/handlers/oracle/proxy.go | 64 ++ packages/pam/handlers/oracle/proxy_auth.go | 830 ++++++++++++++++++ packages/pam/handlers/oracle/query_logger.go | 276 ++++++ packages/pam/handlers/oracle/tns.go | 115 +++ packages/pam/handlers/oracle/ttc.go | 302 +++++++ packages/pam/local/database-proxy.go | 5 + packages/pam/pam-proxy.go | 20 + packages/pam/session/uploader.go | 5 +- 13 files changed, 1914 insertions(+), 3 deletions(-) create mode 100644 packages/pam/handlers/oracle/ATTRIBUTION.md create mode 100644 packages/pam/handlers/oracle/constants.go create mode 100644 packages/pam/handlers/oracle/o5logon.go create mode 100644 packages/pam/handlers/oracle/o5logon_server.go create mode 100644 packages/pam/handlers/oracle/proxy.go create mode 100644 packages/pam/handlers/oracle/proxy_auth.go create mode 100644 packages/pam/handlers/oracle/query_logger.go create mode 100644 packages/pam/handlers/oracle/tns.go create mode 100644 packages/pam/handlers/oracle/ttc.go diff --git a/.gitignore b/.gitignore index 2891130d..c654c50c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,5 @@ infisical /agent-testing .vscode/ -# PAM CLI session artifacts (local testing only) +.idea/ /session/ diff --git a/packages/pam/handlers/oracle/ATTRIBUTION.md b/packages/pam/handlers/oracle/ATTRIBUTION.md new file mode 100644 index 00000000..8a3a8c60 --- /dev/null +++ b/packages/pam/handlers/oracle/ATTRIBUTION.md @@ -0,0 +1,23 @@ +This package contains code adapted from [sijms/go-ora](https://github.com/sijms/go-ora). + +MIT License + +Copyright (c) 2020 Samy Sultan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/packages/pam/handlers/oracle/constants.go b/packages/pam/handlers/oracle/constants.go new file mode 100644 index 00000000..7d1b3541 --- /dev/null +++ b/packages/pam/handlers/oracle/constants.go @@ -0,0 +1,4 @@ +package oracle + +// Must be passed by the client; also used for O5Logon key derivation in phase 1. +const ProxyPasswordPlaceholder = "password" diff --git a/packages/pam/handlers/oracle/o5logon.go b/packages/pam/handlers/oracle/o5logon.go new file mode 100644 index 00000000..ca66ea51 --- /dev/null +++ b/packages/pam/handlers/oracle/o5logon.go @@ -0,0 +1,108 @@ +package oracle + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha512" + "encoding/hex" + "fmt" +) + +const ( + ORA1017InvalidCredentials = 1017 +) + +func PKCS5Padding(cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(cipherText, padtext...) +} + +func generateSpeedyKey(buffer, key []byte, turns int) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(append(buffer, 0, 0, 0, 1)) + firstHash := mac.Sum(nil) + tempHash := make([]byte, len(firstHash)) + copy(tempHash, firstHash) + for index1 := 2; index1 <= turns; index1++ { + mac.Reset() + mac.Write(tempHash) + tempHash = mac.Sum(nil) + for index2 := 0; index2 < 64; index2++ { + firstHash[index2] = firstHash[index2] ^ tempHash[index2] + } + } + return firstHash +} + +func decryptSessionKey(padding bool, encKey []byte, sessionKeyHex string) ([]byte, error) { + result, err := hex.DecodeString(sessionKeyHex) + if err != nil { + return nil, err + } + blk, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + dec := cipher.NewCBCDecrypter(blk, make([]byte, 16)) + output := make([]byte, len(result)) + dec.CryptBlocks(output, result) + cutLen := 0 + if padding { + num := int(output[len(output)-1]) + if num < dec.BlockSize() { + apply := true + for x := len(output) - num; x < len(output); x++ { + if output[x] != uint8(num) { + apply = false + break + } + } + if apply { + cutLen = int(output[len(output)-1]) + } + } + } + return output[:len(output)-cutLen], nil +} + +func encryptSessionKey(padding bool, encKey []byte, sessionKey []byte) (string, error) { + blk, err := aes.NewCipher(encKey) + if err != nil { + return "", err + } + enc := cipher.NewCBCEncrypter(blk, make([]byte, 16)) + originalLen := len(sessionKey) + sessionKey = PKCS5Padding(sessionKey, blk.BlockSize()) + output := make([]byte, len(sessionKey)) + enc.CryptBlocks(output, sessionKey) + if !padding { + return fmt.Sprintf("%X", output[:originalLen]), nil + } + return fmt.Sprintf("%X", output), nil +} + +func encryptPassword(password, key []byte, padding bool) (string, error) { + buff1 := make([]byte, 0x10) + if _, err := rand.Read(buff1); err != nil { + return "", err + } + buffer := append(buff1, password...) + return encryptSessionKey(padding, key, buffer) +} + +func deriveServerKey(password string, salt []byte, vGenCount int) (key []byte, speedy []byte, err error) { + message := append([]byte(nil), salt...) + message = append(message, []byte("AUTH_PBKDF2_SPEEDY_KEY")...) + speedy = generateSpeedyKey(message, []byte(password), vGenCount) + + buffer := append([]byte(nil), speedy...) + buffer = append(buffer, salt...) + h := sha512.New() + h.Write(buffer) + key = h.Sum(nil)[:32] + return +} diff --git a/packages/pam/handlers/oracle/o5logon_server.go b/packages/pam/handlers/oracle/o5logon_server.go new file mode 100644 index 00000000..7df71436 --- /dev/null +++ b/packages/pam/handlers/oracle/o5logon_server.go @@ -0,0 +1,163 @@ +package oracle + +import ( + "fmt" + "net" +) + +const ( + TTCMsgAuthRequest = 0x03 + TTCMsgError = 0x04 +) + +const ( + AuthSubOpPhaseOne = 0x76 + AuthSubOpPhaseTwo = 0x73 +) + +type AuthPhaseTwo struct { + EClientSessKey string + EPassword string +} + +func readDataPayload(conn net.Conn, use32BitLen bool) ([]byte, error) { + raw, err := ReadFullPacket(conn, use32BitLen) + if err != nil { + return nil, err + } + if PacketTypeOf(raw) == PacketTypeMarker { + return readDataPayload(conn, use32BitLen) + } + if PacketTypeOf(raw) != PacketTypeData { + return nil, fmt.Errorf("expected DATA packet, got type=%d", raw[4]) + } + pkt, err := ParseDataPacket(raw, use32BitLen) + if err != nil { + return nil, err + } + return pkt.Payload, nil +} + +func writeDataPayload(conn net.Conn, payload []byte, use32BitLen bool) error { + d := &DataPacket{Payload: payload} + _, err := conn.Write(d.Bytes(use32BitLen)) + return err +} + +func ParseAuthPhaseTwo(payload []byte) (*AuthPhaseTwo, error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, err + } + if op != TTCMsgAuthRequest { + return nil, fmt.Errorf("phase2 unexpected opcode 0x%02X", op) + } + sub, err := r.GetByte() + if err != nil { + return nil, err + } + if sub != AuthSubOpPhaseTwo { + return nil, fmt.Errorf("phase2 unexpected sub-op 0x%02X", sub) + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + + out := &AuthPhaseTwo{} + + hasUser, err := r.GetByte() + if err != nil { + return nil, err + } + var userLen int + if hasUser == 1 { + userLen, err = r.GetInt(4, true, true) + if err != nil { + return nil, err + } + } else { + if _, err := r.GetByte(); err != nil { + return nil, err + } + } + + if _, err := r.GetInt(4, true, true); err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + count, err := r.GetInt(4, true, true) + if err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + if hasUser == 1 && userLen > 0 { + // go-ora prefixes username with CLR length byte; JDBC thin sends it raw. + peek, perr := r.PeekByte() + if perr != nil { + return nil, fmt.Errorf("peek phase2 username: %w", perr) + } + if int(peek) == userLen && peek < 0x20 { + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("consume phase2 username length prefix: %w", err) + } + } + if _, err := r.GetBytes(userLen); err != nil { + return nil, fmt.Errorf("read phase2 username bytes: %w", err) + } + } + + for i := 0; i < count; i++ { + k, v, _, err := r.GetKeyVal() + if err != nil { + return nil, fmt.Errorf("phase2 KVP #%d: %w", i, err) + } + switch string(k) { + case "AUTH_SESSKEY": + out.EClientSessKey = string(v) + case "AUTH_PASSWORD": + out.EPassword = string(v) + } + } + return out, nil +} + +func BuildErrorPacket(oraCode int, message string) []byte { + b := NewTTCBuilder() + b.PutBytes(TTCMsgError) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(int64(oraCode), 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutString(message) + b.PutInt(0, 2, true, true) + return b.Bytes() +} + +func WriteErrorToClient(conn net.Conn, oraCode int, message string, use32BitLen bool) error { + return writeDataPayload(conn, BuildErrorPacket(oraCode, message), use32BitLen) +} diff --git a/packages/pam/handlers/oracle/proxy.go b/packages/pam/handlers/oracle/proxy.go new file mode 100644 index 00000000..71916b6a --- /dev/null +++ b/packages/pam/handlers/oracle/proxy.go @@ -0,0 +1,64 @@ +package oracle + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/Infisical/infisical-merge/packages/pam/session" +) + +type OracleProxyConfig struct { + TargetAddr string + InjectUsername string + InjectPassword string + InjectDatabase string + EnableTLS bool + TLSConfig *tls.Config + SessionID string + SessionLogger session.SessionLogger +} + +type OracleProxy struct { + config OracleProxyConfig +} + +func NewOracleProxy(config OracleProxyConfig) *OracleProxy { + return &OracleProxy{config: config} +} + +func (p *OracleProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { + return p.handleConnectionProxied(ctx, clientConn) +} + +func relayWithTap(src, dst net.Conn, tap *QueryExtractor, errCh chan<- error) { + buf := make([]byte, 32*1024) + for { + n, err := src.Read(buf) + if n > 0 { + if _, werr := dst.Write(buf[:n]); werr != nil { + errCh <- werr + return + } + tap.Feed(buf[:n]) + } + if err != nil { + errCh <- err + return + } + } +} + +func splitHostPort(addr string) (string, int, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", 0, err + } + var port int + _, err = fmt.Sscanf(portStr, "%d", &port) + if err != nil { + return "", 0, fmt.Errorf("bad port %q: %w", portStr, err) + } + return host, port, nil +} diff --git a/packages/pam/handlers/oracle/proxy_auth.go b/packages/pam/handlers/oracle/proxy_auth.go new file mode 100644 index 00000000..51caa643 --- /dev/null +++ b/packages/pam/handlers/oracle/proxy_auth.go @@ -0,0 +1,830 @@ +package oracle + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/rs/zerolog/log" +) + +func (p *OracleProxy) handleConnectionProxied(ctx context.Context, clientConn net.Conn) error { + defer clientConn.Close() + defer func() { + if err := p.config.SessionLogger.Close(); err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to close session logger") + } + }() + + log.Info().Str("sessionID", p.config.SessionID).Str("target", p.config.TargetAddr).Msg("Oracle PAM session started (proxied auth)") + + // 1. Dial upstream (keep raw TCP ref — TCPS may need a second TLS handshake mid-flow). + rawUpstream, tlsUpstream, err := dialUpstreamRaw(ctx, p.config) + if err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to dial Oracle upstream") + _ = WriteRefuseToClient(clientConn, "(DESCRIPTION=(ERR=12564)(VSNNUM=0)(ERROR_STACK=(ERROR=(CODE=12564)(EMFI=4))))") + return fmt.Errorf("upstream dial: %w", err) + } + var upstreamConn net.Conn + if tlsUpstream != nil { + upstreamConn = tlsUpstream + } else { + upstreamConn = rawUpstream + } + defer func() { upstreamConn.Close() }() + + // 2. Read client's CONNECT, rewrite SERVICE_NAME, forward to upstream. + connectRaw, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read client CONNECT: %w", err) + } + if PacketTypeOf(connectRaw) != PacketTypeConnect { + return fmt.Errorf("expected CONNECT, got type=%d", connectRaw[4]) + } + if p.config.InjectDatabase == "" { + return fmt.Errorf("InjectDatabase (service name) is required but empty") + } + connectRaw = rewriteConnectServiceName(connectRaw, p.config.InjectDatabase) + if _, err := upstreamConn.Write(connectRaw); err != nil { + return fmt.Errorf("forward CONNECT: %w", err) + } + + // If connect-data wasn't inline (JDBC thin / go-ora with long descriptions), + // the client already sent it as a follow-up packet. Forward it now — some + // Oracle listeners (e.g., OCI Autonomous DB) won't RESEND, they just wait. + if len(connectRaw) >= 28 { + cdLen := int(binary.BigEndian.Uint16(connectRaw[24:26])) + cdOff := int(binary.BigEndian.Uint16(connectRaw[26:28])) + if cdLen > 0 && cdOff+cdLen > len(connectRaw) { + supplement, serr := ReadFullPacket(clientConn, false) + if serr != nil { + return fmt.Errorf("read connect-data supplement: %w", serr) + } + log.Info().Str("sessionID", p.config.SessionID).Int("supplementLen", len(supplement)).Msg("Proxy: forwarding connect-data supplement before handshake") + if _, werr := upstreamConn.Write(supplement); werr != nil { + return fmt.Errorf("forward connect-data supplement: %w", werr) + } + } + } + + // 3. Read upstream responses until ACCEPT. Handle RESEND (TLS restart). + const maxHandshakeAttempts = 10 + var acceptRaw []byte + for attempt := 0; acceptRaw == nil; attempt++ { + if attempt >= maxHandshakeAttempts { + return fmt.Errorf("too many handshake packets (%d) without ACCEPT", attempt) + } + pkt, err := ReadFullPacket(upstreamConn, false) + if err != nil { + return fmt.Errorf("read upstream handshake packet (attempt %d): %w", attempt, err) + } + pktType := PacketTypeOf(pkt) + var origFlag byte + if len(pkt) > 5 { + origFlag = pkt[5] + } + log.Info().Str("sessionID", p.config.SessionID).Uint8("pktType", uint8(pktType)).Int("pktLen", len(pkt)).Uint8("flag", origFlag).Msg("Proxy: upstream handshake packet") + + // RESEND flag 0x08: tear down current TLS, do a fresh handshake on the raw socket. + if p.config.EnableTLS && pktType == PacketTypeResend && origFlag&0x08 != 0 { + tc, terr := upgradeToTLS(ctx, rawUpstream, p.config) + if terr != nil { + return fmt.Errorf("upstream TLS upgrade after RESEND(flag=0x08): %w", terr) + } + upstreamConn = tc + log.Info().Str("sessionID", p.config.SessionID).Str("tlsVersion", tlsVersionString(tc.ConnectionState().Version)).Str("cipher", tls.CipherSuiteName(tc.ConnectionState().CipherSuite)).Msg("Proxy: upstream TLS re-handshook on RESEND(flag=0x08)") + } + + // Mask byte 5 so thin clients don't try TLS upgrade on plain TCP. + if p.config.EnableTLS && len(pkt) > 5 { + pkt[5] = 0x00 + } + if _, werr := clientConn.Write(pkt); werr != nil { + return fmt.Errorf("forward upstream handshake packet: %w", werr) + } + switch pktType { + case PacketTypeAccept: + acceptRaw = pkt + case PacketTypeRefuse: + return fmt.Errorf("upstream REFUSE during handshake") + case PacketTypeRedirect: + return fmt.Errorf("upstream REDIRECT during handshake (not supported)") + case PacketTypeResend: + clientPkt, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read client response after RESEND: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("len", len(clientPkt)).Uint8("type", uint8(PacketTypeOf(clientPkt))).Msg("Proxy: forwarding client response after RESEND") + if _, werr := upstreamConn.Write(clientPkt); werr != nil { + return fmt.Errorf("forward client response after RESEND: %w", werr) + } + // Client may re-send CONNECT with non-inline connect-data — forward + // the supplement too, same as we did before the handshake loop. + if PacketTypeOf(clientPkt) == PacketTypeConnect && len(clientPkt) >= 28 { + cdLen := int(binary.BigEndian.Uint16(clientPkt[24:26])) + cdOff := int(binary.BigEndian.Uint16(clientPkt[26:28])) + if cdLen > 0 && cdOff+cdLen > len(clientPkt) { + supp, serr := ReadFullPacket(clientConn, false) + if serr != nil { + return fmt.Errorf("read connect-data supplement after RESEND: %w", serr) + } + log.Info().Str("sessionID", p.config.SessionID).Int("supplementLen", len(supp)).Msg("Proxy: forwarding connect-data supplement after RESEND") + if _, werr := upstreamConn.Write(supp); werr != nil { + return fmt.Errorf("forward connect-data supplement after RESEND: %w", werr) + } + } + } + } + } + + var acceptVersion uint16 + if len(acceptRaw) >= 10 { + acceptVersion = binary.BigEndian.Uint16(acceptRaw[8:10]) + } + use32Bit := acceptVersion >= 315 + log.Info().Str("sessionID", p.config.SessionID).Uint16("acceptVersion", acceptVersion).Bool("use32Bit", use32Bit).Msg("Proxy: ACCEPT forwarded") + + + + p1Payload, err := proxyUntilAuthRequest(clientConn, upstreamConn, use32Bit, p.config.SessionID) + if err != nil { + return fmt.Errorf("pre-auth proxy: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("p1Len", len(p1Payload)).Msg("Proxy: auth-request boundary reached") + + p1Forward := p1Payload + if p.config.InjectUsername != "" { + rewritten, rerr := rewritePhase1User(p1Payload, p.config.InjectUsername) + if rerr != nil { + return fmt.Errorf("rewrite phase 1 username: %w", rerr) + } + p1Forward = rewritten + } + if err := writeDataPayload(upstreamConn, p1Forward, use32Bit); err != nil { + return fmt.Errorf("forward phase 1 request: %w", err) + } + + p1RespUpstream, err := readDataPayload(upstreamConn, use32Bit) + if err != nil { + return fmt.Errorf("read upstream phase 1 response: %w", err) + } + state, p1RespTranslated, err := translatePhase1Response(p1RespUpstream, p.config.InjectPassword) + if err != nil { + _ = WriteErrorToClient(clientConn, ORA1017InvalidCredentials, "ORA-01017: invalid username/password; logon denied", use32Bit) + return fmt.Errorf("translate phase 1 response: %w", err) + } + if err := writeDataPayload(clientConn, p1RespTranslated, use32Bit); err != nil { + return fmt.Errorf("write translated phase 1 response: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-1 response translated and forwarded") + + p2ReqClient, err := readDataPayload(clientConn, use32Bit) + if err != nil { + return fmt.Errorf("read client phase 2 request: %w", err) + } + p2ReqTranslated, err := translatePhase2Request(p2ReqClient, state, p.config.InjectPassword) + if err != nil { + _ = WriteErrorToClient(clientConn, ORA1017InvalidCredentials, "ORA-01017: invalid username/password; logon denied", use32Bit) + return fmt.Errorf("translate phase 2 request: %w", err) + } + // Oracle cross-checks phase-2 username against phase-1. + if p.config.InjectUsername != "" { + rewritten, rerr := rewritePhase2User(p2ReqTranslated, p.config.InjectUsername) + if rerr != nil { + return fmt.Errorf("rewrite phase 2 username: %w", rerr) + } + p2ReqTranslated = rewritten + } + if err := writeDataPayload(upstreamConn, p2ReqTranslated, use32Bit); err != nil { + return fmt.Errorf("forward phase 2 request: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-2 request translated and forwarded") + + // AUTH_SVR_RESPONSE is keyed on session material (not password) — forward unchanged. + p2RespRaw, err := ReadFullPacket(upstreamConn, use32Bit) + if err != nil { + return fmt.Errorf("read upstream phase 2 response: %w", err) + } + if _, err := clientConn.Write(p2RespRaw); err != nil { + return fmt.Errorf("forward phase 2 response: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-2 response forwarded; client authenticated") + + state = nil + + c2u, u2c := NewQueryExtractorPair(p.config.SessionLogger, p.config.SessionID, use32Bit) + defer c2u.Stop() + defer u2c.Stop() + + errCh := make(chan error, 2) + go relayWithTap(clientConn, upstreamConn, c2u, errCh) + go relayWithTap(upstreamConn, clientConn, u2c, errCh) + + select { + case rerr := <-errCh: + if rerr != nil && rerr != io.EOF { + log.Debug().Err(rerr).Str("sessionID", p.config.SessionID).Msg("Oracle relay ended") + } + case <-ctx.Done(): + log.Info().Str("sessionID", p.config.SessionID).Msg("Oracle session cancelled by context") + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Oracle PAM session ended") + return nil +} + +// Includes legacy RSA-CBC suites needed by Oracle 19c / AWS RDS. +var oracleUpstreamCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, +} + +// TLS 1.0–1.2 only: Oracle TCPS has no TLS-1.3 restart mechanism; RDS negotiates down to 1.0. +func buildOracleTLSConfig(base *tls.Config, host string) *tls.Config { + cfg := base.Clone() + if cfg.ServerName == "" { + cfg.ServerName = host + } + cfg.MinVersion = tls.VersionTLS10 + cfg.MaxVersion = tls.VersionTLS12 + cfg.CipherSuites = oracleUpstreamCiphers + return cfg +} + +func dialUpstreamRaw(ctx context.Context, cfg OracleProxyConfig) (rawConn net.Conn, tlsConn *tls.Conn, err error) { + host, _, err := splitHostPort(cfg.TargetAddr) + if err != nil { + return nil, nil, fmt.Errorf("invalid target addr: %w", err) + } + d := &net.Dialer{Timeout: 15 * time.Second} + rawConn, err = d.DialContext(ctx, "tcp", cfg.TargetAddr) + if err != nil { + return nil, nil, err + } + if !cfg.EnableTLS { + return rawConn, nil, nil + } + if cfg.TLSConfig == nil { + rawConn.Close() + return nil, nil, fmt.Errorf("upstream TLS requested but no TLSConfig provided") + } + tlsCfg := buildOracleTLSConfig(cfg.TLSConfig, host) + tc := tls.Client(rawConn, tlsCfg) + if err := tc.HandshakeContext(ctx); err != nil { + rawConn.Close() + return nil, nil, fmt.Errorf("upstream TLS handshake: %w", err) + } + return rawConn, tc, nil +} + +func tlsVersionString(v uint16) string { + switch v { + case tls.VersionTLS10: + return "TLS1.0" + case tls.VersionTLS11: + return "TLS1.1" + case tls.VersionTLS12: + return "TLS1.2" + case tls.VersionTLS13: + return "TLS1.3" + default: + return fmt.Sprintf("0x%04x", v) + } +} + +func upgradeToTLS(ctx context.Context, rawConn net.Conn, cfg OracleProxyConfig) (*tls.Conn, error) { + host, _, err := splitHostPort(cfg.TargetAddr) + if err != nil { + return nil, fmt.Errorf("invalid target addr: %w", err) + } + tlsCfg := buildOracleTLSConfig(cfg.TLSConfig, host) + tc := tls.Client(rawConn, tlsCfg) + if err := tc.HandshakeContext(ctx); err != nil { + return nil, fmt.Errorf("upstream TLS handshake: %w", err) + } + return tc, nil +} + +func proxyUntilAuthRequest(client, upstream net.Conn, use32Bit bool, sessionID string) ([]byte, error) { + type result struct { + payload []byte + err error + } + done := make(chan result, 2) + stop := make(chan struct{}) + + go func() { + for { + select { + case <-stop: + return + default: + } + pkt, err := ReadFullPacket(upstream, use32Bit) + if err != nil { + select { + case done <- result{err: fmt.Errorf("read upstream: %w", err)}: + default: + } + return + } + if _, werr := client.Write(pkt); werr != nil { + select { + case done <- result{err: fmt.Errorf("write client: %w", werr)}: + default: + } + return + } + log.Debug().Str("sessionID", sessionID).Uint8("type", uint8(PacketTypeOf(pkt))).Int("len", len(pkt)).Msg("Proxy pre-auth: upstream → client") + } + }() + + go func() { + for { + select { + case <-stop: + return + default: + } + pkt, err := ReadFullPacket(client, use32Bit) + if err != nil { + select { + case done <- result{err: fmt.Errorf("read client: %w", err)}: + default: + } + return + } + pktType := PacketTypeOf(pkt) + if pktType == PacketTypeData { + payload, perr := extractDataPayload(pkt) + if perr == nil && len(payload) >= 2 && + payload[0] == TTCMsgAuthRequest && payload[1] == AuthSubOpPhaseOne { + select { + case done <- result{payload: payload}: + default: + } + return + } + } + if _, werr := upstream.Write(pkt); werr != nil { + select { + case done <- result{err: fmt.Errorf("write upstream: %w", werr)}: + default: + } + return + } + log.Debug().Str("sessionID", sessionID).Uint8("type", uint8(pktType)).Int("len", len(pkt)).Msg("Proxy pre-auth: client → upstream") + } + }() + + res := <-done + close(stop) + // Unblock the other goroutine so it doesn't steal the phase-1 response. + if uc, ok := upstream.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = uc.SetReadDeadline(time.Now().Add(-1 * time.Second)) + } + time.Sleep(50 * time.Millisecond) + if uc, ok := upstream.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = uc.SetReadDeadline(time.Time{}) + } + if res.err != nil { + return nil, res.err + } + return res.payload, nil +} + +func extractDataPayload(pkt []byte) ([]byte, error) { + const headerLen = 10 + if len(pkt) < headerLen { + return nil, fmt.Errorf("packet too short: %d", len(pkt)) + } + return pkt[headerLen:], nil +} + +func rewriteAuthRequestUser(payload []byte, expectedSubOp byte, newUser string) ([]byte, error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, fmt.Errorf("opcode: %w", err) + } + if op != TTCMsgAuthRequest { + return nil, fmt.Errorf("unexpected opcode 0x%02X", op) + } + sub, err := r.GetByte() + if err != nil { + return nil, err + } + if sub != expectedSubOp { + return nil, fmt.Errorf("unexpected sub-op 0x%02X (want 0x%02X)", sub, expectedSubOp) + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + + hasUser, err := r.GetByte() + if err != nil { + return nil, err + } + if hasUser != 1 { + return payload, nil + } + origUserLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, fmt.Errorf("userLen: %w", err) + } + if origUserLen <= 0 { + return payload, nil + } + + middleStart := r.Pos() + + if _, err := r.GetInt(4, true, true); err != nil { + return nil, fmt.Errorf("mode: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker after mode: %w", err) + } + if _, err := r.GetInt(4, true, true); err != nil { + return nil, fmt.Errorf("count: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker 1: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker 2: %w", err) + } + middleEnd := r.Pos() + + // go-ora prefixes user bytes with a CLR-length byte; JDBC thin omits it. + peek, perr := r.PeekByte() + if perr != nil { + return nil, fmt.Errorf("peek user: %w", perr) + } + usedCLRPrefix := int(peek) == origUserLen && peek < 0x20 + if usedCLRPrefix { + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("consume user CLR length: %w", err) + } + } + if _, err := r.GetBytes(origUserLen); err != nil { + return nil, fmt.Errorf("user bytes: %w", err) + } + userEnd := r.Pos() + + newUserBytes := []byte(newUser) + newUserLen := len(newUserBytes) + + out := make([]byte, 0, len(payload)+16) + out = append(out, payload[:3]...) + out = append(out, 0x01) + lb := NewTTCBuilder() + lb.PutInt(int64(newUserLen), 4, true, true) + out = append(out, lb.Bytes()...) + out = append(out, payload[middleStart:middleEnd]...) + if usedCLRPrefix { + out = append(out, byte(newUserLen)) + } + out = append(out, newUserBytes...) + out = append(out, payload[userEnd:]...) + return out, nil +} + +func rewriteConnectServiceName(pkt []byte, newName string) []byte { + marker := []byte("SERVICE_NAME=") + idx := bytes.Index(pkt, marker) + if idx < 0 { + return pkt + } + valStart := idx + len(marker) + valEnd := bytes.IndexByte(pkt[valStart:], ')') + if valEnd < 0 { + return pkt + } + valEnd += valStart + + oldVal := pkt[valStart:valEnd] + newVal := []byte(newName) + if bytes.Equal(oldVal, newVal) { + return pkt + } + + out := make([]byte, 0, len(pkt)+len(newVal)-len(oldVal)) + out = append(out, pkt[:valStart]...) + out = append(out, newVal...) + out = append(out, pkt[valEnd:]...) + + binary.BigEndian.PutUint16(out[0:2], uint16(len(out))) + if len(out) >= 26 { + oldCDLen := binary.BigEndian.Uint16(pkt[24:26]) + binary.BigEndian.PutUint16(out[24:26], uint16(int(oldCDLen)+len(newVal)-len(oldVal))) + } + return out +} + +func rewritePhase1User(payload []byte, newUser string) ([]byte, error) { + return rewriteAuthRequestUser(payload, AuthSubOpPhaseOne, newUser) +} + +func rewritePhase2User(payload []byte, newUser string) ([]byte, error) { + return rewriteAuthRequestUser(payload, AuthSubOpPhaseTwo, newUser) +} + +type ProxyAuthState struct { + Salt []byte + Pbkdf2CSKSalt string + Pbkdf2VGenCount int + Pbkdf2SDerCount int + RealKey []byte + PlaceholderKey []byte + ServerSessKey []byte +} + +func translatePhase1Response(payload []byte, realPassword string) (*ProxyAuthState, []byte, error) { + kvs, trailer, err := parseAuthRespKVPList(payload) + if err != nil { + return nil, nil, fmt.Errorf("parse upstream phase 1: %w", err) + } + + var eSessKey, vfrData, cskSalt, vGenStr, sDerStr string + for _, kv := range kvs { + switch kv.Key { + case "AUTH_SESSKEY": + eSessKey = kv.Value + case "AUTH_VFR_DATA": + vfrData = kv.Value + case "AUTH_PBKDF2_CSK_SALT": + cskSalt = kv.Value + case "AUTH_PBKDF2_VGEN_COUNT": + vGenStr = kv.Value + case "AUTH_PBKDF2_SDER_COUNT": + sDerStr = kv.Value + } + } + if eSessKey == "" || vfrData == "" { + return nil, nil, fmt.Errorf("upstream phase 1 missing AUTH_SESSKEY or AUTH_VFR_DATA") + } + salt, err := hex.DecodeString(vfrData) + if err != nil { + return nil, nil, fmt.Errorf("decode salt: %w", err) + } + vGen, _ := strconv.Atoi(vGenStr) + if vGen == 0 { + vGen = 4096 + } + sDer, _ := strconv.Atoi(sDerStr) + if sDer == 0 { + sDer = 3 + } + + realKey, _, err := deriveServerKey(realPassword, salt, vGen) + if err != nil { + return nil, nil, fmt.Errorf("derive real key: %w", err) + } + placeholderKey, _, err := deriveServerKey(ProxyPasswordPlaceholder, salt, vGen) + if err != nil { + return nil, nil, fmt.Errorf("derive placeholder key: %w", err) + } + + serverSessKey, err := decryptSessionKey(false, realKey, eSessKey) + if err != nil { + return nil, nil, fmt.Errorf("decrypt upstream server session key: %w", err) + } + newESessKey, err := encryptSessionKey(false, placeholderKey, serverSessKey) + if err != nil { + return nil, nil, fmt.Errorf("re-encrypt server session key: %w", err) + } + + for i := range kvs { + if kvs[i].Key == "AUTH_SESSKEY" { + kvs[i].Value = newESessKey + break + } + } + + rebuilt := rebuildAuthRespPayload(kvs, trailer) + + state := &ProxyAuthState{ + Salt: salt, + Pbkdf2CSKSalt: cskSalt, + Pbkdf2VGenCount: vGen, + Pbkdf2SDerCount: sDer, + RealKey: realKey, + PlaceholderKey: placeholderKey, + ServerSessKey: serverSessKey, + } + return state, rebuilt, nil +} + +func translatePhase2Request(payload []byte, state *ProxyAuthState, realPassword string) ([]byte, error) { + p2, err := ParseAuthPhaseTwo(payload) + if err != nil { + return nil, fmt.Errorf("parse client phase 2: %w", err) + } + + if p2.EClientSessKey == "" || p2.EPassword == "" { + return nil, fmt.Errorf("client phase 2 missing AUTH_SESSKEY or AUTH_PASSWORD") + } + + clientSessKey, err := decryptSessionKey(false, state.PlaceholderKey, p2.EClientSessKey) + if err != nil { + return nil, fmt.Errorf("decrypt client session key: %w", err) + } + if len(clientSessKey) != len(state.ServerSessKey) { + return nil, fmt.Errorf("client session key length mismatch: got %d want %d", len(clientSessKey), len(state.ServerSessKey)) + } + newEClientSessKey, err := encryptSessionKey(false, state.RealKey, clientSessKey) + if err != nil { + return nil, fmt.Errorf("re-encrypt client session key: %w", err) + } + + // encKey derives from session keys + CSK salt, not the password. + encKey, err := deriveProxyPasswordEncKey(clientSessKey, state.ServerSessKey, state.Pbkdf2CSKSalt, state.Pbkdf2SDerCount) + if err != nil { + return nil, fmt.Errorf("derive enc key: %w", err) + } + // Verify the client used the placeholder password. Wrong password would also + // fail cryptographically in phase 1 (ORA-17452), but this gives a clearer error. + decoded, err := decryptSessionKey(true, encKey, p2.EPassword) + if err != nil { + return nil, fmt.Errorf("decrypt client password: %w", err) + } + if len(decoded) <= 16 || string(decoded[16:]) != ProxyPasswordPlaceholder { + return nil, fmt.Errorf("password mismatch") + } + newEPassword, err := encryptPassword([]byte(realPassword), encKey, true) + if err != nil { + return nil, fmt.Errorf("encrypt real password: %w", err) + } + + rebuilt, err := rebuildPhase2Request(payload, newEClientSessKey, newEPassword) + if err != nil { + return nil, fmt.Errorf("rebuild phase 2: %w", err) + } + return rebuilt, nil +} + +func deriveProxyPasswordEncKey(clientSessKey, serverSessKey []byte, pbkdf2CSKSaltHex string, sderCount int) ([]byte, error) { + buffer := append([]byte(nil), clientSessKey...) + buffer = append(buffer, serverSessKey...) + keyBuffer := []byte(fmt.Sprintf("%X", buffer)) + cskSalt, err := hex.DecodeString(pbkdf2CSKSaltHex) + if err != nil { + return nil, fmt.Errorf("decode pbkdf2 salt: %w", err) + } + full := generateSpeedyKey(cskSalt, keyBuffer, sderCount) + if len(full) < 32 { + return nil, fmt.Errorf("speedy key too short: %d", len(full)) + } + return full[:32], nil +} + +type parsedKVP struct { + Key string + Value string + Flag int +} + +func parseAuthRespKVPList(payload []byte) (kvs []parsedKVP, trailer []byte, err error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, nil, err + } + if op != 0x08 { + return nil, nil, fmt.Errorf("expected auth response opcode 0x08, got 0x%02X", op) + } + dictLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("dict len: %w", err) + } + for i := 0; i < dictLen; i++ { + keyLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d key len: %w", i, err) + } + var keyBytes []byte + if keyLen > 0 { + keyBytes, err = r.GetClr() + if err != nil { + return nil, nil, fmt.Errorf("kvp %d key: %w", i, err) + } + if len(keyBytes) > keyLen { + keyBytes = keyBytes[:keyLen] + } + } + valLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d val len: %w", i, err) + } + var valBytes []byte + if valLen > 0 { + valBytes, err = r.GetClr() + if err != nil { + return nil, nil, fmt.Errorf("kvp %d val: %w", i, err) + } + if len(valBytes) > valLen { + valBytes = valBytes[:valLen] + } + } + flag, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d flag: %w", i, err) + } + kvs = append(kvs, parsedKVP{ + Key: string(bytes.TrimRight(keyBytes, "\x00")), + Value: string(valBytes), + Flag: flag, + }) + } + trailer = make([]byte, r.Remaining()) + rem, _ := r.GetBytes(r.Remaining()) + copy(trailer, rem) + return kvs, trailer, nil +} + +func rebuildAuthRespPayload(kvs []parsedKVP, trailer []byte) []byte { + b := NewTTCBuilder() + b.PutBytes(0x08) + b.PutUint(uint64(len(kvs)), 4, true, true) + for _, kv := range kvs { + b.PutKeyValString(kv.Key, kv.Value, uint32(kv.Flag)) + } + b.PutBytes(trailer...) + return b.Bytes() +} + +func rebuildPhase2Request(payload []byte, newESessKey, newEPassword string) ([]byte, error) { + out := make([]byte, 0, len(payload)+128) + out = append(out, payload...) + + out, err := replaceKVPValue(out, "AUTH_SESSKEY", newESessKey) + if err != nil { + return nil, fmt.Errorf("replace AUTH_SESSKEY: %w", err) + } + out, err = replaceKVPValue(out, "AUTH_PASSWORD", newEPassword) + if err != nil { + return nil, fmt.Errorf("replace AUTH_PASSWORD: %w", err) + } + return out, nil +} + +func replaceKVPValue(payload []byte, key, newValue string) ([]byte, error) { + keyBytes := []byte(key) + idx := bytes.Index(payload, keyBytes) + if idx < 0 { + return nil, fmt.Errorf("key %q not found", key) + } + pos := idx + len(keyBytes) + if pos >= len(payload) { + return nil, fmt.Errorf("truncated after key") + } + vSizeByte := payload[pos] + pos++ + var vLen int + if vSizeByte == 0 { + vLen = 0 + } else if int(vSizeByte) <= 8 { + for i := 0; i < int(vSizeByte); i++ { + vLen = (vLen << 8) | int(payload[pos+i]) + } + pos += int(vSizeByte) + } else { + return nil, fmt.Errorf("invalid val_len size byte %d", vSizeByte) + } + if vLen > 0 { + if pos >= len(payload) || int(payload[pos]) != vLen { + return nil, fmt.Errorf("CLR length byte mismatch for %q: got %d want %d", key, payload[pos], vLen) + } + pos++ + valBodyStart := pos + valBodyEnd := valBodyStart + vLen + // PutClr handles chunked 0xFE form for values > 0xFC bytes. + newVal := []byte(newValue) + vb := NewTTCBuilder() + vb.PutUint(uint64(len(newVal)), 4, true, true) + vb.PutClr(newVal) + newValSection := vb.Bytes() + oldStart := idx + len(keyBytes) + oldEnd := valBodyEnd + out := make([]byte, 0, len(payload)+len(newValSection)) + out = append(out, payload[:oldStart]...) + out = append(out, newValSection...) + out = append(out, payload[oldEnd:]...) + return out, nil + } + return payload, fmt.Errorf("unexpected empty value for %q", key) +} diff --git a/packages/pam/handlers/oracle/query_logger.go b/packages/pam/handlers/oracle/query_logger.go new file mode 100644 index 00000000..2860b452 --- /dev/null +++ b/packages/pam/handlers/oracle/query_logger.go @@ -0,0 +1,276 @@ +package oracle + +import ( + "bytes" + "encoding/binary" + "fmt" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/rs/zerolog/log" +) + +const ( + ttcFuncOALL8 = 0x5E + ttcFuncOCOMMIT = 0x0E + ttcFuncORLLBK = 0x0F + ttcMsgFunction = 0x03 +) + +type pendingQuery struct { + sql string + timestamp time.Time +} + +// Best-effort SQL extraction from the byte stream. +type QueryExtractor struct { + logger session.SessionLogger + sessionID string + direction string + ch chan []byte + stopCh chan struct{} + wg sync.WaitGroup + use32Bit bool + pair *pairState +} + +type pairState struct { + mu sync.Mutex + pending *pendingQuery +} + +func NewQueryExtractorPair(logger session.SessionLogger, sessionID string, use32Bit bool) (clientToUpstream, upstreamToClient *QueryExtractor) { + p := &pairState{} + clientToUpstream = newExtractor(logger, sessionID, "client->upstream", use32Bit, p) + upstreamToClient = newExtractor(logger, sessionID, "upstream->client", use32Bit, p) + return +} + +func newExtractor(logger session.SessionLogger, sessionID, direction string, use32Bit bool, pair *pairState) *QueryExtractor { + e := &QueryExtractor{ + logger: logger, + sessionID: sessionID, + direction: direction, + ch: make(chan []byte, 64), + stopCh: make(chan struct{}), + use32Bit: use32Bit, + pair: pair, + } + e.wg.Add(1) + go e.loop() + return e +} + +func (e *QueryExtractor) Feed(data []byte) { + if len(data) == 0 { + return + } + cp := make([]byte, len(data)) + copy(cp, data) + select { + case e.ch <- cp: + default: + } +} + +func (e *QueryExtractor) Stop() { + close(e.stopCh) + e.wg.Wait() +} + +func (e *QueryExtractor) loop() { + defer e.wg.Done() + var buffer bytes.Buffer + + for { + select { + case <-e.stopCh: + return + case chunk := <-e.ch: + buffer.Write(chunk) + e.drain(&buffer) + } + } +} + +func (e *QueryExtractor) drain(buf *bytes.Buffer) { + for { + if buf.Len() < 8 { + return + } + head := buf.Bytes()[:8] + var length uint32 + if e.use32Bit { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } + if length < 8 || length > 16*1024*1024 { + buf.Reset() + return + } + if buf.Len() < int(length) { + return + } + packet := make([]byte, length) + if _, err := buf.Read(packet); err != nil { + return + } + e.handlePacket(packet) + } +} + +func (e *QueryExtractor) handlePacket(raw []byte) { + if PacketTypeOf(raw) != PacketTypeData { + return + } + d, err := ParseDataPacket(raw, e.use32Bit) + if err != nil { + return + } + if len(d.Payload) < 1 { + return + } + switch e.direction { + case "client->upstream": + e.handleClientRequest(d.Payload) + case "upstream->client": + e.handleServerResponse(d.Payload) + } +} + +func (e *QueryExtractor) handleClientRequest(payload []byte) { + // Clients often piggyback an OCLOSE before the new function call; scan for + // the function-call+opcode marker pair instead of parsing from offset 0. + if idx := findBytePair(payload, ttcMsgFunction, ttcFuncOALL8); idx >= 0 { + r := NewTTCReader(payload[idx+2:]) + if sqlText := tryExtractSQL(r); sqlText != "" { + e.pair.mu.Lock() + e.pair.pending = &pendingQuery{sql: sqlText, timestamp: time.Now()} + e.pair.mu.Unlock() + } + return + } + if findBytePair(payload, ttcMsgFunction, ttcFuncOCOMMIT) >= 0 { + e.recordLiteral("COMMIT") + return + } + if findBytePair(payload, ttcMsgFunction, ttcFuncORLLBK) >= 0 { + e.recordLiteral("ROLLBACK") + return + } +} + +func findBytePair(data []byte, b1, b2 byte) int { + for i := 0; i+1 < len(data); i++ { + if data[i] == b1 && data[i+1] == b2 { + return i + } + } + return -1 +} + +func (e *QueryExtractor) recordLiteral(sql string) { + e.pair.mu.Lock() + e.pair.pending = &pendingQuery{sql: sql, timestamp: time.Now()} + e.pair.mu.Unlock() +} + +// tryExtractSQL uses a longest-printable-run heuristic because OALL8 headers +// vary across client drivers and bind patterns. +func tryExtractSQL(r *TTCReader) string { + remaining := r.Remaining() + if remaining <= 0 { + return "" + } + buf, err := r.GetBytes(remaining) + if err != nil { + return "" + } + return longestPrintableRun(buf) +} + +func longestPrintableRun(data []byte) string { + bestStart, bestLen := 0, 0 + curStart, curLen := 0, 0 + for i, b := range data { + printable := b == '\t' || b == '\n' || b == '\r' || (b >= 0x20 && b <= 0x7E) + if printable { + if curLen == 0 { + curStart = i + } + curLen++ + if curLen > bestLen { + bestLen = curLen + bestStart = curStart + } + } else { + curLen = 0 + } + } + if bestLen < 4 { + return "" + } + return string(data[bestStart : bestStart+bestLen]) +} + +func (e *QueryExtractor) handleServerResponse(payload []byte) { + e.pair.mu.Lock() + pending := e.pair.pending + e.pair.pending = nil + e.pair.mu.Unlock() + if pending == nil { + return + } + output := extractResponseOutcome(payload) + err := e.logger.LogEntry(session.SessionLogEntry{ + Timestamp: pending.timestamp, + Input: pending.sql, + Output: output, + }) + if err != nil { + log.Debug().Err(err).Str("sessionID", e.sessionID).Msg("session log entry dropped") + } +} + +func extractResponseOutcome(payload []byte) string { + r := NewTTCReader(payload) + for r.Remaining() > 0 { + op, err := r.GetByte() + if err != nil { + break + } + if op == 0x04 { + for i := 0; i < 3; i++ { + if _, err := r.GetInt(4, true, true); err != nil { + return "OK" + } + } + code, err := r.GetInt(4, true, true) + if err != nil || code == 0 { + return "OK" + } + return ora(code) + } + } + return "" +} + +func ora(code int) string { + switch code { + case 0: + return "OK" + case 1: + return "ERROR: ORA-00001: unique constraint violated" + case 900: + return "ERROR: ORA-00900: invalid SQL statement" + case 942: + return "ERROR: ORA-00942: table or view does not exist" + case 1017: + return "ERROR: ORA-01017: invalid username/password" + case 28000: + return "ERROR: ORA-28000: the account is locked" + } + return fmt.Sprintf("ERROR: ORA-%05d", code) +} diff --git a/packages/pam/handlers/oracle/tns.go b/packages/pam/handlers/oracle/tns.go new file mode 100644 index 00000000..93d01f54 --- /dev/null +++ b/packages/pam/handlers/oracle/tns.go @@ -0,0 +1,115 @@ +package oracle + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +type PacketType uint8 + +const ( + PacketTypeConnect PacketType = 1 + PacketTypeAccept PacketType = 2 + PacketTypeRefuse PacketType = 4 + PacketTypeRedirect PacketType = 5 + PacketTypeData PacketType = 6 + PacketTypeResend PacketType = 11 + PacketTypeMarker PacketType = 12 +) + +// use32BitLen: 32-bit length framing after ACCEPT (version >= 315), 16-bit before. +func ReadFullPacket(r io.Reader, use32BitLen bool) ([]byte, error) { + head := make([]byte, 8) + if _, err := io.ReadFull(r, head); err != nil { + return nil, err + } + var length uint32 + if use32BitLen { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } + if length < 8 { + return nil, fmt.Errorf("invalid TNS packet length: %d", length) + } + if length > 1<<22 { + return nil, fmt.Errorf("TNS packet too large: %d", length) + } + buf := make([]byte, length) + copy(buf, head) + if length > 8 { + if _, err := io.ReadFull(r, buf[8:]); err != nil { + return nil, err + } + } + return buf, nil +} + +func PacketTypeOf(packet []byte) PacketType { + if len(packet) < 5 { + return 0 + } + return PacketType(packet[4]) +} + +type DataPacket struct { + DataFlag uint16 + Payload []byte +} + +func ParseDataPacket(raw []byte, use32BitLen bool) (*DataPacket, error) { + if len(raw) < 10 || PacketType(raw[4]) != PacketTypeData { + return nil, errors.New("not a DATA packet") + } + return &DataPacket{ + DataFlag: binary.BigEndian.Uint16(raw[8:]), + Payload: append([]byte(nil), raw[10:]...), + }, nil +} + +func (d *DataPacket) Bytes(use32BitLen bool) []byte { + length := uint32(10 + len(d.Payload)) + out := make([]byte, length) + if use32BitLen { + binary.BigEndian.PutUint32(out, length) + } else { + binary.BigEndian.PutUint16(out, uint16(length)) + } + out[4] = byte(PacketTypeData) + out[5] = 0 + binary.BigEndian.PutUint16(out[8:], d.DataFlag) + copy(out[10:], d.Payload) + return out +} + +type RefusePacket struct { + UserReason uint8 + SystemReason uint8 + Message string +} + +func (r *RefusePacket) Bytes() []byte { + msg := []byte(r.Message) + length := uint32(12 + len(msg)) + out := make([]byte, length) + binary.BigEndian.PutUint16(out, uint16(length)) + out[4] = byte(PacketTypeRefuse) + out[5] = 0 + out[8] = r.UserReason + out[9] = r.SystemReason + binary.BigEndian.PutUint16(out[10:], uint16(len(msg))) + copy(out[12:], msg) + return out +} + +func WriteRefuseToClient(w io.Writer, message string) error { + pkt := &RefusePacket{ + UserReason: 0, + SystemReason: 0, + Message: message, + } + _, err := w.Write(pkt.Bytes()) + return err +} diff --git a/packages/pam/handlers/oracle/ttc.go b/packages/pam/handlers/oracle/ttc.go new file mode 100644 index 00000000..30461e59 --- /dev/null +++ b/packages/pam/handlers/oracle/ttc.go @@ -0,0 +1,302 @@ +package oracle + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +type TTCBuilder struct { + buf bytes.Buffer + useBigClrChunks bool + clrChunkSize int +} + +func NewTTCBuilder() *TTCBuilder { + return &TTCBuilder{useBigClrChunks: true, clrChunkSize: 0x7FFF} +} + +func (b *TTCBuilder) Bytes() []byte { return b.buf.Bytes() } + +func (b *TTCBuilder) PutBytes(data ...byte) { b.buf.Write(data) } + +func (b *TTCBuilder) PutUint(num uint64, size uint8, bigEndian, compress bool) { + if size == 1 { + b.buf.WriteByte(uint8(num)) + return + } + if compress { + temp := make([]byte, 8) + binary.BigEndian.PutUint64(temp, num) + temp = bytes.TrimLeft(temp, "\x00") + if size > uint8(len(temp)) { + size = uint8(len(temp)) + } + if size == 0 { + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(size) + b.buf.Write(temp) + return + } + temp := make([]byte, size) + if bigEndian { + switch size { + case 2: + binary.BigEndian.PutUint16(temp, uint16(num)) + case 4: + binary.BigEndian.PutUint32(temp, uint32(num)) + case 8: + binary.BigEndian.PutUint64(temp, num) + } + } else { + switch size { + case 2: + binary.LittleEndian.PutUint16(temp, uint16(num)) + case 4: + binary.LittleEndian.PutUint32(temp, uint32(num)) + case 8: + binary.LittleEndian.PutUint64(temp, num) + } + } + b.buf.Write(temp) +} + +func (b *TTCBuilder) PutInt(num int64, size uint8, bigEndian, compress bool) { + if compress { + temp := make([]byte, 8) + binary.BigEndian.PutUint64(temp, uint64(num)) + temp = bytes.TrimLeft(temp, "\x00") + if size > uint8(len(temp)) { + size = uint8(len(temp)) + } + if size == 0 { + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(size) + b.buf.Write(temp[:size]) + return + } + b.PutUint(uint64(num), size, bigEndian, false) +} + +func (b *TTCBuilder) PutClr(data []byte) { + dataLen := len(data) + if dataLen == 0 { + b.buf.WriteByte(0) + return + } + if dataLen > 0xFC { + b.buf.WriteByte(0xFE) + start := 0 + for start < dataLen { + end := start + b.clrChunkSize + if end > dataLen { + end = dataLen + } + chunk := data[start:end] + if b.useBigClrChunks { + b.PutInt(int64(len(chunk)), 4, true, true) + } else { + b.buf.WriteByte(uint8(len(chunk))) + } + b.buf.Write(chunk) + start += b.clrChunkSize + } + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(uint8(dataLen)) + b.buf.Write(data) +} + +func (b *TTCBuilder) PutString(s string) { b.PutClr([]byte(s)) } + +func (b *TTCBuilder) PutKeyVal(key, val []byte, num uint32) { + if len(key) == 0 { + b.buf.WriteByte(0) + } else { + b.PutUint(uint64(len(key)), 4, true, true) + b.PutClr(key) + } + if len(val) == 0 { + b.buf.WriteByte(0) + } else { + b.PutUint(uint64(len(val)), 4, true, true) + b.PutClr(val) + } + b.PutInt(int64(num), 4, true, true) +} + +func (b *TTCBuilder) PutKeyValString(key, val string, num uint32) { + b.PutKeyVal([]byte(key), []byte(val), num) +} + +type TTCReader struct { + buf []byte + pos int + useBigClrChunks bool +} + +func NewTTCReader(payload []byte) *TTCReader { + return &TTCReader{buf: payload, useBigClrChunks: true} +} + +func (r *TTCReader) Remaining() int { return len(r.buf) - r.pos } + +func (r *TTCReader) Pos() int { return r.pos } + +func (r *TTCReader) read(n int) ([]byte, error) { + if r.pos+n > len(r.buf) { + return nil, io.ErrUnexpectedEOF + } + out := r.buf[r.pos : r.pos+n] + r.pos += n + return out, nil +} + +func (r *TTCReader) GetByte() (uint8, error) { + b, err := r.read(1) + if err != nil { + return 0, err + } + return b[0], nil +} + +func (r *TTCReader) PeekByte() (uint8, error) { + if r.pos >= len(r.buf) { + return 0, io.ErrUnexpectedEOF + } + return r.buf[r.pos], nil +} + +func (r *TTCReader) GetBytes(n int) ([]byte, error) { + b, err := r.read(n) + if err != nil { + return nil, err + } + out := make([]byte, len(b)) + copy(out, b) + return out, nil +} + +func (r *TTCReader) GetInt64(size int, compress, bigEndian bool) (int64, error) { + negFlag := false + if compress { + sb, err := r.read(1) + if err != nil { + return 0, err + } + size = int(sb[0]) + if size&0x80 > 0 { + negFlag = true + size = size & 0x7F + } + bigEndian = true + } + if size == 0 { + return 0, nil + } + if size > 8 { + return 0, errors.New("invalid size for GetInt64") + } + rb, err := r.read(size) + if err != nil { + return 0, err + } + temp := make([]byte, 8) + var v int64 + if bigEndian { + copy(temp[8-size:], rb) + v = int64(binary.BigEndian.Uint64(temp)) + } else { + copy(temp[:size], rb) + v = int64(binary.LittleEndian.Uint64(temp)) + } + if negFlag { + v = -v + } + return v, nil +} + +func (r *TTCReader) GetInt(size int, compress, bigEndian bool) (int, error) { + v, err := r.GetInt64(size, compress, bigEndian) + return int(v), err +} + +func (r *TTCReader) GetClr() ([]byte, error) { + nb, err := r.GetByte() + if err != nil { + return nil, err + } + if nb == 0 || nb == 0xFF || nb == 0xFD { + return nil, nil + } + if nb != 0xFE { + out, err := r.read(int(nb)) + if err != nil { + return nil, err + } + ret := make([]byte, len(out)) + copy(ret, out) + return ret, nil + } + var buf bytes.Buffer + for { + var chunkSize int + if r.useBigClrChunks { + chunkSize, err = r.GetInt(4, true, true) + } else { + b, err2 := r.GetByte() + err = err2 + chunkSize = int(b) + } + if err != nil { + return nil, err + } + if chunkSize == 0 { + break + } + chunk, err := r.read(chunkSize) + if err != nil { + return nil, err + } + buf.Write(chunk) + } + return buf.Bytes(), nil +} + +func (r *TTCReader) GetDlc() ([]byte, error) { + length, err := r.GetInt(4, true, true) + if err != nil { + return nil, err + } + if length <= 0 { + _, _ = r.GetClr() + return nil, nil + } + out, err := r.GetClr() + if err != nil { + return nil, err + } + if len(out) > length { + out = out[:length] + } + return out, nil +} + +func (r *TTCReader) GetKeyVal() (key, val []byte, num int, err error) { + key, err = r.GetDlc() + if err != nil { + return + } + val, err = r.GetDlc() + if err != nil { + return + } + num, err = r.GetInt(4, true, true) + return +} diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index 9f51f4e3..c418795c 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -10,6 +10,7 @@ import ( "syscall" "time" + "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" @@ -125,6 +126,10 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p util.PrintfStderr("sqlserver://%s@localhost:%d?database=%s&encrypt=false&trustServerCertificate=true", username, proxy.port, database) case session.ResourceTypeMongodb: util.PrintfStderr("mongodb://localhost:%d/%s?serverSelectionTimeoutMS=15000", proxy.port, database) + case session.ResourceTypeOracledb: + util.PrintfStderr("%s/%s@localhost:%d/%s", username, oracle.ProxyPasswordPlaceholder, proxy.port, database) + util.PrintfStderr("\njdbc:oracle:thin:@localhost:%d/%s (user: %s, password: %s)", proxy.port, database, username, oracle.ProxyPasswordPlaceholder) + util.PrintfStderr("\n\nNote: the password shown is a protocol placeholder required by Oracle, not a secret.") default: util.PrintfStderr("localhost:%d", proxy.port) } diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 3240e682..dd46900b 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -18,6 +18,7 @@ import ( "github.com/Infisical/infisical-merge/packages/pam/handlers/mongodb" "github.com/Infisical/infisical-merge/packages/pam/handlers/mssql" "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" + "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" "github.com/Infisical/infisical-merge/packages/pam/handlers/rdp" "github.com/Infisical/infisical-merge/packages/pam/handlers/redis" "github.com/Infisical/infisical-merge/packages/pam/handlers/ssh" @@ -55,6 +56,7 @@ func GetSupportedResourceTypes() []string { session.ResourceTypeKubernetes, session.ResourceTypeRedis, session.ResourceTypeMongodb, + session.ResourceTypeOracledb, } // Only advertise RDP when the real bridge is compiled in. A stub // build would otherwise accept RDP session routing and fail every @@ -409,6 +411,24 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("authMethod", credentials.AuthMethod). Msg("Starting Kubernetes PAM proxy") return proxy.HandleConnection(ctx, handlerConn) + case session.ResourceTypeOracledb: + oracleConfig := oracle.OracleProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectDatabase: credentials.Database, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + } + proxy := oracle.NewOracleProxy(oracleConfig) + log.Info(). + Str("sessionId", pamConfig.SessionId). + Str("target", oracleConfig.TargetAddr). + Bool("sslEnabled", credentials.SSLEnabled). + Msg("Starting Oracle PAM proxy") + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMongodb: mongoConfig := mongodb.MongoDBProxyConfig{ Host: credentials.ConnectionString, diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index 6f43781c..e0b71cc1 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -31,7 +31,8 @@ const ( ResourceTypeSSH = "ssh" ResourceTypeKubernetes = "kubernetes" ResourceTypeMongodb = "mongodb" - ResourceTypeWindows = "windows" + ResourceTypeOracledb = "oracledb" + ResourceTypeWindows = "windows" ) type SessionFileInfo struct { @@ -75,7 +76,7 @@ func NewSessionUploader(httpClient *resty.Client, credentialsManager *Credential func ParseSessionFilename(filename string) (*SessionFileInfo, error) { // Try new format first: pam_session_{sessionID}_{resourceType}_expires_{timestamp}.enc // Build regex pattern using constants - resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeMssql, ResourceTypeKubernetes, ResourceTypeMongodb, ResourceTypeWindows) + resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeMssql, ResourceTypeKubernetes, ResourceTypeMongodb, ResourceTypeOracledb, ResourceTypeWindows) newFormatRegex := regexp.MustCompile(fmt.Sprintf(`^pam_session_(.+)_%s_expires_(\d+)\.enc$`, resourceTypePattern)) matches := newFormatRegex.FindStringSubmatch(filename) From 9ad910ba0ef712eb7b463399db9e6378f89b2f18 Mon Sep 17 00:00:00 2001 From: Saif Ur Rahman Date: Tue, 12 May 2026 00:09:10 +0530 Subject: [PATCH 08/10] fix(pam): fix SSH e2e test flakes from container readiness races (#228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * debug(pam): add diagnostic logging for SSH resource creation flake Add detailed logging when CreateSshPamResource returns non-200 to diagnose the intermittent 400 error seen only in CI. On failure, dumps: - Response body (the actual backend error message) - Request params (gatewayId, projectId, host, port) - Relay/gateway process status and stderr/stdout - Gateway API state (heartbeat, name) - Backend container logs (last 3000 chars) Also temporarily scopes CI to only PAM SSH tests in a loop (up to 10 runs, stops on first failure) to maximize chance of catching the flake. * fix(pam): use localhost for SSH resource host to avoid iptables race The SSH e2e tests were using getOutboundIP() (e.g. 10.1.0.34) as the resource host. The gateway dials this address to validate the connection during resource creation. On Linux CI with native Docker, this path goes through iptables DNAT rules which can have a brief propagation delay after container start — causing intermittent "connection refused" even though testcontainers confirmed the port was listening (via localhost). CI error: "Unable to validate connection to ssh: Connection lost before handshake" with gateway log: "dial tcp 10.1.0.34:32774: connect: connection refused" Fix: use localhost instead. The gateway runs on the host, so localhost:port reaches docker-proxy directly on the loopback interface without needing iptables DNAT. Diagnostic logging kept in place to catch any other failure modes. * debug(ci): add sleep between loop iterations for port release The relay binds to hardcoded ports (8443, 2222). When running multiple test iterations in a loop, the previous relay's ports may still be in TIME_WAIT when the next iteration starts, causing "Relay is reachable" to never appear. Add 15s sleep between iterations. * experiment: revert to getOutboundIP to confirm flake reproduces in CI Reverting the localhost fix to verify that getOutboundIP triggers the iptables race ("connection refused") in CI. If this fails, it confirms the root cause. Will re-apply the fix afterward. * fix(pam): re-apply localhost fix after confirming getOutboundIP flake Experiment confirmed: - getOutboundIP (10.1.0.168): FAILED run 1 — "dial tcp 10.1.0.168:32773: connect: connection refused" (iptables DNAT race) - localhost: 20/20 passed across two CI runs Re-applying the localhost fix. * experiment: second run with getOutboundIP to confirm reproducibility * fix(pam): re-apply localhost fix — confirmed root cause with A/B test A/B experiment in CI confirms getOutboundIP causes the flake: getOutboundIP (external IP): - Experiment 1: FAILED run 1 (password, 10.1.0.168:32773) - Experiment 2: FAILED run 4 (certificate, 10.1.0.114:32799) - Error: "dial tcp :: connect: connection refused" localhost: - 3 CI runs: PASSED 25/25 consecutively Root cause: on Linux CI with native Docker, port mappings via external IP go through iptables DNAT rules which have a brief propagation delay after container start. localhost bypasses this via docker-proxy on loopback. * fix(pam): finalize — restore CI workflow, keep diagnostics - Restore all CI jobs (CLI, Agent, PAM) to normal single-run mode - Run full PAM test suite (not just SSH) - Keep diagnostic dump on createSSHPamResource failure - Log resourceHost value so any future failure shows which path was used * test: validation loop with log correlation check Runs SSH tests 5 times with localhost fix, then greps all logs for "connection refused" to verify the error is completely absent. With getOutboundIP: "connection refused" appeared in gateway logs right before every failure. With localhost: should see zero occurrences across all 5 runs. * fix(pam): use localhost for SSH test resource host + add failure diagnostics Root cause: SSH e2e tests used getOutboundIP() (e.g. 10.1.0.168) as the resource host. During backend's validateConnection(), the gateway dials this address to reach the SSH container. On Linux CI with native Docker, this path goes through iptables DNAT rules which have a brief propagation delay after container start. testcontainers confirms the port via localhost (docker-proxy on loopback), but the iptables path isn't ready yet — causing "connection refused" at the gateway. Fix: use "localhost" instead. The gateway runs on the Docker host, so localhost reaches the port mapping directly via docker-proxy without needing iptables DNAT. Verified with A/B experiment in CI: getOutboundIP: FAILED 2/2 runs ("connection refused" in gateway logs) localhost: PASSED 30/30 runs (zero "connection refused" in logs) Also adds diagnostic dump on createSSHPamResource failure (response body, relay/gateway logs, backend container logs, gateway API state) for visibility into any future failure modes. * fix(pam): get SSH resource host from container.Host() like Postgres/Redis Use container.Host(ctx) instead of getOutboundIP() or hardcoded "localhost". This is the same pattern Postgres and Redis tests use, and returns the correct Docker daemon host (localhost on Linux, which routes through docker-proxy without iptables DNAT races). * fix(pam): use container.Host() for SSH resource host, remove debug code Get the SSH resource host from container.Host(ctx) — same as Postgres and Redis tests — instead of getOutboundIP(). On Linux CI, Host() returns localhost which routes through docker-proxy on loopback, avoiding the iptables DNAT propagation race that caused intermittent "connection refused" when using the host's external IP. Removes all diagnostic logging added during investigation. * chore: restore original workflow file (no changes needed) * debug: log response body on createSSHPamResource failure * debug: SSH-only loop to catch second flake mode with response body logging * ci: retrigger loop to catch rare second flake * debug: add gateway/relay stderr dump on failure to catch second flake * ci: retrigger to catch rare flake (attempt 3) * fix(pam): wait for sshd "Server listening" log, not just port The SSH container's port 22 becomes listenable before sshd is fully ready to accept connections. testcontainers' ForListeningPort confirms the port is open, but sshd may still be generating host keys or loading config. Connections during this window get "connection reset by peer" (0 bytes from service), which the backend sees as "Connection lost before handshake". Fix: use ForAll(ForListeningPort, ForLog("Server listening")) to wait for both port availability AND sshd's startup confirmation log. * ci: retrigger for confidence * fix(pam): use SSH handshake to verify sshd readiness after SIGHUP After configureCertAuth sends SIGHUP to reload sshd config, the test checked readiness with a TCP dial. But sshd never closes its listen socket during SIGHUP, so the port is always reachable — the dial passes instantly even while sshd is still processing the config reload. Connections during this window get accepted then dropped, causing "Connection to 127.0.0.1 closed by remote host". Fix: do an actual SSH handshake (ssh.Dial with password auth) instead of a TCP dial. This confirms sshd is fully serving connections, not just listening. * ci: retrigger confidence run * fix(pam): fix three SSH e2e test flakes 1. Use container.Host(ctx) for resource host instead of getOutboundIP(). On Linux CI, external IPs route through iptables DNAT which has a propagation delay after container start. 2. Wait for sshd "Server listening" log, not just port open. ForListeningPort passes before sshd finishes key generation, causing connection resets. 3. Use SSH handshake to verify sshd readiness after SIGHUP config reload. sshd never closes its listen socket during SIGHUP, so TCP dial passes instantly while sshd is still reloading. --------- Co-authored-by: saif <11242541+saifsmailbox98@users.noreply.github.com> --- e2e/pam/ssh_test.go | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/e2e/pam/ssh_test.go b/e2e/pam/ssh_test.go index 4d79b1d4..df329af5 100644 --- a/e2e/pam/ssh_test.go +++ b/e2e/pam/ssh_test.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "log/slog" - "net" "net/http" "os" "strings" @@ -33,8 +32,8 @@ const ( sshPassword = "testpass" ) -func startSSHContainer(t *testing.T, ctx context.Context, env map[string]string) (testcontainers.Container, int) { - container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ +func startSSHContainer(t *testing.T, ctx context.Context, env map[string]string) (testcontainers.Container, string, int) { + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ Context: "testdata/ssh-server", @@ -45,20 +44,25 @@ func startSSHContainer(t *testing.T, ctx context.Context, env map[string]string) HostConfigModifier: func(hc *container.HostConfig) { hc.ExtraHosts = append(hc.ExtraHosts, "host.docker.internal:host-gateway") }, - WaitingFor: wait.ForListeningPort("22/tcp").WithStartupTimeout(30 * time.Second), + WaitingFor: wait.ForAll( + wait.ForListeningPort("22/tcp"), + wait.ForLog("Server listening"), + ).WithStartupTimeout(30 * time.Second), }, Started: true, }) require.NoError(t, err) t.Cleanup(func() { - if err := container.Terminate(ctx); err != nil { + if err := ctr.Terminate(ctx); err != nil { t.Logf("Failed to terminate SSH container: %v", err) } }) - port, err := container.MappedPort(ctx, "22") + host, err := ctr.Host(ctx) require.NoError(t, err) - return container, port.Int() + port, err := ctr.MappedPort(ctx, "22") + require.NoError(t, err) + return ctr, host, port.Int() } func createSSHPamResource(t *testing.T, ctx context.Context, infra *PAMTestInfra, name, host string, port int) uuid.UUID { @@ -200,12 +204,20 @@ func configureCertAuth(t *testing.T, ctx context.Context, infra *PAMTestInfra, c require.NoError(t, err) require.Equal(t, 0, exitCode, "sshd reload should succeed") - // Wait for sshd to be responsive after config reload. + // Wait for sshd to be fully responsive after config reload. + // A TCP dial is not enough — sshd never closes the listen socket during + // SIGHUP, so the port is always reachable. Instead, do an actual SSH + // handshake to confirm sshd is serving connections after the reload. result := helpers.WaitFor(t, helpers.WaitForOptions{ Timeout: 10 * time.Second, Interval: 500 * time.Millisecond, Condition: func() helpers.ConditionResult { - conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", sshPort), time.Second) + conn, err := ssh.Dial("tcp", fmt.Sprintf("localhost:%d", sshPort), &ssh.ClientConfig{ + User: sshUser, + Auth: []ssh.AuthMethod{ssh.Password(sshPassword)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + }) if err != nil { return helpers.ConditionWait } @@ -221,7 +233,7 @@ func configureCertAuth(t *testing.T, ctx context.Context, infra *PAMTestInfra, c // - password: uses hardcoded testuser/testpass from entrypoint; account gets username + password // - public-key: container gets SSH_AUTHORIZED_KEY (generated ed25519); account gets username + privateKey // - certificate: container configured via curl | bash (ssh-ca-setup endpoint); account gets just username -func runSSHAuthTest(t *testing.T, ctx context.Context, infra *PAMTestInfra, resourceHost string, method string) { +func runSSHAuthTest(t *testing.T, ctx context.Context, infra *PAMTestInfra, method string) { containerEnv := map[string]string{} accountCreds := map[string]interface{}{ "authMethod": method, @@ -249,11 +261,11 @@ func runSSHAuthTest(t *testing.T, ctx context.Context, infra *PAMTestInfra, reso // Cert auth is configured after resource creation via curl | bash. } - container, sshPort := startSSHContainer(t, ctx, containerEnv) - slog.Info("SSH container started", "method", method, "host", resourceHost, "port", sshPort) + container, sshHost, sshPort := startSSHContainer(t, ctx, containerEnv) + slog.Info("SSH container started", "method", method, "host", sshHost, "port", sshPort) resourceName := fmt.Sprintf("ssh-%s-resource", method) - resourceId := createSSHPamResource(t, ctx, infra, resourceName, resourceHost, sshPort) + resourceId := createSSHPamResource(t, ctx, infra, resourceName, sshHost, sshPort) if method == "certificate" { configureCertAuth(t, ctx, infra, container, sshPort, resourceId) @@ -273,12 +285,10 @@ func TestPAM_SSH(t *testing.T) { infra := SetupPAMInfra(t, ctx) LoginUser(t, ctx, infra) - resourceHost := getOutboundIP(t) - methods := []string{"password", "public-key", "certificate"} for _, method := range methods { t.Run(method, func(t *testing.T) { - runSSHAuthTest(t, ctx, infra, resourceHost, method) + runSSHAuthTest(t, ctx, infra, method) }) } } From 33db99f47df31ee25dc9be7777fd09e4b79b78f7 Mon Sep 17 00:00:00 2001 From: bernie-g Date: Mon, 11 May 2026 15:23:45 -0400 Subject: [PATCH 09/10] add RDP E2E tests for PAM module 6 subtests: connection, bad credentials, unreachable target, reconnect, concurrent connections, and session duration. Uses an xrdp container as the target and FreeRDP under xvfb for headless verification. CI pam-test job updated to build the Rust bridge and CLI with -tags rdp. --- .github/workflows/run-cli-e2e-tests.yml | 19 +- e2e/openapi-cfg.yaml | 2 + e2e/pam/rdp_test.go | 336 ++++++++++++++++++++++ e2e/pam/testdata/rdp-server/Dockerfile | 16 ++ e2e/pam/testdata/rdp-server/entrypoint.sh | 20 ++ 5 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 e2e/pam/rdp_test.go create mode 100644 e2e/pam/testdata/rdp-server/Dockerfile create mode 100644 e2e/pam/testdata/rdp-server/entrypoint.sh diff --git a/.github/workflows/run-cli-e2e-tests.yml b/.github/workflows/run-cli-e2e-tests.yml index 1929dfd7..59c0eed7 100644 --- a/.github/workflows/run-cli-e2e-tests.yml +++ b/.github/workflows/run-cli-e2e-tests.yml @@ -82,8 +82,25 @@ jobs: go-version: "1.25.9" - name: Install dependencies run: go get . + - name: Cache cargo registry + target + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + packages/pam/handlers/rdp/native/target + key: rdp-bridge-cargo-${{ runner.os }}-${{ hashFiles('packages/pam/handlers/rdp/native/Cargo.lock') }} + restore-keys: rdp-bridge-cargo-${{ runner.os }}- + - name: Install pinned Rust toolchain + working-directory: packages/pam/handlers/rdp/native + run: rustup show active-toolchain + - name: Build Rust RDP bridge + working-directory: packages/pam/handlers/rdp/native + run: cargo build --release - name: Build the CLI - run: go build -o infisical-cli + run: CGO_ENABLED=1 go build -tags rdp -o infisical-cli + - name: Install RDP test dependencies + run: sudo apt-get update && sudo apt-get install -y --no-install-recommends freerdp2-x11 xvfb - name: Checkout infisical repo uses: actions/checkout@v6 with: diff --git a/e2e/openapi-cfg.yaml b/e2e/openapi-cfg.yaml index e40b0472..5c0ee79c 100644 --- a/e2e/openapi-cfg.yaml +++ b/e2e/openapi-cfg.yaml @@ -37,3 +37,5 @@ output-options: - createSshPamResource - createSshPamAccount - createRedisPamAccount + - createWindowsPamResource + - createWindowsPamAccount diff --git a/e2e/pam/rdp_test.go b/e2e/pam/rdp_test.go new file mode 100644 index 00000000..efd0e186 --- /dev/null +++ b/e2e/pam/rdp_test.go @@ -0,0 +1,336 @@ +package pam + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os/exec" + "strings" + "sync" + "testing" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/google/uuid" + helpers "github.com/infisical/cli/e2e-tests/util" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + rdpUser = "testuser" + rdpPassword = "testpass" +) + +func startRDPContainer(t *testing.T, ctx context.Context) (testcontainers.Container, int) { + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + FromDockerfile: testcontainers.FromDockerfile{ + Context: "testdata/rdp-server", + Dockerfile: "Dockerfile", + }, + ExposedPorts: []string{"3389/tcp"}, + HostConfigModifier: func(hc *container.HostConfig) { + hc.ExtraHosts = append(hc.ExtraHosts, "host.docker.internal:host-gateway") + }, + WaitingFor: wait.ForListeningPort("3389/tcp").WithStartupTimeout(60 * time.Second), + }, + Started: true, + }) + require.NoError(t, err) + t.Cleanup(func() { + if err := ctr.Terminate(ctx); err != nil { + t.Logf("Failed to terminate RDP container: %v", err) + } + }) + + port, err := ctr.MappedPort(ctx, "3389") + require.NoError(t, err) + return ctr, port.Int() +} + +func pamAPIRequest(t *testing.T, infra *PAMTestInfra, method, path string, body interface{}) (int, []byte) { + jsonBody, err := json.Marshal(body) + require.NoError(t, err) + + url := infra.Infisical.ApiUrl(t) + path + req, err := http.NewRequest(method, url, bytes.NewReader(jsonBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+infra.ProvisionResult.Token) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + return resp.StatusCode, respBody +} + +func createRDPPamResource(t *testing.T, ctx context.Context, infra *PAMTestInfra, name, host string, port int) uuid.UUID { + status, respBody := pamAPIRequest(t, infra, "POST", "/api/v1/pam/resources/windows", map[string]interface{}{ + "projectId": infra.ProjectId, + "gatewayId": infra.GatewayId, + "name": name, + "connectionDetails": map[string]interface{}{ + "protocol": "rdp", + "hostname": host, + "port": port, + "winrmPort": 5985, + "useWinrmHttps": false, + "winrmRejectUnauthorized": false, + }, + }) + require.Equal(t, http.StatusOK, status, "create Windows resource: %s", string(respBody)) + + var result struct { + Resource struct { + Id uuid.UUID `json:"id"` + } `json:"resource"` + } + require.NoError(t, json.Unmarshal(respBody, &result)) + slog.Info("Created Windows PAM resource", "resourceId", result.Resource.Id, "name", name) + return result.Resource.Id +} + +func createRDPPamAccount(t *testing.T, ctx context.Context, infra *PAMTestInfra, resourceId uuid.UUID, name, username, password string) { + body := map[string]interface{}{ + "resourceId": resourceId.String(), + "name": name, + "credentials": map[string]interface{}{ + "username": username, + "password": password, + }, + "internalMetadata": map[string]interface{}{ + "accountType": "user", + }, + } + + result := helpers.WaitFor(t, helpers.WaitForOptions{ + Timeout: 90 * time.Second, + Interval: 3 * time.Second, + Condition: func() helpers.ConditionResult { + status, respBody := pamAPIRequest(t, infra, "POST", "/api/v1/pam/accounts/windows", body) + if status != http.StatusOK { + slog.Warn("Windows PAM account creation returned non-200, retrying...", "status", status, "body", string(respBody)) + return helpers.ConditionWait + } + return helpers.ConditionSuccess + }, + }) + require.Equal(t, helpers.WaitSuccess, result, "Windows PAM account creation should succeed for %s", name) + slog.Info("Created Windows PAM account", "name", name) +} + +func startRDPProxy(t *testing.T, ctx context.Context, infra *PAMTestInfra, resourceName, accountName, duration string, port int) (int, *helpers.Command) { + pamCmd := helpers.Command{ + Test: t, + RunMethod: helpers.RunMethodSubprocess, + DisableTempHomeDir: true, + Args: []string{ + "pam", "rdp", "access", + "--resource", resourceName, + "--account", accountName, + "--project-id", infra.ProjectId, + "--duration", duration, + "--port", fmt.Sprintf("%d", port), + "--no-launch", + }, + Env: map[string]string{ + "HOME": infra.SharedHomeDir, + "INFISICAL_API_URL": infra.Infisical.ApiUrl(t), + }, + } + pamCmd.Start(ctx) + t.Cleanup(pamCmd.Stop) + + result := helpers.WaitFor(t, helpers.WaitForOptions{ + EnsureCmdRunning: &pamCmd, + Condition: func() helpers.ConditionResult { + if strings.Contains(pamCmd.Stderr(), "RDP Proxy Session Started") { + return helpers.ConditionSuccess + } + return helpers.ConditionWait + }, + }) + if result != helpers.WaitSuccess { + pamCmd.DumpOutput() + } + require.Equal(t, helpers.WaitSuccess, result, "RDP proxy should start successfully") + + return port, &pamCmd +} + +func findFreeRDPBinary(t *testing.T) string { + for _, name := range []string{"xfreerdp3", "xfreerdp"} { + if path, err := exec.LookPath(name); err == nil { + return path + } + } + t.Skip("xfreerdp not found; install freerdp2-x11 or freerdp3-x11") + return "" +} + +func connectFreeRDP(t *testing.T, ctx context.Context, binary string, proxyPort int, timeout time.Duration) error { + timeoutMs := int(timeout.Milliseconds()) + args := []string{ + binary, + fmt.Sprintf("/v:127.0.0.1:%d", proxyPort), + "/u:testuser", + "/p:", + "/cert:ignore", + fmt.Sprintf("/timeout:%d", timeoutMs), + } + + cmdCtx, cancel := context.WithTimeout(ctx, timeout+10*time.Second) + defer cancel() + + cmd := exec.CommandContext(cmdCtx, "xvfb-run", append([]string{"-a"}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("xfreerdp failed (exit %v): %s", err, string(output)) + } + return nil +} + +func TestPAM_RDP(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + infra := SetupPAMInfra(t, ctx) + LoginUser(t, ctx, infra) + + rdpBinary := findFreeRDPBinary(t) + resourceHost := getOutboundIP(t) + + t.Run("connection", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + slog.Info("RDP container started", "host", resourceHost, "port", rdpPort) + + resourceName := "rdp-connection-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-connection-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + startRDPProxy(t, ctx, infra, resourceName, "rdp-connection-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 30*time.Second) + require.NoError(t, err, "FreeRDP connection through proxy should succeed") + slog.Info("RDP connection test passed") + }) + + t.Run("bad-credentials", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-badcreds-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-badcreds-account", rdpUser, "wrong-password") + + proxyPort := helpers.GetFreePort() + _, pamCmd := startRDPProxy(t, ctx, infra, resourceName, "rdp-badcreds-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 30*time.Second) + require.Error(t, err, "FreeRDP should fail with bad credentials") + slog.Info("Bad credentials test passed", "error", err) + + _ = pamCmd + }) + + t.Run("unreachable-target", func(t *testing.T) { + resourceName := "rdp-unreachable-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, "192.0.2.1", 3389) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-unreachable-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + _, pamCmd := startRDPProxy(t, ctx, infra, resourceName, "rdp-unreachable-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 30*time.Second) + require.Error(t, err, "FreeRDP should fail when target is unreachable") + slog.Info("Unreachable target test passed", "error", err) + + _ = pamCmd + }) + + t.Run("reconnect", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-reconnect-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-reconnect-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + startRDPProxy(t, ctx, infra, resourceName, "rdp-reconnect-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 15*time.Second) + require.NoError(t, err, "First FreeRDP connection should succeed") + slog.Info("First RDP connection succeeded, reconnecting...") + + time.Sleep(2 * time.Second) + + err = connectFreeRDP(t, ctx, rdpBinary, proxyPort, 15*time.Second) + require.NoError(t, err, "Second FreeRDP connection (reconnect) should succeed") + slog.Info("Reconnect test passed") + }) + + t.Run("concurrent-connections", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-concurrent-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-concurrent-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + startRDPProxy(t, ctx, infra, resourceName, "rdp-concurrent-account", "5m", proxyPort) + + const numClients = 3 + var wg sync.WaitGroup + errs := make([]error, numClients) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + errs[idx] = connectFreeRDP(t, ctx, rdpBinary, proxyPort, 20*time.Second) + }(i) + } + + wg.Wait() + for i, err := range errs { + require.NoError(t, err, "concurrent RDP client %d should succeed", i) + } + slog.Info("All concurrent RDP connections succeeded", "numClients", numClients) + }) + + t.Run("session-duration", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-duration-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-duration-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + _, pamCmd := startRDPProxy(t, ctx, infra, resourceName, "rdp-duration-account", "30s", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 15*time.Second) + require.NoError(t, err, "Initial FreeRDP connection should succeed") + + result := helpers.WaitFor(t, helpers.WaitForOptions{ + Timeout: 60 * time.Second, + Interval: 2 * time.Second, + Condition: func() helpers.ConditionResult { + if !pamCmd.IsRunning() { + return helpers.ConditionSuccess + } + return helpers.ConditionWait + }, + }) + require.Equal(t, helpers.WaitSuccess, result, "RDP proxy should terminate after session duration expires") + slog.Info("Session duration test passed") + }) +} diff --git a/e2e/pam/testdata/rdp-server/Dockerfile b/e2e/pam/testdata/rdp-server/Dockerfile new file mode 100644 index 00000000..e07c9fcb --- /dev/null +++ b/e2e/pam/testdata/rdp-server/Dockerfile @@ -0,0 +1,16 @@ +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + xrdp xorgxrdp openbox dbus-x11 xterm && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +COPY entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +EXPOSE 3389 + +ENTRYPOINT ["/entrypoint.sh"] diff --git a/e2e/pam/testdata/rdp-server/entrypoint.sh b/e2e/pam/testdata/rdp-server/entrypoint.sh new file mode 100644 index 00000000..8cd4917b --- /dev/null +++ b/e2e/pam/testdata/rdp-server/entrypoint.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +useradd -m -s /bin/bash testuser +echo "testuser:testpass" | chpasswd + +mkdir -p /home/testuser +echo "openbox-session" > /home/testuser/.xsession +chown testuser:testuser /home/testuser/.xsession + +if [ ! -f /etc/xrdp/rsakeys.ini ]; then + xrdp-keygen xrdp auto +fi + +mkdir -p /run/dbus +dbus-daemon --system --fork + +xrdp-sesman --nodaemon & + +exec xrdp --nodaemon From fbc6fdfc014c4d5b4aeebe661f59169c4f1afe88 Mon Sep 17 00:00:00 2001 From: bernie-g Date: Mon, 11 May 2026 15:56:36 -0400 Subject: [PATCH 10/10] fix(e2e): seed recording config before creating Windows PAM resources The backend requires a pam_project_recording_configs row to exist before allowing Windows resource creation. Insert a dummy row directly into the database, bypassing FK checks since we don't need a real app connection for the E2E tests. --- e2e/pam/rdp_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/e2e/pam/rdp_test.go b/e2e/pam/rdp_test.go index efd0e186..e7144feb 100644 --- a/e2e/pam/rdp_test.go +++ b/e2e/pam/rdp_test.go @@ -15,8 +15,10 @@ import ( "time" "github.com/docker/docker/api/types/container" + "github.com/docker/go-connections/nat" "github.com/google/uuid" helpers "github.com/infisical/cli/e2e-tests/util" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" @@ -73,6 +75,35 @@ func pamAPIRequest(t *testing.T, infra *PAMTestInfra, method, path string, body return resp.StatusCode, respBody } +func setupRecordingConfig(t *testing.T, ctx context.Context, infra *PAMTestInfra) { + dbContainer, err := infra.Infisical.Compose().ServiceContainer(ctx, "db") + require.NoError(t, err) + dbPort, err := dbContainer.MappedPort(ctx, nat.Port("5432")) + require.NoError(t, err) + + connStr := fmt.Sprintf("postgres://infisical:infisical@localhost:%s/infisical", dbPort.Port()) + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + + // Temporarily bypass FK checks so we can insert a dummy connectionId + // without needing a real app_connections row. + _, err = conn.Exec(ctx, `SET session_replication_role = 'replica'`) + require.NoError(t, err) + defer func() { + _, _ = conn.Exec(ctx, `SET session_replication_role = 'origin'`) + }() + + _, err = conn.Exec(ctx, ` + INSERT INTO pam_project_recording_configs (id, "projectId", "storageBackend", "connectionId", bucket, region) + VALUES ($1, $2, 'aws-s3', $3, 'e2e-test-bucket', 'us-east-1') + ON CONFLICT ("projectId") DO NOTHING`, + uuid.New().String(), infra.ProjectId, uuid.New().String(), + ) + require.NoError(t, err) + slog.Info("Inserted recording config for project", "projectId", infra.ProjectId) +} + func createRDPPamResource(t *testing.T, ctx context.Context, infra *PAMTestInfra, name, host string, port int) uuid.UUID { status, respBody := pamAPIRequest(t, infra, "POST", "/api/v1/pam/resources/windows", map[string]interface{}{ "projectId": infra.ProjectId, @@ -205,6 +236,7 @@ func TestPAM_RDP(t *testing.T) { infra := SetupPAMInfra(t, ctx) LoginUser(t, ctx, infra) + setupRecordingConfig(t, ctx, infra) rdpBinary := findFreeRDPBinary(t) resourceHost := getOutboundIP(t)