diff --git a/guest-agent/src/server.rs b/guest-agent/src/server.rs index 00d3d90c..e6cc8bd7 100644 --- a/guest-agent/src/server.rs +++ b/guest-agent/src/server.rs @@ -10,10 +10,11 @@ use crate::http_routes; use crate::rpc_service::{AppState, ExternalRpcHandler, InternalRpcHandler, InternalRpcHandlerV0}; use crate::socket_activation::{ActivatedSockets, ActivatedUnixListener}; use anyhow::{anyhow, Context, Result}; +use ra_rpc::rocket_helper::UnixPeerCredListener; use rocket::{ fairing::AdHoc, figment::Figment, - listener::{Bind, DefaultListener}, + listener::{unix::UnixListener, Bind, DefaultListener, Endpoint}, }; use rocket_vsock_listener::VsockListener; use sd_notify::{notify as sd_notify, NotifyState}; @@ -43,7 +44,7 @@ async fn run_internal_v0( if let Some(std_listener) = activated_socket { info!("Using systemd-activated socket for tappd.sock"); - let listener = ActivatedUnixListener::new(std_listener)?; + let listener = UnixPeerCredListener::new(ActivatedUnixListener::new(std_listener)?); sock_ready_tx.send(()).ok(); ignite .launch_on(listener) @@ -52,14 +53,29 @@ async fn run_internal_v0( } else { let endpoint = DefaultListener::bind_endpoint(&ignite) .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; + match endpoint { + Endpoint::Unix(_) => { + let listener = UnixPeerCredListener::new( + ::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?, + ); + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + _ => { + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + } } Ok(()) } @@ -80,7 +96,7 @@ async fn run_internal( if let Some(std_listener) = activated_socket { info!("Using systemd-activated socket for dstack.sock"); - let listener = ActivatedUnixListener::new(std_listener)?; + let listener = UnixPeerCredListener::new(ActivatedUnixListener::new(std_listener)?); sock_ready_tx.send(()).ok(); ignite .launch_on(listener) @@ -89,14 +105,29 @@ async fn run_internal( } else { let endpoint = DefaultListener::bind_endpoint(&ignite) .map_err(|err| anyhow!("Failed to get endpoint: {err}"))?; - let listener = DefaultListener::bind(&ignite) - .await - .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; sock_ready_tx.send(()).ok(); - ignite - .launch_on(listener) - .await - .map_err(|err| anyhow!(err.to_string()))?; + match endpoint { + Endpoint::Unix(_) => { + let listener = UnixPeerCredListener::new( + ::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?, + ); + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + _ => { + let listener = DefaultListener::bind(&ignite) + .await + .map_err(|err| anyhow!("Failed to bind on {endpoint}: {err}"))?; + ignite + .launch_on(listener) + .await + .map_err(|err| anyhow!(err.to_string()))?; + } + } } Ok(()) } diff --git a/ra-rpc/src/lib.rs b/ra-rpc/src/lib.rs index 253d8eaa..bf1e2208 100644 --- a/ra-rpc/src/lib.rs +++ b/ra-rpc/src/lib.rs @@ -22,12 +22,31 @@ pub mod client; #[cfg(feature = "openapi")] pub mod openapi; +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UnixPeerCred { + /// Peer process ID (platform-independent representation) + pub pid: u64, + /// Peer user ID + pub uid: u64, + /// Peer group ID + pub gid: u64, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum RemoteEndpoint { Tcp(SocketAddr), Quic(SocketAddr), - Unix(PathBuf), - Vsock { cid: u32, port: u32 }, + /// Unix domain socket endpoint. + /// + /// When available, `peer` can carry SO_PEERCRED (pid/uid/gid) of the caller. + Unix { + path: PathBuf, + peer: Option, + }, + Vsock { + cid: u32, + port: u32, + }, Other(String), } diff --git a/ra-rpc/src/rocket_helper.rs b/ra-rpc/src/rocket_helper.rs index 770b0739..b3074a4c 100644 --- a/ra-rpc/src/rocket_helper.rs +++ b/ra-rpc/src/rocket_helper.rs @@ -3,6 +3,11 @@ // SPDX-License-Identifier: Apache-2.0 use std::convert::Infallible; +use std::fmt; +use std::io; +use std::path::PathBuf; +use std::pin::Pin; +use std::task::{Context as TaskContext, Poll}; #[cfg(all(feature = "rocket", feature = "openapi"))] use crate::openapi::{OpenApiDoc, RenderedDoc}; @@ -13,6 +18,9 @@ use std::{borrow::Cow, sync::Arc}; use anyhow::{Context, Result}; use ra_tls::traits::CertExt; +use rocket::listener::unix::UnixStream; +use rocket::listener::{Connection, Listener}; +use rocket::tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use rocket::{ data::{ByteUnit, Data, Limits, ToByteUnit}, http::{uri::Origin, ContentType, Method, Status}, @@ -25,7 +33,7 @@ use rocket::{ use rocket_vsock_listener::VsockEndpoint; use tracing::warn; -use crate::{encode_error, CallContext, RemoteEndpoint, RpcCall}; +use crate::{encode_error, CallContext, RemoteEndpoint, RpcCall, UnixPeerCred}; pub struct RpcResponse { is_json: bool, @@ -48,6 +56,131 @@ impl<'r> Responder<'r, 'static> for RpcResponse { } } +#[derive(Debug, Clone)] +struct UnixPeerEndpoint { + path: PathBuf, + peer: Option, +} + +impl fmt::Display for UnixPeerEndpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "unix:{}", self.path.display()) + } +} + +pub struct UnixPeerCredListener { + inner: L, +} + +impl UnixPeerCredListener { + pub fn new(inner: L) -> Self { + Self { inner } + } +} + +pub struct UnixPeerCredConnection { + stream: UnixStream, + endpoint: rocket::listener::Endpoint, +} + +impl AsyncRead for UnixPeerCredConnection { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl AsyncWrite for UnixPeerCredConnection { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut TaskContext<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write_vectored(cx, bufs) + } +} + +impl Connection for UnixPeerCredConnection { + fn endpoint(&self) -> io::Result { + Ok(self.endpoint.clone()) + } +} + +impl Listener for UnixPeerCredListener +where + L: Listener, +{ + type Accept = UnixStream; + type Connection = UnixPeerCredConnection; + + async fn accept(&self) -> io::Result { + self.inner.accept().await + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + let path = accept + .local_addr()? + .as_pathname() + .map(PathBuf::from) + .or_else(|| { + self.inner + .endpoint() + .ok() + .and_then(|e| e.unix().map(PathBuf::from)) + }); + + let endpoint = match path { + Some(path) => rocket::listener::Endpoint::new(UnixPeerEndpoint { + path, + peer: unix_peer_cred(&accept), + }), + None => accept.local_addr()?.try_into()?, + }; + + Ok(UnixPeerCredConnection { + stream: accept, + endpoint, + }) + } + + fn endpoint(&self) -> io::Result { + self.inner.endpoint() + } +} + +fn unix_peer_cred(stream: &UnixStream) -> Option { + let cred = stream.peer_cred().ok()?; + let pid = cred.pid()?; + Some(UnixPeerCred { + pid: pid as u64, + uid: cred.uid() as u64, + gid: cred.gid() as u64, + }) +} + #[derive(Debug, Clone)] pub struct QuoteVerifier { pccs_url: Option, @@ -265,7 +398,27 @@ impl From for RemoteEndpoint { match endpoint { Endpoint::Tcp(addr) => RemoteEndpoint::Tcp(addr), Endpoint::Quic(addr) => RemoteEndpoint::Quic(addr), - Endpoint::Unix(path) => RemoteEndpoint::Unix(path), + Endpoint::Unix(path) => RemoteEndpoint::Unix { path, peer: None }, + Endpoint::Custom(endpoint) => { + if let Some(endpoint) = + (endpoint.as_ref() as &dyn std::any::Any).downcast_ref::() + { + RemoteEndpoint::Unix { + path: endpoint.path.clone(), + peer: endpoint.peer.clone(), + } + } else { + let address = endpoint.to_string(); + match address.parse::() { + Ok(addr) => RemoteEndpoint::Vsock { + cid: addr.cid, + port: addr.port, + }, + Err(_) => RemoteEndpoint::Other(address), + } + } + } + Endpoint::Tls(inner, _) => RemoteEndpoint::from((*inner).clone()), _ => { let address = endpoint.to_string(); match address.parse::() { @@ -280,6 +433,74 @@ impl From for RemoteEndpoint { } } +#[cfg(test)] +mod tests { + use super::*; + use rocket::listener::unix::UnixListener; + use rocket::tokio; + use std::time::{SystemTime, UNIX_EPOCH}; + + #[test] + fn custom_unix_endpoint_maps_to_remote_endpoint() { + let endpoint = Endpoint::new(UnixPeerEndpoint { + path: PathBuf::from("/tmp/test.sock"), + peer: Some(UnixPeerCred { + pid: 1, + uid: 2, + gid: 3, + }), + }); + + let remote = RemoteEndpoint::from(endpoint); + assert_eq!( + remote, + RemoteEndpoint::Unix { + path: PathBuf::from("/tmp/test.sock"), + peer: Some(UnixPeerCred { + pid: 1, + uid: 2, + gid: 3, + }), + } + ); + } + + #[tokio::test] + async fn unix_peer_cred_listener_populates_peer() { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + let path = std::env::temp_dir().join(format!("ra-rpc-peer-{unique}.sock")); + + let listener = UnixListener::bind(&path, false).await.unwrap(); + let listener = UnixPeerCredListener::new(listener); + + let client = tokio::spawn({ + let path = path.clone(); + async move { tokio::net::UnixStream::connect(path).await } + }); + let accepted = listener.accept().await.unwrap(); + let _client = client.await.unwrap().unwrap(); + let conn = listener.connect(accepted).await.unwrap(); + + let remote = RemoteEndpoint::from(conn.endpoint().unwrap()); + match remote { + RemoteEndpoint::Unix { + path: got_path, + peer, + } => { + assert_eq!(got_path, path); + let peer = peer.expect("expected unix peer credentials"); + assert_eq!(peer.pid, std::process::id() as u64); + } + other => panic!("unexpected remote endpoint: {other:?}"), + } + + let _ = std::fs::remove_file(path); + } +} + pub async fn handle_prpc_impl>( args: PrpcHandler<'_, '_, S>, ) -> Result {