diff --git a/src/commands/deploy.rs b/src/commands/deploy.rs index a31d004..a265a07 100644 --- a/src/commands/deploy.rs +++ b/src/commands/deploy.rs @@ -113,3 +113,145 @@ pub async fn rollback_latest() -> Result> { save_manifest(&m).await?; Ok(Some(entry)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::EnvGuard; + use std::sync::{Mutex, OnceLock}; + + fn env_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + } + + #[test] + fn rollback_entry_serialization() { + let entry = RollbackEntry { + job_id: "job-1".to_string(), + backup_path: "/backups/status.bak".to_string(), + install_path: "/usr/bin/status".to_string(), + timestamp: Utc::now(), + }; + let json = serde_json::to_string(&entry).unwrap(); + let deserialized: RollbackEntry = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.job_id, "job-1"); + assert_eq!(deserialized.backup_path, "/backups/status.bak"); + assert_eq!(deserialized.install_path, "/usr/bin/status"); + } + + #[test] + fn rollback_manifest_default_is_empty() { + let manifest = RollbackManifest::default(); + assert!(manifest.entries.is_empty()); + } + + #[test] + fn rollback_manifest_serialization_roundtrip() { + let manifest = RollbackManifest { + entries: vec![ + RollbackEntry { + job_id: "job-1".to_string(), + backup_path: "/backups/a.bak".to_string(), + install_path: "/usr/bin/status".to_string(), + timestamp: Utc::now(), + }, + RollbackEntry { + job_id: "job-2".to_string(), + backup_path: "/backups/b.bak".to_string(), + install_path: "/usr/bin/status".to_string(), + timestamp: Utc::now(), + }, + ], + }; + let json = serde_json::to_string_pretty(&manifest).unwrap(); + let deserialized: RollbackManifest = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.entries.len(), 2); + assert_eq!(deserialized.entries[0].job_id, "job-1"); + assert_eq!(deserialized.entries[1].job_id, "job-2"); + } + + #[tokio::test] + async fn load_manifest_nonexistent_returns_default() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let dir = tempfile::tempdir().unwrap(); + let _env = EnvGuard::set("UPDATE_STORAGE_PATH", dir.path().to_str().unwrap()); + let manifest = load_manifest().await.unwrap(); + assert!(manifest.entries.is_empty()); + } + + #[tokio::test] + async fn save_and_load_manifest_roundtrip() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let dir = tempfile::tempdir().unwrap(); + let _env = EnvGuard::set("UPDATE_STORAGE_PATH", dir.path().to_str().unwrap()); + + let manifest = RollbackManifest { + entries: vec![RollbackEntry { + job_id: "test-job".to_string(), + backup_path: "/backup/test.bak".to_string(), + install_path: "/usr/bin/status".to_string(), + timestamp: Utc::now(), + }], + }; + save_manifest(&manifest).await.unwrap(); + + let loaded = load_manifest().await.unwrap(); + assert_eq!(loaded.entries.len(), 1); + assert_eq!(loaded.entries[0].job_id, "test-job"); + } + + #[tokio::test] + async fn record_rollback_appends_entry() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let dir = tempfile::tempdir().unwrap(); + let _env = EnvGuard::set("UPDATE_STORAGE_PATH", dir.path().to_str().unwrap()); + + // Save an initial empty manifest + save_manifest(&RollbackManifest::default()).await.unwrap(); + + record_rollback("job-1", "/backup/1.bak", "/usr/bin/status") + .await + .unwrap(); + record_rollback("job-2", "/backup/2.bak", "/usr/bin/status") + .await + .unwrap(); + + let loaded = load_manifest().await.unwrap(); + assert_eq!(loaded.entries.len(), 2); + assert_eq!(loaded.entries[0].job_id, "job-1"); + assert_eq!(loaded.entries[1].job_id, "job-2"); + } + + #[tokio::test] + async fn backup_current_binary_creates_file() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let dir = tempfile::tempdir().unwrap(); + let _env = EnvGuard::set("UPDATE_STORAGE_PATH", dir.path().to_str().unwrap()); + + // Create a fake binary to back up + let src = dir.path().join("status"); + tokio::fs::write(&src, b"fake binary content") + .await + .unwrap(); + + let backup_path = backup_current_binary(src.to_str().unwrap(), "test-job") + .await + .unwrap(); + assert!(Path::new(&backup_path).exists()); + + let content = tokio::fs::read(&backup_path).await.unwrap(); + assert_eq!(content, b"fake binary content"); + } + + #[tokio::test] + async fn rollback_latest_with_empty_manifest_returns_none() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let dir = tempfile::tempdir().unwrap(); + let _env = EnvGuard::set("UPDATE_STORAGE_PATH", dir.path().to_str().unwrap()); + + save_manifest(&RollbackManifest::default()).await.unwrap(); + let result = rollback_latest().await.unwrap(); + assert!(result.is_none()); + } +} diff --git a/src/commands/version_check.rs b/src/commands/version_check.rs index 13345ca..b12facd 100644 --- a/src/commands/version_check.rs +++ b/src/commands/version_check.rs @@ -34,3 +34,62 @@ pub async fn check_remote_version() -> Result> { .context("parsing remote version response")?; Ok(Some(rv)) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::EnvGuard; + use std::sync::{Mutex, OnceLock}; + + fn env_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + } + + #[test] + fn remote_version_deserialize_with_checksum() { + let json = r#"{"version": "1.2.3", "checksum": "abc123"}"#; + let rv: RemoteVersion = serde_json::from_str(json).unwrap(); + assert_eq!(rv.version, "1.2.3"); + assert_eq!(rv.checksum, Some("abc123".to_string())); + } + + #[test] + fn remote_version_deserialize_without_checksum() { + let json = r#"{"version": "1.2.3"}"#; + let rv: RemoteVersion = serde_json::from_str(json).unwrap(); + assert_eq!(rv.version, "1.2.3"); + assert_eq!(rv.checksum, None); + } + + #[test] + fn remote_version_deserialize_null_checksum() { + let json = r#"{"version": "0.1.0", "checksum": null}"#; + let rv: RemoteVersion = serde_json::from_str(json).unwrap(); + assert_eq!(rv.version, "0.1.0"); + assert_eq!(rv.checksum, None); + } + + #[test] + fn remote_version_deserialize_missing_version_fails() { + let json = r#"{"checksum": "abc"}"#; + let result: std::result::Result = serde_json::from_str(json); + assert!(result.is_err()); + } + + #[tokio::test] + async fn check_remote_version_no_env_returns_none() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::remove("UPDATE_SERVER_URL"); + let result = check_remote_version().await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn check_remote_version_empty_env_returns_none() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::set("UPDATE_SERVER_URL", ""); + let result = check_remote_version().await.unwrap(); + assert!(result.is_none()); + } +} diff --git a/src/lib.rs b/src/lib.rs index ace8c72..2451389 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,5 +7,8 @@ pub mod security; pub mod transport; pub mod utils; +#[cfg(test)] +pub(crate) mod test_utils; + // Crate version exposed for runtime queries pub const VERSION: &str = env!("CARGO_PKG_VERSION"); diff --git a/src/monitoring/mod.rs b/src/monitoring/mod.rs index 5dc948d..9a41984 100644 --- a/src/monitoring/mod.rs +++ b/src/monitoring/mod.rs @@ -222,3 +222,143 @@ pub fn spawn_heartbeat( } }) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn metrics_snapshot_default() { + let snapshot = MetricsSnapshot::default(); + assert_eq!(snapshot.timestamp_ms, 0); + assert_eq!(snapshot.cpu_usage_pct, 0.0); + assert_eq!(snapshot.memory_total_bytes, 0); + assert_eq!(snapshot.memory_used_bytes, 0); + assert_eq!(snapshot.memory_used_pct, 0.0); + assert_eq!(snapshot.disk_total_bytes, 0); + assert_eq!(snapshot.disk_used_bytes, 0); + assert_eq!(snapshot.disk_used_pct, 0.0); + } + + #[test] + fn metrics_snapshot_serialization() { + let snapshot = MetricsSnapshot { + timestamp_ms: 1700000000000, + cpu_usage_pct: 45.5, + memory_total_bytes: 16_000_000_000, + memory_used_bytes: 8_000_000_000, + memory_used_pct: 50.0, + disk_total_bytes: 500_000_000_000, + disk_used_bytes: 250_000_000_000, + disk_used_pct: 50.0, + }; + let json = serde_json::to_string(&snapshot).unwrap(); + assert!(json.contains("\"cpu_usage_pct\":45.5")); + assert!(json.contains("\"memory_total_bytes\":16000000000")); + } + + #[test] + fn control_plane_display() { + assert_eq!(ControlPlane::StatusPanel.to_string(), "status_panel"); + assert_eq!(ControlPlane::ComposeAgent.to_string(), "compose_agent"); + } + + #[test] + fn control_plane_serialization() { + let json = serde_json::to_string(&ControlPlane::StatusPanel).unwrap(); + assert_eq!(json, "\"status_panel\""); + let json = serde_json::to_string(&ControlPlane::ComposeAgent).unwrap(); + assert_eq!(json, "\"compose_agent\""); + } + + #[test] + fn control_plane_equality() { + assert_eq!(ControlPlane::StatusPanel, ControlPlane::StatusPanel); + assert_ne!(ControlPlane::StatusPanel, ControlPlane::ComposeAgent); + } + + #[test] + fn command_execution_metrics_default() { + let metrics = CommandExecutionMetrics::default(); + assert_eq!(metrics.status_panel_count, 0); + assert_eq!(metrics.compose_agent_count, 0); + assert_eq!(metrics.total_count, 0); + assert!(metrics.last_control_plane.is_none()); + assert_eq!(metrics.last_command_timestamp_ms, 0); + } + + #[test] + fn record_status_panel_execution() { + let mut metrics = CommandExecutionMetrics::default(); + metrics.record_execution(ControlPlane::StatusPanel); + + assert_eq!(metrics.status_panel_count, 1); + assert_eq!(metrics.compose_agent_count, 0); + assert_eq!(metrics.total_count, 1); + assert_eq!(metrics.last_control_plane, Some("status_panel".to_string())); + assert!(metrics.last_command_timestamp_ms > 0); + } + + #[test] + fn record_compose_agent_execution() { + let mut metrics = CommandExecutionMetrics::default(); + metrics.record_execution(ControlPlane::ComposeAgent); + + assert_eq!(metrics.status_panel_count, 0); + assert_eq!(metrics.compose_agent_count, 1); + assert_eq!(metrics.total_count, 1); + assert_eq!( + metrics.last_control_plane, + Some("compose_agent".to_string()) + ); + } + + #[test] + fn record_multiple_executions() { + let mut metrics = CommandExecutionMetrics::default(); + metrics.record_execution(ControlPlane::StatusPanel); + metrics.record_execution(ControlPlane::StatusPanel); + metrics.record_execution(ControlPlane::ComposeAgent); + + assert_eq!(metrics.status_panel_count, 2); + assert_eq!(metrics.compose_agent_count, 1); + assert_eq!(metrics.total_count, 3); + assert_eq!( + metrics.last_control_plane, + Some("compose_agent".to_string()) + ); + } + + #[tokio::test] + async fn metrics_collector_snapshot_returns_valid_data() { + let collector = MetricsCollector::new(); + let snapshot = collector.snapshot().await; + + assert!(snapshot.timestamp_ms > 0); + // On any machine, total memory should be > 0 + assert!(snapshot.memory_total_bytes > 0); + // Used memory should not exceed total + assert!(snapshot.memory_used_bytes <= snapshot.memory_total_bytes); + // Percentages should be 0-100 range + assert!(snapshot.memory_used_pct >= 0.0 && snapshot.memory_used_pct <= 100.0); + assert!(snapshot.disk_used_pct >= 0.0 && snapshot.disk_used_pct <= 100.0); + } + + #[test] + fn command_execution_metrics_serialization() { + let mut metrics = CommandExecutionMetrics::default(); + metrics.record_execution(ControlPlane::StatusPanel); + + let json = serde_json::to_string(&metrics).unwrap(); + assert!(json.contains("\"status_panel_count\":1")); + assert!(json.contains("\"compose_agent_count\":0")); + assert!(json.contains("\"total_count\":1")); + } + + #[test] + fn metrics_collector_default() { + // Verify Default trait works + let collector = MetricsCollector::default(); + let _ = format!("{:?}", collector); + } +} diff --git a/src/security/audit_log.rs b/src/security/audit_log.rs index 7f163fa..533d4e2 100644 --- a/src/security/audit_log.rs +++ b/src/security/audit_log.rs @@ -55,3 +55,91 @@ impl AuditLogger { error!(target: "audit", event = "internal_error", agent_id = agent_id.unwrap_or("") , request_id = request_id.unwrap_or(""), error = error_msg); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn audit_logger_creation() { + let logger = AuditLogger::new(); + // Verify Debug trait works + let _ = format!("{:?}", logger); + } + + #[test] + fn audit_logger_default() { + let logger = AuditLogger::default(); + let _ = format!("{:?}", logger); + } + + #[test] + fn audit_logger_auth_success_does_not_panic() { + let logger = AuditLogger::new(); + logger.auth_success("agent-1", Some("req-1"), "login"); + logger.auth_success("agent-1", None, "login"); + } + + #[test] + fn audit_logger_auth_failure_does_not_panic() { + let logger = AuditLogger::new(); + logger.auth_failure(Some("agent-1"), Some("req-1"), "bad password"); + logger.auth_failure(None, None, "unknown agent"); + } + + #[test] + fn audit_logger_signature_invalid_does_not_panic() { + let logger = AuditLogger::new(); + logger.signature_invalid(Some("agent-1"), Some("req-1")); + logger.signature_invalid(None, None); + } + + #[test] + fn audit_logger_rate_limited_does_not_panic() { + let logger = AuditLogger::new(); + logger.rate_limited("agent-1", Some("req-1")); + logger.rate_limited("agent-1", None); + } + + #[test] + fn audit_logger_replay_detected_does_not_panic() { + let logger = AuditLogger::new(); + logger.replay_detected(Some("agent-1"), Some("req-1")); + logger.replay_detected(None, None); + } + + #[test] + fn audit_logger_scope_denied_does_not_panic() { + let logger = AuditLogger::new(); + logger.scope_denied("agent-1", Some("req-1"), "docker:restart"); + logger.scope_denied("agent-1", None, "admin"); + } + + #[test] + fn audit_logger_command_executed_does_not_panic() { + let logger = AuditLogger::new(); + logger.command_executed("agent-1", Some("req-1"), "cmd-1", "restart"); + logger.command_executed("agent-1", None, "cmd-2", "stop"); + } + + #[test] + fn audit_logger_token_rotated_does_not_panic() { + let logger = AuditLogger::new(); + logger.token_rotated("agent-1", Some("req-1")); + logger.token_rotated("agent-1", None); + } + + #[test] + fn audit_logger_internal_error_does_not_panic() { + let logger = AuditLogger::new(); + logger.internal_error(Some("agent-1"), Some("req-1"), "database timeout"); + logger.internal_error(None, None, "unknown error"); + } + + #[test] + fn audit_logger_clone() { + let logger = AuditLogger::new(); + let cloned = logger.clone(); + cloned.auth_success("agent-1", None, "test"); + } +} diff --git a/src/security/rate_limit.rs b/src/security/rate_limit.rs index 245e835..d30e31e 100644 --- a/src/security/rate_limit.rs +++ b/src/security/rate_limit.rs @@ -41,3 +41,66 @@ impl RateLimiter { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn allows_requests_within_limit() { + let limiter = RateLimiter::new_per_minute(3); + assert!(limiter.allow("client-1").await); + assert!(limiter.allow("client-1").await); + assert!(limiter.allow("client-1").await); + } + + #[tokio::test] + async fn blocks_requests_over_limit() { + let limiter = RateLimiter::new_per_minute(2); + assert!(limiter.allow("client-1").await); + assert!(limiter.allow("client-1").await); + assert!(!limiter.allow("client-1").await); + } + + #[tokio::test] + async fn independent_keys() { + let limiter = RateLimiter::new_per_minute(1); + assert!(limiter.allow("client-1").await); + assert!(limiter.allow("client-2").await); + // client-1 is now blocked, client-2 is also blocked + assert!(!limiter.allow("client-1").await); + assert!(!limiter.allow("client-2").await); + } + + #[tokio::test] + async fn window_expiry_allows_new_requests() { + // Use a very short window to test expiry + let limiter = RateLimiter { + window: Duration::from_millis(50), + limit: 1, + inner: Arc::new(Mutex::new(HashMap::new())), + }; + assert!(limiter.allow("client").await); + assert!(!limiter.allow("client").await); + + // Wait for window to expire + tokio::time::sleep(Duration::from_millis(100)).await; + assert!(limiter.allow("client").await); + } + + #[tokio::test] + async fn limit_of_zero_blocks_all() { + let limiter = RateLimiter::new_per_minute(0); + assert!(!limiter.allow("client").await); + } + + #[tokio::test] + async fn limiter_is_clone_safe() { + let limiter = RateLimiter::new_per_minute(1); + let limiter_clone = limiter.clone(); + + assert!(limiter.allow("client").await); + // Clone shares state, so this should be blocked + assert!(!limiter_clone.allow("client").await); + } +} diff --git a/src/security/request_signer.rs b/src/security/request_signer.rs index 026daab..2695267 100644 --- a/src/security/request_signer.rs +++ b/src/security/request_signer.rs @@ -77,3 +77,194 @@ pub fn verify_signature( Err(anyhow!("signature mismatch")) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn compute_signature_deterministic() { + let sig1 = compute_signature_base64("secret", b"hello"); + let sig2 = compute_signature_base64("secret", b"hello"); + assert_eq!(sig1, sig2); + } + + #[test] + fn compute_signature_different_keys() { + let sig1 = compute_signature_base64("key1", b"body"); + let sig2 = compute_signature_base64("key2", b"body"); + assert_ne!(sig1, sig2); + } + + #[test] + fn compute_signature_different_bodies() { + let sig1 = compute_signature_base64("key", b"body1"); + let sig2 = compute_signature_base64("key", b"body2"); + assert_ne!(sig1, sig2); + } + + #[test] + fn compute_signature_empty_body() { + let sig = compute_signature_base64("key", b""); + assert!(!sig.is_empty()); + // Verify it's valid base64 + assert!(general_purpose::STANDARD.decode(&sig).is_ok()); + } + + #[test] + fn decode_signature_base64() { + let original = b"test data for signature"; + let encoded = general_purpose::STANDARD.encode(original); + let decoded = decode_signature(&encoded).unwrap(); + assert_eq!(decoded, original); + } + + #[test] + fn decode_signature_hex_fallback() { + // "hello" in hex + let decoded = decode_signature("68656c6c6f").unwrap(); + assert_eq!(decoded, b"hello"); + } + + #[test] + fn decode_signature_hex_uppercase() { + let decoded = decode_signature("48454C4C4F").unwrap(); + assert_eq!(decoded, b"HELLO"); + } + + #[test] + fn decode_signature_invalid_encoding() { + // Odd-length string that's not valid base64 and not valid hex + let result = decode_signature("xyz"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("invalid signature encoding")); + } + + #[test] + fn verify_signature_valid() { + let key = "test-secret"; + let body = b"request body"; + let sig = compute_signature_base64(key, body); + let ts = Utc::now().timestamp().to_string(); + + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", ts.parse().unwrap()); + headers.insert("X-Agent-Signature", sig.parse().unwrap()); + + assert!(verify_signature(&headers, body, key, 60).is_ok()); + } + + #[test] + fn verify_signature_missing_timestamp() { + let mut headers = HeaderMap::new(); + headers.insert("X-Agent-Signature", "sig".parse().unwrap()); + + let result = verify_signature(&headers, b"body", "key", 60); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("missing X-Timestamp")); + } + + #[test] + fn verify_signature_invalid_timestamp() { + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", "not-a-number".parse().unwrap()); + headers.insert("X-Agent-Signature", "sig".parse().unwrap()); + + let result = verify_signature(&headers, b"body", "key", 60); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("invalid X-Timestamp")); + } + + #[test] + fn verify_signature_stale_timestamp() { + let key = "test-secret"; + let body = b"body"; + let sig = compute_signature_base64(key, body); + let old_ts = (Utc::now().timestamp() - 120).to_string(); + + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", old_ts.parse().unwrap()); + headers.insert("X-Agent-Signature", sig.parse().unwrap()); + + let result = verify_signature(&headers, body, key, 60); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("stale request")); + } + + #[test] + fn verify_signature_missing_signature_header() { + let ts = Utc::now().timestamp().to_string(); + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", ts.parse().unwrap()); + + let result = verify_signature(&headers, b"body", "key", 60); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("missing X-Agent-Signature")); + } + + #[test] + fn verify_signature_wrong_key() { + let body = b"body"; + let sig = compute_signature_base64("correct-key", body); + let ts = Utc::now().timestamp().to_string(); + + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", ts.parse().unwrap()); + headers.insert("X-Agent-Signature", sig.parse().unwrap()); + + let result = verify_signature(&headers, body, "wrong-key", 60); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("signature mismatch")); + } + + #[test] + fn verify_signature_tampered_body() { + let key = "test-secret"; + let body = b"original body"; + let sig = compute_signature_base64(key, body); + let ts = Utc::now().timestamp().to_string(); + + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", ts.parse().unwrap()); + headers.insert("X-Agent-Signature", sig.parse().unwrap()); + + // Verify with a different body + let result = verify_signature(&headers, b"tampered body", key, 60); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("signature mismatch")); + } + + #[test] + fn verify_signature_large_skew_allowed() { + let key = "test-secret"; + let body = b"body"; + let sig = compute_signature_base64(key, body); + // Timestamp 30 seconds in the past + let ts = (Utc::now().timestamp() - 30).to_string(); + + let mut headers = HeaderMap::new(); + headers.insert("X-Timestamp", ts.parse().unwrap()); + headers.insert("X-Agent-Signature", sig.parse().unwrap()); + + // 60 second skew allows 30 second old request + assert!(verify_signature(&headers, body, key, 60).is_ok()); + } +} diff --git a/src/security/scopes.rs b/src/security/scopes.rs index 68247c2..8b38b0d 100644 --- a/src/security/scopes.rs +++ b/src/security/scopes.rs @@ -28,3 +28,83 @@ impl Scopes { self.allowed.contains(scope) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::EnvGuard; + use std::sync::{Mutex, OnceLock}; + + fn env_lock() -> &'static Mutex<()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + } + + #[test] + fn empty_scopes_allow_everything() { + let scopes = Scopes::default(); + assert!(scopes.is_allowed("anything")); + assert!(scopes.is_allowed("docker:restart")); + assert!(scopes.is_allowed("")); + } + + #[test] + fn scopes_with_values_restrict_access() { + let mut allowed = HashSet::new(); + allowed.insert("docker:restart".to_string()); + allowed.insert("docker:logs".to_string()); + let scopes = Scopes { allowed }; + + assert!(scopes.is_allowed("docker:restart")); + assert!(scopes.is_allowed("docker:logs")); + assert!(!scopes.is_allowed("docker:stop")); + assert!(!scopes.is_allowed("admin")); + } + + #[test] + fn scopes_from_env_parses_comma_separated() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::set("AGENT_SCOPES", "docker:restart,docker:logs,admin"); + let scopes = Scopes::from_env(); + assert!(scopes.is_allowed("docker:restart")); + assert!(scopes.is_allowed("docker:logs")); + assert!(scopes.is_allowed("admin")); + assert!(!scopes.is_allowed("docker:stop")); + } + + #[test] + fn scopes_from_env_trims_whitespace() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::set("AGENT_SCOPES", " docker:restart , admin "); + let scopes = Scopes::from_env(); + assert!(scopes.is_allowed("docker:restart")); + assert!(scopes.is_allowed("admin")); + } + + #[test] + fn scopes_from_env_skips_empty_items() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::set("AGENT_SCOPES", "docker:restart,,, ,admin"); + let scopes = Scopes::from_env(); + assert!(scopes.is_allowed("docker:restart")); + assert!(scopes.is_allowed("admin")); + // The empty strings should NOT be in the set + assert!(!scopes.is_allowed("")); + } + + #[test] + fn scopes_from_env_missing_var_allows_all() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::remove("AGENT_SCOPES"); + let scopes = Scopes::from_env(); + assert!(scopes.is_allowed("anything")); + } + + #[test] + fn scopes_from_env_empty_string_allows_all() { + let _lock = env_lock().lock().expect("env lock poisoned"); + let _env = EnvGuard::set("AGENT_SCOPES", ""); + let scopes = Scopes::from_env(); + assert!(scopes.is_allowed("anything")); + } +} diff --git a/src/test_utils.rs b/src/test_utils.rs new file mode 100644 index 0000000..f3c7c30 --- /dev/null +++ b/src/test_utils.rs @@ -0,0 +1,31 @@ +/// A drop-based guard that saves an environment variable's original value before modification +/// and restores it when dropped, ensuring cleanup even when test assertions panic. +pub(crate) struct EnvGuard { + key: &'static str, + original: Option, +} + +impl EnvGuard { + /// Sets `key` to `value` and saves the previous value for restoration on drop. + pub(crate) fn set(key: &'static str, value: &str) -> Self { + let original = std::env::var(key).ok(); + std::env::set_var(key, value); + Self { key, original } + } + + /// Removes `key` from the environment and saves the previous value for restoration on drop. + pub(crate) fn remove(key: &'static str) -> Self { + let original = std::env::var(key).ok(); + std::env::remove_var(key); + Self { key, original } + } +} + +impl Drop for EnvGuard { + fn drop(&mut self) { + match &self.original { + Some(v) => std::env::set_var(self.key, v), + None => std::env::remove_var(self.key), + } + } +} diff --git a/src/transport/http_polling.rs b/src/transport/http_polling.rs index c8099c7..a29b288 100644 --- a/src/transport/http_polling.rs +++ b/src/transport/http_polling.rs @@ -602,4 +602,179 @@ mod tests { assert!(result.command.is_none()); mock.assert(); } + + #[test] + fn build_wait_command_url_with_priority() { + let url = build_wait_command_url("https://example.com", "dep-1", 30, Some("high")); + assert_eq!( + url, + "https://example.com/api/v1/agent/commands/wait/dep-1?timeout=30&priority=high" + ); + } + + #[test] + fn build_wait_command_url_default_priority() { + let url = build_wait_command_url("https://example.com", "dep-1", 60, None); + assert_eq!( + url, + "https://example.com/api/v1/agent/commands/wait/dep-1?timeout=60&priority=normal" + ); + } + + #[test] + fn extract_next_poll_secs_from_numeric() { + let json = json!({"meta": {"next_poll_secs": 30}}); + assert_eq!(extract_next_poll_secs(&json), Some(30)); + } + + #[test] + fn extract_next_poll_secs_from_string() { + let json = json!({"meta": {"next_poll_secs": "45"}}); + assert_eq!(extract_next_poll_secs(&json), Some(45)); + } + + #[test] + fn extract_next_poll_secs_missing_meta() { + let json = json!({"data": "something"}); + assert_eq!(extract_next_poll_secs(&json), None); + } + + #[test] + fn extract_next_poll_secs_missing_field() { + let json = json!({"meta": {"other": 1}}); + assert_eq!(extract_next_poll_secs(&json), None); + } + + #[test] + fn extract_field_or_default_present() { + let mut obj = serde_json::Map::new(); + obj.insert("type".to_string(), json!("restart")); + assert_eq!(extract_field_or_default(&obj, "type", "unknown"), "restart"); + } + + #[test] + fn extract_field_or_default_missing() { + let obj = serde_json::Map::new(); + assert_eq!(extract_field_or_default(&obj, "type", "unknown"), "unknown"); + } + + #[test] + fn extract_parameters_present() { + let mut obj = serde_json::Map::new(); + obj.insert("parameters".to_string(), json!({"key": "value"})); + let params = extract_parameters(&obj); + assert_eq!(params["key"], "value"); + } + + #[test] + fn extract_parameters_missing() { + let obj = serde_json::Map::new(); + let params = extract_parameters(&obj); + assert!(params.is_object()); + assert!(params.as_object().unwrap().is_empty()); + } + + #[test] + fn extract_optional_string_present() { + let mut obj = serde_json::Map::new(); + obj.insert("field".to_string(), json!("value")); + assert_eq!( + extract_optional_string(&obj, "field"), + Some("value".to_string()) + ); + } + + #[test] + fn extract_optional_string_missing() { + let obj = serde_json::Map::new(); + assert_eq!(extract_optional_string(&obj, "field"), None); + } + + #[test] + fn extract_optional_string_non_string() { + let mut obj = serde_json::Map::new(); + obj.insert("field".to_string(), json!(42)); + assert_eq!(extract_optional_string(&obj, "field"), None); + } + + #[test] + fn extract_app_code_from_params() { + let params = json!({"app_code": "myapp", "other": "data"}); + assert_eq!(extract_app_code(¶ms), Some("myapp".to_string())); + } + + #[test] + fn extract_app_code_missing() { + let params = json!({"other": "data"}); + assert_eq!(extract_app_code(¶ms), None); + } + + #[test] + fn extract_command_from_json_valid() { + let json = json!({ + "item": { + "command_id": "cmd-1", + "type": "restart", + "parameters": {"container": "nginx"}, + "deployment_hash": "dep-1", + }, + "meta": {"next_poll_secs": 10} + }); + let body = serde_json::to_string(&json).unwrap(); + let result = extract_command_from_json(json, &body).unwrap(); + + assert!(result.command.is_some()); + let cmd = result.command.unwrap(); + assert_eq!(cmd.command_id, "cmd-1"); + assert_eq!(cmd.name, "restart"); + assert_eq!(cmd.deployment_hash, Some("dep-1".to_string())); + assert_eq!(result.next_poll_secs, Some(10)); + } + + #[test] + fn extract_command_from_json_null_item() { + let json = json!({"item": null}); + let body = serde_json::to_string(&json).unwrap(); + let result = extract_command_from_json(json, &body).unwrap(); + + assert!(result.command.is_none()); + } + + #[test] + fn extract_command_from_json_empty_item() { + let json = json!({"item": {}}); + let body = serde_json::to_string(&json).unwrap(); + let result = extract_command_from_json(json, &body).unwrap(); + + assert!(result.command.is_none()); + } + + #[test] + fn extract_command_from_json_missing_command_id() { + let json = json!({ + "item": { + "type": "restart", + "parameters": {} + } + }); + let body = serde_json::to_string(&json).unwrap(); + let result = extract_command_from_json(json, &body); + + assert!(result.is_err()); + } + + #[test] + fn validate_command_id_present() { + let mut obj = serde_json::Map::new(); + obj.insert("command_id".to_string(), json!("cmd-1")); + let result = validate_command_id(&obj, "body"); + assert_eq!(result.unwrap(), "cmd-1"); + } + + #[test] + fn validate_command_id_missing() { + let obj = serde_json::Map::new(); + let result = validate_command_id(&obj, "body"); + assert!(result.is_err()); + } }