diff --git a/aw-server/src/config.rs b/aw-server/src/config.rs index 37d59bbd..6e09bf42 100644 --- a/aw-server/src/config.rs +++ b/aw-server/src/config.rs @@ -1,5 +1,6 @@ -use std::fs::File; +use std::fs::{self, File}; use std::io::{Read, Write}; +use std::path::{Path, PathBuf}; use rocket::config::Config; use rocket::data::{Limits, ToByteUnit}; @@ -119,8 +120,11 @@ fn default_custom_static() -> std::collections::HashMap { std::collections::HashMap::new() } -pub fn create_config(testing: bool) -> AWConfig { - set_testing(testing); +fn get_config_path(testing: bool, config_override: Option<&Path>) -> PathBuf { + if let Some(config_path) = config_override { + return config_path.to_path_buf(); + } + let mut config_path = dirs::get_config_dir().unwrap(); if !testing { config_path.push("config.toml") @@ -128,6 +132,16 @@ pub fn create_config(testing: bool) -> AWConfig { config_path.push("config-testing.toml") } + config_path +} + +pub fn create_config(testing: bool, config_override: Option<&Path>) -> AWConfig { + set_testing(testing); + let config_path = get_config_path(testing, config_override); + if let Some(parent) = config_path.parent() { + fs::create_dir_all(parent).expect("Unable to create config dir"); + } + /* If there is no config file, create a new config file with default values but every value is * commented out by default in case we would change a default value at some point in the future */ if !config_path.is_file() { @@ -157,3 +171,62 @@ pub fn create_config(testing: bool) -> AWConfig { aw_config } + +#[cfg(test)] +mod tests { + use super::create_config; + use std::fs; + use std::path::PathBuf; + use std::sync::Mutex; + use uuid::Uuid; + + static TEST_LOCK: Mutex<()> = Mutex::new(()); + + struct TestConfigPath { + root: PathBuf, + config_path: PathBuf, + } + + impl TestConfigPath { + fn new(name: &str) -> Self { + let root = std::env::temp_dir() + .join("aw-server-config-tests") + .join(format!("{name}-{}", Uuid::new_v4())); + let config_path = root.join("config.toml"); + + Self { root, config_path } + } + } + + impl Drop for TestConfigPath { + fn drop(&mut self) { + let _ = fs::remove_dir_all(&self.root); + } + } + + #[test] + fn create_config_uses_override_path() { + // create_config mutates the TESTING global, so these tests must not overlap. + let _lock = TEST_LOCK.lock().unwrap(); + let paths = TestConfigPath::new("override"); + fs::create_dir_all(paths.config_path.parent().unwrap()).unwrap(); + fs::write(&paths.config_path, "address = \"0.0.0.0\"\nport = 5611\n").unwrap(); + + let config = create_config(false, Some(paths.config_path.as_path())); + + assert_eq!(config.address, "0.0.0.0"); + assert_eq!(config.port, 5611); + } + + #[test] + fn create_config_creates_missing_override_file() { + let _lock = TEST_LOCK.lock().unwrap(); + let paths = TestConfigPath::new("missing"); + + let config = create_config(false, Some(paths.config_path.as_path())); + + assert!(paths.config_path.is_file()); + assert_eq!(config.address, "127.0.0.1"); + assert_eq!(config.port, 5600); + } +} diff --git a/aw-server/src/main.rs b/aw-server/src/main.rs index 2cbf39e7..1d883d9b 100644 --- a/aw-server/src/main.rs +++ b/aw-server/src/main.rs @@ -42,6 +42,10 @@ struct Opts { #[clap(long)] dbpath: Option, + /// Path to config file override + #[clap(short = 'c', long = "config")] + config: Option, + /// Path to webui override #[clap(long)] webpath: Option, @@ -79,7 +83,7 @@ async fn main() -> Result<(), rocket::Error> { info!("Running server in Testing mode"); } - let mut config = config::create_config(testing); + let mut config = config::create_config(testing, opts.config.as_deref()); // set host if overridden if let Some(host) = opts.host {