diff --git a/packages/api/model.go b/packages/api/model.go index 5dd3fe95..c1b5fea2 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -958,8 +958,7 @@ type UploadSessionLogEntry struct { Output string `json:"output"` } -// UploadTerminalEvent represents a terminal session event for upload -type UploadTerminalEvent struct { +type UploadSessionEvent struct { Timestamp time.Time `json:"timestamp"` EventType string `json:"eventType"` ChannelType string `json:"channelType,omitempty"` @@ -979,7 +978,7 @@ type UploadHttpEvent struct { } type UploadPAMSessionLogsRequest struct { - Logs interface{} `json:"logs"` // Can be []UploadSessionLogEntry or []UploadTerminalEvent + Logs interface{} `json:"logs"` // Can be []UploadSessionLogEntry or []UploadSessionEvent } type RelayHeartbeatRequest struct { diff --git a/packages/gateway-v2/gateway.go b/packages/gateway-v2/gateway.go index b8f4e127..7220cd9b 100644 --- a/packages/gateway-v2/gateway.go +++ b/packages/gateway-v2/gateway.go @@ -89,6 +89,7 @@ type GatewayConfig struct { type pamSessionEntry struct { cancel context.CancelFunc conn *tls.Conn + done chan struct{} // closed when HandlePAMProxy has fully returned for this entry lastActivity atomic.Int64 } @@ -170,7 +171,7 @@ func NewGateway(config *GatewayConfig) (*Gateway, error) { // RegisterPAMSession registers an active PAM proxy connection for cancellation support // Returns a function that handlers should call when data flows through the connection func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc, conn *tls.Conn) func() { - entry := &pamSessionEntry{cancel: cancel, conn: conn} + entry := &pamSessionEntry{cancel: cancel, conn: conn, done: make(chan struct{})} entry.lastActivity.Store(time.Now().Unix()) g.pamSessionsMu.Lock() @@ -189,23 +190,55 @@ func (g *Gateway) RegisterPAMSession(sessionID string, cancel context.CancelFunc // The proxy is cleaned up on session cancellation or gateway shutdown. func (g *Gateway) DeregisterPAMSession(sessionID string, conn *tls.Conn) bool { g.pamSessionsMu.Lock() - defer g.pamSessionsMu.Unlock() - entries, exists := g.pamSessions[sessionID] if !exists { + g.pamSessionsMu.Unlock() return false } + var removed *pamSessionEntry for i, e := range entries { if e.conn == conn { + removed = e g.pamSessions[sessionID] = append(entries[:i], entries[i+1:]...) break } } - if len(g.pamSessions[sessionID]) == 0 { + isLast := len(g.pamSessions[sessionID]) == 0 + if isLast { delete(g.pamSessions, sessionID) - return true } - return false + g.pamSessionsMu.Unlock() + if removed != nil { + close(removed.done) + } + return isLast +} + +// Cancels prior entries and waits for them to clean up. RDP needs serial +// bridges so drain writes don't interleave into the recording file. +func (g *Gateway) evictExistingPAMSessions(sessionID string, timeout time.Duration) { + g.pamSessionsMu.Lock() + prior := g.pamSessions[sessionID] + g.pamSessionsMu.Unlock() + if len(prior) == 0 { + return + } + log.Info().Str("sessionId", sessionID).Int("priorCount", len(prior)). + Msg("Evicting existing PAM connections before starting new RDP bridge") + for _, e := range prior { + _ = e.conn.Close() + e.cancel() + } + deadline := time.After(timeout) + for _, e := range prior { + select { + case <-e.done: + case <-deadline: + log.Warn().Str("sessionId", sessionID). + Msg("Timed out waiting for prior PAM connection to clean up; proceeding anyway") + return + } + } } // CancelPAMSession kills all active connections for a PAM session @@ -222,6 +255,7 @@ func (g *Gateway) CancelPAMSession(sessionID string) bool { for _, e := range entries { e.conn.Close() e.cancel() + close(e.done) } g.closeMongoProxy(sessionID) return true @@ -833,6 +867,11 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { } return } else if forwardConfig.Mode == ForwardModePAM { + // RDP only: prior bridge must fully tear down before the new one starts, + // else overlapping drains write non-monotonic elapsedMs to the recording. + if forwardConfig.PAMConfig.ResourceType == session.ResourceTypeWindows { + g.evictExistingPAMSessions(forwardConfig.PAMConfig.SessionId, 5*time.Second) + } sessionCtx, sessionCancel := context.WithCancel(g.ctx) touchSession := g.RegisterPAMSession(forwardConfig.PAMConfig.SessionId, sessionCancel, tlsConn) forwardConfig.PAMConfig.OnActivity = touchSession @@ -844,7 +883,11 @@ func (g *Gateway) handleIncomingChannel(newChannel ssh.NewChannel) { } } sessionCancel() - if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn { + // RDP reconnects via a stable .rdp file within the session's validity + // window; terminating on disconnect would break that. Idle reaper / + // expiry / explicit cancel still end the session normally. + isRDP := forwardConfig.PAMConfig.ResourceType == session.ResourceTypeWindows + if lastConn := g.DeregisterPAMSession(forwardConfig.PAMConfig.SessionId, tlsConn); lastConn && !isRDP { if err := forwardConfig.PAMConfig.SessionUploader.CleanupPAMSession( forwardConfig.PAMConfig.SessionId, "connection_closed", ); err != nil { diff --git a/packages/pam/handlers/rdp/bridge.go b/packages/pam/handlers/rdp/bridge.go index f582c864..17970ccc 100644 --- a/packages/pam/handlers/rdp/bridge.go +++ b/packages/pam/handlers/rdp/bridge.go @@ -15,3 +15,44 @@ type Bridge struct { handle uint64 cleanup func() } + +// EventType discriminates the variants in Event. +type EventType uint8 + +const ( + EventTypeKeyboard EventType = 1 + EventTypeUnicode EventType = 2 + EventTypeMouse EventType = 3 + EventTypeTargetFrame EventType = 4 +) + +// Action identifies the RDP framing of a TargetFrame event. +type Action uint8 + +const ( + ActionX224 Action = 0 + ActionFastPath Action = 1 +) + +// Fields are reused across variants; switch on Type. +type Event struct { + Type EventType + ElapsedNs uint64 + Scancode uint8 + CodePoint uint16 + X uint16 + Y uint16 + Flags uint32 + WheelDelta int32 + Action Action + Payload []byte +} + +// PollResult discriminates PollEvent outcomes. +type PollResult uint8 + +const ( + PollOK PollResult = 0 + PollTimeout PollResult = 1 + PollEnded PollResult = 2 +) diff --git a/packages/pam/handlers/rdp/bridge_cgo_shared.go b/packages/pam/handlers/rdp/bridge_cgo_shared.go index 9c8d8fba..e6816821 100644 --- a/packages/pam/handlers/rdp/bridge_cgo_shared.go +++ b/packages/pam/handlers/rdp/bridge_cgo_shared.go @@ -5,6 +5,7 @@ package rdp /* #cgo CFLAGS: -I${SRCDIR}/native/include +#include #include "rdp_bridge.h" */ import "C" @@ -14,6 +15,8 @@ import ( "errors" "fmt" "net" + "time" + "unsafe" ) func (p *RDPProxy) HandleConnection(ctx context.Context, clientConn net.Conn) error { @@ -37,16 +40,40 @@ func (p *RDPProxy) HandleConnection(ctx context.Context, clientConn net.Conn) er } defer bridge.Close() + // Drain bridge tap events into the session logger. The Rust side closes + // the events channel when the session ends, so the goroutine exits via + // PollEnded without needing an explicit shutdown signal. + drainCtx, cancelDrain := context.WithCancel(ctx) + drainDone := make(chan struct{}) + go func() { + defer close(drainDone) + drainBridgeEvents(drainCtx, bridge, p.config.SessionLogger, p.config.SessionID, p.config.PriorElapsedNs, p.config.SessionUploader) + }() + // Wait for the drain to finish naturally on the normal-end path so the + // tail of the recording isn't dropped: PollEnded fires after the Rust + // side closes the events channel (post bridge.Wait return). Cancellation + // paths trigger cancelDrain() explicitly below to bail early. + defer func() { + select { + case <-drainDone: + case <-time.After(2 * pollTimeout): + } + // Always release the drain context (no-op if already cancelled). + cancelDrain() + }() + waitErr := make(chan error, 1) go func() { waitErr <- bridge.Wait() }() select { case err := <-waitErr: if err != nil && !errors.Is(err, ErrInvalidHandle) { + cancelDrain() return fmt.Errorf("rdp proxy: session: %w", err) } return nil case <-ctx.Done(): + cancelDrain() _ = bridge.Cancel() <-waitErr return ctx.Err() @@ -90,8 +117,62 @@ func (b *Bridge) Close() error { return nil } -// IsSupported reports whether this build has a real RDP bridge. Used -// by the gateway to decide whether to advertise RDP in the capabilities -// response: a stub-build gateway that advertises support would route -// RDP sessions only to fail them at connect time. +// True when the real bridge is compiled in (vs the stub). func IsSupported() bool { return true } + +// PollEvent drains one tap event with the given timeout. The returned Event +// is only meaningful when result == PollOK. PollEvent is not safe to call +// concurrently for the same Bridge; serialize calls in a single goroutine. +func (b *Bridge) PollEvent(timeout time.Duration) (PollResult, Event, error) { + timeoutMs := timeout.Milliseconds() + if timeoutMs < 0 { + timeoutMs = 0 + } + if timeoutMs > int64(^C.uint32_t(0)) { + timeoutMs = int64(^C.uint32_t(0)) + } + + var raw C.struct_RdpEvent + rc := C.rdp_bridge_poll_event(C.uint64_t(b.handle), &raw, C.uint32_t(timeoutMs)) + + switch rc { + case C.RDP_POLL_OK: + // fall through to event materialization below + case C.RDP_POLL_TIMEOUT: + return PollTimeout, Event{}, nil + case C.RDP_POLL_ENDED: + return PollEnded, Event{}, nil + case C.RDP_POLL_INVALID_HANDLE: + return PollEnded, Event{}, ErrInvalidHandle + default: + return PollEnded, Event{}, fmt.Errorf("rdp bridge: poll returned unexpected status %d", int32(rc)) + } + + ev := Event{ + Type: EventType(uint8(raw.event_type)), + ElapsedNs: uint64(raw.elapsed_ns), + Flags: uint32(raw.flags), + WheelDelta: int32(raw.wheel_delta), + Action: Action(uint8(raw.action)), + } + switch ev.Type { + case EventTypeKeyboard: + ev.Scancode = uint8(raw.value_a) + case EventTypeUnicode: + ev.CodePoint = uint16(raw.value_a) + case EventTypeMouse: + ev.X = uint16(raw.value_a) + ev.Y = uint16(raw.value_b) + case EventTypeTargetFrame: + // Always free the libc-malloc'd buffer Rust handed us, even if + // the copy below is empty -- ownership transfer is unconditional. + if raw.payload_ptr != nil { + defer C.free(unsafe.Pointer(raw.payload_ptr)) + if raw.payload_len > 0 { + ev.Payload = C.GoBytes(unsafe.Pointer(raw.payload_ptr), C.int(raw.payload_len)) + } + } + } + + return PollOK, ev, nil +} diff --git a/packages/pam/handlers/rdp/bridge_cgo_unix.go b/packages/pam/handlers/rdp/bridge_cgo_unix.go index 37e7d2ee..f5d3f454 100644 --- a/packages/pam/handlers/rdp/bridge_cgo_unix.go +++ b/packages/pam/handlers/rdp/bridge_cgo_unix.go @@ -72,19 +72,8 @@ func startWithDupedFD(dupFd int, targetHost string, targetPort uint16, username, return &Bridge{handle: uint64(handle)}, nil } -// StartWithReadWriter adapts an fd-less Go byte stream (e.g. *tls.Conn -// from the gateway's mTLS-wrapped virtual connection) to the bridge, -// which needs a real file descriptor because the Rust side uses tokio's -// TcpStream::from_raw_fd and does direct async I/O on the socket. -// -// Trick: open a loopback TCP pair. Hand one end's fd to the bridge (it -// thinks it has a real client). Keep the other end in Go and shuttle -// bytes between it and rw with two io.Copy goroutines. -// -// rw (e.g. *tls.Conn) <-io.Copy-> peer <-kernel loopback-> accepted (fd -> Rust bridge) -// -// Cost: two extra in-process copies and a loopback round-trip per byte. -// Negligible vs. the TLS + CredSSP work on either side. +// Adapts an fd-less Go byte stream to the Rust bridge (which needs a real fd +// for tokio's TcpStream::from_raw_fd) by routing through a loopback TCP pair. func StartWithReadWriter(rw io.ReadWriter, targetHost string, targetPort uint16, username, password, domain string) (*Bridge, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/packages/pam/handlers/rdp/bridge_stub.go b/packages/pam/handlers/rdp/bridge_stub.go index 2c488000..28da7815 100644 --- a/packages/pam/handlers/rdp/bridge_stub.go +++ b/packages/pam/handlers/rdp/bridge_stub.go @@ -6,6 +6,7 @@ import ( "context" "io" "net" + "time" ) // Stub implementations for builds without `-tags rdp` or on platforms @@ -29,6 +30,10 @@ func (b *Bridge) Wait() error { return ErrRdpUnavailable } func (b *Bridge) Cancel() error { return ErrRdpUnavailable } func (b *Bridge) Close() error { return ErrRdpUnavailable } +func (b *Bridge) PollEvent(_ time.Duration) (PollResult, Event, error) { + return PollEnded, Event{}, ErrRdpUnavailable +} + // IsSupported reports whether this build has a real RDP bridge. See the // rdp-enabled counterpart in bridge_cgo_shared.go for details. func IsSupported() bool { return false } diff --git a/packages/pam/handlers/rdp/native/Cargo.lock b/packages/pam/handlers/rdp/native/Cargo.lock index 5c04a3e5..c4652505 100644 --- a/packages/pam/handlers/rdp/native/Cargo.lock +++ b/packages/pam/handlers/rdp/native/Cargo.lock @@ -1309,9 +1309,11 @@ dependencies = [ "bytes", "ironrdp-acceptor", "ironrdp-connector", + "ironrdp-core", "ironrdp-pdu", "ironrdp-tls", "ironrdp-tokio", + "libc", "libz-sys", "rcgen", "rustls", diff --git a/packages/pam/handlers/rdp/native/Cargo.toml b/packages/pam/handlers/rdp/native/Cargo.toml index 500a2117..cb53a5d2 100644 --- a/packages/pam/handlers/rdp/native/Cargo.toml +++ b/packages/pam/handlers/rdp/native/Cargo.toml @@ -13,10 +13,12 @@ path = "src/lib.rs" [dependencies] ironrdp-acceptor = "0.8" ironrdp-connector = "0.8" +ironrdp-core = "0.1" ironrdp-tokio = { version = "0.8", features = ["reqwest"] } ironrdp-pdu = "0.7" ironrdp-tls = { version = "0.2", features = ["rustls"] } x509-cert = { version = "0.2", features = ["std"] } +libc = "0.2" tokio = { version = "1", features = ["full"] } tokio-util = "0.7" diff --git a/packages/pam/handlers/rdp/native/include/rdp_bridge.h b/packages/pam/handlers/rdp/native/include/rdp_bridge.h index 65200f5f..150c6b6f 100644 --- a/packages/pam/handlers/rdp/native/include/rdp_bridge.h +++ b/packages/pam/handlers/rdp/native/include/rdp_bridge.h @@ -1,8 +1,5 @@ -/* - * infisical-rdp-bridge C ABI. See ffi.rs for details. Lifecycle: - * start_* -> wait -> free; cancel may be called from any thread. - * start_* transfers ownership of the client fd/socket to the bridge. - */ +/* C ABI; see ffi.rs. Lifecycle: start_* -> wait -> free. start_* takes + * ownership of the client fd/socket. cancel is thread-safe. */ #ifndef INFISICAL_RDP_BRIDGE_H #define INFISICAL_RDP_BRIDGE_H @@ -51,6 +48,35 @@ int32_t rdp_bridge_wait(uint64_t handle); int32_t rdp_bridge_cancel(uint64_t handle); int32_t rdp_bridge_free(uint64_t handle); +/* Poll return codes (distinct number space from the bridge status codes + * above; consumed by rdp_bridge_poll_event only). */ +#define RDP_POLL_OK 0 +#define RDP_POLL_TIMEOUT 1 +#define RDP_POLL_ENDED 2 +#define RDP_POLL_INVALID_HANDLE -1 + +/* Event type discriminator. */ +#define RDP_EVENT_KEYBOARD 1 +#define RDP_EVENT_UNICODE 2 +#define RDP_EVENT_MOUSE 3 +#define RDP_EVENT_TARGET_FRAME 4 + +/* Fields reused across variants; check event_type. For TargetFrame, + * payload_ptr is libc-malloc'd and the Go caller must C.free it. */ +struct RdpEvent { + uint8_t event_type; + uint64_t elapsed_ns; + uint32_t value_a; + uint32_t value_b; + uint32_t flags; + int32_t wheel_delta; + uint8_t action; + uint8_t *payload_ptr; + uint32_t payload_len; +}; + +int32_t rdp_bridge_poll_event(uint64_t handle, struct RdpEvent *out, uint32_t timeout_ms); + #ifdef __cplusplus } #endif diff --git a/packages/pam/handlers/rdp/native/src/bridge.rs b/packages/pam/handlers/rdp/native/src/bridge.rs index 1eac7b1e..d7ccab1a 100644 --- a/packages/pam/handlers/rdp/native/src/bridge.rs +++ b/packages/pam/handlers/rdp/native/src/bridge.rs @@ -1,30 +1,46 @@ -//! MITM bridge. Runs acceptor + connector only through CredSSP (to inject -//! credentials), then byte-forwards between the two TLS streams. Letting -//! client and target negotiate MCS/capabilities/share-state directly -//! avoids drift that breaks strict clients (Windows App, mstsc). +//! MITM bridge. Runs acceptor + connector through CredSSP only, then byte- +//! forwards. Letting client/target negotiate MCS directly avoids drift +//! that breaks strict clients (Windows App, mstsc). +use std::borrow::Cow; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; +use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use ironrdp_acceptor::{Acceptor, BeginResult}; use ironrdp_connector::credssp::{CredsspSequence, KerberosConfig}; use ironrdp_connector::sspi::credssp::ClientState; use ironrdp_connector::sspi::generator::GeneratorState; -use ironrdp_connector::{encode_x224_packet, ClientConnector, ClientConnectorState}; +use ironrdp_connector::{encode_x224_packet, ClientConnector, ClientConnectorState, Credentials}; +use ironrdp_core::ReadCursor; use ironrdp_pdu::gcc::ConferenceCreateRequest; -use ironrdp_pdu::ironrdp_core::{decode, WriteBuf}; -use ironrdp_pdu::mcs::ConnectInitial; -use ironrdp_pdu::nego::SecurityProtocol; +use ironrdp_pdu::input::fast_path::{FastPathInput, FastPathInputEvent}; +use ironrdp_pdu::ironrdp_core::{decode, encode_buf, DecodeOwned as _, WriteBuf}; +use ironrdp_pdu::mcs::{ConnectInitial, SendDataRequest}; +use ironrdp_pdu::nego::{ + ConnectionConfirm, ConnectionRequest, NegoRequestData, RequestFlags, SecurityProtocol, +}; use ironrdp_pdu::rdp::client_info::Credentials as AcceptorCredentials; +use ironrdp_pdu::rdp::headers::{ShareControlHeader, ShareControlPdu}; use ironrdp_pdu::x224::{X224Data, X224}; +use ironrdp_pdu::Action; use ironrdp_tokio::reqwest::ReqwestNetworkClient; use ironrdp_tokio::{FramedWrite, NetworkClient}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; +use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; -use tracing::info; +use tracing::{debug, info, warn}; +use crate::cap_filter; use crate::config::{connector_config, DEFAULT_HEIGHT, DEFAULT_WIDTH}; +use crate::events::{elapsed_ns_since, EventSender, SessionEvent}; + +/// Cap on c2t PDUs to inspect before giving up on the cap filter. +const CONFIRM_ACTIVE_SCAN_MAX_PDUS: usize = 32; +/// Wall-clock cap on the cap-filter scan window. +const CONFIRM_ACTIVE_SCAN_MAX_DURATION: Duration = Duration::from_secs(5); // The acceptor side of the bridge expects the user to type the target // username with an empty password. The real password is injected by the @@ -44,9 +60,10 @@ pub async fn run_mitm( client_tcp: TcpStream, target: TargetEndpoint, cancel: CancellationToken, + tx: EventSender, ) -> Result<()> { tokio::select! { - result = run_mitm_inner(client_tcp, target) => result, + result = run_mitm_inner(client_tcp, target, tx) => result, _ = cancel.cancelled() => { info!("session canceled by caller"); Ok(()) @@ -54,7 +71,11 @@ pub async fn run_mitm( } } -async fn run_mitm_inner(client_tcp: TcpStream, target: TargetEndpoint) -> Result<()> { +async fn run_mitm_inner( + client_tcp: TcpStream, + target: TargetEndpoint, + tx: EventSender, +) -> Result<()> { // Our tree pulls both ring (direct) and aws-lc-rs (via reqwest); rustls // 0.23 needs an explicit provider when more than one is compiled in. let _ = rustls::crypto::ring::default_provider().install_default(); @@ -62,16 +83,12 @@ async fn run_mitm_inner(client_tcp: TcpStream, target: TargetEndpoint) -> Result let acceptor_username = target.username.clone(); let (acceptor_output, connector_output) = tokio::try_join!( run_acceptor_half(client_tcp, acceptor_username), - run_connector_half(target) + run_connector_half(target), )?; let (mut client_stream, client_leftover) = acceptor_output; let (mut target_stream, target_leftover) = connector_output; - // Strip virtual channels (clipboard, drives, audio, USB, etc.) from the - // client's MCS Connect Initial before forwarding. Mouse/keyboard/screen - // ride the implicit MCS I/O channel, not virtual channels, so they're - // unaffected. filter_client_mcs_connect_initial(&mut client_stream, &mut target_stream, client_leftover) .await .context("filter client MCS Connect Initial")?; @@ -94,24 +111,410 @@ async fn run_mitm_inner(client_tcp: TcpStream, target: TargetEndpoint) -> Result .await .context("flush target stream before passthrough")?; - // Real RDP clients hard-close TCP without TLS close_notify, which - // rustls surfaces as UnexpectedEof. Treat that as clean shutdown. - match tokio::io::copy_bidirectional(&mut client_stream, &mut target_stream).await { - Ok(_) => info!("session ended cleanly"), - Err(e) if is_unexpected_eof(&e) => info!("session ended (peer hard-closed)"), - Err(e) => return Err(e).context("passthrough copy_bidirectional"), + // PDU-framed bridge with an event tap. read_pdu is pure TPKT/FastPath + // framing (no state machine) so this preserves the "no MCS drift" + // property of the byte-level copy_bidirectional it replaces. + let client_framed = ironrdp_tokio::TokioFramed::new(client_stream); + let target_framed = ironrdp_tokio::TokioFramed::new(target_stream); + bridge_pdus(client_framed, target_framed, tx).await +} + +async fn bridge_pdus( + client_framed: ironrdp_tokio::TokioFramed, + target_framed: ironrdp_tokio::TokioFramed, + tx: EventSender, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, +{ + let (mut client_read, mut client_write) = ironrdp_tokio::split_tokio_framed(client_framed); + let (mut target_read, mut target_write) = ironrdp_tokio::split_tokio_framed(target_framed); + + // Recording starts when the first FastPath frame arrives from the target + // (actual bitmap/surface data). Pre-visual negotiation PDUs are forwarded + // but not recorded, eliminating the black-screen preamble. + // The offset stores the nanosecond mark when recording activates so all + // recorded timestamps are relative to the first visual frame. + const NOT_ACTIVE: u64 = u64::MAX; + let recording_offset_ns = Arc::new(AtomicU64::new(NOT_ACTIVE)); + let recording_offset_c2t = recording_offset_ns.clone(); + let started_at = Instant::now(); + let tx_c2t = tx.clone(); + let tx_t2c = tx; + + let c2t = async move { + let mut cap_filter = CapFilterState::Scanning { + started_at: Instant::now(), + pdus_seen: 0, + info_done: false, + confirm_done: false, + }; + loop { + let (action, frame) = match client_read.read_pdu().await { + Ok(v) => v, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err::<_, anyhow::Error>(e.into()), + }; + let offset = recording_offset_c2t.load(Ordering::Relaxed); + if offset != NOT_ACTIVE { + tap_client_to_target(action, &frame, started_at, offset, &tx_c2t); + } + + let bytes_to_forward: Vec = match cap_filter.consider(action, &frame) { + CapFilterDecision::Forward => frame.to_vec(), + CapFilterDecision::Replace(modified) => modified, + }; + target_write + .write_all(&bytes_to_forward) + .await + .context("write client PDU to target")?; + } + Ok(()) + }; + + let t2c = async move { + loop { + let (action, frame) = match target_read.read_pdu().await { + Ok(v) => v, + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break, + Err(e) => return Err::<_, anyhow::Error>(e.into()), + }; + let mut offset = recording_offset_ns.load(Ordering::Relaxed); + if offset == NOT_ACTIVE && action == Action::FastPath { + offset = elapsed_ns_since(started_at); + recording_offset_ns.store(offset, Ordering::Relaxed); + debug!( + skip_ms = offset / 1_000_000, + "first FastPath target frame, recording starts" + ); + } + if offset != NOT_ACTIVE { + tap_target_to_client(action, &frame, started_at, offset, &tx_t2c); + } + client_write + .write_all(&frame) + .await + .context("write target PDU to client")?; + } + Ok(()) + }; + + // select! (not try_join!) so the first branch to EOF cancels the other: + // try_join! waits for both to complete on Ok, but on a normal client + // disconnect the t2c read_pdu blocks indefinitely on a quiet target. + // Dropping the cancelled future releases its read+write halves; with + // the opposite branch already done, the underlying stream's Drop closes + // the socket and the peer observes the half-close. + let result = tokio::select! { + r = c2t => r, + r = t2c => r, + }; + match result { + Ok(_) => { + info!("session ended cleanly"); + Ok(()) + } + Err(e) => Err(e).context("bridge_pdus"), } - Ok(()) } -fn is_unexpected_eof(err: &std::io::Error) -> bool { - err.kind() == std::io::ErrorKind::UnexpectedEof +/// One-shot c2t scan that patches Client Info + Client Confirm Active. +enum CapFilterState { + Scanning { + started_at: Instant, + pdus_seen: usize, + info_done: bool, + confirm_done: bool, + }, + Done, +} + +enum CapFilterDecision { + Forward, + Replace(Vec), } -// Reads the client's MCS Connect Initial PDU, removes any virtual channels -// declared in its Client Network Data block, and forwards the rewritten PDU -// to the target. Any bytes after the PDU (rare; PDUs typically arrive one at -// a time at this stage) are forwarded unchanged. +impl CapFilterState { + fn consider(&mut self, action: Action, frame: &[u8]) -> CapFilterDecision { + let CapFilterState::Scanning { + started_at, + pdus_seen, + info_done, + confirm_done, + } = self + else { + return CapFilterDecision::Forward; + }; + + if action != Action::X224 { + return CapFilterDecision::Forward; + } + + *pdus_seen += 1; + if *pdus_seen > CONFIRM_ACTIVE_SCAN_MAX_PDUS + || started_at.elapsed() > CONFIRM_ACTIVE_SCAN_MAX_DURATION + { + warn!( + pdus_seen, + info_done = *info_done, + confirm_done = *confirm_done, + "scan window exhausted before both filters fired" + ); + *self = CapFilterState::Done; + return CapFilterDecision::Forward; + } + + // The two filters are disjoint, so a match short-circuits. + if !*info_done { + if let Some(modified) = try_filter_client_info(frame) { + *info_done = true; + let both_done = *info_done && *confirm_done; + if both_done { + *self = CapFilterState::Done; + } + return CapFilterDecision::Replace(modified); + } + } + if !*confirm_done { + if let Some(modified) = try_filter_confirm_active(frame) { + *confirm_done = true; + let both_done = *info_done && *confirm_done; + if both_done { + *self = CapFilterState::Done; + } + return CapFilterDecision::Replace(modified); + } + } + CapFilterDecision::Forward + } +} + +#[derive(Debug, Clone, Copy)] +struct ByteRange { + offset: usize, + len: usize, +} + +impl ByteRange { + fn slice<'a>(&self, frame: &'a [u8]) -> &'a [u8] { + &frame[self.offset..self.offset + self.len] + } + + fn slice_mut<'a>(&self, frame: &'a mut [u8]) -> &'a mut [u8] { + &mut frame[self.offset..self.offset + self.len] + } +} + +/// Locate `send_data.user_data` inside `frame`. Bails on Cow::Owned. +fn user_data_range_within(frame: &[u8], send_data: &SendDataRequest<'_>) -> Option { + let slice: &[u8] = match &send_data.user_data { + Cow::Borrowed(s) => s, + Cow::Owned(_) => return None, + }; + let frame_start = frame.as_ptr() as usize; + let slice_start = slice.as_ptr() as usize; + if slice_start < frame_start || slice_start + slice.len() > frame_start + frame.len() { + return None; + } + Some(ByteRange { + offset: slice_start - frame_start, + len: slice.len(), + }) +} + +fn locate_client_info(frame: &[u8]) -> Option { + const SEC_INFO_PKT: u16 = 0x0040; + let send_data = decode::>>(frame).ok()?.0; + let user_data = user_data_range_within(frame, &send_data)?; + if user_data.len < 4 { + return None; + } + let bytes = user_data.slice(frame); + let sec_flags = u16::from_le_bytes([bytes[0], bytes[1]]); + (sec_flags & SEC_INFO_PKT != 0).then_some(user_data) +} + +struct ConfirmActiveLayout { + user_data: ByteRange, + caps_start_in_user_data: usize, +} + +fn locate_confirm_active(frame: &[u8]) -> Option { + let send_data = decode::>>(frame).ok()?.0; + let share_control = decode::(send_data.user_data.as_ref()).ok()?; + if !matches!( + share_control.share_control_pdu, + ShareControlPdu::ClientConfirmActive(_), + ) { + return None; + } + let user_data = user_data_range_within(frame, &send_data)?; + let caps_start_in_user_data = parse_confirm_active_caps_start(user_data.slice(frame))?; + Some(ConfirmActiveLayout { + user_data, + caps_start_in_user_data, + }) +} + +/// MS-RDPBCGR 2.2.1.13.2.1: ShareControlHeader(10) + originatorId(2) + +/// sourceDescLen(2) + combinedLen(2) + sourceDescriptor(var) + numCaps(2) + pad(2) +fn parse_confirm_active_caps_start(user_data: &[u8]) -> Option { + let mut p = 10 + 2; + if user_data.len() < p + 4 { + return None; + } + let source_desc_len = u16::from_le_bytes([user_data[p], user_data[p + 1]]) as usize; + p += 4 + source_desc_len + 4; + (p <= user_data.len()).then_some(p) +} + +fn try_filter_client_info(frame: &[u8]) -> Option> { + let user_data = locate_client_info(frame)?; + let mut out = frame.to_vec(); + if !cap_filter::client_info::clear_compression(user_data.slice_mut(&mut out)) { + return None; + } + debug!("Client Info PDU: cleared INFO_COMPRESSION + CompressionTypeMask"); + Some(out) +} + +fn try_filter_confirm_active(frame: &[u8]) -> Option> { + let layout = locate_confirm_active(frame)?; + let user_data_bytes = layout.user_data.slice(frame); + + let mut order_body_offset_in_frame: Option = None; + let mut codecs_body_offset_in_frame: Option = None; + for cap in cap_filter::walk_caps(user_data_bytes, layout.caps_start_in_user_data) { + let body_offset_in_frame = layout.user_data.offset + cap.body_offset_in_user_data; + match cap.cap_type { + cap_filter::cap_types::ORDER if cap.cap_len >= cap_filter::order_cap::BODY_LEN + 4 => { + order_body_offset_in_frame = Some(body_offset_in_frame); + } + cap_filter::cap_types::BITMAP_CODECS + if cap.cap_len >= cap_filter::bitmap_codecs_cap::MIN_BODY_LEN + 4 => + { + codecs_body_offset_in_frame = Some(body_offset_in_frame); + } + _ => {} + } + } + + // Without Order patched, server emits unrenderable Orders. + let order_offset = order_body_offset_in_frame?; + let mut out = frame.to_vec(); + cap_filter::order_cap::clear_order_support( + &mut out[order_offset..order_offset + cap_filter::order_cap::BODY_LEN], + ); + if let Some(codecs_offset) = codecs_body_offset_in_frame { + cap_filter::bitmap_codecs_cap::clear_codec_count(&mut out[codecs_offset..]); + } + debug!("Confirm Active: cleared Order support + BitmapCodecs count"); + Some(out) +} + +fn tap_client_to_target( + action: Action, + frame: &[u8], + started_at: Instant, + offset_ns: u64, + tx: &EventSender, +) { + if action != Action::FastPath { + return; + } + // Microsoft's Mac client sets spurious header flags that IronRDP + // rejects; mask them off before decoding (forwarded bytes are untouched). + let mut sanitized: Vec; + let bytes_for_decode: &[u8] = if frame.first().copied().unwrap_or(0) & 0xC0 != 0 { + sanitized = frame.to_vec(); + sanitized[0] &= 0x3F; + &sanitized + } else { + frame + }; + let input: FastPathInput = match decode_fast_path_input(bytes_for_decode) { + Ok(input) => input, + Err(_) => return, + }; + let elapsed_ns = elapsed_ns_since(started_at).saturating_sub(offset_ns); + for event in input.input_events() { + let session_event = match *event { + FastPathInputEvent::KeyboardEvent(flags, scancode) => SessionEvent::KeyboardInput { + scancode, + flags, + elapsed_ns, + }, + FastPathInputEvent::UnicodeKeyboardEvent(flags, code_point) => { + SessionEvent::UnicodeInput { + code_point, + flags, + elapsed_ns, + } + } + FastPathInputEvent::MouseEvent(pdu) => SessionEvent::MouseInput { + x: pdu.x_position, + y: pdu.y_position, + flags: pdu.flags, + wheel_delta: pdu.number_of_wheel_rotation_units, + elapsed_ns, + }, + // Windows clients route most mouse moves through MouseEventEx + // (XButton-aware variant). Replay only needs x/y to position the + // cursor; the X-button flags don't map onto PointerFlags so we + // surface MouseEventEx as a MouseInput with empty flags. + FastPathInputEvent::MouseEventEx(pdu) => SessionEvent::MouseInput { + x: pdu.x_position, + y: pdu.y_position, + flags: ironrdp_pdu::input::mouse::PointerFlags::empty(), + wheel_delta: 0, + elapsed_ns, + }, + // MouseEventRel, QoeEvent, SyncEvent: skip; uncommon and not + // needed for replay V1. + _ => continue, + }; + // try_send: never block the bridge byte stream on a slow consumer. + // Errors mean either Full (drop the input event; rare under typical + // sub-1k events/sec input rates) or Closed (poll loop exited; bridge + // keeps forwarding bytes regardless). + if let Err(e) = tx.try_send(session_event) { + if matches!(e, mpsc::error::TrySendError::Full(_)) { + warn!("session event channel full, dropping input event"); + } + } + } +} + +fn tap_target_to_client( + action: Action, + frame: &[u8], + started_at: Instant, + offset_ns: u64, + tx: &EventSender, +) { + if let Err(e) = tx.try_send(SessionEvent::TargetFrame { + action, + payload: frame.to_vec(), + elapsed_ns: elapsed_ns_since(started_at).saturating_sub(offset_ns), + }) { + if matches!(e, mpsc::error::TrySendError::Full(_)) { + warn!("session event channel full, dropping target frame"); + } + } +} + +fn decode_fast_path_input(frame: &[u8]) -> anyhow::Result { + use ironrdp_core::Decode as _; + let mut cursor = ReadCursor::new(frame); + FastPathInput::decode(&mut cursor).map_err(|e| anyhow::anyhow!("decode FastPathInput: {e}")) +} + +// Decode + mutate + re-encode the client's MCS Connect Initial: +// - set CS_CORE.serverSelectedProtocol to HYBRID_EX (FreeRDP echoes the +// wrong value, and target servers reject mismatched echoes) +// - clear CS_NET.channels so the target doesn't try to open virtual +// channels (clipboard, drives, audio, USB) the bridge can't service async fn filter_client_mcs_connect_initial( client_stream: &mut ErasedStream, target_stream: &mut ErasedStream, @@ -157,19 +560,9 @@ async fn filter_client_mcs_connect_initial( .map_err(|e| anyhow::anyhow!("decode MCS Connect Initial: {e:?}"))?; let mut gcc_blocks = connect_initial.conference_create_request.into_gcc_blocks(); + gcc_blocks.core.optional_data.server_selected_protocol = Some(SecurityProtocol::HYBRID_EX); if let Some(network) = gcc_blocks.network.as_mut() { - let stripped: Vec = network - .channels - .iter() - .map(|c| c.name.as_str().unwrap_or("?").to_owned()) - .collect(); - if !stripped.is_empty() { - info!( - ?stripped, - "stripped virtual channels from MCS Connect Initial" - ); - network.channels.clear(); - } + network.channels.clear(); } connect_initial.conference_create_request = ConferenceCreateRequest::new(gcc_blocks) .map_err(|e| anyhow::anyhow!("rebuild ConferenceCreateRequest: {e:?}"))?; @@ -205,7 +598,6 @@ async fn run_acceptor_half( password: ACCEPTOR_PASSWORD.to_owned(), domain: None, }; - // Capabilities/desktop-size are shape-fillers; we never call accept_finalize. let mut acceptor = Acceptor::new( SecurityProtocol::HYBRID_EX | SecurityProtocol::HYBRID | SecurityProtocol::SSL, ironrdp_acceptor::DesktopSize { @@ -269,16 +661,20 @@ async fn run_connector_half(target: TargetEndpoint) -> Result<(ErasedStream, byt ); let mut connector = ClientConnector::new(config, client_addr); - let should_upgrade = ironrdp_tokio::connect_begin(&mut target_framed, &mut connector) + // Request the same protocol set native clients send so the target's + // ServerCoreData.clientRequestedProtocols echo matches what they expect. + let request_set = + SecurityProtocol::HYBRID_EX | SecurityProtocol::HYBRID | SecurityProtocol::SSL; + connector_x224_with_protocol(&mut target_framed, &mut connector, request_set) .await - .context("connector: connect_begin")?; + .context("connector: X.224 init")?; let (initial_stream, leftover) = target_framed.into_inner(); let (upgraded_stream, tls_cert) = ironrdp_tls::upgrade(initial_stream, &target.host) .await .context("connector: TLS upgrade")?; - let _upgraded = ironrdp_tokio::mark_as_upgraded(should_upgrade, &mut connector); + connector.mark_security_upgrade_as_done(); let erased: ErasedStream = Box::new(upgraded_stream); let mut target_framed = ironrdp_tokio::TokioFramed::new_with_leftover(erased, leftover); @@ -303,6 +699,71 @@ async fn run_connector_half(target: TargetEndpoint) -> Result<(ErasedStream, byt Ok(target_framed.into_inner()) } +// Drive the X.224 negotiation with the caller-chosen protocol set, then +// transition the connector into EnhancedSecurityUpgrade so the rest of +// the pipeline (TLS upgrade + CredSSP) proceeds normally. ironrdp's +// connect_begin hardcodes HYBRID|HYBRID_EX, which doesn't match the set +// native clients (Windows App, mstsc) advertise. +async fn connector_x224_with_protocol( + framed: &mut ironrdp_tokio::TokioFramed, + connector: &mut ClientConnector, + requested: SecurityProtocol, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, +{ + // Mirror what ironrdp's connect_begin includes: routing cookie with the + // username, which some Windows targets / load balancers expect. + let nego_data = + connector + .config + .request_data + .clone() + .or_else(|| match &connector.config.credentials { + Credentials::UsernamePassword { username, .. } if !username.is_empty() => { + Some(NegoRequestData::cookie(username.clone())) + } + _ => None, + }); + let request = ConnectionRequest { + nego_data, + flags: RequestFlags::empty(), + protocol: requested, + }; + + let mut buf = WriteBuf::new(); + encode_buf(&X224(request), &mut buf) + .map_err(|e| anyhow::anyhow!("encode X.224 connection request: {e:?}"))?; + framed + .write_all(buf.filled()) + .await + .context("write X.224 connection request")?; + + let pdu = framed + .read_pdu() + .await + .context("read X.224 connection confirm")?; + let confirm = ConnectionConfirm::decode_owned(&mut ReadCursor::new(&pdu.1)) + .map_err(|e| anyhow::anyhow!("decode X.224 connection confirm: {e:?}"))?; + + let selected_protocol = match confirm { + ConnectionConfirm::Response { protocol, .. } => protocol, + ConnectionConfirm::Failure { code } => { + anyhow::bail!("X.224 negotiation failure: {:?}", code); + } + }; + if !requested.contains(selected_protocol) { + anyhow::bail!( + "target selected protocol {:?} not in requested set {:?}", + selected_protocol, + requested + ); + } + + connector.state = ClientConnectorState::EnhancedSecurityUpgrade { selected_protocol }; + Ok(()) +} + // Replicated from ironrdp-async's private perform_credssp_step so we can // stop before connect_finalize (which would start MCS/capability exchange). async fn perform_connector_credssp( @@ -415,3 +876,50 @@ pub trait AsyncReadWrite: AsyncRead + AsyncWrite {} impl AsyncReadWrite for T where T: AsyncRead + AsyncWrite {} pub type ErasedStream = Box; + +#[cfg(test)] +mod tests { + use super::*; + + /// Build a synthetic ConfirmActive user_data prefix: + /// ShareControlHeader(10) + originatorId(2) + sourceDescLen(2) + + /// combinedLen(2) + sourceDescriptor(source_desc_len) + numCaps(2) + pad(2) + fn confirm_active_prefix(source_desc_len: usize) -> Vec { + let mut buf = vec![0xAA_u8; 10 + 2]; + buf.extend_from_slice(&(source_desc_len as u16).to_le_bytes()); + buf.extend_from_slice(&0xBBBB_u16.to_le_bytes()); + buf.extend_from_slice(&vec![0xCC; source_desc_len]); + buf.extend_from_slice(&0xDDDD_u16.to_le_bytes()); + buf.extend_from_slice(&0xEEEE_u16.to_le_bytes()); + buf + } + + #[test] + fn caps_start_after_variable_source_descriptor() { + let user_data = confirm_active_prefix(6); + let p = parse_confirm_active_caps_start(&user_data).expect("caps start"); + assert_eq!(p, 12 + 4 + 6 + 4); + assert_eq!(p, user_data.len()); + } + + #[test] + fn caps_start_works_when_source_descriptor_is_empty() { + let user_data = confirm_active_prefix(0); + let p = parse_confirm_active_caps_start(&user_data).expect("caps start"); + assert_eq!(p, 12 + 4 + 4); + } + + #[test] + fn caps_start_returns_none_when_header_truncated() { + let user_data = vec![0u8; 15]; + assert!(parse_confirm_active_caps_start(&user_data).is_none()); + } + + #[test] + fn caps_start_returns_none_when_source_desc_len_overflows() { + let mut user_data = vec![0u8; 12]; + user_data.extend_from_slice(&9999_u16.to_le_bytes()); + user_data.extend_from_slice(&[0u8; 2]); + assert!(parse_confirm_active_caps_start(&user_data).is_none()); + } +} diff --git a/packages/pam/handlers/rdp/native/src/cap_filter.rs b/packages/pam/handlers/rdp/native/src/cap_filter.rs new file mode 100644 index 00000000..d2076669 --- /dev/null +++ b/packages/pam/handlers/rdp/native/src/cap_filter.rs @@ -0,0 +1,276 @@ +//! Byte-level patches for Confirm Active / Client Info PDUs. +//! IronRDP's typed decode->encode loses unrelated fields, so we mutate in place. + +/// MS-RDPBCGR 2.2.7 +pub mod cap_types { + pub const ORDER: u16 = 0x0003; + pub const BITMAP_CODECS: u16 = 0x001d; +} + +/// MS-RDPBCGR 2.2.7.1.3 +pub mod order_cap { + use std::ops::Range; + + pub const BODY_LEN: usize = 84; + pub const ORDER_SUPPORT: Range = 32..64; + + /// Forces server to fall back to Bitmap updates. + /// orderFlags untouched so NEGOTIATEORDERSUPPORT (mandatory) stays set. + pub fn clear_order_support(body: &mut [u8]) { + body[ORDER_SUPPORT].fill(0); + } +} + +/// MS-RDPBCGR 2.2.7.2.10 +pub mod bitmap_codecs_cap { + pub const CODEC_COUNT_OFFSET: usize = 0; + pub const MIN_BODY_LEN: usize = 1; + + /// Prevents server from picking RFX/NSCodec/AVC. + pub fn clear_codec_count(body: &mut [u8]) { + body[CODEC_COUNT_OFFSET] = 0; + } +} + +/// MS-RDPBCGR 2.2.1.11.1.1, given user_data of an MCS Send Data Request +/// whose security header has SEC_INFO_PKT set. +pub mod client_info { + use std::ops::Range; + + /// 4 bytes security header + 4 bytes CodePage. + pub const FLAGS: Range = 8..12; + pub const INFO_COMPRESSION: u32 = 0x0000_0080; + pub const COMPRESSION_TYPE_MASK: u32 = 0x0000_1E00; + + /// Disables MPPC bulk compression (IronRDP-session can't decompress it). + pub fn clear_compression(user_data: &mut [u8]) -> bool { + if user_data.len() < FLAGS.end { + return false; + } + let bytes: [u8; 4] = match user_data[FLAGS.clone()].try_into() { + Ok(b) => b, + Err(_) => return false, + }; + let flags = u32::from_le_bytes(bytes); + let new_flags = flags & !(INFO_COMPRESSION | COMPRESSION_TYPE_MASK); + if flags == new_flags { + return false; + } + user_data[FLAGS.clone()].copy_from_slice(&new_flags.to_le_bytes()); + true + } +} + +#[derive(Debug, Clone, Copy)] +pub struct WalkedCap { + pub cap_type: u16, + pub cap_len: usize, + pub body_offset_in_user_data: usize, +} + +/// Stops on a malformed cap header. +pub fn walk_caps(user_data: &[u8], caps_start: usize) -> CapIter<'_> { + CapIter { + user_data, + cursor: caps_start, + } +} + +pub struct CapIter<'a> { + user_data: &'a [u8], + cursor: usize, +} + +impl<'a> Iterator for CapIter<'a> { + type Item = WalkedCap; + + fn next(&mut self) -> Option { + if self.cursor + 4 > self.user_data.len() { + return None; + } + let cap_type = + u16::from_le_bytes([self.user_data[self.cursor], self.user_data[self.cursor + 1]]); + let cap_len = u16::from_le_bytes([ + self.user_data[self.cursor + 2], + self.user_data[self.cursor + 3], + ]) as usize; + if cap_len < 4 || self.cursor + cap_len > self.user_data.len() { + return None; + } + let item = WalkedCap { + cap_type, + cap_len, + body_offset_in_user_data: self.cursor + 4, + }; + self.cursor += cap_len; + Some(item) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn order_clear_zeros_only_the_support_array() { + let mut body = vec![0xff_u8; order_cap::BODY_LEN]; + order_cap::clear_order_support(&mut body); + assert_eq!(&body[order_cap::ORDER_SUPPORT], &[0; 32]); + assert_eq!(&body[28..32], &[0xff; 4]); + assert_eq!(&body[64..68], &[0xff; 4]); + } + + #[test] + fn bitmap_codecs_clears_only_first_byte() { + let mut body = vec![0xff_u8; 16]; + bitmap_codecs_cap::clear_codec_count(&mut body); + assert_eq!(body[0], 0); + assert_eq!(&body[1..], &[0xff; 15]); + } + + #[test] + fn client_info_clears_compression_bits() { + let mut user_data = vec![0u8; 12]; + user_data[8..12].copy_from_slice(&0x0000_1E80_u32.to_le_bytes()); + assert!(client_info::clear_compression(&mut user_data)); + let new_flags = u32::from_le_bytes(user_data[8..12].try_into().unwrap()); + assert_eq!(new_flags, 0); + } + + #[test] + fn client_info_noop_when_compression_already_off() { + let mut user_data = vec![0u8; 12]; + user_data[8..12].copy_from_slice(&0x0000_0040_u32.to_le_bytes()); + assert!(!client_info::clear_compression(&mut user_data)); + } + + #[test] + fn client_info_returns_false_when_user_data_too_short() { + let mut user_data = vec![0u8; 11]; + assert!(!client_info::clear_compression(&mut user_data)); + } + + #[test] + fn client_info_preserves_unrelated_flag_bits() { + let mut user_data = vec![0xAB_u8; 12]; + // INFO_COMPRESSION + CompressionTypeMask + INFO_AUTOLOGON(0x0008) + INFO_UNICODE(0x0010) + let original = 0x0000_1E80_u32 | 0x0000_0008 | 0x0000_0010; + user_data[8..12].copy_from_slice(&original.to_le_bytes()); + assert!(client_info::clear_compression(&mut user_data)); + let new_flags = u32::from_le_bytes(user_data[8..12].try_into().unwrap()); + assert_eq!(new_flags, 0x0000_0008 | 0x0000_0010); + assert_eq!(&user_data[..8], &[0xAB; 8]); + } + + #[test] + fn walk_caps_iterates_each_cap() { + let mut user_data = vec![0u8; 8]; + user_data.extend_from_slice(&[0x01, 0x00, 0x08, 0x00, 0xaa, 0xbb, 0xcc, 0xdd]); + user_data.extend_from_slice(&[ + 0x03, 0x00, 0x0c, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, + ]); + let caps: Vec<_> = walk_caps(&user_data, 8).collect(); + assert_eq!(caps.len(), 2); + assert_eq!(caps[0].cap_type, 0x0001); + assert_eq!(caps[0].cap_len, 8); + assert_eq!(caps[0].body_offset_in_user_data, 12); + assert_eq!(caps[1].cap_type, 0x0003); + assert_eq!(caps[1].cap_len, 12); + assert_eq!(caps[1].body_offset_in_user_data, 20); + } + + #[test] + fn walk_caps_stops_on_malformed_header() { + let mut user_data = vec![0u8; 4]; + user_data.extend_from_slice(&[0x01, 0x00, 0x64, 0x00]); + let caps: Vec<_> = walk_caps(&user_data, 4).collect(); + assert_eq!(caps.len(), 0); + } + + #[test] + fn walk_caps_stops_on_cap_len_below_header_size() { + let user_data = vec![0x01, 0x00, 0x02, 0x00]; + let caps: Vec<_> = walk_caps(&user_data, 0).collect(); + assert_eq!(caps.len(), 0); + } + + /// End-to-end byte-preservation contract: walk a synthetic caps array + /// containing Order, BitmapCodecs, and an unrelated cap; patch only + /// the targeted fields; assert every other byte is identical. + #[test] + fn walk_and_patch_preserves_unrelated_bytes() { + let mut buf: Vec = Vec::new(); + + // Cap 1: unrelated cap_type=0x0001, len=8, body filled with 0x77 + buf.extend_from_slice(&[0x01, 0x00, 0x08, 0x00]); + buf.extend_from_slice(&[0x77; 4]); + let unrelated_range = 0..buf.len(); + + // Cap 2: Order (0x0003), full body of 0xFF + 4-byte header + let order_header_offset = buf.len(); + let order_total_len = (order_cap::BODY_LEN + 4) as u16; + buf.extend_from_slice(&[0x03, 0x00]); + buf.extend_from_slice(&order_total_len.to_le_bytes()); + let order_body_offset = buf.len(); + buf.extend_from_slice(&[0xFF; order_cap::BODY_LEN]); + + // Cap 3: BitmapCodecs (0x001d), 4-byte header + body of 0xEE + let codecs_header_offset = buf.len(); + let codecs_body_len = 16usize; + buf.extend_from_slice(&[0x1D, 0x00]); + buf.extend_from_slice(&((codecs_body_len + 4) as u16).to_le_bytes()); + let codecs_body_offset = buf.len(); + buf.extend_from_slice(&vec![0xEE; codecs_body_len]); + + // Cap 4: trailing unrelated cap (filter must not stop early or read past it) + let trailing_offset = buf.len(); + buf.extend_from_slice(&[0x02, 0x00, 0x06, 0x00, 0x55, 0x55]); + + let original = buf.clone(); + + let caps: Vec<_> = walk_caps(&buf, 0).collect(); + assert_eq!(caps.len(), 4); + assert_eq!(caps[0].body_offset_in_user_data, order_header_offset - 4); + assert_eq!(caps[1].cap_type, cap_types::ORDER); + assert_eq!(caps[1].body_offset_in_user_data, order_body_offset); + assert_eq!(caps[2].cap_type, cap_types::BITMAP_CODECS); + assert_eq!(caps[2].body_offset_in_user_data, codecs_body_offset); + assert_eq!(caps[3].body_offset_in_user_data, trailing_offset + 4); + + order_cap::clear_order_support( + &mut buf[order_body_offset..order_body_offset + order_cap::BODY_LEN], + ); + bitmap_codecs_cap::clear_codec_count(&mut buf[codecs_body_offset..]); + + // Unrelated cap: byte-identical + assert_eq!(&buf[unrelated_range.clone()], &original[unrelated_range]); + // Order cap: header preserved, only ORDER_SUPPORT range zeroed + assert_eq!( + &buf[order_header_offset..order_body_offset], + &original[order_header_offset..order_body_offset] + ); + let zeroed_start = order_body_offset + order_cap::ORDER_SUPPORT.start; + let zeroed_end = order_body_offset + order_cap::ORDER_SUPPORT.end; + assert_eq!( + &buf[order_body_offset..zeroed_start], + &original[order_body_offset..zeroed_start] + ); + assert_eq!(&buf[zeroed_start..zeroed_end], &[0u8; 32]); + assert_eq!( + &buf[zeroed_end..codecs_header_offset], + &original[zeroed_end..codecs_header_offset] + ); + // BitmapCodecs cap: header preserved, only first body byte zeroed + assert_eq!( + &buf[codecs_header_offset..codecs_body_offset], + &original[codecs_header_offset..codecs_body_offset] + ); + assert_eq!(buf[codecs_body_offset], 0); + assert_eq!( + &buf[codecs_body_offset + 1..trailing_offset], + &original[codecs_body_offset + 1..trailing_offset] + ); + // Trailing cap: byte-identical + assert_eq!(&buf[trailing_offset..], &original[trailing_offset..]); + } +} diff --git a/packages/pam/handlers/rdp/native/src/config.rs b/packages/pam/handlers/rdp/native/src/config.rs index ba223311..d959fe18 100644 --- a/packages/pam/handlers/rdp/native/src/config.rs +++ b/packages/pam/handlers/rdp/native/src/config.rs @@ -17,10 +17,8 @@ pub fn connector_config(username: String, password: String, domain: Option, + elapsed_ns: u64, + }, +} + +pub fn elapsed_ns_since(started_at: Instant) -> u64 { + started_at.elapsed().as_nanos() as u64 +} + +pub type EventSender = mpsc::Sender; +pub type EventReceiver = mpsc::Receiver; + +// Bounded so a busy-disk gateway can't OOM under heavy graphics: producer +// (tap_*) uses try_send and drops on full rather than back-pressuring the +// bridge byte stream. Sized to ~few seconds of 60 fps RDP frames at typical +// PDU rates; lossy recording > unbounded memory. +pub const EVENT_CHANNEL_CAPACITY: usize = 1024; + +pub fn channel() -> (EventSender, EventReceiver) { + mpsc::channel(EVENT_CHANNEL_CAPACITY) +} diff --git a/packages/pam/handlers/rdp/native/src/ffi.rs b/packages/pam/handlers/rdp/native/src/ffi.rs index 96a2fd52..fb637e99 100644 --- a/packages/pam/handlers/rdp/native/src/ffi.rs +++ b/packages/pam/handlers/rdp/native/src/ffi.rs @@ -1,8 +1,5 @@ -//! C ABI for the bridge. Called from Go via CGo. -//! -//! Each session runs on its own OS thread with a current-thread tokio -//! runtime. `start_*` transfers ownership of the client fd/socket to -//! Rust (Go hands in a dup). Contract: wait, then free. +//! C ABI for the bridge. Each session runs on its own thread with a +//! current-thread tokio runtime. Caller contract: wait, then free. use std::collections::HashMap; use std::ffi::{c_char, CStr}; @@ -10,12 +7,15 @@ use std::net::TcpStream as StdTcpStream; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{LazyLock, Mutex}; use std::thread::JoinHandle; +use std::time::Duration; use tokio::net::TcpStream; +use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tracing::{error, info}; use crate::bridge::{run_mitm, TargetEndpoint}; +use crate::events::{self, SessionEvent}; pub const RDP_BRIDGE_OK: i32 = 0; pub const RDP_BRIDGE_SESSION_ERROR: i32 = 1; @@ -24,10 +24,146 @@ pub const RDP_BRIDGE_INVALID_HANDLE: i32 = -1; pub const RDP_BRIDGE_BAD_ARG: i32 = -2; pub const RDP_BRIDGE_RUNTIME_ERROR: i32 = -3; +// Distinct number space from the bridge status codes above; consumed by +// a different Go function. +pub const RDP_POLL_OK: i32 = 0; +pub const RDP_POLL_TIMEOUT: i32 = 1; +pub const RDP_POLL_ENDED: i32 = 2; +pub const RDP_POLL_INVALID_HANDLE: i32 = -1; + +#[repr(u8)] +pub enum RdpEventType { + Keyboard = 1, + Unicode = 2, + Mouse = 3, + TargetFrame = 4, +} + +/// Fields are reused across variants; check `event_type` first. +/// For TargetFrame, `payload_ptr` is libc::malloc'd; Go must libc::free it. +#[repr(C)] +pub struct RdpEvent { + pub event_type: u8, + /// Nanoseconds since bridge start. + pub elapsed_ns: u64, + /// Keyboard: scancode. Unicode: code point. Mouse: x. TargetFrame: bytes. + pub value_a: u32, + /// Mouse: y. Others: 0. + pub value_b: u32, + /// Keyboard / Unicode / Mouse flags (raw bits from the RDP layer). + pub flags: u32, + /// Mouse wheel delta (signed). 0 for others. + pub wheel_delta: i32, + /// TargetFrame: 0 = X.224, 1 = FastPath. 0 for others. + pub action: u8, + pub payload_ptr: *mut u8, + pub payload_len: u32, +} + +impl RdpEvent { + const fn zero() -> Self { + Self { + event_type: 0, + elapsed_ns: 0, + value_a: 0, + value_b: 0, + flags: 0, + wheel_delta: 0, + action: 0, + payload_ptr: std::ptr::null_mut(), + payload_len: 0, + } + } + + fn from_session_event(ev: SessionEvent) -> Self { + match ev { + SessionEvent::KeyboardInput { + scancode, + flags, + elapsed_ns, + } => Self { + event_type: RdpEventType::Keyboard as u8, + elapsed_ns, + value_a: scancode.into(), + flags: flags.bits().into(), + ..Self::zero() + }, + SessionEvent::UnicodeInput { + code_point, + flags, + elapsed_ns, + } => Self { + event_type: RdpEventType::Unicode as u8, + elapsed_ns, + value_a: code_point.into(), + flags: flags.bits().into(), + ..Self::zero() + }, + SessionEvent::MouseInput { + x, + y, + flags, + wheel_delta, + elapsed_ns, + } => Self { + event_type: RdpEventType::Mouse as u8, + elapsed_ns, + value_a: x.into(), + value_b: y.into(), + flags: flags.bits().into(), + wheel_delta: wheel_delta.into(), + ..Self::zero() + }, + SessionEvent::TargetFrame { + action, + payload, + elapsed_ns, + } => { + // Copy into a libc::malloc'd buffer the Go caller will free. + // Using libc (not Rust's allocator) lets Go free directly via + // C.free without an extra trip back through the FFI. + let len = payload.len(); + let ptr = if len == 0 { + std::ptr::null_mut() + } else { + unsafe { + let p = libc::malloc(len) as *mut u8; + if p.is_null() { + std::ptr::null_mut() + } else { + std::ptr::copy_nonoverlapping(payload.as_ptr(), p, len); + p + } + } + }; + Self { + event_type: RdpEventType::TargetFrame as u8, + elapsed_ns, + value_a: len as u32, + action: match action { + ironrdp_pdu::Action::X224 => 0, + ironrdp_pdu::Action::FastPath => 1, + }, + payload_ptr: ptr, + payload_len: len as u32, + ..Self::zero() + } + } + } + } +} + struct BridgeEntry { cancel: CancellationToken, // Taken by wait(); None afterward. join: Mutex>>>, + // Receiver side of the bridge's event channel. Polled by Go via + // rdp_bridge_poll_event. Wrapped in Option so the poll loop can take it + // out for the duration of the await without holding the HANDLES lock. + events_rx: Mutex>>, + // Set once the events channel has reported closed; subsequent polls + // short-circuit to RDP_POLL_ENDED. + events_ended: Mutex, } static HANDLES: LazyLock>> = @@ -65,6 +201,8 @@ fn spawn_session( let cancel = CancellationToken::new(); let cancel_for_thread = cancel.clone(); + let (events_tx, events_rx) = events::channel(); + let join = std::thread::Builder::new() .name("rdp-bridge-session".to_owned()) .spawn(move || -> anyhow::Result<()> { @@ -80,13 +218,15 @@ fn spawn_session( password, domain, }; - run_mitm(client, endpoint, cancel_for_thread).await + run_mitm(client, endpoint, cancel_for_thread, events_tx).await }) })?; Ok(register(BridgeEntry { cancel, join: Mutex::new(Some(join)), + events_rx: Mutex::new(Some(events_rx)), + events_ended: Mutex::new(false), })) } @@ -203,10 +343,12 @@ pub extern "C" fn rdp_bridge_wait(handle: u64) -> i32 { } Ok(Err(e)) => { error!(handle, error = ?e, "rdp_bridge_wait: session failed"); + eprintln!("rdp bridge session failed (handle={handle}): {e:?}"); RDP_BRIDGE_SESSION_ERROR } - Err(_) => { + Err(panic) => { error!(handle, "rdp_bridge_wait: session thread panicked"); + eprintln!("rdp bridge session thread panicked (handle={handle}): {panic:?}"); RDP_BRIDGE_THREAD_PANIC } }, @@ -235,3 +377,75 @@ pub extern "C" fn rdp_bridge_free(handle: u64) -> i32 { RDP_BRIDGE_INVALID_HANDLE } } + +/// Poll the next event, blocking up to `timeout_ms` ms. On RDP_POLL_OK, +/// caller owns *payload_ptr (must libc::free). +/// +/// # Safety +/// +/// `out` must be a non-null, writable `*mut RdpEvent`. +#[no_mangle] +pub unsafe extern "C" fn rdp_bridge_poll_event( + handle: u64, + out: *mut RdpEvent, + timeout_ms: u32, +) -> i32 { + if out.is_null() { + return RDP_POLL_INVALID_HANDLE; + } + + // Avoid holding the HANDLES lock across the await. + let take_result: Result>, i32> = { + let handles = HANDLES.lock().expect("HANDLES poisoned"); + match handles.get(&handle) { + None => Err(RDP_POLL_INVALID_HANDLE), + Some(entry) => { + if *entry.events_ended.lock().expect("events_ended poisoned") { + Err(RDP_POLL_ENDED) + } else { + Ok(entry.events_rx.lock().expect("events_rx poisoned").take()) + } + } + } + }; + let mut rx = match take_result { + Ok(Some(rx)) => rx, + // Concurrent poll on the same handle; callers must serialize. + Ok(None) => return RDP_POLL_INVALID_HANDLE, + Err(code) => return code, + }; + + // Short-lived single-thread runtime just for the timeout. Cheap; the + // bridge thread runs its own runtime. + let result = { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .expect("build poll runtime"); + rt.block_on(async { + tokio::time::timeout(Duration::from_millis(timeout_ms.into()), rx.recv()).await + }) + }; + + let outcome = match result { + Ok(Some(event)) => { + let rdp_event = RdpEvent::from_session_event(event); + unsafe { out.write(rdp_event) }; + RDP_POLL_OK + } + Ok(None) => RDP_POLL_ENDED, // sender side dropped (bridge ended) + Err(_timeout) => RDP_POLL_TIMEOUT, + }; + + // Restore the receiver, or mark ended if the channel reported closed. + let handles = HANDLES.lock().expect("HANDLES poisoned"); + if let Some(entry) = handles.get(&handle) { + if outcome == RDP_POLL_ENDED { + *entry.events_ended.lock().expect("events_ended poisoned") = true; + } else { + *entry.events_rx.lock().expect("events_rx poisoned") = Some(rx); + } + } + + outcome +} diff --git a/packages/pam/handlers/rdp/native/src/lib.rs b/packages/pam/handlers/rdp/native/src/lib.rs index 61c64480..abb6f0bd 100644 --- a/packages/pam/handlers/rdp/native/src/lib.rs +++ b/packages/pam/handlers/rdp/native/src/lib.rs @@ -3,5 +3,7 @@ //! passes bytes through. pub mod bridge; +pub mod cap_filter; pub mod config; +pub mod events; pub mod ffi; diff --git a/packages/pam/handlers/rdp/proxy.go b/packages/pam/handlers/rdp/proxy.go index 2bd6aa8d..18bef66f 100644 --- a/packages/pam/handlers/rdp/proxy.go +++ b/packages/pam/handlers/rdp/proxy.go @@ -1,6 +1,13 @@ package rdp import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/rs/zerolog/log" + "github.com/Infisical/infisical-merge/packages/pam/session" ) @@ -13,9 +20,12 @@ type RDPProxyConfig struct { // domain-joined NTLM CredSSP. Backend session credentials populate this. InjectDomain string SessionID string - // Retained for API symmetry with other PAM handlers; not yet written - // through (no RDP session recording in this MVP). SessionLogger session.SessionLogger + // Added to every event's elapsed_ns so timestamps stay monotonic across + // RDP reconnects within the same PAM session. Zero for the first connection. + PriorElapsedNs uint64 + // drain calls RecordEmittedElapsedNs after each persisted event. + SessionUploader *session.SessionUploader } type RDPProxy struct { @@ -25,3 +35,124 @@ type RDPProxy struct { func NewRDPProxy(config RDPProxyConfig) *RDPProxy { return &RDPProxy{config: config} } + +// Wire envelopes carried inside SessionEvent.Data for ChannelType=RDP. +type rdpTargetFrameEnvelope struct { + Type string `json:"type"` // "target_frame" + Action string `json:"action"` // "x224" | "fastpath" + Payload []byte `json:"payload"` // raw PDU bytes (base64 by Go's json.Marshal) + ElapsedNs uint64 `json:"elapsedNs"` +} + +type rdpKeyboardEnvelope struct { + Type string `json:"type"` // "keyboard" + Scancode uint8 `json:"scancode"` + Flags uint32 `json:"flags"` + ElapsedNs uint64 `json:"elapsedNs"` +} + +type rdpUnicodeEnvelope struct { + Type string `json:"type"` // "unicode" + CodePoint uint16 `json:"codePoint"` + Flags uint32 `json:"flags"` + ElapsedNs uint64 `json:"elapsedNs"` +} + +type rdpMouseEnvelope struct { + Type string `json:"type"` // "mouse" + X uint16 `json:"x"` + Y uint16 `json:"y"` + Flags uint32 `json:"flags"` + WheelDelta int32 `json:"wheelDelta"` + ElapsedNs uint64 `json:"elapsedNs"` +} + +// Bounds bridge poll latency so Cancel ends the drain loop promptly. +const pollTimeout = 250 * time.Millisecond + +var errUnknownRdpEventType = errors.New("rdp: unknown event type") + +// Logger errors are warned but don't stop the drain; dropping one event is +// better than back-pressuring the bridge byte stream. +func drainBridgeEvents(ctx context.Context, b *Bridge, logger session.SessionLogger, sessionID string, priorElapsedNs uint64, uploader *session.SessionUploader) { + if logger == nil { + return + } + for { + if ctx.Err() != nil { + return + } + result, ev, err := b.PollEvent(pollTimeout) + if err != nil { + log.Debug().Err(err).Str("sessionID", sessionID).Msg("rdp event drain stopped") + return + } + switch result { + case PollEnded: + return + case PollTimeout: + continue + case PollOK: + ev.ElapsedNs += priorElapsedNs + data, encErr := encodeRdpEvent(ev) + if encErr != nil { + log.Warn().Err(encErr).Str("sessionID", sessionID).Uint8("type", uint8(ev.Type)).Msg("encode RDP event") + continue + } + te := session.SessionEvent{ + Timestamp: time.Now(), + EventType: session.SessionEventRDP, + ChannelType: session.SessionChannelRDP, + Data: data, + ElapsedTime: float64(ev.ElapsedNs) / 1e9, + } + if logErr := logger.LogSessionEvent(te); logErr != nil { + log.Warn().Err(logErr).Str("sessionID", sessionID).Msg("log RDP event") + continue + } + if uploader != nil { + uploader.RecordEmittedElapsedNs(sessionID, ev.ElapsedNs) + } + } + } +} + +func encodeRdpEvent(ev Event) ([]byte, error) { + switch ev.Type { + case EventTypeTargetFrame: + action := "x224" + if ev.Action == ActionFastPath { + action = "fastpath" + } + return json.Marshal(rdpTargetFrameEnvelope{ + Type: "target_frame", + Action: action, + Payload: ev.Payload, + ElapsedNs: ev.ElapsedNs, + }) + case EventTypeKeyboard: + return json.Marshal(rdpKeyboardEnvelope{ + Type: "keyboard", + Scancode: ev.Scancode, + Flags: ev.Flags, + ElapsedNs: ev.ElapsedNs, + }) + case EventTypeUnicode: + return json.Marshal(rdpUnicodeEnvelope{ + Type: "unicode", + CodePoint: ev.CodePoint, + Flags: ev.Flags, + ElapsedNs: ev.ElapsedNs, + }) + case EventTypeMouse: + return json.Marshal(rdpMouseEnvelope{ + Type: "mouse", + X: ev.X, + Y: ev.Y, + Flags: ev.Flags, + WheelDelta: ev.WheelDelta, + ElapsedNs: ev.ElapsedNs, + }) + } + return nil, errUnknownRdpEventType +} diff --git a/packages/pam/handlers/ssh/proxy.go b/packages/pam/handlers/ssh/proxy.go index 8d9a6195..9b6eafeb 100644 --- a/packages/pam/handlers/ssh/proxy.go +++ b/packages/pam/handlers/ssh/proxy.go @@ -36,17 +36,17 @@ type SSHProxy struct { mutex sync.Mutex sessionData []byte // Store session data for logging inputBuffer []byte // Buffer for input data to batch keystrokes - inputChannelType session.TerminalChannelType // Channel type for buffered input + inputChannelType session.SessionChannelType // Channel type for buffered input escapeState int // 0=normal, 1=got ESC, 2=in CSI sequence outputMutex sync.Mutex outputBuffer []byte // Buffer for output data to enable masking across chunks - outputChannelType session.TerminalChannelType // Channel type for buffered output + outputChannelType session.SessionChannelType // Channel type for buffered output } // channelState holds per-channel state for tracking session type type channelState struct { mutex sync.Mutex - channelType session.TerminalChannelType // Type of channel (terminal, exec, sftp) + channelType session.SessionChannelType // Type of channel (terminal, exec, sftp) isBinarySession bool // True if this channel is SFTP/SCP binary protocol sftpParser *SFTPParser // Parser for SFTP protocol to extract file operations } @@ -360,13 +360,13 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha Msg("Blocked SSH exec command") // Log the blocked exec to session recording - blockedEvent := session.TerminalEvent{ + blockedEvent := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventInput, - ChannelType: session.TerminalChannelExec, + EventType: session.SessionEventInput, + ChannelType: session.SessionChannelExec, Data: []byte(fmt.Sprintf("$ %s\n[BLOCKED] Command not permitted\n", command)), } - if err := p.config.SessionLogger.LogTerminalEvent(blockedEvent); err != nil { + if err := p.config.SessionLogger.LogSessionEvent(blockedEvent); err != nil { log.Error().Err(err).Str("sessionID", sessionID).Msg("Failed to log blocked exec command") } @@ -382,9 +382,9 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha if isSCP { // Mark this channel as binary so we don't log the raw file data chState.isBinarySession = true - chState.channelType = session.TerminalChannelSFTP // SCP is file transfer + chState.channelType = session.SessionChannelSFTP // SCP is file transfer } else { - chState.channelType = session.TerminalChannelExec + chState.channelType = session.SessionChannelExec } chState.mutex.Unlock() @@ -395,9 +395,9 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha // Log the exec command to the session recording var logMessage string - var channelType session.TerminalChannelType + var channelType session.SessionChannelType if isSCP { - channelType = session.TerminalChannelSFTP + channelType = session.SessionChannelSFTP // Parse SCP command for more readable logging // scp -t /path = receiving file TO server // scp -f /path = sending file FROM server @@ -411,17 +411,17 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha logMessage = fmt.Sprintf("$ %s\n", command) } } else { - channelType = session.TerminalChannelExec + channelType = session.SessionChannelExec logMessage = fmt.Sprintf("$ %s\n", command) } - event := session.TerminalEvent{ + event := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventInput, + EventType: session.SessionEventInput, ChannelType: channelType, Data: []byte(logMessage), } - if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + if err := p.config.SessionLogger.LogSessionEvent(event); err != nil { log.Error().Err(err). Str("sessionID", sessionID). Str("command", command). @@ -431,7 +431,7 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha } case "shell": chState.mutex.Lock() - chState.channelType = session.TerminalChannelShell + chState.channelType = session.SessionChannelShell chState.mutex.Unlock() log.Info(). Str("sessionID", sessionID). @@ -451,18 +451,18 @@ func (p *SSHProxy) handleChannelRequests(requests <-chan *ssh.Request, targetCha // Log SFTP sessions and set up SFTP parser for file operation logging if subsystem == "sftp" { chState.mutex.Lock() - chState.channelType = session.TerminalChannelSFTP + chState.channelType = session.SessionChannelSFTP chState.isBinarySession = true chState.sftpParser = NewSFTPParser() chState.mutex.Unlock() - event := session.TerminalEvent{ + event := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventInput, - ChannelType: session.TerminalChannelSFTP, + EventType: session.SessionEventInput, + ChannelType: session.SessionChannelSFTP, Data: []byte("File transfer session started\n"), } - if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + if err := p.config.SessionLogger.LogSessionEvent(event); err != nil { log.Error().Err(err). Str("sessionID", sessionID). Msg("Failed to log SFTP session start") @@ -542,13 +542,13 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses for _, op := range operations { // Log each SFTP operation logMsg := FormatOperation(op) + "\n" - event := session.TerminalEvent{ + event := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventInput, - ChannelType: session.TerminalChannelSFTP, + EventType: session.SessionEventInput, + ChannelType: session.SessionChannelSFTP, Data: []byte(logMsg), } - if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + if err := p.config.SessionLogger.LogSessionEvent(event); err != nil { log.Error().Err(err). Str("sessionID", sessionID). Str("operation", op.Type). @@ -594,7 +594,7 @@ func (p *SSHProxy) proxyData(src io.Reader, dst io.Writer, direction string, ses // bufferInput accumulates input data and logs the effective command after processing edits. // It interprets control characters (backspace, Ctrl+C/U/W) so that the logged command // reflects what the user actually sent, not the raw keystrokes. -func (p *SSHProxy) bufferInput(data []byte, sessionID string, channelType session.TerminalChannelType) { +func (p *SSHProxy) bufferInput(data []byte, sessionID string, channelType session.SessionChannelType) { p.mutex.Lock() defer p.mutex.Unlock() @@ -669,18 +669,18 @@ func (p *SSHProxy) flushInputBufferUnsafe(sessionID string) { return } - event := session.TerminalEvent{ + event := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventInput, + EventType: session.SessionEventInput, ChannelType: p.inputChannelType, Data: make([]byte, len(p.inputBuffer)), } copy(event.Data, p.inputBuffer) - if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + if err := p.config.SessionLogger.LogSessionEvent(event); err != nil { log.Error().Err(err). Str("sessionID", sessionID). - Str("eventType", string(session.TerminalEventInput)). + Str("eventType", string(session.SessionEventInput)). Msg("Failed to log terminal event") } @@ -691,7 +691,7 @@ func (p *SSHProxy) flushInputBufferUnsafe(sessionID string) { // bufferOutput accumulates output data and flushes on newline or size limit. // This allows session log masking patterns to match across character-by-character echo, // because the regex sees a full line rather than individual bytes. -func (p *SSHProxy) bufferOutput(data []byte, sessionID string, channelType session.TerminalChannelType) { +func (p *SSHProxy) bufferOutput(data []byte, sessionID string, channelType session.SessionChannelType) { p.outputMutex.Lock() defer p.outputMutex.Unlock() @@ -720,18 +720,18 @@ func (p *SSHProxy) flushOutputBufferUnsafe(sessionID string) { return } - event := session.TerminalEvent{ + event := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventOutput, + EventType: session.SessionEventOutput, ChannelType: p.outputChannelType, Data: make([]byte, len(p.outputBuffer)), } copy(event.Data, p.outputBuffer) - if err := p.config.SessionLogger.LogTerminalEvent(event); err != nil { + if err := p.config.SessionLogger.LogSessionEvent(event); err != nil { log.Error().Err(err). Str("sessionID", sessionID). - Str("eventType", string(session.TerminalEventOutput)). + Str("eventType", string(session.SessionEventOutput)). Msg("Failed to log terminal event") } @@ -789,13 +789,13 @@ func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer, operations := sftpParser.Parse(buf[:n]) for _, op := range operations { logMsg := FormatOperation(op) + "\n" - event := session.TerminalEvent{ + event := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventInput, - ChannelType: session.TerminalChannelSFTP, + EventType: session.SessionEventInput, + ChannelType: session.SessionChannelSFTP, Data: []byte(logMsg), } - if logErr := p.config.SessionLogger.LogTerminalEvent(event); logErr != nil { + if logErr := p.config.SessionLogger.LogSessionEvent(event); logErr != nil { log.Error().Err(logErr). Str("sessionID", sessionID). Str("operation", op.Type). @@ -844,13 +844,13 @@ func (p *SSHProxy) proxyClientToServerWithBlocking(src io.Reader, dst io.Writer, clientWriter.Write([]byte(blockedMsg)) // Log the blocked message as output so it appears in session replay - blockedEvent := session.TerminalEvent{ + blockedEvent := session.SessionEvent{ Timestamp: time.Now(), - EventType: session.TerminalEventOutput, + EventType: session.SessionEventOutput, ChannelType: channelType, Data: []byte(blockedMsg), } - if logErr := p.config.SessionLogger.LogTerminalEvent(blockedEvent); logErr != nil { + if logErr := p.config.SessionLogger.LogSessionEvent(blockedEvent); logErr != nil { log.Error().Err(logErr).Str("sessionID", sessionID).Msg("Failed to log blocked command event") } diff --git a/packages/pam/local/rdp-proxy.go b/packages/pam/local/rdp-proxy.go index eee9d8a3..a128e763 100644 --- a/packages/pam/local/rdp-proxy.go +++ b/packages/pam/local/rdp-proxy.go @@ -19,22 +19,15 @@ import ( "github.com/rs/zerolog/log" ) -// RDPProxyServer exposes a local loopback TCP listener that tunnels bytes -// to the gateway's RDP MITM bridge via the existing mTLS + SSH relay. The -// user's RDP client connects to the loopback port; the gateway takes care -// of credential injection and forwarding to the Windows target. +// Loopback listener that tunnels RDP client traffic to the gateway's MITM bridge. type RDPProxyServer struct { BaseProxyServer server net.Listener port int - rdpFilePath string // path to the generated .rdp file, if any + rdpFilePath string } -// StartRDPLocalProxy is the CLI entry point for `infisical pam rdp access`. -// It creates a PAM session with the backend, binds a loopback listener, -// writes a .rdp file pointing at that loopback, optionally launches the -// user's default RDP client, and forwards accepted connections to the -// gateway. +// CLI entry point for `infisical pam rdp access`. func StartRDPLocalProxy(accessToken string, accessParams PAMAccessParams, projectID string, durationStr string, port int, noLaunch bool) { log.Info().Msgf("Starting RDP proxy for account: %s", accessParams.GetDisplayName()) log.Info().Msgf("Session duration: %s", durationStr) @@ -171,10 +164,8 @@ func (p *RDPProxyServer) gracefulShutdown() { p.shutdownOnce.Do(func() { log.Info().Msg("Starting graceful shutdown of RDP proxy...") - // Remove the .rdp file first: p.cancel() below unblocks Run(), - // which returns to main, which may exit before the rest of this - // goroutine completes. Do the cleanup that has to happen before - // anything that could let main race ahead. + // p.cancel() below can return main before this goroutine finishes; + // remove the .rdp file before risking that race. if p.rdpFilePath != "" { if err := os.Remove(p.rdpFilePath); err != nil && !os.IsNotExist(err) { log.Debug().Err(err).Str("path", p.rdpFilePath).Msg("Failed to remove .rdp file on exit") @@ -315,15 +306,8 @@ func (p *RDPProxyServer) handleConnection(clientConn net.Conn) { log.Info().Msgf("RDP connection closed for client: %s", clientConn.RemoteAddr().String()) } -// writeRDPFile creates a .rdp file pointing at the local loopback -// listener. Files live under `~/.infisical/rdp/` to match the CLI's -// existing convention for per-user state (alongside the login config -// and update-check cache). Filename includes the session ID so -// concurrent sessions don't collide. The file is removed on graceful -// shutdown (see gracefulShutdown) since the embedded loopback port -// becomes invalid as soon as the CLI exits; reopening the file later -// would just dial a dead port. -// Falls back to the OS temp dir if the home directory can't be resolved. +// Generates a per-session .rdp file under ~/.infisical/rdp/ pointing at +// the loopback listener. Removed on graceful shutdown. func writeRDPFile(listenPort int, sessionID, username string) (string, error) { filename := fmt.Sprintf("infisical-rdp-%s.rdp", sessionID) @@ -336,9 +320,14 @@ func writeRDPFile(listenPort int, sessionID, username string) (string, error) { } path := filepath.Join(dir, filename) + // authentication level:i:0 -> mstsc connects even if it can't verify the + // server's TLS cert. The bridge presents a self-signed cert, so without + // this mstsc terminates with "unexpected server authentication certificate". + // FreeRDP/Windows App ignore the cert by default; mstsc is the strict one. content := fmt.Sprintf( "full address:s:127.0.0.1:%d\r\n"+ - "username:s:%s\r\n", + "username:s:%s\r\n"+ + "authentication level:i:0\r\n", listenPort, username, ) diff --git a/packages/pam/pam-proxy.go b/packages/pam/pam-proxy.go index dd46900b..edf273af 100644 --- a/packages/pam/pam-proxy.go +++ b/packages/pam/pam-proxy.go @@ -107,13 +107,18 @@ func HandlePAMCancellation(ctx context.Context, conn *tls.Conn, pamConfig *Gatew // Kill the active proxy connection if it exists in the registry if cancelled := cancelSession(pamConfig.SessionId); cancelled { log.Info().Str("sessionId", pamConfig.SessionId).Msg("Active proxy session cancelled via registry") - if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { - log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") - } } else { log.Info().Str("sessionId", pamConfig.SessionId).Msg("No active proxy session found in registry (may have already ended)") } + // Always run cleanup on explicit cancellation. RDP keeps sessions alive + // across client disconnects to support .rdp-file reconnects, so when the + // CLI ctrl-C arrives the registry is already empty but the platform-side + // session is still active and needs to be terminated. + if err := pamConfig.SessionUploader.CleanupPAMSession(pamConfig.SessionId, "cancellation"); err != nil { + log.Error().Err(err).Str("sessionId", pamConfig.SessionId).Msg("Failed to cleanup PAM session") + } + conn.Close() return nil @@ -459,13 +464,15 @@ func HandlePAMProxy(ctx context.Context, conn *tls.Conn, pamConfig *GatewayPAMCo return fmt.Errorf("rdp: target port %d out of range", credentials.Port) } rdpConfig := rdp.RDPProxyConfig{ - TargetHost: credentials.Host, - TargetPort: uint16(credentials.Port), - InjectUsername: credentials.Username, - InjectPassword: credentials.Password, - InjectDomain: credentials.Domain, - SessionID: pamConfig.SessionId, - SessionLogger: sessionLogger, + TargetHost: credentials.Host, + TargetPort: uint16(credentials.Port), + InjectUsername: credentials.Username, + InjectPassword: credentials.Password, + InjectDomain: credentials.Domain, + SessionID: pamConfig.SessionId, + SessionLogger: sessionLogger, + PriorElapsedNs: pamConfig.SessionUploader.GetPriorElapsedNs(pamConfig.SessionId), + SessionUploader: pamConfig.SessionUploader, } proxy := rdp.NewRDPProxy(rdpConfig) log.Info(). diff --git a/packages/pam/session/logger.go b/packages/pam/session/logger.go index 77c3c3e3..2b343ca8 100644 --- a/packages/pam/session/logger.go +++ b/packages/pam/session/logger.go @@ -25,29 +25,31 @@ type SessionLogEntry struct { Output string `json:"output"` } -// TerminalEventType represents the type of terminal event -type TerminalEventType string +// SessionEventType represents the type of session event +type SessionEventType string const ( - TerminalEventInput TerminalEventType = "input" // Data from user to server - TerminalEventOutput TerminalEventType = "output" // Data from server to user + SessionEventInput SessionEventType = "input" // Data from user to server + SessionEventOutput SessionEventType = "output" // Data from server to user + SessionEventRDP SessionEventType = "rdp" // RDP tap event (see SessionChannelRDP) ) -// TerminalChannelType represents the type of SSH channel -type TerminalChannelType string +// SessionChannelType represents the type of SSH channel +type SessionChannelType string const ( - TerminalChannelShell TerminalChannelType = "terminal" // Interactive shell session - TerminalChannelExec TerminalChannelType = "exec" // Single command execution - TerminalChannelSFTP TerminalChannelType = "sftp" // SFTP file transfer + SessionChannelShell SessionChannelType = "terminal" // Interactive shell session + SessionChannelExec SessionChannelType = "exec" // Single command execution + SessionChannelSFTP SessionChannelType = "sftp" // SFTP file transfer + SessionChannelRDP SessionChannelType = "rdp" // RDP frame/input tap; Data carries an RDP-specific JSON envelope ) -// TerminalEvent represents a single event in a terminal session -type TerminalEvent struct { +// SessionEvent represents a single event in a recorded session (SSH or RDP). +type SessionEvent struct { Timestamp time.Time `json:"timestamp"` - EventType TerminalEventType `json:"eventType"` - ChannelType TerminalChannelType `json:"channelType,omitempty"` // Type of SSH channel - Data []byte `json:"data"` // Raw terminal data + EventType SessionEventType `json:"eventType"` + ChannelType SessionChannelType `json:"channelType,omitempty"` // Channel kind (SSH shell/exec/sftp or RDP) + Data []byte `json:"data"` // SSH: raw terminal bytes; RDP: JSON envelope (base64-marshaled) ElapsedTime float64 `json:"elapsedTime"` // Seconds since session start (for replay) } @@ -73,7 +75,7 @@ const ( type SessionLogger interface { LogEntry(entry SessionLogEntry) error - LogTerminalEvent(event TerminalEvent) error + LogSessionEvent(event SessionEvent) error LogHttpEvent(event HttpEvent) error Close() error } @@ -300,12 +302,19 @@ func (sl *EncryptedSessionLogger) LogEntry(entry SessionLogEntry) error { }) } -func (sl *EncryptedSessionLogger) LogTerminalEvent(event TerminalEvent) error { +func (sl *EncryptedSessionLogger) LogSessionEvent(event SessionEvent) error { return sl.writeEvent(func() ([]byte, error) { if event.ElapsedTime == 0 { event.ElapsedTime = time.Since(sl.sessionStart).Seconds() } - event.Data = sl.applyMasking(event.Data) + // RDP carries a structured JSON envelope (with base64-encoded PDU + // bytes, scancodes, etc.) in Data, not free-form terminal text. + // Masking patterns are SSH-shaped regexes; running them over the + // envelope would corrupt valid recordings whenever a pattern + // happened to match a substring of the JSON or base64. + if event.ChannelType != SessionChannelRDP { + event.Data = sl.applyMasking(event.Data) + } return json.Marshal(event) }) } diff --git a/packages/pam/session/uploader.go b/packages/pam/session/uploader.go index e0b71cc1..c204a11c 100644 --- a/packages/pam/session/uploader.go +++ b/packages/pam/session/uploader.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/Infisical/infisical-merge/packages/api" @@ -48,7 +49,10 @@ type sessionUploadState struct { legacyMode bool // true if the batch upload endpoint returned 404 (platform too old); fall back to bulk upload at session end startedAt time.Time lastEndElapsedMs int64 - mu sync.Mutex + // Advanced per-event by streaming writers so GetPriorElapsedNs sees a fresh + // anchor between flush ticks (rapid RDP reconnects within the 10s window). + lastEmittedElapsedNs atomic.Uint64 + mu sync.Mutex } type SessionUploader struct { @@ -228,8 +232,8 @@ func ReadEncryptedSessionLogByFilename(filename string, encryptionKey string) ([ return readEncryptedEntries[SessionLogEntry](filename, encryptionKey) } -func ReadEncryptedTerminalEventsFromFile(filename string, encryptionKey string) ([]TerminalEvent, error) { - return readEncryptedEntries[TerminalEvent](filename, encryptionKey) +func ReadEncryptedSessionEventsFromFile(filename string, encryptionKey string) ([]SessionEvent, error) { + return readEncryptedEntries[SessionEvent](filename, encryptionKey) } func ReadEncryptedHttpEventsFromFile(filename string, encryptionKey string) ([]HttpEvent, error) { @@ -274,27 +278,27 @@ func deletePersistedOffset(filename string) { _ = os.Remove(offsetFilePath(filename)) } -// readFromOffset reads length-prefixed encrypted records from filename starting at offset, -// decrypts each, and returns them as a JSON array payload plus the new file offset. -// When maxPayloadBytes > 0, stops accumulating once the next entry would push the serialized JSON array past that limit -// Returns nil payload (and the unchanged offset) if there are no new records. -func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadBytes int) ([]byte, int64, error) { +// Returns (payload JSON array, new offset, last entry's elapsedMs, err). +// lastEntryElapsedMs is 0 if entries lack the field. maxPayloadBytes>0 +// caps the JSON array size; caller loops for the rest. +func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadBytes int) ([]byte, int64, int64, error) { recordingDir := GetSessionRecordingDir() fullPath := filepath.Join(recordingDir, filename) file, err := os.Open(fullPath) if err != nil { - return nil, offset, fmt.Errorf("failed to open session file: %w", err) + return nil, offset, 0, fmt.Errorf("failed to open session file: %w", err) } defer file.Close() if _, err := file.Seek(offset, io.SeekStart); err != nil { - return nil, offset, fmt.Errorf("failed to seek to offset %d: %w", offset, err) + return nil, offset, 0, fmt.Errorf("failed to seek to offset %d: %w", offset, err) } var entries []json.RawMessage newOffset := offset runningSize := 2 // account for JSON array brackets [] + var lastEntryElapsedMs int64 for { lengthBytes := make([]byte, 4) @@ -302,7 +306,7 @@ func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadByte if err == io.EOF || err == io.ErrUnexpectedEOF { break // No more complete records } - return nil, newOffset, fmt.Errorf("failed to read length prefix: %w", err) + return nil, newOffset, 0, fmt.Errorf("failed to read length prefix: %w", err) } length := binary.BigEndian.Uint32(lengthBytes) @@ -313,7 +317,7 @@ func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadByte decryptedData, err := DecryptData(encryptedData, encryptionKey) if err != nil { - return nil, newOffset, fmt.Errorf("failed to decrypt record at offset %d: %w", newOffset, err) + return nil, newOffset, 0, fmt.Errorf("failed to decrypt record at offset %d: %w", newOffset, err) } entrySize := len(decryptedData) @@ -324,21 +328,66 @@ func readFromOffset(filename, encryptionKey string, offset int64, maxPayloadByte break // would exceed budget; caller will loop for the rest } + // Probe the entry's elapsedTime field. Absent on HTTP/Kubernetes events. + var probe struct { + ElapsedTime float64 `json:"elapsedTime"` + } + if jsonErr := json.Unmarshal(decryptedData, &probe); jsonErr == nil && probe.ElapsedTime > 0 { + lastEntryElapsedMs = int64(probe.ElapsedTime * 1000) + } + entries = append(entries, json.RawMessage(decryptedData)) newOffset += int64(4 + length) runningSize += entrySize } if len(entries) == 0 { - return nil, newOffset, nil + return nil, newOffset, 0, nil } payload, err := json.Marshal(entries) if err != nil { - return nil, newOffset, fmt.Errorf("failed to marshal event batch: %w", err) + return nil, newOffset, 0, fmt.Errorf("failed to marshal event batch: %w", err) + } + + return payload, newOffset, lastEntryElapsedMs, nil +} + +// GetPriorElapsedNs returns the last recorded elapsed time for this session +// in nanoseconds. On reconnects this is added to the bridge's elapsed_ns so +// timestamps stay monotonic across bridge restarts. +func (su *SessionUploader) GetPriorElapsedNs(sessionID string) uint64 { + su.activeSessionsMu.RLock() + defer su.activeSessionsMu.RUnlock() + state, ok := su.activeSessions[sessionID] + if !ok { + return 0 + } + emitted := state.lastEmittedElapsedNs.Load() + flushed := uint64(state.lastEndElapsedMs) * 1_000_000 + if emitted > flushed { + return emitted } + return flushed +} - return payload, newOffset, nil +// Monotonically advances the per-session GetPriorElapsedNs anchor; stale values are ignored. +func (su *SessionUploader) RecordEmittedElapsedNs(sessionID string, elapsedNs uint64) { + su.activeSessionsMu.RLock() + state, ok := su.activeSessions[sessionID] + su.activeSessionsMu.RUnlock() + if !ok { + return + } + for { + cur := state.lastEmittedElapsedNs.Load() + if elapsedNs <= cur { + return + } + if state.lastEmittedElapsedNs.CompareAndSwap(cur, elapsedNs) { + return + } + } } // RegisterSession registers a session for incremental batch uploads, resuming from @@ -360,11 +409,20 @@ func (su *SessionUploader) RegisterSession(sessionID string) { } su.activeSessionsMu.Lock() - su.activeSessions[sessionID] = &sessionUploadState{ - fileOffset: startOffset, - filename: fileInfo.Filename, - startedAt: time.Now().Add(-time.Duration(lastEndElapsedMs) * time.Millisecond), - lastEndElapsedMs: lastEndElapsedMs, + // Preserve the original anchor across RDP reconnects within the same PAM + // session: HandlePAMProxy calls RegisterSession on every gateway connection, + // and overwriting the entry would reset startedAt to ~now, making elapsedNs + // rewind on reconnect. The persisted .offset only catches up after a flush, + // so it can't be the source of truth here. + if _, exists := su.activeSessions[sessionID]; !exists { + state := &sessionUploadState{ + fileOffset: startOffset, + filename: fileInfo.Filename, + startedAt: time.Now().Add(-time.Duration(lastEndElapsedMs) * time.Millisecond), + lastEndElapsedMs: lastEndElapsedMs, + } + state.lastEmittedElapsedNs.Store(uint64(lastEndElapsedMs) * 1_000_000) + su.activeSessions[sessionID] = state } su.activeSessionsMu.Unlock() @@ -416,12 +474,8 @@ func (su *SessionUploader) startUploadRoutine() { }() } -// resumeInProgressSessions re-registers non-expired recording files into the upload loop at startup. -// A gateway restart kills all proxy connections, so any file on disk is from a session that is -// already over from the customer's perspective. Re-registering restores offset tracking so the -// ticker-based flush and chunk reconciliation can drive uploads to completion over subsequent ticks. -// Already-expired files are skipped here and handled exclusively by uploadExpiredSessionFiles -// to avoid duplicate back-to-back cleanup attempts on the same file at startup. +// Re-registers non-expired recording files at startup so the flush ticker +// can drain them. Expired files are handled by uploadExpiredSessionFiles. func (su *SessionUploader) resumeInProgressSessions() { allFiles, err := ListSessionFiles() if err != nil { @@ -495,10 +549,7 @@ func (su *SessionUploader) flushActiveSessions() { } } -// flushSession reads new events from the session recording file since the last uploaded offset, -// uploads them as a batch, and advances the offset on success. Returns nil when there is nothing -// to do (session not registered, already in legacy mode, no new events) or when a 404 cleanly -// transitions the session to legacy mode; the caller treats those as success. +// Uploads new events as a batch and advances the offset on success. func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { su.activeSessionsMu.RLock() state, ok := su.activeSessions[sessionID] @@ -519,7 +570,7 @@ func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { currentOffset := state.fileOffset for { - payload, newOffset, err := readFromOffset(state.filename, encryptionKey, currentOffset, pamRecordingMaxPlaintextBytes) + payload, newOffset, lastEntryElapsedMs, err := readFromOffset(state.filename, encryptionKey, currentOffset, pamRecordingMaxPlaintextBytes) if err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to read session events for chunk upload") break @@ -528,7 +579,14 @@ func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { break } - endElapsedMs := time.Since(state.startedAt).Milliseconds() + // Wallclock fallback only when the chunk carried no elapsedTime at all + // (HTTP/Kubernetes); otherwise it includes reconnect idle gaps. + endElapsedMs := lastEntryElapsedMs + if lastEntryElapsedMs == 0 { + endElapsedMs = time.Since(state.startedAt).Milliseconds() + } else if endElapsedMs < startElapsedMs { + endElapsedMs = startElapsedMs + } pc, encErr := su.chunkUploader.EncryptAndQueueChunk(sessionID, payload, startElapsedMs, endElapsedMs) if encErr != nil { @@ -552,7 +610,7 @@ func (su *SessionUploader) flushSession(sessionID, encryptionKey string) error { return nil } - payload, newOffset, err := readFromOffset(state.filename, encryptionKey, state.fileOffset, 0) + payload, newOffset, _, err := readFromOffset(state.filename, encryptionKey, state.fileOffset, 0) if err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Failed to read session events for batch upload") return err @@ -588,21 +646,25 @@ func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { return fmt.Errorf("failed to get encryption key: %w", err) } - if fileInfo.ResourceType == ResourceTypeSSH { - terminalEvents, err := ReadEncryptedTerminalEventsFromFile(fileInfo.Filename, encryptionKey) + // SSH and Windows both write SessionEvent records (SSH uses input/output/ + // resize/error; Windows uses ChannelType=rdp). Bulk-uploading either via + // the Database fallback would silently zero-fill input/output, dropping + // the entire recording. + if fileInfo.ResourceType == ResourceTypeSSH || fileInfo.ResourceType == ResourceTypeWindows { + sessionEvents, err := ReadEncryptedSessionEventsFromFile(fileInfo.Filename, encryptionKey) if err != nil { - return fmt.Errorf("failed to read SSH session file: %w", err) + return fmt.Errorf("failed to read session event file: %w", err) } log.Debug(). Str("sessionId", fileInfo.SessionID). Str("resourceType", fileInfo.ResourceType). - Int("eventCount", len(terminalEvents)). - Msg("Uploading terminal session events") + Int("eventCount", len(sessionEvents)). + Msg("Uploading session events") - var logs []api.UploadTerminalEvent - for _, event := range terminalEvents { - logs = append(logs, api.UploadTerminalEvent{ + var logs []api.UploadSessionEvent + for _, event := range sessionEvents { + logs = append(logs, api.UploadSessionEvent{ Timestamp: event.Timestamp, EventType: string(event.EventType), ChannelType: string(event.ChannelType), @@ -624,7 +686,7 @@ func (su *SessionUploader) uploadSessionFile(fileInfo *SessionFileInfo) error { Str("sessionId", fileInfo.SessionID). Str("resourceType", fileInfo.ResourceType). Int("eventCount", len(httpEvents)). - Msg("Uploading terminal session events") + Msg("Uploading Kubernetes session events") var logs []api.UploadHttpEvent for _, event := range httpEvents { @@ -701,10 +763,8 @@ func (su *SessionUploader) CleanupPAMSession(sessionID string, reason string) er su.RegisterSession(sessionID) } - // Final flush: upload any remaining events before we delete the file. Any failure on this path - // (key fetch, batch flush, or legacy bulk upload) returns early with the recording file, registry - // entry, and persisted offset intact so uploadExpiredSessionFiles can retry once the file crosses - // ExpiresAt. Deleting on failure would lose unuploaded events unrecoverably. + // On any failure here, return early so uploadExpiredSessionFiles can retry + // past ExpiresAt; deleting the file on failure would lose events. encryptionKey, err := su.credentialsManager.GetPAMSessionEncryptionKey() if err != nil { log.Error().Err(err).Str("sessionId", sessionID).Msg("Could not get encryption key for final flush, keeping recording file for retry") @@ -715,8 +775,7 @@ func (su *SessionUploader) CleanupPAMSession(sessionID string, reason string) er return flushErr } - // If the batch endpoint was not supported (or this session was already in legacy mode), - // fall back to a single bulk upload of the whole file. + // Legacy fallback: single bulk upload of the whole file. su.activeSessionsMu.RLock() state, stateExists := su.activeSessions[sessionID] su.activeSessionsMu.RUnlock()