diff --git a/.github/workflows/release_build_infisical_cli.yml b/.github/workflows/release_build_infisical_cli.yml index d388356a..bb5a74e9 100644 --- a/.github/workflows/release_build_infisical_cli.yml +++ b/.github/workflows/release_build_infisical_cli.yml @@ -31,37 +31,6 @@ jobs: build-rdp-bridge: uses: ./.github/workflows/build-rdp-bridge.yml - # Create the GitHub release draft up front so both goreleaser - # (ubuntu) and goreleaser-darwin (macos) can append to it in - # parallel instead of serializing on ubuntu creating the draft. - # Skipped on dry-run since --snapshot doesn't touch GitHub at all. - create-release-draft: - if: | - always() && - (needs.validate-tag-branch.result == 'success' || needs.validate-tag-branch.result == 'skipped') && - needs.cli-tests.result == 'success' && - (github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run)) - needs: - - validate-tag-branch - - cli-tests - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Create GitHub release draft (idempotent) - run: | - if gh release view "${{ github.ref_name }}" >/dev/null 2>&1; then - echo "Release for ${{ github.ref_name }} already exists, skipping creation" - else - gh release create "${{ github.ref_name }}" \ - --draft \ - --title "${{ github.ref_name }}" \ - --generate-notes - fi - env: - GH_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} - # cli-integration-tests: # name: Run tests before deployment # uses: ./.github/workflows/run-cli-tests.yml @@ -136,12 +105,10 @@ jobs: - validate-tag-branch - cli-tests - build-rdp-bridge - - create-release-draft if: | always() && needs.cli-tests.result == 'success' && needs.build-rdp-bridge.result == 'success' && - (needs.create-release-draft.result == 'success' || needs.create-release-draft.result == 'skipped') && (github.event_name == 'workflow_dispatch' || needs.validate-tag-branch.result == 'success') steps: - uses: actions/checkout@v3 @@ -218,14 +185,13 @@ jobs: mkdir -p "$target_dir" cp "/tmp/rdp-bridge-artifacts/rdp-bridge-$triple/libinfisical_rdp_bridge.a" "$target_dir/" done - - name: GoReleaser (build, no publish) + - name: GoReleaser (dry-run snapshot) + if: github.event_name == 'workflow_dispatch' && inputs.dry_run uses: goreleaser/goreleaser-action@v4 with: distribution: goreleaser-pro version: v1.26.2-pro - args: >- - release --clean --skip=publish,announce - ${{ (github.event_name == 'workflow_dispatch' && inputs.dry_run) && '--snapshot' || '' }} + args: release --clean --snapshot --skip=publish env: GITHUB_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} POSTHOG_API_KEY_FOR_CLI: ${{ secrets.POSTHOG_API_KEY_FOR_CLI }} @@ -240,6 +206,7 @@ jobs: path: dist/ retention-days: 7 - name: Smoke test linux binary across supported distros + if: github.event_name == 'workflow_dispatch' && inputs.dry_run run: | set -uo pipefail fail=0 @@ -282,13 +249,13 @@ jobs: echo "::endgroup::" [ "$fail" -eq 0 ] || exit 1 - - name: GoReleaser (publish) + - name: GoReleaser (release) if: github.event_name == 'push' || (github.event_name == 'workflow_dispatch' && !inputs.dry_run) uses: goreleaser/goreleaser-action@v4 with: distribution: goreleaser-pro version: v1.26.2-pro - args: release --skip=build,validate,before + args: release --clean env: GITHUB_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} POSTHOG_API_KEY_FOR_CLI: ${{ secrets.POSTHOG_API_KEY_FOR_CLI }} @@ -344,12 +311,10 @@ jobs: - validate-tag-branch - cli-tests - build-rdp-bridge - - create-release-draft if: | always() && needs.cli-tests.result == 'success' && needs.build-rdp-bridge.result == 'success' && - (needs.create-release-draft.result == 'success' || needs.create-release-draft.result == 'skipped') && (github.event_name == 'workflow_dispatch' || needs.validate-tag-branch.result == 'success') steps: - uses: actions/checkout@v4 @@ -426,11 +391,9 @@ jobs: needs: - validate-tag-branch - cli-tests - - create-release-draft if: | always() && needs.cli-tests.result == 'success' && - (needs.create-release-draft.result == 'success' || needs.create-release-draft.result == 'skipped') && (github.event_name == 'workflow_dispatch' || needs.validate-tag-branch.result == 'success') steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run-cli-e2e-tests.yml b/.github/workflows/run-cli-e2e-tests.yml index 1929dfd7..59c0eed7 100644 --- a/.github/workflows/run-cli-e2e-tests.yml +++ b/.github/workflows/run-cli-e2e-tests.yml @@ -82,8 +82,25 @@ jobs: go-version: "1.25.9" - name: Install dependencies run: go get . + - name: Cache cargo registry + target + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + packages/pam/handlers/rdp/native/target + key: rdp-bridge-cargo-${{ runner.os }}-${{ hashFiles('packages/pam/handlers/rdp/native/Cargo.lock') }} + restore-keys: rdp-bridge-cargo-${{ runner.os }}- + - name: Install pinned Rust toolchain + working-directory: packages/pam/handlers/rdp/native + run: rustup show active-toolchain + - name: Build Rust RDP bridge + working-directory: packages/pam/handlers/rdp/native + run: cargo build --release - name: Build the CLI - run: go build -o infisical-cli + run: CGO_ENABLED=1 go build -tags rdp -o infisical-cli + - name: Install RDP test dependencies + run: sudo apt-get update && sudo apt-get install -y --no-install-recommends freerdp2-x11 xvfb - name: Checkout infisical repo uses: actions/checkout@v6 with: diff --git a/.gitignore b/.gitignore index 2891130d..c654c50c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,5 @@ infisical /agent-testing .vscode/ -# PAM CLI session artifacts (local testing only) +.idea/ /session/ diff --git a/.goreleaser-darwin.yaml b/.goreleaser-darwin.yaml index adfe00b1..2a6a8e96 100644 --- a/.goreleaser-darwin.yaml +++ b/.goreleaser-darwin.yaml @@ -45,9 +45,7 @@ archives: - manpages/* - completions/* -# Append to the release draft created by the ubuntu goreleaser job. release: - replace_existing_draft: false mode: append checksum: diff --git a/.goreleaser-windows.yaml b/.goreleaser-windows.yaml index 73752b9d..b945f8ab 100644 --- a/.goreleaser-windows.yaml +++ b/.goreleaser-windows.yaml @@ -1,6 +1,4 @@ -# Append to the release draft created by the ubuntu goreleaser job. release: - replace_existing_draft: false mode: append builds: diff --git a/.goreleaser.yaml b/.goreleaser.yaml index b39dfde9..1b0a5362 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -161,10 +161,6 @@ archives: - completions/* release: - # The draft is created up front by the create-release-draft workflow - # job, so both this config and .goreleaser-darwin.yaml use append mode - # to add their artifacts in parallel. - replace_existing_draft: false mode: append checksum: diff --git a/e2e/openapi-cfg.yaml b/e2e/openapi-cfg.yaml index e40b0472..5c0ee79c 100644 --- a/e2e/openapi-cfg.yaml +++ b/e2e/openapi-cfg.yaml @@ -37,3 +37,5 @@ output-options: - createSshPamResource - createSshPamAccount - createRedisPamAccount + - createWindowsPamResource + - createWindowsPamAccount diff --git a/e2e/pam/rdp_test.go b/e2e/pam/rdp_test.go new file mode 100644 index 00000000..e7144feb --- /dev/null +++ b/e2e/pam/rdp_test.go @@ -0,0 +1,368 @@ +package pam + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os/exec" + "strings" + "sync" + "testing" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/docker/go-connections/nat" + "github.com/google/uuid" + helpers "github.com/infisical/cli/e2e-tests/util" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +const ( + rdpUser = "testuser" + rdpPassword = "testpass" +) + +func startRDPContainer(t *testing.T, ctx context.Context) (testcontainers.Container, int) { + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: testcontainers.ContainerRequest{ + FromDockerfile: testcontainers.FromDockerfile{ + Context: "testdata/rdp-server", + Dockerfile: "Dockerfile", + }, + ExposedPorts: []string{"3389/tcp"}, + HostConfigModifier: func(hc *container.HostConfig) { + hc.ExtraHosts = append(hc.ExtraHosts, "host.docker.internal:host-gateway") + }, + WaitingFor: wait.ForListeningPort("3389/tcp").WithStartupTimeout(60 * time.Second), + }, + Started: true, + }) + require.NoError(t, err) + t.Cleanup(func() { + if err := ctr.Terminate(ctx); err != nil { + t.Logf("Failed to terminate RDP container: %v", err) + } + }) + + port, err := ctr.MappedPort(ctx, "3389") + require.NoError(t, err) + return ctr, port.Int() +} + +func pamAPIRequest(t *testing.T, infra *PAMTestInfra, method, path string, body interface{}) (int, []byte) { + jsonBody, err := json.Marshal(body) + require.NoError(t, err) + + url := infra.Infisical.ApiUrl(t) + path + req, err := http.NewRequest(method, url, bytes.NewReader(jsonBody)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+infra.ProvisionResult.Token) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + return resp.StatusCode, respBody +} + +func setupRecordingConfig(t *testing.T, ctx context.Context, infra *PAMTestInfra) { + dbContainer, err := infra.Infisical.Compose().ServiceContainer(ctx, "db") + require.NoError(t, err) + dbPort, err := dbContainer.MappedPort(ctx, nat.Port("5432")) + require.NoError(t, err) + + connStr := fmt.Sprintf("postgres://infisical:infisical@localhost:%s/infisical", dbPort.Port()) + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + + // Temporarily bypass FK checks so we can insert a dummy connectionId + // without needing a real app_connections row. + _, err = conn.Exec(ctx, `SET session_replication_role = 'replica'`) + require.NoError(t, err) + defer func() { + _, _ = conn.Exec(ctx, `SET session_replication_role = 'origin'`) + }() + + _, err = conn.Exec(ctx, ` + INSERT INTO pam_project_recording_configs (id, "projectId", "storageBackend", "connectionId", bucket, region) + VALUES ($1, $2, 'aws-s3', $3, 'e2e-test-bucket', 'us-east-1') + ON CONFLICT ("projectId") DO NOTHING`, + uuid.New().String(), infra.ProjectId, uuid.New().String(), + ) + require.NoError(t, err) + slog.Info("Inserted recording config for project", "projectId", infra.ProjectId) +} + +func createRDPPamResource(t *testing.T, ctx context.Context, infra *PAMTestInfra, name, host string, port int) uuid.UUID { + status, respBody := pamAPIRequest(t, infra, "POST", "/api/v1/pam/resources/windows", map[string]interface{}{ + "projectId": infra.ProjectId, + "gatewayId": infra.GatewayId, + "name": name, + "connectionDetails": map[string]interface{}{ + "protocol": "rdp", + "hostname": host, + "port": port, + "winrmPort": 5985, + "useWinrmHttps": false, + "winrmRejectUnauthorized": false, + }, + }) + require.Equal(t, http.StatusOK, status, "create Windows resource: %s", string(respBody)) + + var result struct { + Resource struct { + Id uuid.UUID `json:"id"` + } `json:"resource"` + } + require.NoError(t, json.Unmarshal(respBody, &result)) + slog.Info("Created Windows PAM resource", "resourceId", result.Resource.Id, "name", name) + return result.Resource.Id +} + +func createRDPPamAccount(t *testing.T, ctx context.Context, infra *PAMTestInfra, resourceId uuid.UUID, name, username, password string) { + body := map[string]interface{}{ + "resourceId": resourceId.String(), + "name": name, + "credentials": map[string]interface{}{ + "username": username, + "password": password, + }, + "internalMetadata": map[string]interface{}{ + "accountType": "user", + }, + } + + result := helpers.WaitFor(t, helpers.WaitForOptions{ + Timeout: 90 * time.Second, + Interval: 3 * time.Second, + Condition: func() helpers.ConditionResult { + status, respBody := pamAPIRequest(t, infra, "POST", "/api/v1/pam/accounts/windows", body) + if status != http.StatusOK { + slog.Warn("Windows PAM account creation returned non-200, retrying...", "status", status, "body", string(respBody)) + return helpers.ConditionWait + } + return helpers.ConditionSuccess + }, + }) + require.Equal(t, helpers.WaitSuccess, result, "Windows PAM account creation should succeed for %s", name) + slog.Info("Created Windows PAM account", "name", name) +} + +func startRDPProxy(t *testing.T, ctx context.Context, infra *PAMTestInfra, resourceName, accountName, duration string, port int) (int, *helpers.Command) { + pamCmd := helpers.Command{ + Test: t, + RunMethod: helpers.RunMethodSubprocess, + DisableTempHomeDir: true, + Args: []string{ + "pam", "rdp", "access", + "--resource", resourceName, + "--account", accountName, + "--project-id", infra.ProjectId, + "--duration", duration, + "--port", fmt.Sprintf("%d", port), + "--no-launch", + }, + Env: map[string]string{ + "HOME": infra.SharedHomeDir, + "INFISICAL_API_URL": infra.Infisical.ApiUrl(t), + }, + } + pamCmd.Start(ctx) + t.Cleanup(pamCmd.Stop) + + result := helpers.WaitFor(t, helpers.WaitForOptions{ + EnsureCmdRunning: &pamCmd, + Condition: func() helpers.ConditionResult { + if strings.Contains(pamCmd.Stderr(), "RDP Proxy Session Started") { + return helpers.ConditionSuccess + } + return helpers.ConditionWait + }, + }) + if result != helpers.WaitSuccess { + pamCmd.DumpOutput() + } + require.Equal(t, helpers.WaitSuccess, result, "RDP proxy should start successfully") + + return port, &pamCmd +} + +func findFreeRDPBinary(t *testing.T) string { + for _, name := range []string{"xfreerdp3", "xfreerdp"} { + if path, err := exec.LookPath(name); err == nil { + return path + } + } + t.Skip("xfreerdp not found; install freerdp2-x11 or freerdp3-x11") + return "" +} + +func connectFreeRDP(t *testing.T, ctx context.Context, binary string, proxyPort int, timeout time.Duration) error { + timeoutMs := int(timeout.Milliseconds()) + args := []string{ + binary, + fmt.Sprintf("/v:127.0.0.1:%d", proxyPort), + "/u:testuser", + "/p:", + "/cert:ignore", + fmt.Sprintf("/timeout:%d", timeoutMs), + } + + cmdCtx, cancel := context.WithTimeout(ctx, timeout+10*time.Second) + defer cancel() + + cmd := exec.CommandContext(cmdCtx, "xvfb-run", append([]string{"-a"}, args...)...) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("xfreerdp failed (exit %v): %s", err, string(output)) + } + return nil +} + +func TestPAM_RDP(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + infra := SetupPAMInfra(t, ctx) + LoginUser(t, ctx, infra) + setupRecordingConfig(t, ctx, infra) + + rdpBinary := findFreeRDPBinary(t) + resourceHost := getOutboundIP(t) + + t.Run("connection", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + slog.Info("RDP container started", "host", resourceHost, "port", rdpPort) + + resourceName := "rdp-connection-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-connection-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + startRDPProxy(t, ctx, infra, resourceName, "rdp-connection-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 30*time.Second) + require.NoError(t, err, "FreeRDP connection through proxy should succeed") + slog.Info("RDP connection test passed") + }) + + t.Run("bad-credentials", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-badcreds-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-badcreds-account", rdpUser, "wrong-password") + + proxyPort := helpers.GetFreePort() + _, pamCmd := startRDPProxy(t, ctx, infra, resourceName, "rdp-badcreds-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 30*time.Second) + require.Error(t, err, "FreeRDP should fail with bad credentials") + slog.Info("Bad credentials test passed", "error", err) + + _ = pamCmd + }) + + t.Run("unreachable-target", func(t *testing.T) { + resourceName := "rdp-unreachable-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, "192.0.2.1", 3389) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-unreachable-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + _, pamCmd := startRDPProxy(t, ctx, infra, resourceName, "rdp-unreachable-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 30*time.Second) + require.Error(t, err, "FreeRDP should fail when target is unreachable") + slog.Info("Unreachable target test passed", "error", err) + + _ = pamCmd + }) + + t.Run("reconnect", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-reconnect-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-reconnect-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + startRDPProxy(t, ctx, infra, resourceName, "rdp-reconnect-account", "5m", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 15*time.Second) + require.NoError(t, err, "First FreeRDP connection should succeed") + slog.Info("First RDP connection succeeded, reconnecting...") + + time.Sleep(2 * time.Second) + + err = connectFreeRDP(t, ctx, rdpBinary, proxyPort, 15*time.Second) + require.NoError(t, err, "Second FreeRDP connection (reconnect) should succeed") + slog.Info("Reconnect test passed") + }) + + t.Run("concurrent-connections", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-concurrent-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-concurrent-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + startRDPProxy(t, ctx, infra, resourceName, "rdp-concurrent-account", "5m", proxyPort) + + const numClients = 3 + var wg sync.WaitGroup + errs := make([]error, numClients) + + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + errs[idx] = connectFreeRDP(t, ctx, rdpBinary, proxyPort, 20*time.Second) + }(i) + } + + wg.Wait() + for i, err := range errs { + require.NoError(t, err, "concurrent RDP client %d should succeed", i) + } + slog.Info("All concurrent RDP connections succeeded", "numClients", numClients) + }) + + t.Run("session-duration", func(t *testing.T) { + _, rdpPort := startRDPContainer(t, ctx) + + resourceName := "rdp-duration-resource" + resourceId := createRDPPamResource(t, ctx, infra, resourceName, resourceHost, rdpPort) + createRDPPamAccount(t, ctx, infra, resourceId, "rdp-duration-account", rdpUser, rdpPassword) + + proxyPort := helpers.GetFreePort() + _, pamCmd := startRDPProxy(t, ctx, infra, resourceName, "rdp-duration-account", "30s", proxyPort) + + err := connectFreeRDP(t, ctx, rdpBinary, proxyPort, 15*time.Second) + require.NoError(t, err, "Initial FreeRDP connection should succeed") + + result := helpers.WaitFor(t, helpers.WaitForOptions{ + Timeout: 60 * time.Second, + Interval: 2 * time.Second, + Condition: func() helpers.ConditionResult { + if !pamCmd.IsRunning() { + return helpers.ConditionSuccess + } + return helpers.ConditionWait + }, + }) + require.Equal(t, helpers.WaitSuccess, result, "RDP proxy should terminate after session duration expires") + slog.Info("Session duration test passed") + }) +} diff --git a/e2e/pam/ssh_test.go b/e2e/pam/ssh_test.go index 4d79b1d4..df329af5 100644 --- a/e2e/pam/ssh_test.go +++ b/e2e/pam/ssh_test.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "log/slog" - "net" "net/http" "os" "strings" @@ -33,8 +32,8 @@ const ( sshPassword = "testpass" ) -func startSSHContainer(t *testing.T, ctx context.Context, env map[string]string) (testcontainers.Container, int) { - container, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ +func startSSHContainer(t *testing.T, ctx context.Context, env map[string]string) (testcontainers.Container, string, int) { + ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ ContainerRequest: testcontainers.ContainerRequest{ FromDockerfile: testcontainers.FromDockerfile{ Context: "testdata/ssh-server", @@ -45,20 +44,25 @@ func startSSHContainer(t *testing.T, ctx context.Context, env map[string]string) HostConfigModifier: func(hc *container.HostConfig) { hc.ExtraHosts = append(hc.ExtraHosts, "host.docker.internal:host-gateway") }, - WaitingFor: wait.ForListeningPort("22/tcp").WithStartupTimeout(30 * time.Second), + WaitingFor: wait.ForAll( + wait.ForListeningPort("22/tcp"), + wait.ForLog("Server listening"), + ).WithStartupTimeout(30 * time.Second), }, Started: true, }) require.NoError(t, err) t.Cleanup(func() { - if err := container.Terminate(ctx); err != nil { + if err := ctr.Terminate(ctx); err != nil { t.Logf("Failed to terminate SSH container: %v", err) } }) - port, err := container.MappedPort(ctx, "22") + host, err := ctr.Host(ctx) require.NoError(t, err) - return container, port.Int() + port, err := ctr.MappedPort(ctx, "22") + require.NoError(t, err) + return ctr, host, port.Int() } func createSSHPamResource(t *testing.T, ctx context.Context, infra *PAMTestInfra, name, host string, port int) uuid.UUID { @@ -200,12 +204,20 @@ func configureCertAuth(t *testing.T, ctx context.Context, infra *PAMTestInfra, c require.NoError(t, err) require.Equal(t, 0, exitCode, "sshd reload should succeed") - // Wait for sshd to be responsive after config reload. + // Wait for sshd to be fully responsive after config reload. + // A TCP dial is not enough — sshd never closes the listen socket during + // SIGHUP, so the port is always reachable. Instead, do an actual SSH + // handshake to confirm sshd is serving connections after the reload. result := helpers.WaitFor(t, helpers.WaitForOptions{ Timeout: 10 * time.Second, Interval: 500 * time.Millisecond, Condition: func() helpers.ConditionResult { - conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", sshPort), time.Second) + conn, err := ssh.Dial("tcp", fmt.Sprintf("localhost:%d", sshPort), &ssh.ClientConfig{ + User: sshUser, + Auth: []ssh.AuthMethod{ssh.Password(sshPassword)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + }) if err != nil { return helpers.ConditionWait } @@ -221,7 +233,7 @@ func configureCertAuth(t *testing.T, ctx context.Context, infra *PAMTestInfra, c // - password: uses hardcoded testuser/testpass from entrypoint; account gets username + password // - public-key: container gets SSH_AUTHORIZED_KEY (generated ed25519); account gets username + privateKey // - certificate: container configured via curl | bash (ssh-ca-setup endpoint); account gets just username -func runSSHAuthTest(t *testing.T, ctx context.Context, infra *PAMTestInfra, resourceHost string, method string) { +func runSSHAuthTest(t *testing.T, ctx context.Context, infra *PAMTestInfra, method string) { containerEnv := map[string]string{} accountCreds := map[string]interface{}{ "authMethod": method, @@ -249,11 +261,11 @@ func runSSHAuthTest(t *testing.T, ctx context.Context, infra *PAMTestInfra, reso // Cert auth is configured after resource creation via curl | bash. } - container, sshPort := startSSHContainer(t, ctx, containerEnv) - slog.Info("SSH container started", "method", method, "host", resourceHost, "port", sshPort) + container, sshHost, sshPort := startSSHContainer(t, ctx, containerEnv) + slog.Info("SSH container started", "method", method, "host", sshHost, "port", sshPort) resourceName := fmt.Sprintf("ssh-%s-resource", method) - resourceId := createSSHPamResource(t, ctx, infra, resourceName, resourceHost, sshPort) + resourceId := createSSHPamResource(t, ctx, infra, resourceName, sshHost, sshPort) if method == "certificate" { configureCertAuth(t, ctx, infra, container, sshPort, resourceId) @@ -273,12 +285,10 @@ func TestPAM_SSH(t *testing.T) { infra := SetupPAMInfra(t, ctx) LoginUser(t, ctx, infra) - resourceHost := getOutboundIP(t) - methods := []string{"password", "public-key", "certificate"} for _, method := range methods { t.Run(method, func(t *testing.T) { - runSSHAuthTest(t, ctx, infra, resourceHost, method) + runSSHAuthTest(t, ctx, infra, method) }) } } diff --git a/e2e/pam/testdata/rdp-server/Dockerfile b/e2e/pam/testdata/rdp-server/Dockerfile new file mode 100644 index 00000000..e07c9fcb --- /dev/null +++ b/e2e/pam/testdata/rdp-server/Dockerfile @@ -0,0 +1,16 @@ +FROM ubuntu:22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + xrdp xorgxrdp openbox dbus-x11 xterm && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +COPY entrypoint.sh /entrypoint.sh +RUN chmod +x /entrypoint.sh + +EXPOSE 3389 + +ENTRYPOINT ["/entrypoint.sh"] diff --git a/e2e/pam/testdata/rdp-server/entrypoint.sh b/e2e/pam/testdata/rdp-server/entrypoint.sh new file mode 100644 index 00000000..8cd4917b --- /dev/null +++ b/e2e/pam/testdata/rdp-server/entrypoint.sh @@ -0,0 +1,20 @@ +#!/bin/bash +set -e + +useradd -m -s /bin/bash testuser +echo "testuser:testpass" | chpasswd + +mkdir -p /home/testuser +echo "openbox-session" > /home/testuser/.xsession +chown testuser:testuser /home/testuser/.xsession + +if [ ! -f /etc/xrdp/rsakeys.ini ]; then + xrdp-keygen xrdp auto +fi + +mkdir -p /run/dbus +dbus-daemon --system --fork + +xrdp-sesman --nodaemon & + +exec xrdp --nodaemon diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index 07c32120..b8f4e127 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -14,6 +14,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/Infisical/infisical-merge/packages/api" @@ -86,8 +87,9 @@ type GatewayConfig struct { } type pamSessionEntry struct { - cancel context.CancelFunc - conn *tls.Conn + cancel context.CancelFunc + conn *tls.Conn + lastActivity atomic.Int64 } type Gateway struct { @@ -166,20 +168,33 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { } // RegisterPAMSession registers an active PAM proxy connection for cancellation support -func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) { +// Returns a function that handlers should call when data flows through the connection +func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) func() { + entry := &pamSessionEntry{cancel: cancel, conn: conn} + entry.lastActivity.Store(time.Now().Unix()) + g.pamSessionsMu.Lock() defer g.pamSessionsMu.Unlock() - g.pamSessions[sessionID] = append(g.pamSessions[sessionID], &pamSessionEntry{cancel: cancel, conn: conn}) + g.pamSessions[sessionID] = append(g.pamSessions[sessionID], entry) + + return func() { + entry.lastActivity.Store(time.Now().Unix()) + } } // DeregisterPAMSession removes a specific connection from the session registry. +// Returns true if this was the last connection for the session. // The MongoDB proxy (if any) is NOT closed here — it persists across connections // so that subsequent client connections (e.g. mongosh retries) find a warm topology. // The proxy is cleaned up on session cancellation or gateway shutdown. -func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) { +func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) bool { g.pamSessionsMu.Lock() defer g.pamSessionsMu.Unlock() - entries := g.pamSessions[sessionID] + + entries, exists := g.pamSessions[sessionID] + if !exists { + return false + } for i, e := range entries { if e.conn == conn { g.pamSessions[sessionID] = append(entries[:i], entries[i+1:]...) @@ -188,7 +203,9 @@ func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) { } if len(g.pamSessions[sessionID]) == 0 { delete(g.pamSessions, sessionID) + return true } + return false } // CancelPAMSession kills all active connections for a PAM session @@ -264,6 +281,51 @@ func (g *Gateway) closeMongoProxy(sessionID string) { } } +const pamIdleTimeout = 30 * time.Minute + +// startIdleReaper periodically scans the PAM session registry and cancels +// sessions whose connections have had no data flow for pamIdleTimeout +func (g *Gateway) startIdleReaper(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + g.reapIdleSessions() + } + } +} + +func (g *Gateway) reapIdleSessions() { + cutoff := time.Now().Add(-pamIdleTimeout).Unix() + + g.pamSessionsMu.Lock() + var stale []string + for sessionID, entries := range g.pamSessions { + allIdle := true + for _, e := range entries { + if e.lastActivity.Load() > cutoff { + allIdle = false + break + } + } + if allIdle { + stale = append(stale, sessionID) + } + } + g.pamSessionsMu.Unlock() + + for _, sessionID := range stale { + log.Info().Str("sessionId", sessionID).Dur("idleTimeout", pamIdleTimeout).Msg("Reaping idle PAM session") + g.CancelPAMSession(sessionID) + if err := g.pamSessionUploader.CleanupPAMSession(sessionID, "idle_timeout"); err != nil { + log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to cleanup reaped PAM session") + } + } +} + func (g *Gateway) registerHeartBeat(ctx context.Context, errCh chan error) { sendHeartbeat := func() error { if err := api.CallGatewayHeartBeatV2(g.httpClient); err != nil { @@ -329,6 +391,8 @@ func (g *Gateway) Start(ctx context.Context) error { // Start session uploader goroutine for PAM g.pamSessionUploader.Start() + go g.startIdleReaper(ctx) + go func() { for { select { @@ -489,6 +553,25 @@ func (g *Gateway) handleConnection(client *ssh.Client) error { client.Close() }() + // Keepalive on the relay SSH connection. If the relay drops silently, + // this closes the client so the reconnect loop in connectWithRetry kicks in + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := util.SSHKeepalive(client, 15*time.Second); err != nil { + log.Warn().Err(err).Msg("Relay SSH keepalive failed, closing connection") + client.Close() + return + } + case <-g.ctx.Done(): + return + } + } + }() + // Process incoming channels with context cancellation support for { select { @@ -751,8 +834,8 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { return } else if forwardConfig.Mode == ForwardModePAM { sessionCtx, sessionCancel := context.WithCancel(g.ctx) - g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) - defer g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn) + touchSession := g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) + forwardConfig.PAMConfig.OnActivity = touchSession if err := pam.HandlePAMProxy(sessionCtx, tlsConn, &forwardConfig.PAMConfig, g.httpClient); err != nil { if err.Error() == "unexpected EOF" { log.Debug().Err(err).Msg("PAM proxy handler ended with unexpected connection termination") @@ -760,6 +843,14 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { log.Error().Err(err).Msg("PAM proxy handler ended with error") } } + sessionCancel() + if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn { + if err := forwardConfig.PAMConfig.SessionUploader.CleanupPAMSession( + forwardConfig.PAMConfig.SessionId, "connection_closed", + ); err != nil { + log.Error().Err(err).Str("sessionId", forwardConfig.PAMConfig.SessionId).Msg("Failed to cleanup PAM session") + } + } return } else if forwardConfig.Mode == ForwardModePAMCancellation { if err := pam.HandlePAMCancellation(g.ctx, tlsConn, &forwardConfig.PAMConfig, g.httpClient, g.CancelPAMSession); err != nil { diff --git a/packages/pam/handlers/oracle/ATTRIBUTION.md b/packages/pam/handlers/oracle/ATTRIBUTION.md new file mode 100644 index 00000000..8a3a8c60 --- /dev/null +++ b/packages/pam/handlers/oracle/ATTRIBUTION.md @@ -0,0 +1,23 @@ +This package contains code adapted from [sijms/go-ora](https://github.com/sijms/go-ora). + +MIT License + +Copyright (c) 2020 Samy Sultan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/packages/pam/handlers/oracle/constants.go b/packages/pam/handlers/oracle/constants.go new file mode 100644 index 00000000..7d1b3541 --- /dev/null +++ b/packages/pam/handlers/oracle/constants.go @@ -0,0 +1,4 @@ +package oracle + +// Must be passed by the client; also used for O5Logon key derivation in phase 1. +const ProxyPasswordPlaceholder = "password" diff --git a/packages/pam/handlers/oracle/o5logon.go b/packages/pam/handlers/oracle/o5logon.go new file mode 100644 index 00000000..ca66ea51 --- /dev/null +++ b/packages/pam/handlers/oracle/o5logon.go @@ -0,0 +1,108 @@ +package oracle + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha512" + "encoding/hex" + "fmt" +) + +const ( + ORA1017InvalidCredentials = 1017 +) + +func PKCS5Padding(cipherText []byte, blockSize int) []byte { + padding := blockSize - len(cipherText)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(cipherText, padtext...) +} + +func generateSpeedyKey(buffer, key []byte, turns int) []byte { + mac := hmac.New(sha512.New, key) + mac.Write(append(buffer, 0, 0, 0, 1)) + firstHash := mac.Sum(nil) + tempHash := make([]byte, len(firstHash)) + copy(tempHash, firstHash) + for index1 := 2; index1 <= turns; index1++ { + mac.Reset() + mac.Write(tempHash) + tempHash = mac.Sum(nil) + for index2 := 0; index2 < 64; index2++ { + firstHash[index2] = firstHash[index2] ^ tempHash[index2] + } + } + return firstHash +} + +func decryptSessionKey(padding bool, encKey []byte, sessionKeyHex string) ([]byte, error) { + result, err := hex.DecodeString(sessionKeyHex) + if err != nil { + return nil, err + } + blk, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + dec := cipher.NewCBCDecrypter(blk, make([]byte, 16)) + output := make([]byte, len(result)) + dec.CryptBlocks(output, result) + cutLen := 0 + if padding { + num := int(output[len(output)-1]) + if num < dec.BlockSize() { + apply := true + for x := len(output) - num; x < len(output); x++ { + if output[x] != uint8(num) { + apply = false + break + } + } + if apply { + cutLen = int(output[len(output)-1]) + } + } + } + return output[:len(output)-cutLen], nil +} + +func encryptSessionKey(padding bool, encKey []byte, sessionKey []byte) (string, error) { + blk, err := aes.NewCipher(encKey) + if err != nil { + return "", err + } + enc := cipher.NewCBCEncrypter(blk, make([]byte, 16)) + originalLen := len(sessionKey) + sessionKey = PKCS5Padding(sessionKey, blk.BlockSize()) + output := make([]byte, len(sessionKey)) + enc.CryptBlocks(output, sessionKey) + if !padding { + return fmt.Sprintf("%X", output[:originalLen]), nil + } + return fmt.Sprintf("%X", output), nil +} + +func encryptPassword(password, key []byte, padding bool) (string, error) { + buff1 := make([]byte, 0x10) + if _, err := rand.Read(buff1); err != nil { + return "", err + } + buffer := append(buff1, password...) + return encryptSessionKey(padding, key, buffer) +} + +func deriveServerKey(password string, salt []byte, vGenCount int) (key []byte, speedy []byte, err error) { + message := append([]byte(nil), salt...) + message = append(message, []byte("AUTH_PBKDF2_SPEEDY_KEY")...) + speedy = generateSpeedyKey(message, []byte(password), vGenCount) + + buffer := append([]byte(nil), speedy...) + buffer = append(buffer, salt...) + h := sha512.New() + h.Write(buffer) + key = h.Sum(nil)[:32] + return +} diff --git a/packages/pam/handlers/oracle/o5logon_server.go b/packages/pam/handlers/oracle/o5logon_server.go new file mode 100644 index 00000000..7df71436 --- /dev/null +++ b/packages/pam/handlers/oracle/o5logon_server.go @@ -0,0 +1,163 @@ +package oracle + +import ( + "fmt" + "net" +) + +const ( + TTCMsgAuthRequest = 0x03 + TTCMsgError = 0x04 +) + +const ( + AuthSubOpPhaseOne = 0x76 + AuthSubOpPhaseTwo = 0x73 +) + +type AuthPhaseTwo struct { + EClientSessKey string + EPassword string +} + +func readDataPayload(conn net.Conn, use32BitLen bool) ([]byte, error) { + raw, err := ReadFullPacket(conn, use32BitLen) + if err != nil { + return nil, err + } + if PacketTypeOf(raw) == PacketTypeMarker { + return readDataPayload(conn, use32BitLen) + } + if PacketTypeOf(raw) != PacketTypeData { + return nil, fmt.Errorf("expected DATA packet, got type=%d", raw[4]) + } + pkt, err := ParseDataPacket(raw, use32BitLen) + if err != nil { + return nil, err + } + return pkt.Payload, nil +} + +func writeDataPayload(conn net.Conn, payload []byte, use32BitLen bool) error { + d := &DataPacket{Payload: payload} + _, err := conn.Write(d.Bytes(use32BitLen)) + return err +} + +func ParseAuthPhaseTwo(payload []byte) (*AuthPhaseTwo, error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, err + } + if op != TTCMsgAuthRequest { + return nil, fmt.Errorf("phase2 unexpected opcode 0x%02X", op) + } + sub, err := r.GetByte() + if err != nil { + return nil, err + } + if sub != AuthSubOpPhaseTwo { + return nil, fmt.Errorf("phase2 unexpected sub-op 0x%02X", sub) + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + + out := &AuthPhaseTwo{} + + hasUser, err := r.GetByte() + if err != nil { + return nil, err + } + var userLen int + if hasUser == 1 { + userLen, err = r.GetInt(4, true, true) + if err != nil { + return nil, err + } + } else { + if _, err := r.GetByte(); err != nil { + return nil, err + } + } + + if _, err := r.GetInt(4, true, true); err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + count, err := r.GetInt(4, true, true) + if err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + if hasUser == 1 && userLen > 0 { + // go-ora prefixes username with CLR length byte; JDBC thin sends it raw. + peek, perr := r.PeekByte() + if perr != nil { + return nil, fmt.Errorf("peek phase2 username: %w", perr) + } + if int(peek) == userLen && peek < 0x20 { + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("consume phase2 username length prefix: %w", err) + } + } + if _, err := r.GetBytes(userLen); err != nil { + return nil, fmt.Errorf("read phase2 username bytes: %w", err) + } + } + + for i := 0; i < count; i++ { + k, v, _, err := r.GetKeyVal() + if err != nil { + return nil, fmt.Errorf("phase2 KVP #%d: %w", i, err) + } + switch string(k) { + case "AUTH_SESSKEY": + out.EClientSessKey = string(v) + case "AUTH_PASSWORD": + out.EPassword = string(v) + } + } + return out, nil +} + +func BuildErrorPacket(oraCode int, message string) []byte { + b := NewTTCBuilder() + b.PutBytes(TTCMsgError) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(int64(oraCode), 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 1, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 4, true, true) + b.PutInt(0, 2, true, true) + b.PutInt(0, 2, true, true) + b.PutString(message) + b.PutInt(0, 2, true, true) + return b.Bytes() +} + +func WriteErrorToClient(conn net.Conn, oraCode int, message string, use32BitLen bool) error { + return writeDataPayload(conn, BuildErrorPacket(oraCode, message), use32BitLen) +} diff --git a/packages/pam/handlers/oracle/proxy.go b/packages/pam/handlers/oracle/proxy.go new file mode 100644 index 00000000..71916b6a --- /dev/null +++ b/packages/pam/handlers/oracle/proxy.go @@ -0,0 +1,64 @@ +package oracle + +import ( + "context" + "crypto/tls" + "fmt" + "net" + + "github.com/Infisical/infisical-merge/packages/pam/session" +) + +type OracleProxyConfig struct { + TargetAddr string + InjectUsername string + InjectPassword string + InjectDatabase string + EnableTLS bool + TLSConfig *tls.Config + SessionID string + SessionLogger session.SessionLogger +} + +type OracleProxy struct { + config OracleProxyConfig +} + +func NewOracleProxy(config OracleProxyConfig) *OracleProxy { + return &OracleProxy{config: config} +} + +func (p *OracleProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { + return p.handleConnectionProxied(ctx, clientConn) +} + +func relayWithTap(src, dst net.Conn, tap *QueryExtractor, errCh chan<- error) { + buf := make([]byte, 32*1024) + for { + n, err := src.Read(buf) + if n > 0 { + if _, werr := dst.Write(buf[:n]); werr != nil { + errCh <- werr + return + } + tap.Feed(buf[:n]) + } + if err != nil { + errCh <- err + return + } + } +} + +func splitHostPort(addr string) (string, int, error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", 0, err + } + var port int + _, err = fmt.Sscanf(portStr, "%d", &port) + if err != nil { + return "", 0, fmt.Errorf("bad port %q: %w", portStr, err) + } + return host, port, nil +} diff --git a/packages/pam/handlers/oracle/proxy_auth.go b/packages/pam/handlers/oracle/proxy_auth.go new file mode 100644 index 00000000..51caa643 --- /dev/null +++ b/packages/pam/handlers/oracle/proxy_auth.go @@ -0,0 +1,830 @@ +package oracle + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/rs/zerolog/log" +) + +func (p *OracleProxy) handleConnectionProxied(ctx context.Context, clientConn net.Conn) error { + defer clientConn.Close() + defer func() { + if err := p.config.SessionLogger.Close(); err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to close session logger") + } + }() + + log.Info().Str("sessionID", p.config.SessionID).Str("target", p.config.TargetAddr).Msg("Oracle PAM session started (proxied auth)") + + // 1. Dial upstream (keep raw TCP ref — TCPS may need a second TLS handshake mid-flow). + rawUpstream, tlsUpstream, err := dialUpstreamRaw(ctx, p.config) + if err != nil { + log.Error().Err(err).Str("sessionID", p.config.SessionID).Msg("Failed to dial Oracle upstream") + _ = WriteRefuseToClient(clientConn, "(DESCRIPTION=(ERR=12564)(VSNNUM=0)(ERROR_STACK=(ERROR=(CODE=12564)(EMFI=4))))") + return fmt.Errorf("upstream dial: %w", err) + } + var upstreamConn net.Conn + if tlsUpstream != nil { + upstreamConn = tlsUpstream + } else { + upstreamConn = rawUpstream + } + defer func() { upstreamConn.Close() }() + + // 2. Read client's CONNECT, rewrite SERVICE_NAME, forward to upstream. + connectRaw, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read client CONNECT: %w", err) + } + if PacketTypeOf(connectRaw) != PacketTypeConnect { + return fmt.Errorf("expected CONNECT, got type=%d", connectRaw[4]) + } + if p.config.InjectDatabase == "" { + return fmt.Errorf("InjectDatabase (service name) is required but empty") + } + connectRaw = rewriteConnectServiceName(connectRaw, p.config.InjectDatabase) + if _, err := upstreamConn.Write(connectRaw); err != nil { + return fmt.Errorf("forward CONNECT: %w", err) + } + + // If connect-data wasn't inline (JDBC thin / go-ora with long descriptions), + // the client already sent it as a follow-up packet. Forward it now — some + // Oracle listeners (e.g., OCI Autonomous DB) won't RESEND, they just wait. + if len(connectRaw) >= 28 { + cdLen := int(binary.BigEndian.Uint16(connectRaw[24:26])) + cdOff := int(binary.BigEndian.Uint16(connectRaw[26:28])) + if cdLen > 0 && cdOff+cdLen > len(connectRaw) { + supplement, serr := ReadFullPacket(clientConn, false) + if serr != nil { + return fmt.Errorf("read connect-data supplement: %w", serr) + } + log.Info().Str("sessionID", p.config.SessionID).Int("supplementLen", len(supplement)).Msg("Proxy: forwarding connect-data supplement before handshake") + if _, werr := upstreamConn.Write(supplement); werr != nil { + return fmt.Errorf("forward connect-data supplement: %w", werr) + } + } + } + + // 3. Read upstream responses until ACCEPT. Handle RESEND (TLS restart). + const maxHandshakeAttempts = 10 + var acceptRaw []byte + for attempt := 0; acceptRaw == nil; attempt++ { + if attempt >= maxHandshakeAttempts { + return fmt.Errorf("too many handshake packets (%d) without ACCEPT", attempt) + } + pkt, err := ReadFullPacket(upstreamConn, false) + if err != nil { + return fmt.Errorf("read upstream handshake packet (attempt %d): %w", attempt, err) + } + pktType := PacketTypeOf(pkt) + var origFlag byte + if len(pkt) > 5 { + origFlag = pkt[5] + } + log.Info().Str("sessionID", p.config.SessionID).Uint8("pktType", uint8(pktType)).Int("pktLen", len(pkt)).Uint8("flag", origFlag).Msg("Proxy: upstream handshake packet") + + // RESEND flag 0x08: tear down current TLS, do a fresh handshake on the raw socket. + if p.config.EnableTLS && pktType == PacketTypeResend && origFlag&0x08 != 0 { + tc, terr := upgradeToTLS(ctx, rawUpstream, p.config) + if terr != nil { + return fmt.Errorf("upstream TLS upgrade after RESEND(flag=0x08): %w", terr) + } + upstreamConn = tc + log.Info().Str("sessionID", p.config.SessionID).Str("tlsVersion", tlsVersionString(tc.ConnectionState().Version)).Str("cipher", tls.CipherSuiteName(tc.ConnectionState().CipherSuite)).Msg("Proxy: upstream TLS re-handshook on RESEND(flag=0x08)") + } + + // Mask byte 5 so thin clients don't try TLS upgrade on plain TCP. + if p.config.EnableTLS && len(pkt) > 5 { + pkt[5] = 0x00 + } + if _, werr := clientConn.Write(pkt); werr != nil { + return fmt.Errorf("forward upstream handshake packet: %w", werr) + } + switch pktType { + case PacketTypeAccept: + acceptRaw = pkt + case PacketTypeRefuse: + return fmt.Errorf("upstream REFUSE during handshake") + case PacketTypeRedirect: + return fmt.Errorf("upstream REDIRECT during handshake (not supported)") + case PacketTypeResend: + clientPkt, err := ReadFullPacket(clientConn, false) + if err != nil { + return fmt.Errorf("read client response after RESEND: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("len", len(clientPkt)).Uint8("type", uint8(PacketTypeOf(clientPkt))).Msg("Proxy: forwarding client response after RESEND") + if _, werr := upstreamConn.Write(clientPkt); werr != nil { + return fmt.Errorf("forward client response after RESEND: %w", werr) + } + // Client may re-send CONNECT with non-inline connect-data — forward + // the supplement too, same as we did before the handshake loop. + if PacketTypeOf(clientPkt) == PacketTypeConnect && len(clientPkt) >= 28 { + cdLen := int(binary.BigEndian.Uint16(clientPkt[24:26])) + cdOff := int(binary.BigEndian.Uint16(clientPkt[26:28])) + if cdLen > 0 && cdOff+cdLen > len(clientPkt) { + supp, serr := ReadFullPacket(clientConn, false) + if serr != nil { + return fmt.Errorf("read connect-data supplement after RESEND: %w", serr) + } + log.Info().Str("sessionID", p.config.SessionID).Int("supplementLen", len(supp)).Msg("Proxy: forwarding connect-data supplement after RESEND") + if _, werr := upstreamConn.Write(supp); werr != nil { + return fmt.Errorf("forward connect-data supplement after RESEND: %w", werr) + } + } + } + } + } + + var acceptVersion uint16 + if len(acceptRaw) >= 10 { + acceptVersion = binary.BigEndian.Uint16(acceptRaw[8:10]) + } + use32Bit := acceptVersion >= 315 + log.Info().Str("sessionID", p.config.SessionID).Uint16("acceptVersion", acceptVersion).Bool("use32Bit", use32Bit).Msg("Proxy: ACCEPT forwarded") + + + + p1Payload, err := proxyUntilAuthRequest(clientConn, upstreamConn, use32Bit, p.config.SessionID) + if err != nil { + return fmt.Errorf("pre-auth proxy: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Int("p1Len", len(p1Payload)).Msg("Proxy: auth-request boundary reached") + + p1Forward := p1Payload + if p.config.InjectUsername != "" { + rewritten, rerr := rewritePhase1User(p1Payload, p.config.InjectUsername) + if rerr != nil { + return fmt.Errorf("rewrite phase 1 username: %w", rerr) + } + p1Forward = rewritten + } + if err := writeDataPayload(upstreamConn, p1Forward, use32Bit); err != nil { + return fmt.Errorf("forward phase 1 request: %w", err) + } + + p1RespUpstream, err := readDataPayload(upstreamConn, use32Bit) + if err != nil { + return fmt.Errorf("read upstream phase 1 response: %w", err) + } + state, p1RespTranslated, err := translatePhase1Response(p1RespUpstream, p.config.InjectPassword) + if err != nil { + _ = WriteErrorToClient(clientConn, ORA1017InvalidCredentials, "ORA-01017: invalid username/password; logon denied", use32Bit) + return fmt.Errorf("translate phase 1 response: %w", err) + } + if err := writeDataPayload(clientConn, p1RespTranslated, use32Bit); err != nil { + return fmt.Errorf("write translated phase 1 response: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-1 response translated and forwarded") + + p2ReqClient, err := readDataPayload(clientConn, use32Bit) + if err != nil { + return fmt.Errorf("read client phase 2 request: %w", err) + } + p2ReqTranslated, err := translatePhase2Request(p2ReqClient, state, p.config.InjectPassword) + if err != nil { + _ = WriteErrorToClient(clientConn, ORA1017InvalidCredentials, "ORA-01017: invalid username/password; logon denied", use32Bit) + return fmt.Errorf("translate phase 2 request: %w", err) + } + // Oracle cross-checks phase-2 username against phase-1. + if p.config.InjectUsername != "" { + rewritten, rerr := rewritePhase2User(p2ReqTranslated, p.config.InjectUsername) + if rerr != nil { + return fmt.Errorf("rewrite phase 2 username: %w", rerr) + } + p2ReqTranslated = rewritten + } + if err := writeDataPayload(upstreamConn, p2ReqTranslated, use32Bit); err != nil { + return fmt.Errorf("forward phase 2 request: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-2 request translated and forwarded") + + // AUTH_SVR_RESPONSE is keyed on session material (not password) — forward unchanged. + p2RespRaw, err := ReadFullPacket(upstreamConn, use32Bit) + if err != nil { + return fmt.Errorf("read upstream phase 2 response: %w", err) + } + if _, err := clientConn.Write(p2RespRaw); err != nil { + return fmt.Errorf("forward phase 2 response: %w", err) + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Proxy: phase-2 response forwarded; client authenticated") + + state = nil + + c2u, u2c := NewQueryExtractorPair(p.config.SessionLogger, p.config.SessionID, use32Bit) + defer c2u.Stop() + defer u2c.Stop() + + errCh := make(chan error, 2) + go relayWithTap(clientConn, upstreamConn, c2u, errCh) + go relayWithTap(upstreamConn, clientConn, u2c, errCh) + + select { + case rerr := <-errCh: + if rerr != nil && rerr != io.EOF { + log.Debug().Err(rerr).Str("sessionID", p.config.SessionID).Msg("Oracle relay ended") + } + case <-ctx.Done(): + log.Info().Str("sessionID", p.config.SessionID).Msg("Oracle session cancelled by context") + } + log.Info().Str("sessionID", p.config.SessionID).Msg("Oracle PAM session ended") + return nil +} + +// Includes legacy RSA-CBC suites needed by Oracle 19c / AWS RDS. +var oracleUpstreamCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, +} + +// TLS 1.0–1.2 only: Oracle TCPS has no TLS-1.3 restart mechanism; RDS negotiates down to 1.0. +func buildOracleTLSConfig(base *tls.Config, host string) *tls.Config { + cfg := base.Clone() + if cfg.ServerName == "" { + cfg.ServerName = host + } + cfg.MinVersion = tls.VersionTLS10 + cfg.MaxVersion = tls.VersionTLS12 + cfg.CipherSuites = oracleUpstreamCiphers + return cfg +} + +func dialUpstreamRaw(ctx context.Context, cfg OracleProxyConfig) (rawConn net.Conn, tlsConn *tls.Conn, err error) { + host, _, err := splitHostPort(cfg.TargetAddr) + if err != nil { + return nil, nil, fmt.Errorf("invalid target addr: %w", err) + } + d := &net.Dialer{Timeout: 15 * time.Second} + rawConn, err = d.DialContext(ctx, "tcp", cfg.TargetAddr) + if err != nil { + return nil, nil, err + } + if !cfg.EnableTLS { + return rawConn, nil, nil + } + if cfg.TLSConfig == nil { + rawConn.Close() + return nil, nil, fmt.Errorf("upstream TLS requested but no TLSConfig provided") + } + tlsCfg := buildOracleTLSConfig(cfg.TLSConfig, host) + tc := tls.Client(rawConn, tlsCfg) + if err := tc.HandshakeContext(ctx); err != nil { + rawConn.Close() + return nil, nil, fmt.Errorf("upstream TLS handshake: %w", err) + } + return rawConn, tc, nil +} + +func tlsVersionString(v uint16) string { + switch v { + case tls.VersionTLS10: + return "TLS1.0" + case tls.VersionTLS11: + return "TLS1.1" + case tls.VersionTLS12: + return "TLS1.2" + case tls.VersionTLS13: + return "TLS1.3" + default: + return fmt.Sprintf("0x%04x", v) + } +} + +func upgradeToTLS(ctx context.Context, rawConn net.Conn, cfg OracleProxyConfig) (*tls.Conn, error) { + host, _, err := splitHostPort(cfg.TargetAddr) + if err != nil { + return nil, fmt.Errorf("invalid target addr: %w", err) + } + tlsCfg := buildOracleTLSConfig(cfg.TLSConfig, host) + tc := tls.Client(rawConn, tlsCfg) + if err := tc.HandshakeContext(ctx); err != nil { + return nil, fmt.Errorf("upstream TLS handshake: %w", err) + } + return tc, nil +} + +func proxyUntilAuthRequest(client, upstream net.Conn, use32Bit bool, sessionID string) ([]byte, error) { + type result struct { + payload []byte + err error + } + done := make(chan result, 2) + stop := make(chan struct{}) + + go func() { + for { + select { + case <-stop: + return + default: + } + pkt, err := ReadFullPacket(upstream, use32Bit) + if err != nil { + select { + case done <- result{err: fmt.Errorf("read upstream: %w", err)}: + default: + } + return + } + if _, werr := client.Write(pkt); werr != nil { + select { + case done <- result{err: fmt.Errorf("write client: %w", werr)}: + default: + } + return + } + log.Debug().Str("sessionID", sessionID).Uint8("type", uint8(PacketTypeOf(pkt))).Int("len", len(pkt)).Msg("Proxy pre-auth: upstream → client") + } + }() + + go func() { + for { + select { + case <-stop: + return + default: + } + pkt, err := ReadFullPacket(client, use32Bit) + if err != nil { + select { + case done <- result{err: fmt.Errorf("read client: %w", err)}: + default: + } + return + } + pktType := PacketTypeOf(pkt) + if pktType == PacketTypeData { + payload, perr := extractDataPayload(pkt) + if perr == nil && len(payload) >= 2 && + payload[0] == TTCMsgAuthRequest && payload[1] == AuthSubOpPhaseOne { + select { + case done <- result{payload: payload}: + default: + } + return + } + } + if _, werr := upstream.Write(pkt); werr != nil { + select { + case done <- result{err: fmt.Errorf("write upstream: %w", werr)}: + default: + } + return + } + log.Debug().Str("sessionID", sessionID).Uint8("type", uint8(pktType)).Int("len", len(pkt)).Msg("Proxy pre-auth: client → upstream") + } + }() + + res := <-done + close(stop) + // Unblock the other goroutine so it doesn't steal the phase-1 response. + if uc, ok := upstream.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = uc.SetReadDeadline(time.Now().Add(-1 * time.Second)) + } + time.Sleep(50 * time.Millisecond) + if uc, ok := upstream.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = uc.SetReadDeadline(time.Time{}) + } + if res.err != nil { + return nil, res.err + } + return res.payload, nil +} + +func extractDataPayload(pkt []byte) ([]byte, error) { + const headerLen = 10 + if len(pkt) < headerLen { + return nil, fmt.Errorf("packet too short: %d", len(pkt)) + } + return pkt[headerLen:], nil +} + +func rewriteAuthRequestUser(payload []byte, expectedSubOp byte, newUser string) ([]byte, error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, fmt.Errorf("opcode: %w", err) + } + if op != TTCMsgAuthRequest { + return nil, fmt.Errorf("unexpected opcode 0x%02X", op) + } + sub, err := r.GetByte() + if err != nil { + return nil, err + } + if sub != expectedSubOp { + return nil, fmt.Errorf("unexpected sub-op 0x%02X (want 0x%02X)", sub, expectedSubOp) + } + if _, err := r.GetByte(); err != nil { + return nil, err + } + + hasUser, err := r.GetByte() + if err != nil { + return nil, err + } + if hasUser != 1 { + return payload, nil + } + origUserLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, fmt.Errorf("userLen: %w", err) + } + if origUserLen <= 0 { + return payload, nil + } + + middleStart := r.Pos() + + if _, err := r.GetInt(4, true, true); err != nil { + return nil, fmt.Errorf("mode: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker after mode: %w", err) + } + if _, err := r.GetInt(4, true, true); err != nil { + return nil, fmt.Errorf("count: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker 1: %w", err) + } + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("marker 2: %w", err) + } + middleEnd := r.Pos() + + // go-ora prefixes user bytes with a CLR-length byte; JDBC thin omits it. + peek, perr := r.PeekByte() + if perr != nil { + return nil, fmt.Errorf("peek user: %w", perr) + } + usedCLRPrefix := int(peek) == origUserLen && peek < 0x20 + if usedCLRPrefix { + if _, err := r.GetByte(); err != nil { + return nil, fmt.Errorf("consume user CLR length: %w", err) + } + } + if _, err := r.GetBytes(origUserLen); err != nil { + return nil, fmt.Errorf("user bytes: %w", err) + } + userEnd := r.Pos() + + newUserBytes := []byte(newUser) + newUserLen := len(newUserBytes) + + out := make([]byte, 0, len(payload)+16) + out = append(out, payload[:3]...) + out = append(out, 0x01) + lb := NewTTCBuilder() + lb.PutInt(int64(newUserLen), 4, true, true) + out = append(out, lb.Bytes()...) + out = append(out, payload[middleStart:middleEnd]...) + if usedCLRPrefix { + out = append(out, byte(newUserLen)) + } + out = append(out, newUserBytes...) + out = append(out, payload[userEnd:]...) + return out, nil +} + +func rewriteConnectServiceName(pkt []byte, newName string) []byte { + marker := []byte("SERVICE_NAME=") + idx := bytes.Index(pkt, marker) + if idx < 0 { + return pkt + } + valStart := idx + len(marker) + valEnd := bytes.IndexByte(pkt[valStart:], ')') + if valEnd < 0 { + return pkt + } + valEnd += valStart + + oldVal := pkt[valStart:valEnd] + newVal := []byte(newName) + if bytes.Equal(oldVal, newVal) { + return pkt + } + + out := make([]byte, 0, len(pkt)+len(newVal)-len(oldVal)) + out = append(out, pkt[:valStart]...) + out = append(out, newVal...) + out = append(out, pkt[valEnd:]...) + + binary.BigEndian.PutUint16(out[0:2], uint16(len(out))) + if len(out) >= 26 { + oldCDLen := binary.BigEndian.Uint16(pkt[24:26]) + binary.BigEndian.PutUint16(out[24:26], uint16(int(oldCDLen)+len(newVal)-len(oldVal))) + } + return out +} + +func rewritePhase1User(payload []byte, newUser string) ([]byte, error) { + return rewriteAuthRequestUser(payload, AuthSubOpPhaseOne, newUser) +} + +func rewritePhase2User(payload []byte, newUser string) ([]byte, error) { + return rewriteAuthRequestUser(payload, AuthSubOpPhaseTwo, newUser) +} + +type ProxyAuthState struct { + Salt []byte + Pbkdf2CSKSalt string + Pbkdf2VGenCount int + Pbkdf2SDerCount int + RealKey []byte + PlaceholderKey []byte + ServerSessKey []byte +} + +func translatePhase1Response(payload []byte, realPassword string) (*ProxyAuthState, []byte, error) { + kvs, trailer, err := parseAuthRespKVPList(payload) + if err != nil { + return nil, nil, fmt.Errorf("parse upstream phase 1: %w", err) + } + + var eSessKey, vfrData, cskSalt, vGenStr, sDerStr string + for _, kv := range kvs { + switch kv.Key { + case "AUTH_SESSKEY": + eSessKey = kv.Value + case "AUTH_VFR_DATA": + vfrData = kv.Value + case "AUTH_PBKDF2_CSK_SALT": + cskSalt = kv.Value + case "AUTH_PBKDF2_VGEN_COUNT": + vGenStr = kv.Value + case "AUTH_PBKDF2_SDER_COUNT": + sDerStr = kv.Value + } + } + if eSessKey == "" || vfrData == "" { + return nil, nil, fmt.Errorf("upstream phase 1 missing AUTH_SESSKEY or AUTH_VFR_DATA") + } + salt, err := hex.DecodeString(vfrData) + if err != nil { + return nil, nil, fmt.Errorf("decode salt: %w", err) + } + vGen, _ := strconv.Atoi(vGenStr) + if vGen == 0 { + vGen = 4096 + } + sDer, _ := strconv.Atoi(sDerStr) + if sDer == 0 { + sDer = 3 + } + + realKey, _, err := deriveServerKey(realPassword, salt, vGen) + if err != nil { + return nil, nil, fmt.Errorf("derive real key: %w", err) + } + placeholderKey, _, err := deriveServerKey(ProxyPasswordPlaceholder, salt, vGen) + if err != nil { + return nil, nil, fmt.Errorf("derive placeholder key: %w", err) + } + + serverSessKey, err := decryptSessionKey(false, realKey, eSessKey) + if err != nil { + return nil, nil, fmt.Errorf("decrypt upstream server session key: %w", err) + } + newESessKey, err := encryptSessionKey(false, placeholderKey, serverSessKey) + if err != nil { + return nil, nil, fmt.Errorf("re-encrypt server session key: %w", err) + } + + for i := range kvs { + if kvs[i].Key == "AUTH_SESSKEY" { + kvs[i].Value = newESessKey + break + } + } + + rebuilt := rebuildAuthRespPayload(kvs, trailer) + + state := &ProxyAuthState{ + Salt: salt, + Pbkdf2CSKSalt: cskSalt, + Pbkdf2VGenCount: vGen, + Pbkdf2SDerCount: sDer, + RealKey: realKey, + PlaceholderKey: placeholderKey, + ServerSessKey: serverSessKey, + } + return state, rebuilt, nil +} + +func translatePhase2Request(payload []byte, state *ProxyAuthState, realPassword string) ([]byte, error) { + p2, err := ParseAuthPhaseTwo(payload) + if err != nil { + return nil, fmt.Errorf("parse client phase 2: %w", err) + } + + if p2.EClientSessKey == "" || p2.EPassword == "" { + return nil, fmt.Errorf("client phase 2 missing AUTH_SESSKEY or AUTH_PASSWORD") + } + + clientSessKey, err := decryptSessionKey(false, state.PlaceholderKey, p2.EClientSessKey) + if err != nil { + return nil, fmt.Errorf("decrypt client session key: %w", err) + } + if len(clientSessKey) != len(state.ServerSessKey) { + return nil, fmt.Errorf("client session key length mismatch: got %d want %d", len(clientSessKey), len(state.ServerSessKey)) + } + newEClientSessKey, err := encryptSessionKey(false, state.RealKey, clientSessKey) + if err != nil { + return nil, fmt.Errorf("re-encrypt client session key: %w", err) + } + + // encKey derives from session keys + CSK salt, not the password. + encKey, err := deriveProxyPasswordEncKey(clientSessKey, state.ServerSessKey, state.Pbkdf2CSKSalt, state.Pbkdf2SDerCount) + if err != nil { + return nil, fmt.Errorf("derive enc key: %w", err) + } + // Verify the client used the placeholder password. Wrong password would also + // fail cryptographically in phase 1 (ORA-17452), but this gives a clearer error. + decoded, err := decryptSessionKey(true, encKey, p2.EPassword) + if err != nil { + return nil, fmt.Errorf("decrypt client password: %w", err) + } + if len(decoded) <= 16 || string(decoded[16:]) != ProxyPasswordPlaceholder { + return nil, fmt.Errorf("password mismatch") + } + newEPassword, err := encryptPassword([]byte(realPassword), encKey, true) + if err != nil { + return nil, fmt.Errorf("encrypt real password: %w", err) + } + + rebuilt, err := rebuildPhase2Request(payload, newEClientSessKey, newEPassword) + if err != nil { + return nil, fmt.Errorf("rebuild phase 2: %w", err) + } + return rebuilt, nil +} + +func deriveProxyPasswordEncKey(clientSessKey, serverSessKey []byte, pbkdf2CSKSaltHex string, sderCount int) ([]byte, error) { + buffer := append([]byte(nil), clientSessKey...) + buffer = append(buffer, serverSessKey...) + keyBuffer := []byte(fmt.Sprintf("%X", buffer)) + cskSalt, err := hex.DecodeString(pbkdf2CSKSaltHex) + if err != nil { + return nil, fmt.Errorf("decode pbkdf2 salt: %w", err) + } + full := generateSpeedyKey(cskSalt, keyBuffer, sderCount) + if len(full) < 32 { + return nil, fmt.Errorf("speedy key too short: %d", len(full)) + } + return full[:32], nil +} + +type parsedKVP struct { + Key string + Value string + Flag int +} + +func parseAuthRespKVPList(payload []byte) (kvs []parsedKVP, trailer []byte, err error) { + r := NewTTCReader(payload) + op, err := r.GetByte() + if err != nil { + return nil, nil, err + } + if op != 0x08 { + return nil, nil, fmt.Errorf("expected auth response opcode 0x08, got 0x%02X", op) + } + dictLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("dict len: %w", err) + } + for i := 0; i < dictLen; i++ { + keyLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d key len: %w", i, err) + } + var keyBytes []byte + if keyLen > 0 { + keyBytes, err = r.GetClr() + if err != nil { + return nil, nil, fmt.Errorf("kvp %d key: %w", i, err) + } + if len(keyBytes) > keyLen { + keyBytes = keyBytes[:keyLen] + } + } + valLen, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d val len: %w", i, err) + } + var valBytes []byte + if valLen > 0 { + valBytes, err = r.GetClr() + if err != nil { + return nil, nil, fmt.Errorf("kvp %d val: %w", i, err) + } + if len(valBytes) > valLen { + valBytes = valBytes[:valLen] + } + } + flag, err := r.GetInt(4, true, true) + if err != nil { + return nil, nil, fmt.Errorf("kvp %d flag: %w", i, err) + } + kvs = append(kvs, parsedKVP{ + Key: string(bytes.TrimRight(keyBytes, "\x00")), + Value: string(valBytes), + Flag: flag, + }) + } + trailer = make([]byte, r.Remaining()) + rem, _ := r.GetBytes(r.Remaining()) + copy(trailer, rem) + return kvs, trailer, nil +} + +func rebuildAuthRespPayload(kvs []parsedKVP, trailer []byte) []byte { + b := NewTTCBuilder() + b.PutBytes(0x08) + b.PutUint(uint64(len(kvs)), 4, true, true) + for _, kv := range kvs { + b.PutKeyValString(kv.Key, kv.Value, uint32(kv.Flag)) + } + b.PutBytes(trailer...) + return b.Bytes() +} + +func rebuildPhase2Request(payload []byte, newESessKey, newEPassword string) ([]byte, error) { + out := make([]byte, 0, len(payload)+128) + out = append(out, payload...) + + out, err := replaceKVPValue(out, "AUTH_SESSKEY", newESessKey) + if err != nil { + return nil, fmt.Errorf("replace AUTH_SESSKEY: %w", err) + } + out, err = replaceKVPValue(out, "AUTH_PASSWORD", newEPassword) + if err != nil { + return nil, fmt.Errorf("replace AUTH_PASSWORD: %w", err) + } + return out, nil +} + +func replaceKVPValue(payload []byte, key, newValue string) ([]byte, error) { + keyBytes := []byte(key) + idx := bytes.Index(payload, keyBytes) + if idx < 0 { + return nil, fmt.Errorf("key %q not found", key) + } + pos := idx + len(keyBytes) + if pos >= len(payload) { + return nil, fmt.Errorf("truncated after key") + } + vSizeByte := payload[pos] + pos++ + var vLen int + if vSizeByte == 0 { + vLen = 0 + } else if int(vSizeByte) <= 8 { + for i := 0; i < int(vSizeByte); i++ { + vLen = (vLen << 8) | int(payload[pos+i]) + } + pos += int(vSizeByte) + } else { + return nil, fmt.Errorf("invalid val_len size byte %d", vSizeByte) + } + if vLen > 0 { + if pos >= len(payload) || int(payload[pos]) != vLen { + return nil, fmt.Errorf("CLR length byte mismatch for %q: got %d want %d", key, payload[pos], vLen) + } + pos++ + valBodyStart := pos + valBodyEnd := valBodyStart + vLen + // PutClr handles chunked 0xFE form for values > 0xFC bytes. + newVal := []byte(newValue) + vb := NewTTCBuilder() + vb.PutUint(uint64(len(newVal)), 4, true, true) + vb.PutClr(newVal) + newValSection := vb.Bytes() + oldStart := idx + len(keyBytes) + oldEnd := valBodyEnd + out := make([]byte, 0, len(payload)+len(newValSection)) + out = append(out, payload[:oldStart]...) + out = append(out, newValSection...) + out = append(out, payload[oldEnd:]...) + return out, nil + } + return payload, fmt.Errorf("unexpected empty value for %q", key) +} diff --git a/packages/pam/handlers/oracle/query_logger.go b/packages/pam/handlers/oracle/query_logger.go new file mode 100644 index 00000000..2860b452 --- /dev/null +++ b/packages/pam/handlers/oracle/query_logger.go @@ -0,0 +1,276 @@ +package oracle + +import ( + "bytes" + "encoding/binary" + "fmt" + "sync" + "time" + + "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/rs/zerolog/log" +) + +const ( + ttcFuncOALL8 = 0x5E + ttcFuncOCOMMIT = 0x0E + ttcFuncORLLBK = 0x0F + ttcMsgFunction = 0x03 +) + +type pendingQuery struct { + sql string + timestamp time.Time +} + +// Best-effort SQL extraction from the byte stream. +type QueryExtractor struct { + logger session.SessionLogger + sessionID string + direction string + ch chan []byte + stopCh chan struct{} + wg sync.WaitGroup + use32Bit bool + pair *pairState +} + +type pairState struct { + mu sync.Mutex + pending *pendingQuery +} + +func NewQueryExtractorPair(logger session.SessionLogger, sessionID string, use32Bit bool) (clientToUpstream, upstreamToClient *QueryExtractor) { + p := &pairState{} + clientToUpstream = newExtractor(logger, sessionID, "client->upstream", use32Bit, p) + upstreamToClient = newExtractor(logger, sessionID, "upstream->client", use32Bit, p) + return +} + +func newExtractor(logger session.SessionLogger, sessionID, direction string, use32Bit bool, pair *pairState) *QueryExtractor { + e := &QueryExtractor{ + logger: logger, + sessionID: sessionID, + direction: direction, + ch: make(chan []byte, 64), + stopCh: make(chan struct{}), + use32Bit: use32Bit, + pair: pair, + } + e.wg.Add(1) + go e.loop() + return e +} + +func (e *QueryExtractor) Feed(data []byte) { + if len(data) == 0 { + return + } + cp := make([]byte, len(data)) + copy(cp, data) + select { + case e.ch <- cp: + default: + } +} + +func (e *QueryExtractor) Stop() { + close(e.stopCh) + e.wg.Wait() +} + +func (e *QueryExtractor) loop() { + defer e.wg.Done() + var buffer bytes.Buffer + + for { + select { + case <-e.stopCh: + return + case chunk := <-e.ch: + buffer.Write(chunk) + e.drain(&buffer) + } + } +} + +func (e *QueryExtractor) drain(buf *bytes.Buffer) { + for { + if buf.Len() < 8 { + return + } + head := buf.Bytes()[:8] + var length uint32 + if e.use32Bit { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } + if length < 8 || length > 16*1024*1024 { + buf.Reset() + return + } + if buf.Len() < int(length) { + return + } + packet := make([]byte, length) + if _, err := buf.Read(packet); err != nil { + return + } + e.handlePacket(packet) + } +} + +func (e *QueryExtractor) handlePacket(raw []byte) { + if PacketTypeOf(raw) != PacketTypeData { + return + } + d, err := ParseDataPacket(raw, e.use32Bit) + if err != nil { + return + } + if len(d.Payload) < 1 { + return + } + switch e.direction { + case "client->upstream": + e.handleClientRequest(d.Payload) + case "upstream->client": + e.handleServerResponse(d.Payload) + } +} + +func (e *QueryExtractor) handleClientRequest(payload []byte) { + // Clients often piggyback an OCLOSE before the new function call; scan for + // the function-call+opcode marker pair instead of parsing from offset 0. + if idx := findBytePair(payload, ttcMsgFunction, ttcFuncOALL8); idx >= 0 { + r := NewTTCReader(payload[idx+2:]) + if sqlText := tryExtractSQL(r); sqlText != "" { + e.pair.mu.Lock() + e.pair.pending = &pendingQuery{sql: sqlText, timestamp: time.Now()} + e.pair.mu.Unlock() + } + return + } + if findBytePair(payload, ttcMsgFunction, ttcFuncOCOMMIT) >= 0 { + e.recordLiteral("COMMIT") + return + } + if findBytePair(payload, ttcMsgFunction, ttcFuncORLLBK) >= 0 { + e.recordLiteral("ROLLBACK") + return + } +} + +func findBytePair(data []byte, b1, b2 byte) int { + for i := 0; i+1 < len(data); i++ { + if data[i] == b1 && data[i+1] == b2 { + return i + } + } + return -1 +} + +func (e *QueryExtractor) recordLiteral(sql string) { + e.pair.mu.Lock() + e.pair.pending = &pendingQuery{sql: sql, timestamp: time.Now()} + e.pair.mu.Unlock() +} + +// tryExtractSQL uses a longest-printable-run heuristic because OALL8 headers +// vary across client drivers and bind patterns. +func tryExtractSQL(r *TTCReader) string { + remaining := r.Remaining() + if remaining <= 0 { + return "" + } + buf, err := r.GetBytes(remaining) + if err != nil { + return "" + } + return longestPrintableRun(buf) +} + +func longestPrintableRun(data []byte) string { + bestStart, bestLen := 0, 0 + curStart, curLen := 0, 0 + for i, b := range data { + printable := b == '\t' || b == '\n' || b == '\r' || (b >= 0x20 && b <= 0x7E) + if printable { + if curLen == 0 { + curStart = i + } + curLen++ + if curLen > bestLen { + bestLen = curLen + bestStart = curStart + } + } else { + curLen = 0 + } + } + if bestLen < 4 { + return "" + } + return string(data[bestStart : bestStart+bestLen]) +} + +func (e *QueryExtractor) handleServerResponse(payload []byte) { + e.pair.mu.Lock() + pending := e.pair.pending + e.pair.pending = nil + e.pair.mu.Unlock() + if pending == nil { + return + } + output := extractResponseOutcome(payload) + err := e.logger.LogEntry(session.SessionLogEntry{ + Timestamp: pending.timestamp, + Input: pending.sql, + Output: output, + }) + if err != nil { + log.Debug().Err(err).Str("sessionID", e.sessionID).Msg("session log entry dropped") + } +} + +func extractResponseOutcome(payload []byte) string { + r := NewTTCReader(payload) + for r.Remaining() > 0 { + op, err := r.GetByte() + if err != nil { + break + } + if op == 0x04 { + for i := 0; i < 3; i++ { + if _, err := r.GetInt(4, true, true); err != nil { + return "OK" + } + } + code, err := r.GetInt(4, true, true) + if err != nil || code == 0 { + return "OK" + } + return ora(code) + } + } + return "" +} + +func ora(code int) string { + switch code { + case 0: + return "OK" + case 1: + return "ERROR: ORA-00001: unique constraint violated" + case 900: + return "ERROR: ORA-00900: invalid SQL statement" + case 942: + return "ERROR: ORA-00942: table or view does not exist" + case 1017: + return "ERROR: ORA-01017: invalid username/password" + case 28000: + return "ERROR: ORA-28000: the account is locked" + } + return fmt.Sprintf("ERROR: ORA-%05d", code) +} diff --git a/packages/pam/handlers/oracle/tns.go b/packages/pam/handlers/oracle/tns.go new file mode 100644 index 00000000..93d01f54 --- /dev/null +++ b/packages/pam/handlers/oracle/tns.go @@ -0,0 +1,115 @@ +package oracle + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +type PacketType uint8 + +const ( + PacketTypeConnect PacketType = 1 + PacketTypeAccept PacketType = 2 + PacketTypeRefuse PacketType = 4 + PacketTypeRedirect PacketType = 5 + PacketTypeData PacketType = 6 + PacketTypeResend PacketType = 11 + PacketTypeMarker PacketType = 12 +) + +// use32BitLen: 32-bit length framing after ACCEPT (version >= 315), 16-bit before. +func ReadFullPacket(r io.Reader, use32BitLen bool) ([]byte, error) { + head := make([]byte, 8) + if _, err := io.ReadFull(r, head); err != nil { + return nil, err + } + var length uint32 + if use32BitLen { + length = binary.BigEndian.Uint32(head) + } else { + length = uint32(binary.BigEndian.Uint16(head)) + } + if length < 8 { + return nil, fmt.Errorf("invalid TNS packet length: %d", length) + } + if length > 1<<22 { + return nil, fmt.Errorf("TNS packet too large: %d", length) + } + buf := make([]byte, length) + copy(buf, head) + if length > 8 { + if _, err := io.ReadFull(r, buf[8:]); err != nil { + return nil, err + } + } + return buf, nil +} + +func PacketTypeOf(packet []byte) PacketType { + if len(packet) < 5 { + return 0 + } + return PacketType(packet[4]) +} + +type DataPacket struct { + DataFlag uint16 + Payload []byte +} + +func ParseDataPacket(raw []byte, use32BitLen bool) (*DataPacket, error) { + if len(raw) < 10 || PacketType(raw[4]) != PacketTypeData { + return nil, errors.New("not a DATA packet") + } + return &DataPacket{ + DataFlag: binary.BigEndian.Uint16(raw[8:]), + Payload: append([]byte(nil), raw[10:]...), + }, nil +} + +func (d *DataPacket) Bytes(use32BitLen bool) []byte { + length := uint32(10 + len(d.Payload)) + out := make([]byte, length) + if use32BitLen { + binary.BigEndian.PutUint32(out, length) + } else { + binary.BigEndian.PutUint16(out, uint16(length)) + } + out[4] = byte(PacketTypeData) + out[5] = 0 + binary.BigEndian.PutUint16(out[8:], d.DataFlag) + copy(out[10:], d.Payload) + return out +} + +type RefusePacket struct { + UserReason uint8 + SystemReason uint8 + Message string +} + +func (r *RefusePacket) Bytes() []byte { + msg := []byte(r.Message) + length := uint32(12 + len(msg)) + out := make([]byte, length) + binary.BigEndian.PutUint16(out, uint16(length)) + out[4] = byte(PacketTypeRefuse) + out[5] = 0 + out[8] = r.UserReason + out[9] = r.SystemReason + binary.BigEndian.PutUint16(out[10:], uint16(len(msg))) + copy(out[12:], msg) + return out +} + +func WriteRefuseToClient(w io.Writer, message string) error { + pkt := &RefusePacket{ + UserReason: 0, + SystemReason: 0, + Message: message, + } + _, err := w.Write(pkt.Bytes()) + return err +} diff --git a/packages/pam/handlers/oracle/ttc.go b/packages/pam/handlers/oracle/ttc.go new file mode 100644 index 00000000..30461e59 --- /dev/null +++ b/packages/pam/handlers/oracle/ttc.go @@ -0,0 +1,302 @@ +package oracle + +import ( + "bytes" + "encoding/binary" + "errors" + "io" +) + +type TTCBuilder struct { + buf bytes.Buffer + useBigClrChunks bool + clrChunkSize int +} + +func NewTTCBuilder() *TTCBuilder { + return &TTCBuilder{useBigClrChunks: true, clrChunkSize: 0x7FFF} +} + +func (b *TTCBuilder) Bytes() []byte { return b.buf.Bytes() } + +func (b *TTCBuilder) PutBytes(data ...byte) { b.buf.Write(data) } + +func (b *TTCBuilder) PutUint(num uint64, size uint8, bigEndian, compress bool) { + if size == 1 { + b.buf.WriteByte(uint8(num)) + return + } + if compress { + temp := make([]byte, 8) + binary.BigEndian.PutUint64(temp, num) + temp = bytes.TrimLeft(temp, "\x00") + if size > uint8(len(temp)) { + size = uint8(len(temp)) + } + if size == 0 { + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(size) + b.buf.Write(temp) + return + } + temp := make([]byte, size) + if bigEndian { + switch size { + case 2: + binary.BigEndian.PutUint16(temp, uint16(num)) + case 4: + binary.BigEndian.PutUint32(temp, uint32(num)) + case 8: + binary.BigEndian.PutUint64(temp, num) + } + } else { + switch size { + case 2: + binary.LittleEndian.PutUint16(temp, uint16(num)) + case 4: + binary.LittleEndian.PutUint32(temp, uint32(num)) + case 8: + binary.LittleEndian.PutUint64(temp, num) + } + } + b.buf.Write(temp) +} + +func (b *TTCBuilder) PutInt(num int64, size uint8, bigEndian, compress bool) { + if compress { + temp := make([]byte, 8) + binary.BigEndian.PutUint64(temp, uint64(num)) + temp = bytes.TrimLeft(temp, "\x00") + if size > uint8(len(temp)) { + size = uint8(len(temp)) + } + if size == 0 { + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(size) + b.buf.Write(temp[:size]) + return + } + b.PutUint(uint64(num), size, bigEndian, false) +} + +func (b *TTCBuilder) PutClr(data []byte) { + dataLen := len(data) + if dataLen == 0 { + b.buf.WriteByte(0) + return + } + if dataLen > 0xFC { + b.buf.WriteByte(0xFE) + start := 0 + for start < dataLen { + end := start + b.clrChunkSize + if end > dataLen { + end = dataLen + } + chunk := data[start:end] + if b.useBigClrChunks { + b.PutInt(int64(len(chunk)), 4, true, true) + } else { + b.buf.WriteByte(uint8(len(chunk))) + } + b.buf.Write(chunk) + start += b.clrChunkSize + } + b.buf.WriteByte(0) + return + } + b.buf.WriteByte(uint8(dataLen)) + b.buf.Write(data) +} + +func (b *TTCBuilder) PutString(s string) { b.PutClr([]byte(s)) } + +func (b *TTCBuilder) PutKeyVal(key, val []byte, num uint32) { + if len(key) == 0 { + b.buf.WriteByte(0) + } else { + b.PutUint(uint64(len(key)), 4, true, true) + b.PutClr(key) + } + if len(val) == 0 { + b.buf.WriteByte(0) + } else { + b.PutUint(uint64(len(val)), 4, true, true) + b.PutClr(val) + } + b.PutInt(int64(num), 4, true, true) +} + +func (b *TTCBuilder) PutKeyValString(key, val string, num uint32) { + b.PutKeyVal([]byte(key), []byte(val), num) +} + +type TTCReader struct { + buf []byte + pos int + useBigClrChunks bool +} + +func NewTTCReader(payload []byte) *TTCReader { + return &TTCReader{buf: payload, useBigClrChunks: true} +} + +func (r *TTCReader) Remaining() int { return len(r.buf) - r.pos } + +func (r *TTCReader) Pos() int { return r.pos } + +func (r *TTCReader) read(n int) ([]byte, error) { + if r.pos+n > len(r.buf) { + return nil, io.ErrUnexpectedEOF + } + out := r.buf[r.pos : r.pos+n] + r.pos += n + return out, nil +} + +func (r *TTCReader) GetByte() (uint8, error) { + b, err := r.read(1) + if err != nil { + return 0, err + } + return b[0], nil +} + +func (r *TTCReader) PeekByte() (uint8, error) { + if r.pos >= len(r.buf) { + return 0, io.ErrUnexpectedEOF + } + return r.buf[r.pos], nil +} + +func (r *TTCReader) GetBytes(n int) ([]byte, error) { + b, err := r.read(n) + if err != nil { + return nil, err + } + out := make([]byte, len(b)) + copy(out, b) + return out, nil +} + +func (r *TTCReader) GetInt64(size int, compress, bigEndian bool) (int64, error) { + negFlag := false + if compress { + sb, err := r.read(1) + if err != nil { + return 0, err + } + size = int(sb[0]) + if size&0x80 > 0 { + negFlag = true + size = size & 0x7F + } + bigEndian = true + } + if size == 0 { + return 0, nil + } + if size > 8 { + return 0, errors.New("invalid size for GetInt64") + } + rb, err := r.read(size) + if err != nil { + return 0, err + } + temp := make([]byte, 8) + var v int64 + if bigEndian { + copy(temp[8-size:], rb) + v = int64(binary.BigEndian.Uint64(temp)) + } else { + copy(temp[:size], rb) + v = int64(binary.LittleEndian.Uint64(temp)) + } + if negFlag { + v = -v + } + return v, nil +} + +func (r *TTCReader) GetInt(size int, compress, bigEndian bool) (int, error) { + v, err := r.GetInt64(size, compress, bigEndian) + return int(v), err +} + +func (r *TTCReader) GetClr() ([]byte, error) { + nb, err := r.GetByte() + if err != nil { + return nil, err + } + if nb == 0 || nb == 0xFF || nb == 0xFD { + return nil, nil + } + if nb != 0xFE { + out, err := r.read(int(nb)) + if err != nil { + return nil, err + } + ret := make([]byte, len(out)) + copy(ret, out) + return ret, nil + } + var buf bytes.Buffer + for { + var chunkSize int + if r.useBigClrChunks { + chunkSize, err = r.GetInt(4, true, true) + } else { + b, err2 := r.GetByte() + err = err2 + chunkSize = int(b) + } + if err != nil { + return nil, err + } + if chunkSize == 0 { + break + } + chunk, err := r.read(chunkSize) + if err != nil { + return nil, err + } + buf.Write(chunk) + } + return buf.Bytes(), nil +} + +func (r *TTCReader) GetDlc() ([]byte, error) { + length, err := r.GetInt(4, true, true) + if err != nil { + return nil, err + } + if length <= 0 { + _, _ = r.GetClr() + return nil, nil + } + out, err := r.GetClr() + if err != nil { + return nil, err + } + if len(out) > length { + out = out[:length] + } + return out, nil +} + +func (r *TTCReader) GetKeyVal() (key, val []byte, num int, err error) { + key, err = r.GetDlc() + if err != nil { + return + } + val, err = r.GetDlc() + if err != nil { + return + } + num, err = r.GetInt(4, true, true) + return +} diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go index 0d2657af..8d9a6195 100644 --- a/packages/pam/handlers/ssh/proxy.go +++ b/packages/pam/handlers/ssh/proxy.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Infisical/infisical-merge/packages/pam/session" + "github.com/Infisical/infisical-merge/packages/util" "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) @@ -26,6 +27,7 @@ type SSHProxyConfig struct { SessionID string SessionLogger session.SessionLogger BlockedCommandPatterns []*regexp.Regexp // Regex patterns for command blocking (nil = no blocking) + OnActivity func() // Called when channel data flows } // SSHProxy handles proxying SSH connections with credential injection @@ -123,6 +125,29 @@ func (p *SSHProxy) HandleConnection(ctx context.Context, clientConn net.Conn) er // Discard global requests (not needed for basic remote access) go ssh.DiscardRequests(clientRequests) + // SSH keepalive: detect dead connections where TCP goes silent. Probes both sides every 30s + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := util.SSHKeepalive(clientSSHConn, 15*time.Second); err != nil { + log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to client failed, tearing down connection") + clientConn.Close() + return + } + if err := util.SSHKeepalive(serverSSHConn, 15*time.Second); err != nil { + log.Info().Err(err).Str("sessionID", sessionID).Msg("SSH keepalive to target failed, tearing down connection") + clientConn.Close() + return + } + case <-ctx.Done(): + return + } + } + }() + // Handle channels from client (this is where actual SSH sessions happen) for newChannel := range clientChannels { go p.handleChannel(ctx, newChannel, serverSSHConn, sessionID) @@ -500,6 +525,10 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses for { n, err := src.Read(buf) if n > 0 { + if p.config.OnActivity != nil { + p.config.OnActivity() + } + // Check if this channel is a binary session (SFTP/SCP) chState.mutex.Lock() isBinary := chState.isBinarySession @@ -744,6 +773,10 @@ func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer, for { n, err := src.Read(buf) if n > 0 { + if p.config.OnActivity != nil { + p.config.OnActivity() + } + chState.mutex.Lock() isBinary := chState.isBinarySession sftpParser := chState.sftpParser diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index 9f51f4e3..c418795c 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -10,6 +10,7 @@ import ( "syscall" "time" + "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" "github.com/Infisical/infisical-merge/packages/pam/session" "github.com/Infisical/infisical-merge/packages/util" "github.com/go-resty/resty/v2" @@ -125,6 +126,10 @@ func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, p util.PrintfStderr("sqlserver://%s@localhost:%d?database=%s&encrypt=false&trustServerCertificate=true", username, proxy.port, database) case session.ResourceTypeMongodb: util.PrintfStderr("mongodb://localhost:%d/%s?serverSelectionTimeoutMS=15000", proxy.port, database) + case session.ResourceTypeOracledb: + util.PrintfStderr("%s/%s@localhost:%d/%s", username, oracle.ProxyPasswordPlaceholder, proxy.port, database) + util.PrintfStderr("\njdbc:oracle:thin:@localhost:%d/%s (user: %s, password: %s)", proxy.port, database, username, oracle.ProxyPasswordPlaceholder) + util.PrintfStderr("\n\nNote: the password shown is a protocol placeholder required by Oracle, not a secret.") default: util.PrintfStderr("localhost:%d", proxy.port) } diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index 8d374539..3dd81882 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -18,6 +18,7 @@ import ( "github.com/Infisical/infisical-merge/packages/pam/handlers/mongodb" "github.com/Infisical/infisical-merge/packages/pam/handlers/mssql" "github.com/Infisical/infisical-merge/packages/pam/handlers/mysql" + "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" "github.com/Infisical/infisical-merge/packages/pam/handlers/rdp" "github.com/Infisical/infisical-merge/packages/pam/handlers/redis" "github.com/Infisical/infisical-merge/packages/pam/handlers/ssh" @@ -38,6 +39,7 @@ type GatewayPAMConfig struct { CredentialsManager *session.CredentialsManager SessionUploader *session.SessionUploader GetMongoProxy MongoProxyGetter // Session-level MongoDB proxy sharing + OnActivity func() // Called on data flow } type PAMCapabilitiesResponse struct { @@ -54,6 +56,7 @@ func GetSupportedResourceTypes() []string { session.ResourceTypeKubernetes, session.ResourceTypeRedis, session.ResourceTypeMongodb, + session.ResourceTypeOracledb, } // Only advertise RDP when the real bridge is compiled in. A stub // build would otherwise accept RDP session routing and fail every @@ -104,19 +107,40 @@ func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *Gatew // Kill the active proxy connection if it exists in the registry if cancelled := cancelSession(pamConfig.SessionId); cancelled { log.Info().Str("sessionId", pamConfig.SessionId).Msg("Active proxy session cancelled via registry") + if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { + log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") + } } else { log.Info().Str("sessionId", pamConfig.SessionId).Msg("No active proxy session found in registry (may have already ended)") } - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") - } - conn.Close() return nil } +// activityConn wraps a net.Conn and calls onActivity on every successful read or write +type activityConn struct { + net.Conn + onActivity func() +} + +func (c *activityConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + if n > 0 { + c.onActivity() + } + return n, err +} + +func (c *activityConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + if n > 0 { + c.onActivity() + } + return n, err +} + // compilePolicyPatterns compiles regex pattern strings, logging warnings for any that fail. func compilePolicyPatterns(config *api.PAMPolicyRuleConfig, sessionID string, ruleType string) []*regexp.Regexp { if config == nil || len(config.Patterns) == 0 { @@ -161,10 +185,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session expired, closing connection") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "expiry"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session on expiry") - } - conn.Close() case <-ctx.Done(): // Context cancelled, exit gracefully @@ -177,10 +197,6 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Time("expiryTime", pamConfig.ExpiryTime). Msg("PAM session already expired, closing connection immediately") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "already_expired"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup already expired PAM session") - } - conn.Close() } }() @@ -242,6 +258,12 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo } } + // Wrap the connection so every read/write resets the idle reaper timer + var handlerConn net.Conn = conn + if pamConfig.OnActivity != nil { + handlerConn = &activityConn{Conn: conn, onActivity: pamConfig.OnActivity} + } + switch pamConfig.ResourceType { case session.ResourceTypePostgres: proxyConfig := handlers.PostgresProxyConfig{ @@ -260,7 +282,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", proxyConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting PostgreSQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMysql: mysqlConfig := mysql.MysqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -279,7 +301,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", mysqlConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MySQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMssql: mssqlConfig := mssql.MssqlProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -298,7 +320,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", mssqlConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting MSSQL PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeRedis: redisConfig := redis.RedisProxyConfig{ TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), @@ -316,7 +338,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", redisConfig.TargetAddr). Bool("sslEnabled", credentials.SSLEnabled). Msg("Starting Redis PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeSSH: // Compile command blocking patterns from policy rules var blockedCommandPatterns []*regexp.Regexp @@ -334,6 +356,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo SessionID: pamConfig.SessionId, SessionLogger: sessionLogger, BlockedCommandPatterns: blockedCommandPatterns, + OnActivity: pamConfig.OnActivity, } proxy := ssh.NewSSHProxy(sshConfig) log.Info(). @@ -387,7 +410,25 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("target", kubernetesConfig.TargetApiServer). Str("authMethod", credentials.AuthMethod). Msg("Starting Kubernetes PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) + case session.ResourceTypeOracledb: + oracleConfig := oracle.OracleProxyConfig{ + TargetAddr: fmt.Sprintf("%s:%d", credentials.Host, credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectDatabase: credentials.Database, + EnableTLS: credentials.SSLEnabled, + TLSConfig: tlsConfig, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + } + proxy := oracle.NewOracleProxy(oracleConfig) + log.Info(). + Str("sessionId", pamConfig.SessionId). + Str("target", oracleConfig.TargetAddr). + Bool("sslEnabled", credentials.SSLEnabled). + Msg("Starting Oracle PAM proxy") + return proxy.HandleConnection(ctx, handlerConn) case session.ResourceTypeMongodb: mongoConfig := mongodb.MongoDBProxyConfig{ Host: credentials.ConnectionString, @@ -412,7 +453,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo return fmt.Errorf("MongoDB proxy init: %w", err) } - return proxy.HandleConnection(ctx, conn, sessionLogger) + return proxy.HandleConnection(ctx, handlerConn, sessionLogger) case session.ResourceTypeWindows: if credentials.Port <= 0 || credentials.Port > 65535 { return fmt.Errorf("rdp: target port %d out of range", credentials.Port) @@ -432,7 +473,7 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo Str("sessionId", pamConfig.SessionId). Str("target", fmt.Sprintf("%s:%d", credentials.Host, credentials.Port)). Msg("Starting RDP PAM proxy") - return proxy.HandleConnection(ctx, conn) + return proxy.HandleConnection(ctx, handlerConn) default: return fmt.Errorf("unsupported resource type: %s", pamConfig.ResourceType) } diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index cafb196c..41ce36d4 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -31,7 +31,8 @@ const ( ResourceTypeSSH = "ssh" ResourceTypeKubernetes = "kubernetes" ResourceTypeMongodb = "mongodb" - ResourceTypeWindows = "windows" + ResourceTypeOracledb = "oracledb" + ResourceTypeWindows = "windows" ) type SessionFileInfo struct { @@ -75,7 +76,7 @@ func NewSessionUploader(httpClient *resty.Client, credentialsManager *Credential func ParseSessionFilename(filename string) (*SessionFileInfo, error) { // Try new format first: pam_session_{sessionID}_{resourceType}_expires_{timestamp}.enc // Build regex pattern using constants - resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeMssql, ResourceTypeKubernetes, ResourceTypeMongodb, ResourceTypeWindows) + resourceTypePattern := fmt.Sprintf("(%s|%s|%s|%s|%s|%s|%s|%s|%s)", ResourceTypeSSH, ResourceTypePostgres, ResourceTypeRedis, ResourceTypeMysql, ResourceTypeMssql, ResourceTypeKubernetes, ResourceTypeMongodb, ResourceTypeOracledb, ResourceTypeWindows) newFormatRegex := regexp.MustCompile(fmt.Sprintf(`^pam_session_(.+)_%s_expires_(\d+)\.enc$`, resourceTypePattern)) matches := newFormatRegex.FindStringSubmatch(filename) diff --git a/packages/util/ssh.go b/packages/util/ssh.go new file mode 100644 index 00000000..d207a75d --- /dev/null +++ b/packages/util/ssh.go @@ -0,0 +1,23 @@ +package util + +import ( + "fmt" + "time" + + "golang.org/x/crypto/ssh" +) + +// SSHKeepalive sends an SSH keepalive request and waits up to timeout for a response +func SSHKeepalive(conn ssh.Conn, timeout time.Duration) error { + errCh := make(chan error, 1) + go func() { + _, _, err := conn.SendRequest("keepalive@openssh.com", true, nil) + errCh <- err + }() + select { + case err := <-errCh: + return err + case <-time.After(timeout): + return fmt.Errorf("no keepalive response within %v", timeout) + } +}