diff --git a/Cargo.lock b/Cargo.lock index ec1f3551..5b602a4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2762,6 +2762,7 @@ dependencies = [ "indexmap", "log", "memchr", + "native-tls", "once_cell", "percent-encoding", "serde", diff --git a/Cargo.toml b/Cargo.toml index 67119a08..cd0c07b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ sentry_protos = "0.4.11" serde = "1.0.214" serde_yaml = "0.9.34" sha2 = "0.10.8" -sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "postgres"] } +sqlx = { version = "0.8.3", features = ["sqlite", "runtime-tokio", "chrono", "postgres", "tls-native-tls"] } tokio = { version = "1.43.1", features = ["full"] } tokio-stream = { version = "0.1.16", features = ["full"] } tokio-util = "0.7.12" diff --git a/src/store/postgres_activation_store.rs b/src/store/postgres_activation_store.rs index fdfc493b..4d148e2a 100644 --- a/src/store/postgres_activation_store.rs +++ b/src/store/postgres_activation_store.rs @@ -20,28 +20,74 @@ pub async fn create_postgres_pool( url: &str, database_name: &str, ) -> Result<(Pool, Pool), Error> { - let conn_str = url.to_owned() + "/" + database_name; + let conn_opts = build_pg_connect_options(url, database_name)?; let read_pool = PgPoolOptions::new() .max_connections(64) - .connect_with(PgConnectOptions::from_str(&conn_str)?) + .connect_with(conn_opts.clone()) .await?; let write_pool = PgPoolOptions::new() .max_connections(64) - .connect_with(PgConnectOptions::from_str(&conn_str)?) + .connect_with(conn_opts) .await?; Ok((read_pool, write_pool)) } pub async fn create_default_postgres_pool(url: &str) -> Result, Error> { - let conn_str = url.to_owned() + "/postgres"; + let conn_opts = build_pg_connect_options(url, "postgres")?; let default_pool = PgPoolOptions::new() .max_connections(64) - .connect_with(PgConnectOptions::from_str(&conn_str)?) + .connect_with(conn_opts) .await?; Ok(default_pool) } +fn build_pg_connect_options(url: &str, database_name: &str) -> Result { + Ok(PgConnectOptions::from_str(url)?.database(database_name)) +} + +#[cfg(test)] +mod tests { + use super::build_pg_connect_options; + use sqlx::postgres::PgSslMode; + + #[test] + fn test_connect_opts_plain_url() { + let pg_url = "postgresql://user:pass@localhost:5432"; + let custom_db_name = "my-custom-db"; + let opts = build_pg_connect_options(pg_url, custom_db_name).unwrap(); + assert_eq!(opts.get_database(), Some(custom_db_name)); + assert_eq!(opts.get_host(), "localhost"); + assert_eq!(opts.get_port(), 5432); + } + + #[test] + fn test_connect_opts_preserves_sslmode_query_param() { + let pg_url_with_query = "postgresql://user:pass@localhost:5432?sslmode=require"; + let custom_db_name = "my-custom-db"; + let opts = build_pg_connect_options(pg_url_with_query, custom_db_name).unwrap(); + assert_eq!(opts.get_database(), Some(custom_db_name)); + assert!(matches!(opts.get_ssl_mode(), PgSslMode::Require)); + } + + #[test] + fn test_connect_opts_overrides_existing_db_in_url() { + let pg_url_with_existing_db = "postgresql://user:pass@localhost:5432/olddb-in-path"; + let new_db_name = "newdb"; + let opts = build_pg_connect_options(pg_url_with_existing_db, new_db_name).unwrap(); + assert_eq!(opts.get_database(), Some(new_db_name)); + } + + #[test] + fn test_connect_opts_overrides_db_and_preserves_tls() { + let pg_url_with_existing_db_and_tls = "postgresql://user:pass@localhost:5432/olddb?sslmode=verify-ca"; + let new_db_name = "newdb"; + let opts = build_pg_connect_options(pg_url_with_existing_db_and_tls, new_db_name).unwrap(); + assert_eq!(opts.get_database(), Some("newdb")); + assert!(matches!(opts.get_ssl_mode(), PgSslMode::VerifyCa)); + } +} + pub struct PostgresActivationStoreConfig { pub pg_url: String, pub pg_database_name: String,