From 473433d27d42802a8e7609306b3109c3cadf6d1d Mon Sep 17 00:00:00 2001 From: Ethan Pailes Date: Wed, 29 Apr 2026 20:43:45 +0000 Subject: [PATCH] feat: add session name template support This patch adds support for session name templates so that you can switch multiple shpool sessions all at once. In some sense this is a super-set of the 'shpool switch' FR. I did lay some groundwork for implementing support for that in this change, though I'm starting to wonder if templates are good enough on their own. The only extra thing that switch would bring is the ability to switch sessions you don't pre-declare as switchable with a dedicated variable up front. One thing worth bikeshedding: At first I was thinking '${var}' for the substitution syntax, then realized that would be weird about nesting when it comes to shells, so I switched to '#{var}' syntax, but then I realized that's the comment char in shells, and wound up on '@{var}'. I'm open to other symbol/syntax ideas. BREAKING: this breaks shpool-protocol since we have a new chunk kind. --- libshpool/src/attach.rs | 495 ++++++++++++++++++++------------- libshpool/src/daemon/server.rs | 70 ++++- libshpool/src/daemon/shell.rs | 52 +++- libshpool/src/lib.rs | 51 ++++ libshpool/src/protocol.rs | 165 ++++++++--- libshpool/src/template.rs | 206 ++++++++++++++ libshpool/src/var.rs | 67 +++++ shpool-protocol/src/lib.rs | 39 ++- shpool/tests/attach.rs | 175 +++++++----- shpool/tests/daemon.rs | 2 +- shpool/tests/support/daemon.rs | 74 +++++ shpool/tests/var.rs | 137 +++++++++ 12 files changed, 1214 insertions(+), 319 deletions(-) create mode 100644 libshpool/src/template.rs create mode 100644 libshpool/src/var.rs create mode 100644 shpool/tests/var.rs diff --git a/libshpool/src/attach.rs b/libshpool/src/attach.rs index 816f584d..a5c614d8 100644 --- a/libshpool/src/attach.rs +++ b/libshpool/src/attach.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Google LLC +// Copyright 2023-2026 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,17 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{env, fmt, io, path::PathBuf, thread, time}; +use std::{ + collections::HashMap, + env, fmt, io, + os::fd::AsFd, + path::PathBuf, + sync::{Arc, Mutex}, + thread, time, +}; use anyhow::{anyhow, bail, Context}; +use nix::unistd; use shpool_protocol::{ - AttachHeader, AttachReplyHeader, ConnectHeader, DetachReply, DetachRequest, ResizeReply, - ResizeRequest, SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload, - TtySize, + AttachHeader, AttachReplyHeader, ConnectHeader, DetachReply, DetachRequest, MaybeSwitch, + ResizeReply, ResizeRequest, SessionMessageReply, SessionMessageRequest, + SessionMessageRequestPayload, TtySize, }; use tracing::{debug, error, info, warn}; -use super::{config, duration, protocol, protocol::ClientResult, test_hooks, tty::TtySizeExt as _}; +use crate::{ + config, duration, protocol, + protocol::{ClientResult, PipeBytesResult}, + template, test_hooks, + tty::TtySizeExt as _, +}; const MAX_FORCE_RETRIES: usize = 20; @@ -40,18 +53,7 @@ pub fn run( info!("\n\n======================== STARTING ATTACH ============================\n\n"); test_hooks::emit("attach-startup"); - if name.is_empty() { - eprintln!("blank session names are not allowed"); - return Ok(()); - } - if name.contains(char::is_whitespace) { - eprintln!("whitespace is not allowed in session names"); - return Ok(()); - } - - if !background { - SignalHandler::new(name.clone(), socket.clone()).spawn()?; - } + let session_name_tmpl = template::Template::new(&name).context("parsing session name tmpl")?; let ttl = match &ttl { Some(src) => match duration::parse(src.as_str()) { @@ -63,224 +65,312 @@ pub fn run( None => None, }; - let mut detached = false; - let mut tries = 0; - let attach_client = loop { - match do_attach(&config_manager, name.as_str(), background, &ttl, &cmd, &dir, &socket) { - Ok(client) => break client, - Err(err) => match err.downcast() { - Ok(BusyError) if !force => { - eprintln!("session '{name}' already has a terminal attached"); - return Ok(()); - } - Ok(BusyError) => { - if !detached { - let mut client = dial_client(&socket, background)?; - client - .write_connect_header(ConnectHeader::Detach(DetachRequest { - sessions: vec![name.clone()], - })) - .context("writing detach request header")?; - let detach_reply: DetachReply = - client.read_reply().context("reading reply")?; - if !detach_reply.not_found_sessions.is_empty() { - warn!("could not find session '{}' to detach it", name); - } + let attach = + Attach { config_manager, session_name_tmpl, force, background, ttl, cmd, dir, socket }; - detached = true; - } - thread::sleep(time::Duration::from_millis(100)); + attach.run() +} - if tries > MAX_FORCE_RETRIES { - eprintln!("session '{name}' already has a terminal which remains attached even after attempting to detach it"); - return Err(anyhow!("could not detach session, forced attach failed")); - } - tries += 1; - } - Err(err) => return Err(err), - }, - } - }; +struct Attach { + config_manager: config::Manager, + session_name_tmpl: template::Template, + force: bool, + background: bool, + ttl: Option, + cmd: Option, + dir: Option, + socket: PathBuf, +} - if background { - // Close the attached connection first so the daemon can observe EOF. - // We still send an explicit Detach on a fresh connection as a best-effort - // fallback in case EOF processing is delayed. - drop(attach_client); - let mut client = dial_client(&socket, true)?; - client - .write_connect_header(ConnectHeader::Detach(DetachRequest { - sessions: vec![name.clone()], - })) - .context("writing detach request header")?; - let detach_reply: DetachReply = client.read_reply().context("reading reply")?; - if !detach_reply.not_found_sessions.is_empty() { - warn!("could not find session '{}' to detach it", name); - } - if !detach_reply.not_attached_sessions.is_empty() { - debug!( - "session '{}' was already detached while processing background detach request (expected)", - name - ); +impl Attach { + fn run(self) -> anyhow::Result<()> { + // This is the first time we dial the daemon, so we do want to show + // warnings. After this we shouldn't show them again. + let mut client = self.dial_client(false).context("dialing daemon")?; + info!("dialed initial conn for GetVars"); + + client.write_connect_header(ConnectHeader::GetVars).context("getting vars")?; + let mut maybe_switch: MaybeSwitch = client.read_reply().context("reading reply")?; + + let var_map = maybe_switch.vars.iter().cloned().collect(); + let mut resolved_name = self.session_name_tmpl.apply(&var_map); + + let sig_handler_session_name_slot = if !self.background { + Some(SignalHandler::new(resolved_name.clone(), self.socket.clone()).spawn()?) + } else { + None + }; + + info!("looping on attach_with_name"); + loop { + match self.attach_with_name(resolved_name) { + Ok(AttachResult::Done) => return Ok(()), + Ok(AttachResult::Switch(s)) => maybe_switch = s, + Err(e) => return Err(e), + } + + let var_map = maybe_switch.vars.iter().cloned().collect(); + resolved_name = self.session_name_tmpl.apply(&var_map); + + if let Some(ref slot) = sig_handler_session_name_slot { + let mut slot = slot.lock().unwrap(); + *slot = resolved_name.clone(); + } } - return Ok(()); } - match attach_client.pipe_bytes() { - Ok(exit_status) => std::process::exit(exit_status), - Err(e) => Err(e), - } -} + /// Attach with the given resolved name. This will run until exit or until + /// we need to reconnect due to + pub fn attach_with_name(&self, resolved_name: String) -> anyhow::Result { + if resolved_name.is_empty() { + eprintln!("blank session names are not allowed"); + return Ok(AttachResult::Done); + } + if resolved_name.contains(char::is_whitespace) { + eprintln!("session name '{}' may not have whitespace", resolved_name); + return Ok(AttachResult::Done); + } -#[derive(Debug)] -struct BusyError; -impl fmt::Display for BusyError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "BusyError") - } -} -impl std::error::Error for BusyError {} + let mut detached = false; + let mut tries = 0; + let attach_client = loop { + match self.dial_attach(resolved_name.as_str()) { + Ok(client) => break client, + Err(err) => match err.downcast() { + Ok(BusyError) if !self.force => { + eprintln!("session '{resolved_name}' already has a terminal attached"); + return Ok(AttachResult::Done); + } + Ok(BusyError) => { + if !detached { + let mut client = self.dial_client(true)?; + client + .write_connect_header(ConnectHeader::Detach(DetachRequest { + sessions: vec![resolved_name.clone()], + })) + .context("writing detach request header")?; + let detach_reply: DetachReply = + client.read_reply().context("reading reply")?; + if !detach_reply.not_found_sessions.is_empty() { + warn!("could not find session '{}' to detach it", resolved_name); + } + + detached = true; + } + thread::sleep(time::Duration::from_millis(100)); -/// Attach to a session and return the connected client without piping stdio. -/// -/// `background` is forwarded to `dial_client` to suppress the interactive -/// version-mismatch prompt on stdin; no other behavior changes. -fn do_attach( - config: &config::Manager, - name: &str, - background: bool, - ttl: &Option, - cmd: &Option, - dir: &Option, - socket: &PathBuf, -) -> anyhow::Result { - let mut client = dial_client(socket, background)?; - - let tty_size = match TtySize::from_fd(0) { - Ok(s) => s, - Err(e) => { - warn!("stdin is not a tty, using default size (err: {e:?})"); - TtySize { rows: 24, cols: 80, xpixel: 0, ypixel: 0 } + if tries > MAX_FORCE_RETRIES { + eprintln!("session '{resolved_name}' already has a terminal which remains attached even after attempting to detach it"); + return Err(anyhow!("could not detach session, forced attach failed")); + } + tries += 1; + } + Err(err) => return Err(err), + }, + } + }; + info!("got attach client"); + + if self.background { + // Close the attached connection first so the daemon can observe EOF. + // We still send an explicit Detach on a fresh connection as a best-effort + // fallback in case EOF processing is delayed. + drop(attach_client); + let mut client = self.dial_client(true)?; + client + .write_connect_header(ConnectHeader::Detach(DetachRequest { + sessions: vec![resolved_name.clone()], + })) + .context("writing detach request header")?; + let detach_reply: DetachReply = client.read_reply().context("reading reply")?; + if !detach_reply.not_found_sessions.is_empty() { + warn!("could not find session '{}' to detach it", resolved_name); + } + if !detach_reply.not_attached_sessions.is_empty() { + debug!( + "session '{}' was already detached while processing background detach request (expected)", + resolved_name + ); + } + return Ok(AttachResult::Done); } - }; - let forward_env = config.get().forward_env.clone(); - let mut local_env_keys = vec!["TERM", "DISPLAY", "LANG", "SSH_AUTH_SOCK"]; - if let Some(fenv) = &forward_env { - for var in fenv.iter() { - local_env_keys.push(var); + info!("entering bidi streaming mode"); + let session_name_tmpl = self.session_name_tmpl.clone(); + match attach_client.pipe_bytes(move |maybe_switch: &MaybeSwitch| { + let var_map: HashMap = maybe_switch.vars.iter().cloned().collect(); + session_name_tmpl.apply(&var_map) != resolved_name + }) { + Ok(PipeBytesResult::Exit(exit_status)) => std::process::exit(exit_status), + Ok(PipeBytesResult::MaybeSwitch(s)) => Ok(AttachResult::Switch(s)), + Err(e) => Err(e), } } - info!("local env keys: {local_env_keys:?}"); - - let cwd = String::from(env::current_dir().context("getting cwd")?.to_string_lossy()); - let default_dir = config.get().default_dir.clone().unwrap_or(String::from("$HOME")); - let start_dir = match (default_dir.as_str(), dir.as_deref()) { - (".", None) => Some(cwd), - ("$HOME", None) => None, - (d, None) => Some(String::from(d)), - (_, Some(".")) => Some(cwd), - (_, Some(d)) => Some(String::from(d)), - }; - client - .write_connect_header(ConnectHeader::Attach(AttachHeader { - name: String::from(name), - local_tty_size: tty_size, - local_env: local_env_keys - .into_iter() - .filter_map(|var| { - let val = env::var(var).context("resolving var").ok()?; - Some((String::from(var), val)) - }) - .collect::>(), - ttl_secs: ttl.map(|d| d.as_secs()), - cmd: cmd.clone(), - dir: start_dir, - })) - .context("writing attach header")?; - - let attach_resp: AttachReplyHeader = client.read_reply().context("reading attach reply")?; - info!("attach_resp.status={:?}", attach_resp.status); - - { - use shpool_protocol::AttachStatus::*; - match attach_resp.status { - Busy => { - return Err(BusyError.into()); + /// Attach to a session and return the connected client without piping + /// stdio. + fn dial_attach(&self, name: &str) -> anyhow::Result { + let mut client = self.dial_client(true)?; + + let tty_size = match TtySize::from_fd(0) { + Ok(s) => s, + Err(e) => { + warn!("stdin is not a tty, using default size (err: {e:?})"); + TtySize { rows: 24, cols: 80, xpixel: 0, ypixel: 0 } } - Forbidden(reason) => { - eprintln!("forbidden: {reason}"); - return Err(anyhow!("forbidden: {reason}")); + }; + + let forward_env = self.config_manager.get().forward_env.clone(); + let mut local_env_keys = vec!["TERM", "DISPLAY", "LANG", "SSH_AUTH_SOCK"]; + if let Some(fenv) = &forward_env { + for var in fenv.iter() { + local_env_keys.push(var); } - Attached { warnings } => { - for warning in warnings.into_iter() { - eprintln!("shpool: warn: {warning}"); + } + info!("local env keys: {local_env_keys:?}"); + + let cwd = String::from(env::current_dir().context("getting cwd")?.to_string_lossy()); + let default_dir = + self.config_manager.get().default_dir.clone().unwrap_or(String::from("$HOME")); + let start_dir = match (default_dir.as_str(), self.dir.as_deref()) { + (".", None) => Some(cwd), + ("$HOME", None) => None, + (d, None) => Some(String::from(d)), + (_, Some(".")) => Some(cwd), + (_, Some(d)) => Some(String::from(d)), + }; + + client + .write_connect_header(ConnectHeader::Attach(AttachHeader { + name: String::from(name), + local_tty_size: tty_size, + local_env: local_env_keys + .into_iter() + .filter_map(|var| { + let val = env::var(var).context("resolving var").ok()?; + Some((String::from(var), val)) + }) + .collect::>(), + ttl_secs: self.ttl.map(|d| d.as_secs()), + cmd: self.cmd.clone(), + dir: start_dir, + })) + .context("writing attach header")?; + + let attach_resp: AttachReplyHeader = client.read_reply().context("reading attach reply")?; + info!("attach_resp.status={:?}", attach_resp.status); + + { + use shpool_protocol::AttachStatus::*; + match attach_resp.status { + Busy => { + return Err(BusyError.into()); } - info!("attached to an existing session: '{}'", name); - } - Created { warnings } => { - for warning in warnings.into_iter() { - eprintln!("shpool: warn: {warning}"); + Forbidden(reason) => { + eprintln!("forbidden: {reason}"); + return Err(anyhow!("forbidden: {reason}")); + } + Attached { warnings } => { + for warning in warnings.into_iter() { + eprintln!("shpool: warn: {warning}"); + } + info!("attached to an existing session: '{}'", name); + } + Created { warnings } => { + for warning in warnings.into_iter() { + eprintln!("shpool: warn: {warning}"); + } + info!("created a new session: '{}'", name); + } + UnexpectedError(err) => { + return Err(anyhow!("BUG: unexpected error attaching to '{}': {}", name, err)); } - info!("created a new session: '{}'", name); - } - UnexpectedError(err) => { - return Err(anyhow!("BUG: unexpected error attaching to '{}': {}", name, err)); } } + + Ok(client) } - Ok(client) -} + // Dial the daemon. If silent is true, don't attempt to warn the user. + // After the first dial, silent should always be true. + fn dial_client(&self, silent: bool) -> anyhow::Result { + match protocol::Client::new(&self.socket) { + Ok(ClientResult::JustClient(c)) => Ok(c), + Ok(ClientResult::VersionMismatch { warning, client }) => { + if silent { + return Ok(client); + } -fn dial_client(socket: &PathBuf, background: bool) -> anyhow::Result { - match protocol::Client::new(socket) { - Ok(ClientResult::JustClient(c)) => Ok(c), - Ok(ClientResult::VersionMismatch { warning, client }) => { - if background { - eprintln!( - "warning: {warning}, proceeding in background mode; try restarting your daemon" - ); - } else { - eprintln!("warning: {warning}, try restarting your daemon"); - eprintln!("hit enter to continue anyway or ^C to exit"); - - let _ = io::stdin() - .lines() - .next() - .context("waiting for a continue through a version mismatch")?; - } + if self.background { + eprintln!( + "warning: {warning}, proceeding in background mode; try restarting your daemon" + ); + } else { + eprintln!("warning: {warning}, try restarting your daemon"); + eprintln!("hit enter to continue anyway or ^C to exit"); + + let mut buf = [0u8; 1]; + loop { + match unistd::read(io::stdin().as_fd(), &mut buf) { + Ok(0) => break, + Ok(1) if buf[0] == b'\n' => break, + Ok(_) => continue, + Err(nix::errno::Errno::EINTR) => continue, + Err(e) => { + return Err(anyhow::Error::from(e)) + .context("waiting for a continue through a version mismatch") + } + } + } + info!("user continued through version mismatch"); + } - Ok(client) - } - Err(err) => { - let io_err = err.downcast::()?; - if io_err.kind() == io::ErrorKind::NotFound { - eprintln!("could not connect to daemon"); + Ok(client) + } + Err(err) => { + let io_err = err.downcast::()?; + if io_err.kind() == io::ErrorKind::NotFound { + eprintln!("could not connect to daemon"); + } + Err(io_err).context("connecting to daemon") } - Err(io_err).context("connecting to daemon") } } } +#[derive(Debug)] +struct BusyError; +impl fmt::Display for BusyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "BusyError") + } +} +impl std::error::Error for BusyError {} + +enum AttachResult { + Done, + Switch(MaybeSwitch), +} + // // Signal Handling // struct SignalHandler { - session_name: String, + session_name: Arc>, socket: PathBuf, } impl SignalHandler { fn new(session_name: String, socket: PathBuf) -> Self { - SignalHandler { session_name, socket } + SignalHandler { session_name: Arc::new(Mutex::new(session_name)), socket } } - fn spawn(self) -> anyhow::Result<()> { + fn spawn(self) -> anyhow::Result>> { use signal_hook::{consts::*, iterator::*}; + let session_name_slot = Arc::clone(&self.session_name); + let sigs = vec![SIGWINCH]; let mut signals = Signals::new(sigs).context("creating signal iterator")?; @@ -299,7 +389,7 @@ impl SignalHandler { } }); - Ok(()) + Ok(session_name_slot) } fn handle_sigwinch(&self) -> anyhow::Result<()> { @@ -318,7 +408,7 @@ impl SignalHandler { // write the request on a new, seperate connection client .write_connect_header(ConnectHeader::SessionMessage(SessionMessageRequest { - session_name: self.session_name.clone(), + session_name: self.get_session_name(), payload: SessionMessageRequestPayload::Resize(ResizeRequest { tty_size: tty_size.clone(), }), @@ -331,11 +421,15 @@ impl SignalHandler { SessionMessageReply::NotFound => { warn!( "handle_sigwinch: sent resize for session '{}', but the daemon has no record of that session", - self.session_name + self.get_session_name() ); } SessionMessageReply::Resize(ResizeReply::Ok) => { - info!("handle_sigwinch: resized session '{}' to {:?}", self.session_name, tty_size); + info!( + "handle_sigwinch: resized session '{}' to {:?}", + self.get_session_name(), + tty_size + ); } reply => { warn!("handle_sigwinch: unexpected resize reply: {:?}", reply); @@ -344,4 +438,9 @@ impl SignalHandler { Ok(()) } + + fn get_session_name(&self) -> String { + let session_name = self.session_name.lock().unwrap(); + session_name.clone() + } } diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index adbeb622..5eb16e70 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -35,9 +35,10 @@ use anyhow::{anyhow, Context}; use nix::unistd; use shpool_protocol::{ AttachHeader, AttachReplyHeader, AttachStatus, ConnectHeader, DetachReply, DetachRequest, - KillReply, KillRequest, ListReply, LogLevel, ResizeReply, Session, SessionMessageDetachReply, - SessionMessageReply, SessionMessageRequest, SessionMessageRequestPayload, SessionStatus, - SetLogLevelReply, SetLogLevelRequest, VersionHeader, + KillReply, KillRequest, ListReply, LogLevel, MaybeSwitch, ModifyVarReply, ModifyVarRequest, + ResizeReply, Session, SessionMessageDetachReply, SessionMessageReply, SessionMessageRequest, + SessionMessageRequestPayload, SessionStatus, SetLogLevelReply, SetLogLevelRequest, + VersionHeader, }; use tracing::{debug, error, info, instrument, span, warn, Level}; @@ -79,6 +80,7 @@ pub struct Server { tracing_subscriber::filter::LevelFilter, tracing_subscriber::registry::Registry, >, + vars: Mutex>, } impl Server { @@ -112,6 +114,7 @@ impl Server { hooks, daily_messenger, log_level_handle, + vars: HashMap::new().into(), })) } @@ -209,6 +212,8 @@ impl Server { ConnectHeader::List => self.handle_list(stream), ConnectHeader::SessionMessage(header) => self.handle_session_message(stream, header), ConnectHeader::SetLogLevel(r) => self.handle_set_log_level(stream, r), + ConnectHeader::GetVars => self.handle_get_vars(stream), + ConnectHeader::ModifyVar(r) => self.handle_modify_var(stream, r), } } @@ -609,6 +614,57 @@ impl Server { Ok(()) } + #[instrument(skip_all)] + fn handle_get_vars(&self, mut stream: UnixStream) -> anyhow::Result<()> { + let maybe_switch = { + let var_map = self.vars.lock().unwrap(); + let vars: Vec<(String, String)> = + var_map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + shpool_protocol::MaybeSwitch { switch_to: None, vars } + }; + + write_reply(&mut stream, maybe_switch).context("writing maybe_switch reply")?; + Ok(()) + } + + #[instrument(skip_all)] + fn handle_modify_var( + &self, + mut stream: UnixStream, + request: ModifyVarRequest, + ) -> anyhow::Result<()> { + let maybe_switch = { + let mut vars = self.vars.lock().unwrap(); + if let Some(val) = request.val { + vars.insert(request.var, val); + } else { + vars.remove(&request.var); + } + + MaybeSwitch { + switch_to: None, + vars: vars.iter().map(|(k, v)| (k.clone(), v.clone())).collect(), + } + }; + + let mut ctls = Vec::new(); + { + let shells = self.shells.lock().unwrap(); + for (_, session) in shells.iter() { + ctls.push(Arc::clone(&session.shell_to_client_ctl)); + } + } + for ctl in ctls.into_iter() { + let ctl = ctl.lock().unwrap(); + ctl.maybe_switch + .send_timeout(maybe_switch.clone(), SESSION_MSG_TIMEOUT) + .context("broadcasting maybe_switch")?; + } + + write_reply(&mut stream, ModifyVarReply {}).context("writing modify var reply")?; + Ok(()) + } + #[instrument(skip_all)] fn handle_kill(&self, mut stream: UnixStream, request: KillRequest) -> anyhow::Result<()> { let mut not_found_sessions = vec![]; @@ -983,13 +1039,18 @@ impl Server { let (heartbeat_tx, heartbeat_rx) = crossbeam_channel::bounded(0); let (heartbeat_ack_tx, heartbeat_ack_rx) = crossbeam_channel::bounded(0); - let shell_to_client_ctl = Arc::new(Mutex::new(shell::ReaderCtl { + // We make this buffered to avoid blocking during a broadcast. There is + // no ack chan so we can afford to buffer a bit. + let (maybe_switch_tx, maybe_switch_rx) = crossbeam_channel::bounded(10); + + let shell_to_client_ctl = Arc::new(Mutex::new(shell::ShellToClientCtl { client_connection: client_connection_tx, client_connection_ack: client_connection_ack_rx, tty_size_change: tty_size_change_tx, tty_size_change_ack: tty_size_change_ack_rx, heartbeat: heartbeat_tx, heartbeat_ack: heartbeat_ack_rx, + maybe_switch: maybe_switch_tx, })); let mut session_inner = shell::SessionInner { @@ -1023,6 +1084,7 @@ impl Server { tty_size_change_ack: tty_size_change_ack_tx, heartbeat: heartbeat_rx, heartbeat_ack: heartbeat_ack_tx, + maybe_switch: maybe_switch_rx, child_exit_notifier: shell_to_client_child_exit_notifier, })?); diff --git a/libshpool/src/daemon/shell.rs b/libshpool/src/daemon/shell.rs index ee74bc2c..b7aa9ba2 100644 --- a/libshpool/src/daemon/shell.rs +++ b/libshpool/src/daemon/shell.rs @@ -28,12 +28,13 @@ use std::{ use anyhow::{anyhow, Context}; use nix::{poll, poll::PollFlags, sys::signal, unistd::Pid}; -use shpool_protocol::{Chunk, ChunkKind, TtySize}; +use shpool_protocol::{Chunk, ChunkKind, MaybeSwitch, TtySize}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use crate::{ common, consts, daemon::{config, exit_notify::ExitNotifier, keybindings, pager::PagerCtl, prompt, show_motd}, + protocol, protocol::ChunkExt as _, session_restore, test_hooks, tty::TtySizeExt as _, @@ -74,7 +75,7 @@ pub struct Session { pub lifecycle_timestamps: Mutex, pub child_pid: libc::pid_t, pub child_exit_notifier: Arc, - pub shell_to_client_ctl: Arc>, + pub shell_to_client_ctl: Arc>, pub pager_ctl: Arc>>, /// Mutable state with the lock held by the servicing handle_attach thread /// while a tty is attached to the session. Probing the mutex can be used @@ -109,7 +110,7 @@ impl Session { #[derive(Debug)] pub struct SessionInner { pub name: String, // to improve logging - pub shell_to_client_ctl: Arc>, + pub shell_to_client_ctl: Arc>, pub pty_master: shpool_pty::fork::Fork, pub client_stream: Option, pub config: config::Manager, @@ -207,6 +208,7 @@ pub struct ShellToClientArgs { pub tty_size_change: crossbeam_channel::Receiver, pub tty_size_change_ack: crossbeam_channel::Sender<()>, pub heartbeat: crossbeam_channel::Receiver<()>, + pub maybe_switch: crossbeam_channel::Receiver, // true if the client is still live, false if it has hung up on us pub heartbeat_ack: crossbeam_channel::Sender, pub child_exit_notifier: Arc, @@ -393,6 +395,42 @@ impl SessionInner { args.heartbeat_ack.send(client_present) .context("sending heartbeat ack")?; } + recv(args.maybe_switch) -> maybe_switch => { + let maybe_switch = match maybe_switch { + Ok(ms) => ms, + Err(e) => { + error!("error recving MaybeSwitch: {:?}", e); + continue; + }, + }; + + let conn = if let ClientConnectionMsg::New(c) = &mut client_conn { + c + } else { + info!("got MaybeSwitch, but no attached client, dropping"); + continue; + }; + + let mut encoded = Vec::new(); + if let Err(e) = protocol::encode_to(&maybe_switch, &mut encoded) { + error!("error encoding MaybeSwitch: {:?}", e); + continue; + } + + let chunk = Chunk { kind: ChunkKind::MaybeSwitch, buf: &encoded[..] }; + match chunk.write_to(&mut conn.sink).and_then(|_| conn.sink.flush()) { + Ok(_) => { + trace!("wrote MaybeSwitch"); + } + Err(e) if e.kind() == io::ErrorKind::BrokenPipe => { + trace!("writing MaybeSwitch: client hangup: {:?}", e); + } + Err(e) => { + error!("unexpected IO error while writing heartbeat: {}", e); + return Err(e).context("writing MaybeSwitch")?; + } + } + } // make this select non-blocking so we spend most of our time parked // in poll @@ -1003,7 +1041,7 @@ impl SessionInner { /// Shared between the session struct (for calls originating with the cli) /// and the session inner struct (for calls resulting from keybindings). #[derive(Debug)] -pub struct ReaderCtl { +pub struct ShellToClientCtl { /// A control channel for the shell->client thread. Whenever a new client /// dials in, the output stream for that client must be attached to the /// shell->client thread by sending it down this channel. A disconnect @@ -1028,6 +1066,12 @@ pub struct ReaderCtl { // True if the client is still listening, false if it has hung up // on us. pub heartbeat_ack: crossbeam_channel::Receiver, + + /// A control channel telling the shell->client thread to + /// broadcast the given MaybeSwitch. There is no ack channel + /// because we just blast this out and the caller doesn't need + /// to know about completion. + pub maybe_switch: crossbeam_channel::Sender, } /// Given a buffer, a length after which the data is not valid, a list of diff --git a/libshpool/src/lib.rs b/libshpool/src/lib.rs index 528fcd62..82e16a4c 100644 --- a/libshpool/src/lib.rs +++ b/libshpool/src/lib.rs @@ -43,9 +43,11 @@ mod list; mod protocol; mod session_restore; mod set_log_level; +mod template; mod test_hooks; mod tty; mod user; +mod var; /// The command line arguments that shpool expects. /// These can be directly parsed with clap or manually @@ -210,6 +212,54 @@ needs debugging, but would be clobbered by a restart.")] #[clap(help = "new log level")] level: shpool_protocol::LogLevel, }, + + #[clap(about = "Manipulate template variables + +shpool session names can include {variables} which are resolved via +an environment stored globally in the shpool daemon. This command +manipulates that environment. + +The main usecase for templated session names is the ability to switch +multiple shpool sessions to new targets at the same time. For example, +you might have a `shpool attach -f '{workspace}-edit'` session and +a `shpool attach -f '{workspace}-term'` session. To switch both +sessions from the fun-feature workspace to the key-bugfix workspace, +you could just do `shpool var set workspace key-bugfix`. +")] + #[non_exhaustive] + Var { + #[clap(subcommand)] + command: VarCommands, + }, +} + +/// The subcommds of the var command. +#[derive(Subcommand, Debug)] +#[non_exhaustive] +pub enum VarCommands { + #[clap(about = "List the variables + +This command dumps out the whole variable list with +both vars and values in a JSON object using vars as keys.")] + List { + #[clap(short, long, help = "Output as JSON")] + json: bool, + }, + #[clap(about = "Get a variable + +This returns the raw value of the given variable.")] + #[non_exhaustive] + Get { var: String }, + #[clap(about = "Set a variable + +This updates the value of the given variable.")] + #[non_exhaustive] + Set { var: String, val: String }, + #[clap(about = "Unset a variable + +This removes the given variable from the environment.")] + #[non_exhaustive] + Unset { var: String }, } impl Args { @@ -382,6 +432,7 @@ pub fn run(args: Args, hooks: Option>) -> an Commands::Kill { sessions } => kill::run(sessions, socket), Commands::List { json } => list::run(socket, json), Commands::SetLogLevel { level } => set_log_level::run(level, socket), + Commands::Var { command } => var::run(socket, command), }; if let Err(err) = res { diff --git a/libshpool/src/protocol.rs b/libshpool/src/protocol.rs index e4f64a4a..109a2e14 100644 --- a/libshpool/src/protocol.rs +++ b/libshpool/src/protocol.rs @@ -15,27 +15,29 @@ use std::{ cmp, io::{self, Read, Write}, - os::unix::net::UnixStream, + os::{fd::AsFd, unix::net::UnixStream}, path::Path, - sync::atomic::{AtomicI32, Ordering}, + sync::Mutex, thread, time, }; use anyhow::{anyhow, Context}; use byteorder::{LittleEndian, ReadBytesExt as _, WriteBytesExt as _}; +use nix::poll; use serde::{Deserialize, Serialize}; use shpool_protocol::{Chunk, ChunkKind, ConnectHeader, VersionHeader}; use tracing::{debug, error, info, instrument, span, trace, warn, Level}; use super::{common, consts, tty}; -const DETACH_DISCONNECT_FAST_WAIT_DUR: time::Duration = time::Duration::from_millis(10); const MAX_DETACH_WAIT_DUR: time::Duration = time::Duration::from_millis(300); const DETACH_BACKOFF_INITIAL_DUR: time::Duration = time::Duration::from_millis(1); // Cap backoff steps so slow-path stays responsive while still avoiding busy // waits. const DETACH_BACKOFF_MAX_STEP_DUR: time::Duration = time::Duration::from_millis(25); +const STDIN_READ_POLL_MS: u16 = 50; + /// The centralized encoding function that should be used for all protocol /// serialization. pub fn encode_to(d: &T, w: W) -> anyhow::Result<()> @@ -241,16 +243,23 @@ impl Client { /// socket and back again. It is the main loop of /// `shpool attach`. /// - /// Return value: the exit status that `shpool attach` should - /// exit with. + /// The on_maybe_switch callback should return true if pipe_bytes + /// should exit with a PipeBytesResult::MaybeSwitch because the + /// attach process needs to reattach. #[instrument(skip_all)] - pub fn pipe_bytes(self) -> anyhow::Result { + pub fn pipe_bytes( + self, + on_maybe_switch: OnMaybeSwitchF, + ) -> anyhow::Result + where + OnMaybeSwitchF: Fn(&shpool_protocol::MaybeSwitch) -> bool + Send + Sync + 'static, + { let tty_guard = tty::set_attach_flags()?; let mut read_client_stream = self.stream.try_clone().context("cloning read stream")?; let mut write_client_stream = self.stream.try_clone().context("cloning read stream")?; - let exit_status = AtomicI32::new(1); + let result_slot = Mutex::new(None); thread::scope(|s| { // stdin -> sock let stdin_to_sock_h = s.spawn(|| -> anyhow::Result<()> { @@ -259,6 +268,24 @@ impl Client { let mut buf = vec![0; consts::BUF_SIZE]; loop { + { + let res = result_slot.lock().unwrap(); + if res.is_some() { + return Ok(()); + } + } + + { + let mut poll_fds = + [poll::PollFd::new(stdin.as_fd(), poll::PollFlags::POLLIN)]; + let nready = poll::poll(&mut poll_fds, STDIN_READ_POLL_MS) + .context("polling stdin")?; + if nready == 0 { + // timeout + continue; + } + } + let nread = stdin.read(&mut buf).context("reading stdin from user")?; if nread == 0 { return Ok(()); @@ -281,6 +308,26 @@ impl Client { let mut buf = vec![0; consts::BUF_SIZE]; loop { + { + let res = result_slot.lock().unwrap(); + if res.is_some() { + return Ok(()); + } + } + + { + let mut poll_fds = [poll::PollFd::new( + read_client_stream.as_fd(), + poll::PollFlags::POLLIN, + )]; + let nready = poll::poll(&mut poll_fds, STDIN_READ_POLL_MS) + .context("polling stdin")?; + if nready == 0 { + // timeout + continue; + } + } + let chunk = match Chunk::read_into(&mut read_client_stream, &mut buf) { Ok(c) => c, Err(err) => { @@ -323,46 +370,48 @@ impl Client { .read_i32::() .context("reading exit status from exit status chunk")?; info!("got exit status frame (status={})", stat); - exit_status.store(stat, Ordering::Release); + { + let mut res = result_slot.lock().unwrap(); + *res = Some(PipeBytesResult::Exit(stat)); + } + } + ChunkKind::MaybeSwitch => { + let maybe_switch_reader = io::Cursor::new(chunk.buf); + let maybe_switch: shpool_protocol::MaybeSwitch = + decode_from(maybe_switch_reader).context("decoding vars list")?; + + info!("got vars update (maybe_switch={:?})", maybe_switch); + if on_maybe_switch(&maybe_switch) { + let mut res = result_slot.lock().unwrap(); + *res = Some(PipeBytesResult::MaybeSwitch(maybe_switch)); + } } } } }); loop { - let mut nfinished_threads = 0; - if stdin_to_sock_h.is_finished() { - nfinished_threads += 1; - } - if sock_to_stdout_h.is_finished() { - nfinished_threads += 1; - } + let mut nfinished_threads = (stdin_to_sock_h.is_finished() as usize) + + (sock_to_stdout_h.is_finished() as usize); if nfinished_threads > 0 { - if nfinished_threads < 2 { - // Fast-path: when sock->stdout already ended (detach/disconnect), - // stdin->sock can stay blocked on stdin. In that case, do a very - // short grace wait and then exit quickly. This is independent - // of stdin being a TTY or a pipe. - // Slow-path: for other shutdown orders, keep compatibility by - // waiting up to 300ms with backoff. - let mut stdin_done = stdin_to_sock_h.is_finished(); - let mut stdout_done = sock_to_stdout_h.is_finished(); - - // Keep max_wait fixed for this detach sequence. Recomputing it inside - // the loop could accidentally switch paths mid-cleanup. - let max_wait = if stdout_done && !stdin_done { - DETACH_DISCONNECT_FAST_WAIT_DUR - } else { - MAX_DETACH_WAIT_DUR - }; + // If one of the threads has exited, but not the other + // make sure that the exit result slot has some contents + // so the other thread will exit the next time it wakes + // from its poll(). + { + let mut res = result_slot.lock().unwrap(); + if res.is_none() { + *res = Some(PipeBytesResult::Exit(1)); + } + } + if nfinished_threads < 2 { let finished_waiting = common::sleep_unless( - max_wait, + MAX_DETACH_WAIT_DUR, || { - stdin_done = stdin_to_sock_h.is_finished(); - stdout_done = sock_to_stdout_h.is_finished(); - nfinished_threads = (stdin_done as usize) + (stdout_done as usize); + nfinished_threads = (stdin_to_sock_h.is_finished() as usize) + + (sock_to_stdout_h.is_finished() as usize); nfinished_threads >= 2 }, common::PollStrategy::Backoff { @@ -375,12 +424,17 @@ impl Client { if !finished_waiting { // Re-probe after timeout because thread state can change // during the final sleep inside sleep_unless. - stdin_done = stdin_to_sock_h.is_finished(); - stdout_done = sock_to_stdout_h.is_finished(); - nfinished_threads = (stdin_done as usize) + (stdout_done as usize); + nfinished_threads = (stdin_to_sock_h.is_finished() as usize) + + (sock_to_stdout_h.is_finished() as usize); } if nfinished_threads < 2 { + // It should be impossible to get here because both + // loops use poll() to wake up every so often to + // check if they need to exit. Nevertheless, if + // we somehow still have a stuck thread at this + // point, we'll just exit. + // If one of the worker threads is done and the // other is not exiting, we are likely blocked on // some IO. Fortunately, since there isn't much else @@ -389,15 +443,20 @@ impl Client { // by just hard-exiting the whole process. This allows // us to use simple blocking IO. warn!( - "exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}", - stdin_done, - stdout_done + "internal error: exiting due to a stuck IO thread stdin_to_sock_finished={} sock_to_stdout_finished={}", + stdin_to_sock_h.is_finished(), + sock_to_stdout_h.is_finished(), ); // make sure that we restore the tty flags on the input // tty before exiting the process. drop(tty_guard); - std::process::exit(exit_status.load(Ordering::Acquire)); + let res = result_slot.lock().unwrap(); + if let Some(PipeBytesResult::Exit(stat)) = *res { + std::process::exit(stat); + } else { + std::process::exit(1); + } } } break; @@ -425,11 +484,29 @@ impl Client { } } - Ok(exit_status.load(Ordering::Acquire)) + let mut ret = PipeBytesResult::Exit(1); + { + let res = result_slot.lock().unwrap(); + if let Some(r) = res.clone() { + ret = r; + } + } + Ok(ret) }) } } +#[derive(Clone)] +pub enum PipeBytesResult { + /// The attach proc should exit with the given exit status. + Exit(i32), + /// The on_maybe_switch callback has requested the client stop streaming + /// due to some change that requires a reconnect (almost certainly a + /// changed var in the session name template). The attach proc + /// should recompute any relevant templates and re-attach. + MaybeSwitch(shpool_protocol::MaybeSwitch), +} + #[cfg(test)] mod test { use super::*; diff --git a/libshpool/src/template.rs b/libshpool/src/template.rs new file mode 100644 index 00000000..d1be900a --- /dev/null +++ b/libshpool/src/template.rs @@ -0,0 +1,206 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use anyhow::anyhow; + +// A guess at how large the values of variables will be on average. +// This is intended as a slight over-estimate as we use it to compute +// the buffer size we should pre-allocate for instantiation. +const VAR_SIZE_GUESS: usize = 40; + +/// A template is a simple variable substitution string template used +/// by the templated session name feature to allow automatic client +/// switching. +/// +/// The template syntax is that variable subsitutions look like +/// `#{var_name}`, where var_name must be some alphanumeric string. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Template { + chunks: Vec, + instantiated_size_guess: usize, +} + +/// A chunk is either a raw hunk of text or a variable substitution. +#[derive(Debug, Clone, Eq, PartialEq)] +enum Chunk { + Raw(String), + Var(String), +} + +impl Template { + pub fn new(src: &str) -> anyhow::Result