diff --git a/libshpool/src/config.rs b/libshpool/src/config.rs index 3ee2203a..404a7b89 100644 --- a/libshpool/src/config.rs +++ b/libshpool/src/config.rs @@ -293,6 +293,12 @@ pub struct Config { /// See https://man7.org/linux/man-pages/man8/pam_motd.8.html /// for more info. pub motd_args: Option>, + + /// A list of default values for variables. Setting values + /// in this list is equivilant to running `shpool var set` for + /// each (var, value) tuple every time the shpool daemon starts + /// up. + pub var_default: Option>, } impl Config { @@ -327,6 +333,7 @@ impl Config { prompt_prefix: self.prompt_prefix.or(another.prompt_prefix), motd: self.motd.or(another.motd), motd_args: self.motd_args.or(another.motd_args), + var_default: self.var_default.or(another.var_default), } } } @@ -407,6 +414,14 @@ pub enum MotdDisplayMode { Dump, } +#[derive(Deserialize, Debug, Clone)] +pub struct VarSetting { + /// The variable name. + pub var: String, + /// The variable value. + pub value: String, +} + #[cfg(test)] mod test { use super::*; @@ -436,6 +451,11 @@ mod test { r#" session_restore_engine = "vterm" "#, + r#" + [[var_default]] + var = "foo" + value = "bar" + "#, ]; for case in cases.into_iter() { diff --git a/libshpool/src/daemon/server.rs b/libshpool/src/daemon/server.rs index 5c8d9cfc..98652140 100644 --- a/libshpool/src/daemon/server.rs +++ b/libshpool/src/daemon/server.rs @@ -106,6 +106,17 @@ impl Server { } }); + let vars = Mutex::new( + config + .get() + .var_default + .clone() + .unwrap_or(vec![]) + .into_iter() + .map(|v| (v.var, v.value)) + .collect(), + ); + let daily_messenger = Arc::new(show_motd::DailyMessenger::new(config.clone())?); Ok(Arc::new(Server { config, @@ -115,7 +126,7 @@ impl Server { hooks, daily_messenger, log_level_handle, - vars: HashMap::new().into(), + vars, })) } diff --git a/shpool/tests/data/var_default.toml b/shpool/tests/data/var_default.toml new file mode 100644 index 00000000..c9a78155 --- /dev/null +++ b/shpool/tests/data/var_default.toml @@ -0,0 +1,13 @@ +norc = true +noecho = true +shell = "/bin/bash" +session_restore_mode = "simple" +prompt_prefix = "" + +[env] +PS1 = "prompt> " +TERM = "" + +[[var_default]] +var = "default_foo" +value = "default_bar" diff --git a/shpool/tests/support/line_matcher.rs b/shpool/tests/support/line_matcher.rs index 6619c2a0..4e772d66 100644 --- a/shpool/tests/support/line_matcher.rs +++ b/shpool/tests/support/line_matcher.rs @@ -31,8 +31,8 @@ where pub fn scan_until_re(&mut self, re: &str) -> anyhow::Result<()> { let compiled_re = Regex::new(re)?; let start = time::Instant::now(); + let mut line = String::new(); loop { - let mut line = String::new(); match self.out.read_line(&mut line) { Ok(0) => { return Err(anyhow!("LineMatcher: EOF")); @@ -70,6 +70,7 @@ where return Ok(()); } else { eprintln!(" no match"); + line.clear(); } } } @@ -83,8 +84,8 @@ where pub fn capture_re(&mut self, re: &str) -> anyhow::Result>> { let start = time::Instant::now(); + let mut line = String::new(); loop { - let mut line = String::new(); match self.out.read_line(&mut line) { Ok(0) => { return Err(anyhow!("LineMatcher: EOF")); @@ -133,8 +134,8 @@ where /// assertions fail (the never match regex). pub fn drain(&mut self) -> anyhow::Result<()> { let start = time::Instant::now(); + let mut line = String::new(); loop { - let mut line = String::new(); match self.out.read_line(&mut line) { Ok(0) => { return Ok(()); @@ -165,6 +166,7 @@ where } self.check_persistant_assertions(&line)?; + line.clear(); } } diff --git a/shpool/tests/var.rs b/shpool/tests/var.rs index 915b8099..50f36879 100644 --- a/shpool/tests/var.rs +++ b/shpool/tests/var.rs @@ -135,3 +135,21 @@ fn no_daemon() -> anyhow::Result<()> { Ok(()) } + +#[test] +#[timeout(30000)] +fn default_vars() -> anyhow::Result<()> { + let mut daemon_proc = support::daemon::Proc::new( + "var_default.toml", + DaemonArgs { listen_events: false, ..DaemonArgs::default() }, + ) + .context("starting daemon proc")?; + + let out = daemon_proc.var_get("default_foo")?; + assert!(out.status.success(), "var get proc did not exit successfully"); + + let stdout = String::from_utf8_lossy(&out.stdout[..]); + assert_eq!(stdout.trim(), "default_bar"); + + Ok(()) +}