diff --git a/crates/agent-tunnel/Cargo.toml b/crates/agent-tunnel/Cargo.toml index f72af2cfc..eba3e2e82 100644 --- a/crates/agent-tunnel/Cargo.toml +++ b/crates/agent-tunnel/Cargo.toml @@ -8,6 +8,12 @@ publish = false [lints] workspace = true +[features] +# Exposes `_for_test` helpers (e.g. `AgentPeer::set_last_seen_for_test`) so +# integration tests in other crates can force specific peer states without +# wall-clock sleeps. Production builds must not enable this. +test-utils = [] + [dependencies] # Internal crates agent-tunnel-proto = { path = "../agent-tunnel-proto", features = ["serde"] } diff --git a/crates/agent-tunnel/src/lib.rs b/crates/agent-tunnel/src/lib.rs index c5aaaed91..2c5f98795 100644 --- a/crates/agent-tunnel/src/lib.rs +++ b/crates/agent-tunnel/src/lib.rs @@ -10,6 +10,7 @@ pub mod cert; pub mod enrollment_store; pub mod listener; pub mod registry; +pub mod routing; pub mod stream; pub use enrollment_store::EnrollmentTokenStore; diff --git a/crates/agent-tunnel/src/listener.rs b/crates/agent-tunnel/src/listener.rs index 530e9ff26..9a96872ae 100644 --- a/crates/agent-tunnel/src/listener.rs +++ b/crates/agent-tunnel/src/listener.rs @@ -158,7 +158,7 @@ impl AgentTunnelListener { let handle = AgentTunnelHandle { registry: Arc::clone(®istry), agent_connections: Arc::clone(&agent_connections), - ca_manager, + ca_manager: Arc::clone(&ca_manager), enrollment_token_store, }; @@ -170,6 +170,11 @@ impl AgentTunnelListener { Ok((listener, handle)) } + + /// Returns the local address the QUIC endpoint is bound to. + pub fn local_addr(&self) -> SocketAddr { + self.endpoint.local_addr().expect("endpoint has local addr") + } } #[async_trait] @@ -202,9 +207,7 @@ impl devolutions_gateway_task::Task for AgentTunnelListener { let registry = Arc::clone(&self.registry); let agent_connections = Arc::clone(&self.agent_connections); - conn_handles.spawn( - run_agent_connection(registry, agent_connections, incoming), - ); + conn_handles.spawn(run_agent_connection(registry, agent_connections, incoming)); } // Reap completed connection tasks to prevent unbounded growth. @@ -253,7 +256,7 @@ async fn run_agent_connection( info!(%agent_id, %agent_name, %peer_addr, "Agent authenticated via mTLS"); - let peer = Arc::new(AgentPeer::new(agent_id, agent_name, fingerprint)); + let peer = Arc::new(AgentPeer::new(agent_id, agent_name.clone(), fingerprint)); registry.register(Arc::clone(&peer)).await; agent_connections.write().await.insert(agent_id, conn.clone()); diff --git a/crates/agent-tunnel/src/registry.rs b/crates/agent-tunnel/src/registry.rs index 3026acd20..b28fe0118 100644 --- a/crates/agent-tunnel/src/registry.rs +++ b/crates/agent-tunnel/src/registry.rs @@ -10,6 +10,8 @@ use serde::Serialize; use tokio::sync::RwLock as TokioRwLock; use uuid::Uuid; +use crate::routing::RouteTarget; + /// Duration after which an agent is considered offline if no heartbeat has been received. pub const AGENT_OFFLINE_TIMEOUT: Duration = Duration::from_secs(90); @@ -32,6 +34,33 @@ pub struct RouteAdvertisementState { pub updated_at: SystemTime, } +impl RouteAdvertisementState { + /// Match this route set against a parsed target host. + /// + /// Returns a specificity score if matched, or `None` if no match. + /// IP subnet matches return `usize::MAX` (always highest priority). + /// Domain matches return the matched domain length (longer = more specific). + pub fn matches_target(&self, target: &RouteTarget) -> Option { + use std::net::IpAddr; + + match target { + // Only IPv4 subnets are currently tracked; only match IPv4 target IPs. + RouteTarget::Ip(IpAddr::V4(ipv4)) => self + .subnets + .iter() + .any(|subnet| subnet.contains(*ipv4)) + .then_some(usize::MAX), + RouteTarget::Ip(IpAddr::V6(_)) => None, + RouteTarget::Hostname(hostname) => self + .domains + .iter() + .filter(|adv| adv.domain.matches_hostname(hostname.as_str())) + .map(|adv| adv.domain.as_str().len()) + .max(), + } + } +} + impl Default for RouteAdvertisementState { fn default() -> Self { let now = SystemTime::now(); @@ -79,6 +108,37 @@ impl AgentPeer { self.last_seen.store(now_ms, Ordering::Release); } + /// Set `last_seen` to an explicit timestamp (milliseconds since UNIX epoch). + /// + /// Test-only API — the `_for_test` suffix is the project's signal that + /// production code must not call this. Used by integration tests in other + /// crates (e.g. the workspace `testsuite`) to force an agent into the + /// "offline" state without waiting for the real timeout to elapse; + /// production code should use [`touch`](Self::touch) instead. Gated behind + /// `test-utils` (and `cfg(test)` for this crate's own unit tests) so + /// production builds cannot link against it; cross-crate consumers must + /// opt in via `features = ["test-utils"]` on their `agent-tunnel` + /// dev-dependency. + #[cfg(any(test, feature = "test-utils"))] + #[doc(hidden)] + pub fn set_last_seen_for_test(&self, last_seen_ms: u64) { + self.last_seen.store(last_seen_ms, Ordering::Release); + } + + /// Overwrite `received_at` on the current route state. + /// + /// Test-only API. Intended for tests that need to assert ordering by + /// arrival time without relying on wall-clock `thread::sleep` — which is + /// flaky on platforms with coarse timer resolution (e.g. Windows ~16 ms). + /// See [`set_last_seen_for_test`](Self::set_last_seen_for_test) for the + /// gating rationale. + #[cfg(any(test, feature = "test-utils"))] + #[doc(hidden)] + pub fn set_received_at_for_test(&self, received_at: SystemTime) { + let mut state = self.route_state.write(); + state.received_at = received_at; + } + /// Returns the last-seen timestamp as milliseconds since UNIX epoch. pub fn last_seen_ms(&self) -> u64 { self.last_seen.load(Ordering::Acquire) @@ -223,6 +283,39 @@ impl AgentRegistry { pub async fn agent_infos(&self) -> Vec { self.agents.read().await.values().map(AgentInfo::from).collect() } + + /// Find all online agents that can route to the given parsed target host. + /// + /// For IP targets: matches against advertised subnets. + /// For domain targets: uses longest suffix match (more specific domain wins). + /// + /// Results with equal specificity are sorted by `received_at` descending (most recent first). + pub async fn find_agents_for(&self, target: &RouteTarget) -> Vec> { + let mut best_specificity: usize = 0; + let mut candidates: Vec<(SystemTime, Arc)> = Vec::new(); + + let agents = self.agents.read().await; + for agent in agents.values() { + if !agent.is_online(AGENT_OFFLINE_TIMEOUT) { + continue; + } + + let route_state = agent.route_state(); + + if let Some(specificity) = route_state.matches_target(target) { + if specificity > best_specificity { + best_specificity = specificity; + candidates.clear(); + candidates.push((route_state.received_at, Arc::clone(agent))); + } else if specificity == best_specificity { + candidates.push((route_state.received_at, Arc::clone(agent))); + } + } + } + + candidates.sort_by(|a, b| b.0.cmp(&a.0)); + candidates.into_iter().map(|(_, agent)| agent).collect() + } } impl Default for AgentRegistry { diff --git a/crates/agent-tunnel/src/routing.rs b/crates/agent-tunnel/src/routing.rs new file mode 100644 index 000000000..b205fd2ae --- /dev/null +++ b/crates/agent-tunnel/src/routing.rs @@ -0,0 +1,192 @@ +//! Shared routing pipeline for agent tunnel. +//! +//! Consumed by the upstream connection paths (forwarding, RDP clean path, +//! generic client) to ensure consistent routing behavior and error messages. + +use std::net::IpAddr; +use std::sync::Arc; + +use agent_tunnel_proto::DomainName; +use anyhow::{Result, anyhow}; +use uuid::Uuid; + +use super::listener::AgentTunnelHandle; +use super::registry::{AgentPeer, AgentRegistry}; +use super::stream::TunnelStream; + +/// A parsed target host used for route matching. +/// +/// Routing cares only about the host identity, not the port or scheme used by +/// the eventual connection attempt. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RouteTarget { + Ip(IpAddr), + Hostname(DomainName), +} + +impl RouteTarget { + pub fn ip(ip: IpAddr) -> Self { + Self::Ip(ip) + } + + pub fn hostname(hostname: impl Into) -> Self { + Self::Hostname(DomainName::new(hostname)) + } +} + +impl std::fmt::Display for RouteTarget { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ip(ip) => ip.fmt(f), + Self::Hostname(hostname) => hostname.fmt(f), + } + } +} + +/// Result of the routing pipeline. +/// +/// Each variant carries enough context for the caller to produce an actionable error message. +#[derive(Debug)] +pub enum RoutingDecision { + /// Route through these agent candidates (try in order, first success wins). + ViaAgent(Vec>), + /// Explicit agent_id was specified but not found in registry. + ExplicitAgentNotFound(Uuid), + /// No agent matched — caller should attempt direct connection. + Direct, +} + +/// Determines how to route a connection to the given target. +/// +/// Pipeline (in order of priority): +/// 1. Explicit agent_id (from JWT) → route to that agent +/// 2. Target match (IP subnet or domain suffix) → best match wins +/// 3. No match → direct connection +pub async fn resolve_route( + registry: &AgentRegistry, + explicit_agent_id: Option, + target: &RouteTarget, +) -> RoutingDecision { + // Step 1: Explicit agent ID (from JWT) + if let Some(id) = explicit_agent_id { + return match registry.get(&id).await { + Some(agent) => RoutingDecision::ViaAgent(vec![agent]), + None => RoutingDecision::ExplicitAgentNotFound(id), + }; + } + + // Step 2: Match target against all agents (IP subnet or domain suffix) + let agents = registry.find_agents_for(target).await; + + if agents.is_empty() { + RoutingDecision::Direct + } else { + RoutingDecision::ViaAgent(agents) + } +} + +/// Attempt to route a connection via the agent tunnel. +/// +/// Returns `Ok(Some(stream))` if routed through an agent, `Ok(None)` if the caller +/// should fall through to direct connect, or `Err` if an explicit agent was specified +/// but not found (or all candidates failed). +pub async fn try_route( + handle: Option<&AgentTunnelHandle>, + explicit_agent_id: Option, + target: &RouteTarget, + session_id: Uuid, + target_addr: &str, +) -> Result)>> { + let Some(handle) = handle else { + // An explicit `jet_agent_id` claim means the token requires routing via that + // specific agent; silently falling back to a direct connect would bypass the + // intended network boundary. Reject instead. + return match explicit_agent_id { + Some(id) => Err(anyhow!( + "agent {id} specified in token requires agent tunnel routing, but no tunnel handle is configured" + )), + None => Ok(None), + }; + }; + + match resolve_route(handle.registry(), explicit_agent_id, target).await { + RoutingDecision::ExplicitAgentNotFound(id) => { + Err(anyhow!("agent {id} specified in token not found in registry")) + } + RoutingDecision::Direct => Ok(None), + RoutingDecision::ViaAgent(candidates) => { + let result = route_and_connect(handle, &candidates, session_id, target_addr).await?; + Ok(Some(result)) + } + } +} + +/// Try connecting to target through agent candidates (try-fail-retry). +/// +/// Returns the connected `TunnelStream` and the agent that succeeded. +/// +/// Callers must handle `RoutingDecision::ExplicitAgentNotFound` and +/// `RoutingDecision::Direct` before calling this function. +pub async fn route_and_connect( + handle: &AgentTunnelHandle, + candidates: &[Arc], + session_id: Uuid, + target: &str, +) -> Result<(TunnelStream, Arc)> { + if candidates.is_empty() { + return Err(anyhow!("route_and_connect called with empty candidates")); + } + + let mut last_error = None; + + for agent in candidates { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Routing via agent tunnel" + ); + + match handle.connect_via_agent(agent.agent_id, session_id, target).await { + Ok(stream) => { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + "Agent tunnel connection established" + ); + return Ok((stream, Arc::clone(agent))); + } + Err(error) => { + warn!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + %target, + error = format!("{error:#}"), + "Agent tunnel connection failed, trying next candidate" + ); + last_error = Some(error); + } + } + } + + let agent_names: Vec<&str> = candidates.iter().map(|a| a.name.as_str()).collect(); + let last_err_msg = last_error.as_ref().map(|e| format!("{e:#}")).unwrap_or_default(); + + error!( + agent_count = candidates.len(), + %target, + agents = ?agent_names, + last_error = %last_err_msg, + "All agent tunnel candidates failed" + ); + + Err(last_error.unwrap_or_else(|| { + anyhow!( + "All {} agents matching target '{}' failed to connect. Agents tried: [{}]", + candidates.len(), + target, + agent_names.join(", "), + ) + })) +} diff --git a/devolutions-agent/src/tunnel_helpers.rs b/devolutions-agent/src/tunnel_helpers.rs index fdb54198c..fad94c88b 100644 --- a/devolutions-agent/src/tunnel_helpers.rs +++ b/devolutions-agent/src/tunnel_helpers.rs @@ -1,4 +1,4 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, SocketAddr}; use anyhow::{Context as _, bail}; use ipnetwork::Ipv4Network; @@ -7,19 +7,16 @@ use tokio::net::TcpStream; /// Parsed connection target — either a raw IP or a domain name. #[derive(Debug)] pub(crate) enum Target { - Ip(Ipv4Addr, u16), + Ip(IpAddr, u16), Domain(String, u16), } impl Target { /// Parse a `host:port` string into a typed target. pub(crate) fn parse(target: &str) -> anyhow::Result { - // Try IP:port first. + // Try IP:port first (handles both IPv4 and IPv6). if let Ok(addr) = target.parse::() { - return match addr.ip() { - IpAddr::V4(ip) => Ok(Self::Ip(ip, addr.port())), - IpAddr::V6(_) => bail!("IPv6 targets are not supported: {target}"), - }; + return Ok(Self::Ip(addr.ip(), addr.port())); } // Otherwise it's domain:port — split on last ':'. @@ -35,26 +32,33 @@ impl Target { } /// Resolve a target to candidate socket addresses within the advertised subnets. +/// +/// Only IPv4 subnets are supported right now, matching the wire protocol. IPv6 targets +/// never match. pub(crate) async fn resolve_target( target: &Target, advertise_subnets: &[Ipv4Network], ) -> anyhow::Result> { + fn matches(advertise_subnets: &[Ipv4Network], ip: IpAddr) -> bool { + match ip { + IpAddr::V4(ipv4) => advertise_subnets.iter().any(|subnet| subnet.contains(ipv4)), + IpAddr::V6(_) => false, + } + } + match target { Target::Ip(ip, port) => { - if !advertise_subnets.iter().any(|subnet| subnet.contains(*ip)) { + if !matches(advertise_subnets, *ip) { bail!("target {ip}:{port} is not in advertised subnets"); } - Ok(vec![SocketAddr::new(IpAddr::V4(*ip), *port)]) + Ok(vec![SocketAddr::new(*ip, *port)]) } Target::Domain(host, port) => { let lookup = format!("{host}:{port}"); let resolved: Vec = tokio::net::lookup_host(&lookup) .await .with_context(|| format!("resolve target {lookup}"))? - .filter(|addr| match addr.ip() { - IpAddr::V4(ipv4) => advertise_subnets.iter().any(|subnet| subnet.contains(ipv4)), - IpAddr::V6(_) => false, - }) + .filter(|addr| matches(advertise_subnets, addr.ip())) .collect(); if resolved.is_empty() { diff --git a/devolutions-gateway/src/api/fwd.rs b/devolutions-gateway/src/api/fwd.rs index f0b1701d6..b2099519a 100644 --- a/devolutions-gateway/src/api/fwd.rs +++ b/devolutions-gateway/src/api/fwd.rs @@ -15,6 +15,7 @@ use tracing::{Instrument as _, field}; use typed_builder::TypedBuilder; use uuid::Uuid; +use crate::DgwState; use crate::config::Conf; use crate::extract::{AssociationToken, BridgeToken}; use crate::http::HttpError; @@ -22,7 +23,7 @@ use crate::proxy::Proxy; use crate::session::{ConnectionModeDetails, DisconnectInterest, SessionInfo, SessionMessageSender}; use crate::subscriber::SubscriberSender; use crate::token::{ApplicationProtocol, AssociationTokenClaims, ConnectionMode, Protocol, RecordingPolicy}; -use crate::{DgwState, utils}; +use crate::upstream::{self, PreparedUpstream, UpstreamMode}; pub fn make_router(state: DgwState) -> Router { use axum::routing::{self, MethodFilter, get}; @@ -54,6 +55,7 @@ async fn fwd_tcp( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -78,6 +80,7 @@ async fn fwd_tcp( claims, source_addr, false, + agent_tunnel_handle, ) .instrument(span) }); @@ -91,6 +94,7 @@ async fn fwd_tls( sessions, subscriber_tx, shutdown_signal, + agent_tunnel_handle, .. }): State, AssociationToken(claims): AssociationToken, @@ -115,6 +119,7 @@ async fn fwd_tls( claims, source_addr, true, + agent_tunnel_handle, ) .instrument(span) }); @@ -132,6 +137,7 @@ async fn handle_fwd( claims: AssociationTokenClaims, source_addr: SocketAddr, with_tls: bool, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -153,7 +159,8 @@ async fn handle_fwd( .claims(claims) .sessions(sessions) .subscriber_tx(subscriber_tx) - .with_tls(with_tls) + .mode(if with_tls { UpstreamMode::Tls } else { UpstreamMode::Tcp }) + .agent_tunnel_handle(agent_tunnel_handle) .build() .run() .instrument(span.clone()) @@ -183,7 +190,9 @@ struct Forward { client_addr: SocketAddr, sessions: SessionMessageSender, subscriber_tx: SubscriberSender, - with_tls: bool, + mode: UpstreamMode, + #[builder(default)] + agent_tunnel_handle: Option>, } #[derive(Debug, thiserror::Error)] @@ -206,33 +215,17 @@ where client_addr, sessions, subscriber_tx, - with_tls, + mode, + agent_tunnel_handle, } = self; - match claims.jet_rec { - RecordingPolicy::None | RecordingPolicy::Stream => (), - RecordingPolicy::Proxy => { - return Err(ForwardError::Internal(anyhow::anyhow!( - "recording policy not supported" - ))); - } - } + validate_forward_request(&claims)?; - let ConnectionMode::Fwd { targets, .. } = claims.jet_cm else { - return Err(ForwardError::Internal(anyhow::anyhow!("connection mode not supported"))); + let targets = match &claims.jet_cm { + ConnectionMode::Fwd { targets, .. } => targets, + _ => unreachable!("validated connection mode"), }; - let span = tracing::Span::current(); - - trace!("Select and connect to target"); - - let ((server_stream, server_addr), selected_target) = utils::successive_try(&targets, utils::tcp_connect) - .await - .map_err(ForwardError::BadGateway)?; - - trace!(%selected_target, "Connected"); - span.record("target", selected_target.to_string()); - // ARD uses MVS codec which doesn't like buffering. let buffer_size = if claims.jet_ap == ApplicationProtocol::Known(Protocol::Ard) { Some(1024) @@ -240,79 +233,81 @@ where None }; - if with_tls { - trace!("Establishing TLS connection with server"); + let connected = upstream::connect_upstream( + targets, + claims.jet_agent_id, + claims.jet_aid, + agent_tunnel_handle.as_deref(), + ) + .await + .map_err(ForwardError::BadGateway)?; + + let PreparedUpstream { + session, + server_addr, + selected_target, + } = upstream::prepare_upstream(connected, mode, claims.cert_thumb256) + .await + .map_err(ForwardError::BadGateway)?; - // Establish TLS connection with server. - let server_stream = - crate::tls::safe_connect(selected_target.host().to_owned(), server_stream, claims.cert_thumb256) - .await - .context("TLS connect") - .map_err(ForwardError::BadGateway)?; - - info!("WebSocket-TLS forwarding"); - - let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) - .details(ConnectionModeDetails::Fwd { - destination_host: selected_target.clone(), - }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) - .filtering_policy(claims.jet_flt) - .build(); - - Proxy::builder() - .conf(conf) - .session_info(info) - .address_a(client_addr) - .transport_a(client_stream) - .address_b(server_addr) - .transport_b(server_stream) - .sessions(sessions) - .subscriber_tx(subscriber_tx) - .buffer_size(buffer_size) - .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) - .build() - .select_dissector_and_forward() - .await - .context("encountered a failure during plain tls traffic proxying") - .map_err(ForwardError::Internal) - } else { - info!("WebSocket-TCP forwarding"); - - let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) - .details(ConnectionModeDetails::Fwd { - destination_host: selected_target.clone(), - }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) - .filtering_policy(claims.jet_flt) - .build(); - - Proxy::builder() - .conf(conf) - .session_info(info) - .address_a(client_addr) - .transport_a(client_stream) - .address_b(server_addr) - .transport_b(server_stream) - .sessions(sessions) - .subscriber_tx(subscriber_tx) - .buffer_size(buffer_size) - .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) - .build() - .select_dissector_and_forward() - .await - .context("encountered a failure during plain tcp traffic proxying") - .map_err(ForwardError::Internal) + tracing::Span::current().record("target", selected_target.to_string()); + + let info = SessionInfo::builder() + .id(claims.jet_aid) + .application_protocol(claims.jet_ap) + .details(ConnectionModeDetails::Fwd { + destination_host: selected_target, + }) + .time_to_live(claims.jet_ttl) + .recording_policy(claims.jet_rec) + .filtering_policy(claims.jet_flt) + .build(); + + // Keep the per-mode message shape that pre-refactor logs and + // integration tests grep for ("WebSocket-TCP forwarding" / + // "WebSocket-TLS forwarding"); structured `mode` field is for + // newer telemetry. + match mode { + UpstreamMode::Tcp => info!(mode = "tcp", "WebSocket-TCP forwarding"), + UpstreamMode::Tls => info!(mode = "tls", "WebSocket-TLS forwarding"), } + + Proxy::builder() + .conf(conf) + .session_info(info) + .address_a(client_addr) + .transport_a(client_stream) + .address_b(server_addr) + .transport_b(session) + .sessions(sessions) + .subscriber_tx(subscriber_tx) + .buffer_size(buffer_size) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .build() + .select_dissector_and_forward() + .await + .context("forward websocket traffic") + .map_err(ForwardError::Internal) } } +fn validate_forward_request(claims: &AssociationTokenClaims) -> Result<(), ForwardError> { + match claims.jet_rec { + RecordingPolicy::None | RecordingPolicy::Stream => {} + RecordingPolicy::Proxy => { + return Err(ForwardError::Internal(anyhow::anyhow!( + "recording policy not supported" + ))); + } + } + + if !matches!(claims.jet_cm, ConnectionMode::Fwd { .. }) { + return Err(ForwardError::Internal(anyhow::anyhow!("connection mode not supported"))); + } + + Ok(()) +} + async fn fwd_http( State(state): State, BridgeToken(claims): BridgeToken, diff --git a/devolutions-gateway/src/api/rdp.rs b/devolutions-gateway/src/api/rdp.rs index 65cfe5b2e..6129a776c 100644 --- a/devolutions-gateway/src/api/rdp.rs +++ b/devolutions-gateway/src/api/rdp.rs @@ -26,6 +26,7 @@ pub async fn handler( recordings, shutdown_signal, credential_store, + agent_tunnel_handle, .. }): State, ConnectInfo(source_addr): ConnectInfo, @@ -46,6 +47,7 @@ pub async fn handler( recordings.active_recordings, source_addr, credential_store, + agent_tunnel_handle, ) .instrument(span) }); @@ -65,6 +67,7 @@ async fn handle_socket( active_recordings: Arc, source_addr: SocketAddr, credential_store: crate::credential::CredentialStoreHandle, + agent_tunnel_handle: Option>, ) { let (stream, close_handle) = crate::ws::handle( ws, @@ -82,6 +85,7 @@ async fn handle_socket( subscriber_tx, &active_recordings, &credential_store, + agent_tunnel_handle, ) .await; diff --git a/devolutions-gateway/src/generic_client.rs b/devolutions-gateway/src/generic_client.rs index 4209bea57..7b5e6c47b 100644 --- a/devolutions-gateway/src/generic_client.rs +++ b/devolutions-gateway/src/generic_client.rs @@ -15,7 +15,7 @@ use crate::recording::ActiveRecordings; use crate::session::{ConnectionModeDetails, DisconnectInterest, SessionInfo, SessionMessageSender}; use crate::subscriber::SubscriberSender; use crate::token::{self, ConnectionMode, CurrentJrl, RecordingPolicy, TokenCache}; -use crate::utils; +use crate::upstream::{self, ConnectedUpstream}; #[derive(TypedBuilder)] pub struct GenericClient { @@ -113,82 +113,19 @@ where RecordingPolicy::Proxy => anyhow::bail!("can't meet recording policy"), } - // Route via agent tunnel if jet_agent_id is specified. - if let Some(agent_id) = claims.jet_agent_id { - let handle = agent_tunnel_handle.context("agent tunnel not configured on this gateway")?; + let ConnectedUpstream { + leg: mut server_stream, + server_addr, + selected_target, + } = upstream::connect_upstream( + &targets, + claims.jet_agent_id, + claims.jet_aid, + agent_tunnel_handle.as_deref(), + ) + .await + .context("connect to upstream")?; - let mut selected_target = None; - let mut server_stream = None; - let mut last_error = None; - - for candidate in targets.iter() { - let target_str = format!("{}:{}", candidate.host(), candidate.port()); - - info!(%agent_id, %target_str, "Routing via agent tunnel"); - - match handle.connect_via_agent(agent_id, claims.jet_aid, &target_str).await { - Ok(stream) => { - selected_target = Some(candidate.clone()); - server_stream = Some(stream); - break; - } - Err(error) => { - warn!( - %agent_id, - %target_str, - error = format!("{error:#}"), - "Agent tunnel target failed" - ); - last_error = Some(error); - } - } - } - - let selected_target = selected_target.ok_or_else(|| { - last_error.unwrap_or_else(|| anyhow::anyhow!("agent tunnel target selection failed")) - })?; - span.record("target", selected_target.to_string()); - let server_stream = server_stream.expect("server stream should be present when target is selected"); - - let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) - .details(ConnectionModeDetails::Fwd { - destination_host: selected_target.clone(), - }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) - .filtering_policy(claims.jet_flt) - .build(); - - let disconnect_interest = DisconnectInterest::from_reconnection_policy(claims.jet_reuse); - - // Agent handles the TCP connection; no leftover bytes to forward. - // Use a placeholder server address since the actual target is behind the agent. - let server_addr: SocketAddr = "0.0.0.0:0".parse().expect("valid placeholder"); - - return Proxy::builder() - .conf(conf) - .session_info(info) - .address_a(client_addr) - .transport_a(client_stream) - .address_b(server_addr) - .transport_b(server_stream) - .sessions(sessions) - .subscriber_tx(subscriber_tx) - .disconnect_interest(disconnect_interest) - .build() - .select_dissector_and_forward() - .await - .context("encountered a failure during agent tunnel traffic proxying"); - } - - trace!("Select and connect to target"); - - let ((mut server_stream, server_addr), selected_target) = - utils::successive_try(&targets, utils::tcp_connect).await?; - - trace!(%selected_target, "Connected"); span.record("target", selected_target.to_string()); let is_rdp = claims.jet_ap == token::ApplicationProtocol::Known(token::Protocol::Rdp); @@ -209,6 +146,9 @@ where // We support proxy-based credential injection for RDP. // If a credential mapping has been pushed, we automatically switch to this mode. // Otherwise, we continue the generic procedure. + // + // RdpProxy is generic over the server stream, so credential injection now works + // regardless of whether the upstream is direct TCP or tunnelled via an agent. if is_rdp { let token_id = token::extract_jti(token).context("failed to extract jti claim from token")?; @@ -238,7 +178,7 @@ where } } - info!("TCP forwarding"); + info!("Upstream forwarding"); server_stream .write_buf(&mut leftover_bytes) @@ -258,7 +198,7 @@ where .build() .select_dissector_and_forward() .await - .context("encountered a failure during plain tcp traffic proxying") + .context("encountered a failure during upstream traffic proxying") } } } diff --git a/devolutions-gateway/src/lib.rs b/devolutions-gateway/src/lib.rs index 4ac5d4729..3af9f7999 100644 --- a/devolutions-gateway/src/lib.rs +++ b/devolutions-gateway/src/lib.rs @@ -40,6 +40,7 @@ pub mod target_addr; pub mod tls; pub mod token; pub mod traffic_audit; +pub mod upstream; pub mod utils; pub mod ws; diff --git a/devolutions-gateway/src/proxy.rs b/devolutions-gateway/src/proxy.rs index 0c2f09e6a..fb5d1b4c6 100644 --- a/devolutions-gateway/src/proxy.rs +++ b/devolutions-gateway/src/proxy.rs @@ -32,8 +32,8 @@ pub struct Proxy { impl Proxy where - A: AsyncWrite + AsyncRead + Unpin, - B: AsyncWrite + AsyncRead + Unpin, + A: AsyncWrite + AsyncRead + Unpin + Send, + B: AsyncWrite + AsyncRead + Unpin + Send, { pub async fn select_dissector_and_forward(self) -> anyhow::Result<()> { match self.session_info.application_protocol { diff --git a/devolutions-gateway/src/rd_clean_path.rs b/devolutions-gateway/src/rd_clean_path.rs index 6d4614b5e..41a118a02 100644 --- a/devolutions-gateway/src/rd_clean_path.rs +++ b/devolutions-gateway/src/rd_clean_path.rs @@ -20,6 +20,7 @@ use crate::session::{ConnectionModeDetails, DisconnectInterest, DisconnectedInfo use crate::subscriber::SubscriberSender; use crate::target_addr::TargetAddr; use crate::token::{AssociationTokenClaims, CurrentJrl, TokenCache, TokenError}; +use crate::upstream::{self, ConnectedUpstream, UpstreamLeg}; #[derive(Debug, Error)] enum AuthorizationError { @@ -158,25 +159,24 @@ enum CleanPathError { Io(#[from] io::Error), } -struct CleanPathResult { +// Upstream transport (TCP or agent-tunnel) comes from `crate::upstream::UpstreamLeg`, +// which is also used by fwd.rs and generic_client.rs. + +struct CleanPathAuth { claims: AssociationTokenClaims, - destination: TargetAddr, - server_addr: SocketAddr, - server_stream: tokio_rustls::client::TlsStream, - x224_rsp: Vec, } -async fn process_cleanpath( - cleanpath_pdu: RDCleanPathPdu, +/// Validate the RDCleanPath PDU token and authorize the session. +/// Pure validation — no connections established. +async fn authorize_cleanpath( + cleanpath_pdu: &RDCleanPathPdu, client_addr: SocketAddr, conf: &Conf, token_cache: &TokenCache, jrl: &CurrentJrl, active_recordings: &ActiveRecordings, sessions: &SessionMessageSender, -) -> Result { - use crate::utils; - +) -> Result { let token = cleanpath_pdu .proxy_auth .as_deref() @@ -207,10 +207,9 @@ async fn process_cleanpath( }; let span = tracing::Span::current(); - span.record("session_id", claims.jet_aid.to_string()); - // Sanity check. + // Sanity check destination in PDU vs token. match cleanpath_pdu.destination.as_deref() { Some(destination) => match TargetAddr::parse(destination, 3389) { Ok(destination) if !destination.eq(targets.first()) => { @@ -224,14 +223,50 @@ async fn process_cleanpath( None => warn!("RDCleanPath PDU is missing the destination field"), } + Ok(CleanPathAuth { claims }) +} + +struct ConnectedRdpServer { + tls_stream: tokio_rustls::client::TlsStream, + server_addr: SocketAddr, + selected_target: TargetAddr, + x224_rsp: Vec, +} + +/// Establish a connection to the RDP server: route (agent/direct) → connect → X224 → TLS. +/// +/// The routing pipeline (explicit agent → subnet/domain match → direct) is shared with +/// the WebSocket forwarders in [`crate::upstream`]; here we just do the RDP-specific +/// PCB + X224 + TLS upgrade on top of whatever leg that returns. +async fn connect_rdp_server( + claims: &AssociationTokenClaims, + cleanpath_pdu: RDCleanPathPdu, + agent_tunnel_handle: Option<&Arc>, +) -> Result { + let crate::token::ConnectionMode::Fwd { ref targets, .. } = claims.jet_cm else { + return anyhow::Error::msg("unexpected connection mode") + .pipe(CleanPathError::BadRequest) + .pipe(Err); + }; + trace!(?targets, "Connecting to destination server"); - let ((mut server_stream, server_addr), selected_target) = utils::successive_try(targets, utils::tcp_connect) - .await - .context("connect to RDP server")?; + let ConnectedUpstream { + leg: mut server_stream, + server_addr, + selected_target, + } = upstream::connect_upstream( + targets, + claims.jet_agent_id, + claims.jet_aid, + agent_tunnel_handle.map(AsRef::as_ref), + ) + .await + .context("connect to RDP server") + .map_err(CleanPathError::Internal)?; debug!(%selected_target, "Connected to destination server"); - span.record("target", selected_target.to_string()); + tracing::Span::current().record("target", selected_target.to_string()); // Send preconnection blob if applicable. if let Some(pcb) = cleanpath_pdu.preconnection_blob { @@ -245,8 +280,6 @@ async fn process_cleanpath( .map_err(CleanPathError::BadRequest)?; server_stream.write_all(x224_req.as_bytes()).await?; - // == Receive server X224 connection response == - trace!("Receiving X224 response"); let x224_rsp = read_x224_response(&mut server_stream) @@ -256,20 +289,17 @@ async fn process_cleanpath( trace!("Establishing TLS connection with server"); - // == Establish TLS connection with server == - - let server_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) + let tls_stream = crate::tls::dangerous_connect(selected_target.host().to_owned(), server_stream) .await .map_err(|source| CleanPathError::TlsHandshake { source, - target_server: selected_target.to_owned(), + target_server: selected_target.clone(), })?; - Ok(CleanPathResult { - destination: selected_target.to_owned(), - claims, + Ok(ConnectedRdpServer { + tls_stream, server_addr, - server_stream, + selected_target, x224_rsp, }) } @@ -287,6 +317,7 @@ async fn handle_with_credential_injection( active_recordings: &ActiveRecordings, cleanpath_pdu: RDCleanPathPdu, credential_entry: Arc, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { let tls_conf = conf.credssp_tls.get().context("CredSSP TLS configuration")?; @@ -318,16 +349,9 @@ async fn handle_with_credential_injection( ) }; - // Run normal RDCleanPath flow (this will handle server-side TLS and get certs). - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - .. - } = process_cleanpath( - cleanpath_pdu, + // Authorize and connect to the RDP server. + let CleanPathAuth { claims } = authorize_cleanpath( + &cleanpath_pdu, client_addr, &conf, token_cache, @@ -336,7 +360,16 @@ async fn handle_with_credential_injection( &sessions, ) .await - .context("RDCleanPath processing failed")?; + .context("RDCleanPath authorization failed")?; + + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connect_rdp_server(&claims, cleanpath_pdu, agent_tunnel_handle.as_ref()) + .await + .context("RDCleanPath connection failed")?; // Retrieve the Gateway TLS public key that must be used for client-proxy CredSSP later on. let gateway_cert_chain_handle = tokio::spawn(crate::tls::get_cert_chain_for_acceptor_cached( @@ -532,6 +565,7 @@ pub async fn handle( subscriber_tx: SubscriberSender, active_recordings: &ActiveRecordings, credential_store: &CredentialStoreHandle, + agent_tunnel_handle: Option>, ) -> anyhow::Result<()> { // Special handshake of our RDP extension @@ -569,27 +603,29 @@ pub async fn handle( active_recordings, cleanpath_pdu, entry, + agent_tunnel_handle.clone(), ) .await; } trace!("Processing RDCleanPath"); - let CleanPathResult { - claims, - destination, - server_addr, - server_stream, - x224_rsp, - } = match process_cleanpath( - cleanpath_pdu, - client_addr, - &conf, - token_cache, - jrl, - active_recordings, - &sessions, - ) + let (auth, connected) = match async { + let auth = authorize_cleanpath( + &cleanpath_pdu, + client_addr, + &conf, + token_cache, + jrl, + active_recordings, + &sessions, + ) + .await?; + + let connected = connect_rdp_server(&auth.claims, cleanpath_pdu, agent_tunnel_handle.as_ref()).await?; + + Ok::<_, CleanPathError>((auth, connected)) + } .await { Ok(result) => result, @@ -602,6 +638,13 @@ pub async fn handle( } }; + let ConnectedRdpServer { + tls_stream: server_stream, + server_addr, + selected_target: destination, + x224_rsp, + } = connected; + // == Send success RDCleanPathPdu response == let x509_chain = server_stream @@ -622,13 +665,13 @@ pub async fn handle( // == Start actual RDP session == let info = SessionInfo::builder() - .id(claims.jet_aid) - .application_protocol(claims.jet_ap) + .id(auth.claims.jet_aid) + .application_protocol(auth.claims.jet_ap) .details(ConnectionModeDetails::Fwd { destination_host: destination.clone(), }) - .time_to_live(claims.jet_ttl) - .recording_policy(claims.jet_rec) + .time_to_live(auth.claims.jet_ttl) + .recording_policy(auth.claims.jet_rec) .build(); info!("RDP-TLS forwarding (RDCleanPath)"); @@ -642,7 +685,7 @@ pub async fn handle( .transport_b(server_stream) .sessions(sessions) .subscriber_tx(subscriber_tx) - .disconnect_interest(DisconnectInterest::from_reconnection_policy(claims.jet_reuse)) + .disconnect_interest(DisconnectInterest::from_reconnection_policy(auth.claims.jet_reuse)) .build() .select_dissector_and_forward() .await diff --git a/devolutions-gateway/src/rdp_proxy.rs b/devolutions-gateway/src/rdp_proxy.rs index b3dc466a7..af7d5f090 100644 --- a/devolutions-gateway/src/rdp_proxy.rs +++ b/devolutions-gateway/src/rdp_proxy.rs @@ -637,6 +637,10 @@ where async fn send_network_request(request: &NetworkRequest) -> anyhow::Result> { let target_addr = TargetAddr::parse(request.url.as_str(), Some(88))?; + // TODO(DGW-384): plumb `agent_tunnel_handle` through `RdpProxy` so + // CredSSP-originated Kerberos requests can traverse the agent tunnel. + // Currently these go direct from the gateway host, bypassing the + // routing pipeline used by every other proxy path. send_krb_message(&target_addr, &request.data) .await .map_err(|err| anyhow::Error::msg("failed to send KDC message").context(err)) diff --git a/devolutions-gateway/src/upstream.rs b/devolutions-gateway/src/upstream.rs new file mode 100644 index 000000000..94e853026 --- /dev/null +++ b/devolutions-gateway/src/upstream.rs @@ -0,0 +1,382 @@ +//! Shared upstream (server-side) connection machinery. +//! +//! All proxy paths that forward a client to some upstream target share the same +//! routing decision + connect sequence: +//! +//! 1. For each target in the token's `jet_cm: Fwd { targets }` list: +//! - If the JWT named a specific `jet_agent_id`, route via that agent or fail. +//! - Otherwise ask the registry for agents that cover the target (subnet / +//! domain match); route via the best match, or fall back to direct TCP. +//! 2. On the first successful connection, optionally wrap in client TLS. +//! +//! The two consumer patterns differ only in whether they want the TLS wrap +//! applied here (fwd.rs) or manage their own TLS upgrade (rd_clean_path.rs does +//! X224 first, then TLS). Both share `UpstreamLeg` and [`connect_upstream`]. + +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use agent_tunnel::AgentTunnelHandle; +use agent_tunnel::registry::AgentPeer; +use agent_tunnel::routing::{RouteTarget, RoutingDecision, resolve_route}; +use agent_tunnel::stream::TunnelStream; +use anyhow::{Context as _, Result, anyhow}; +use nonempty::NonEmpty; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpStream; +use tokio_rustls::client::TlsStream; +use uuid::Uuid; + +use crate::target_addr::TargetAddr; +use crate::tls::thumbprint::Sha256Thumbprint; +use crate::utils; + +// --------------------------------------------------------------------------- +// Upstream transport types +// --------------------------------------------------------------------------- + +/// Upstream transport to the target server. +/// +/// An enum (not `Box`) because the surrounding proxy futures must be +/// `Send` with a concrete type, and trait-object projections block that proof +/// on the `ws.on_upgrade()` boundary. +pub enum UpstreamLeg { + /// Direct TCP to the target. + Tcp(TcpStream), + /// Tunnelled through an enrolled agent via QUIC. + Tunnel(TunnelStream), +} + +impl AsyncRead for UpstreamLeg { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + Self::Tunnel(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for UpstreamLeg { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + Self::Tunnel(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_flush(cx), + Self::Tunnel(stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + Self::Tunnel(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +/// An `UpstreamLeg` optionally wrapped in a client TLS session. +pub enum UpstreamSession { + Tcp(UpstreamLeg), + Tls(Box>), +} + +impl AsyncRead for UpstreamSession { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + Self::Tls(stream) => Pin::new(stream.as_mut()).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for UpstreamSession { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + Self::Tls(stream) => Pin::new(stream.as_mut()).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_flush(cx), + Self::Tls(stream) => Pin::new(stream.as_mut()).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + Self::Tls(stream) => Pin::new(stream.as_mut()).poll_shutdown(cx), + } + } +} + +/// Whether the caller wants the upstream wrapped in client TLS before being +/// handed back from [`prepare_upstream`]. +#[derive(Debug, Clone, Copy)] +pub enum UpstreamMode { + Tcp, + Tls, +} + +// --------------------------------------------------------------------------- +// Result structs +// --------------------------------------------------------------------------- + +/// A successfully-connected upstream leg. The transport is either direct TCP +/// or an agent-tunnel stream; a TLS wrap has not yet been applied. +pub struct ConnectedUpstream { + pub leg: UpstreamLeg, + /// Remote peer address. For direct routing this is the resolved TCP peer. + /// For agent-tunnel routing the real TCP peer lives on the agent side, so + /// we surface the target IP:port (when the target is an IP literal) or + /// `0.0.0.0:` (when the target is a hostname the gateway + /// never resolves) — both are more useful in logs/PCAP than a true zero. + pub server_addr: SocketAddr, + pub selected_target: TargetAddr, +} + +/// A [`ConnectedUpstream`] that has gone through [`prepare_upstream`]; the +/// session may or may not be wrapped in TLS depending on [`UpstreamMode`]. +pub struct PreparedUpstream { + pub session: UpstreamSession, + pub server_addr: SocketAddr, + pub selected_target: TargetAddr, +} + +// --------------------------------------------------------------------------- +// Routing +// --------------------------------------------------------------------------- + +/// A routing decision for a single target. +pub(crate) enum RoutePlan<'a> { + Direct(&'a TargetAddr), + ViaAgent { + target: &'a TargetAddr, + candidates: Vec>, + }, +} + +impl<'a> RoutePlan<'a> { + /// Pick how to reach `target`: + /// - Explicit `jet_agent_id` → route via that agent (or error if missing). + /// - Otherwise registry subnet/domain match → best candidates, else Direct. + pub(crate) async fn resolve( + handle: Option<&AgentTunnelHandle>, + explicit_agent_id: Option, + target: &'a TargetAddr, + ) -> Result { + if let Some(agent_id) = explicit_agent_id { + let handle = handle.ok_or_else(|| { + anyhow!( + "agent {agent_id} specified in token requires agent tunnel routing, but no tunnel handle is configured" + ) + })?; + + let agent = handle + .registry() + .get(&agent_id) + .await + .ok_or_else(|| anyhow!("agent {agent_id} specified in token not found in registry"))?; + + return Ok(Self::ViaAgent { + target, + candidates: vec![agent], + }); + } + + let Some(handle) = handle else { + return Ok(Self::Direct(target)); + }; + + let route_target = route_target_from_target_addr(target); + let decision = resolve_route(handle.registry(), None, &route_target).await; + debug!( + target = %route_target, + decision = ?match &decision { + RoutingDecision::ViaAgent(c) => format!("ViaAgent({} candidates)", c.len()), + RoutingDecision::Direct => "Direct".to_owned(), + RoutingDecision::ExplicitAgentNotFound(id) => format!("ExplicitAgentNotFound({id})"), + }, + "Routing decision for implicit lookup" + ); + match decision { + RoutingDecision::ViaAgent(candidates) => Ok(Self::ViaAgent { target, candidates }), + RoutingDecision::Direct => Ok(Self::Direct(target)), + RoutingDecision::ExplicitAgentNotFound(agent_id) => { + // resolve_route only returns this when an explicit agent_id is passed + // in; we pass None above. Treat as a soft failure rather than panic + // so a future change in the routing crate cannot crash the gateway. + warn!( + %agent_id, + "routing crate returned ExplicitAgentNotFound for an implicit lookup; falling back to direct" + ); + Ok(Self::Direct(target)) + } + } + } + + /// Establish a concrete transport based on the plan. + /// + /// For `Direct`, does a TCP connect. For `ViaAgent`, tries each candidate + /// in order until one succeeds. + pub(crate) async fn execute( + self, + handle: Option<&AgentTunnelHandle>, + session_id: Uuid, + ) -> Result { + match self { + Self::Direct(target) => { + trace!(%target, "Connecting to target directly"); + + let (stream, server_addr) = utils::tcp_connect(target).await?; + + trace!(%target, "Connected"); + + Ok(ConnectedUpstream { + leg: UpstreamLeg::Tcp(stream), + server_addr, + selected_target: target.clone(), + }) + } + Self::ViaAgent { target, candidates } => { + let handle = handle.expect("route plan requires configured agent tunnel"); + let mut last_error = None; + + for agent in &candidates { + info!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + target = %target.as_addr(), + "Routing via agent tunnel" + ); + + match handle + .connect_via_agent(agent.agent_id, session_id, target.as_addr()) + .await + { + Ok(stream) => { + // The TCP peer lives on the agent side; surface the target + // IP:port for logs/PCAP when the target is a literal IP, or + // 0.0.0.0: when it's a hostname the gateway never + // resolved itself. Either is more useful than 0.0.0.0:0. + let server_addr = match target.host_ip() { + Some(ip) => SocketAddr::new(ip, target.port()), + None => SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, target.port())), + }; + + return Ok(ConnectedUpstream { + leg: UpstreamLeg::Tunnel(stream), + server_addr, + selected_target: target.clone(), + }); + } + Err(error) => { + warn!( + agent_id = %agent.agent_id, + agent_name = %agent.name, + target = %target.as_addr(), + error = format!("{error:#}"), + "Agent tunnel candidate failed" + ); + last_error = Some(error); + } + } + } + + Err(last_error.unwrap_or_else(|| anyhow!("all agent tunnel candidates failed"))) + } + } + } +} + +fn route_target_from_target_addr(target: &TargetAddr) -> RouteTarget { + match target.host_ip() { + Some(ip) => RouteTarget::ip(ip), + None => RouteTarget::hostname(target.host()), + } +} + +// --------------------------------------------------------------------------- +// Public entry points +// --------------------------------------------------------------------------- + +/// Iterate `targets` in token order, resolving and connecting each. The first +/// successful connection wins. +/// +/// Errors from earlier targets are chained onto the final error so the caller +/// sees the full failure story. +pub async fn connect_upstream( + targets: &NonEmpty, + explicit_agent_id: Option, + session_id: Uuid, + handle: Option<&AgentTunnelHandle>, +) -> Result { + let mut accumulated: Option = None; + + for target in targets { + let attempt = async { + RoutePlan::resolve(handle, explicit_agent_id, target) + .await? + .execute(handle, session_id) + .await + }; + + match attempt.await { + Ok(connected) => return Ok(connected), + Err(error) => { + let annotated = error.context(format!("{target} failed")); + accumulated = Some(match accumulated.take() { + Some(prev) => prev.context(annotated), + None => annotated, + }); + } + } + } + + Err(accumulated.unwrap_or_else(|| anyhow!("no target candidates available"))) +} + +/// Optionally wrap a [`ConnectedUpstream`] in a client TLS session. +/// +/// When `mode` is [`UpstreamMode::Tls`] the TLS handshake uses the gateway's +/// safe verifier with an optional SHA-256 thumbprint pin; otherwise the +/// session is returned as a plain TCP transport. +pub async fn prepare_upstream( + connected: ConnectedUpstream, + mode: UpstreamMode, + cert_thumb256: Option, +) -> Result { + let ConnectedUpstream { + leg, + server_addr, + selected_target, + } = connected; + + let session = match mode { + UpstreamMode::Tcp => UpstreamSession::Tcp(leg), + UpstreamMode::Tls => { + trace!(target = %selected_target, "Establishing TLS connection with upstream"); + + let tls_stream = crate::tls::safe_connect(selected_target.host().to_owned(), leg, cert_thumb256) + .await + .context("TLS connect")?; + + UpstreamSession::Tls(Box::new(tls_stream)) + } + }; + + Ok(PreparedUpstream { + session, + server_addr, + selected_target, + }) +}