diff --git a/libshpool/src/attach.rs b/libshpool/src/attach.rs index 816f584d..a5c614d8 100644 --- a/libshpool/src/attach.rs +++ b/libshpool/src/attach.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Google LLC +// Copyright 2023-2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{env, fmt, io, path::PathBuf, thread, time}; +use std::{ + collections::HashMap, + env, fmt, io, + os::fd::AsFd, + path::PathBuf, + sync::{Arc, Mutex}, + thread, time, +}; use anyhow::{anyhow, bail, Context}; +use nix::unistd; use shpool_protocol::{ - AttachHeader, AttachReplyHeader, ConnectHeader, DetachReply, DetachRequest, ResizeReply, - ResizeRequest, SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload, - TtySize, + AttachHeader, AttachReplyHeader, ConnectHeader, DetachReply, DetachRequest, MaybeSwitch, + ResizeReply, ResizeRequest, SessionMessageReply, SessionMessageRequest, + SessionMessageRequestPayload, TtySize, }; use tracing::{debug, error, info, warn}; -use super::{config, duration, protocol, protocol::ClientResult, test_hooks, tty::TtySizeExt as _}; +use crate::{ + config, duration, protocol, + protocol::{ClientResult, PipeBytesResult}, + template, test_hooks, + tty::TtySizeExt as _, +}; const MAX_FORCE_RETRIES: usize = 20; @@ -40,18 +53,7 @@ pub fn run( info!("\n\n======================== STARTING ATTACH ============================\n\n"); test_hooks::emit("attach-startup"); - if name.is_empty() { - eprintln!("blank session names are not allowed"); - return Ok(()); - } - if name.contains(char::is_whitespace) { - eprintln!("whitespace is not allowed in session names"); - return Ok(()); - } - - if !background { - SignalHandler::new(name.clone(), socket.clone()).spawn()?; - } + let session_name_tmpl = template::Template::new(&name).context("parsing session name tmpl")?; let ttl = match &ttl { Some(src) => match duration::parse(src.as_str()) { @@ -63,224 +65,312 @@ pub fn run( None => None, }; - let mut detached = false; - let mut tries = 0; - let attach_client = loop { - match do_attach(&config_manager, name.as_str(), background, &ttl, &cmd, &dir, &socket) { - Ok(client) => break client, - Err(err) => match err.downcast() { - Ok(BusyError) if !force => { - eprintln!("session '{name}' already has a terminal attached"); - return Ok(()); - } - Ok(BusyError) => { - if !detached { - let mut client = dial_client(&socket, background)?; - client - .write_connect_header(ConnectHeader::Detach(DetachRequest { - sessions: vec![name.clone()], - })) - .context("writing detach request header")?; - let detach_reply: DetachReply = - client.read_reply().context("reading reply")?; - if !detach_reply.not_found_sessions.is_empty() { - warn!("could not find session '{}' to detach it", name); - } + let attach = + Attach { config_manager, session_name_tmpl, force, background, ttl, cmd, dir, socket }; - detached = true; - } - thread::sleep(time::Duration::from_millis(100)); + attach.run() +} - if tries > MAX_FORCE_RETRIES { - eprintln!("session '{name}' already has a terminal which remains attached even after attempting to detach it"); - return Err(anyhow!("could not detach session, forced attach failed")); - } - tries += 1; - } - Err(err) => return Err(err), - }, - } - }; +struct Attach { + config_manager: config::Manager, + session_name_tmpl: template::Template, + force: bool, + background: bool, + ttl: Option, + cmd: Option, + dir: Option, + socket: PathBuf, +} - if background { - // Close the attached connection first so the daemon can observe EOF. - // We still send an explicit Detach on a fresh connection as a best-effort - // fallback in case EOF processing is delayed. - drop(attach_client); - let mut client = dial_client(&socket, true)?; - client - .write_connect_header(ConnectHeader::Detach(DetachRequest { - sessions: vec![name.clone()], - })) - .context("writing detach request header")?; - let detach_reply: DetachReply = client.read_reply().context("reading reply")?; - if !detach_reply.not_found_sessions.is_empty() { - warn!("could not find session '{}' to detach it", name); - } - if !detach_reply.not_attached_sessions.is_empty() { - debug!( - "session '{}' was already detached while processing background detach request (expected)", - name - ); +impl Attach { + fn run(self) -> anyhow::Result<()> { + // This is the first time we dial the daemon, so we do want to show + // warnings. After this we shouldn't show them again. + let mut client = self.dial_client(false).context("dialing daemon")?; + info!("dialed initial conn for GetVars"); + + client.write_connect_header(ConnectHeader::GetVars).context("getting vars")?; + let mut maybe_switch: MaybeSwitch = client.read_reply().context("reading reply")?; + + let var_map = maybe_switch.vars.iter().cloned().collect(); + let mut resolved_name = self.session_name_tmpl.apply(&var_map); + + let sig_handler_session_name_slot = if !self.background { + Some(SignalHandler::new(resolved_name.clone(), self.socket.clone()).spawn()?) + } else { + None + }; + + info!("looping on attach_with_name"); + loop { + match self.attach_with_name(resolved_name) { + Ok(AttachResult::Done) => return Ok(()), + Ok(AttachResult::Switch(s)) => maybe_switch = s, + Err(e) => return Err(e), + } + + let var_map = maybe_switch.vars.iter().cloned().collect(); + resolved_name = self.session_name_tmpl.apply(&var_map); + + if let Some(ref slot) = sig_handler_session_name_slot { + let mut slot = slot.lock().unwrap(); + *slot = resolved_name.clone(); + } } - return Ok(()); } - match attach_client.pipe_bytes() { - Ok(exit_status) => std::process::exit(exit_status), - Err(e) => Err(e), - } -} + /// Attach with the given resolved name. This will run until exit or until + /// we need to reconnect due to + pub fn attach_with_name(&self, resolved_name: String) -> anyhow::Result { + if resolved_name.is_empty() { + eprintln!("blank session names are not allowed"); + return Ok(AttachResult::Done); + } + if resolved_name.contains(char::is_whitespace) { + eprintln!("session name '{}' may not have whitespace", resolved_name); + return Ok(AttachResult::Done); + } -#[derive(Debug)] -struct BusyError; -impl fmt::Display for BusyError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BusyError") - } -} -impl std::error::Error for BusyError {} + let mut detached = false; + let mut tries = 0; + let attach_client = loop { + match self.dial_attach(resolved_name.as_str()) { + Ok(client) => break client, + Err(err) => match err.downcast() { + Ok(BusyError) if !self.force => { + eprintln!("session '{resolved_name}' already has a terminal attached"); + return Ok(AttachResult::Done); + } + Ok(BusyError) => { + if !detached { + let mut client = self.dial_client(true)?; + client + .write_connect_header(ConnectHeader::Detach(DetachRequest { + sessions: vec![resolved_name.clone()], + })) + .context("writing detach request header")?; + let detach_reply: DetachReply = + client.read_reply().context("reading reply")?; + if !detach_reply.not_found_sessions.is_empty() { + warn!("could not find session '{}' to detach it", resolved_name); + } + + detached = true; + } + thread::sleep(time::Duration::from_millis(100)); -/// Attach to a session and return the connected client without piping stdio. -/// -/// `background` is forwarded to `dial_client` to suppress the interactive -/// version-mismatch prompt on stdin; no other behavior changes. -fn do_attach( - config: &config::Manager, - name: &str, - background: bool, - ttl: &Option, - cmd: &Option, - dir: &Option, - socket: &PathBuf, -) -> anyhow::Result { - let mut client = dial_client(socket, background)?; - - let tty_size = match TtySize::from_fd(0) { - Ok(s) => s, - Err(e) => { - warn!("stdin is not a tty, using default size (err: {e:?})"); - TtySize { rows: 24, cols: 80, xpixel: 0, ypixel: 0 } + if tries > MAX_FORCE_RETRIES { + eprintln!("session '{resolved_name}' already has a terminal which remains attached even after attempting to detach it"); + return Err(anyhow!("could not detach session, forced attach failed")); + } + tries += 1; + } + Err(err) => return Err(err), + }, + } + }; + info!("got attach client"); + + if self.background { + // Close the attached connection first so the daemon can observe EOF. + // We still send an explicit Detach on a fresh connection as a best-effort + // fallback in case EOF processing is delayed. + drop(attach_client); + let mut client = self.dial_client(true)?; + client + .write_connect_header(ConnectHeader::Detach(DetachRequest { + sessions: vec![resolved_name.clone()], + })) + .context("writing detach request header")?; + let detach_reply: DetachReply = client.read_reply().context("reading reply")?; + if !detach_reply.not_found_sessions.is_empty() { + warn!("could not find session '{}' to detach it", resolved_name); + } + if !detach_reply.not_attached_sessions.is_empty() { + debug!( + "session '{}' was already detached while processing background detach request (expected)", + resolved_name + ); + } + return Ok(AttachResult::Done); } - }; - let forward_env = config.get().forward_env.clone(); - let mut local_env_keys = vec!["TERM", "DISPLAY", "LANG", "SSH_AUTH_SOCK"]; - if let Some(fenv) = &forward_env { - for var in fenv.iter() { - local_env_keys.push(var); + info!("entering bidi streaming mode"); + let session_name_tmpl = self.session_name_tmpl.clone(); + match attach_client.pipe_bytes(move |maybe_switch: &MaybeSwitch| { + let var_map: HashMap = maybe_switch.vars.iter().cloned().collect(); + session_name_tmpl.apply(&var_map) != resolved_name + }) { + Ok(PipeBytesResult::Exit(exit_status)) => std::process::exit(exit_status), + Ok(PipeBytesResult::MaybeSwitch(s)) => Ok(AttachResult::Switch(s)), + Err(e) => Err(e), } } - info!("local env keys: {local_env_keys:?}"); - - let cwd = String::from(env::current_dir().context("getting cwd")?.to_string_lossy()); - let default_dir = config.get().default_dir.clone().unwrap_or(String::from("$HOME")); - let start_dir = match (default_dir.as_str(), dir.as_deref()) { - (".", None) => Some(cwd), - ("$HOME", None) => None, - (d, None) => Some(String::from(d)), - (_, Some(".")) => Some(cwd), - (_, Some(d)) => Some(String::from(d)), - }; - client - .write_connect_header(ConnectHeader::Attach(AttachHeader { - name: String::from(name), - local_tty_size: tty_size, - local_env: local_env_keys - .into_iter() - .filter_map(|var| { - let val = env::var(var).context("resolving var").ok()?; - Some((String::from(var), val)) - }) - .collect::>(), - ttl_secs: ttl.map(|d| d.as_secs()), - cmd: cmd.clone(), - dir: start_dir, - })) - .context("writing attach header")?; - - let attach_resp: AttachReplyHeader = client.read_reply().context("reading attach reply")?; - info!("attach_resp.status={:?}", attach_resp.status); - - { - use shpool_protocol::AttachStatus::*; - match attach_resp.status { - Busy => { - return Err(BusyError.into()); + /// Attach to a session and return the connected client without piping + /// stdio. + fn dial_attach(&self, name: &str) -> anyhow::Result { + let mut client = self.dial_client(true)?; + + let tty_size = match TtySize::from_fd(0) { + Ok(s) => s, + Err(e) => { + warn!("stdin is not a tty, using default size (err: {e:?})"); + TtySize { rows: 24, cols: 80, xpixel: 0, ypixel: 0 } } - Forbidden(reason) => { - eprintln!("forbidden: {reason}"); - return Err(anyhow!("forbidden: {reason}")); + }; + + let forward_env = self.config_manager.get().forward_env.clone(); + let mut local_env_keys = vec!["TERM", "DISPLAY", "LANG", "SSH_AUTH_SOCK"]; + if let Some(fenv) = &forward_env { + for var in fenv.iter() { + local_env_keys.push(var); } - Attached { warnings } => { - for warning in warnings.into_iter() { - eprintln!("shpool: warn: {warning}"); + } + info!("local env keys: {local_env_keys:?}"); + + let cwd = String::from(env::current_dir().context("getting cwd")?.to_string_lossy()); + let default_dir = + self.config_manager.get().default_dir.clone().unwrap_or(String::from("$HOME")); + let start_dir = match (default_dir.as_str(), self.dir.as_deref()) { + (".", None) => Some(cwd), + ("$HOME", None) => None, + (d, None) => Some(String::from(d)), + (_, Some(".")) => Some(cwd), + (_, Some(d)) => Some(String::from(d)), + }; + + client + .write_connect_header(ConnectHeader::Attach(AttachHeader { + name: String::from(name), + local_tty_size: tty_size, + local_env: local_env_keys + .into_iter() + .filter_map(|var| { + let val = env::var(var).context("resolving var").ok()?; + Some((String::from(var), val)) + }) + .collect::>(), + ttl_secs: self.ttl.map(|d| d.as_secs()), + cmd: self.cmd.clone(), + dir: start_dir, + })) + .context("writing attach header")?; + + let attach_resp: AttachReplyHeader = client.read_reply().context("reading attach reply")?; + info!("attach_resp.status={:?}", attach_resp.status); + + { + use shpool_protocol::AttachStatus::*; + match attach_resp.status { + Busy => { + return Err(BusyError.into()); } - info!("attached to an existing session: '{}'", name); - } - Created { warnings } => { - for warning in warnings.into_iter() { - eprintln!("shpool: warn: {warning}"); + Forbidden(reason) => { + eprintln!("forbidden: {reason}"); + return Err(anyhow!("forbidden: {reason}")); + } + Attached { warnings } => { + for warning in warnings.into_iter() { + eprintln!("shpool: warn: {warning}"); + } + info!("attached to an existing session: '{}'", name); + } + Created { warnings } => { + for warning in warnings.into_iter() { + eprintln!("shpool: warn: {warning}"); + } + info!("created a new session: '{}'", name); + } + UnexpectedError(err) => { + return Err(anyhow!("BUG: unexpected error attaching to '{}': {}", name, err)); } - info!("created a new session: '{}'", name); - } - UnexpectedError(err) => { - return Err(anyhow!("BUG: unexpected error attaching to '{}': {}", name, err)); } } + + Ok(client) } - Ok(client) -} + // Dial the daemon. If silent is true, don't attempt to warn the user. + // After the first dial, silent should always be true. + fn dial_client(&self, silent: bool) -> anyhow::Result { + match protocol::Client::new(&self.socket) { + Ok(ClientResult::JustClient(c)) => Ok(c), + Ok(ClientResult::VersionMismatch { warning, client }) => { + if silent { + return Ok(client); + } -fn dial_client(socket: &PathBuf, background: bool) -> anyhow::Result { - match protocol::Client::new(socket) { - Ok(ClientResult::JustClient(c)) => Ok(c), - Ok(ClientResult::VersionMismatch { warning, client }) => { - if background { - eprintln!( - "warning: {warning}, proceeding in background mode; try restarting your daemon" - ); - } else { - eprintln!("warning: {warning}, try restarting your daemon"); - eprintln!("hit enter to continue anyway or ^C to exit"); - - let _ = io::stdin() - .lines() - .next() - .context("waiting for a continue through a version mismatch")?; - } + if self.background { + eprintln!( + "warning: {warning}, proceeding in background mode; try restarting your daemon" + ); + } else { + eprintln!("warning: {warning}, try restarting your daemon"); + eprintln!("hit enter to continue anyway or ^C to exit"); + + let mut buf = [0u8; 1]; + loop { + match unistd::read(io::stdin().as_fd(), &mut buf) { + Ok(0) => break, + Ok(1) if buf[0] == b'\n' => break, + Ok(_) => continue, + Err(nix::errno::Errno::EINTR) => continue, + Err(e) => { + return Err(anyhow::Error::from(e)) + .context("waiting for a continue through a version mismatch") + } + } + } + info!("user continued through version mismatch"); + } - Ok(client) - } - Err(err) => { - let io_err = err.downcast::()?; - if io_err.kind() == io::ErrorKind::NotFound { - eprintln!("could not connect to daemon"); + Ok(client) + } + Err(err) => { + let io_err = err.downcast::()?; + if io_err.kind() == io::ErrorKind::NotFound { + eprintln!("could not connect to daemon"); + } + Err(io_err).context("connecting to daemon") } - Err(io_err).context("connecting to daemon") } } } +#[derive(Debug)] +struct BusyError; +impl fmt::Display for BusyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "BusyError") + } +} +impl std::error::Error for BusyError {} + +enum AttachResult { + Done, + Switch(MaybeSwitch), +} + // // Signal Handling // struct SignalHandler { - session_name: String, + session_name: Arc>, socket: PathBuf, } impl SignalHandler { fn new(session_name: String, socket: PathBuf) -> Self { - SignalHandler { session_name, socket } + SignalHandler { session_name: Arc::new(Mutex::new(session_name)), socket } } - fn spawn(self) -> anyhow::Result<()> { + fn spawn(self) -> anyhow::Result>> { use signal_hook::{consts::*, iterator::*}; + let session_name_slot = Arc::clone(&self.session_name); + let sigs = vec![SIGWINCH]; let mut signals = Signals::new(sigs).context("creating signal iterator")?; @@ -299,7 +389,7 @@ impl SignalHandler { } }); - Ok(()) + Ok(session_name_slot) } fn handle_sigwinch(&self) -> anyhow::Result<()> { @@ -318,7 +408,7 @@ impl SignalHandler { // write the request on a new, seperate connection client .write_connect_header(ConnectHeader::SessionMessage(SessionMessageRequest { - session_name: self.session_name.clone(), + session_name: self.get_session_name(), payload: SessionMessageRequestPayload::Resize(ResizeRequest { tty_size: tty_size.clone(), }), @@ -331,11 +421,15 @@ impl SignalHandler { SessionMessageReply::NotFound => { warn!( "handle_sigwinch: sent resize for session '{}', but the daemon has no record of that session", - self.session_name + self.get_session_name() ); } SessionMessageReply::Resize(ResizeReply::Ok) => { - info!("handle_sigwinch: resized session '{}' to {:?}", self.session_name, tty_size); + info!( + "handle_sigwinch: resized session '{}' to {:?}", + self.get_session_name(), + tty_size + ); } reply => { warn!("handle_sigwinch: unexpected resize reply: {:?}", reply); @@ -344,4 +438,9 @@ impl SignalHandler { Ok(()) } + + fn get_session_name(&self) -> String { + let session_name = self.session_name.lock().unwrap(); + session_name.clone() + } } diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index adbeb622..5eb16e70 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -35,9 +35,10 @@ use anyhow::{anyhow, Context}; use nix::unistd; use shpool_protocol::{ AttachHeader, AttachReplyHeader, AttachStatus, ConnectHeader, DetachReply, DetachRequest, - KillReply, KillRequest, ListReply, LogLevel, ResizeReply, Session, SessionMessageDetachReply, - SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload, SessionStatus, - SetLogLevelReply, SetLogLevelRequest, VersionHeader, + KillReply, KillRequest, ListReply, LogLevel, MaybeSwitch, ModifyVarReply, ModifyVarRequest, + ResizeReply, Session, SessionMessageDetachReply, SessionMessageReply, SessionMessageRequest, + SessionMessageRequestPayload, SessionStatus, SetLogLevelReply, SetLogLevelRequest, + VersionHeader, }; use tracing::{debug, error, info, instrument, span, warn, Level}; @@ -79,6 +80,7 @@ pub struct Server { tracing_subscriber::filter::LevelFilter, tracing_subscriber::registry::Registry, >, + vars: Mutex>, } impl Server { @@ -112,6 +114,7 @@ impl Server { hooks, daily_messenger, log_level_handle, + vars: HashMap::new().into(), })) } @@ -209,6 +212,8 @@ impl Server { ConnectHeader::List => self.handle_list(stream), ConnectHeader::SessionMessage(header) => self.handle_session_message(stream, header), ConnectHeader::SetLogLevel(r) => self.handle_set_log_level(stream, r), + ConnectHeader::GetVars => self.handle_get_vars(stream), + ConnectHeader::ModifyVar(r) => self.handle_modify_var(stream, r), } } @@ -609,6 +614,57 @@ impl Server { Ok(()) } + #[instrument(skip_all)] + fn handle_get_vars(&self, mut stream: UnixStream) -> anyhow::Result<()> { + let maybe_switch = { + let var_map = self.vars.lock().unwrap(); + let vars: Vec<(String, String)> = + var_map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + shpool_protocol::MaybeSwitch { switch_to: None, vars } + }; + + write_reply(&mut stream, maybe_switch).context("writing maybe_switch reply")?; + Ok(()) + } + + #[instrument(skip_all)] + fn handle_modify_var( + &self, + mut stream: UnixStream, + request: ModifyVarRequest, + ) -> anyhow::Result<()> { + let maybe_switch = { + let mut vars = self.vars.lock().unwrap(); + if let Some(val) = request.val { + vars.insert(request.var, val); + } else { + vars.remove(&request.var); + } + + MaybeSwitch { + switch_to: None, + vars: vars.iter().map(|(k, v)| (k.clone(), v.clone())).collect(), + } + }; + + let mut ctls = Vec::new(); + { + let shells = self.shells.lock().unwrap(); + for (_, session) in shells.iter() { + ctls.push(Arc::clone(&session.shell_to_client_ctl)); + } + } + for ctl in ctls.into_iter() { + let ctl = ctl.lock().unwrap(); + ctl.maybe_switch + .send_timeout(maybe_switch.clone(), SESSION_MSG_TIMEOUT) + .context("broadcasting maybe_switch")?; + } + + write_reply(&mut stream, ModifyVarReply {}).context("writing modify var reply")?; + Ok(()) + } + #[instrument(skip_all)] fn handle_kill(&self, mut stream: UnixStream, request: KillRequest) -> anyhow::Result<()> { let mut not_found_sessions = vec![]; @@ -983,13 +1039,18 @@ impl Server { let (heartbeat_tx, heartbeat_rx) = crossbeam_channel::bounded(0); let (heartbeat_ack_tx, heartbeat_ack_rx) = crossbeam_channel::bounded(0); - let shell_to_client_ctl = Arc::new(Mutex::new(shell::ReaderCtl { + // We make this buffered to avoid blocking during a broadcast. There is + // no ack chan so we can afford to buffer a bit. + let (maybe_switch_tx, maybe_switch_rx) = crossbeam_channel::bounded(10); + + let shell_to_client_ctl = Arc::new(Mutex::new(shell::ShellToClientCtl { client_connection: client_connection_tx, client_connection_ack: client_connection_ack_rx, tty_size_change: tty_size_change_tx, tty_size_change_ack: tty_size_change_ack_rx, heartbeat: heartbeat_tx, heartbeat_ack: heartbeat_ack_rx, + maybe_switch: maybe_switch_tx, })); let mut session_inner = shell::SessionInner { @@ -1023,6 +1084,7 @@ impl Server { tty_size_change_ack: tty_size_change_ack_tx, heartbeat: heartbeat_rx, heartbeat_ack: heartbeat_ack_tx, + maybe_switch: maybe_switch_rx, child_exit_notifier: shell_to_client_child_exit_notifier, })?); diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index ee74bc2c..b7aa9ba2 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -28,12 +28,13 @@ use std::{ use anyhow::{anyhow, Context}; use nix::{poll, poll::PollFlags, sys::signal, unistd::Pid}; -use shpool_protocol::{Chunk, ChunkKind, TtySize}; +use shpool_protocol::{Chunk, ChunkKind, MaybeSwitch, TtySize}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use crate::{ common, consts, daemon::{config, exit_notify::ExitNotifier, keybindings, pager::PagerCtl, prompt, show_motd}, + protocol, protocol::ChunkExt as _, session_restore, test_hooks, tty::TtySizeExt as _, @@ -74,7 +75,7 @@ pub struct Session { pub lifecycle_timestamps: Mutex, pub child_pid: libc::pid_t, pub child_exit_notifier: Arc, - pub shell_to_client_ctl: Arc>, + pub shell_to_client_ctl: Arc>, pub pager_ctl: Arc>>, /// Mutable state with the lock held by the servicing handle_attach thread /// while a tty is attached to the session. Probing the mutex can be used @@ -109,7 +110,7 @@ impl Session { #[derive(Debug)] pub struct SessionInner { pub name: String, // to improve logging - pub shell_to_client_ctl: Arc>, + pub shell_to_client_ctl: Arc>, pub pty_master: shpool_pty::fork::Fork, pub client_stream: Option, pub config: config::Manager, @@ -207,6 +208,7 @@ pub struct ShellToClientArgs { pub tty_size_change: crossbeam_channel::Receiver, pub tty_size_change_ack: crossbeam_channel::Sender<()>, pub heartbeat: crossbeam_channel::Receiver<()>, + pub maybe_switch: crossbeam_channel::Receiver, // true if the client is still live, false if it has hung up on us pub heartbeat_ack: crossbeam_channel::Sender, pub child_exit_notifier: Arc, @@ -393,6 +395,42 @@ impl SessionInner { args.heartbeat_ack.send(client_present) .context("sending heartbeat ack")?; } + recv(args.maybe_switch) -> maybe_switch => { + let maybe_switch = match maybe_switch { + Ok(ms) => ms, + Err(e) => { + error!("error recving MaybeSwitch: {:?}", e); + continue; + }, + }; + + let conn = if let ClientConnectionMsg::New(c) = &mut client_conn { + c + } else { + info!("got MaybeSwitch, but no attached client, dropping"); + continue; + }; + + let mut encoded = Vec::new(); + if let Err(e) = protocol::encode_to(&maybe_switch, &mut encoded) { + error!("error encoding MaybeSwitch: {:?}", e); + continue; + } + + let chunk = Chunk { kind: ChunkKind::MaybeSwitch, buf: &encoded[..] }; + match chunk.write_to(&mut conn.sink).and_then(|_| conn.sink.flush()) { + Ok(_) => { + trace!("wrote MaybeSwitch"); + } + Err(e) if e.kind() == io::ErrorKind::BrokenPipe => { + trace!("writing MaybeSwitch: client hangup: {:?}", e); + } + Err(e) => { + error!("unexpected IO error while writing heartbeat: {}", e); + return Err(e).context("writing MaybeSwitch")?; + } + } + } // make this select non-blocking so we spend most of our time parked // in poll @@ -1003,7 +1041,7 @@ impl SessionInner { /// Shared between the session struct (for calls originating with the cli) /// and the session inner struct (for calls resulting from keybindings). #[derive(Debug)] -pub struct ReaderCtl { +pub struct ShellToClientCtl { /// A control channel for the shell->client thread. Whenever a new client /// dials in, the output stream for that client must be attached to the /// shell->client thread by sending it down this channel. A disconnect @@ -1028,6 +1066,12 @@ pub struct ReaderCtl { // True if the client is still listening, false if it has hung up // on us. pub heartbeat_ack: crossbeam_channel::Receiver, + + /// A control channel telling the shell->client thread to + /// broadcast the given MaybeSwitch. There is no ack channel + /// because we just blast this out and the caller doesn't need + /// to know about completion. + pub maybe_switch: crossbeam_channel::Sender, } /// Given a buffer, a length after which the data is not valid, a list of diff --git a/libshpool/src/lib.rs b/libshpool/src/lib.rs index 528fcd62..82e16a4c 100644 --- a/libshpool/src/lib.rs +++ b/libshpool/src/lib.rs @@ -43,9 +43,11 @@ mod list; mod protocol; mod session_restore; mod set_log_level; +mod template; mod test_hooks; mod tty; mod user; +mod var; /// The command line arguments that shpool expects. /// These can be directly parsed with clap or manually @@ -210,6 +212,54 @@ needs debugging, but would be clobbered by a restart.")] #[clap(help = "new log level")] level: shpool_protocol::LogLevel, }, + + #[clap(about = "Manipulate template variables + +shpool session names can include {variables} which are resolved via +an environment stored globally in the shpool daemon. This command +manipulates that environment. + +The main usecase for templated session names is the ability to switch +multiple shpool sessions to new targets at the same time. For example, +you might have a `shpool attach -f '{workspace}-edit'` session and +a `shpool attach -f '{workspace}-term'` session. To switch both +sessions from the fun-feature workspace to the key-bugfix workspace, +you could just do `shpool var set workspace key-bugfix`. +")] + #[non_exhaustive] + Var { + #[clap(subcommand)] + command: VarCommands, + }, +} + +/// The subcommds of the var command. +#[derive(Subcommand, Debug)] +#[non_exhaustive] +pub enum VarCommands { + #[clap(about = "List the variables + +This command dumps out the whole variable list with +both vars and values in a JSON object using vars as keys.")] + List { + #[clap(short, long, help = "Output as JSON")] + json: bool, + }, + #[clap(about = "Get a variable + +This returns the raw value of the given variable.")] + #[non_exhaustive] + Get { var: String }, + #[clap(about = "Set a variable + +This updates the value of the given variable.")] + #[non_exhaustive] + Set { var: String, val: String }, + #[clap(about = "Unset a variable + +This removes the given variable from the environment.")] + #[non_exhaustive] + Unset { var: String }, } impl Args { @@ -382,6 +432,7 @@ pub fn run(args: Args, hooks: Option>) -> an Commands::Kill { sessions } => kill::run(sessions, socket), Commands::List { json } => list::run(socket, json), Commands::SetLogLevel { level } => set_log_level::run(level, socket), + Commands::Var { command } => var::run(socket, command), }; if let Err(err) = res { diff --git a/libshpool/src/protocol.rs b/libshpool/src/protocol.rs index e4f64a4a..109a2e14 100644 --- a/libshpool/src/protocol.rs +++ b/libshpool/src/protocol.rs @@ -15,27 +15,29 @@ use std::{ cmp, io::{self, Read, Write}, - os::unix::net::UnixStream, + os::{fd::AsFd, unix::net::UnixStream}, path::Path, - sync::atomic::{AtomicI32, Ordering}, + sync::Mutex, thread, time, }; use anyhow::{anyhow, Context}; use byteorder::{LittleEndian, ReadBytesExt as _, WriteBytesExt as _}; +use nix::poll; use serde::{Deserialize, Serialize}; use shpool_protocol::{Chunk, ChunkKind, ConnectHeader, VersionHeader}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use super::{common, consts, tty}; -const DETACH_DISCONNECT_FAST_WAIT_DUR: time::Duration = time::Duration::from_millis(10); const MAX_DETACH_WAIT_DUR: time::Duration = time::Duration::from_millis(300); const DETACH_BACKOFF_INITIAL_DUR: time::Duration = time::Duration::from_millis(1); // Cap backoff steps so slow-path stays responsive while still avoiding busy // waits. const DETACH_BACKOFF_MAX_STEP_DUR: time::Duration = time::Duration::from_millis(25); +const STDIN_READ_POLL_MS: u16 = 50; + /// The centralized encoding function that should be used for all protocol /// serialization. pub fn encode_to(d: &T, w: W) -> anyhow::Result<()> @@ -241,16 +243,23 @@ impl Client { /// socket and back again. It is the main loop of /// `shpool attach`. /// - /// Return value: the exit status that `shpool attach` should - /// exit with. + /// The on_maybe_switch callback should return true if pipe_bytes + /// should exit with a PipeBytesResult::MaybeSwitch because the + /// attach process needs to reattach. #[instrument(skip_all)] - pub fn pipe_bytes(self) -> anyhow::Result { + pub fn pipe_bytes( + self, + on_maybe_switch: OnMaybeSwitchF, + ) -> anyhow::Result + where + OnMaybeSwitchF: Fn(&shpool_protocol::MaybeSwitch) -> bool + Send + Sync + 'static, + { let tty_guard = tty::set_attach_flags()?; let mut read_client_stream = self.stream.try_clone().context("cloning read stream")?; let mut write_client_stream = self.stream.try_clone().context("cloning read stream")?; - let exit_status = AtomicI32::new(1); + let result_slot = Mutex::new(None); thread::scope(|s| { // stdin -> sock let stdin_to_sock_h = s.spawn(|| -> anyhow::Result<()> { @@ -259,6 +268,24 @@ impl Client { let mut buf = vec![0; consts::BUF_SIZE]; loop { + { + let res = result_slot.lock().unwrap(); + if res.is_some() { + return Ok(()); + } + } + + { + let mut poll_fds = + [poll::PollFd::new(stdin.as_fd(), poll::PollFlags::POLLIN)]; + let nready = poll::poll(&mut poll_fds, STDIN_READ_POLL_MS) + .context("polling stdin")?; + if nready == 0 { + // timeout + continue; + } + } + let nread = stdin.read(&mut buf).context("reading stdin from user")?; if nread == 0 { return Ok(()); @@ -281,6 +308,26 @@ impl Client { let mut buf = vec![0; consts::BUF_SIZE]; loop { + { + let res = result_slot.lock().unwrap(); + if res.is_some() { + return Ok(()); + } + } + + { + let mut poll_fds = [poll::PollFd::new( + read_client_stream.as_fd(), + poll::PollFlags::POLLIN, + )]; + let nready = poll::poll(&mut poll_fds, STDIN_READ_POLL_MS) + .context("polling stdin")?; + if nready == 0 { + // timeout + continue; + } + } + let chunk = match Chunk::read_into(&mut read_client_stream, &mut buf) { Ok(c) => c, Err(err) => { @@ -323,46 +370,48 @@ impl Client { .read_i32::() .context("reading exit status from exit status chunk")?; info!("got exit status frame (status={})", stat); - exit_status.store(stat, Ordering::Release); + { + let mut res = result_slot.lock().unwrap(); + *res = Some(PipeBytesResult::Exit(stat)); + } + } + ChunkKind::MaybeSwitch => { + let maybe_switch_reader = io::Cursor::new(chunk.buf); + let maybe_switch: shpool_protocol::MaybeSwitch = + decode_from(maybe_switch_reader).context("decoding vars list")?; + + info!("got vars update (maybe_switch={:?})", maybe_switch); + if on_maybe_switch(&maybe_switch) { + let mut res = result_slot.lock().unwrap(); + *res = Some(PipeBytesResult::MaybeSwitch(maybe_switch)); + } } } } }); loop { - let mut nfinished_threads = 0; - if stdin_to_sock_h.is_finished() { - nfinished_threads += 1; - } - if sock_to_stdout_h.is_finished() { - nfinished_threads += 1; - } + let mut nfinished_threads = (stdin_to_sock_h.is_finished() as usize) + + (sock_to_stdout_h.is_finished() as usize); if nfinished_threads > 0 { - if nfinished_threads < 2 { - // Fast-path: when sock->stdout already ended (detach/disconnect), - // stdin->sock can stay blocked on stdin. In that case, do a very - // short grace wait and then exit quickly. This is independent - // of stdin being a TTY or a pipe. - // Slow-path: for other shutdown orders, keep compatibility by - // waiting up to 300ms with backoff. - let mut stdin_done = stdin_to_sock_h.is_finished(); - let mut stdout_done = sock_to_stdout_h.is_finished(); - - // Keep max_wait fixed for this detach sequence. Recomputing it inside - // the loop could accidentally switch paths mid-cleanup. - let max_wait = if stdout_done && !stdin_done { - DETACH_DISCONNECT_FAST_WAIT_DUR - } else { - MAX_DETACH_WAIT_DUR - }; + // If one of the threads has exited, but not the other + // make sure that the exit result slot has some contents + // so the other thread will exit the next time it wakes + // from its poll(). + { + let mut res = result_slot.lock().unwrap(); + if res.is_none() { + *res = Some(PipeBytesResult::Exit(1)); + } + } + if nfinished_threads < 2 { let finished_waiting = common::sleep_unless( - max_wait, + MAX_DETACH_WAIT_DUR, || { - stdin_done = stdin_to_sock_h.is_finished(); - stdout_done = sock_to_stdout_h.is_finished(); - nfinished_threads = (stdin_done as usize) + (stdout_done as usize); + nfinished_threads = (stdin_to_sock_h.is_finished() as usize) + + (sock_to_stdout_h.is_finished() as usize); nfinished_threads >= 2 }, common::PollStrategy::Backoff { @@ -375,12 +424,17 @@ impl Client { if !finished_waiting { // Re-probe after timeout because thread state can change // during the final sleep inside sleep_unless. - stdin_done = stdin_to_sock_h.is_finished(); - stdout_done = sock_to_stdout_h.is_finished(); - nfinished_threads = (stdin_done as usize) + (stdout_done as usize); + nfinished_threads = (stdin_to_sock_h.is_finished() as usize) + + (sock_to_stdout_h.is_finished() as usize); } if nfinished_threads < 2 { + // It should be impossible to get here because both + // loops use poll() to wake up every so often to + // check if they need to exit. Nevertheless, if + // we somehow still have a stuck thread at this + // point, we'll just exit. + // If one of the worker threads is done and the // other is not exiting, we are likely blocked on // some IO. Fortunately, since there isn't much else @@ -389,15 +443,20 @@ impl Client { // by just hard-exiting the whole process. This allows // us to use simple blocking IO. warn!( - "exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}", - stdin_done, - stdout_done + "internal error: exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}", + stdin_to_sock_h.is_finished(), + sock_to_stdout_h.is_finished(), ); // make sure that we restore the tty flags on the input // tty before exiting the process. drop(tty_guard); - std::process::exit(exit_status.load(Ordering::Acquire)); + let res = result_slot.lock().unwrap(); + if let Some(PipeBytesResult::Exit(stat)) = *res { + std::process::exit(stat); + } else { + std::process::exit(1); + } } } break; @@ -425,11 +484,29 @@ impl Client { } } - Ok(exit_status.load(Ordering::Acquire)) + let mut ret = PipeBytesResult::Exit(1); + { + let res = result_slot.lock().unwrap(); + if let Some(r) = res.clone() { + ret = r; + } + } + Ok(ret) }) } } +#[derive(Clone)] +pub enum PipeBytesResult { + /// The attach proc should exit with the given exit status. + Exit(i32), + /// The on_maybe_switch callback has requested the client stop streaming + /// due to some change that requires a reconnect (almost certainly a + /// changed var in the session name template). The attach proc + /// should recompute any relevant templates and re-attach. + MaybeSwitch(shpool_protocol::MaybeSwitch), +} + #[cfg(test)] mod test { use super::*; diff --git a/libshpool/src/template.rs b/libshpool/src/template.rs new file mode 100644 index 00000000..d1be900a --- /dev/null +++ b/libshpool/src/template.rs @@ -0,0 +1,206 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use anyhow::anyhow; + +// A guess at how large the values of variables will be on average. +// This is intended as a slight over-estimate as we use it to compute +// the buffer size we should pre-allocate for instantiation. +const VAR_SIZE_GUESS: usize = 40; + +/// A template is a simple variable substitution string template used +/// by the templated session name feature to allow automatic client +/// switching. +/// +/// The template syntax is that variable subsitutions look like +/// `#{var_name}`, where var_name must be some alphanumeric string. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Template { + chunks: Vec, + instantiated_size_guess: usize, +} + +/// A chunk is either a raw hunk of text or a variable substitution. +#[derive(Debug, Clone, Eq, PartialEq)] +enum Chunk { + Raw(String), + Var(String), +} + +impl Template { + pub fn new(src: &str) -> anyhow::Result