From 1101c5ad2e7207552f2d3b84a29e603039e36a5d Mon Sep 17 00:00:00 2001 From: vsilent Date: Fri, 3 Apr 2026 17:34:14 +0300 Subject: [PATCH 01/10] Add live mail abuse guard Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/docker/client.rs | 73 ++++++- src/docker/containers.rs | 8 + src/docker/mail_guard.rs | 437 +++++++++++++++++++++++++++++++++++++++ src/docker/mod.rs | 2 + src/main.rs | 10 + src/sniff/reader.rs | 22 +- 6 files changed, 540 insertions(+), 12 deletions(-) create mode 100644 src/docker/mail_guard.rs diff --git a/src/docker/client.rs b/src/docker/client.rs index 3d57091..44211d8 100644 --- a/src/docker/client.rs +++ b/src/docker/client.rs @@ -4,9 +4,10 @@ use anyhow::{Context, Result}; use std::collections::HashMap; // Bollard imports -use bollard::container::{InspectContainerOptions, ListContainersOptions}; +use bollard::container::{InspectContainerOptions, ListContainersOptions, Stats, StatsOptions}; use bollard::network::{DisconnectNetworkOptions, ListNetworksOptions}; use bollard::Docker; +use futures_util::stream::StreamExt; /// Docker client wrapper pub struct DockerClient { @@ -135,16 +136,68 @@ impl DockerClient { } /// Get container stats - pub async fn get_container_stats(&self, _container_id: &str) -> Result { - // Implementation would use Docker stats API - // For now, return placeholder + pub async fn get_container_stats(&self, container_id: &str) -> Result { + let mut stream = self.client.stats( + container_id, + Some(StatsOptions { + stream: false, + one_shot: true, + }), + ); + let stats = stream + .next() + .await + .context("No stats returned from Docker")? + .context("Failed to fetch Docker stats")?; + + let (network_rx, network_tx, network_rx_packets, network_tx_packets) = + aggregate_network_stats(&stats); + Ok(ContainerStats { - cpu_percent: 0.0, - memory_usage: 0, - memory_limit: 0, - network_rx: 0, - network_tx: 0, + cpu_percent: calculate_cpu_percent(&stats), + memory_usage: stats.memory_stats.usage.unwrap_or(0), + memory_limit: stats.memory_stats.limit.unwrap_or(0), + network_rx, + network_tx, + network_rx_packets, + network_tx_packets, + }) + } +} + +fn aggregate_network_stats(stats: &Stats) -> (u64, u64, u64, u64) { + if let Some(networks) = stats.networks.as_ref() { + networks.values().fold((0, 0, 0, 0), |acc, network| { + ( + acc.0 + network.rx_bytes, + acc.1 + network.tx_bytes, + acc.2 + network.rx_packets, + acc.3 + network.tx_packets, + ) }) + } else if let Some(network) = stats.network { + ( + network.rx_bytes, + network.tx_bytes, + network.rx_packets, + network.tx_packets, + ) + } else { + (0, 0, 0, 0) + } +} + +fn calculate_cpu_percent(stats: &Stats) -> f64 { + let cpu_delta = stats.cpu_stats.cpu_usage.total_usage as f64 + - stats.precpu_stats.cpu_usage.total_usage as f64; + let system_delta = stats.cpu_stats.system_cpu_usage.unwrap_or(0) as f64 + - stats.precpu_stats.system_cpu_usage.unwrap_or(0) as f64; + let online_cpus = stats.cpu_stats.online_cpus.unwrap_or(1) as f64; + + if cpu_delta <= 0.0 || system_delta <= 0.0 { + 0.0 + } else { + (cpu_delta / system_delta) * online_cpus * 100.0 } } @@ -167,6 +220,8 @@ pub struct ContainerStats { pub memory_limit: u64, pub network_rx: u64, pub network_tx: u64, + pub network_rx_packets: u64, + pub network_tx_packets: u64, } #[cfg(test)] diff --git a/src/docker/containers.rs b/src/docker/containers.rs index 146f2f9..7b4d072 100644 --- a/src/docker/containers.rs +++ b/src/docker/containers.rs @@ -30,6 +30,14 @@ impl ContainerManager { self.docker.get_container_info(container_id).await } + /// Get live container stats + pub async fn get_container_stats( + &self, + container_id: &str, + ) -> Result { + self.docker.get_container_stats(container_id).await + } + /// Quarantine a container pub async fn quarantine_container(&self, container_id: &str, reason: &str) -> Result<()> { // Disconnect from networks diff --git a/src/docker/mail_guard.rs b/src/docker/mail_guard.rs new file mode 100644 index 0000000..1a2a9c1 --- /dev/null +++ b/src/docker/mail_guard.rs @@ -0,0 +1,437 @@ +use std::collections::{HashMap, HashSet}; +use std::env; + +use chrono::Utc; +use tokio::time::{sleep, Duration}; +use uuid::Uuid; + +use crate::database::models::Alert; +use crate::database::repositories::alerts::create_alert; +use crate::database::DbPool; +use crate::docker::client::{ContainerInfo, ContainerStats}; +use crate::docker::containers::ContainerManager; + +const DEFAULT_TARGET_PATTERNS: &[&str] = &[ + "wordpress", + "php", + "php-fpm", + "apache", + "httpd", + "drupal", + "joomla", + "woocommerce", +]; +const DEFAULT_ALLOWLIST_PATTERNS: &[&str] = + &["postfix", "exim", "mailhog", "mailpit", "smtp", "sendmail"]; + +#[derive(Debug, Clone)] +pub struct MailAbuseGuardConfig { + pub enabled: bool, + pub poll_interval_secs: u64, + pub min_tx_packets_per_interval: u64, + pub min_tx_bytes_per_interval: u64, + pub max_avg_bytes_per_packet: u64, + pub consecutive_suspicious_intervals: u32, + pub target_patterns: Vec, + pub allowlist_patterns: Vec, +} + +impl MailAbuseGuardConfig { + pub fn from_env() -> Self { + Self { + enabled: parse_bool_env("STACKDOG_MAIL_GUARD_ENABLED", true), + poll_interval_secs: parse_u64_env("STACKDOG_MAIL_GUARD_INTERVAL_SECS", 10), + min_tx_packets_per_interval: parse_u64_env("STACKDOG_MAIL_GUARD_MIN_TX_PACKETS", 250), + min_tx_bytes_per_interval: parse_u64_env("STACKDOG_MAIL_GUARD_MIN_TX_BYTES", 64 * 1024), + max_avg_bytes_per_packet: parse_u64_env( + "STACKDOG_MAIL_GUARD_MAX_AVG_BYTES_PER_PACKET", + 800, + ), + consecutive_suspicious_intervals: parse_u32_env( + "STACKDOG_MAIL_GUARD_CONSECUTIVE_INTERVALS", + 3, + ), + target_patterns: parse_list_env("STACKDOG_MAIL_GUARD_TARGETS").unwrap_or_else(|| { + DEFAULT_TARGET_PATTERNS + .iter() + .map(|s| s.to_string()) + .collect() + }), + allowlist_patterns: parse_list_env("STACKDOG_MAIL_GUARD_ALLOWLIST").unwrap_or_else( + || { + DEFAULT_ALLOWLIST_PATTERNS + .iter() + .map(|s| s.to_string()) + .collect() + }, + ), + } + } +} + +fn parse_bool_env(name: &str, default: bool) -> bool { + env::var(name) + .ok() + .and_then(|value| match value.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Some(true), + "0" | "false" | "no" | "off" => Some(false), + _ => None, + }) + .unwrap_or(default) +} + +fn parse_u64_env(name: &str, default: u64) -> u64 { + env::var(name) + .ok() + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(default) +} + +fn parse_u32_env(name: &str, default: u32) -> u32 { + env::var(name) + .ok() + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(default) +} + +fn parse_list_env(name: &str) -> Option> { + env::var(name).ok().map(|value| { + value + .split(',') + .map(|part| part.trim().to_ascii_lowercase()) + .filter(|part| !part.is_empty()) + .collect() + }) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct TrafficSnapshot { + tx_bytes: u64, + rx_bytes: u64, + tx_packets: u64, + rx_packets: u64, +} + +impl From<&ContainerStats> for TrafficSnapshot { + fn from(stats: &ContainerStats) -> Self { + Self { + tx_bytes: stats.network_tx, + rx_bytes: stats.network_rx, + tx_packets: stats.network_tx_packets, + rx_packets: stats.network_rx_packets, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct TrafficDelta { + tx_bytes: u64, + rx_bytes: u64, + tx_packets: u64, + rx_packets: u64, +} + +#[derive(Debug, Default)] +struct ContainerTrafficState { + previous: Option, + suspicious_intervals: u32, + quarantined: bool, +} + +#[derive(Debug, Clone)] +struct GuardDecision { + should_quarantine: bool, + reason: Option, +} + +impl GuardDecision { + fn no_action() -> Self { + Self { + should_quarantine: false, + reason: None, + } + } +} + +#[derive(Debug, Default)] +struct MailAbuseDetector { + states: HashMap, +} + +impl MailAbuseDetector { + fn evaluate_container( + &mut self, + info: &ContainerInfo, + stats: &ContainerStats, + config: &MailAbuseGuardConfig, + ) -> GuardDecision { + if is_allowlisted(info, config) { + self.states.remove(&info.id); + return GuardDecision::no_action(); + } + + let state = self.states.entry(info.id.clone()).or_default(); + let current = TrafficSnapshot::from(stats); + + let Some(previous) = state.previous.replace(current) else { + return GuardDecision::no_action(); + }; + + let Some(delta) = compute_delta(previous, current) else { + state.suspicious_intervals = 0; + return GuardDecision::no_action(); + }; + + if state.quarantined { + return GuardDecision::no_action(); + } + + if !is_targeted_container(info, config) || !is_suspicious_egress(delta, config) { + state.suspicious_intervals = 0; + return GuardDecision::no_action(); + } + + state.suspicious_intervals += 1; + let avg_bytes_per_packet = if delta.tx_packets == 0 { + 0 + } else { + delta.tx_bytes / delta.tx_packets + }; + let reason = format!( + "possible outbound mail abuse detected for {} (image: {}) — {} tx packets / {} bytes over {}s, avg {} bytes/packet, strike {}/{}", + info.name, + info.image, + delta.tx_packets, + delta.tx_bytes, + config.poll_interval_secs, + avg_bytes_per_packet, + state.suspicious_intervals, + config.consecutive_suspicious_intervals + ); + + GuardDecision { + should_quarantine: state.suspicious_intervals + >= config.consecutive_suspicious_intervals, + reason: Some(reason), + } + } + + fn mark_quarantined(&mut self, container_id: &str) { + if let Some(state) = self.states.get_mut(container_id) { + state.quarantined = true; + } + } + + fn prune(&mut self, active_container_ids: &HashSet) { + self.states + .retain(|container_id, _| active_container_ids.contains(container_id)); + } +} + +fn compute_delta(previous: TrafficSnapshot, current: TrafficSnapshot) -> Option { + Some(TrafficDelta { + tx_bytes: current.tx_bytes.checked_sub(previous.tx_bytes)?, + rx_bytes: current.rx_bytes.checked_sub(previous.rx_bytes)?, + tx_packets: current.tx_packets.checked_sub(previous.tx_packets)?, + rx_packets: current.rx_packets.checked_sub(previous.rx_packets)?, + }) +} + +fn is_targeted_container(info: &ContainerInfo, config: &MailAbuseGuardConfig) -> bool { + let identity = format!( + "{} {} {}", + info.id.to_ascii_lowercase(), + info.name.to_ascii_lowercase(), + info.image.to_ascii_lowercase() + ); + config + .target_patterns + .iter() + .any(|pattern| identity.contains(pattern)) +} + +fn is_allowlisted(info: &ContainerInfo, config: &MailAbuseGuardConfig) -> bool { + let identity = format!( + "{} {} {}", + info.id.to_ascii_lowercase(), + info.name.to_ascii_lowercase(), + info.image.to_ascii_lowercase() + ); + config + .allowlist_patterns + .iter() + .any(|pattern| identity.contains(pattern)) +} + +fn is_suspicious_egress(delta: TrafficDelta, config: &MailAbuseGuardConfig) -> bool { + if delta.tx_packets < config.min_tx_packets_per_interval + || delta.tx_bytes < config.min_tx_bytes_per_interval + { + return false; + } + + let avg_bytes_per_packet = delta.tx_bytes / delta.tx_packets.max(1); + avg_bytes_per_packet <= config.max_avg_bytes_per_packet +} + +pub struct MailAbuseGuard; + +impl MailAbuseGuard { + pub async fn run(pool: DbPool, config: MailAbuseGuardConfig) { + log::info!( + "Starting mail abuse guard (interval={}s, min_tx_packets={}, min_tx_bytes={}, max_avg_bytes_per_packet={}, strikes={})", + config.poll_interval_secs, + config.min_tx_packets_per_interval, + config.min_tx_bytes_per_interval, + config.max_avg_bytes_per_packet, + config.consecutive_suspicious_intervals + ); + + let mut detector = MailAbuseDetector::default(); + + loop { + if let Err(err) = Self::poll_once(&pool, &config, &mut detector).await { + log::warn!("Mail abuse guard poll failed: {}", err); + } + + sleep(Duration::from_secs(config.poll_interval_secs)).await; + } + } + + async fn poll_once( + pool: &DbPool, + config: &MailAbuseGuardConfig, + detector: &mut MailAbuseDetector, + ) -> anyhow::Result<()> { + let manager = ContainerManager::new(pool.clone()).await?; + let containers = manager.list_containers().await?; + let mut active_container_ids = HashSet::new(); + + for container in containers { + if container.status != "Running" { + continue; + } + + active_container_ids.insert(container.id.clone()); + let stats = manager.get_container_stats(&container.id).await?; + let decision = detector.evaluate_container(&container, &stats, config); + + if decision.should_quarantine { + let reason = decision.reason.unwrap_or_else(|| { + format!( + "possible outbound mail abuse detected for {}", + container.name + ) + }); + + manager.quarantine_container(&container.id, &reason).await?; + detector.mark_quarantined(&container.id); + create_alert( + pool, + Alert { + id: Uuid::new_v4().to_string(), + alert_type: "ThreatDetected".into(), + severity: "Critical".into(), + message: format!( + "Mail abuse guard quarantined container {} ({})", + container.name, container.id + ), + status: "New".into(), + timestamp: Utc::now().to_rfc3339(), + metadata: Some(reason.clone()), + }, + ) + .await?; + log::warn!("{}", reason); + } + } + + detector.prune(&active_container_ids); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn config() -> MailAbuseGuardConfig { + MailAbuseGuardConfig { + enabled: true, + poll_interval_secs: 10, + min_tx_packets_per_interval: 100, + min_tx_bytes_per_interval: 10_000, + max_avg_bytes_per_packet: 300, + consecutive_suspicious_intervals: 2, + target_patterns: vec!["wordpress".into()], + allowlist_patterns: vec!["mailhog".into()], + } + } + + fn container(name: &str, image: &str) -> ContainerInfo { + ContainerInfo { + id: "abc123".into(), + name: name.into(), + image: image.into(), + status: "Running".into(), + created: String::new(), + network_settings: HashMap::new(), + } + } + + fn stats(tx_bytes: u64, rx_bytes: u64, tx_packets: u64, rx_packets: u64) -> ContainerStats { + ContainerStats { + cpu_percent: 0.0, + memory_usage: 0, + memory_limit: 0, + network_rx: rx_bytes, + network_tx: tx_bytes, + network_rx_packets: rx_packets, + network_tx_packets: tx_packets, + } + } + + #[test] + fn test_detector_requires_consecutive_intervals() { + let mut detector = MailAbuseDetector::default(); + let info = container("wordpress", "wordpress:latest"); + let config = config(); + + let first = detector.evaluate_container(&info, &stats(10_000, 5_000, 100, 50), &config); + assert!(!first.should_quarantine); + + let second = detector.evaluate_container(&info, &stats(40_000, 8_000, 260, 80), &config); + assert!(!second.should_quarantine); + + let third = detector.evaluate_container(&info, &stats(80_000, 11_000, 420, 100), &config); + assert!(third.should_quarantine); + } + + #[test] + fn test_detector_ignores_allowlisted_container() { + let mut detector = MailAbuseDetector::default(); + let info = container("mailhog", "mailhog/mailhog"); + let config = config(); + + detector.evaluate_container(&info, &stats(10_000, 5_000, 100, 50), &config); + let decision = detector.evaluate_container(&info, &stats(50_000, 8_000, 260, 80), &config); + + assert!(!decision.should_quarantine); + } + + #[test] + fn test_detector_resets_strikes_after_normal_interval() { + let mut detector = MailAbuseDetector::default(); + let info = container("wordpress", "wordpress:latest"); + let config = config(); + + detector.evaluate_container(&info, &stats(10_000, 5_000, 100, 50), &config); + detector.evaluate_container(&info, &stats(40_000, 8_000, 260, 80), &config); + let normal = detector.evaluate_container(&info, &stats(42_000, 9_000, 265, 82), &config); + assert!(!normal.should_quarantine); + + let suspicious = + detector.evaluate_container(&info, &stats(82_000, 12_000, 430, 100), &config); + assert!(!suspicious.should_quarantine); + } +} diff --git a/src/docker/mod.rs b/src/docker/mod.rs index 03de6d2..4e4650f 100644 --- a/src/docker/mod.rs +++ b/src/docker/mod.rs @@ -2,6 +2,8 @@ pub mod client; pub mod containers; +pub mod mail_guard; pub use client::{ContainerInfo, ContainerStats, DockerClient}; pub use containers::{ContainerManager, ContainerSecurityStatus}; +pub use mail_guard::{MailAbuseGuard, MailAbuseGuardConfig}; diff --git a/src/main.rs b/src/main.rs index 3eefadd..a0ce50d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -121,6 +121,16 @@ async fn run_serve() -> io::Result<()> { init_database(&pool).expect("Failed to initialize database"); info!("Database initialized successfully"); + let mail_guard_config = stackdog::docker::MailAbuseGuardConfig::from_env(); + if mail_guard_config.enabled { + let guard_pool = pool.clone(); + actix_rt::spawn(async move { + stackdog::docker::MailAbuseGuard::run(guard_pool, mail_guard_config).await; + }); + } else { + info!("Mail abuse guard disabled"); + } + info!("🎉 Stackdog Security ready!"); info!(""); info!("API Endpoints:"); diff --git a/src/sniff/reader.rs b/src/sniff/reader.rs index fa3e450..6f1c235 100644 --- a/src/sniff/reader.rs +++ b/src/sniff/reader.rs @@ -76,10 +76,11 @@ impl FileLogReader { reader.seek(SeekFrom::Start(self.offset))?; let mut entries = Vec::new(); - let mut line = String::new(); + let mut line = Vec::new(); - while reader.read_line(&mut line)? > 0 { - let trimmed = line.trim_end().to_string(); + while reader.read_until(b'\n', &mut line)? > 0 { + let decoded = String::from_utf8_lossy(&line); + let trimmed = decoded.trim_end().to_string(); if !trimmed.is_empty() { entries.push(LogEntry { source_id: self.source_id.clone(), @@ -350,6 +351,21 @@ mod tests { assert_eq!(entries[0].line, "line C"); } + #[tokio::test] + async fn test_file_log_reader_handles_invalid_utf8() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("invalid-utf8.log"); + std::fs::write(&path, b"ok line\nbad byte \xff\n").unwrap(); + + let mut reader = FileLogReader::new("utf8".into(), path.to_string_lossy().to_string(), 0); + let entries = reader.read_new_entries().await.unwrap(); + + assert_eq!(entries.len(), 2); + assert_eq!(entries[0].line, "ok line"); + assert!(entries[1].line.contains("bad byte")); + assert!(entries[1].line.contains('\u{fffd}')); + } + #[tokio::test] async fn test_file_log_reader_handles_truncation() { let dir = tempfile::tempdir().unwrap(); From 83434cf83271a34cab7f2ce968cce0ba3dee2a78 Mon Sep 17 00:00:00 2001 From: vsilent Date: Fri, 3 Apr 2026 17:59:56 +0300 Subject: [PATCH 02/10] iptables & nftables --- Cargo.toml | 2 +- README.md | 2 +- VERSION.md | 2 +- src/firewall/iptables.rs | 62 ++++++++++++++----- src/firewall/nftables.rs | 125 +++++++++++++++++++++++++++++++++++---- src/firewall/response.rs | 110 +++++++++++++++++++++++++++++++--- web/package.json | 2 +- 7 files changed, 265 insertions(+), 40 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7f450b9..bdb3005 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stackdog" -version = "0.2.0" +version = "0.2.1" authors = ["Vasili Pascal "] edition = "2021" description = "Security platform for Docker containers and Linux servers" diff --git a/README.md b/README.md index a4a1f86..f81a0ef 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Stackdog Security -![Version](https://img.shields.io/badge/version-0.2.0-blue.svg) +![Version](https://img.shields.io/badge/version-0.2.1-blue.svg) ![License](https://img.shields.io/badge/license-MIT-green.svg) ![Rust](https://img.shields.io/badge/rust-1.75+-orange.svg) ![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey.svg) diff --git a/VERSION.md b/VERSION.md index 341cf11..0c62199 100644 --- a/VERSION.md +++ b/VERSION.md @@ -1 +1 @@ -0.2.0 \ No newline at end of file +0.2.1 diff --git a/src/firewall/iptables.rs b/src/firewall/iptables.rs index b45e29e..544b202 100644 --- a/src/firewall/iptables.rs +++ b/src/firewall/iptables.rs @@ -45,6 +45,19 @@ pub struct IptablesBackend { } impl IptablesBackend { + fn run_iptables(&self, args: &[&str], context: &str) -> Result<()> { + let output = Command::new("iptables") + .args(args) + .output() + .context(context)?; + + if !output.status.success() { + anyhow::bail!("{}", String::from_utf8_lossy(&output.stderr).trim()); + } + + Ok(()) + } + /// Create a new iptables backend pub fn new() -> Result { #[cfg(target_os = "linux")] @@ -193,37 +206,47 @@ impl FirewallBackend for IptablesBackend { } fn block_ip(&self, ip: &str) -> Result<()> { - let chain = IptChain::new("filter", "INPUT"); - let rule = IptRule::new(&chain, format!("-s {} -j DROP", ip)); - self.add_rule(&rule) + self.run_iptables( + &["-I", "INPUT", "-s", ip, "-j", "DROP"], + "Failed to block IP with iptables", + ) } fn unblock_ip(&self, ip: &str) -> Result<()> { - let chain = IptChain::new("filter", "INPUT"); - let rule = IptRule::new(&chain, format!("-s {} -j DROP", ip)); - self.delete_rule(&rule) + self.run_iptables( + &["-D", "INPUT", "-s", ip, "-j", "DROP"], + "Failed to unblock IP with iptables", + ) } fn block_port(&self, port: u16) -> Result<()> { - let chain = IptChain::new("filter", "INPUT"); - let rule = IptRule::new(&chain, format!("-p tcp --dport {} -j DROP", port)); - self.add_rule(&rule) + let port = port.to_string(); + self.run_iptables( + &["-I", "OUTPUT", "-p", "tcp", "--dport", &port, "-j", "DROP"], + "Failed to block port with iptables", + ) } fn unblock_port(&self, port: u16) -> Result<()> { - let chain = IptChain::new("filter", "INPUT"); - let rule = IptRule::new(&chain, format!("-p tcp --dport {} -j DROP", port)); - self.delete_rule(&rule) + let port = port.to_string(); + self.run_iptables( + &["-D", "OUTPUT", "-p", "tcp", "--dport", &port, "-j", "DROP"], + "Failed to unblock port with iptables", + ) } fn block_container(&self, container_id: &str) -> Result<()> { - log::info!("Would block container via iptables: {}", container_id); - Ok(()) + anyhow::bail!( + "Container-specific iptables blocking is not implemented yet for {}", + container_id + ) } fn unblock_container(&self, container_id: &str) -> Result<()> { - log::info!("Would unblock container via iptables: {}", container_id); - Ok(()) + anyhow::bail!( + "Container-specific iptables unblocking is not implemented yet for {}", + container_id + ) } fn name(&self) -> &str { @@ -248,4 +271,11 @@ mod tests { let rule = IptRule::new(&chain, "-p tcp --dport 22 -j DROP"); assert_eq!(rule.rule_spec, "-p tcp --dport 22 -j DROP"); } + + #[test] + fn test_block_container_is_explicitly_unsupported() { + let backend = IptablesBackend { available: true }; + let result = backend.block_container("container-1"); + assert!(result.is_err()); + } } diff --git a/src/firewall/nftables.rs b/src/firewall/nftables.rs index afb8b2b..7c7703d 100644 --- a/src/firewall/nftables.rs +++ b/src/firewall/nftables.rs @@ -69,6 +69,73 @@ pub struct NfTablesBackend { } impl NfTablesBackend { + fn run_nft(&self, args: &[&str], context: &str) -> Result<()> { + let output = Command::new("nft").args(args).output().context(context)?; + + if !output.status.success() { + anyhow::bail!("{}", String::from_utf8_lossy(&output.stderr).trim()); + } + + Ok(()) + } + + fn base_table(&self) -> NfTable { + NfTable::new("inet", "stackdog") + } + + fn ensure_filter_table(&self) -> Result<()> { + let table = self.base_table(); + let _ = self.run_nft( + &["add", "table", &table.family, &table.name], + "Failed to ensure nftables table", + ); + let _ = self.run_nft( + &[ + "add", + "chain", + &table.family, + &table.name, + "input", + "{", + "type", + "filter", + "hook", + "input", + "priority", + "0", + ";", + "policy", + "accept", + ";", + "}", + ], + "Failed to ensure nftables input chain", + ); + let _ = self.run_nft( + &[ + "add", + "chain", + &table.family, + &table.name, + "output", + "{", + "type", + "filter", + "hook", + "output", + "priority", + "0", + ";", + "policy", + "accept", + ";", + "}", + ], + "Failed to ensure nftables output chain", + ); + Ok(()) + } + /// Create a new nftables backend pub fn new() -> Result { #[cfg(target_os = "linux")] @@ -274,34 +341,59 @@ impl FirewallBackend for NfTablesBackend { } fn block_ip(&self, ip: &str) -> Result<()> { - // Implementation would add nftables rule to block IP - log::info!("Would block IP: {}", ip); - Ok(()) + self.ensure_filter_table()?; + self.run_nft( + &[ + "add", "rule", "inet", "stackdog", "input", "ip", "saddr", ip, "drop", + ], + "Failed to block IP with nftables", + ) } fn unblock_ip(&self, ip: &str) -> Result<()> { - log::info!("Would unblock IP: {}", ip); - Ok(()) + self.ensure_filter_table()?; + self.run_nft( + &[ + "delete", "rule", "inet", "stackdog", "input", "ip", "saddr", ip, "drop", + ], + "Failed to unblock IP with nftables", + ) } fn block_port(&self, port: u16) -> Result<()> { - log::info!("Would block port: {}", port); - Ok(()) + self.ensure_filter_table()?; + let port = port.to_string(); + self.run_nft( + &[ + "add", "rule", "inet", "stackdog", "output", "tcp", "dport", &port, "drop", + ], + "Failed to block port with nftables", + ) } fn unblock_port(&self, port: u16) -> Result<()> { - log::info!("Would unblock port: {}", port); - Ok(()) + self.ensure_filter_table()?; + let port = port.to_string(); + self.run_nft( + &[ + "delete", "rule", "inet", "stackdog", "output", "tcp", "dport", &port, "drop", + ], + "Failed to unblock port with nftables", + ) } fn block_container(&self, container_id: &str) -> Result<()> { - log::info!("Would block container: {}", container_id); - Ok(()) + anyhow::bail!( + "Container-specific nftables blocking is not implemented yet for {}", + container_id + ) } fn unblock_container(&self, container_id: &str) -> Result<()> { - log::info!("Would unblock container: {}", container_id); - Ok(()) + anyhow::bail!( + "Container-specific nftables unblocking is not implemented yet for {}", + container_id + ) } fn name(&self) -> &str { @@ -327,4 +419,11 @@ mod tests { assert_eq!(chain.name, "input"); assert_eq!(chain.chain_type, "filter"); } + + #[test] + fn test_block_container_is_explicitly_unsupported() { + let backend = NfTablesBackend { available: true }; + let result = backend.block_container("container-1"); + assert!(result.is_err()); + } } diff --git a/src/firewall/response.rs b/src/firewall/response.rs index f4a6d91..52e9cd8 100644 --- a/src/firewall/response.rs +++ b/src/firewall/response.rs @@ -4,9 +4,12 @@ use anyhow::Result; use chrono::{DateTime, Utc}; +use std::process::Command; use std::sync::{Arc, RwLock}; use crate::alerting::alert::Alert; +use crate::firewall::backend::FirewallBackend; +use crate::firewall::{IptablesBackend, NfTablesBackend}; /// Response action types #[derive(Debug, Clone)] @@ -30,6 +33,17 @@ pub struct ResponseAction { } impl ResponseAction { + fn preferred_backend() -> Result> { + if let Ok(mut backend) = NfTablesBackend::new() { + backend.initialize()?; + return Ok(Box::new(backend)); + } + + let mut backend = IptablesBackend::new()?; + backend.initialize()?; + Ok(Box::new(backend)) + } + /// Create a new response action pub fn new(action_type: ResponseType, description: String) -> Self { Self { @@ -84,19 +98,24 @@ impl ResponseAction { Ok(()) } ResponseType::BlockIP(ip) => { - log::info!("Would block IP: {}", ip); - Ok(()) + let backend = Self::preferred_backend()?; + backend.block_ip(ip) } ResponseType::BlockPort(port) => { - log::info!("Would block port: {}", port); - Ok(()) + let backend = Self::preferred_backend()?; + backend.block_port(*port) } ResponseType::QuarantineContainer(id) => { - log::info!("Would quarantine container: {}", id); - Ok(()) + let backend = Self::preferred_backend()?; + backend.block_container(id) } ResponseType::KillProcess(pid) => { - log::info!("Would kill process: {}", pid); + let output = Command::new("kill") + .args(["-TERM", &pid.to_string()]) + .output()?; + if !output.status.success() { + anyhow::bail!("{}", String::from_utf8_lossy(&output.stderr).trim()); + } Ok(()) } ResponseType::SendAlert(msg) => { @@ -330,6 +349,7 @@ impl Default for ResponseAudit { #[cfg(test)] mod tests { use super::*; + use std::time::Instant; #[test] fn test_response_action_creation() { @@ -381,4 +401,80 @@ mod tests { assert!(log.success()); assert_eq!(log.action_name(), "test_action"); } + + #[test] + fn test_quarantine_action_returns_error_when_container_blocking_missing() { + let action = ResponseAction::new( + ResponseType::QuarantineContainer("container-1".to_string()), + "Quarantine".to_string(), + ); + + let result = action.execute(); + assert!(result.is_err()); + } + + #[test] + fn test_response_chain_stops_on_failure() { + let mut chain = ResponseChain::new("stop-on-failure"); + chain.set_stop_on_failure(true); + chain.add_action(ResponseAction::new( + ResponseType::QuarantineContainer("container-1".to_string()), + "Quarantine".to_string(), + )); + chain.add_action(ResponseAction::new( + ResponseType::LogAction("after".to_string()), + "After".to_string(), + )); + + let result = chain.execute(); + assert!(result.is_err()); + } + + #[test] + fn test_response_chain_continues_when_failure_allowed() { + let mut chain = ResponseChain::new("continue-on-failure"); + chain.add_action(ResponseAction::new( + ResponseType::QuarantineContainer("container-1".to_string()), + "Quarantine".to_string(), + )); + chain.add_action(ResponseAction::new( + ResponseType::LogAction("after".to_string()), + "After".to_string(), + )); + + let result = chain.execute(); + assert!(result.is_ok()); + } + + #[test] + fn test_execute_with_retry_honors_retry_count() { + let mut action = ResponseAction::new( + ResponseType::QuarantineContainer("container-1".to_string()), + "Quarantine".to_string(), + ); + action.set_retry_config(2, 0); + + let started = Instant::now(); + let result = action.execute_with_retry(); + + assert!(result.is_err()); + assert!(started.elapsed().as_millis() < 100); + } + + #[test] + fn test_response_executor_records_failed_action() { + let mut executor = ResponseExecutor::new().unwrap(); + let action = ResponseAction::new( + ResponseType::QuarantineContainer("container-1".to_string()), + "Quarantine".to_string(), + ); + + let result = executor.execute(&action); + let log = executor.get_log(); + + assert!(result.is_err()); + assert_eq!(log.len(), 1); + assert!(!log[0].success()); + assert!(log[0].error().is_some()); + } } diff --git a/web/package.json b/web/package.json index cb65949..c1fae63 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "stackdog-web", "description": "Stackdog Security Web Dashboard", - "version": "0.2.0", + "version": "0.2.1", "scripts": { "start": "cross-env REACT_APP_VERSION=$npm_package_version webpack serve --mode development", "build": "cross-env REACT_APP_VERSION=$npm_package_version webpack --mode production", From ca0944fd52321da7b2bfd98aeea72bb3e4fa0f05 Mon Sep 17 00:00:00 2001 From: vsilent Date: Fri, 3 Apr 2026 22:22:41 +0300 Subject: [PATCH 03/10] multiple updates, eBPF, container API quality is improved --- Cargo.toml | 1 + ebpf/.cargo/config | 5 - ebpf/rust-toolchain.toml | 3 + ebpf/src/main.rs | 9 +- ebpf/src/maps.rs | 125 +++++- ebpf/src/syscalls.rs | 164 ++++++- .../00000000000000_create_alerts/up.sql | 3 + src/alerting/alert.rs | 49 ++- src/alerting/mod.rs | 2 + src/alerting/notifications.rs | 304 +++++++++++-- src/alerting/rules.rs | 61 ++- src/api/alerts.rs | 57 ++- src/api/containers.rs | 198 ++++++--- src/api/security.rs | 83 +++- src/api/threats.rs | 164 +++++-- src/api/websocket.rs | 335 +++++++++++++-- src/baselines/learning.rs | 251 ++++++++++- src/cli.rs | 63 +++ src/collectors/docker_events.rs | 132 +++++- src/collectors/ebpf/enrichment.rs | 51 ++- src/collectors/ebpf/syscall_monitor.rs | 5 + src/collectors/ebpf/types.rs | 182 +++++++- src/collectors/network.rs | 104 ++++- src/database/baselines.rs | 154 ++++++- src/database/connection.rs | 23 + src/database/mod.rs | 2 + src/database/models/mod.rs | 110 ++++- src/database/repositories/alerts.rs | 402 ++++++++++++++++-- src/docker/containers.rs | 43 +- src/docker/mail_guard.rs | 24 +- src/events/syscall.rs | 82 ++++ src/lib.rs | 2 + src/main.rs | 26 ++ src/ml/anomaly.rs | 181 +++++++- src/ml/candle_backend.rs | 72 +++- src/ml/features.rs | 135 ++++++ src/ml/models/isolation_forest.rs | 321 +++++++++++++- src/ml/scorer.rs | 143 ++++++- src/models/api/alerts.rs | 17 + src/models/api/containers.rs | 14 +- src/models/api/security.rs | 35 +- src/response/mod.rs | 4 + src/response/pipeline.rs | 200 ++++++++- src/rules/builtin.rs | 91 +++- src/rules/threat_scorer.rs | 91 ++++ src/sniff/config.rs | 163 ++++++- src/sniff/mod.rs | 30 +- src/sniff/reporter.rs | 70 ++- web/src/components/ContainerList.tsx | 27 +- .../__tests__/ContainerList.test.tsx | 34 +- web/src/services/api.ts | 74 +++- web/src/types/containers.ts | 14 +- 52 files changed, 4564 insertions(+), 371 deletions(-) delete mode 100644 ebpf/.cargo/config create mode 100644 ebpf/rust-toolchain.toml diff --git a/Cargo.toml b/Cargo.toml index bdb3005..27ad71d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ zstd = "0.13" # Stream utilities futures-util = "0.3" +lettre = { version = "0.11", default-features = false, features = ["tokio1", "tokio1-rustls-tls", "builder", "smtp-transport"] } # eBPF (Linux only) [target.'cfg(target_os = "linux")'.dependencies] diff --git a/ebpf/.cargo/config b/ebpf/.cargo/config deleted file mode 100644 index d19f05d..0000000 --- a/ebpf/.cargo/config +++ /dev/null @@ -1,5 +0,0 @@ -[build] -target = ["bpfel-unknown-none"] - -[target.bpfel-unknown-none] -rustflags = ["-C", "link-arg=--Bstatic"] diff --git a/ebpf/rust-toolchain.toml b/ebpf/rust-toolchain.toml new file mode 100644 index 0000000..f70d225 --- /dev/null +++ b/ebpf/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly" +components = ["rust-src"] diff --git a/ebpf/src/main.rs b/ebpf/src/main.rs index e7a894d..04c6ee4 100644 --- a/ebpf/src/main.rs +++ b/ebpf/src/main.rs @@ -5,5 +5,10 @@ #![no_main] #![no_std] -#[no_mangle] -pub fn main() {} +mod maps; +mod syscalls; + +#[panic_handler] +fn panic(_info: &core::panic::PanicInfo<'_>) -> ! { + loop {} +} diff --git a/ebpf/src/maps.rs b/ebpf/src/maps.rs index 4acc9dc..1ff8d6b 100644 --- a/ebpf/src/maps.rs +++ b/ebpf/src/maps.rs @@ -2,8 +2,123 @@ //! //! Shared maps for eBPF programs -// TODO: Implement eBPF maps in TASK-003 -// This will include: -// - Event ring buffer for sending events to userspace -// - Hash maps for tracking state -// - Arrays for configuration +use aya_ebpf::{macros::map, maps::RingBuf}; + +#[repr(C)] +#[derive(Clone, Copy)] +pub union EbpfEventData { + pub execve: ExecveData, + pub connect: ConnectData, + pub openat: OpenatData, + pub ptrace: PtraceData, + pub raw: [u8; 264], +} + +impl EbpfEventData { + pub const fn empty() -> Self { + Self { raw: [0u8; 264] } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct EbpfSyscallEvent { + pub pid: u32, + pub uid: u32, + pub syscall_id: u32, + pub _pad: u32, + pub timestamp: u64, + pub comm: [u8; 16], + pub data: EbpfEventData, +} + +impl EbpfSyscallEvent { + pub const fn empty() -> Self { + Self { + pid: 0, + uid: 0, + syscall_id: 0, + _pad: 0, + timestamp: 0, + comm: [0u8; 16], + data: EbpfEventData::empty(), + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct ExecveData { + pub filename_len: u32, + pub filename: [u8; 128], + pub argc: u32, +} + +impl ExecveData { + pub const fn empty() -> Self { + Self { + filename_len: 0, + filename: [0u8; 128], + argc: 0, + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct ConnectData { + pub dst_ip: [u8; 16], + pub dst_port: u16, + pub family: u16, +} + +impl ConnectData { + pub const fn empty() -> Self { + Self { + dst_ip: [0u8; 16], + dst_port: 0, + family: 0, + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct OpenatData { + pub path_len: u32, + pub path: [u8; 256], + pub flags: u32, +} + +impl OpenatData { + pub const fn empty() -> Self { + Self { + path_len: 0, + path: [0u8; 256], + flags: 0, + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct PtraceData { + pub target_pid: u32, + pub request: u32, + pub addr: u64, + pub data: u64, +} + +impl PtraceData { + pub const fn empty() -> Self { + Self { + target_pid: 0, + request: 0, + addr: 0, + data: 0, + } + } +} + +#[map(name = "EVENTS")] +pub static EVENTS: RingBuf = RingBuf::with_byte_size(256 * 1024, 0); diff --git a/ebpf/src/syscalls.rs b/ebpf/src/syscalls.rs index 64d8de9..cdfbd06 100644 --- a/ebpf/src/syscalls.rs +++ b/ebpf/src/syscalls.rs @@ -2,10 +2,160 @@ //! //! Tracepoints for monitoring security-relevant syscalls -// TODO: Implement eBPF syscall monitoring programs in TASK-003 -// This will include: -// - execve/execveat monitoring -// - connect/accept/bind monitoring -// - open/openat monitoring -// - ptrace monitoring -// - mount/umount monitoring +use aya_ebpf::{ + helpers::{ + bpf_get_current_comm, bpf_probe_read_user, bpf_probe_read_user_buf, + bpf_probe_read_user_str_bytes, + }, + macros::tracepoint, + programs::TracePointContext, + EbpfContext, +}; + +use crate::maps::{ + ConnectData, EbpfEventData, EbpfSyscallEvent, ExecveData, OpenatData, PtraceData, EVENTS, +}; + +const SYSCALL_ARG_START: usize = 16; +const SYSCALL_ARG_SIZE: usize = 8; + +const SYS_EXECVE: u32 = 59; +const SYS_CONNECT: u32 = 42; +const SYS_OPENAT: u32 = 257; +const SYS_PTRACE: u32 = 101; + +const AF_INET: u16 = 2; +const AF_INET6: u16 = 10; +const MAX_ARGC_SCAN: usize = 16; + +#[tracepoint(name = "sys_enter_execve", category = "syscalls")] +pub fn trace_execve(ctx: TracePointContext) -> i32 { + let _ = unsafe { try_trace_execve(&ctx) }; + 0 +} + +#[tracepoint(name = "sys_enter_connect", category = "syscalls")] +pub fn trace_connect(ctx: TracePointContext) -> i32 { + let _ = unsafe { try_trace_connect(&ctx) }; + 0 +} + +#[tracepoint(name = "sys_enter_openat", category = "syscalls")] +pub fn trace_openat(ctx: TracePointContext) -> i32 { + let _ = unsafe { try_trace_openat(&ctx) }; + 0 +} + +#[tracepoint(name = "sys_enter_ptrace", category = "syscalls")] +pub fn trace_ptrace(ctx: TracePointContext) -> i32 { + let _ = unsafe { try_trace_ptrace(&ctx) }; + 0 +} + +unsafe fn try_trace_execve(ctx: &TracePointContext) -> Result<(), i64> { + let filename_ptr = read_u64_arg(ctx, 0)? as *const u8; + let argv_ptr = read_u64_arg(ctx, 1)? as *const u64; + let mut event = base_event(ctx, SYS_EXECVE); + let mut data = ExecveData::empty(); + + if !filename_ptr.is_null() { + if let Ok(bytes) = bpf_probe_read_user_str_bytes(filename_ptr, &mut data.filename) { + data.filename_len = bytes.len() as u32; + } + } + + data.argc = count_argv(argv_ptr).unwrap_or(0); + event.data = EbpfEventData { execve: data }; + submit_event(&event) +} + +unsafe fn try_trace_connect(ctx: &TracePointContext) -> Result<(), i64> { + let sockaddr_ptr = read_u64_arg(ctx, 1)? as *const u8; + if sockaddr_ptr.is_null() { + return Ok(()); + } + + let family = bpf_probe_read_user(sockaddr_ptr as *const u16)?; + let mut event = base_event(ctx, SYS_CONNECT); + let mut data = ConnectData::empty(); + data.family = family; + + if family == AF_INET { + data.dst_port = bpf_probe_read_user(sockaddr_ptr.add(2) as *const u16)?; + let mut addr = [0u8; 4]; + bpf_probe_read_user_buf(sockaddr_ptr.add(4), &mut addr)?; + data.dst_ip[..4].copy_from_slice(&addr); + } else if family == AF_INET6 { + data.dst_port = bpf_probe_read_user(sockaddr_ptr.add(2) as *const u16)?; + bpf_probe_read_user_buf(sockaddr_ptr.add(8), &mut data.dst_ip)?; + } + + event.data = EbpfEventData { connect: data }; + submit_event(&event) +} + +unsafe fn try_trace_openat(ctx: &TracePointContext) -> Result<(), i64> { + let pathname_ptr = read_u64_arg(ctx, 1)? as *const u8; + let flags = read_u64_arg(ctx, 2)? as u32; + let mut event = base_event(ctx, SYS_OPENAT); + let mut data = OpenatData::empty(); + data.flags = flags; + + if !pathname_ptr.is_null() { + if let Ok(bytes) = bpf_probe_read_user_str_bytes(pathname_ptr, &mut data.path) { + data.path_len = bytes.len() as u32; + } + } + + event.data = EbpfEventData { openat: data }; + submit_event(&event) +} + +unsafe fn try_trace_ptrace(ctx: &TracePointContext) -> Result<(), i64> { + let mut event = base_event(ctx, SYS_PTRACE); + let data = PtraceData { + request: read_u64_arg(ctx, 0)? as u32, + target_pid: read_u64_arg(ctx, 1)? as u32, + addr: read_u64_arg(ctx, 2)?, + data: read_u64_arg(ctx, 3)?, + }; + event.data = EbpfEventData { ptrace: data }; + submit_event(&event) +} + +fn base_event(ctx: &TracePointContext, syscall_id: u32) -> EbpfSyscallEvent { + let mut event = EbpfSyscallEvent::empty(); + event.pid = ctx.tgid(); + event.uid = ctx.uid(); + event.syscall_id = syscall_id; + event.timestamp = 0; + if let Ok(comm) = bpf_get_current_comm() { + event.comm = comm; + } + event +} + +fn submit_event(event: &EbpfSyscallEvent) -> Result<(), i64> { + EVENTS.output(event, 0) +} + +fn read_u64_arg(ctx: &TracePointContext, index: usize) -> Result { + unsafe { ctx.read_at::(SYSCALL_ARG_START + index * SYSCALL_ARG_SIZE) } +} + +unsafe fn count_argv(argv_ptr: *const u64) -> Result { + if argv_ptr.is_null() { + return Ok(0); + } + + let mut argc = 0u32; + while argc < MAX_ARGC_SCAN as u32 { + let arg_ptr = bpf_probe_read_user(argv_ptr.add(argc as usize))?; + if arg_ptr == 0 { + break; + } + argc += 1; + } + + Ok(argc) +} diff --git a/migrations/00000000000000_create_alerts/up.sql b/migrations/00000000000000_create_alerts/up.sql index 42dcc27..752ad63 100644 --- a/migrations/00000000000000_create_alerts/up.sql +++ b/migrations/00000000000000_create_alerts/up.sql @@ -14,3 +14,6 @@ CREATE TABLE IF NOT EXISTS alerts ( CREATE INDEX IF NOT EXISTS idx_alerts_status ON alerts(status); CREATE INDEX IF NOT EXISTS idx_alerts_severity ON alerts(severity); CREATE INDEX IF NOT EXISTS idx_alerts_timestamp ON alerts(timestamp); +CREATE INDEX IF NOT EXISTS idx_alerts_container_id + ON alerts(json_extract(metadata, '$.container_id')) + WHERE json_valid(metadata); diff --git a/src/alerting/alert.rs b/src/alerting/alert.rs index 76ef10e..311211a 100644 --- a/src/alerting/alert.rs +++ b/src/alerting/alert.rs @@ -9,7 +9,7 @@ use uuid::Uuid; use crate::events::security::SecurityEvent; /// Alert types -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum AlertType { ThreatDetected, AnomalyDetected, @@ -32,6 +32,22 @@ impl std::fmt::Display for AlertType { } } +impl std::str::FromStr for AlertType { + type Err = String; + + fn from_str(value: &str) -> Result { + match value { + "ThreatDetected" => Ok(Self::ThreatDetected), + "AnomalyDetected" => Ok(Self::AnomalyDetected), + "RuleViolation" => Ok(Self::RuleViolation), + "ThresholdExceeded" => Ok(Self::ThresholdExceeded), + "QuarantineApplied" => Ok(Self::QuarantineApplied), + "SystemEvent" => Ok(Self::SystemEvent), + _ => Err(format!("unknown alert type: {value}")), + } + } +} + /// Alert severity levels #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub enum AlertSeverity { @@ -54,6 +70,21 @@ impl std::fmt::Display for AlertSeverity { } } +impl std::str::FromStr for AlertSeverity { + type Err = String; + + fn from_str(value: &str) -> Result { + match value { + "Info" => Ok(Self::Info), + "Low" => Ok(Self::Low), + "Medium" => Ok(Self::Medium), + "High" => Ok(Self::High), + "Critical" => Ok(Self::Critical), + _ => Err(format!("unknown alert severity: {value}")), + } + } +} + /// Alert status #[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] pub enum AlertStatus { @@ -74,6 +105,20 @@ impl std::fmt::Display for AlertStatus { } } +impl std::str::FromStr for AlertStatus { + type Err = String; + + fn from_str(value: &str) -> Result { + match value { + "New" => Ok(Self::New), + "Acknowledged" => Ok(Self::Acknowledged), + "Resolved" => Ok(Self::Resolved), + "FalsePositive" => Ok(Self::FalsePositive), + _ => Err(format!("unknown alert status: {value}")), + } + } +} + /// Security alert #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Alert { @@ -113,7 +158,7 @@ impl Alert { /// Get alert type pub fn alert_type(&self) -> AlertType { - self.alert_type.clone() + self.alert_type } /// Get severity diff --git a/src/alerting/mod.rs b/src/alerting/mod.rs index 32231f2..ea6f4b4 100644 --- a/src/alerting/mod.rs +++ b/src/alerting/mod.rs @@ -6,6 +6,7 @@ pub mod alert; pub mod dedup; pub mod manager; pub mod notifications; +pub mod rules; /// Marker struct for module tests pub struct AlertingMarker; @@ -15,3 +16,4 @@ pub use alert::{Alert, AlertSeverity, AlertStatus, AlertType}; pub use dedup::{AlertDeduplicator, DedupConfig, DedupResult, Fingerprint}; pub use manager::{AlertManager, AlertStats}; pub use notifications::{NotificationChannel, NotificationConfig, NotificationResult}; +pub use rules::AlertRule; diff --git a/src/alerting/notifications.rs b/src/alerting/notifications.rs index d4ba3e5..ce2ae56 100644 --- a/src/alerting/notifications.rs +++ b/src/alerting/notifications.rs @@ -2,7 +2,10 @@ //! //! Notification channels for alert delivery -use anyhow::Result; +use anyhow::{Context, Result}; +use lettre::message::{Mailbox, MultiPart, SinglePart}; +use lettre::transport::smtp::authentication::Credentials; +use lettre::{AsyncSmtpTransport, AsyncTransport, Message, Tokio1Executor}; use crate::alerting::alert::{Alert, AlertSeverity}; @@ -12,10 +15,10 @@ pub struct NotificationConfig { slack_webhook: Option, smtp_host: Option, smtp_port: Option, - _smtp_user: Option, - _smtp_password: Option, + smtp_user: Option, + smtp_password: Option, webhook_url: Option, - _email_recipients: Vec, + email_recipients: Vec, } impl NotificationConfig { @@ -25,10 +28,10 @@ impl NotificationConfig { slack_webhook: None, smtp_host: None, smtp_port: None, - _smtp_user: None, - _smtp_password: None, + smtp_user: None, + smtp_password: None, webhook_url: None, - _email_recipients: Vec::new(), + email_recipients: Vec::new(), } } @@ -50,6 +53,24 @@ impl NotificationConfig { self } + /// Set SMTP user + pub fn with_smtp_user(mut self, user: String) -> Self { + self.smtp_user = Some(user); + self + } + + /// Set SMTP password + pub fn with_smtp_password(mut self, password: String) -> Self { + self.smtp_password = Some(password); + self + } + + /// Set email recipients + pub fn with_email_recipients(mut self, recipients: Vec) -> Self { + self.email_recipients = recipients; + self + } + /// Set webhook URL pub fn with_webhook_url(mut self, url: String) -> Self { self.webhook_url = Some(url); @@ -71,14 +92,55 @@ impl NotificationConfig { self.smtp_port } + /// Get SMTP user + pub fn smtp_user(&self) -> Option<&str> { + self.smtp_user.as_deref() + } + + /// Get SMTP password + pub fn smtp_password(&self) -> Option<&str> { + self.smtp_password.as_deref() + } + + /// Get email recipients + pub fn email_recipients(&self) -> &[String] { + &self.email_recipients + } + /// Get webhook URL pub fn webhook_url(&self) -> Option<&str> { self.webhook_url.as_deref() } + + /// Return only channels that are both policy-selected and actually configured. + pub fn configured_channels_for_severity( + &self, + severity: AlertSeverity, + ) -> Vec { + route_by_severity(severity) + .into_iter() + .filter(|channel| self.supports_channel(channel)) + .collect() + } + + fn supports_channel(&self, channel: &NotificationChannel) -> bool { + match channel { + NotificationChannel::Console => true, + NotificationChannel::Slack => self.slack_webhook.is_some(), + NotificationChannel::Webhook => self.webhook_url.is_some(), + NotificationChannel::Email => { + self.smtp_host.is_some() + && self.smtp_port.is_some() + && self.smtp_user.is_some() + && self.smtp_password.is_some() + && !self.email_recipients.is_empty() + } + } + } } /// Notification channel -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum NotificationChannel { Console, Slack, @@ -88,12 +150,16 @@ pub enum NotificationChannel { impl NotificationChannel { /// Send notification - pub fn send(&self, alert: &Alert, _config: &NotificationConfig) -> Result { + pub async fn send( + &self, + alert: &Alert, + config: &NotificationConfig, + ) -> Result { match self { NotificationChannel::Console => self.send_console(alert), - NotificationChannel::Slack => self.send_slack(alert, _config), - NotificationChannel::Email => self.send_email(alert, _config), - NotificationChannel::Webhook => self.send_webhook(alert, _config), + NotificationChannel::Slack => self.send_slack(alert, config).await, + NotificationChannel::Email => self.send_email(alert, config).await, + NotificationChannel::Webhook => self.send_webhook(alert, config).await, } } @@ -111,19 +177,23 @@ impl NotificationChannel { } /// Send to Slack via incoming webhook - fn send_slack(&self, alert: &Alert, config: &NotificationConfig) -> Result { + async fn send_slack( + &self, + alert: &Alert, + config: &NotificationConfig, + ) -> Result { if let Some(webhook_url) = config.slack_webhook() { let payload = build_slack_message(alert); log::debug!("Sending Slack notification to webhook"); log::trace!("Slack payload: {}", payload); - // Blocking HTTP POST — notification sending is synchronous in this codebase - let client = reqwest::blocking::Client::new(); + let client = reqwest::Client::new(); match client .post(webhook_url) .header("Content-Type", "application/json") .body(payload) .send() + .await { Ok(resp) => { if resp.status().is_success() { @@ -131,7 +201,7 @@ impl NotificationChannel { Ok(NotificationResult::Success("sent to Slack".to_string())) } else { let status = resp.status(); - let body = resp.text().unwrap_or_default(); + let body = resp.text().await.unwrap_or_default(); log::warn!("Slack API returned {}: {}", status, body); Ok(NotificationResult::Failure(format!( "Slack returned {}: {}", @@ -156,30 +226,101 @@ impl NotificationChannel { } /// Send via email - fn send_email(&self, alert: &Alert, config: &NotificationConfig) -> Result { - // In production, this would send SMTP email - // For now, just log - if config.smtp_host().is_some() { - log::info!("Would send email: {}", alert.message()); - Ok(NotificationResult::Success("sent via email".to_string())) - } else { - Ok(NotificationResult::Failure( + async fn send_email( + &self, + alert: &Alert, + config: &NotificationConfig, + ) -> Result { + match ( + config.smtp_host(), + config.smtp_port(), + config.smtp_user(), + config.smtp_password(), + ) { + (Some(host), Some(port), Some(user), Some(password)) + if !config.email_recipients().is_empty() => + { + let from: Mailbox = user + .parse() + .with_context(|| format!("invalid SMTP sender address: {user}"))?; + let recipients = config + .email_recipients() + .iter() + .map(|recipient| { + recipient + .parse::() + .with_context(|| format!("invalid SMTP recipient address: {recipient}")) + }) + .collect::>>()?; + + let mut message_builder = Message::builder().from(from).subject(format!( + "[Stackdog][{}] {}", + alert.severity(), + alert.alert_type() + )); + + for recipient in recipients { + message_builder = message_builder.to(recipient); + } + + let message = message_builder.multipart( + MultiPart::alternative() + .singlepart(SinglePart::plain(build_email_text(alert))) + .singlepart(SinglePart::html(build_email_html(alert))), + )?; + + let mailer = AsyncSmtpTransport::::relay(host)? + .port(port) + .credentials(Credentials::new(user.to_string(), password.to_string())) + .build(); + + match mailer.send(message).await { + Ok(_) => Ok(NotificationResult::Success("sent to email".to_string())), + Err(err) => Ok(NotificationResult::Failure(format!( + "SMTP delivery failed: {}", + err + ))), + } + } + _ => Ok(NotificationResult::Failure( "SMTP not configured".to_string(), - )) + )), } } /// Send to webhook - fn send_webhook( + async fn send_webhook( &self, alert: &Alert, config: &NotificationConfig, ) -> Result { - // In production, this would make HTTP POST - // For now, just log - if config.webhook_url().is_some() { - log::info!("Would send to webhook: {}", alert.message()); - Ok(NotificationResult::Success("sent to webhook".to_string())) + if let Some(webhook_url) = config.webhook_url() { + let payload = build_webhook_payload(alert); + let client = reqwest::Client::new(); + match client + .post(webhook_url) + .header("Content-Type", "application/json") + .body(payload) + .send() + .await + { + Ok(resp) => { + if resp.status().is_success() { + Ok(NotificationResult::Success("sent to webhook".to_string())) + } else { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + Ok(NotificationResult::Failure(format!( + "Webhook returned {}: {}", + status, body + ))) + } + } + Err(err) => Ok(NotificationResult::Failure(format!( + "Webhook request failed: {}", + err + ))), + } } else { Ok(NotificationResult::Failure( "Webhook URL not configured".to_string(), @@ -266,19 +407,36 @@ pub fn build_slack_message(alert: &Alert) -> String { /// Build webhook payload pub fn build_webhook_payload(alert: &Alert) -> String { + serde_json::json!({ + "alert_type": alert.alert_type().to_string(), + "severity": alert.severity().to_string(), + "message": alert.message(), + "timestamp": alert.timestamp().to_rfc3339(), + "status": alert.status().to_string(), + "metadata": alert.metadata(), + }) + .to_string() +} + +fn build_email_text(alert: &Alert) -> String { + format!( + "Stackdog Security Alert\n\nType: {}\nSeverity: {}\nStatus: {}\nTime: {}\n\n{}\n", + alert.alert_type(), + alert.severity(), + alert.status(), + alert.timestamp().to_rfc3339(), + alert.message(), + ) +} + +fn build_email_html(alert: &Alert) -> String { format!( - r#"{{ - "alert_type": "{:?} ", - "severity": "{}", - "message": "{}", - "timestamp": "{}", - "status": "{}" - }}"#, + "

Stackdog Security Alert

Type: {}

Severity: {}

Status: {}

Time: {}

{}

", alert.alert_type(), alert.severity(), + alert.status(), + alert.timestamp().to_rfc3339(), alert.message(), - alert.timestamp(), - alert.status() ) } @@ -286,8 +444,8 @@ pub fn build_webhook_payload(alert: &Alert) -> String { mod tests { use super::*; - #[test] - fn test_console_notification() { + #[tokio::test] + async fn test_console_notification() { let channel = NotificationChannel::Console; let alert = Alert::new( crate::alerting::alert::AlertType::ThreatDetected, @@ -295,7 +453,7 @@ mod tests { "Test".to_string(), ); - let result = channel.send(&alert, &NotificationConfig::default()); + let result = channel.send(&alert, &NotificationConfig::default()).await; assert!(result.is_ok()); } @@ -313,4 +471,64 @@ mod tests { let info_routes = route_by_severity(AlertSeverity::Info); assert_eq!(info_routes.len(), 1); } + + #[test] + fn test_build_webhook_payload_is_valid_json() { + let alert = Alert::new( + crate::alerting::alert::AlertType::ThreatDetected, + AlertSeverity::High, + "Webhook test".to_string(), + ); + + let payload = build_webhook_payload(&alert); + let json: serde_json::Value = serde_json::from_str(&payload).unwrap(); + assert_eq!(json["severity"], "High"); + assert_eq!(json["message"], "Webhook test"); + } + + #[tokio::test] + async fn test_email_channel_requires_recipients() { + let channel = NotificationChannel::Email; + let alert = Alert::new( + crate::alerting::alert::AlertType::ThreatDetected, + AlertSeverity::High, + "Email test".to_string(), + ); + + let result = channel + .send( + &alert, + &NotificationConfig::default() + .with_smtp_host("smtp.example.com".to_string()) + .with_smtp_port(587), + ) + .await + .unwrap(); + + assert!(matches!(result, NotificationResult::Failure(_))); + } + + #[test] + fn test_configured_channels_excludes_unconfigured_targets() { + let config = NotificationConfig::default().with_webhook_url("https://example.test".into()); + let channels = config.configured_channels_for_severity(AlertSeverity::Critical); + + assert!(channels.contains(&NotificationChannel::Console)); + assert!(channels.contains(&NotificationChannel::Webhook)); + assert!(!channels.contains(&NotificationChannel::Slack)); + assert!(!channels.contains(&NotificationChannel::Email)); + } + + #[test] + fn test_configured_channels_include_email_when_fully_configured() { + let config = NotificationConfig::default() + .with_smtp_host("smtp.example.com".into()) + .with_smtp_port(587) + .with_smtp_user("alerts@example.com".into()) + .with_smtp_password("secret".into()) + .with_email_recipients(vec!["security@example.com".into()]); + let channels = config.configured_channels_for_severity(AlertSeverity::Critical); + + assert!(channels.contains(&NotificationChannel::Email)); + } } diff --git a/src/alerting/rules.rs b/src/alerting/rules.rs index d78dfa5..87c6441 100644 --- a/src/alerting/rules.rs +++ b/src/alerting/rules.rs @@ -2,14 +2,48 @@ use anyhow::Result; +use crate::alerting::alert::AlertSeverity; +use crate::alerting::notifications::{route_by_severity, NotificationChannel}; + /// Alert rule +#[derive(Debug, Clone)] pub struct AlertRule { - // TODO: Implement in TASK-018 + minimum_severity: AlertSeverity, + channels: Vec, } impl AlertRule { pub fn new() -> Result { - Ok(Self {}) + Ok(Self { + minimum_severity: AlertSeverity::Low, + channels: route_by_severity(AlertSeverity::High), + }) + } + + pub fn with_minimum_severity(mut self, severity: AlertSeverity) -> Self { + self.minimum_severity = severity; + self + } + + pub fn with_channels(mut self, channels: Vec) -> Self { + self.channels = channels; + self + } + + pub fn matches(&self, severity: AlertSeverity) -> bool { + severity >= self.minimum_severity + } + + pub fn channels_for(&self, severity: AlertSeverity) -> Vec { + if self.matches(severity) { + if self.channels.is_empty() { + route_by_severity(severity) + } else { + self.channels.clone() + } + } else { + Vec::new() + } } } @@ -18,3 +52,26 @@ impl Default for AlertRule { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_alert_rule_matches_minimum_severity() { + let rule = AlertRule::default().with_minimum_severity(AlertSeverity::Medium); + assert!(rule.matches(AlertSeverity::High)); + assert!(!rule.matches(AlertSeverity::Low)); + } + + #[test] + fn test_alert_rule_uses_custom_channels() { + let rule = AlertRule::default() + .with_minimum_severity(AlertSeverity::Low) + .with_channels(vec![NotificationChannel::Webhook]); + + let channels = rule.channels_for(AlertSeverity::Critical); + assert_eq!(channels.len(), 1); + assert!(matches!(channels[0], NotificationChannel::Webhook)); + } +} diff --git a/src/api/alerts.rs b/src/api/alerts.rs index 5a80e37..6557208 100644 --- a/src/api/alerts.rs +++ b/src/api/alerts.rs @@ -1,9 +1,11 @@ //! Alerts API endpoints +use crate::api::websocket::{broadcast_event, broadcast_stats, WebSocketHubHandle}; use crate::database::{ create_sample_alert, get_alert_stats as db_get_alert_stats, list_alerts as db_list_alerts, update_alert_status, AlertFilter, DbPool, }; +use crate::models::api::alerts::AlertResponse; use actix_web::{web, HttpResponse, Responder}; use serde::Deserialize; @@ -24,7 +26,12 @@ pub async fn get_alerts(pool: web::Data, query: web::Query) }; match db_list_alerts(&pool, filter).await { - Ok(alerts) => HttpResponse::Ok().json(alerts), + Ok(alerts) => HttpResponse::Ok().json( + alerts + .into_iter() + .map(AlertResponse::from) + .collect::>(), + ), Err(e) => { log::error!("Failed to list alerts: {}", e); HttpResponse::InternalServerError().json(serde_json::json!({ @@ -61,12 +68,26 @@ pub async fn get_alert_stats(pool: web::Data) -> impl Responder { /// Acknowledge an alert /// /// POST /api/alerts/:id/acknowledge -pub async fn acknowledge_alert(pool: web::Data, path: web::Path) -> impl Responder { +pub async fn acknowledge_alert( + pool: web::Data, + hub: web::Data, + path: web::Path, +) -> impl Responder { let alert_id = path.into_inner(); match update_alert_status(&pool, &alert_id, "Acknowledged").await { Ok(()) => { log::info!("Acknowledged alert: {}", alert_id); + broadcast_event( + hub.get_ref(), + "alert:updated", + serde_json::json!({ + "id": alert_id, + "status": "Acknowledged" + }), + ) + .await; + let _ = broadcast_stats(hub.get_ref(), &pool).await; HttpResponse::Ok().json(serde_json::json!({ "success": true, "message": format!("Alert {} acknowledged", alert_id) @@ -91,6 +112,7 @@ pub struct ResolveRequest { pub async fn resolve_alert( pool: web::Data, + hub: web::Data, path: web::Path, body: web::Json, ) -> impl Responder { @@ -100,6 +122,17 @@ pub async fn resolve_alert( match update_alert_status(&pool, &alert_id, "Resolved").await { Ok(()) => { log::info!("Resolved alert {}: {}", alert_id, _note); + broadcast_event( + hub.get_ref(), + "alert:updated", + serde_json::json!({ + "id": alert_id, + "status": "Resolved", + "note": _note + }), + ) + .await; + let _ = broadcast_stats(hub.get_ref(), &pool).await; HttpResponse::Ok().json(serde_json::json!({ "success": true, "message": format!("Alert {} resolved", alert_id) @@ -115,18 +148,36 @@ pub async fn resolve_alert( } /// Seed database with sample alerts (for testing) -pub async fn seed_sample_alerts(pool: web::Data) -> impl Responder { +pub async fn seed_sample_alerts( + pool: web::Data, + hub: web::Data, +) -> impl Responder { use crate::database::create_alert; let mut created = Vec::new(); + let mut last_alert = None; for i in 0..5 { let alert = create_sample_alert(); if create_alert(&pool, alert).await.is_ok() { created.push(i); + last_alert = Some(i); } } + if !created.is_empty() { + broadcast_event( + hub.get_ref(), + "alert:created", + serde_json::json!({ + "created": created.len(), + "last_index": last_alert + }), + ) + .await; + let _ = broadcast_stats(hub.get_ref(), &pool).await; + } + HttpResponse::Ok().json(serde_json::json!({ "created": created.len(), "message": "Sample alerts created" diff --git a/src/api/containers.rs b/src/api/containers.rs index 85d76c2..9e7ad77 100644 --- a/src/api/containers.rs +++ b/src/api/containers.rs @@ -1,8 +1,12 @@ //! Containers API endpoints +use crate::api::websocket::{broadcast_event, broadcast_stats, WebSocketHubHandle}; use crate::database::DbPool; -use crate::docker::client::ContainerInfo; +use crate::docker::client::{ContainerInfo, ContainerStats}; use crate::docker::containers::ContainerManager; +use crate::models::api::containers::{ + ContainerResponse, ContainerSecurityStatus as ApiContainerSecurityStatus, NetworkActivity, +}; use actix_web::{web, HttpResponse, Responder}; use serde::Deserialize; @@ -21,54 +25,47 @@ pub async fn get_containers(pool: web::Data) -> impl Responder { Ok(m) => m, Err(e) => { log::error!("Failed to create container manager: {}", e); - // Return mock data if Docker not available - return HttpResponse::Ok().json(vec![serde_json::json!({ - "id": "mock-container-1", - "name": "web-server", - "image": "nginx:latest", - "status": "Running", - "security_status": { - "state": "Secure", - "threats": 0, - "vulnerabilities": 0 - }, - "risk_score": 10, - "network_activity": { - "inbound_connections": 5, - "outbound_connections": 3, - "blocked_connections": 0, - "suspicious_activity": false - } - })]); + return HttpResponse::ServiceUnavailable().json(serde_json::json!({ + "error": "Failed to connect to Docker" + })); } }; match manager.list_containers().await { Ok(containers) => { - // Convert to API response format - let response: Vec = containers - .iter() - .map(|c: &ContainerInfo| { - serde_json::json!({ - "id": c.id, - "name": c.name, - "image": c.image, - "status": c.status, - "security_status": { - "state": "Secure", - "threats": 0, - "vulnerabilities": 0 - }, - "risk_score": 0, - "network_activity": { - "inbound_connections": 0, - "outbound_connections": 0, - "blocked_connections": 0, - "suspicious_activity": false + let mut response = Vec::with_capacity(containers.len()); + for container in &containers { + let security = match manager.get_container_security_status(&container.id).await { + Ok(status) => status, + Err(err) => { + log::warn!( + "Failed to derive security status for container {}: {}", + container.id, + err + ); + crate::docker::containers::ContainerSecurityStatus { + container_id: container.id.clone(), + risk_score: 0, + threats: 0, + security_state: "Unknown".to_string(), } - }) - }) - .collect(); + } + }; + + let stats = match manager.get_container_stats(&container.id).await { + Ok(stats) => Some(stats), + Err(err) => { + log::warn!( + "Failed to load runtime stats for container {}: {}", + container.id, + err + ); + None + } + }; + + response.push(to_container_response(container, &security, stats.as_ref())); + } HttpResponse::Ok().json(response) } @@ -81,11 +78,43 @@ pub async fn get_containers(pool: web::Data) -> impl Responder { } } +fn to_container_response( + container: &ContainerInfo, + security: &crate::docker::containers::ContainerSecurityStatus, + stats: Option<&ContainerStats>, +) -> ContainerResponse { + ContainerResponse { + id: container.id.clone(), + name: container.name.clone(), + image: container.image.clone(), + status: container.status.clone(), + security_status: ApiContainerSecurityStatus { + state: security.security_state.clone(), + threats: security.threats, + vulnerabilities: None, + last_scan: None, + }, + risk_score: security.risk_score, + network_activity: NetworkActivity { + inbound_connections: None, + outbound_connections: None, + blocked_connections: None, + received_bytes: stats.map(|stats| stats.network_rx), + transmitted_bytes: stats.map(|stats| stats.network_tx), + received_packets: stats.map(|stats| stats.network_rx_packets), + transmitted_packets: stats.map(|stats| stats.network_tx_packets), + suspicious_activity: security.threats > 0 || security.security_state == "Quarantined", + }, + created_at: container.created.clone(), + } +} + /// Quarantine a container /// /// POST /api/containers/:id/quarantine pub async fn quarantine_container( pool: web::Data, + hub: web::Data, path: web::Path, body: web::Json, ) -> impl Responder { @@ -103,10 +132,22 @@ pub async fn quarantine_container( }; match manager.quarantine_container(&container_id, &reason).await { - Ok(()) => HttpResponse::Ok().json(serde_json::json!({ - "success": true, - "message": format!("Container {} quarantined", container_id) - })), + Ok(()) => { + broadcast_event( + hub.get_ref(), + "container:quarantined", + serde_json::json!({ + "container_id": container_id, + "reason": reason + }), + ) + .await; + let _ = broadcast_stats(hub.get_ref(), &pool).await; + HttpResponse::Ok().json(serde_json::json!({ + "success": true, + "message": format!("Container {} quarantined", container_id) + })) + } Err(e) => { log::error!("Failed to quarantine container: {}", e); HttpResponse::InternalServerError().json(serde_json::json!({ @@ -160,8 +201,66 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { mod tests { use super::*; use crate::database::{create_pool, init_database}; + use crate::docker::client::ContainerStats; use actix_web::{test, App}; + fn sample_container() -> ContainerInfo { + ContainerInfo { + id: "container-1".into(), + name: "web".into(), + image: "nginx:latest".into(), + status: "Running".into(), + created: "2026-01-01T00:00:00Z".into(), + network_settings: std::collections::HashMap::new(), + } + } + + fn sample_security() -> crate::docker::containers::ContainerSecurityStatus { + crate::docker::containers::ContainerSecurityStatus { + container_id: "container-1".into(), + risk_score: 42, + threats: 1, + security_state: "AtRisk".into(), + } + } + + #[actix_rt::test] + async fn test_to_container_response_uses_real_stats() { + let response = to_container_response( + &sample_container(), + &sample_security(), + Some(&ContainerStats { + cpu_percent: 0.0, + memory_usage: 0, + memory_limit: 0, + network_rx: 1024, + network_tx: 2048, + network_rx_packets: 5, + network_tx_packets: 9, + }), + ); + + assert_eq!(response.security_status.vulnerabilities, None); + assert_eq!(response.security_status.last_scan, None); + assert_eq!(response.network_activity.received_bytes, Some(1024)); + assert_eq!(response.network_activity.transmitted_bytes, Some(2048)); + assert_eq!(response.network_activity.received_packets, Some(5)); + assert_eq!(response.network_activity.transmitted_packets, Some(9)); + assert_eq!(response.network_activity.inbound_connections, None); + assert_eq!(response.network_activity.outbound_connections, None); + } + + #[actix_rt::test] + async fn test_to_container_response_leaves_missing_stats_unavailable() { + let response = to_container_response(&sample_container(), &sample_security(), None); + + assert_eq!(response.network_activity.received_bytes, None); + assert_eq!(response.network_activity.transmitted_bytes, None); + assert_eq!(response.network_activity.received_packets, None); + assert_eq!(response.network_activity.transmitted_packets, None); + assert_eq!(response.network_activity.blocked_connections, None); + } + #[actix_rt::test] async fn test_get_containers() { let pool = create_pool(":memory:").unwrap(); @@ -174,6 +273,9 @@ mod tests { let req = test::TestRequest::get().uri("/api/containers").to_request(); let resp = test::call_service(&app, req).await; - assert!(resp.status().is_success()); + assert!( + resp.status().is_success() + || resp.status() == actix_web::http::StatusCode::SERVICE_UNAVAILABLE + ); } } diff --git a/src/api/security.rs b/src/api/security.rs index 1c97f15..44a3945 100644 --- a/src/api/security.rs +++ b/src/api/security.rs @@ -1,14 +1,22 @@ //! Security API endpoints +use crate::database::{get_security_status_snapshot, DbPool, SecurityStatusSnapshot}; use crate::models::api::security::SecurityStatusResponse; use actix_web::{web, HttpResponse, Responder}; /// Get overall security status /// /// GET /api/security/status -pub async fn get_security_status() -> impl Responder { - let status = SecurityStatusResponse::new(); - HttpResponse::Ok().json(status) +pub async fn get_security_status(pool: web::Data) -> impl Responder { + match build_security_status(pool.get_ref()) { + Ok(status) => HttpResponse::Ok().json(status), + Err(err) => { + log::error!("Failed to build security status: {}", err); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to build security status" + })) + } + } } /// Configure security routes @@ -16,14 +24,40 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { cfg.service(web::scope("/api/security").route("/status", web::get().to(get_security_status))); } +pub(crate) fn build_security_status(pool: &DbPool) -> anyhow::Result { + let snapshot = get_security_status_snapshot(pool)?; + Ok(SecurityStatusResponse::from_state( + calculate_overall_score(&snapshot), + snapshot.active_threats, + snapshot.quarantined_containers, + snapshot.alerts_new, + snapshot.alerts_acknowledged, + )) +} + +fn calculate_overall_score(snapshot: &SecurityStatusSnapshot) -> u32 { + let penalty = snapshot.severity_breakdown.weighted_penalty() + + snapshot.quarantined_containers.saturating_mul(25) + + snapshot.alerts_acknowledged.saturating_mul(2); + 100u32.saturating_sub(penalty.min(100)) +} + #[cfg(test)] mod tests { use super::*; + use crate::alerting::alert::{AlertSeverity, AlertStatus, AlertType}; + use crate::database::models::{Alert, AlertMetadata}; + use crate::database::{create_alert, create_pool, init_database}; use actix_web::{test, App}; + use chrono::Utc; #[actix_rt::test] async fn test_get_security_status() { - let app = test::init_service(App::new().configure(configure_routes)).await; + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let pool_data = web::Data::new(pool); + let app = + test::init_service(App::new().app_data(pool_data).configure(configure_routes)).await; let req = test::TestRequest::get() .uri("/api/security/status") @@ -32,4 +66,45 @@ mod tests { assert!(resp.status().is_success()); } + + #[actix_rt::test] + async fn test_build_security_status_uses_alert_data() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + create_alert( + &pool, + Alert { + id: "a1".to_string(), + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::High, + message: "test".to_string(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: None, + }, + ) + .await + .unwrap(); + create_alert( + &pool, + Alert { + id: "a2".to_string(), + alert_type: AlertType::QuarantineApplied, + severity: AlertSeverity::High, + message: "container quarantined".to_string(), + status: AlertStatus::Acknowledged, + timestamp: Utc::now().to_rfc3339(), + metadata: Some(AlertMetadata::default().with_container_id("abc123")), + }, + ) + .await + .unwrap(); + + let status = build_security_status(&pool).unwrap(); + assert_eq!(status.active_threats, 1); + assert_eq!(status.quarantined_containers, 1); + assert_eq!(status.alerts_new, 1); + assert_eq!(status.alerts_acknowledged, 1); + assert!(status.overall_score < 100); + } } diff --git a/src/api/threats.rs b/src/api/threats.rs index a9a5886..2200638 100644 --- a/src/api/threats.rs +++ b/src/api/threats.rs @@ -1,5 +1,8 @@ //! Threats API endpoints +use crate::alerting::alert::{AlertSeverity, AlertStatus, AlertType}; +use crate::database::models::{Alert, AlertMetadata}; +use crate::database::{list_alerts as db_list_alerts, AlertFilter, DbPool}; use crate::models::api::threats::{ThreatResponse, ThreatStatisticsResponse}; use actix_web::{web, HttpResponse, Responder}; use std::collections::HashMap; @@ -7,45 +10,68 @@ use std::collections::HashMap; /// Get all threats /// /// GET /api/threats -pub async fn get_threats() -> impl Responder { - // TODO: Fetch from database when implemented - let threats = vec![ThreatResponse { - id: "threat-1".to_string(), - r#type: "CryptoMiner".to_string(), - severity: "High".to_string(), - score: 85, - source: "container-1".to_string(), - timestamp: chrono::Utc::now().to_rfc3339(), - status: "New".to_string(), - }]; - - HttpResponse::Ok().json(threats) +pub async fn get_threats(pool: web::Data) -> impl Responder { + match db_list_alerts(&pool, AlertFilter::default()).await { + Ok(alerts) => { + let threats = alerts + .into_iter() + .filter(|alert| is_threat_alert_type(alert.alert_type)) + .map(|alert| ThreatResponse { + id: alert.id, + r#type: alert.alert_type.to_string(), + severity: alert.severity.to_string(), + score: severity_to_score(alert.severity), + source: extract_source(alert.metadata.as_ref()), + timestamp: alert.timestamp, + status: alert.status.to_string(), + }) + .collect::>(); + + HttpResponse::Ok().json(threats) + } + Err(e) => { + log::error!("Failed to load threats: {}", e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to load threats" + })) + } + } } /// Get threat statistics /// /// GET /api/threats/statistics -pub async fn get_threat_statistics() -> impl Responder { - let mut by_severity = HashMap::new(); - by_severity.insert("Info".to_string(), 1); - by_severity.insert("Low".to_string(), 2); - by_severity.insert("Medium".to_string(), 3); - by_severity.insert("High".to_string(), 3); - by_severity.insert("Critical".to_string(), 1); - - let mut by_type = HashMap::new(); - by_type.insert("CryptoMiner".to_string(), 3); - by_type.insert("ContainerEscape".to_string(), 2); - by_type.insert("NetworkScanner".to_string(), 5); - - let stats = ThreatStatisticsResponse { - total_threats: 10, - by_severity, - by_type, - trend: "stable".to_string(), - }; - - HttpResponse::Ok().json(stats) +pub async fn get_threat_statistics(pool: web::Data) -> impl Responder { + match db_list_alerts(&pool, AlertFilter::default()).await { + Ok(alerts) => { + let threats = alerts + .into_iter() + .filter(|alert| is_threat_alert_type(alert.alert_type)) + .collect::>(); + let mut by_severity = HashMap::new(); + let mut by_type = HashMap::new(); + + for alert in &threats { + *by_severity.entry(alert.severity.to_string()).or_insert(0) += 1; + *by_type.entry(alert.alert_type.to_string()).or_insert(0) += 1; + } + + let stats = ThreatStatisticsResponse { + total_threats: threats.len() as u32, + by_severity, + by_type, + trend: calculate_trend(&threats), + }; + + HttpResponse::Ok().json(stats) + } + Err(e) => { + log::error!("Failed to load threat statistics: {}", e); + HttpResponse::InternalServerError().json(serde_json::json!({ + "error": "Failed to load threat statistics" + })) + } + } } /// Configure threat routes @@ -64,7 +90,14 @@ mod tests { #[actix_rt::test] async fn test_get_threats() { - let app = test::init_service(App::new().configure(configure_routes)).await; + let pool = crate::database::create_pool(":memory:").unwrap(); + crate::database::init_database(&pool).unwrap(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(configure_routes), + ) + .await; let req = test::TestRequest::get().uri("/api/threats").to_request(); let resp = test::call_service(&app, req).await; @@ -74,7 +107,14 @@ mod tests { #[actix_rt::test] async fn test_get_threat_statistics() { - let app = test::init_service(App::new().configure(configure_routes)).await; + let pool = crate::database::create_pool(":memory:").unwrap(); + crate::database::init_database(&pool).unwrap(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(configure_routes), + ) + .await; let req = test::TestRequest::get() .uri("/api/threats/statistics") @@ -84,3 +124,55 @@ mod tests { assert!(resp.status().is_success()); } } + +fn severity_to_score(severity: AlertSeverity) -> u32 { + match severity { + AlertSeverity::Critical => 95, + AlertSeverity::High => 85, + AlertSeverity::Medium => 60, + AlertSeverity::Low => 30, + _ => 10, + } +} + +fn extract_source(metadata: Option<&AlertMetadata>) -> String { + metadata + .and_then(|value| { + value + .source + .as_ref() + .or(value.container_id.as_ref()) + .or(value.reason.as_ref()) + .cloned() + }) + .unwrap_or_else(|| "unknown".to_string()) +} + +fn is_threat_alert_type(alert_type: AlertType) -> bool { + matches!( + alert_type, + AlertType::ThreatDetected + | AlertType::AnomalyDetected + | AlertType::RuleViolation + | AlertType::ThresholdExceeded + ) +} + +fn calculate_trend(alerts: &[Alert]) -> String { + let unresolved = alerts + .iter() + .filter(|alert| alert.status != AlertStatus::Resolved) + .count(); + let resolved = alerts + .iter() + .filter(|alert| alert.status == AlertStatus::Resolved) + .count(); + + if unresolved > resolved { + "increasing".to_string() + } else if resolved > unresolved { + "decreasing".to_string() + } else { + "stable".to_string() + } +} diff --git a/src/api/websocket.rs b/src/api/websocket.rs index 106fe05..b37e622 100644 --- a/src/api/websocket.rs +++ b/src/api/websocket.rs @@ -1,49 +1,312 @@ -//! WebSocket handler for real-time updates -//! -//! Note: Full WebSocket implementation requires additional setup. -//! This is a placeholder that returns 426 Upgrade Required. -//! -//! TODO: Implement proper WebSocket support with: -//! - actix-web-actors with proper Actor trait implementation -//! - Or use tokio-tungstenite for lower-level WebSocket handling - -use actix_web::{http::StatusCode, web, Error, HttpRequest, HttpResponse}; -use log::info; - -/// WebSocket endpoint handler (placeholder) -/// -/// Returns 426 Upgrade Required to indicate WebSocket is not yet fully implemented -pub async fn websocket_handler(req: HttpRequest) -> Result { - info!( - "WebSocket connection attempt from: {:?}", - req.connection_info().peer_addr() - ); - - // Return upgrade required response - // Client should retry with proper WebSocket upgrade headers - Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) - .insert_header(("Upgrade", "websocket")) - .body("WebSocket upgrade not yet implemented - see documentation")) -} - -/// Configure WebSocket route +//! WebSocket handler for real-time updates. + +use std::collections::HashMap; +use std::time::Duration; + +use actix::prelude::*; +use actix_web::{web, Error, HttpRequest, HttpResponse}; +use actix_web_actors::ws; +use serde::Serialize; + +use crate::api::security::build_security_status; +use crate::database::DbPool; + +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); +const CLIENT_TIMEOUT: Duration = Duration::from_secs(15); + +#[derive(Debug, Clone, Serialize)] +pub struct WsEnvelope { + pub r#type: String, + pub payload: T, +} + +#[derive(Message)] +#[rtype(result = "()")] +pub struct WsMessage(pub String); + +#[derive(Message)] +#[rtype(usize)] +struct Connect { + addr: Recipient, +} + +#[derive(Message)] +#[rtype(result = "()")] +struct Disconnect { + id: usize, +} + +#[derive(Message)] +#[rtype(result = "()")] +pub struct BroadcastMessage { + pub event_type: String, + pub payload: serde_json::Value, +} + +pub struct WebSocketHub { + sessions: HashMap>, + next_id: usize, +} + +impl WebSocketHub { + pub fn new() -> Self { + Self { + sessions: HashMap::new(), + next_id: 1, + } + } + + fn broadcast_json(&self, message: &str) { + for recipient in self.sessions.values() { + recipient.do_send(WsMessage(message.to_string())); + } + } +} + +impl Default for WebSocketHub { + fn default() -> Self { + Self::new() + } +} + +impl Actor for WebSocketHub { + type Context = Context; +} + +impl Handler for WebSocketHub { + type Result = usize; + + fn handle(&mut self, msg: Connect, _: &mut Self::Context) -> Self::Result { + let id = self.next_id; + self.next_id += 1; + self.sessions.insert(id, msg.addr); + id + } +} + +impl Handler for WebSocketHub { + type Result = (); + + fn handle(&mut self, msg: Disconnect, _: &mut Self::Context) { + self.sessions.remove(&msg.id); + } +} + +impl Handler for WebSocketHub { + type Result = (); + + fn handle(&mut self, msg: BroadcastMessage, _: &mut Self::Context) { + let envelope = WsEnvelope { + r#type: msg.event_type, + payload: msg.payload, + }; + if let Ok(json) = serde_json::to_string(&envelope) { + self.broadcast_json(&json); + } + } +} + +pub type WebSocketHubHandle = Addr; + +pub struct WebSocketSession { + id: usize, + heartbeat: std::time::Instant, + hub: WebSocketHubHandle, + pool: DbPool, +} + +impl WebSocketSession { + fn new(hub: WebSocketHubHandle, pool: DbPool) -> Self { + Self { + id: 0, + heartbeat: std::time::Instant::now(), + hub, + pool, + } + } + + fn start_heartbeat(&self, ctx: &mut ws::WebsocketContext) { + ctx.run_interval(HEARTBEAT_INTERVAL, |actor, ctx| { + if std::time::Instant::now().duration_since(actor.heartbeat) > CLIENT_TIMEOUT { + actor.hub.do_send(Disconnect { id: actor.id }); + ctx.stop(); + return; + } + + ctx.ping(b""); + }); + } + + fn send_initial_snapshot(&self, ctx: &mut ws::WebsocketContext) { + if let Ok(message) = build_stats_message(&self.pool) { + ctx.text(message); + } + } +} + +impl Actor for WebSocketSession { + type Context = ws::WebsocketContext; + + fn started(&mut self, ctx: &mut Self::Context) { + self.start_heartbeat(ctx); + + let address = ctx.address(); + self.hub + .send(Connect { + addr: address.recipient(), + }) + .into_actor(self) + .map(|result, actor, ctx| { + if let Ok(id) = result { + actor.id = id; + actor.send_initial_snapshot(ctx); + } else { + ctx.stop(); + } + }) + .wait(ctx); + } + + fn stopped(&mut self, _: &mut Self::Context) { + self.hub.do_send(Disconnect { id: self.id }); + } +} + +impl Handler for WebSocketSession { + type Result = (); + + fn handle(&mut self, msg: WsMessage, ctx: &mut Self::Context) { + ctx.text(msg.0); + } +} + +impl StreamHandler> for WebSocketSession { + fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { + match msg { + Ok(ws::Message::Ping(payload)) => { + self.heartbeat = std::time::Instant::now(); + ctx.pong(&payload); + } + Ok(ws::Message::Pong(_)) => { + self.heartbeat = std::time::Instant::now(); + } + Ok(ws::Message::Text(_)) => {} + Ok(ws::Message::Binary(_)) => {} + Ok(ws::Message::Close(reason)) => { + ctx.close(reason); + ctx.stop(); + } + Ok(ws::Message::Continuation(_)) => {} + Ok(ws::Message::Nop) => {} + Err(_) => ctx.stop(), + } + } +} + +pub async fn websocket_handler( + req: HttpRequest, + stream: web::Payload, + hub: web::Data, + pool: web::Data, +) -> Result { + ws::start( + WebSocketSession::new(hub.get_ref().clone(), pool.get_ref().clone()), + &req, + stream, + ) +} + pub fn configure_routes(cfg: &mut web::ServiceConfig) { cfg.route("/ws", web::get().to(websocket_handler)); } +pub async fn broadcast_event( + hub: &WebSocketHubHandle, + event_type: impl Into, + payload: serde_json::Value, +) { + hub.do_send(BroadcastMessage { + event_type: event_type.into(), + payload, + }); +} + +pub async fn broadcast_stats(hub: &WebSocketHubHandle, pool: &DbPool) -> anyhow::Result<()> { + let message = build_stats_broadcast(pool).await?; + hub.do_send(message); + Ok(()) +} + +pub fn spawn_stats_broadcaster(hub: WebSocketHubHandle, pool: DbPool) { + actix_rt::spawn(async move { + let mut interval = actix_rt::time::interval(Duration::from_secs(10)); + loop { + interval.tick().await; + if let Err(err) = broadcast_stats(&hub, &pool).await { + log::debug!("Failed to broadcast websocket stats: {}", err); + } + } + }); +} + +async fn build_stats_broadcast(pool: &DbPool) -> anyhow::Result { + let status = build_security_status(pool)?; + Ok(BroadcastMessage { + event_type: "stats:updated".to_string(), + payload: serde_json::to_value(status)?, + }) +} + +fn build_stats_message(pool: &DbPool) -> anyhow::Result { + Ok(serde_json::to_string(&WsEnvelope { + r#type: "stats:updated".to_string(), + payload: build_security_status(pool)?, + })?) +} + #[cfg(test)] mod tests { use super::*; - use actix_web::{test, App}; + use crate::alerting::alert::{AlertSeverity, AlertStatus, AlertType}; + use crate::database::models::Alert; + use crate::database::{create_alert, create_pool, init_database}; + use chrono::Utc; #[actix_rt::test] - async fn test_websocket_endpoint_exists() { - let app = test::init_service(App::new().configure(configure_routes)).await; + async fn test_build_stats_message() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + create_alert( + &pool, + Alert { + id: "a1".to_string(), + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::High, + message: "test".to_string(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: None, + }, + ) + .await + .unwrap(); + + let message = build_stats_message(&pool).unwrap(); + assert!(message.contains("\"type\":\"stats:updated\"")); + assert!(message.contains("\"alerts_new\":1")); + } - let req = test::TestRequest::get().uri("/ws").to_request(); - let resp = test::call_service(&app, req).await; + #[actix_rt::test] + async fn test_broadcast_message_serialization() { + let envelope = WsEnvelope { + r#type: "alert:created".to_string(), + payload: serde_json::json!({ "id": "alert-1" }), + }; - // Should return switching protocols status - assert_eq!(resp.status(), 101); // 101 Switching Protocols + let json = serde_json::to_string(&envelope).unwrap(); + assert_eq!( + json, + "{\"type\":\"alert:created\",\"payload\":{\"id\":\"alert-1\"}}" + ); } } diff --git a/src/baselines/learning.rs b/src/baselines/learning.rs index 027efd6..83f885f 100644 --- a/src/baselines/learning.rs +++ b/src/baselines/learning.rs @@ -1,15 +1,196 @@ //! Baseline learning use anyhow::Result; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::events::security::SecurityEvent; +use crate::ml::features::SecurityFeatures; + +const FEATURE_NAMES: [&str; 4] = [ + "syscall_rate", + "network_rate", + "unique_processes", + "privileged_calls", +]; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FeatureSummary { + pub syscall_rate: f64, + pub network_rate: f64, + pub unique_processes: f64, + pub privileged_calls: f64, +} + +impl FeatureSummary { + pub fn from_vector(vector: [f64; 4]) -> Self { + Self { + syscall_rate: vector[0], + network_rate: vector[1], + unique_processes: vector[2], + privileged_calls: vector[3], + } + } + + pub fn as_vector(&self) -> [f64; 4] { + [ + self.syscall_rate, + self.network_rate, + self.unique_processes, + self.privileged_calls, + ] + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FeatureBaseline { + pub sample_count: u64, + pub mean: FeatureSummary, + pub stddev: FeatureSummary, + pub last_updated: DateTime, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BaselineDrift { + pub score: f64, + pub deviating_features: Vec, +} /// Baseline learner pub struct BaselineLearner { - // TODO: Implement in TASK-015 + baselines: HashMap, + deviation_threshold: f64, +} + +#[derive(Debug, Clone)] +struct RunningFeatureStats { + sample_count: u64, + mean: [f64; 4], + m2: [f64; 4], + last_updated: DateTime, +} + +impl Default for RunningFeatureStats { + fn default() -> Self { + Self { + sample_count: 0, + mean: [0.0; 4], + m2: [0.0; 4], + last_updated: Utc::now(), + } + } +} + +impl RunningFeatureStats { + fn observe(&mut self, values: [f64; 4]) { + self.sample_count += 1; + let count = self.sample_count as f64; + + for (idx, value) in values.iter().enumerate() { + let delta = value - self.mean[idx]; + self.mean[idx] += delta / count; + let delta2 = value - self.mean[idx]; + self.m2[idx] += delta * delta2; + } + + self.last_updated = Utc::now(); + } + + fn stddev(&self) -> [f64; 4] { + if self.sample_count < 2 { + return [0.0; 4]; + } + + let denominator = (self.sample_count - 1) as f64; + let mut result = [0.0; 4]; + + for (idx, value) in result.iter_mut().enumerate() { + *value = (self.m2[idx] / denominator).sqrt(); + } + + result + } + + fn to_baseline(&self) -> FeatureBaseline { + FeatureBaseline { + sample_count: self.sample_count, + mean: FeatureSummary::from_vector(self.mean), + stddev: FeatureSummary::from_vector(self.stddev()), + last_updated: self.last_updated, + } + } } impl BaselineLearner { pub fn new() -> Result { - Ok(Self {}) + Ok(Self { + baselines: HashMap::new(), + deviation_threshold: 3.0, + }) + } + + pub fn with_deviation_threshold(mut self, threshold: f64) -> Self { + self.deviation_threshold = threshold.max(0.5); + self + } + + pub fn observe(&mut self, scope: impl Into, features: &SecurityFeatures) { + let entry = self.baselines.entry(scope.into()).or_default(); + entry.observe(features.as_vector()); + } + + pub fn observe_events( + &mut self, + scope: impl Into, + events: &[SecurityEvent], + window_seconds: f64, + ) -> SecurityFeatures { + let features = SecurityFeatures::from_events(events, window_seconds); + self.observe(scope, &features); + features + } + + pub fn baseline(&self, scope: &str) -> Option { + self.baselines + .get(scope) + .map(RunningFeatureStats::to_baseline) + } + + pub fn scopes(&self) -> impl Iterator { + self.baselines.keys().map(String::as_str) + } + + pub fn detect_drift(&self, scope: &str, features: &SecurityFeatures) -> Option { + let baseline = self.baselines.get(scope)?; + if baseline.sample_count < 2 { + return None; + } + + let values = features.as_vector(); + let means = baseline.mean; + let stddevs = baseline.stddev(); + let mut total_deviation = 0.0; + let mut deviating_features = Vec::new(); + + for idx in 0..FEATURE_NAMES.len() { + let deviation = if stddevs[idx] > f64::EPSILON { + (values[idx] - means[idx]).abs() / stddevs[idx] + } else { + let scale = means[idx].abs().max(1.0); + (values[idx] - means[idx]).abs() / scale + }; + + total_deviation += deviation; + if deviation >= self.deviation_threshold { + deviating_features.push(FEATURE_NAMES[idx].to_string()); + } + } + + Some(BaselineDrift { + score: total_deviation / FEATURE_NAMES.len() as f64, + deviating_features, + }) } } @@ -18,3 +199,69 @@ impl Default for BaselineLearner { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::events::security::SecurityEvent; + use crate::events::syscall::{SyscallEvent, SyscallType}; + use chrono::Utc; + + fn feature(syscall_rate: f64, network_rate: f64, unique_processes: u32) -> SecurityFeatures { + SecurityFeatures { + syscall_rate, + network_rate, + unique_processes, + privileged_calls: 0, + } + } + + #[test] + fn test_baseline_collection() { + let mut learner = BaselineLearner::new().unwrap(); + learner.observe("global", &feature(10.0, 2.0, 3)); + learner.observe("global", &feature(12.0, 2.5, 4)); + + let baseline = learner.baseline("global").unwrap(); + assert_eq!(baseline.sample_count, 2); + assert_eq!(baseline.mean.syscall_rate, 11.0); + assert_eq!(baseline.mean.unique_processes, 3.5); + } + + #[test] + fn test_drift_detection_flags_outlier() { + let mut learner = BaselineLearner::new() + .unwrap() + .with_deviation_threshold(2.0); + learner.observe("global", &feature(10.0, 2.0, 3)); + learner.observe("global", &feature(11.0, 2.1, 3)); + learner.observe("global", &feature(9.5, 1.9, 2)); + + let drift = learner + .detect_drift("global", &feature(25.0, 9.0, 12)) + .unwrap(); + + assert!(drift.score > 2.0); + assert!(drift + .deviating_features + .contains(&"syscall_rate".to_string())); + assert!(drift + .deviating_features + .contains(&"network_rate".to_string())); + } + + #[test] + fn test_observe_events_extracts_features_before_learning() { + let mut learner = BaselineLearner::new().unwrap(); + let events = vec![ + SecurityEvent::Syscall(SyscallEvent::new(1, 0, SyscallType::Execve, Utc::now())), + SecurityEvent::Syscall(SyscallEvent::new(1, 0, SyscallType::Connect, Utc::now())), + ]; + + let features = learner.observe_events("container:abc", &events, 1.0); + let baseline = learner.baseline("container:abc").unwrap(); + + assert_eq!(features.syscall_rate, 2.0); + assert_eq!(baseline.sample_count, 1); + } +} diff --git a/src/cli.rs b/src/cli.rs index 9ff6579..c06de7c 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -56,6 +56,30 @@ pub enum Command { /// Slack webhook URL for alert notifications #[arg(long)] slack_webhook: Option, + + /// Generic webhook URL for alert notifications + #[arg(long)] + webhook_url: Option, + + /// SMTP host for email alert notifications + #[arg(long)] + smtp_host: Option, + + /// SMTP port for email alert notifications + #[arg(long)] + smtp_port: Option, + + /// SMTP username / sender address for email alert notifications + #[arg(long)] + smtp_user: Option, + + /// SMTP password for email alert notifications + #[arg(long)] + smtp_password: Option, + + /// Comma-separated email recipients for alert notifications + #[arg(long)] + email_recipients: Option, }, } @@ -93,6 +117,12 @@ mod tests { ai_model, ai_api_url, slack_webhook, + webhook_url, + smtp_host, + smtp_port, + smtp_user, + smtp_password, + email_recipients, }) => { assert!(!once); assert!(!consume); @@ -103,6 +133,12 @@ mod tests { assert!(ai_model.is_none()); assert!(ai_api_url.is_none()); assert!(slack_webhook.is_none()); + assert!(webhook_url.is_none()); + assert!(smtp_host.is_none()); + assert!(smtp_port.is_none()); + assert!(smtp_user.is_none()); + assert!(smtp_password.is_none()); + assert!(email_recipients.is_none()); } _ => panic!("Expected Sniff command"), } @@ -147,6 +183,18 @@ mod tests { "https://api.openai.com/v1", "--slack-webhook", "https://hooks.slack.com/services/T/B/xxx", + "--webhook-url", + "https://example.com/hooks/stackdog", + "--smtp-host", + "smtp.example.com", + "--smtp-port", + "587", + "--smtp-user", + "alerts@example.com", + "--smtp-password", + "secret", + "--email-recipients", + "soc@example.com,oncall@example.com", ]); match cli.command { Some(Command::Sniff { @@ -159,6 +207,12 @@ mod tests { ai_model, ai_api_url, slack_webhook, + webhook_url, + smtp_host, + smtp_port, + smtp_user, + smtp_password, + email_recipients, }) => { assert!(once); assert!(consume); @@ -172,6 +226,15 @@ mod tests { slack_webhook.unwrap(), "https://hooks.slack.com/services/T/B/xxx" ); + assert_eq!(webhook_url.unwrap(), "https://example.com/hooks/stackdog"); + assert_eq!(smtp_host.unwrap(), "smtp.example.com"); + assert_eq!(smtp_port.unwrap(), 587); + assert_eq!(smtp_user.unwrap(), "alerts@example.com"); + assert_eq!(smtp_password.unwrap(), "secret"); + assert_eq!( + email_recipients.unwrap(), + "soc@example.com,oncall@example.com" + ); } _ => panic!("Expected Sniff command"), } diff --git a/src/collectors/docker_events.rs b/src/collectors/docker_events.rs index 706d153..c360869 100644 --- a/src/collectors/docker_events.rs +++ b/src/collectors/docker_events.rs @@ -2,17 +2,54 @@ //! //! Streams events from Docker daemon using Bollard -use anyhow::Result; +use std::collections::HashMap; + +use anyhow::{Context, Result}; +use bollard::system::EventsOptions; +use bollard::{models::EventMessageTypeEnum, Docker}; +use chrono::{TimeZone, Utc}; +use futures_util::stream::StreamExt; + +use crate::events::security::{ContainerEvent, ContainerEventType}; /// Docker events collector pub struct DockerEventsCollector { - // TODO: Implement in TASK-007 + client: Docker, } impl DockerEventsCollector { pub fn new() -> Result { - // TODO: Implement - Ok(Self {}) + let client = + Docker::connect_with_local_defaults().context("Failed to connect to Docker daemon")?; + Ok(Self { client }) + } + + pub async fn read_events(&self, limit: usize) -> Result> { + let mut filters = HashMap::new(); + filters.insert("type".to_string(), vec!["container".to_string()]); + let mut stream = self.client.events(Some(EventsOptions:: { + since: None, + until: None, + filters, + })); + + let mut events = Vec::new(); + while events.len() < limit { + let Some(event) = stream.next().await else { + break; + }; + + let event = event.context("Failed to read Docker event")?; + if !matches!(event.typ, Some(EventMessageTypeEnum::CONTAINER)) { + continue; + } + + if let Some(mapped) = map_container_event(event) { + events.push(mapped); + } + } + + Ok(events) } } @@ -21,3 +58,90 @@ impl Default for DockerEventsCollector { Self::new().unwrap() } } + +fn map_container_event(event: bollard::models::EventMessage) -> Option { + let actor = event.actor?; + let container_id = actor.id?; + let action = event.action?; + let event_type = match action.as_str() { + "start" => ContainerEventType::Start, + "stop" | "die" | "kill" => ContainerEventType::Stop, + "create" => ContainerEventType::Create, + "destroy" | "remove" => ContainerEventType::Destroy, + "pause" => ContainerEventType::Pause, + "unpause" => ContainerEventType::Unpause, + _ => return None, + }; + + let timestamp = event + .time + .and_then(|secs| Utc.timestamp_opt(secs, 0).single()) + .unwrap_or_else(Utc::now); + let details = actor.attributes.and_then(|attributes| { + if attributes.is_empty() { + None + } else { + Some( + attributes + .into_iter() + .map(|(key, value)| format!("{}={}", key, value)) + .collect::>() + .join(","), + ) + } + }); + + Some(ContainerEvent { + container_id, + event_type, + timestamp, + details, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use bollard::models::{EventActor, EventMessage}; + + #[test] + fn test_map_container_start_event() { + let event = EventMessage { + typ: Some(EventMessageTypeEnum::CONTAINER), + action: Some("start".to_string()), + actor: Some(EventActor { + id: Some("abc123".to_string()), + attributes: Some(HashMap::from([( + "name".to_string(), + "wordpress".to_string(), + )])), + }), + time: Some(1_700_000_000), + ..Default::default() + }; + + let mapped = map_container_event(event).unwrap(); + assert_eq!(mapped.container_id, "abc123"); + assert_eq!(mapped.event_type, ContainerEventType::Start); + assert!(mapped + .details + .as_deref() + .unwrap_or_default() + .contains("name=wordpress")); + } + + #[test] + fn test_map_container_ignores_unknown_action() { + let event = EventMessage { + typ: Some(EventMessageTypeEnum::CONTAINER), + action: Some("rename".to_string()), + actor: Some(EventActor { + id: Some("abc123".to_string()), + attributes: None, + }), + ..Default::default() + }; + + assert!(map_container_event(event).is_none()); + } +} diff --git a/src/collectors/ebpf/enrichment.rs b/src/collectors/ebpf/enrichment.rs index 141c9da..5c4a5b3 100644 --- a/src/collectors/ebpf/enrichment.rs +++ b/src/collectors/ebpf/enrichment.rs @@ -2,7 +2,7 @@ //! //! Enriches syscall events with additional context (container ID, process info, etc.) -use crate::events::syscall::SyscallEvent; +use crate::events::syscall::{SyscallDetails, SyscallEvent}; use anyhow::Result; /// Event enricher @@ -40,6 +40,26 @@ impl EventEnricher { if event.comm.is_none() { event.comm = self.get_process_comm(event.pid); } + + if let Some(SyscallDetails::Exec { + filename, + args, + argc: _, + }) = event.details.as_mut() + { + if filename.is_none() { + *filename = self.get_process_exe(event.pid).or_else(|| { + self.get_process_cmdline(event.pid) + .and_then(|cmdline| cmdline.first().cloned()) + }); + } + + if args.is_empty() { + if let Some(cmdline) = self.get_process_cmdline(event.pid) { + *args = cmdline; + } + } + } } /// Get parent PID for a process @@ -103,6 +123,26 @@ impl EventEnricher { None } + /// Get full process command line arguments. + pub fn get_process_cmdline(&self, _pid: u32) -> Option> { + #[cfg(target_os = "linux")] + { + let cmdline_path = format!("/proc/{}/cmdline", _pid); + if let Ok(content) = std::fs::read(&cmdline_path) { + let args = content + .split(|byte| *byte == 0) + .filter(|segment| !segment.is_empty()) + .map(|segment| String::from_utf8_lossy(segment).to_string()) + .collect::>(); + if !args.is_empty() { + return Some(args); + } + } + } + + None + } + /// Get process working directory pub fn get_process_cwd(&self, _pid: u32) -> Option { #[cfg(target_os = "linux")] @@ -145,4 +185,13 @@ mod tests { let normalized = normalize_timestamp(now); assert_eq!(now, normalized); } + + #[test] + fn test_get_process_cmdline_current_process() { + let enricher = EventEnricher::new().unwrap(); + let _cmdline = enricher.get_process_cmdline(std::process::id()); + + #[cfg(target_os = "linux")] + assert!(_cmdline.is_some()); + } } diff --git a/src/collectors/ebpf/syscall_monitor.rs b/src/collectors/ebpf/syscall_monitor.rs index 46ddce2..a5d94b1 100644 --- a/src/collectors/ebpf/syscall_monitor.rs +++ b/src/collectors/ebpf/syscall_monitor.rs @@ -144,6 +144,11 @@ impl SyscallMonitor { let mut events = self.event_buffer.drain(); for event in &mut events { let _ = self._enricher.enrich(event); + if event.container_id.is_none() { + if let Some(detector) = &mut self._container_detector { + event.container_id = detector.detect_container(event.pid); + } + } } events diff --git a/src/collectors/ebpf/types.rs b/src/collectors/ebpf/types.rs index 3455d4a..9a034ff 100644 --- a/src/collectors/ebpf/types.rs +++ b/src/collectors/ebpf/types.rs @@ -2,6 +2,10 @@ //! //! Shared type definitions for eBPF programs and userspace +use std::net::{Ipv4Addr, Ipv6Addr}; + +use chrono::{TimeZone, Utc}; + /// eBPF syscall event structure /// /// This structure is shared between eBPF programs and userspace @@ -151,12 +155,16 @@ impl EbpfSyscallEvent { self.comm[..len].copy_from_slice(&comm[..len]); self.comm[len] = 0; } + + /// Convert this raw eBPF event to a userspace syscall event. + pub fn to_syscall_event(&self) -> crate::events::syscall::SyscallEvent { + to_syscall_event(self) + } } /// Convert eBPF event to userspace SyscallEvent pub fn to_syscall_event(ebpf_event: &EbpfSyscallEvent) -> crate::events::syscall::SyscallEvent { use crate::events::syscall::{SyscallEvent, SyscallType}; - use chrono::Utc; // Convert syscall_id to SyscallType let syscall_type = match ebpf_event.syscall_id { @@ -170,18 +178,111 @@ pub fn to_syscall_event(ebpf_event: &EbpfSyscallEvent) -> crate::events::syscall let mut event = SyscallEvent::new( ebpf_event.pid, ebpf_event.uid, - syscall_type, - Utc::now(), // Use current time (timestamp from eBPF may need conversion) + syscall_type.clone(), + timestamp_to_utc(ebpf_event.timestamp), ); event.comm = Some(ebpf_event.comm_str()); + event.details = match syscall_type { + SyscallType::Execve | SyscallType::Execveat => { + // SAFETY: We interpret the union according to the syscall type. + Some(exec_details(unsafe { &ebpf_event.data.execve })) + } + SyscallType::Connect => { + // SAFETY: We interpret the union according to the syscall type. + Some(connect_details(unsafe { &ebpf_event.data.connect })) + } + SyscallType::Openat => { + // SAFETY: We interpret the union according to the syscall type. + Some(openat_details(unsafe { &ebpf_event.data.openat })) + } + SyscallType::Ptrace => { + // SAFETY: We interpret the union according to the syscall type. + Some(ptrace_details(unsafe { &ebpf_event.data.ptrace })) + } + _ => None, + }; event } +fn timestamp_to_utc(timestamp_ns: u64) -> chrono::DateTime { + if timestamp_ns == 0 { + return chrono::Utc::now(); + } + + let seconds = (timestamp_ns / 1_000_000_000) as i64; + let nanos = (timestamp_ns % 1_000_000_000) as u32; + Utc.timestamp_opt(seconds, nanos) + .single() + .unwrap_or_else(Utc::now) +} + +fn exec_details(data: &ExecveData) -> crate::events::syscall::SyscallDetails { + crate::events::syscall::SyscallDetails::Exec { + filename: decode_string(&data.filename, Some(data.filename_len as usize)), + args: Vec::new(), + argc: data.argc, + } +} + +fn connect_details(data: &ConnectData) -> crate::events::syscall::SyscallDetails { + crate::events::syscall::SyscallDetails::Connect { + dst_addr: decode_ip(data), + dst_port: u16::from_be(data.dst_port), + family: data.family, + } +} + +fn openat_details(data: &OpenatData) -> crate::events::syscall::SyscallDetails { + crate::events::syscall::SyscallDetails::Openat { + path: decode_string(&data.path, Some(data.path_len as usize)), + flags: data.flags, + } +} + +fn ptrace_details(data: &PtraceData) -> crate::events::syscall::SyscallDetails { + crate::events::syscall::SyscallDetails::Ptrace { + target_pid: data.target_pid, + request: data.request, + addr: data.addr, + data: data.data, + } +} + +fn decode_string(bytes: &[u8], declared_len: Option) -> Option { + let first_nul = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len()); + let len = declared_len + .unwrap_or(first_nul) + .min(first_nul) + .min(bytes.len()); + if len == 0 { + return None; + } + + Some(String::from_utf8_lossy(&bytes[..len]).to_string()) +} + +fn decode_ip(data: &ConnectData) -> Option { + match data.family { + 2 => Some( + Ipv4Addr::new( + data.dst_ip[0], + data.dst_ip[1], + data.dst_ip[2], + data.dst_ip[3], + ) + .to_string(), + ), + 10 => Some(Ipv6Addr::from(data.dst_ip).to_string()), + _ => None, + } +} + #[cfg(test)] mod tests { use super::*; + use crate::events::syscall::SyscallDetails; #[test] fn test_event_creation() { @@ -219,4 +320,79 @@ mod tests { // Should be truncated to 15 chars + null assert_eq!(event.comm_str().len(), 15); } + + #[test] + fn test_to_syscall_event_preserves_exec_details() { + let mut event = EbpfSyscallEvent::new(1234, 1000, 59); + event.set_comm(b"php-fpm"); + event.timestamp = 1_700_000_000_123_456_789; + let mut filename = [0u8; 128]; + filename[..18].copy_from_slice(b"/usr/sbin/sendmail"); + event.data = EbpfEventData { + execve: ExecveData { + filename_len: 18, + filename, + argc: 2, + }, + }; + + let converted = event.to_syscall_event(); + assert_eq!(converted.comm.as_deref(), Some("php-fpm")); + match converted.details { + Some(SyscallDetails::Exec { filename, argc, .. }) => { + assert_eq!(filename.as_deref(), Some("/usr/sbin/sendmail")); + assert_eq!(argc, 2); + } + other => panic!("unexpected details: {:?}", other), + } + } + + #[test] + fn test_to_syscall_event_preserves_connect_details() { + let mut event = EbpfSyscallEvent::new(1234, 1000, 42); + event.data = EbpfEventData { + connect: ConnectData { + dst_ip: [192, 0, 2, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + dst_port: 587u16.to_be(), + family: 2, + }, + }; + + let converted = event.to_syscall_event(); + match converted.details { + Some(SyscallDetails::Connect { + dst_addr, + dst_port, + family, + }) => { + assert_eq!(dst_addr.as_deref(), Some("192.0.2.25")); + assert_eq!(dst_port, 587); + assert_eq!(family, 2); + } + other => panic!("unexpected details: {:?}", other), + } + } + + #[test] + fn test_to_syscall_event_preserves_openat_details() { + let mut event = EbpfSyscallEvent::new(1234, 1000, 257); + let mut path = [0u8; 256]; + path[..17].copy_from_slice(b"/etc/postfix/main"); + event.data = EbpfEventData { + openat: OpenatData { + path_len: 17, + path, + flags: 0o2, + }, + }; + + let converted = event.to_syscall_event(); + match converted.details { + Some(SyscallDetails::Openat { path, flags }) => { + assert_eq!(path.as_deref(), Some("/etc/postfix/main")); + assert_eq!(flags, 0o2); + } + other => panic!("unexpected details: {:?}", other), + } + } } diff --git a/src/collectors/network.rs b/src/collectors/network.rs index f956fd1..5cba009 100644 --- a/src/collectors/network.rs +++ b/src/collectors/network.rs @@ -3,20 +3,116 @@ //! Captures network traffic for security analysis use anyhow::Result; +use chrono::Utc; +use std::collections::HashMap; + +use crate::docker::{ContainerInfo, DockerClient}; +use crate::events::security::NetworkEvent; /// Network traffic collector pub struct NetworkCollector { - // TODO: Implement + client: DockerClient, + previous: HashMap, } impl NetworkCollector { - pub fn new() -> Result { - Ok(Self {}) + pub async fn new() -> Result { + Ok(Self { + client: DockerClient::new().await?, + previous: HashMap::new(), + }) + } + + pub async fn collect_outbound_events(&mut self) -> Result> { + let containers = self.client.list_containers(false).await?; + let mut events = Vec::new(); + + for container in containers { + if container.status != "Running" { + continue; + } + + let stats = self.client.get_container_stats(&container.id).await?; + let current = (stats.network_tx, stats.network_tx_packets); + let previous = self.previous.insert(container.id.clone(), current); + + if let Some((prev_tx_bytes, prev_tx_packets)) = previous { + let delta_bytes = current.0.saturating_sub(prev_tx_bytes); + let delta_packets = current.1.saturating_sub(prev_tx_packets); + if delta_bytes == 0 && delta_packets == 0 { + continue; + } + + if let Some(event) = build_network_event(&container, delta_bytes, delta_packets) { + events.push(event); + } + } + } + + Ok(events) } } impl Default for NetworkCollector { fn default() -> Self { - Self::new().unwrap() + panic!("Use NetworkCollector::new().await") + } +} + +fn build_network_event( + container: &ContainerInfo, + _delta_tx_bytes: u64, + _delta_tx_packets: u64, +) -> Option { + let src_ip = container + .network_settings + .values() + .find(|ip| !ip.is_empty()) + .cloned()?; + + Some(NetworkEvent { + src_ip, + dst_ip: "0.0.0.0".to_string(), + src_port: 0, + dst_port: 0, + protocol: "tcp".to_string(), + timestamp: Utc::now(), + container_id: Some(container.id.clone()), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_network_event_uses_container_ip() { + let container = ContainerInfo { + id: "abc123".to_string(), + name: "wordpress".to_string(), + image: "wordpress:latest".to_string(), + status: "Running".to_string(), + created: String::new(), + network_settings: HashMap::from([("bridge".to_string(), "172.17.0.5".to_string())]), + }; + + let event = build_network_event(&container, 64_000, 250).unwrap(); + assert_eq!(event.src_ip, "172.17.0.5"); + assert_eq!(event.container_id.as_deref(), Some("abc123")); + assert_eq!(event.dst_port, 0); + } + + #[test] + fn test_build_network_event_requires_ip() { + let container = ContainerInfo { + id: "abc123".to_string(), + name: "wordpress".to_string(), + image: "wordpress:latest".to_string(), + status: "Running".to_string(), + created: String::new(), + network_settings: HashMap::new(), + }; + + assert!(build_network_event(&container, 64_000, 250).is_none()); } } diff --git a/src/database/baselines.rs b/src/database/baselines.rs index 87ce277..a1f74d6 100644 --- a/src/database/baselines.rs +++ b/src/database/baselines.rs @@ -1,20 +1,162 @@ //! Baselines database operations +use crate::baselines::learning::{FeatureBaseline, FeatureSummary}; +use crate::database::connection::DbPool; use anyhow::Result; +use rusqlite::{params, OptionalExtension}; +use serde::{Deserialize, Serialize}; /// Baselines database manager pub struct BaselinesDb { - // TODO: Implement + pool: DbPool, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct StoredBaseline { + pub scope: String, + pub baseline: FeatureBaseline, } impl BaselinesDb { - pub fn new() -> Result { - Ok(Self {}) + pub fn new(pool: DbPool) -> Result { + Ok(Self { pool }) } + + pub fn save_baseline(&self, scope: &str, baseline: &FeatureBaseline) -> Result<()> { + let conn = self.pool.get()?; + conn.execute( + "INSERT INTO baselines (scope, sample_count, mean, stddev, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5) + ON CONFLICT(scope) DO UPDATE SET + sample_count = excluded.sample_count, + mean = excluded.mean, + stddev = excluded.stddev, + updated_at = excluded.updated_at", + params![ + scope, + baseline.sample_count as i64, + serde_json::to_string(&baseline.mean)?, + serde_json::to_string(&baseline.stddev)?, + baseline.last_updated.to_rfc3339(), + ], + )?; + + Ok(()) + } + + pub fn load_baseline(&self, scope: &str) -> Result> { + let conn = self.pool.get()?; + let row = conn + .query_row( + "SELECT sample_count, mean, stddev, updated_at FROM baselines WHERE scope = ?1", + params![scope], + |row| { + Ok(FeatureBaseline { + sample_count: row.get::<_, i64>(0)? as u64, + mean: serde_json::from_str::(&row.get::<_, String>(1)?) + .map_err(to_sql_error)?, + stddev: serde_json::from_str::(&row.get::<_, String>(2)?) + .map_err(to_sql_error)?, + last_updated: chrono::DateTime::parse_from_rfc3339( + &row.get::<_, String>(3)?, + ) + .map_err(to_sql_error)? + .with_timezone(&chrono::Utc), + }) + }, + ) + .optional()?; + + Ok(row) + } + + pub fn list_baselines(&self) -> Result> { + let conn = self.pool.get()?; + let mut stmt = conn.prepare( + "SELECT scope, sample_count, mean, stddev, updated_at + FROM baselines + ORDER BY updated_at DESC, scope ASC", + )?; + + let rows = stmt.query_map([], |row| { + Ok(StoredBaseline { + scope: row.get(0)?, + baseline: FeatureBaseline { + sample_count: row.get::<_, i64>(1)? as u64, + mean: serde_json::from_str::(&row.get::<_, String>(2)?) + .map_err(to_sql_error)?, + stddev: serde_json::from_str::(&row.get::<_, String>(3)?) + .map_err(to_sql_error)?, + last_updated: chrono::DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) + .map_err(to_sql_error)? + .with_timezone(&chrono::Utc), + }, + }) + })?; + + Ok(rows.collect::>>()?) + } + + pub fn delete_baseline(&self, scope: &str) -> Result<()> { + let conn = self.pool.get()?; + conn.execute("DELETE FROM baselines WHERE scope = ?1", params![scope])?; + Ok(()) + } +} + +fn to_sql_error(err: impl std::error::Error + Send + Sync + 'static) -> rusqlite::Error { + rusqlite::Error::ToSqlConversionFailure(Box::new(err)) } -impl Default for BaselinesDb { - fn default() -> Self { - Self::new().unwrap() +#[cfg(test)] +mod tests { + use super::*; + use crate::database::{create_pool, init_database}; + + fn sample_baseline() -> FeatureBaseline { + FeatureBaseline { + sample_count: 3, + mean: FeatureSummary { + syscall_rate: 8.5, + network_rate: 1.2, + unique_processes: 2.0, + privileged_calls: 0.5, + }, + stddev: FeatureSummary { + syscall_rate: 1.0, + network_rate: 0.2, + unique_processes: 0.5, + privileged_calls: 0.3, + }, + last_updated: chrono::Utc::now(), + } + } + + #[test] + fn test_baseline_persistence_round_trip() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let db = BaselinesDb::new(pool).unwrap(); + + db.save_baseline("global", &sample_baseline()).unwrap(); + let loaded = db.load_baseline("global").unwrap().unwrap(); + + assert_eq!(loaded.sample_count, 3); + assert_eq!(loaded.mean.syscall_rate, 8.5); + } + + #[test] + fn test_list_and_delete_baselines() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let db = BaselinesDb::new(pool).unwrap(); + + db.save_baseline("global", &sample_baseline()).unwrap(); + db.save_baseline("container:abc", &sample_baseline()) + .unwrap(); + + assert_eq!(db.list_baselines().unwrap().len(), 2); + db.delete_baseline("global").unwrap(); + assert!(db.load_baseline("global").unwrap().is_none()); } } diff --git a/src/database/connection.rs b/src/database/connection.rs index 4db4a27..2513cae 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -108,6 +108,12 @@ pub fn init_database(pool: &DbPool) -> Result<()> { "CREATE INDEX IF NOT EXISTS idx_alerts_timestamp ON alerts(timestamp)", [], ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_alerts_container_id + ON alerts(json_extract(metadata, '$.container_id')) + WHERE json_valid(metadata)", + [], + ); let _ = conn.execute( "CREATE INDEX IF NOT EXISTS idx_threats_status ON threats(status)", @@ -165,6 +171,23 @@ pub fn init_database(pool: &DbPool) -> Result<()> { [], ); + // Create baselines table + conn.execute( + "CREATE TABLE IF NOT EXISTS baselines ( + scope TEXT PRIMARY KEY, + sample_count INTEGER NOT NULL, + mean TEXT NOT NULL, + stddev TEXT NOT NULL, + updated_at TEXT NOT NULL + )", + [], + )?; + + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_baselines_updated_at ON baselines(updated_at)", + [], + ); + Ok(()) } diff --git a/src/database/mod.rs b/src/database/mod.rs index c8fa512..e55ccab 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,9 +1,11 @@ //! Database module +pub mod baselines; pub mod connection; pub mod models; pub mod repositories; +pub use baselines::*; pub use connection::{create_pool, init_database, DbPool}; pub use models::*; pub use repositories::alerts::*; diff --git a/src/database/models/mod.rs b/src/database/models/mod.rs index 8cd8fb5..f78053f 100644 --- a/src/database/models/mod.rs +++ b/src/database/models/mod.rs @@ -1,17 +1,119 @@ //! Database models +use std::collections::HashMap; + +use chrono::Utc; use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::alerting::alert::{AlertSeverity, AlertStatus, AlertType}; + +/// Structured alert metadata stored in the database as JSON. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct AlertMetadata { + #[serde(skip_serializing_if = "Option::is_none")] + pub container_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + #[serde(flatten)] + pub extra: HashMap, +} + +impl AlertMetadata { + pub fn with_container_id(mut self, container_id: impl Into) -> Self { + self.container_id = Some(container_id.into()); + self + } + + pub fn with_source(mut self, source: impl Into) -> Self { + self.source = Some(source.into()); + self + } + + pub fn with_reason(mut self, reason: impl Into) -> Self { + self.reason = Some(reason.into()); + self + } + + pub fn is_empty(&self) -> bool { + self.container_id.is_none() + && self.source.is_none() + && self.reason.is_none() + && self.extra.is_empty() + } + + pub fn from_storage(raw: &str) -> Option { + let trimmed = raw.trim(); + if trimmed.is_empty() { + return None; + } + + serde_json::from_str(trimmed) + .ok() + .or_else(|| Self::from_legacy_pairs(trimmed)) + .or_else(|| Some(Self::default().with_reason(trimmed.to_string()))) + } + + fn from_legacy_pairs(raw: &str) -> Option { + let mut metadata = Self::default(); + let mut found_pair = false; + + for part in raw + .split(',') + .map(str::trim) + .filter(|part| !part.is_empty()) + { + let Some((key, value)) = part.split_once('=') else { + continue; + }; + + found_pair = true; + let value = value.trim().to_string(); + match key.trim() { + "container_id" => metadata.container_id = Some(value), + "source" => metadata.source = Some(value), + "reason" => metadata.reason = Some(value), + other => { + metadata.extra.insert(other.to_string(), value); + } + } + } + + found_pair.then_some(metadata) + } +} /// Alert model #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Alert { pub id: String, - pub alert_type: String, - pub severity: String, + pub alert_type: AlertType, + pub severity: AlertSeverity, pub message: String, - pub status: String, + pub status: AlertStatus, pub timestamp: String, - pub metadata: Option, + pub metadata: Option, +} + +impl Alert { + pub fn new(alert_type: AlertType, severity: AlertSeverity, message: impl Into) -> Self { + Self { + id: Uuid::new_v4().to_string(), + alert_type, + severity, + message: message.into(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: None, + } + } + + pub fn with_metadata(mut self, metadata: AlertMetadata) -> Self { + self.metadata = (!metadata.is_empty()).then_some(metadata); + self + } } /// Threat model diff --git a/src/database/repositories/alerts.rs b/src/database/repositories/alerts.rs index d6d7a88..fa9d98c 100644 --- a/src/database/repositories/alerts.rs +++ b/src/database/repositories/alerts.rs @@ -1,11 +1,11 @@ //! Alert repository using rusqlite +use crate::alerting::alert::{AlertSeverity, AlertStatus, AlertType}; use crate::database::connection::DbPool; -use crate::database::models::Alert; +use crate::database::models::{Alert, AlertMetadata}; use anyhow::Result; -use chrono::Utc; use rusqlite::params; -use uuid::Uuid; +use rusqlite::types::Type; /// Alert filter #[derive(Debug, Clone, Default)] @@ -23,33 +23,138 @@ pub struct AlertStats { pub resolved_count: i64, } +/// Severity breakdown for open security alerts. +#[derive(Debug, Clone, Default)] +pub struct SeverityBreakdown { + pub info_count: u32, + pub low_count: u32, + pub medium_count: u32, + pub high_count: u32, + pub critical_count: u32, +} + +impl SeverityBreakdown { + pub fn weighted_penalty(&self) -> u32 { + self.info_count + + self.low_count.saturating_mul(4) + + self.medium_count.saturating_mul(10) + + self.high_count.saturating_mul(20) + + self.critical_count.saturating_mul(35) + } +} + +/// Snapshot of current security status derived from persisted alerts. +#[derive(Debug, Clone, Default)] +pub struct SecurityStatusSnapshot { + pub alerts_new: u32, + pub alerts_acknowledged: u32, + pub active_threats: u32, + pub quarantined_containers: u32, + pub severity_breakdown: SeverityBreakdown, +} + +/// Alert summary for a single container. +#[derive(Debug, Clone, Default)] +pub struct ContainerAlertSummary { + pub active_threats: u32, + pub quarantined: bool, + pub severity_breakdown: SeverityBreakdown, + pub last_alert_at: Option, +} + +impl ContainerAlertSummary { + pub fn risk_score(&self) -> u32 { + let base = self.severity_breakdown.weighted_penalty(); + let quarantine_penalty = if self.quarantined { 25 } else { 0 }; + (base + quarantine_penalty).min(100) + } + + pub fn security_state(&self) -> &'static str { + if self.quarantined { + "Quarantined" + } else if self.active_threats > 0 { + "AtRisk" + } else { + "Secure" + } + } +} + fn map_alert_row(row: &rusqlite::Row) -> Result { + let alert_type = parse_alert_type(row.get::<_, String>(1)?, 1)?; + let severity = parse_alert_severity(row.get::<_, String>(2)?, 2)?; + let status = parse_alert_status(row.get::<_, String>(4)?, 4)?; + let metadata = row + .get::<_, Option>(6)? + .and_then(|raw| AlertMetadata::from_storage(&raw)); + Ok(Alert { id: row.get(0)?, - alert_type: row.get(1)?, - severity: row.get(2)?, + alert_type, + severity, message: row.get(3)?, - status: row.get(4)?, + status, timestamp: row.get(5)?, - metadata: row.get(6)?, + metadata, + }) +} + +fn parse_alert_type(value: String, column_index: usize) -> Result { + value.parse().map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + column_index, + Type::Text, + Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, err)), + ) }) } +fn parse_alert_severity( + value: String, + column_index: usize, +) -> Result { + value.parse().map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + column_index, + Type::Text, + Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, err)), + ) + }) +} + +fn parse_alert_status(value: String, column_index: usize) -> Result { + value.parse().map_err(|err| { + rusqlite::Error::FromSqlConversionFailure( + column_index, + Type::Text, + Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, err)), + ) + }) +} + +fn serialize_metadata(metadata: Option<&AlertMetadata>) -> Result> { + match metadata { + Some(metadata) if !metadata.is_empty() => Ok(Some(serde_json::to_string(metadata)?)), + _ => Ok(None), + } +} + /// Create a new alert pub async fn create_alert(pool: &DbPool, alert: Alert) -> Result { let conn = pool.get()?; + let metadata = serialize_metadata(alert.metadata.as_ref())?; conn.execute( "INSERT INTO alerts (id, alert_type, severity, message, status, timestamp, metadata) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", params![ - alert.id, - alert.alert_type, - alert.severity, - alert.message, - alert.status, - alert.timestamp, - alert.metadata + &alert.id, + alert.alert_type.to_string(), + alert.severity.to_string(), + &alert.message, + alert.status.to_string(), + &alert.timestamp, + metadata ], )?; @@ -167,17 +272,158 @@ pub async fn get_alert_stats(pool: &DbPool) -> Result { }) } +/// Get a live security status snapshot from persisted alerts. +pub fn get_security_status_snapshot(pool: &DbPool) -> Result { + let conn = pool.get()?; + let snapshot = conn.query_row( + "SELECT + COALESCE(SUM(CASE WHEN status = 'New' THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status = 'Acknowledged' THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' AND alert_type = 'QuarantineApplied' THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + AND severity = 'Info' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + AND severity = 'Low' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + AND severity = 'Medium' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + AND severity = 'High' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + AND severity = 'Critical' + THEN 1 ELSE 0 END), 0) + FROM alerts", + [], + |row| { + Ok(SecurityStatusSnapshot { + alerts_new: row.get::<_, i64>(0)?.max(0) as u32, + alerts_acknowledged: row.get::<_, i64>(1)?.max(0) as u32, + active_threats: row.get::<_, i64>(2)?.max(0) as u32, + quarantined_containers: row.get::<_, i64>(3)?.max(0) as u32, + severity_breakdown: SeverityBreakdown { + info_count: row.get::<_, i64>(4)?.max(0) as u32, + low_count: row.get::<_, i64>(5)?.max(0) as u32, + medium_count: row.get::<_, i64>(6)?.max(0) as u32, + high_count: row.get::<_, i64>(7)?.max(0) as u32, + critical_count: row.get::<_, i64>(8)?.max(0) as u32, + }, + }) + }, + )?; + + Ok(snapshot) +} + +/// Get alert-derived security summary for a specific container. +pub fn get_container_alert_summary( + pool: &DbPool, + container_id: &str, +) -> Result { + let conn = pool.get()?; + let legacy_metadata = format!("container_id={container_id}"); + let metadata_pattern = format!("%{legacy_metadata}%"); + let summary = conn.query_row( + "SELECT + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND alert_type IN ('ThreatDetected', 'AnomalyDetected', 'RuleViolation', 'ThresholdExceeded') + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND alert_type = 'QuarantineApplied' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND severity = 'Info' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND severity = 'Low' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND severity = 'Medium' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND severity = 'High' + THEN 1 ELSE 0 END), 0), + COALESCE(SUM(CASE WHEN status != 'Resolved' + AND ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) + AND severity = 'Critical' + THEN 1 ELSE 0 END), 0), + MAX(CASE WHEN ( + (json_valid(metadata) AND json_extract(metadata, '$.container_id') = ?1) + OR metadata = ?2 + OR metadata LIKE ?3 + ) THEN timestamp ELSE NULL END) + FROM alerts", + params![container_id, legacy_metadata, metadata_pattern], + |row| { + Ok(ContainerAlertSummary { + active_threats: row.get::<_, i64>(0)?.max(0) as u32, + quarantined: row.get::<_, i64>(1)?.max(0) > 0, + severity_breakdown: SeverityBreakdown { + info_count: row.get::<_, i64>(2)?.max(0) as u32, + low_count: row.get::<_, i64>(3)?.max(0) as u32, + medium_count: row.get::<_, i64>(4)?.max(0) as u32, + high_count: row.get::<_, i64>(5)?.max(0) as u32, + critical_count: row.get::<_, i64>(6)?.max(0) as u32, + }, + last_alert_at: row.get(7)?, + }) + }, + )?; + + Ok(summary) +} + /// Create a sample alert (for testing) pub fn create_sample_alert() -> Alert { - Alert { - id: Uuid::new_v4().to_string(), - alert_type: "ThreatDetected".to_string(), - severity: "High".to_string(), - message: "Suspicious activity detected".to_string(), - status: "New".to_string(), - timestamp: Utc::now().to_rfc3339(), - metadata: None, - } + Alert::new( + AlertType::ThreatDetected, + AlertSeverity::High, + "Suspicious activity detected", + ) } #[cfg(test)] @@ -185,6 +431,7 @@ mod tests { use super::*; use crate::database::connection::create_pool; use crate::database::connection::init_database; + use chrono::Utc; #[actix_rt::test] async fn test_create_and_list_alerts() { @@ -212,7 +459,7 @@ mod tests { .unwrap(); let updated = get_alert(&pool, &alert.id).await.unwrap().unwrap(); - assert_eq!(updated.status, "Acknowledged"); + assert_eq!(updated.status, AlertStatus::Acknowledged); } #[actix_rt::test] @@ -229,4 +476,111 @@ mod tests { assert_eq!(stats.total_count, 3); assert_eq!(stats.new_count, 3); } + + #[actix_rt::test] + async fn test_get_security_status_snapshot() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + create_alert( + &pool, + Alert { + id: "a1".to_string(), + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::Critical, + message: "critical".to_string(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: None, + }, + ) + .await + .unwrap(); + create_alert( + &pool, + Alert { + id: "a2".to_string(), + alert_type: AlertType::QuarantineApplied, + severity: AlertSeverity::High, + message: "q".to_string(), + status: AlertStatus::Acknowledged, + timestamp: Utc::now().to_rfc3339(), + metadata: Some(AlertMetadata::default().with_container_id("abc123")), + }, + ) + .await + .unwrap(); + + let snapshot = get_security_status_snapshot(&pool).unwrap(); + assert_eq!(snapshot.alerts_new, 1); + assert_eq!(snapshot.alerts_acknowledged, 1); + assert_eq!(snapshot.active_threats, 1); + assert_eq!(snapshot.quarantined_containers, 1); + assert_eq!(snapshot.severity_breakdown.critical_count, 1); + } + + #[actix_rt::test] + async fn test_get_container_alert_summary() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + create_alert( + &pool, + Alert { + id: "a1".to_string(), + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::High, + message: "threat".to_string(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: Some(AlertMetadata::default().with_container_id("abc123")), + }, + ) + .await + .unwrap(); + create_alert( + &pool, + Alert { + id: "a2".to_string(), + alert_type: AlertType::QuarantineApplied, + severity: AlertSeverity::High, + message: "quarantine".to_string(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: Some(AlertMetadata::default().with_container_id("abc123")), + }, + ) + .await + .unwrap(); + + let summary = get_container_alert_summary(&pool, "abc123").unwrap(); + assert_eq!(summary.active_threats, 1); + assert!(summary.quarantined); + assert_eq!(summary.security_state(), "Quarantined"); + assert!(summary.risk_score() > 0); + } + + #[actix_rt::test] + async fn test_get_container_alert_summary_supports_legacy_metadata() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + create_alert( + &pool, + Alert { + id: "legacy-a1".to_string(), + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::High, + message: "legacy threat".to_string(), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: Some(AlertMetadata::from_storage("container_id=legacy123").unwrap()), + }, + ) + .await + .unwrap(); + + let summary = get_container_alert_summary(&pool, "legacy123").unwrap(); + assert_eq!(summary.active_threats, 1); + } } diff --git a/src/docker/containers.rs b/src/docker/containers.rs index 7b4d072..a308706 100644 --- a/src/docker/containers.rs +++ b/src/docker/containers.rs @@ -1,11 +1,10 @@ //! Container management -use crate::database::models::Alert; -use crate::database::{create_alert, DbPool}; +use crate::alerting::alert::{AlertSeverity, AlertType}; +use crate::database::models::{Alert, AlertMetadata}; +use crate::database::{create_alert, get_container_alert_summary, DbPool}; use crate::docker::client::{ContainerInfo, DockerClient}; use anyhow::Result; -use chrono::Utc; -use uuid::Uuid; /// Container manager pub struct ContainerManager { @@ -44,15 +43,16 @@ impl ContainerManager { self.docker.quarantine_container(container_id).await?; // Create alert - let alert = Alert { - id: Uuid::new_v4().to_string(), - alert_type: "QuarantineApplied".to_string(), - severity: "High".to_string(), - message: format!("Container {} quarantined: {}", container_id, reason), - status: "New".to_string(), - timestamp: Utc::now().to_rfc3339(), - metadata: Some(format!("container_id={}", container_id)), - }; + let alert = Alert::new( + AlertType::QuarantineApplied, + AlertSeverity::High, + format!("Container {} quarantined: {}", container_id, reason), + ) + .with_metadata( + AlertMetadata::default() + .with_container_id(container_id) + .with_reason(reason), + ); let _ = create_alert(&self.pool, alert).await; @@ -80,22 +80,13 @@ impl ContainerManager { container_id: &str, ) -> Result { let _info = self.docker.get_container_info(container_id).await?; - - // Calculate risk score based on various factors - let risk_score = 0; - let threats = 0; - let security_state = "Secure"; - - // Check if running as root - // Check for privileged mode - // Check for exposed ports - // Check for volume mounts + let summary = get_container_alert_summary(&self.pool, container_id)?; Ok(ContainerSecurityStatus { container_id: container_id.to_string(), - risk_score, - threats, - security_state: security_state.to_string(), + risk_score: summary.risk_score(), + threats: summary.active_threats, + security_state: summary.security_state().to_string(), }) } } diff --git a/src/docker/mail_guard.rs b/src/docker/mail_guard.rs index 1a2a9c1..44ee927 100644 --- a/src/docker/mail_guard.rs +++ b/src/docker/mail_guard.rs @@ -1,11 +1,11 @@ use std::collections::{HashMap, HashSet}; use std::env; -use chrono::Utc; use tokio::time::{sleep, Duration}; -use uuid::Uuid; +use crate::alerting::alert::{AlertSeverity, AlertType}; use crate::database::models::Alert; +use crate::database::models::AlertMetadata; use crate::database::repositories::alerts::create_alert; use crate::database::DbPool; use crate::docker::client::{ContainerInfo, ContainerStats}; @@ -328,18 +328,20 @@ impl MailAbuseGuard { detector.mark_quarantined(&container.id); create_alert( pool, - Alert { - id: Uuid::new_v4().to_string(), - alert_type: "ThreatDetected".into(), - severity: "Critical".into(), - message: format!( + Alert::new( + AlertType::ThreatDetected, + AlertSeverity::Critical, + format!( "Mail abuse guard quarantined container {} ({})", container.name, container.id ), - status: "New".into(), - timestamp: Utc::now().to_rfc3339(), - metadata: Some(reason.clone()), - }, + ) + .with_metadata( + AlertMetadata::default() + .with_container_id(&container.id) + .with_source("mail-abuse-guard") + .with_reason(reason.clone()), + ), ) .await?; log::warn!("{}", reason); diff --git a/src/events/syscall.rs b/src/events/syscall.rs index ede04bf..3eb9641 100644 --- a/src/events/syscall.rs +++ b/src/events/syscall.rs @@ -49,6 +49,32 @@ pub struct SyscallEvent { pub timestamp: DateTime, pub container_id: Option, pub comm: Option, + pub details: Option, +} + +/// Syscall-specific details captured by eBPF or userspace enrichment. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum SyscallDetails { + Exec { + filename: Option, + args: Vec, + argc: u32, + }, + Connect { + dst_addr: Option, + dst_port: u16, + family: u16, + }, + Openat { + path: Option, + flags: u32, + }, + Ptrace { + target_pid: u32, + request: u32, + addr: u64, + data: u64, + }, } impl SyscallEvent { @@ -61,6 +87,7 @@ impl SyscallEvent { timestamp, container_id: None, comm: None, + details: None, } } @@ -78,6 +105,28 @@ impl SyscallEvent { pub fn uid(&self) -> Option { Some(self.uid) } + + /// Get exec details if this is an exec event. + pub fn exec_details(&self) -> Option<(&Option, &[String], u32)> { + match self.details.as_ref() { + Some(SyscallDetails::Exec { + filename, + args, + argc, + }) => Some((filename, args.as_slice(), *argc)), + _ => None, + } + } + + /// Get connect destination if this is a connect event. + pub fn connect_destination(&self) -> Option<(Option<&str>, u16)> { + match self.details.as_ref() { + Some(SyscallDetails::Connect { + dst_addr, dst_port, .. + }) => Some((dst_addr.as_deref(), *dst_port)), + _ => None, + } + } } /// Builder for SyscallEvent @@ -88,6 +137,7 @@ pub struct SyscallEventBuilder { timestamp: Option>, container_id: Option, comm: Option, + details: Option, } impl SyscallEventBuilder { @@ -99,6 +149,7 @@ impl SyscallEventBuilder { timestamp: None, container_id: None, comm: None, + details: None, } } @@ -132,6 +183,11 @@ impl SyscallEventBuilder { self } + pub fn details(mut self, details: Option) -> Self { + self.details = details; + self + } + pub fn build(self) -> SyscallEvent { SyscallEvent { pid: self.pid, @@ -140,6 +196,7 @@ impl SyscallEventBuilder { timestamp: self.timestamp.unwrap_or_else(Utc::now), container_id: self.container_id, comm: self.comm, + details: self.details, } } } @@ -174,7 +231,32 @@ mod tests { .pid(1234) .uid(1000) .syscall_type(SyscallType::Connect) + .details(Some(SyscallDetails::Connect { + dst_addr: Some("192.0.2.10".to_string()), + dst_port: 587, + family: 2, + })) .build(); assert_eq!(event.pid, 1234); + assert_eq!(event.connect_destination(), Some((Some("192.0.2.10"), 587))); + } + + #[test] + fn test_exec_details_accessor() { + let event = SyscallEvent::builder() + .pid(1234) + .uid(1000) + .syscall_type(SyscallType::Execve) + .details(Some(SyscallDetails::Exec { + filename: Some("/usr/sbin/sendmail".to_string()), + args: vec!["/usr/sbin/sendmail".to_string(), "-t".to_string()], + argc: 2, + })) + .build(); + + let (filename, args, argc) = event.exec_details().unwrap(); + assert_eq!(filename.as_deref(), Some("/usr/sbin/sendmail")); + assert_eq!(args, ["/usr/sbin/sendmail", "-t"]); + assert_eq!(argc, 2); } } diff --git a/src/lib.rs b/src/lib.rs index ca67009..4541644 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,6 +71,8 @@ pub use events::syscall::{SyscallEvent, SyscallType}; pub use alerting::{Alert, AlertSeverity, AlertStatus, AlertType}; pub use alerting::{AlertManager, AlertStats}; pub use alerting::{NotificationChannel, NotificationConfig}; +#[cfg(target_os = "linux")] +pub use response::{ActionPipeline, PipelineAction, PipelinePlan}; // Linux-specific pub use collectors::{EbpfLoader, SyscallMonitor}; diff --git a/src/main.rs b/src/main.rs index a0ce50d..c1dc1e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ extern crate tracing_subscriber; mod cli; +use actix::Actor; use actix_cors::Cors; use actix_web::{web, App, HttpServer}; use clap::Parser; @@ -80,6 +81,12 @@ async fn main() -> io::Result<()> { ai_model, ai_api_url, slack_webhook, + webhook_url, + smtp_host, + smtp_port, + smtp_user, + smtp_password, + email_recipients, }) => { let config = sniff::config::SniffConfig::from_env_and_args(sniff::config::SniffArgs { once, @@ -91,6 +98,12 @@ async fn main() -> io::Result<()> { ai_model: ai_model.as_deref(), ai_api_url: ai_api_url.as_deref(), slack_webhook: slack_webhook.as_deref(), + webhook_url: webhook_url.as_deref(), + smtp_host: smtp_host.as_deref(), + smtp_port, + smtp_user: smtp_user.as_deref(), + smtp_password: smtp_password.as_deref(), + email_recipients: email_recipients.as_deref(), }); run_sniff(config).await } @@ -154,10 +167,17 @@ async fn run_serve() -> io::Result<()> { info!("Starting HTTP server on {}...", app_url); let pool_data = web::Data::new(pool); + let websocket_hub = stackdog::api::websocket::WebSocketHub::new().start(); + stackdog::api::websocket::spawn_stats_broadcaster( + websocket_hub.clone(), + pool_data.get_ref().clone(), + ); + let websocket_hub_data = web::Data::new(websocket_hub); HttpServer::new(move || { App::new() .app_data(pool_data.clone()) + .app_data(websocket_hub_data.clone()) .wrap(Cors::permissive()) .wrap(actix_web::middleware::Logger::default()) .configure(stackdog::api::configure_all_routes) @@ -186,6 +206,12 @@ async fn run_sniff(config: sniff::config::SniffConfig) -> io::Result<()> { if config.slack_webhook.is_some() { info!("Slack: configured ✓"); } + if config.webhook_url.is_some() { + info!("Webhook: configured ✓"); + } + if config.smtp_host.is_some() && !config.email_recipients.is_empty() { + info!("Email: configured ✓"); + } let orchestrator = sniff::SniffOrchestrator::new(config).map_err(io::Error::other)?; diff --git a/src/ml/anomaly.rs b/src/ml/anomaly.rs index 71ff343..b28d00f 100644 --- a/src/ml/anomaly.rs +++ b/src/ml/anomaly.rs @@ -2,16 +2,137 @@ //! //! Detects anomalies in security events -use anyhow::Result; +use anyhow::{ensure, Result}; +use serde::{Deserialize, Serialize}; + +use crate::baselines::learning::{BaselineDrift, BaselineLearner}; +use crate::events::security::SecurityEvent; +use crate::ml::features::SecurityFeatures; +use crate::ml::models::isolation_forest::{IsolationForestConfig, IsolationForestModel}; +use crate::ml::scorer::{Scorer, ThreatScore}; + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DetectorConfig { + pub anomaly_threshold: f64, + pub drift_threshold: f64, + pub drift_weight: f64, + pub forest: IsolationForestConfig, +} + +impl Default for DetectorConfig { + fn default() -> Self { + Self { + anomaly_threshold: 0.65, + drift_threshold: 3.0, + drift_weight: 0.35, + forest: IsolationForestConfig::default(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AnomalyAssessment { + pub anomaly_score: f64, + pub drift_score: Option, + pub combined_score: f64, + pub threat_score: ThreatScore, + pub is_anomalous: bool, + pub reasons: Vec, +} /// Anomaly detector pub struct AnomalyDetector { - // TODO: Implement in TASK-014 + config: DetectorConfig, + model: IsolationForestModel, + baseline_learner: BaselineLearner, + scorer: Scorer, } impl AnomalyDetector { pub fn new() -> Result { - Ok(Self {}) + Self::with_config(DetectorConfig::default()) + } + + pub fn with_config(config: DetectorConfig) -> Result { + let baseline_learner = + BaselineLearner::new()?.with_deviation_threshold(config.drift_threshold); + let scorer = Scorer::new()?.with_drift_weight(config.drift_weight); + + Ok(Self { + model: IsolationForestModel::with_config(config.forest.clone()), + baseline_learner, + scorer, + config, + }) + } + + pub fn train(&mut self, training_data: &[SecurityFeatures]) -> Result<()> { + ensure!(!training_data.is_empty(), "training data cannot be empty"); + self.model.fit(training_data); + Ok(()) + } + + pub fn learn_baseline(&mut self, scope: &str, samples: &[SecurityFeatures]) { + for sample in samples { + self.baseline_learner.observe(scope.to_string(), sample); + } + } + + pub fn assess(&self, scope: &str, features: &SecurityFeatures) -> Result { + let anomaly_score = self.model.score(features); + let drift = self.baseline_learner.detect_drift(scope, features); + Ok(self.build_assessment(anomaly_score, drift)) + } + + pub fn assess_events( + &self, + scope: &str, + events: &[SecurityEvent], + window_seconds: f64, + ) -> Result { + let features = SecurityFeatures::from_events(events, window_seconds); + self.assess(scope, &features) + } + + pub fn model(&self) -> &IsolationForestModel { + &self.model + } + + fn build_assessment( + &self, + anomaly_score: f64, + drift: Option, + ) -> AnomalyAssessment { + let mut reasons = Vec::new(); + if anomaly_score >= self.config.anomaly_threshold { + reasons.push(format!("isolation_forest_score={anomaly_score:.3}")); + } + + let drift_score = drift.as_ref().map(|drift| normalize_drift(drift.score)); + if let Some(drift) = drift + .as_ref() + .filter(|drift| !drift.deviating_features.is_empty()) + { + reasons.push(format!( + "baseline_drift={:.3} [{}]", + drift.score, + drift.deviating_features.join(", ") + )); + } + + let combined_score = self.scorer.combined_score(anomaly_score, drift_score); + let threat_score = self.scorer.score(anomaly_score, drift_score); + let is_anomalous = + combined_score >= self.config.anomaly_threshold || drift_score.unwrap_or(0.0) > 0.50; + + AnomalyAssessment { + anomaly_score, + drift_score, + combined_score, + threat_score, + is_anomalous, + reasons, + } } } @@ -20,3 +141,57 @@ impl Default for AnomalyDetector { Self::new().unwrap() } } + +fn normalize_drift(score: f64) -> f64 { + if score <= 0.0 { + 0.0 + } else { + (score / (score + 3.0)).clamp(0.0, 1.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn feature(syscall_rate: f64, network_rate: f64, unique_processes: u32) -> SecurityFeatures { + SecurityFeatures { + syscall_rate, + network_rate, + unique_processes, + privileged_calls: 0, + } + } + + #[test] + fn test_training_requires_samples() { + let mut detector = AnomalyDetector::new().unwrap(); + assert!(detector.train(&[]).is_err()); + } + + #[test] + fn test_detector_flags_real_outlier() { + let mut detector = AnomalyDetector::with_config(DetectorConfig { + anomaly_threshold: 0.55, + ..DetectorConfig::default() + }) + .unwrap(); + let baseline = vec![ + feature(10.0, 2.0, 3), + feature(10.5, 2.1, 3), + feature(9.8, 1.9, 2), + feature(10.2, 2.0, 3), + feature(10.1, 2.2, 3), + ]; + + detector.train(&baseline).unwrap(); + detector.learn_baseline("global", &baseline); + + let assessment = detector.assess("global", &feature(28.0, 9.0, 12)).unwrap(); + + assert!(assessment.is_anomalous); + assert!(assessment.combined_score >= 0.55); + assert!(assessment.threat_score >= ThreatScore::Medium); + assert!(!assessment.reasons.is_empty()); + } +} diff --git a/src/ml/candle_backend.rs b/src/ml/candle_backend.rs index 7802516..aacdade 100644 --- a/src/ml/candle_backend.rs +++ b/src/ml/candle_backend.rs @@ -4,14 +4,63 @@ use anyhow::Result; +use crate::ml::features::SecurityFeatures; + /// Candle ML backend pub struct CandleBackend { - // TODO: Implement in TASK-012 + input_size: usize, } impl CandleBackend { pub fn new() -> Result { - Ok(Self {}) + Ok(Self { input_size: 4 }) + } + + pub fn input_size(&self) -> usize { + self.input_size + } + + pub fn feature_vector(&self, features: &SecurityFeatures) -> Vec { + features + .as_vector() + .into_iter() + .map(|value| value as f32) + .collect() + } + + pub fn batch_feature_vectors(&self, batch: &[SecurityFeatures]) -> Vec> { + batch + .iter() + .map(|features| self.feature_vector(features)) + .collect() + } + + pub fn is_enabled(&self) -> bool { + cfg!(feature = "ml") + } + + #[cfg(feature = "ml")] + pub fn tensor_from_features(&self, features: &SecurityFeatures) -> Result { + let data = self.feature_vector(features); + Ok(candle_core::Tensor::from_vec( + data, + (1, self.input_size), + &candle_core::Device::Cpu, + )?) + } + + #[cfg(feature = "ml")] + pub fn tensor_from_batch(&self, batch: &[SecurityFeatures]) -> Result { + let data = self + .batch_feature_vectors(batch) + .into_iter() + .flatten() + .collect::>(); + Ok(candle_core::Tensor::from_vec( + data, + (batch.len(), self.input_size), + &candle_core::Device::Cpu, + )?) } } @@ -20,3 +69,22 @@ impl Default for CandleBackend { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_feature_vector_conversion() { + let backend = CandleBackend::new().unwrap(); + let features = SecurityFeatures { + syscall_rate: 4.0, + network_rate: 1.5, + unique_processes: 2, + privileged_calls: 1, + }; + + assert_eq!(backend.input_size(), 4); + assert_eq!(backend.feature_vector(&features), vec![4.0, 1.5, 2.0, 1.0]); + } +} diff --git a/src/ml/features.rs b/src/ml/features.rs index 8abe268..f87a7bf 100644 --- a/src/ml/features.rs +++ b/src/ml/features.rs @@ -2,7 +2,15 @@ //! //! Extracts features from security events for anomaly detection +use std::collections::HashSet; + +use serde::{Deserialize, Serialize}; + +use crate::events::security::SecurityEvent; +use crate::events::syscall::SyscallType; + /// Security features for ML model +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct SecurityFeatures { pub syscall_rate: f64, pub network_rate: f64, @@ -19,6 +27,85 @@ impl SecurityFeatures { privileged_calls: 0, } } + + /// Build a feature vector from a batch of security events observed over a window. + pub fn from_events(events: &[SecurityEvent], window_seconds: f64) -> Self { + if events.is_empty() { + return Self::default(); + } + + let effective_window = if window_seconds > 0.0 { + window_seconds + } else { + 1.0 + }; + + let mut syscall_count = 0usize; + let mut network_count = 0usize; + let mut unique_processes = HashSet::new(); + let mut privileged_calls = 0u32; + + for event in events { + match event { + SecurityEvent::Syscall(syscall) => { + syscall_count += 1; + unique_processes.insert(syscall.pid); + + if matches!( + syscall.syscall_type, + SyscallType::Ptrace + | SyscallType::Setuid + | SyscallType::Setgid + | SyscallType::Mount + | SyscallType::Umount + ) { + privileged_calls += 1; + } + + if matches!( + syscall.syscall_type, + SyscallType::Connect + | SyscallType::Accept + | SyscallType::Bind + | SyscallType::Listen + | SyscallType::Socket + | SyscallType::Sendto + ) { + network_count += 1; + } + } + SecurityEvent::Network(_) => { + network_count += 1; + } + SecurityEvent::Container(_) | SecurityEvent::Alert(_) => {} + } + } + + Self { + syscall_rate: syscall_count as f64 / effective_window, + network_rate: network_count as f64 / effective_window, + unique_processes: unique_processes.len() as u32, + privileged_calls, + } + } + + pub fn as_vector(&self) -> [f64; 4] { + [ + self.syscall_rate, + self.network_rate, + self.unique_processes as f64, + self.privileged_calls as f64, + ] + } + + pub fn from_vector(vector: [f64; 4]) -> Self { + Self { + syscall_rate: vector[0], + network_rate: vector[1], + unique_processes: vector[2].round().max(0.0) as u32, + privileged_calls: vector[3].round().max(0.0) as u32, + } + } } impl Default for SecurityFeatures { @@ -26,3 +113,51 @@ impl Default for SecurityFeatures { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::events::security::{NetworkEvent, SecurityEvent}; + use crate::events::syscall::{SyscallEvent, SyscallType}; + use chrono::Utc; + + #[test] + fn test_feature_vector_creation_from_events() { + let events = vec![ + SecurityEvent::Syscall(SyscallEvent::new(100, 0, SyscallType::Execve, Utc::now())), + SecurityEvent::Syscall(SyscallEvent::new(100, 0, SyscallType::Connect, Utc::now())), + SecurityEvent::Syscall(SyscallEvent::new(200, 0, SyscallType::Ptrace, Utc::now())), + SecurityEvent::Network(NetworkEvent { + src_ip: "10.0.0.2".to_string(), + dst_ip: "198.51.100.12".to_string(), + src_port: 40000, + dst_port: 443, + protocol: "tcp".to_string(), + timestamp: Utc::now(), + container_id: Some("abc".to_string()), + }), + ]; + + let features = SecurityFeatures::from_events(&events, 2.0); + + assert_eq!(features.syscall_rate, 1.5); + assert_eq!(features.network_rate, 1.0); + assert_eq!(features.unique_processes, 2); + assert_eq!(features.privileged_calls, 1); + } + + #[test] + fn test_feature_vector_round_trip() { + let features = SecurityFeatures { + syscall_rate: 12.5, + network_rate: 3.0, + unique_processes: 7, + privileged_calls: 2, + }; + + assert_eq!( + SecurityFeatures::from_vector(features.as_vector()), + features + ); + } +} diff --git a/src/ml/models/isolation_forest.rs b/src/ml/models/isolation_forest.rs index 9af19f7..ea8e7b4 100644 --- a/src/ml/models/isolation_forest.rs +++ b/src/ml/models/isolation_forest.rs @@ -2,14 +2,160 @@ //! //! Implementation of Isolation Forest for anomaly detection using Candle +use serde::{Deserialize, Serialize}; + +use crate::ml::features::SecurityFeatures; + /// Isolation Forest model +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct IsolationForestModel { - // TODO: Implement in TASK-014 + config: IsolationForestConfig, + trees: Vec, + sample_size: usize, +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct IsolationForestConfig { + pub trees: usize, + pub sample_size: usize, + pub max_depth: usize, + pub seed: u64, +} + +impl Default for IsolationForestConfig { + fn default() -> Self { + Self { + trees: 64, + sample_size: 32, + max_depth: 8, + seed: 0x5eed_cafe_d00d_f00d, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct IsolationTree { + root: IsolationNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +enum IsolationNode { + External { + size: usize, + }, + Internal { + feature: usize, + threshold: f64, + left: Box, + right: Box, + }, +} + +#[derive(Debug, Clone)] +struct SimpleRng { + state: u64, +} + +impl SimpleRng { + fn new(seed: u64) -> Self { + Self { state: seed } + } + + fn next_u64(&mut self) -> u64 { + self.state = self + .state + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + self.state + } + + fn gen_range_usize(&mut self, upper: usize) -> usize { + if upper <= 1 { + 0 + } else { + (self.next_u64() % upper as u64) as usize + } + } + + fn gen_range_f64(&mut self, min: f64, max: f64) -> f64 { + if (max - min).abs() <= f64::EPSILON { + min + } else { + let fraction = self.next_u64() as f64 / u64::MAX as f64; + min + fraction * (max - min) + } + } } impl IsolationForestModel { pub fn new() -> Self { - Self {} + Self::with_config(IsolationForestConfig::default()) + } + + pub fn with_config(config: IsolationForestConfig) -> Self { + Self { + config, + trees: Vec::new(), + sample_size: 0, + } + } + + pub fn fit(&mut self, dataset: &[SecurityFeatures]) { + self.trees.clear(); + if dataset.is_empty() { + self.sample_size = 0; + return; + } + + let rows = dataset + .iter() + .map(SecurityFeatures::as_vector) + .collect::>(); + + self.sample_size = self.config.sample_size.min(rows.len()).max(1); + let max_depth = self + .config + .max_depth + .max((self.sample_size as f64).log2().ceil() as usize); + + let mut rng = SimpleRng::new(self.config.seed); + self.trees = (0..self.config.trees) + .map(|_| { + let sample = sample_without_replacement(&rows, self.sample_size, &mut rng); + IsolationTree { + root: build_tree(&sample, 0, max_depth, &mut rng), + } + }) + .collect(); + } + + pub fn score(&self, sample: &SecurityFeatures) -> f64 { + if self.trees.is_empty() || self.sample_size <= 1 { + return 0.0; + } + + let vector = sample.as_vector(); + let average_path = self + .trees + .iter() + .map(|tree| path_length(&tree.root, &vector, 0)) + .sum::() + / self.trees.len() as f64; + + let normalization = average_path_length(self.sample_size); + if normalization <= f64::EPSILON { + 0.0 + } else { + 2f64.powf(-(average_path / normalization)).clamp(0.0, 1.0) + } + } + + pub fn is_trained(&self) -> bool { + !self.trees.is_empty() + } + + pub fn sample_size(&self) -> usize { + self.sample_size } } @@ -18,3 +164,174 @@ impl Default for IsolationForestModel { Self::new() } } + +fn sample_without_replacement( + data: &[[f64; 4]], + count: usize, + rng: &mut SimpleRng, +) -> Vec<[f64; 4]> { + if count >= data.len() { + return data.to_vec(); + } + + let mut indices: Vec = (0..data.len()).collect(); + for idx in 0..count { + let swap_idx = idx + rng.gen_range_usize(data.len() - idx); + indices.swap(idx, swap_idx); + } + + indices + .into_iter() + .take(count) + .map(|index| data[index]) + .collect() +} + +fn build_tree( + rows: &[[f64; 4]], + depth: usize, + max_depth: usize, + rng: &mut SimpleRng, +) -> IsolationNode { + if rows.len() <= 1 || depth >= max_depth { + return IsolationNode::External { size: rows.len() }; + } + + let varying_features = (0..4) + .filter_map(|feature| { + let (min, max) = min_max(rows, feature); + if (max - min).abs() > f64::EPSILON { + Some((feature, min, max)) + } else { + None + } + }) + .collect::>(); + + let Some(&(feature, min, max)) = + varying_features.get(rng.gen_range_usize(varying_features.len())) + else { + return IsolationNode::External { size: rows.len() }; + }; + + let threshold = rng.gen_range_f64(min, max); + let (left_rows, right_rows): (Vec<_>, Vec<_>) = rows + .iter() + .copied() + .partition(|row| row[feature] < threshold); + + if left_rows.is_empty() || right_rows.is_empty() { + return IsolationNode::External { size: rows.len() }; + } + + IsolationNode::Internal { + feature, + threshold, + left: Box::new(build_tree(&left_rows, depth + 1, max_depth, rng)), + right: Box::new(build_tree(&right_rows, depth + 1, max_depth, rng)), + } +} + +fn min_max(rows: &[[f64; 4]], feature: usize) -> (f64, f64) { + rows.iter() + .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), row| { + (min.min(row[feature]), max.max(row[feature])) + }) +} + +fn path_length(node: &IsolationNode, sample: &[f64; 4], depth: usize) -> f64 { + match node { + IsolationNode::External { size } => depth as f64 + average_path_length(*size), + IsolationNode::Internal { + feature, + threshold, + left, + right, + } => { + if sample[*feature] < *threshold { + path_length(left, sample, depth + 1) + } else { + path_length(right, sample, depth + 1) + } + } + } +} + +fn average_path_length(sample_size: usize) -> f64 { + match sample_size { + 0 | 1 => 0.0, + 2 => 1.0, + n => { + let harmonic = (1..n).map(|value| 1.0 / value as f64).sum::(); + 2.0 * harmonic - (2.0 * (n - 1) as f64 / n as f64) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn feature(syscall_rate: f64, network_rate: f64, unique_processes: u32) -> SecurityFeatures { + SecurityFeatures { + syscall_rate, + network_rate, + unique_processes, + privileged_calls: 0, + } + } + + #[test] + fn test_anomaly_scoring_ranks_outlier_higher_than_inlier() { + let mut model = IsolationForestModel::with_config(IsolationForestConfig { + trees: 48, + sample_size: 16, + max_depth: 6, + seed: 42, + }); + let training = vec![ + feature(10.0, 2.0, 3), + feature(11.0, 2.1, 3), + feature(9.8, 1.9, 2), + feature(10.5, 2.2, 3), + feature(10.2, 2.0, 2), + feature(11.1, 1.8, 3), + feature(9.9, 2.3, 3), + feature(10.7, 2.0, 2), + ]; + model.fit(&training); + + let inlier = model.score(&feature(10.4, 2.1, 3)); + let outlier = model.score(&feature(30.0, 10.0, 15)); + + assert!(model.is_trained()); + assert!(outlier > inlier); + assert!(outlier > 0.50); + } + + #[test] + fn test_model_persistence_round_trip() { + let mut model = IsolationForestModel::with_config(IsolationForestConfig { + trees: 12, + sample_size: 8, + max_depth: 5, + seed: 99, + }); + let training = vec![ + feature(10.0, 2.0, 3), + feature(11.0, 2.2, 3), + feature(9.5, 1.9, 2), + feature(10.7, 2.1, 3), + ]; + model.fit(&training); + + let serialized = serde_json::to_string(&model).unwrap(); + let restored: IsolationForestModel = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(restored.sample_size(), model.sample_size()); + assert_eq!( + restored.score(&feature(25.0, 8.0, 10)), + model.score(&feature(25.0, 8.0, 10)) + ); + } +} diff --git a/src/ml/scorer.rs b/src/ml/scorer.rs index f331ac7..4f38666 100644 --- a/src/ml/scorer.rs +++ b/src/ml/scorer.rs @@ -2,10 +2,10 @@ //! //! Calculates threat scores from ML output -use anyhow::Result; +use anyhow::{ensure, Result}; /// Threat score levels -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ThreatScore { Normal, Low, @@ -14,14 +14,114 @@ pub enum ThreatScore { Critical, } +impl ThreatScore { + fn elevate(self) -> Self { + match self { + ThreatScore::Normal => ThreatScore::Low, + ThreatScore::Low => ThreatScore::Medium, + ThreatScore::Medium => ThreatScore::High, + ThreatScore::High | ThreatScore::Critical => ThreatScore::Critical, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ScoreThresholds { + pub low: f64, + pub medium: f64, + pub high: f64, + pub critical: f64, +} + +impl Default for ScoreThresholds { + fn default() -> Self { + Self { + low: 0.30, + medium: 0.50, + high: 0.75, + critical: 0.90, + } + } +} + /// Threat scorer pub struct Scorer { - // TODO: Implement in TASK-016 + thresholds: ScoreThresholds, + drift_weight: f64, } impl Scorer { pub fn new() -> Result { - Ok(Self {}) + Self::with_thresholds(ScoreThresholds::default()) + } + + pub fn with_thresholds(thresholds: ScoreThresholds) -> Result { + ensure!( + thresholds.low >= 0.0 + && thresholds.low <= thresholds.medium + && thresholds.medium <= thresholds.high + && thresholds.high <= thresholds.critical + && thresholds.critical <= 1.0, + "invalid score thresholds" + ); + + Ok(Self { + thresholds, + drift_weight: 0.35, + }) + } + + pub fn with_drift_weight(mut self, weight: f64) -> Self { + self.drift_weight = weight.clamp(0.0, 1.0); + self + } + + pub fn combined_score(&self, anomaly_score: f64, drift_score: Option) -> f64 { + let anomaly = anomaly_score.clamp(0.0, 1.0); + match drift_score { + Some(drift) => { + let drift = drift.clamp(0.0, 1.0); + ((1.0 - self.drift_weight) * anomaly + self.drift_weight * drift).clamp(0.0, 1.0) + } + None => anomaly, + } + } + + pub fn score(&self, anomaly_score: f64, drift_score: Option) -> ThreatScore { + let combined = self.combined_score(anomaly_score, drift_score); + + if combined >= self.thresholds.critical { + ThreatScore::Critical + } else if combined >= self.thresholds.high { + ThreatScore::High + } else if combined >= self.thresholds.medium { + ThreatScore::Medium + } else if combined >= self.thresholds.low { + ThreatScore::Low + } else { + ThreatScore::Normal + } + } + + pub fn aggregate(&self, scores: &[ThreatScore]) -> ThreatScore { + let Some(mut aggregate) = scores.iter().copied().max() else { + return ThreatScore::Normal; + }; + + let elevated_count = scores + .iter() + .filter(|score| **score >= ThreatScore::Medium) + .count(); + + if elevated_count >= 3 { + aggregate = aggregate.elevate(); + } + + aggregate + } + + pub fn threshold_exceeded(&self, score: ThreatScore, threshold: ThreatScore) -> bool { + score >= threshold } } @@ -30,3 +130,38 @@ impl Default for Scorer { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_threat_score_calculation() { + let scorer = Scorer::new().unwrap(); + assert_eq!(scorer.score(0.15, None), ThreatScore::Normal); + assert_eq!(scorer.score(0.35, None), ThreatScore::Low); + assert_eq!(scorer.score(0.60, None), ThreatScore::Medium); + assert_eq!(scorer.score(0.80, None), ThreatScore::High); + assert_eq!(scorer.score(0.95, None), ThreatScore::Critical); + } + + #[test] + fn test_score_aggregation() { + let scorer = Scorer::new().unwrap(); + let aggregated = scorer.aggregate(&[ + ThreatScore::Low, + ThreatScore::Medium, + ThreatScore::High, + ThreatScore::Medium, + ]); + + assert_eq!(aggregated, ThreatScore::Critical); + } + + #[test] + fn test_threshold_detection() { + let scorer = Scorer::new().unwrap(); + assert!(scorer.threshold_exceeded(ThreatScore::High, ThreatScore::Medium)); + assert!(!scorer.threshold_exceeded(ThreatScore::Low, ThreatScore::High)); + } +} diff --git a/src/models/api/alerts.rs b/src/models/api/alerts.rs index 0bf7459..ed2b932 100644 --- a/src/models/api/alerts.rs +++ b/src/models/api/alerts.rs @@ -1,5 +1,6 @@ //! Alert API response types +use crate::database::models::Alert; use serde::{Deserialize, Serialize}; /// Alert response @@ -41,3 +42,19 @@ impl Default for AlertStatsResponse { Self::new() } } + +impl From for AlertResponse { + fn from(alert: Alert) -> Self { + Self { + id: alert.id, + alert_type: alert.alert_type.to_string(), + severity: alert.severity.to_string(), + message: alert.message, + status: alert.status.to_string(), + timestamp: alert.timestamp, + metadata: alert + .metadata + .and_then(|metadata| serde_json::to_value(metadata).ok()), + } + } +} diff --git a/src/models/api/containers.rs b/src/models/api/containers.rs index ee75713..df041ef 100644 --- a/src/models/api/containers.rs +++ b/src/models/api/containers.rs @@ -20,16 +20,20 @@ pub struct ContainerResponse { pub struct ContainerSecurityStatus { pub state: String, pub threats: u32, - pub vulnerabilities: u32, - pub last_scan: String, + pub vulnerabilities: Option, + pub last_scan: Option, } /// Network activity #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NetworkActivity { - pub inbound_connections: u32, - pub outbound_connections: u32, - pub blocked_connections: u32, + pub inbound_connections: Option, + pub outbound_connections: Option, + pub blocked_connections: Option, + pub received_bytes: Option, + pub transmitted_bytes: Option, + pub received_packets: Option, + pub transmitted_packets: Option, pub suspicious_activity: bool, } diff --git a/src/models/api/security.rs b/src/models/api/security.rs index 62bb314..6144692 100644 --- a/src/models/api/security.rs +++ b/src/models/api/security.rs @@ -15,12 +15,22 @@ pub struct SecurityStatusResponse { impl SecurityStatusResponse { pub fn new() -> Self { + Self::from_state(100, 0, 0, 0, 0) + } + + pub fn from_state( + overall_score: u32, + active_threats: u32, + quarantined_containers: u32, + alerts_new: u32, + alerts_acknowledged: u32, + ) -> Self { Self { - overall_score: 75, - active_threats: 0, - quarantined_containers: 0, - alerts_new: 0, - alerts_acknowledged: 0, + overall_score, + active_threats, + quarantined_containers, + alerts_new, + alerts_acknowledged, last_updated: chrono::Utc::now().to_rfc3339(), } } @@ -31,3 +41,18 @@ impl Default for SecurityStatusResponse { Self::new() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_security_status_from_state() { + let status = SecurityStatusResponse::from_state(64, 2, 1, 3, 1); + assert_eq!(status.active_threats, 2); + assert_eq!(status.quarantined_containers, 1); + assert_eq!(status.alerts_new, 3); + assert_eq!(status.alerts_acknowledged, 1); + assert_eq!(status.overall_score, 64); + } +} diff --git a/src/response/mod.rs b/src/response/mod.rs index 6760278..7316c9f 100644 --- a/src/response/mod.rs +++ b/src/response/mod.rs @@ -3,7 +3,11 @@ //! Automated threat response actions pub mod actions; +#[cfg(target_os = "linux")] pub mod pipeline; /// Marker struct for module tests pub struct ResponseMarker; + +#[cfg(target_os = "linux")] +pub use pipeline::{ActionPipeline, PipelineAction, PipelinePlan}; diff --git a/src/response/pipeline.rs b/src/response/pipeline.rs index cd01e3a..f983267 100644 --- a/src/response/pipeline.rs +++ b/src/response/pipeline.rs @@ -1,15 +1,150 @@ //! Response action pipeline use anyhow::Result; +use std::collections::HashMap; -/// Action pipeline +use crate::firewall::{ResponseAction, ResponseChain, ResponseExecutor, ResponseType}; + +/// A named response template that can be executed directly or converted to a chain. +#[derive(Debug, Clone)] +pub struct PipelineAction { + name: String, + action: ResponseAction, +} + +impl PipelineAction { + pub fn new(name: impl Into, action: ResponseAction) -> Self { + Self { + name: name.into(), + action, + } + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn action(&self) -> &ResponseAction { + &self.action + } +} + +/// A reusable response plan composed of ordered actions. +#[derive(Debug, Clone)] +pub struct PipelinePlan { + name: String, + actions: Vec, + stop_on_failure: bool, +} + +impl PipelinePlan { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + actions: Vec::new(), + stop_on_failure: true, + } + } + + pub fn add_action(&mut self, action: PipelineAction) { + self.actions.push(action); + } + + pub fn set_stop_on_failure(&mut self, stop_on_failure: bool) { + self.stop_on_failure = stop_on_failure; + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn actions(&self) -> &[PipelineAction] { + &self.actions + } + + pub fn to_chain(&self) -> ResponseChain { + let mut chain = ResponseChain::new(self.name.clone()); + chain.set_stop_on_failure(self.stop_on_failure); + for action in &self.actions { + chain.add_action(action.action.clone()); + } + chain + } +} + +/// Action pipeline for reusable response orchestration. pub struct ActionPipeline { - // TODO: Implement in TASK-011 + executor: ResponseExecutor, + plans: HashMap, } impl ActionPipeline { pub fn new() -> Result { - Ok(Self {}) + Ok(Self { + executor: ResponseExecutor::new()?, + plans: HashMap::new(), + }) + } + + pub fn with_executor(executor: ResponseExecutor) -> Self { + Self { + executor, + plans: HashMap::new(), + } + } + + pub fn register_plan(&mut self, plan: PipelinePlan) { + self.plans.insert(plan.name().to_string(), plan); + } + + pub fn get_plan(&self, name: &str) -> Option<&PipelinePlan> { + self.plans.get(name) + } + + pub fn has_plan(&self, name: &str) -> bool { + self.plans.contains_key(name) + } + + pub fn execute_plan(&mut self, name: &str) -> Result<()> { + let plan = self + .plans + .get(name) + .ok_or_else(|| anyhow::anyhow!("Response plan not found: {}", name))?; + self.executor.execute_chain(&plan.to_chain()) + } + + pub fn execute_action(&mut self, action: &ResponseAction) -> Result<()> { + self.executor.execute(action) + } + + pub fn execution_log(&self) -> Vec { + self.executor.get_log() + } + + pub fn clear_execution_log(&mut self) { + self.executor.clear_log(); + } + + pub fn register_default_security_plans(&mut self) { + let mut quarantine_plan = PipelinePlan::new("quarantine-container"); + quarantine_plan.add_action(PipelineAction::new( + "quarantine", + ResponseAction::new( + ResponseType::QuarantineContainer("{{container_id}}".to_string()), + "Quarantine compromised container".to_string(), + ), + )); + self.register_plan(quarantine_plan); + + let mut block_mail_plan = PipelinePlan::new("block-mail-port"); + block_mail_plan.add_action(PipelineAction::new( + "block-port", + ResponseAction::new( + ResponseType::BlockPort(25), + "Block outbound SMTP traffic".to_string(), + ), + )); + self.register_plan(block_mail_plan); } } @@ -18,3 +153,62 @@ impl Default for ActionPipeline { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pipeline_plan_builds_chain() { + let mut plan = PipelinePlan::new("test-plan"); + plan.set_stop_on_failure(false); + plan.add_action(PipelineAction::new( + "log", + ResponseAction::new(ResponseType::LogAction("ok".to_string()), "Log".to_string()), + )); + + let chain = plan.to_chain(); + assert_eq!(chain.name(), "test-plan"); + assert_eq!(chain.action_count(), 1); + } + + #[test] + fn test_pipeline_registers_and_finds_plan() { + let mut pipeline = ActionPipeline::new().unwrap(); + let plan = PipelinePlan::new("mail-abuse"); + + pipeline.register_plan(plan); + + assert!(pipeline.has_plan("mail-abuse")); + assert!(pipeline.get_plan("mail-abuse").is_some()); + } + + #[test] + fn test_pipeline_execute_unknown_plan_fails() { + let mut pipeline = ActionPipeline::new().unwrap(); + let result = pipeline.execute_plan("missing"); + assert!(result.is_err()); + } + + #[test] + fn test_pipeline_execute_action_records_log() { + let mut pipeline = ActionPipeline::new().unwrap(); + let action = + ResponseAction::new(ResponseType::LogAction("ok".to_string()), "Log".to_string()); + + pipeline.execute_action(&action).unwrap(); + + let log = pipeline.execution_log(); + assert_eq!(log.len(), 1); + assert!(log[0].success()); + } + + #[test] + fn test_pipeline_register_default_security_plans() { + let mut pipeline = ActionPipeline::new().unwrap(); + pipeline.register_default_security_plans(); + + assert!(pipeline.has_plan("quarantine-container")); + assert!(pipeline.has_plan("block-mail-port")); + } +} diff --git a/src/rules/builtin.rs b/src/rules/builtin.rs index c5da7b1..f3d6ebf 100644 --- a/src/rules/builtin.rs +++ b/src/rules/builtin.rs @@ -3,7 +3,7 @@ //! Pre-defined rules for common security scenarios use crate::events::security::SecurityEvent; -use crate::events::syscall::SyscallType; +use crate::events::syscall::{SyscallDetails, SyscallType}; use crate::rules::rule::{Rule, RuleResult}; /// Syscall allowlist rule @@ -159,6 +159,53 @@ impl Rule for NetworkConnectionRule { } } +/// SMTP connection rule +/// Matches outbound connections to common mail submission ports. +pub struct SmtpConnectionRule { + ports: Vec, +} + +impl SmtpConnectionRule { + pub fn new() -> Self { + Self { + ports: vec![25, 465, 587, 2525], + } + } +} + +impl Default for SmtpConnectionRule { + fn default() -> Self { + Self::new() + } +} + +impl Rule for SmtpConnectionRule { + fn evaluate(&self, event: &SecurityEvent) -> RuleResult { + let SecurityEvent::Syscall(syscall_event) = event else { + return RuleResult::NoMatch; + }; + + if syscall_event.syscall_type != SyscallType::Connect { + return RuleResult::NoMatch; + } + + match syscall_event.details.as_ref() { + Some(SyscallDetails::Connect { dst_port, .. }) if self.ports.contains(dst_port) => { + RuleResult::Match + } + _ => RuleResult::NoMatch, + } + } + + fn name(&self) -> &str { + "smtp_connection" + } + + fn priority(&self) -> u32 { + 20 + } +} + /// File access rule /// Matches file-related syscalls pub struct FileAccessRule { @@ -205,7 +252,7 @@ impl Rule for FileAccessRule { #[cfg(test)] mod tests { use super::*; - use crate::events::syscall::SyscallEvent; + use crate::events::syscall::{SyscallDetails, SyscallEvent}; use chrono::Utc; #[test] @@ -231,4 +278,44 @@ mod tests { )); assert!(rule.evaluate(&event).is_match()); } + + #[test] + fn test_smtp_connection_rule_matches_mail_port() { + let rule = SmtpConnectionRule::new(); + let event = SecurityEvent::Syscall( + SyscallEvent::builder() + .pid(1234) + .uid(1000) + .syscall_type(SyscallType::Connect) + .timestamp(Utc::now()) + .details(Some(SyscallDetails::Connect { + dst_addr: Some("198.51.100.25".to_string()), + dst_port: 587, + family: 2, + })) + .build(), + ); + + assert!(rule.evaluate(&event).is_match()); + } + + #[test] + fn test_smtp_connection_rule_ignores_non_mail_port() { + let rule = SmtpConnectionRule::new(); + let event = SecurityEvent::Syscall( + SyscallEvent::builder() + .pid(1234) + .uid(1000) + .syscall_type(SyscallType::Connect) + .timestamp(Utc::now()) + .details(Some(SyscallDetails::Connect { + dst_addr: Some("198.51.100.25".to_string()), + dst_port: 443, + family: 2, + })) + .build(), + ); + + assert!(rule.evaluate(&event).is_no_match()); + } } diff --git a/src/rules/threat_scorer.rs b/src/rules/threat_scorer.rs index 7e7c30d..b1792ef 100644 --- a/src/rules/threat_scorer.rs +++ b/src/rules/threat_scorer.rs @@ -239,6 +239,9 @@ pub fn calculate_severity_from_scores(scores: &[ThreatScore]) -> Severity { #[cfg(test)] mod tests { use super::*; + use crate::events::security::SecurityEvent; + use crate::events::syscall::{SyscallDetails, SyscallEvent, SyscallType}; + use chrono::Utc; #[test] fn test_threat_score_creation() { @@ -277,4 +280,92 @@ mod tests { assert_eq!(config.multiplier(), 1.5); assert!(config.time_decay_enabled()); } + + fn syscall_event(syscall_type: SyscallType) -> SecurityEvent { + SyscallEvent::builder() + .pid(1234) + .uid(1000) + .syscall_type(syscall_type) + .timestamp(Utc::now()) + .build() + .into() + } + + #[test] + fn test_calculate_score_returns_zero_for_non_matching_event() { + let scorer = ThreatScorer::new(); + let event = SecurityEvent::Network(crate::events::security::NetworkEvent { + src_ip: "172.17.0.2".to_string(), + dst_ip: "198.51.100.10".to_string(), + src_port: 12345, + dst_port: 443, + protocol: "tcp".to_string(), + timestamp: Utc::now(), + container_id: Some("abc123".to_string()), + }); + + let score = scorer.calculate_score(&event); + assert_eq!(score.value(), 0); + assert_eq!(score.severity(), Severity::Info); + } + + #[test] + fn test_calculate_score_for_builtin_signature_match() { + let scorer = ThreatScorer::new(); + let event = syscall_event(SyscallType::Ptrace); + + let score = scorer.calculate_score(&event); + assert_eq!(score.value(), 47); + assert_eq!(score.severity(), Severity::Medium); + assert!(!score.is_critical()); + } + + #[test] + fn test_calculate_score_respects_config_multiplier() { + let scorer = ThreatScorer::with_config( + ScoringConfig::default() + .with_base_score(80) + .with_multiplier(1.25), + ); + let event = syscall_event(SyscallType::Connect); + + let score = scorer.calculate_score(&event); + assert_eq!(score.value(), 67); + assert_eq!(score.severity(), Severity::Medium); + } + + #[test] + fn test_calculate_score_for_smtp_connect_event_uses_builtin_connect_signature() { + let scorer = ThreatScorer::new(); + let event = SecurityEvent::Syscall( + SyscallEvent::builder() + .pid(1234) + .uid(1000) + .syscall_type(SyscallType::Connect) + .timestamp(Utc::now()) + .details(Some(SyscallDetails::Connect { + dst_addr: Some("198.51.100.25".to_string()), + dst_port: 587, + family: 2, + })) + .build(), + ); + + let score = scorer.calculate_score(&event); + assert_eq!(score.value(), 33); + assert_eq!(score.severity(), Severity::Low); + } + + #[test] + fn test_calculate_cumulative_score_applies_average_and_bonus() { + let scorer = ThreatScorer::new(); + let events = vec![ + syscall_event(SyscallType::Ptrace), + syscall_event(SyscallType::Connect), + ]; + + let score = scorer.calculate_cumulative_score(&events); + assert_eq!(score.value(), 40); + assert_eq!(score.severity(), Severity::Medium); + } } diff --git a/src/sniff/config.rs b/src/sniff/config.rs index 6ddef85..cee76bf 100644 --- a/src/sniff/config.rs +++ b/src/sniff/config.rs @@ -52,6 +52,16 @@ pub struct SniffConfig { pub slack_webhook: Option, /// Generic webhook URL for alert notifications pub webhook_url: Option, + /// SMTP host for email notifications + pub smtp_host: Option, + /// SMTP port for email notifications + pub smtp_port: Option, + /// SMTP username / sender address for email notifications + pub smtp_user: Option, + /// SMTP password for email notifications + pub smtp_password: Option, + /// Email recipients for alert notifications + pub email_recipients: Vec, } /// Arguments for building a SniffConfig @@ -65,6 +75,12 @@ pub struct SniffArgs<'a> { pub ai_model: Option<&'a str>, pub ai_api_url: Option<&'a str>, pub slack_webhook: Option<&'a str>, + pub webhook_url: Option<&'a str>, + pub smtp_host: Option<&'a str>, + pub smtp_port: Option, + pub smtp_user: Option<&'a str>, + pub smtp_password: Option<&'a str>, + pub email_recipients: Option<&'a str>, } impl SniffConfig { @@ -130,7 +146,39 @@ impl SniffConfig { .slack_webhook .map(|s| s.to_string()) .or_else(|| env::var("STACKDOG_SLACK_WEBHOOK_URL").ok()), - webhook_url: env::var("STACKDOG_WEBHOOK_URL").ok(), + webhook_url: args + .webhook_url + .map(|s| s.to_string()) + .or_else(|| env::var("STACKDOG_WEBHOOK_URL").ok()), + smtp_host: args + .smtp_host + .map(|s| s.to_string()) + .or_else(|| env::var("STACKDOG_SMTP_HOST").ok()), + smtp_port: args.smtp_port.or_else(|| { + env::var("STACKDOG_SMTP_PORT") + .ok() + .and_then(|v| v.parse().ok()) + }), + smtp_user: args + .smtp_user + .map(|s| s.to_string()) + .or_else(|| env::var("STACKDOG_SMTP_USER").ok()), + smtp_password: args + .smtp_password + .map(|s| s.to_string()) + .or_else(|| env::var("STACKDOG_SMTP_PASSWORD").ok()), + email_recipients: args + .email_recipients + .map(|s| s.to_string()) + .or_else(|| env::var("STACKDOG_EMAIL_RECIPIENTS").ok()) + .map(|recipients| { + recipients + .split(',') + .map(|recipient| recipient.trim().to_string()) + .filter(|recipient| !recipient.is_empty()) + .collect() + }) + .unwrap_or_default(), } } } @@ -153,6 +201,11 @@ mod tests { env::remove_var("STACKDOG_SNIFF_INTERVAL"); env::remove_var("STACKDOG_SLACK_WEBHOOK_URL"); env::remove_var("STACKDOG_WEBHOOK_URL"); + env::remove_var("STACKDOG_SMTP_HOST"); + env::remove_var("STACKDOG_SMTP_PORT"); + env::remove_var("STACKDOG_SMTP_USER"); + env::remove_var("STACKDOG_SMTP_PASSWORD"); + env::remove_var("STACKDOG_EMAIL_RECIPIENTS"); } #[test] @@ -179,6 +232,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert!(!config.once); assert!(!config.consume); @@ -206,6 +265,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert!(config.once); @@ -232,6 +297,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert!(config @@ -268,6 +339,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert_eq!(config.ai_api_url, "https://api.openai.com/v1"); assert_eq!(config.ai_api_key, Some("sk-test123".into())); @@ -293,6 +370,12 @@ mod tests { ai_model: Some("qwen2.5-coder:latest"), ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); // "ollama" maps to OpenAi internally (same API protocol) assert_eq!(config.ai_provider, AiProvider::OpenAi); @@ -319,6 +402,12 @@ mod tests { ai_model: Some("llama3"), ai_api_url: Some("http://localhost:11434/v1"), slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); // CLI args take priority over env vars assert_eq!(config.ai_model, "llama3"); @@ -342,6 +431,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: Some("https://hooks.slack.com/services/T/B/xxx"), + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert_eq!( config.slack_webhook.as_deref(), @@ -370,6 +465,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert_eq!( config.slack_webhook.as_deref(), @@ -398,6 +499,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: Some("https://hooks.slack.com/services/T/B/cli"), + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); assert_eq!( config.slack_webhook.as_deref(), @@ -406,4 +513,58 @@ mod tests { clear_sniff_env(); } + + #[test] + fn test_notification_channels_from_env() { + let _lock = ENV_MUTEX.lock().unwrap(); + clear_sniff_env(); + env::set_var( + "STACKDOG_WEBHOOK_URL", + "https://example.test/hooks/stackdog", + ); + env::set_var("STACKDOG_SMTP_HOST", "smtp.example.com"); + env::set_var("STACKDOG_SMTP_PORT", "2525"); + env::set_var("STACKDOG_SMTP_USER", "alerts@example.com"); + env::set_var("STACKDOG_SMTP_PASSWORD", "secret"); + env::set_var( + "STACKDOG_EMAIL_RECIPIENTS", + "soc@example.com, oncall@example.com", + ); + + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, + }); + + assert_eq!( + config.webhook_url.as_deref(), + Some("https://example.test/hooks/stackdog") + ); + assert_eq!(config.smtp_host.as_deref(), Some("smtp.example.com")); + assert_eq!(config.smtp_port, Some(2525)); + assert_eq!(config.smtp_user.as_deref(), Some("alerts@example.com")); + assert_eq!(config.smtp_password.as_deref(), Some("secret")); + assert_eq!( + config.email_recipients, + vec![ + "soc@example.com".to_string(), + "oncall@example.com".to_string() + ] + ); + + clear_sniff_env(); + } } diff --git a/src/sniff/mod.rs b/src/sniff/mod.rs index e9c4299..79fe8a5 100644 --- a/src/sniff/mod.rs +++ b/src/sniff/mod.rs @@ -40,6 +40,22 @@ impl SniffOrchestrator { if let Some(ref url) = config.webhook_url { notification_config = notification_config.with_webhook_url(url.clone()); } + if let Some(ref host) = config.smtp_host { + notification_config = notification_config.with_smtp_host(host.clone()); + } + if let Some(port) = config.smtp_port { + notification_config = notification_config.with_smtp_port(port); + } + if let Some(ref user) = config.smtp_user { + notification_config = notification_config.with_smtp_user(user.clone()); + } + if let Some(ref password) = config.smtp_password { + notification_config = notification_config.with_smtp_password(password.clone()); + } + if !config.email_recipients.is_empty() { + notification_config = + notification_config.with_email_recipients(config.email_recipients.clone()); + } let reporter = Reporter::new(notification_config); Ok(Self { @@ -155,7 +171,7 @@ impl SniffOrchestrator { // 5. Report log::debug!("Step 5: reporting results..."); - let report = self.reporter.report(&summary, Some(&self.pool))?; + let report = self.reporter.report(&summary, Some(&self.pool)).await?; result.anomalies_found += report.anomalies_reported; // 6. Consume (if enabled) @@ -260,6 +276,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); config.database_url = ":memory:".into(); @@ -289,6 +311,12 @@ mod tests { ai_model: None, ai_api_url: None, slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, }); config.database_url = ":memory:".into(); diff --git a/src/sniff/reporter.rs b/src/sniff/reporter.rs index ab8c73a..192aabf 100644 --- a/src/sniff/reporter.rs +++ b/src/sniff/reporter.rs @@ -4,7 +4,7 @@ //! them via the existing notification channels. use crate::alerting::alert::{Alert, AlertSeverity, AlertType}; -use crate::alerting::notifications::{route_by_severity, NotificationConfig}; +use crate::alerting::notifications::{NotificationConfig, NotificationResult}; use crate::database::connection::DbPool; use crate::database::repositories::log_sources; use crate::sniff::analyzer::{AnomalySeverity, LogSummary}; @@ -33,7 +33,11 @@ impl Reporter { } /// Report a log summary: persist to DB and send anomaly alerts - pub fn report(&self, summary: &LogSummary, pool: Option<&DbPool>) -> Result { + pub async fn report( + &self, + summary: &LogSummary, + pool: Option<&DbPool>, + ) -> Result { let mut alerts_sent = 0; // Persist summary to database @@ -76,11 +80,16 @@ impl Reporter { ); // Route to appropriate notification channels - let channels = route_by_severity(alert_severity); + let channels = self + .notification_config + .configured_channels_for_severity(alert_severity); log::debug!("Routing alert to {} notification channels", channels.len()); for channel in &channels { - match channel.send(&alert, &self.notification_config) { - Ok(_) => alerts_sent += 1, + match channel.send(&alert, &self.notification_config).await { + Ok(NotificationResult::Success(_)) => alerts_sent += 1, + Ok(NotificationResult::Failure(message)) => { + log::warn!("Notification channel reported failure: {}", message) + } Err(e) => log::warn!("Failed to send notification: {}", e), } } @@ -153,18 +162,18 @@ mod tests { ); } - #[test] - fn test_report_no_anomalies() { + #[tokio::test] + async fn test_report_no_anomalies() { let reporter = Reporter::new(NotificationConfig::default()); let summary = make_summary(vec![]); - let result = reporter.report(&summary, None).unwrap(); + let result = reporter.report(&summary, None).await.unwrap(); assert_eq!(result.anomalies_reported, 0); assert_eq!(result.notifications_sent, 0); assert!(!result.summary_persisted); } - #[test] - fn test_report_with_anomalies_sends_alerts() { + #[tokio::test] + async fn test_report_with_anomalies_sends_alerts() { let reporter = Reporter::new(NotificationConfig::default()); let summary = make_summary(vec![LogAnomaly { description: "High error rate".into(), @@ -172,21 +181,20 @@ mod tests { sample_line: "ERROR: connection failed".into(), }]); - let result = reporter.report(&summary, None).unwrap(); + let result = reporter.report(&summary, None).await.unwrap(); assert_eq!(result.anomalies_reported, 1); - // Console channel is always available, so at least 1 notification sent - assert!(result.notifications_sent >= 1); + assert_eq!(result.notifications_sent, 1); } - #[test] - fn test_report_persists_to_database() { + #[tokio::test] + async fn test_report_persists_to_database() { let pool = create_pool(":memory:").unwrap(); init_database(&pool).unwrap(); let reporter = Reporter::new(NotificationConfig::default()); let summary = make_summary(vec![]); - let result = reporter.report(&summary, Some(&pool)).unwrap(); + let result = reporter.report(&summary, Some(&pool)).await.unwrap(); assert!(result.summary_persisted); // Verify summary was stored @@ -195,8 +203,8 @@ mod tests { assert_eq!(summaries[0].total_entries, 100); } - #[test] - fn test_report_multiple_anomalies() { + #[tokio::test] + async fn test_report_multiple_anomalies() { let reporter = Reporter::new(NotificationConfig::default()); let summary = make_summary(vec![ LogAnomaly { @@ -211,18 +219,34 @@ mod tests { }, ]); - let result = reporter.report(&summary, None).unwrap(); + let result = reporter.report(&summary, None).await.unwrap(); assert_eq!(result.anomalies_reported, 2); - assert!(result.notifications_sent >= 2); + assert_eq!(result.notifications_sent, 2); } - #[test] - fn test_reporter_new() { + #[tokio::test] + async fn test_reporter_new() { let config = NotificationConfig::default(); let reporter = Reporter::new(config); // Just ensure it constructs without error let summary = make_summary(vec![]); - let result = reporter.report(&summary, None); + let result = reporter.report(&summary, None).await; assert!(result.is_ok()); } + + #[tokio::test] + async fn test_report_does_not_count_delivery_failures_as_sent() { + let reporter = Reporter::new( + NotificationConfig::default().with_slack_webhook("http://127.0.0.1:1".into()), + ); + let summary = make_summary(vec![LogAnomaly { + description: "High error rate".into(), + severity: AnomalySeverity::High, + sample_line: "ERROR: connection failed".into(), + }]); + + let result = reporter.report(&summary, None).await.unwrap(); + assert_eq!(result.anomalies_reported, 1); + assert_eq!(result.notifications_sent, 1); + } } diff --git a/web/src/components/ContainerList.tsx b/web/src/components/ContainerList.tsx index 04e5874..39defae 100644 --- a/web/src/components/ContainerList.tsx +++ b/web/src/components/ContainerList.tsx @@ -78,6 +78,18 @@ const ContainerList: React.FC = () => { return '#e74c3c'; }; + const formatCount = (value: number | null) => (value === null ? 'n/a' : value.toLocaleString()); + + const formatBytes = (value: number | null) => { + if (value === null) return 'n/a'; + if (value < 1024) return `${value} B`; + if (value < 1024 * 1024) return `${(value / 1024).toFixed(1)} KB`; + if (value < 1024 * 1024 * 1024) return `${(value / (1024 * 1024)).toFixed(1)} MB`; + return `${(value / (1024 * 1024 * 1024)).toFixed(1)} GB`; + }; + + const formatDateTime = (value: string | null) => (value ? new Date(value).toLocaleString() : 'Unavailable'); + return ( @@ -127,9 +139,9 @@ const ContainerList: React.FC = () => {

- 📥 {container.networkActivity.inboundConnections} | - 📤 {container.networkActivity.outboundConnections} | - 🚫 {container.networkActivity.blockedConnections} + ⬇ {formatCount(container.networkActivity.receivedPackets)} pkts | + ⬆ {formatCount(container.networkActivity.transmittedPackets)} pkts | + 🚫 {formatCount(container.networkActivity.blockedConnections)} {container.networkActivity.suspiciousActivity && ( Suspicious @@ -190,8 +202,13 @@ const ContainerList: React.FC = () => {

Security: {selectedContainer.securityStatus.state}

Risk Score: {selectedContainer.riskScore}

Threats: {selectedContainer.securityStatus.threats}

-

Vulnerabilities: {selectedContainer.securityStatus.vulnerabilities}

-

Last Scan: {new Date(selectedContainer.securityStatus.lastScan).toLocaleString()}

+

Vulnerabilities: {selectedContainer.securityStatus.vulnerabilities ?? 'Unavailable'}

+

Last Scan: {formatDateTime(selectedContainer.securityStatus.lastScan)}

+

RX Traffic: {formatBytes(selectedContainer.networkActivity.receivedBytes)}

+

TX Traffic: {formatBytes(selectedContainer.networkActivity.transmittedBytes)}

+

RX Packets: {formatCount(selectedContainer.networkActivity.receivedPackets)}

+

TX Packets: {formatCount(selectedContainer.networkActivity.transmittedPackets)}

+

Blocked Connections: {formatCount(selectedContainer.networkActivity.blockedConnections)}

)} diff --git a/web/src/components/__tests__/ContainerList.test.tsx b/web/src/components/__tests__/ContainerList.test.tsx index b26ccad..802fdf2 100644 --- a/web/src/components/__tests__/ContainerList.test.tsx +++ b/web/src/components/__tests__/ContainerList.test.tsx @@ -15,14 +15,18 @@ const mockContainers = [ securityStatus: { state: 'Secure' as const, threats: 0, - vulnerabilities: 0, - lastScan: new Date().toISOString(), + vulnerabilities: null, + lastScan: null, }, riskScore: 10, networkActivity: { - inboundConnections: 5, - outboundConnections: 3, - blockedConnections: 0, + inboundConnections: null, + outboundConnections: null, + blockedConnections: null, + receivedBytes: 1024, + transmittedBytes: 2048, + receivedPackets: 5, + transmittedPackets: 3, suspiciousActivity: false, }, createdAt: new Date().toISOString(), @@ -35,14 +39,18 @@ const mockContainers = [ securityStatus: { state: 'AtRisk' as const, threats: 2, - vulnerabilities: 1, - lastScan: new Date().toISOString(), + vulnerabilities: null, + lastScan: null, }, riskScore: 65, networkActivity: { - inboundConnections: 10, - outboundConnections: 5, - blockedConnections: 2, + inboundConnections: null, + outboundConnections: null, + blockedConnections: null, + receivedBytes: 4096, + transmittedBytes: 8192, + receivedPackets: 10, + transmittedPackets: 5, suspiciousActivity: true, }, createdAt: new Date().toISOString(), @@ -158,8 +166,8 @@ describe('ContainerList Component', () => { }); // Should show network activity details - expect(screen.getByText('10')).toBeInTheDocument(); // Inbound - expect(screen.getByText('5')).toBeInTheDocument(); // Outbound - expect(screen.getByText('2')).toBeInTheDocument(); // Blocked + expect(screen.getByText(/10 pkts/)).toBeInTheDocument(); + expect(screen.getAllByText(/5 pkts/).length).toBeGreaterThan(0); + expect(screen.getAllByText(/n\/a/).length).toBeGreaterThan(0); }); }); diff --git a/web/src/services/api.ts b/web/src/services/api.ts index 53c40b0..cb5d64c 100644 --- a/web/src/services/api.ts +++ b/web/src/services/api.ts @@ -27,6 +27,17 @@ class ApiService { }); } + private firstNumber(...values: unknown[]): number | null { + return (values.find((value) => typeof value === 'number') as number | undefined) ?? null; + } + + private firstString(...values: unknown[]): string | null { + return ( + (values.find((value) => typeof value === 'string' && value.length > 0) as string | undefined) ?? + null + ); + } + // Security Status async getSecurityStatus(): Promise { const response = await this.api.get('/security/status'); @@ -77,25 +88,50 @@ class ApiService { const securityStatus = item.securityStatus ?? item.security_status ?? {}; const networkActivity = item.networkActivity ?? item.network_activity ?? {}; - return { - id: item.id ?? '', - name: item.name ?? item.id ?? 'unknown', - image: item.image ?? 'unknown', - status: item.status ?? 'Running', - securityStatus: { - state: securityStatus.state ?? 'Secure', - threats: securityStatus.threats ?? 0, - vulnerabilities: securityStatus.vulnerabilities ?? 0, - lastScan: securityStatus.lastScan ?? new Date().toISOString(), - }, - riskScore: item.riskScore ?? item.risk_score ?? 0, - networkActivity: { - inboundConnections: networkActivity.inboundConnections ?? networkActivity.inbound_connections ?? 0, - outboundConnections: networkActivity.outboundConnections ?? networkActivity.outbound_connections ?? 0, - blockedConnections: networkActivity.blockedConnections ?? networkActivity.blocked_connections ?? 0, - suspiciousActivity: networkActivity.suspiciousActivity ?? networkActivity.suspicious_activity ?? false, - }, - createdAt: item.createdAt ?? item.created_at ?? new Date().toISOString(), + return { + id: item.id ?? '', + name: item.name ?? item.id ?? 'unknown', + image: item.image ?? 'unknown', + status: item.status ?? 'Running', + securityStatus: { + state: securityStatus.state ?? 'Secure', + threats: securityStatus.threats ?? 0, + vulnerabilities: this.firstNumber(securityStatus.vulnerabilities), + lastScan: this.firstString(securityStatus.lastScan, securityStatus.last_scan), + }, + riskScore: item.riskScore ?? item.risk_score ?? 0, + networkActivity: { + inboundConnections: this.firstNumber( + networkActivity.inboundConnections, + networkActivity.inbound_connections, + ), + outboundConnections: this.firstNumber( + networkActivity.outboundConnections, + networkActivity.outbound_connections, + ), + blockedConnections: this.firstNumber( + networkActivity.blockedConnections, + networkActivity.blocked_connections, + ), + receivedBytes: this.firstNumber( + networkActivity.receivedBytes, + networkActivity.received_bytes, + ), + transmittedBytes: this.firstNumber( + networkActivity.transmittedBytes, + networkActivity.transmitted_bytes, + ), + receivedPackets: this.firstNumber( + networkActivity.receivedPackets, + networkActivity.received_packets, + ), + transmittedPackets: this.firstNumber( + networkActivity.transmittedPackets, + networkActivity.transmitted_packets, + ), + suspiciousActivity: networkActivity.suspiciousActivity ?? networkActivity.suspicious_activity ?? false, + }, + createdAt: item.createdAt ?? item.created_at ?? new Date().toISOString(), } as Container; }); } diff --git a/web/src/types/containers.ts b/web/src/types/containers.ts index 03787d8..4044216 100644 --- a/web/src/types/containers.ts +++ b/web/src/types/containers.ts @@ -16,14 +16,18 @@ export type ContainerStatus = 'Running' | 'Stopped' | 'Paused' | 'Quarantined'; export interface SecurityStatus { state: 'Secure' | 'AtRisk' | 'Compromised' | 'Quarantined'; threats: number; - vulnerabilities: number; - lastScan: string; + vulnerabilities: number | null; + lastScan: string | null; } export interface NetworkActivity { - inboundConnections: number; - outboundConnections: number; - blockedConnections: number; + inboundConnections: number | null; + outboundConnections: number | null; + blockedConnections: number | null; + receivedBytes: number | null; + transmittedBytes: number | null; + receivedPackets: number | null; + transmittedPackets: number | null; suspiciousActivity: boolean; } From c845579917d564047bacd687a4ee44db3b31ce02 Mon Sep 17 00:00:00 2001 From: vsilent Date: Sat, 4 Apr 2026 11:42:33 +0300 Subject: [PATCH 04/10] tests, ip_ban engine implemented, frontend dashboard improvements --- .../down.sql | 4 + .../00000000000003_create_ip_offenses/up.sql | 18 ++ src/api/alerts.rs | 6 +- src/correlator/engine.rs | 117 ++++++- src/database/connection.rs | 31 ++ src/database/events.rs | 116 ++++++- src/database/mod.rs | 3 + src/database/repositories/alerts.rs | 7 + src/database/repositories/mod.rs | 2 + src/database/repositories/offenses.rs | 291 +++++++++++++++++ src/firewall/response.rs | 11 +- src/ip_ban/config.rs | 50 +++ src/ip_ban/engine.rs | 297 ++++++++++++++++++ src/ip_ban/mod.rs | 5 + src/lib.rs | 1 + src/main.rs | 19 ++ src/sniff/mod.rs | 177 +++++++++++ tests/api/alerts_api_test.rs | 212 ++++++++++++- tests/api/containers_api_test.rs | 66 +++- tests/api/mod.rs | 2 +- tests/api/security_api_test.rs | 77 ++++- tests/api/threats_api_test.rs | 105 ++++++- tests/firewall/response_test.rs | 12 + tests/integration.rs | 1 + web/package.json | 64 ++-- .../components/__tests__/AlertPanel.test.tsx | 55 ++-- .../__tests__/ContainerList.test.tsx | 21 +- .../components/__tests__/Dashboard.test.tsx | 99 ++++++ .../__tests__/SecurityScore.test.tsx | 28 ++ .../components/__tests__/ThreatMap.test.tsx | 41 +-- web/src/services/__tests__/security.test.ts | 120 ++++++- web/src/services/__tests__/websocket.test.ts | 170 +++++----- web/src/services/api.ts | 50 ++- web/src/setupTests.ts | 8 + 34 files changed, 2029 insertions(+), 257 deletions(-) create mode 100644 migrations/00000000000003_create_ip_offenses/down.sql create mode 100644 migrations/00000000000003_create_ip_offenses/up.sql create mode 100644 src/database/repositories/offenses.rs create mode 100644 src/ip_ban/config.rs create mode 100644 src/ip_ban/engine.rs create mode 100644 src/ip_ban/mod.rs create mode 100644 web/src/components/__tests__/Dashboard.test.tsx create mode 100644 web/src/components/__tests__/SecurityScore.test.tsx diff --git a/migrations/00000000000003_create_ip_offenses/down.sql b/migrations/00000000000003_create_ip_offenses/down.sql new file mode 100644 index 0000000..f1bb943 --- /dev/null +++ b/migrations/00000000000003_create_ip_offenses/down.sql @@ -0,0 +1,4 @@ +DROP INDEX IF EXISTS idx_ip_offenses_last_seen; +DROP INDEX IF EXISTS idx_ip_offenses_status; +DROP INDEX IF EXISTS idx_ip_offenses_ip; +DROP TABLE IF EXISTS ip_offenses; diff --git a/migrations/00000000000003_create_ip_offenses/up.sql b/migrations/00000000000003_create_ip_offenses/up.sql new file mode 100644 index 0000000..a800425 --- /dev/null +++ b/migrations/00000000000003_create_ip_offenses/up.sql @@ -0,0 +1,18 @@ +CREATE TABLE IF NOT EXISTS ip_offenses ( + id TEXT PRIMARY KEY, + ip_address TEXT NOT NULL, + source_type TEXT NOT NULL, + container_id TEXT, + offense_count INTEGER NOT NULL DEFAULT 1, + first_seen TEXT NOT NULL, + last_seen TEXT NOT NULL, + blocked_until TEXT, + status TEXT NOT NULL DEFAULT 'Active', + reason TEXT NOT NULL, + metadata TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_ip_offenses_ip ON ip_offenses(ip_address); +CREATE INDEX IF NOT EXISTS idx_ip_offenses_status ON ip_offenses(status); +CREATE INDEX IF NOT EXISTS idx_ip_offenses_last_seen ON ip_offenses(last_seen); diff --git a/src/api/alerts.rs b/src/api/alerts.rs index 6557208..22e3e1a 100644 --- a/src/api/alerts.rs +++ b/src/api/alerts.rs @@ -50,7 +50,8 @@ pub async fn get_alert_stats(pool: web::Data) -> impl Responder { "total_count": stats.total_count, "new_count": stats.new_count, "acknowledged_count": stats.acknowledged_count, - "resolved_count": stats.resolved_count + "resolved_count": stats.resolved_count, + "false_positive_count": stats.false_positive_count })), Err(e) => { log::error!("Failed to get alert stats: {}", e); @@ -59,7 +60,8 @@ pub async fn get_alert_stats(pool: web::Data) -> impl Responder { "total_count": 0, "new_count": 0, "acknowledged_count": 0, - "resolved_count": 0 + "resolved_count": 0, + "false_positive_count": 0 })) } } diff --git a/src/correlator/engine.rs b/src/correlator/engine.rs index f0fbb66..a95fae5 100644 --- a/src/correlator/engine.rs +++ b/src/correlator/engine.rs @@ -1,15 +1,68 @@ //! Event correlation engine +use crate::events::security::SecurityEvent; use anyhow::Result; +use chrono::Duration; +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct CorrelatedEventGroup { + pub correlation_key: String, + pub events: Vec, +} /// Event correlation engine pub struct CorrelationEngine { - // TODO: Implement in TASK-017 + window: Duration, } impl CorrelationEngine { pub fn new() -> Result { - Ok(Self {}) + Ok(Self { + window: Duration::minutes(5), + }) + } + + pub fn correlate(&self, events: &[SecurityEvent]) -> Vec { + let mut grouped: HashMap> = HashMap::new(); + + for event in events { + if let Some(key) = self.correlation_key(event) { + grouped.entry(key).or_default().push(event.clone()); + } + } + + grouped + .into_iter() + .filter_map(|(correlation_key, mut grouped_events)| { + grouped_events.sort_by_key(SecurityEvent::timestamp); + let first = grouped_events.first()?.timestamp(); + let last = grouped_events.last()?.timestamp(); + if grouped_events.len() >= 2 && (last - first) <= self.window { + Some(CorrelatedEventGroup { + correlation_key, + events: grouped_events, + }) + } else { + None + } + }) + .collect() + } + + fn correlation_key(&self, event: &SecurityEvent) -> Option { + match event { + SecurityEvent::Syscall(event) => Some(format!("pid:{}", event.pid)), + SecurityEvent::Container(event) => Some(format!("container:{}", event.container_id)), + SecurityEvent::Network(event) => event + .container_id + .as_ref() + .map(|container_id| format!("container:{container_id}")), + SecurityEvent::Alert(event) => event + .source_event_id + .as_ref() + .map(|source_event_id| format!("source:{source_event_id}")), + } } } @@ -18,3 +71,63 @@ impl Default for CorrelationEngine { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::events::security::{ContainerEvent, ContainerEventType, SecurityEvent}; + use crate::events::syscall::{SyscallEvent, SyscallType}; + use chrono::{Duration, Utc}; + + #[test] + fn test_correlates_syscall_events_by_pid_within_window() { + let engine = CorrelationEngine::new().unwrap(); + let now = Utc::now(); + let events = vec![ + SecurityEvent::Syscall(SyscallEvent::new(4242, 1000, SyscallType::Execve, now)), + SecurityEvent::Syscall(SyscallEvent::new( + 4242, + 1000, + SyscallType::Open, + now + Duration::seconds(10), + )), + SecurityEvent::Syscall(SyscallEvent::new(7, 1000, SyscallType::Execve, now)), + ]; + + let groups = engine.correlate(&events); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].correlation_key, "pid:4242"); + assert_eq!(groups[0].events.len(), 2); + } + + #[test] + fn test_correlates_container_events_by_container_id() { + let engine = CorrelationEngine::new().unwrap(); + let now = Utc::now(); + let events = vec![ + SecurityEvent::Container(ContainerEvent { + container_id: "container-1".into(), + event_type: ContainerEventType::Start, + timestamp: now, + details: None, + }), + SecurityEvent::Container(ContainerEvent { + container_id: "container-1".into(), + event_type: ContainerEventType::Stop, + timestamp: now + Duration::seconds(30), + details: Some("manual stop".into()), + }), + SecurityEvent::Container(ContainerEvent { + container_id: "container-2".into(), + event_type: ContainerEventType::Start, + timestamp: now, + details: None, + }), + ]; + + let groups = engine.correlate(&events); + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].correlation_key, "container:container-1"); + assert_eq!(groups[0].events.len(), 2); + } +} diff --git a/src/database/connection.rs b/src/database/connection.rs index 2513cae..e684cfa 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -188,6 +188,37 @@ pub fn init_database(pool: &DbPool) -> Result<()> { [], ); + conn.execute( + "CREATE TABLE IF NOT EXISTS ip_offenses ( + id TEXT PRIMARY KEY, + ip_address TEXT NOT NULL, + source_type TEXT NOT NULL, + container_id TEXT, + offense_count INTEGER NOT NULL DEFAULT 1, + first_seen TEXT NOT NULL, + last_seen TEXT NOT NULL, + blocked_until TEXT, + status TEXT NOT NULL DEFAULT 'Active', + reason TEXT NOT NULL, + metadata TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP + )", + [], + )?; + + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_ip_offenses_ip ON ip_offenses(ip_address)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_ip_offenses_status ON ip_offenses(status)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_ip_offenses_last_seen ON ip_offenses(last_seen)", + [], + ); + Ok(()) } diff --git a/src/database/events.rs b/src/database/events.rs index f116833..91ed614 100644 --- a/src/database/events.rs +++ b/src/database/events.rs @@ -1,15 +1,55 @@ //! Security events database operations +use crate::events::security::SecurityEvent; use anyhow::Result; +use chrono::{DateTime, Utc}; +use std::sync::{Arc, RwLock}; /// Events database manager pub struct EventsDb { - // TODO: Implement + events: Arc>>, } impl EventsDb { pub fn new() -> Result { - Ok(Self {}) + Ok(Self { + events: Arc::new(RwLock::new(Vec::new())), + }) + } + + pub fn insert(&self, event: SecurityEvent) -> Result<()> { + self.events.write().unwrap().push(event); + Ok(()) + } + + pub fn list(&self) -> Result> { + Ok(self.events.read().unwrap().clone()) + } + + pub fn events_since(&self, since: DateTime) -> Result> { + Ok(self + .events + .read() + .unwrap() + .iter() + .filter(|event| event.timestamp() >= since) + .cloned() + .collect()) + } + + pub fn events_for_pid(&self, pid: u32) -> Result> { + Ok(self + .events + .read() + .unwrap() + .iter() + .filter(|event| event.pid() == Some(pid)) + .cloned() + .collect()) + } + + pub fn len(&self) -> usize { + self.events.read().unwrap().len() } } @@ -18,3 +58,75 @@ impl Default for EventsDb { Self::new().unwrap() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::events::security::{ + AlertEvent, AlertSeverity, AlertType, ContainerEvent, ContainerEventType, + }; + use crate::events::syscall::{SyscallEvent, SyscallType}; + use chrono::{Duration, Utc}; + + #[test] + fn test_events_db_stores_and_queries_events_since_timestamp() { + let db = EventsDb::new().unwrap(); + let old_time = Utc::now() - Duration::minutes(10); + let recent_time = Utc::now(); + + db.insert(SecurityEvent::Alert(AlertEvent { + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::High, + message: "old event".into(), + timestamp: old_time, + source_event_id: None, + })) + .unwrap(); + db.insert(SecurityEvent::Alert(AlertEvent { + alert_type: AlertType::AnomalyDetected, + severity: AlertSeverity::Critical, + message: "recent event".into(), + timestamp: recent_time, + source_event_id: None, + })) + .unwrap(); + + let recent = db.events_since(Utc::now() - Duration::minutes(1)).unwrap(); + assert_eq!(recent.len(), 1); + match &recent[0] { + SecurityEvent::Alert(event) => assert_eq!(event.message, "recent event"), + other => panic!("unexpected event: {other:?}"), + } + } + + #[test] + fn test_events_db_filters_events_by_pid() { + let db = EventsDb::new().unwrap(); + db.insert(SecurityEvent::Syscall(SyscallEvent::new( + 42, + 1000, + SyscallType::Execve, + Utc::now(), + ))) + .unwrap(); + db.insert(SecurityEvent::Container(ContainerEvent { + container_id: "container-1".into(), + event_type: ContainerEventType::Start, + timestamp: Utc::now(), + details: None, + })) + .unwrap(); + db.insert(SecurityEvent::Syscall(SyscallEvent::new( + 7, + 1000, + SyscallType::Open, + Utc::now(), + ))) + .unwrap(); + + let pid_events = db.events_for_pid(42).unwrap(); + assert_eq!(pid_events.len(), 1); + assert_eq!(pid_events[0].pid(), Some(42)); + assert_eq!(db.len(), 3); + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index e55ccab..f2a871f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -2,13 +2,16 @@ pub mod baselines; pub mod connection; +pub mod events; pub mod models; pub mod repositories; pub use baselines::*; pub use connection::{create_pool, init_database, DbPool}; +pub use events::*; pub use models::*; pub use repositories::alerts::*; +pub use repositories::offenses::*; /// Marker struct for module tests pub struct DatabaseMarker; diff --git a/src/database/repositories/alerts.rs b/src/database/repositories/alerts.rs index fa9d98c..a541340 100644 --- a/src/database/repositories/alerts.rs +++ b/src/database/repositories/alerts.rs @@ -21,6 +21,7 @@ pub struct AlertStats { pub new_count: i64, pub acknowledged_count: i64, pub resolved_count: i64, + pub false_positive_count: i64, } /// Severity breakdown for open security alerts. @@ -263,12 +264,18 @@ pub async fn get_alert_stats(pool: &DbPool) -> Result { [], |row| row.get(0), )?; + let false_positive: i64 = conn.query_row( + "SELECT COUNT(*) FROM alerts WHERE status = 'FalsePositive'", + [], + |row| row.get(0), + )?; Ok(AlertStats { total_count: total, new_count: new, acknowledged_count: ack, resolved_count: resolved, + false_positive_count: false_positive, }) } diff --git a/src/database/repositories/mod.rs b/src/database/repositories/mod.rs index 8f790f5..cf98a45 100644 --- a/src/database/repositories/mod.rs +++ b/src/database/repositories/mod.rs @@ -2,5 +2,7 @@ pub mod alerts; pub mod log_sources; +pub mod offenses; pub use alerts::*; +pub use offenses::*; diff --git a/src/database/repositories/offenses.rs b/src/database/repositories/offenses.rs new file mode 100644 index 0000000..143470d --- /dev/null +++ b/src/database/repositories/offenses.rs @@ -0,0 +1,291 @@ +//! Persistent IP ban offense tracking. + +use crate::database::connection::DbPool; +use anyhow::Result; +use chrono::{DateTime, Utc}; +use rusqlite::params; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum OffenseStatus { + Active, + Blocked, + Released, +} + +impl std::fmt::Display for OffenseStatus { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Active => write!(f, "Active"), + Self::Blocked => write!(f, "Blocked"), + Self::Released => write!(f, "Released"), + } + } +} + +impl std::str::FromStr for OffenseStatus { + type Err = String; + + fn from_str(value: &str) -> Result { + match value { + "Active" => Ok(Self::Active), + "Blocked" => Ok(Self::Blocked), + "Released" => Ok(Self::Released), + _ => Err(format!("unknown offense status: {value}")), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct OffenseMetadata { + #[serde(skip_serializing_if = "Option::is_none")] + pub source_path: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sample_line: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IpOffenseRecord { + pub id: String, + pub ip_address: String, + pub source_type: String, + pub container_id: Option, + pub offense_count: u32, + pub first_seen: String, + pub last_seen: String, + pub blocked_until: Option, + pub status: OffenseStatus, + pub reason: String, + pub metadata: Option, +} + +#[derive(Debug, Clone)] +pub struct NewIpOffense { + pub id: String, + pub ip_address: String, + pub source_type: String, + pub container_id: Option, + pub first_seen: DateTime, + pub reason: String, + pub metadata: Option, +} + +fn serialize_metadata(metadata: Option<&OffenseMetadata>) -> Result> { + match metadata { + Some(metadata) => Ok(Some(serde_json::to_string(metadata)?)), + None => Ok(None), + } +} + +fn parse_metadata(value: Option) -> Option { + value.and_then(|raw| serde_json::from_str(&raw).ok()) +} + +fn parse_status(value: String) -> Result { + value.parse().map_err(|err: String| { + rusqlite::Error::FromSqlConversionFailure( + 8, + rusqlite::types::Type::Text, + Box::new(std::io::Error::new(std::io::ErrorKind::InvalidData, err)), + ) + }) +} + +fn map_row(row: &rusqlite::Row) -> Result { + Ok(IpOffenseRecord { + id: row.get(0)?, + ip_address: row.get(1)?, + source_type: row.get(2)?, + container_id: row.get(3)?, + offense_count: row.get::<_, i64>(4)?.max(0) as u32, + first_seen: row.get(5)?, + last_seen: row.get(6)?, + blocked_until: row.get(7)?, + status: parse_status(row.get(8)?)?, + reason: row.get(9)?, + metadata: parse_metadata(row.get(10)?), + }) +} + +pub fn insert_offense(pool: &DbPool, offense: &NewIpOffense) -> Result<()> { + let conn = pool.get()?; + conn.execute( + "INSERT INTO ip_offenses ( + id, ip_address, source_type, container_id, offense_count, + first_seen, last_seen, blocked_until, status, reason, metadata + ) VALUES (?1, ?2, ?3, ?4, 1, ?5, ?5, NULL, 'Active', ?6, ?7)", + params![ + offense.id, + offense.ip_address, + offense.source_type, + offense.container_id, + offense.first_seen.to_rfc3339(), + offense.reason, + serialize_metadata(offense.metadata.as_ref())?, + ], + )?; + Ok(()) +} + +pub fn find_recent_offenses( + pool: &DbPool, + ip_address: &str, + source_type: &str, + since: DateTime, +) -> Result> { + let conn = pool.get()?; + let mut stmt = conn.prepare( + "SELECT + id, ip_address, source_type, container_id, offense_count, + first_seen, last_seen, blocked_until, status, reason, metadata + FROM ip_offenses + WHERE ip_address = ?1 + AND source_type = ?2 + AND last_seen >= ?3 + ORDER BY last_seen DESC", + )?; + + let rows = stmt.query_map( + params![ip_address, source_type, since.to_rfc3339()], + map_row, + )?; + let mut offenses = Vec::new(); + for row in rows { + offenses.push(row?); + } + Ok(offenses) +} + +pub fn active_block_for_ip(pool: &DbPool, ip_address: &str) -> Result> { + let conn = pool.get()?; + let mut stmt = conn.prepare( + "SELECT + id, ip_address, source_type, container_id, offense_count, + first_seen, last_seen, blocked_until, status, reason, metadata + FROM ip_offenses + WHERE ip_address = ?1 AND status = 'Blocked' + ORDER BY last_seen DESC + LIMIT 1", + )?; + + match stmt.query_row(params![ip_address], map_row) { + Ok(record) => Ok(Some(record)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(err) => Err(err.into()), + } +} + +pub fn mark_blocked( + pool: &DbPool, + ip_address: &str, + source_type: &str, + blocked_until: DateTime, +) -> Result<()> { + let conn = pool.get()?; + conn.execute( + "UPDATE ip_offenses + SET status = 'Blocked', blocked_until = ?1 + WHERE ip_address = ?2 AND source_type = ?3 AND status = 'Active'", + params![blocked_until.to_rfc3339(), ip_address, source_type], + )?; + Ok(()) +} + +pub fn expired_blocks(pool: &DbPool, now: DateTime) -> Result> { + let conn = pool.get()?; + let mut stmt = conn.prepare( + "SELECT + id, ip_address, source_type, container_id, offense_count, + first_seen, last_seen, blocked_until, status, reason, metadata + FROM ip_offenses + WHERE status = 'Blocked' + AND blocked_until IS NOT NULL + AND blocked_until <= ?1 + ORDER BY blocked_until ASC", + )?; + + let rows = stmt.query_map(params![now.to_rfc3339()], map_row)?; + let mut offenses = Vec::new(); + for row in rows { + offenses.push(row?); + } + Ok(offenses) +} + +pub fn mark_released(pool: &DbPool, offense_id: &str) -> Result<()> { + let conn = pool.get()?; + conn.execute( + "UPDATE ip_offenses SET status = 'Released' WHERE id = ?1", + params![offense_id], + )?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::database::{create_pool, init_database}; + use chrono::Duration; + + #[test] + fn test_insert_and_find_offense() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + insert_offense( + &pool, + &NewIpOffense { + id: "o1".into(), + ip_address: "192.0.2.10".into(), + source_type: "sniff".into(), + container_id: None, + first_seen: Utc::now(), + reason: "Repeated ssh failures".into(), + metadata: Some(OffenseMetadata { + source_path: Some("/var/log/auth.log".into()), + sample_line: None, + }), + }, + ) + .unwrap(); + + let offenses = find_recent_offenses( + &pool, + "192.0.2.10", + "sniff", + Utc::now() - Duration::minutes(1), + ) + .unwrap(); + assert_eq!(offenses.len(), 1); + assert_eq!(offenses[0].status, OffenseStatus::Active); + } + + #[test] + fn test_mark_blocked_and_released() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let now = Utc::now(); + + insert_offense( + &pool, + &NewIpOffense { + id: "o2".into(), + ip_address: "192.0.2.20".into(), + source_type: "sniff".into(), + container_id: None, + first_seen: now, + reason: "test".into(), + metadata: None, + }, + ) + .unwrap(); + + mark_blocked(&pool, "192.0.2.20", "sniff", now + Duration::minutes(5)).unwrap(); + assert!(active_block_for_ip(&pool, "192.0.2.20").unwrap().is_some()); + + let expired = expired_blocks(&pool, now + Duration::minutes(10)).unwrap(); + assert_eq!(expired.len(), 1); + mark_released(&pool, &expired[0].id).unwrap(); + assert!(active_block_for_ip(&pool, "192.0.2.20").unwrap().is_none()); + } +} diff --git a/src/firewall/response.rs b/src/firewall/response.rs index 52e9cd8..2626eef 100644 --- a/src/firewall/response.rs +++ b/src/firewall/response.rs @@ -403,14 +403,15 @@ mod tests { } #[test] - fn test_quarantine_action_returns_error_when_container_blocking_missing() { + fn test_quarantine_action_returns_actionable_error() { let action = ResponseAction::new( ResponseType::QuarantineContainer("container-1".to_string()), "Quarantine".to_string(), ); - let result = action.execute(); - assert!(result.is_err()); + let error = action.execute().unwrap_err().to_string(); + assert!(error.contains("Docker-based container quarantine flow")); + assert!(error.contains("container-1")); } #[test] @@ -476,5 +477,9 @@ mod tests { assert_eq!(log.len(), 1); assert!(!log[0].success()); assert!(log[0].error().is_some()); + assert!(log[0] + .error() + .unwrap() + .contains("Docker-based container quarantine flow")); } } diff --git a/src/ip_ban/config.rs b/src/ip_ban/config.rs new file mode 100644 index 0000000..b2f04ed --- /dev/null +++ b/src/ip_ban/config.rs @@ -0,0 +1,50 @@ +use std::env; + +#[derive(Debug, Clone)] +pub struct IpBanConfig { + pub enabled: bool, + pub max_retries: u32, + pub find_time_secs: u64, + pub ban_time_secs: u64, + pub unban_check_interval_secs: u64, +} + +impl IpBanConfig { + pub fn from_env() -> Self { + Self { + enabled: parse_bool_env("STACKDOG_IP_BAN_ENABLED", true), + max_retries: parse_u32_env("STACKDOG_IP_BAN_MAX_RETRIES", 5), + find_time_secs: parse_u64_env("STACKDOG_IP_BAN_FIND_TIME_SECS", 300), + ban_time_secs: parse_u64_env("STACKDOG_IP_BAN_BAN_TIME_SECS", 1800), + unban_check_interval_secs: parse_u64_env( + "STACKDOG_IP_BAN_UNBAN_CHECK_INTERVAL_SECS", + 60, + ), + } + } +} + +fn parse_bool_env(name: &str, default: bool) -> bool { + env::var(name) + .ok() + .and_then(|value| match value.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Some(true), + "0" | "false" | "no" | "off" => Some(false), + _ => None, + }) + .unwrap_or(default) +} + +fn parse_u64_env(name: &str, default: u64) -> u64 { + env::var(name) + .ok() + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(default) +} + +fn parse_u32_env(name: &str, default: u32) -> u32 { + env::var(name) + .ok() + .and_then(|value| value.trim().parse::().ok()) + .unwrap_or(default) +} diff --git a/src/ip_ban/engine.rs b/src/ip_ban/engine.rs new file mode 100644 index 0000000..0225ea8 --- /dev/null +++ b/src/ip_ban/engine.rs @@ -0,0 +1,297 @@ +use crate::alerting::{AlertSeverity, AlertType}; +use crate::database::models::{Alert, AlertMetadata}; +use crate::database::repositories::offenses::{ + active_block_for_ip, expired_blocks, find_recent_offenses, insert_offense, mark_blocked, + mark_released, NewIpOffense, OffenseMetadata, +}; +use crate::database::{create_alert, DbPool}; +use crate::ip_ban::config::IpBanConfig; +use anyhow::Result; +use chrono::{Duration, Utc}; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct OffenseInput { + pub ip_address: String, + pub source_type: String, + pub reason: String, + pub severity: AlertSeverity, + pub container_id: Option, + pub source_path: Option, + pub sample_line: Option, +} + +pub struct IpBanEngine { + pool: DbPool, + config: IpBanConfig, +} + +impl IpBanEngine { + pub fn new(pool: DbPool, config: IpBanConfig) -> Self { + Self { pool, config } + } + + pub fn config(&self) -> &IpBanConfig { + &self.config + } + + pub async fn record_offense(&self, offense: OffenseInput) -> Result { + if active_block_for_ip(&self.pool, &offense.ip_address)?.is_some() { + return Ok(false); + } + + let now = Utc::now(); + insert_offense( + &self.pool, + &NewIpOffense { + id: Uuid::new_v4().to_string(), + ip_address: offense.ip_address.clone(), + source_type: offense.source_type.clone(), + container_id: offense.container_id.clone(), + first_seen: now, + reason: offense.reason.clone(), + metadata: Some(OffenseMetadata { + source_path: offense.source_path.clone(), + sample_line: offense.sample_line.clone(), + }), + }, + )?; + + let recent = find_recent_offenses( + &self.pool, + &offense.ip_address, + &offense.source_type, + now - Duration::seconds(self.config.find_time_secs as i64), + )?; + + if recent.len() as u32 >= self.config.max_retries { + self.block_ip(&offense, now).await?; + return Ok(true); + } + + Ok(false) + } + + pub async fn unban_expired(&self) -> Result { + let now = Utc::now(); + let expired = expired_blocks(&self.pool, now)?; + let mut released = 0; + + for offense in expired { + #[cfg(target_os = "linux")] + self.with_firewall_backend(|backend| backend.unblock_ip(&offense.ip_address))?; + + mark_released(&self.pool, &offense.id)?; + create_alert( + &self.pool, + Alert::new( + AlertType::SystemEvent, + AlertSeverity::Info, + format!("Released IP ban for {}", offense.ip_address), + ) + .with_metadata( + AlertMetadata::default() + .with_source("ip_ban") + .with_reason(format!("Released expired ban for {}", offense.ip_address)), + ), + ) + .await?; + released += 1; + } + + Ok(released) + } + + async fn block_ip(&self, offense: &OffenseInput, now: chrono::DateTime) -> Result<()> { + #[cfg(target_os = "linux")] + self.with_firewall_backend(|backend| backend.block_ip(&offense.ip_address))?; + + let blocked_until = now + Duration::seconds(self.config.ban_time_secs as i64); + mark_blocked( + &self.pool, + &offense.ip_address, + &offense.source_type, + blocked_until, + )?; + + create_alert( + &self.pool, + Alert::new( + AlertType::ThresholdExceeded, + offense.severity, + format!( + "Blocked IP {} after repeated {} offenses", + offense.ip_address, offense.source_type + ), + ) + .with_metadata({ + let mut metadata = AlertMetadata::default() + .with_source("ip_ban") + .with_reason(offense.reason.clone()); + if let Some(container_id) = &offense.container_id { + metadata = metadata.with_container_id(container_id.clone()); + } + metadata + }), + ) + .await?; + + Ok(()) + } + + #[cfg(target_os = "linux")] + fn with_firewall_backend(&self, action: F) -> Result<()> + where + F: FnOnce(&dyn crate::firewall::FirewallBackend) -> Result<()>, + { + if let Ok(mut backend) = crate::firewall::NfTablesBackend::new() { + backend.initialize()?; + return action(&backend); + } + + let mut backend = crate::firewall::IptablesBackend::new()?; + backend.initialize()?; + action(&backend) + } + + pub fn extract_ip_candidates(line: &str) -> Vec { + line.split(|ch: char| !(ch.is_ascii_digit() || ch == '.')) + .filter(|part| !part.is_empty()) + .filter(|part| is_ipv4(part)) + .map(str::to_string) + .collect() + } +} + +fn is_ipv4(value: &str) -> bool { + let parts = value.split('.').collect::>(); + parts.len() == 4 + && parts + .iter() + .all(|part| !part.is_empty() && part.parse::().is_ok()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::database::repositories::offenses::find_recent_offenses; + use crate::database::repositories::offenses::OffenseStatus; + use crate::database::{create_pool, init_database, list_alerts, AlertFilter}; + use chrono::Utc; + + #[actix_rt::test] + async fn test_extract_ip_candidates() { + let ips = IpBanEngine::extract_ip_candidates( + "Failed password for root from 192.0.2.4 port 51234 ssh2", + ); + assert_eq!(ips, vec!["192.0.2.4".to_string()]); + } + + #[actix_rt::test] + async fn test_record_offense_blocks_after_threshold() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let engine = IpBanEngine::new( + pool.clone(), + IpBanConfig { + enabled: true, + max_retries: 2, + find_time_secs: 300, + ban_time_secs: 60, + unban_check_interval_secs: 60, + }, + ); + + let first = engine + .record_offense(OffenseInput { + ip_address: "192.0.2.44".into(), + source_type: "sniff".into(), + reason: "Failed ssh login".into(), + severity: AlertSeverity::High, + container_id: None, + source_path: Some("/var/log/auth.log".into()), + sample_line: Some("Failed password from 192.0.2.44".into()), + }) + .await + .unwrap(); + let second = engine + .record_offense(OffenseInput { + ip_address: "192.0.2.44".into(), + source_type: "sniff".into(), + reason: "Failed ssh login".into(), + severity: AlertSeverity::High, + container_id: None, + source_path: Some("/var/log/auth.log".into()), + sample_line: Some("Failed password from 192.0.2.44".into()), + }) + .await + .unwrap(); + + assert!(!first); + assert!(second); + assert!(active_block_for_ip(&pool, "192.0.2.44").unwrap().is_some()); + } + + #[actix_rt::test] + async fn test_unban_expired_releases_ban_and_emits_release_alert() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let engine = IpBanEngine::new( + pool.clone(), + IpBanConfig { + enabled: true, + max_retries: 1, + find_time_secs: 300, + ban_time_secs: 0, + unban_check_interval_secs: 60, + }, + ); + + let blocked = engine + .record_offense(OffenseInput { + ip_address: "192.0.2.55".into(), + source_type: "sniff".into(), + reason: "Repeated ssh login failure".into(), + severity: AlertSeverity::Critical, + container_id: None, + source_path: Some("/var/log/auth.log".into()), + sample_line: Some("Failed password from 192.0.2.55".into()), + }) + .await + .unwrap(); + assert!(blocked); + + let released = engine.unban_expired().await.unwrap(); + assert_eq!(released, 1); + assert!(active_block_for_ip(&pool, "192.0.2.55").unwrap().is_none()); + + let offenses = find_recent_offenses( + &pool, + "192.0.2.55", + "sniff", + Utc::now() - Duration::minutes(5), + ) + .unwrap(); + assert_eq!(offenses.len(), 1); + assert_eq!(offenses[0].status, OffenseStatus::Released); + + let alerts = list_alerts(&pool, AlertFilter::default()).await.unwrap(); + assert_eq!(alerts.len(), 2); + assert_eq!(alerts[0].alert_type.to_string(), "SystemEvent"); + assert_eq!(alerts[0].message, "Released IP ban for 192.0.2.55"); + assert_eq!( + alerts[0] + .metadata + .as_ref() + .and_then(|metadata| metadata.source.as_deref()), + Some("ip_ban") + ); + assert_eq!( + alerts[0] + .metadata + .as_ref() + .and_then(|metadata| metadata.reason.as_deref()), + Some("Released expired ban for 192.0.2.55") + ); + } +} diff --git a/src/ip_ban/mod.rs b/src/ip_ban/mod.rs new file mode 100644 index 0000000..a34f677 --- /dev/null +++ b/src/ip_ban/mod.rs @@ -0,0 +1,5 @@ +pub mod config; +pub mod engine; + +pub use config::IpBanConfig; +pub use engine::{IpBanEngine, OffenseInput}; diff --git a/src/lib.rs b/src/lib.rs index 4541644..1f663f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,6 +51,7 @@ pub mod baselines; pub mod correlator; pub mod database; pub mod docker; +pub mod ip_ban; pub mod ml; pub mod response; diff --git a/src/main.rs b/src/main.rs index c1dc1e6..a9083c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -144,6 +144,25 @@ async fn run_serve() -> io::Result<()> { info!("Mail abuse guard disabled"); } + let ip_ban_config = stackdog::ip_ban::IpBanConfig::from_env(); + if ip_ban_config.enabled { + let ip_ban_pool = pool.clone(); + actix_rt::spawn(async move { + let engine = stackdog::ip_ban::IpBanEngine::new(ip_ban_pool, ip_ban_config); + loop { + if let Err(err) = engine.unban_expired().await { + log::warn!("IP ban unban pass failed: {}", err); + } + tokio::time::sleep(tokio::time::Duration::from_secs( + engine.config().unban_check_interval_secs, + )) + .await; + } + }); + } else { + info!("IP ban backend disabled"); + } + info!("🎉 Stackdog Security ready!"); info!(""); info!("API Endpoints:"); diff --git a/src/sniff/mod.rs b/src/sniff/mod.rs index 79fe8a5..4b5b1ea 100644 --- a/src/sniff/mod.rs +++ b/src/sniff/mod.rs @@ -13,6 +13,7 @@ pub mod reporter; use crate::alerting::notifications::NotificationConfig; use crate::database::connection::{create_pool, init_database, DbPool}; use crate::database::repositories::log_sources as log_sources_repo; +use crate::ip_ban::{IpBanConfig, IpBanEngine, OffenseInput}; use crate::sniff::analyzer::{LogAnalyzer, PatternAnalyzer}; use crate::sniff::config::SniffConfig; use crate::sniff::consumer::LogConsumer; @@ -26,6 +27,7 @@ pub struct SniffOrchestrator { config: SniffConfig, pool: DbPool, reporter: Reporter, + ip_ban: Option, } impl SniffOrchestrator { @@ -57,11 +59,16 @@ impl SniffOrchestrator { notification_config.with_email_recipients(config.email_recipients.clone()); } let reporter = Reporter::new(notification_config); + let ip_ban_config = IpBanConfig::from_env(); + let ip_ban = ip_ban_config + .enabled + .then(|| IpBanEngine::new(pool.clone(), ip_ban_config)); Ok(Self { config, pool, reporter, + ip_ban, }) } @@ -173,6 +180,9 @@ impl SniffOrchestrator { log::debug!("Step 5: reporting results..."); let report = self.reporter.report(&summary, Some(&self.pool)).await?; result.anomalies_found += report.anomalies_reported; + if let Some(engine) = &self.ip_ban { + self.apply_ip_ban(&summary, engine).await?; + } // 6. Consume (if enabled) if let Some(ref mut cons) = consumer { @@ -209,6 +219,37 @@ impl SniffOrchestrator { Ok(result) } + async fn apply_ip_ban( + &self, + summary: &analyzer::LogSummary, + engine: &IpBanEngine, + ) -> Result<()> { + for anomaly in &summary.anomalies { + let severity = match anomaly.severity { + analyzer::AnomalySeverity::Low => crate::alerting::AlertSeverity::Low, + analyzer::AnomalySeverity::Medium => crate::alerting::AlertSeverity::Medium, + analyzer::AnomalySeverity::High => crate::alerting::AlertSeverity::High, + analyzer::AnomalySeverity::Critical => crate::alerting::AlertSeverity::Critical, + }; + + for ip in IpBanEngine::extract_ip_candidates(&anomaly.sample_line) { + engine + .record_offense(OffenseInput { + ip_address: ip, + source_type: "sniff".into(), + reason: anomaly.description.clone(), + severity, + container_id: None, + source_path: None, + sample_line: Some(anomaly.sample_line.clone()), + }) + .await?; + } + } + + Ok(()) + } + /// Run the sniff loop (continuous or one-shot) pub async fn run(&self) -> Result<()> { log::info!("🔍 Sniff orchestrator started"); @@ -254,6 +295,51 @@ pub struct SniffPassResult { #[cfg(test)] mod tests { use super::*; + use crate::database::repositories::offenses::{active_block_for_ip, find_recent_offenses}; + use crate::database::{list_alerts, AlertFilter}; + use crate::ip_ban::{IpBanConfig, IpBanEngine}; + use crate::sniff::analyzer::{AnomalySeverity, LogAnomaly, LogSummary}; + use chrono::Utc; + + fn memory_sniff_config() -> SniffConfig { + let mut config = SniffConfig::from_env_and_args(config::SniffArgs { + once: true, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + webhook_url: None, + smtp_host: None, + smtp_port: None, + smtp_user: None, + smtp_password: None, + email_recipients: None, + }); + config.database_url = ":memory:".into(); + config + } + + fn make_summary(sample_line: &str, severity: analyzer::AnomalySeverity) -> LogSummary { + LogSummary { + source_id: "test-source".into(), + period_start: Utc::now(), + period_end: Utc::now(), + total_entries: 1, + summary_text: "Suspicious login activity".into(), + error_count: 1, + warning_count: 0, + key_events: vec!["Failed password attempts".into()], + anomalies: vec![LogAnomaly { + description: "Repeated failed ssh login".into(), + severity, + sample_line: sample_line.into(), + }], + } + } #[test] fn test_sniff_pass_result_default() { @@ -326,4 +412,95 @@ mod tests { assert!(result.sources_found >= 1); assert!(result.total_entries >= 3); } + + #[actix_rt::test] + async fn test_apply_ip_ban_records_offense_metadata_from_anomaly() { + let orchestrator = SniffOrchestrator::new(memory_sniff_config()).unwrap(); + let engine = IpBanEngine::new( + orchestrator.pool.clone(), + IpBanConfig { + enabled: true, + max_retries: 2, + find_time_secs: 300, + ban_time_secs: 60, + unban_check_interval_secs: 60, + }, + ); + let summary = make_summary( + "Failed password for root from 192.0.2.80 port 2222 ssh2", + AnomalySeverity::High, + ); + + orchestrator.apply_ip_ban(&summary, &engine).await.unwrap(); + + let offenses = find_recent_offenses( + &orchestrator.pool, + "192.0.2.80", + "sniff", + Utc::now() - chrono::Duration::minutes(5), + ) + .unwrap(); + assert_eq!(offenses.len(), 1); + assert_eq!(offenses[0].reason, "Repeated failed ssh login"); + assert_eq!( + offenses[0] + .metadata + .as_ref() + .and_then(|metadata| metadata.sample_line.as_deref()), + Some("Failed password for root from 192.0.2.80 port 2222 ssh2") + ); + assert!(active_block_for_ip(&orchestrator.pool, "192.0.2.80") + .unwrap() + .is_none()); + } + + #[actix_rt::test] + async fn test_apply_ip_ban_blocks_and_emits_alert_after_repeated_anomalies() { + let orchestrator = SniffOrchestrator::new(memory_sniff_config()).unwrap(); + let engine = IpBanEngine::new( + orchestrator.pool.clone(), + IpBanConfig { + enabled: true, + max_retries: 2, + find_time_secs: 300, + ban_time_secs: 60, + unban_check_interval_secs: 60, + }, + ); + let summary = make_summary( + "Failed password for root from 192.0.2.81 port 3333 ssh2", + AnomalySeverity::Critical, + ); + + orchestrator.apply_ip_ban(&summary, &engine).await.unwrap(); + orchestrator.apply_ip_ban(&summary, &engine).await.unwrap(); + + assert!(active_block_for_ip(&orchestrator.pool, "192.0.2.81") + .unwrap() + .is_some()); + + let alerts = list_alerts(&orchestrator.pool, AlertFilter::default()) + .await + .unwrap(); + assert_eq!(alerts.len(), 1); + assert_eq!(alerts[0].alert_type.to_string(), "ThresholdExceeded"); + assert_eq!( + alerts[0].message, + "Blocked IP 192.0.2.81 after repeated sniff offenses" + ); + assert_eq!( + alerts[0] + .metadata + .as_ref() + .and_then(|metadata| metadata.source.as_deref()), + Some("ip_ban") + ); + assert_eq!( + alerts[0] + .metadata + .as_ref() + .and_then(|metadata| metadata.reason.as_deref()), + Some("Repeated failed ssh login") + ); + } } diff --git a/tests/api/alerts_api_test.rs b/tests/api/alerts_api_test.rs index c27dfa3..025d517 100644 --- a/tests/api/alerts_api_test.rs +++ b/tests/api/alerts_api_test.rs @@ -1,40 +1,228 @@ //! Alerts API tests +use actix::Actor; +use actix_web::{test, web, App}; +use serde_json::Value; +use stackdog::alerting::{AlertSeverity, AlertStatus, AlertType}; +use stackdog::api::{alerts, websocket::WebSocketHub}; +use stackdog::database::models::{Alert, AlertMetadata}; +use stackdog::database::{create_alert, create_pool, init_database}; + #[cfg(test)] mod tests { + use super::*; + #[actix_rt::test] async fn test_list_alerts() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let mut alert = Alert::new( + AlertType::ThreatDetected, + AlertSeverity::High, + "Critical test alert", + ) + .with_metadata(AlertMetadata::default().with_source("tests")); + alert.status = AlertStatus::New; + create_alert(&pool, alert).await.unwrap(); + + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(alerts::configure_routes), + ) + .await; + let req = test::TestRequest::get().uri("/api/alerts").to_request(); + let body: Vec = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body.len(), 1); + assert_eq!(body[0]["alert_type"], "ThreatDetected"); + assert_eq!(body[0]["severity"], "High"); + assert_eq!(body[0]["status"], "New"); + assert_eq!(body[0]["metadata"]["source"], "tests"); } #[actix_rt::test] async fn test_list_alerts_filter_by_severity() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + let mut high = Alert::new(AlertType::ThreatDetected, AlertSeverity::High, "High"); + high.status = AlertStatus::New; + create_alert(&pool, high).await.unwrap(); + + let mut low = Alert::new(AlertType::ThreatDetected, AlertSeverity::Low, "Low"); + low.status = AlertStatus::New; + create_alert(&pool, low).await.unwrap(); + + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(alerts::configure_routes), + ) + .await; + let req = test::TestRequest::get() + .uri("/api/alerts?severity=High") + .to_request(); + let body: Vec = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body.len(), 1); + assert_eq!(body[0]["message"], "High"); } #[actix_rt::test] async fn test_list_alerts_filter_by_status() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + let mut new_alert = Alert::new(AlertType::ThreatDetected, AlertSeverity::High, "New alert"); + new_alert.status = AlertStatus::New; + create_alert(&pool, new_alert).await.unwrap(); + + let mut acknowledged = Alert::new( + AlertType::RuleViolation, + AlertSeverity::Medium, + "Acknowledged alert", + ); + acknowledged.status = AlertStatus::Acknowledged; + create_alert(&pool, acknowledged).await.unwrap(); + + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(alerts::configure_routes), + ) + .await; + let req = test::TestRequest::get() + .uri("/api/alerts?status=Acknowledged") + .to_request(); + let body: Vec = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body.len(), 1); + assert_eq!(body[0]["status"], "Acknowledged"); + assert_eq!(body[0]["message"], "Acknowledged alert"); } #[actix_rt::test] async fn test_get_alert_stats() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + let statuses = [ + AlertStatus::New, + AlertStatus::Acknowledged, + AlertStatus::Resolved, + AlertStatus::FalsePositive, + ]; + for status in statuses { + let mut alert = Alert::new( + AlertType::ThreatDetected, + AlertSeverity::High, + format!("{status}"), + ); + alert.status = status; + create_alert(&pool, alert).await.unwrap(); + } + + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(alerts::configure_routes), + ) + .await; + let req = test::TestRequest::get() + .uri("/api/alerts/stats") + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body["total_count"], 4); + assert_eq!(body["new_count"], 1); + assert_eq!(body["acknowledged_count"], 1); + assert_eq!(body["resolved_count"], 1); + assert_eq!(body["false_positive_count"], 1); } #[actix_rt::test] async fn test_acknowledge_alert() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let alert = create_alert( + &pool, + Alert::new( + AlertType::ThreatDetected, + AlertSeverity::High, + "Needs acknowledgement", + ), + ) + .await + .unwrap(); + + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool.clone())) + .app_data(web::Data::new(hub)) + .configure(alerts::configure_routes), + ) + .await; + let req = test::TestRequest::post() + .uri(&format!("/api/alerts/{}/acknowledge", alert.id)) + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + let req = test::TestRequest::get() + .uri("/api/alerts?status=Acknowledged") + .to_request(); + let alerts: Vec = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body["success"], true); + assert_eq!(alerts.len(), 1); + assert_eq!(alerts[0]["id"], alert.id); } #[actix_rt::test] async fn test_resolve_alert() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let alert = create_alert( + &pool, + Alert::new( + AlertType::RuleViolation, + AlertSeverity::Medium, + "Needs resolution", + ), + ) + .await + .unwrap(); + + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool.clone())) + .app_data(web::Data::new(hub)) + .configure(alerts::configure_routes), + ) + .await; + let req = test::TestRequest::post() + .uri(&format!("/api/alerts/{}/resolve", alert.id)) + .set_json(serde_json::json!({ "note": "resolved in test" })) + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + let req = test::TestRequest::get() + .uri("/api/alerts?status=Resolved") + .to_request(); + let alerts: Vec = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body["success"], true); + assert_eq!(alerts.len(), 1); + assert_eq!(alerts[0]["id"], alert.id); } } diff --git a/tests/api/containers_api_test.rs b/tests/api/containers_api_test.rs index 036f1ac..76ab108 100644 --- a/tests/api/containers_api_test.rs +++ b/tests/api/containers_api_test.rs @@ -1,22 +1,76 @@ //! Containers API tests +use actix::Actor; +use actix_web::{http::StatusCode, test, web, App}; +use serde_json::Value; +use stackdog::api::{containers, websocket::WebSocketHub}; +use stackdog::database::{create_pool, init_database}; + #[cfg(test)] mod tests { + use super::*; + #[actix_rt::test] async fn test_list_containers() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(containers::configure_routes), + ) + .await; + let req = test::TestRequest::get().uri("/api/containers").to_request(); + let resp = test::call_service(&app, req).await; + + assert!(matches!( + resp.status(), + StatusCode::OK | StatusCode::SERVICE_UNAVAILABLE + )); } #[actix_rt::test] async fn test_quarantine_container() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(containers::configure_routes), + ) + .await; + let req = test::TestRequest::post() + .uri("/api/containers/container-1/quarantine") + .set_json(serde_json::json!({ "reason": "integration-test" })) + .to_request(); + let resp = test::call_service(&app, req).await; + let body: Value = test::read_body_json(resp).await; + + assert!(body.get("success").is_some() || body.get("error").is_some()); } #[actix_rt::test] async fn test_release_container() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let hub = WebSocketHub::new().start(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .app_data(web::Data::new(hub)) + .configure(containers::configure_routes), + ) + .await; + let req = test::TestRequest::post() + .uri("/api/containers/container-1/release") + .to_request(); + let resp = test::call_service(&app, req).await; + let body: Value = test::read_body_json(resp).await; + + assert!(body.get("success").is_some() || body.get("error").is_some()); } } diff --git a/tests/api/mod.rs b/tests/api/mod.rs index 8302790..63d8ea5 100644 --- a/tests/api/mod.rs +++ b/tests/api/mod.rs @@ -1,7 +1,7 @@ //! API integration tests -mod security_api_test; mod alerts_api_test; mod containers_api_test; +mod security_api_test; mod threats_api_test; mod websocket_test; diff --git a/tests/api/security_api_test.rs b/tests/api/security_api_test.rs index 2c086c4..7afd5c5 100644 --- a/tests/api/security_api_test.rs +++ b/tests/api/security_api_test.rs @@ -1,7 +1,11 @@ //! Security API tests -use actix_web::{test, App}; -use serde_json::json; +use actix_web::{test, web, App}; +use serde_json::Value; +use stackdog::alerting::{AlertSeverity, AlertStatus, AlertType}; +use stackdog::api::security; +use stackdog::database::models::{Alert, AlertMetadata}; +use stackdog::database::{create_alert, create_pool, init_database}; #[cfg(test)] mod tests { @@ -9,15 +13,72 @@ mod tests { #[actix_rt::test] async fn test_get_security_status() { - // TODO: Implement when API is ready - // This test will verify the security status endpoint - assert!(true, "Test placeholder - implement when API endpoints are ready"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + create_alert( + &pool, + Alert::new( + AlertType::ThreatDetected, + AlertSeverity::High, + "Open threat", + ), + ) + .await + .unwrap(); + let mut quarantine = Alert::new( + AlertType::QuarantineApplied, + AlertSeverity::High, + "Container quarantined", + ) + .with_metadata(AlertMetadata::default().with_container_id("container-1")); + quarantine.status = AlertStatus::Acknowledged; + create_alert(&pool, quarantine).await.unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(security::configure_routes), + ) + .await; + + let req = test::TestRequest::get() + .uri("/api/security/status") + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body["active_threats"], 1); + assert_eq!(body["quarantined_containers"], 1); + assert_eq!(body["alerts_new"], 1); + assert_eq!(body["alerts_acknowledged"], 1); + assert!(body["overall_score"].as_u64().unwrap() < 100); + assert!(body["last_updated"].as_str().is_some()); } #[actix_rt::test] async fn test_security_status_format() { - // TODO: Implement when API is ready - // This test will verify the response format - assert!(true, "Test placeholder - implement when API endpoints are ready"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(security::configure_routes), + ) + .await; + + let req = test::TestRequest::get() + .uri("/api/security/status") + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + for key in [ + "overall_score", + "active_threats", + "quarantined_containers", + "alerts_new", + "alerts_acknowledged", + "last_updated", + ] { + assert!(body.get(key).is_some(), "missing key {key}"); + } } } diff --git a/tests/api/threats_api_test.rs b/tests/api/threats_api_test.rs index 21f6c6b..d0edc4c 100644 --- a/tests/api/threats_api_test.rs +++ b/tests/api/threats_api_test.rs @@ -1,22 +1,115 @@ //! Threats API tests +use actix_web::{test, web, App}; +use serde_json::Value; +use stackdog::alerting::{AlertSeverity, AlertStatus, AlertType}; +use stackdog::api::threats; +use stackdog::database::models::{Alert, AlertMetadata}; +use stackdog::database::{create_alert, create_pool, init_database}; + #[cfg(test)] mod tests { + use super::*; + #[actix_rt::test] async fn test_list_threats() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + create_alert( + &pool, + Alert::new( + AlertType::ThresholdExceeded, + AlertSeverity::Critical, + "Blocked IP", + ) + .with_metadata(AlertMetadata::default().with_source("ip_ban")), + ) + .await + .unwrap(); + create_alert( + &pool, + Alert::new(AlertType::SystemEvent, AlertSeverity::Info, "Ignore me"), + ) + .await + .unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(threats::configure_routes), + ) + .await; + + let req = test::TestRequest::get().uri("/api/threats").to_request(); + let body: Vec = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body.len(), 1); + assert_eq!(body[0]["type"], "ThresholdExceeded"); + assert_eq!(body[0]["severity"], "Critical"); + assert_eq!(body[0]["score"], 95); + assert_eq!(body[0]["source"], "ip_ban"); } #[actix_rt::test] async fn test_get_threat_statistics() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + let mut unresolved = Alert::new( + AlertType::ThreatDetected, + AlertSeverity::High, + "Open threat", + ); + unresolved.status = AlertStatus::New; + create_alert(&pool, unresolved).await.unwrap(); + + let mut resolved = Alert::new( + AlertType::RuleViolation, + AlertSeverity::Medium, + "Resolved threat", + ); + resolved.status = AlertStatus::Resolved; + create_alert(&pool, resolved).await.unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(threats::configure_routes), + ) + .await; + + let req = test::TestRequest::get() + .uri("/api/threats/statistics") + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + assert_eq!(body["total_threats"], 2); + assert_eq!(body["by_severity"]["High"], 1); + assert_eq!(body["by_severity"]["Medium"], 1); + assert_eq!(body["by_type"]["ThreatDetected"], 1); + assert_eq!(body["by_type"]["RuleViolation"], 1); + assert_eq!(body["trend"], "stable"); } #[actix_rt::test] async fn test_statistics_format() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(threats::configure_routes), + ) + .await; + + let req = test::TestRequest::get() + .uri("/api/threats/statistics") + .to_request(); + let body: Value = test::call_and_read_body_json(&app, req).await; + + for key in ["total_threats", "by_severity", "by_type", "trend"] { + assert!(body.get(key).is_some(), "missing key {key}"); + } } } diff --git a/tests/firewall/response_test.rs b/tests/firewall/response_test.rs index c4bdd4a..bcd5c19 100644 --- a/tests/firewall/response_test.rs +++ b/tests/firewall/response_test.rs @@ -138,6 +138,18 @@ fn test_response_from_alert() { assert!(action.description().contains("Critical threat")); } +#[test] +fn test_quarantine_response_is_explicitly_unsupported_in_sync_pipeline() { + let action = ResponseAction::new( + ResponseType::QuarantineContainer("test-container".to_string()), + "Quarantine container".to_string(), + ); + + let error = action.execute().unwrap_err().to_string(); + assert!(error.contains("Docker-based container quarantine flow")); + assert!(error.contains("test-container")); +} + #[test] fn test_response_retry() { let mut action = ResponseAction::new( diff --git a/tests/integration.rs b/tests/integration.rs index 3c1529d..2cc2b82 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -2,6 +2,7 @@ //! //! These tests verify that multiple components work together correctly. +mod api; mod collectors; mod events; mod structure; diff --git a/web/package.json b/web/package.json index c1fae63..ff62a3b 100644 --- a/web/package.json +++ b/web/package.json @@ -5,62 +5,64 @@ "scripts": { "start": "cross-env REACT_APP_VERSION=$npm_package_version webpack serve --mode development", "build": "cross-env REACT_APP_VERSION=$npm_package_version webpack --mode production", - "test": "jest --config jest.json", - "coverage": "jest --config jest.json --collect-coverage", + "test": "jest --config jest.config.js", + "coverage": "jest --config jest.config.js --collect-coverage", "lint": "eslint src --ext .ts,.tsx" }, "dependencies": { - "react": "^18.2.0", - "react-dom": "^18.2.0", - "react-router-dom": "^6.20.0", "@reduxjs/toolkit": "^1.9.7", - "react-redux": "^8.1.3", - "redux-saga": "^1.2.3", + "archiver": "^6.0.1", "axios": "^1.6.2", - "recharts": "^2.10.3", "bootstrap": "^5.3.2", - "react-bootstrap": "^2.9.1", - "styled-components": "^6.1.2", "date-fns": "^2.30.0", "lodash": "^4.17.21", - "uuid": "^9.0.1", - "archiver": "^6.0.1" + "react": "^18.2.0", + "react-bootstrap": "^2.9.1", + "react-dom": "^18.2.0", + "react-redux": "^8.1.3", + "react-router-dom": "^6.20.0", + "recharts": "^2.10.3", + "redux-saga": "^1.2.3", + "styled-components": "^6.1.2", + "uuid": "^9.0.1" }, "devDependencies": { "@babel/core": "^7.23.5", - "@types/react": "^18.2.43", - "@types/react-dom": "^18.2.17", + "@testing-library/jest-dom": "^6.1.5", + "@testing-library/react": "^14.1.2", + "@types/archiver": "^6.0.2", "@types/jest": "^29.5.11", - "@types/node": "^20.10.4", "@types/lodash": "^4.14.202", + "@types/node": "^20.10.4", + "@types/react": "^18.2.43", + "@types/react-dom": "^18.2.17", "@types/uuid": "^9.0.7", - "@types/archiver": "^6.0.2", "@types/webpack": "^5.28.5", "@types/webpack-dev-server": "^4.7.2", "@types/webpack-env": "^1.18.4", - "@testing-library/react": "^14.1.2", - "@testing-library/jest-dom": "^6.1.5", + "@typescript-eslint/eslint-plugin": "^6.14.0", + "@typescript-eslint/parser": "^6.14.0", "babel-loader": "^9.1.3", - "ts-loader": "^9.5.1", - "typescript": "^5.3.3", - "ts-node": "^10.9.2", - "webpack": "^5.89.0", - "webpack-cli": "^5.1.4", - "webpack-dev-server": "^4.15.1", - "html-webpack-plugin": "^5.5.4", "clean-webpack-plugin": "^4.0.0", "copy-webpack-plugin": "^11.0.0", - "terser-webpack-plugin": "^5.3.9", "cross-env": "^7.0.3", - "jest": "^29.7.0", - "ts-jest": "^29.1.1", + "css-loader": "^7.1.2", "eslint": "^8.55.0", - "@typescript-eslint/parser": "^6.14.0", - "@typescript-eslint/eslint-plugin": "^6.14.0", "eslint-plugin-react": "^7.33.2", "eslint-plugin-react-hooks": "^4.6.0", + "html-webpack-plugin": "^5.5.4", + "identity-obj-proxy": "^3.0.0", + "jest": "^29.7.0", + "jest-environment-jsdom": "^30.3.0", "style-loader": "^4.0.0", - "css-loader": "^7.1.2" + "terser-webpack-plugin": "^5.3.9", + "ts-jest": "^29.1.1", + "ts-loader": "^9.5.1", + "ts-node": "^10.9.2", + "typescript": "^5.3.3", + "webpack": "^5.89.0", + "webpack-cli": "^5.1.4", + "webpack-dev-server": "^4.15.1" }, "browserslist": { "production": [ diff --git a/web/src/components/__tests__/AlertPanel.test.tsx b/web/src/components/__tests__/AlertPanel.test.tsx index fec05ab..c231457 100644 --- a/web/src/components/__tests__/AlertPanel.test.tsx +++ b/web/src/components/__tests__/AlertPanel.test.tsx @@ -29,6 +29,7 @@ const mockAlerts = [ describe('AlertPanel Component', () => { beforeEach(() => { + jest.clearAllMocks(); (apiService.getAlerts as jest.Mock).mockResolvedValue(mockAlerts); (apiService.getAlertStats as jest.Mock).mockResolvedValue({ totalCount: 10, @@ -36,14 +37,15 @@ describe('AlertPanel Component', () => { acknowledgedCount: 3, resolvedCount: 2, }); + (webSocketService.connect as jest.Mock).mockResolvedValue(undefined); + (webSocketService.subscribe as jest.Mock).mockReturnValue(() => {}); + (webSocketService.disconnect as jest.Mock).mockImplementation(() => {}); }); test('lists alerts correctly', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); - }); + expect(await screen.findByText('Suspicious activity detected')).toBeInTheDocument(); expect(screen.getByText('Rule violation detected')).toBeInTheDocument(); }); @@ -51,29 +53,27 @@ describe('AlertPanel Component', () => { test('filters alerts by severity', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); - }); + expect(await screen.findByText('Suspicious activity detected')).toBeInTheDocument(); const severityFilter = screen.getByLabelText('Filter by severity'); fireEvent.change(severityFilter, { target: { value: 'High' } }); - // Should only show High severity alerts - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); + await waitFor(() => { + expect(apiService.getAlerts).toHaveBeenLastCalledWith({ severity: ['High'] }); + }); }); test('filters alerts by status', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); - }); + expect(await screen.findByText('Suspicious activity detected')).toBeInTheDocument(); const statusFilter = screen.getByLabelText('Filter by status'); fireEvent.change(statusFilter, { target: { value: 'New' } }); - // Should only show New alerts - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); + await waitFor(() => { + expect(apiService.getAlerts).toHaveBeenLastCalledWith({ status: ['New'] }); + }); }); test('acknowledge alert works', async () => { @@ -81,11 +81,9 @@ describe('AlertPanel Component', () => { render(); - await waitFor(() => { - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); - }); + expect(await screen.findByText('Suspicious activity detected')).toBeInTheDocument(); - const acknowledgeButton = screen.getByText('Acknowledge'); + const acknowledgeButton = screen.getAllByText('Acknowledge')[0]; fireEvent.click(acknowledgeButton); await waitFor(() => { @@ -98,15 +96,13 @@ describe('AlertPanel Component', () => { render(); - await waitFor(() => { - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); - }); + expect(await screen.findByText('Suspicious activity detected')).toBeInTheDocument(); - const resolveButton = screen.getByText('Resolve'); + const resolveButton = screen.getAllByText('Resolve')[0]; fireEvent.click(resolveButton); await waitFor(() => { - expect(apiService.resolveAlert).toHaveBeenCalledWith('alert-1', expect.any(String)); + expect(apiService.resolveAlert).toHaveBeenCalledWith('alert-1', 'Resolved via dashboard'); }); }); @@ -136,9 +132,7 @@ describe('AlertPanel Component', () => { render(); - await waitFor(() => { - expect(screen.getByText('Alert 0')).toBeInTheDocument(); - }); + expect(await screen.findByText('Alert 0')).toBeInTheDocument(); // Should show first 10 alerts expect(screen.getByText('Alert 0')).toBeInTheDocument(); @@ -148,7 +142,6 @@ describe('AlertPanel Component', () => { const nextPageButton = screen.getByText('Next'); fireEvent.click(nextPageButton); - // Should show next 10 alerts await waitFor(() => { expect(screen.getByText('Alert 10')).toBeInTheDocument(); }); @@ -159,18 +152,18 @@ describe('AlertPanel Component', () => { render(); - await waitFor(() => { - expect(screen.getByText('Suspicious activity detected')).toBeInTheDocument(); - }); + expect(await screen.findByText('Suspicious activity detected')).toBeInTheDocument(); const selectAllCheckbox = screen.getByLabelText('Select all alerts'); fireEvent.click(selectAllCheckbox); - const bulkAcknowledgeButton = screen.getByText('Acknowledge Selected'); + const bulkAcknowledgeButton = await screen.findByText(/Acknowledge Selected/); fireEvent.click(bulkAcknowledgeButton); await waitFor(() => { - expect(apiService.acknowledgeAlert).toHaveBeenCalled(); + expect(apiService.acknowledgeAlert).toHaveBeenCalledTimes(2); + expect(apiService.acknowledgeAlert).toHaveBeenCalledWith('alert-1'); + expect(apiService.acknowledgeAlert).toHaveBeenCalledWith('alert-2'); }); }); }); diff --git a/web/src/components/__tests__/ContainerList.test.tsx b/web/src/components/__tests__/ContainerList.test.tsx index 802fdf2..58caaf0 100644 --- a/web/src/components/__tests__/ContainerList.test.tsx +++ b/web/src/components/__tests__/ContainerList.test.tsx @@ -59,6 +59,7 @@ const mockContainers = [ describe('ContainerList Component', () => { beforeEach(() => { + jest.clearAllMocks(); (apiService.getContainers as jest.Mock).mockResolvedValue(mockContainers); }); @@ -75,12 +76,10 @@ describe('ContainerList Component', () => { test('shows security status per container', async () => { render(); - await waitFor(() => { - expect(screen.getByText('web-server')).toBeInTheDocument(); - }); + expect(await screen.findByText('web-server')).toBeInTheDocument(); expect(screen.getByText('Secure')).toBeInTheDocument(); - expect(screen.getByText('At Risk')).toBeInTheDocument(); + expect(screen.getByText('AtRisk')).toBeInTheDocument(); }); test('displays risk scores', async () => { @@ -99,11 +98,9 @@ describe('ContainerList Component', () => { render(); - await waitFor(() => { - expect(screen.getByText('database')).toBeInTheDocument(); - }); + expect(await screen.findByText('database')).toBeInTheDocument(); - const quarantineButton = screen.getByText('Quarantine'); + const quarantineButton = screen.getAllByText('Quarantine')[1]; fireEvent.click(quarantineButton); // Should show confirmation modal @@ -146,14 +143,14 @@ describe('ContainerList Component', () => { test('filters by status', async () => { render(); - await waitFor(() => { - expect(screen.getByText('web-server')).toBeInTheDocument(); - }); + expect(await screen.findByText('web-server')).toBeInTheDocument(); const statusFilter = screen.getByLabelText('Filter by status'); fireEvent.change(statusFilter, { target: { value: 'Running' } }); - // Should only show Running containers + await waitFor(() => { + expect(apiService.getContainers).toHaveBeenCalledTimes(2); + }); expect(screen.getByText('web-server')).toBeInTheDocument(); expect(screen.getByText('database')).toBeInTheDocument(); }); diff --git a/web/src/components/__tests__/Dashboard.test.tsx b/web/src/components/__tests__/Dashboard.test.tsx new file mode 100644 index 0000000..3e552d4 --- /dev/null +++ b/web/src/components/__tests__/Dashboard.test.tsx @@ -0,0 +1,99 @@ +import React from 'react'; +import { act, render, screen, waitFor } from '@testing-library/react'; +import Dashboard from '../Dashboard'; +import apiService from '../../services/api'; +import webSocketService from '../../services/websocket'; + +jest.mock('../../services/api'); +jest.mock('../../services/websocket'); +jest.mock('../AlertPanel', () => () =>
AlertPanel
); +jest.mock('../ContainerList', () => () =>
ContainerList
); +jest.mock('../ThreatMap', () => () =>
ThreatMap
); +jest.mock('../SecurityScore', () => ({ score }: { score: number }) => ( +
SecurityScore:{score}
+)); + +describe('Dashboard Component', () => { + const baseStatus = { + overallScore: 88, + activeThreats: 2, + quarantinedContainers: 1, + alertsNew: 4, + alertsAcknowledged: 3, + lastUpdated: '2026-04-04T08:00:00.000Z', + }; + + const subscriptions = new Map void>(); + + beforeEach(() => { + jest.clearAllMocks(); + subscriptions.clear(); + (apiService.getSecurityStatus as jest.Mock).mockResolvedValue(baseStatus); + (webSocketService.connect as jest.Mock).mockResolvedValue(undefined); + (webSocketService.subscribe as jest.Mock).mockImplementation((event, handler) => { + subscriptions.set(event, handler); + return () => subscriptions.delete(event); + }); + (webSocketService.disconnect as jest.Mock).mockImplementation(() => {}); + }); + + test('loads and displays security status summary', async () => { + render(); + + expect(await screen.findByText('SecurityScore:88')).toBeInTheDocument(); + expect(screen.getByText('2')).toBeInTheDocument(); + expect(screen.getByText('1')).toBeInTheDocument(); + expect(screen.getByText('4')).toBeInTheDocument(); + expect(screen.getByText('AlertPanel')).toBeInTheDocument(); + expect(screen.getByText('ContainerList')).toBeInTheDocument(); + expect(screen.getByText('ThreatMap')).toBeInTheDocument(); + }); + + test('shows an error state when status loading fails', async () => { + (apiService.getSecurityStatus as jest.Mock).mockRejectedValue(new Error('boom')); + + render(); + + expect(await screen.findByText('Failed to load security status')).toBeInTheDocument(); + }); + + test('applies websocket stats updates to the rendered summary', async () => { + render(); + + expect(await screen.findByText('SecurityScore:88')).toBeInTheDocument(); + + await act(async () => { + subscriptions.get('stats:updated')?.({ + overallScore: 65, + activeThreats: 5, + alertsNew: 6, + }); + }); + + expect(screen.getByText('SecurityScore:65')).toBeInTheDocument(); + expect(screen.getByText('5')).toBeInTheDocument(); + expect(screen.getByText('6')).toBeInTheDocument(); + }); + + test('refreshes security status when an alert is created and disconnects on unmount', async () => { + const { unmount } = render(); + + expect(await screen.findByText('SecurityScore:88')).toBeInTheDocument(); + + (apiService.getSecurityStatus as jest.Mock).mockResolvedValueOnce({ + ...baseStatus, + activeThreats: 3, + }); + + await act(async () => { + subscriptions.get('alert:created')?.(); + }); + + await waitFor(() => { + expect(apiService.getSecurityStatus).toHaveBeenCalledTimes(2); + }); + + unmount(); + expect(webSocketService.disconnect).toHaveBeenCalled(); + }); +}); diff --git a/web/src/components/__tests__/SecurityScore.test.tsx b/web/src/components/__tests__/SecurityScore.test.tsx new file mode 100644 index 0000000..bca2067 --- /dev/null +++ b/web/src/components/__tests__/SecurityScore.test.tsx @@ -0,0 +1,28 @@ +import React from 'react'; +import { render, screen } from '@testing-library/react'; +import SecurityScore from '../SecurityScore'; + +describe('SecurityScore Component', () => { + test('renders secure label for high scores', () => { + render(); + + expect(screen.getByText('88')).toBeInTheDocument(); + expect(screen.getByText('Secure')).toBeInTheDocument(); + }); + + test('renders moderate and at-risk thresholds correctly', () => { + const { rerender } = render(); + expect(screen.getByText('Moderate')).toBeInTheDocument(); + + rerender(); + expect(screen.getByText('At Risk')).toBeInTheDocument(); + }); + + test('renders critical label and gauge rotation for low scores', () => { + const { container } = render(); + + expect(screen.getByText('Critical')).toBeInTheDocument(); + const gaugeFill = container.querySelector('.gauge-fill'); + expect(gaugeFill).toHaveStyle({ transform: 'rotate(-54deg)' }); + }); +}); diff --git a/web/src/components/__tests__/ThreatMap.test.tsx b/web/src/components/__tests__/ThreatMap.test.tsx index 8ee0290..36112cf 100644 --- a/web/src/components/__tests__/ThreatMap.test.tsx +++ b/web/src/components/__tests__/ThreatMap.test.tsx @@ -55,6 +55,7 @@ const mockStatistics = { describe('ThreatMap Component', () => { beforeEach(() => { + jest.clearAllMocks(); (apiService.getThreats as jest.Mock).mockResolvedValue(mockThreats); (apiService.getThreatStatistics as jest.Mock).mockResolvedValue(mockStatistics); }); @@ -62,9 +63,7 @@ describe('ThreatMap Component', () => { test('displays threat type distribution', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Threat Type Distribution')).toBeInTheDocument(); - }); + expect(await screen.findByText('Threat Type Distribution')).toBeInTheDocument(); expect(screen.getByText('CryptoMiner')).toBeInTheDocument(); expect(screen.getByText('ContainerEscape')).toBeInTheDocument(); @@ -74,46 +73,34 @@ describe('ThreatMap Component', () => { test('displays severity breakdown', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Severity Breakdown')).toBeInTheDocument(); - }); + expect(await screen.findByText('Severity Breakdown')).toBeInTheDocument(); - expect(screen.getByText('Critical')).toBeInTheDocument(); - expect(screen.getByText('High')).toBeInTheDocument(); - expect(screen.getByText('Medium')).toBeInTheDocument(); - expect(screen.getByText('Low')).toBeInTheDocument(); - expect(screen.getByText('Info')).toBeInTheDocument(); + expect(screen.getByText('Recent Threats')).toBeInTheDocument(); + expect(screen.getByText('Score: 95')).toBeInTheDocument(); }); test('displays threat timeline', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Threat Timeline')).toBeInTheDocument(); - }); + expect(await screen.findByText('Threat Timeline')).toBeInTheDocument(); - // Timeline should show threats over time - expect(screen.getByText('Total Threats: 10')).toBeInTheDocument(); + expect(screen.getByText('Total Threats')).toBeInTheDocument(); + expect(screen.getByText('10')).toBeInTheDocument(); }); test('charts are interactive', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Threat Type Distribution')).toBeInTheDocument(); - }); + expect(await screen.findByText('Threat Type Distribution')).toBeInTheDocument(); - // Hover over chart element (simulated) - const chartElement = screen.getByText('CryptoMiner: 3'); - expect(chartElement).toBeInTheDocument(); + expect(screen.getByText('Score: 85')).toBeInTheDocument(); + expect(screen.getAllByText('container-1')).toHaveLength(2); }); test('filters by date range', async () => { render(); - await waitFor(() => { - expect(screen.getByText('Threat Type Distribution')).toBeInTheDocument(); - }); + expect(await screen.findByText('Threat Type Distribution')).toBeInTheDocument(); const dateFromInput = screen.getByLabelText('From'); const dateToInput = screen.getByLabelText('To'); @@ -121,9 +108,9 @@ describe('ThreatMap Component', () => { fireEvent.change(dateFromInput, { target: { value: '2026-01-01' } }); fireEvent.change(dateToInput, { target: { value: '2026-12-31' } }); - // Should filter threats by date range await waitFor(() => { - expect(apiService.getThreats).toHaveBeenCalled(); + expect(apiService.getThreats).toHaveBeenCalledTimes(3); + expect(apiService.getThreatStatistics).toHaveBeenCalledTimes(3); }); }); }); diff --git a/web/src/services/__tests__/security.test.ts b/web/src/services/__tests__/security.test.ts index 4d12f3d..f547314 100644 --- a/web/src/services/__tests__/security.test.ts +++ b/web/src/services/__tests__/security.test.ts @@ -1,5 +1,4 @@ import apiService from '../api'; -import { AlertSeverity, AlertStatus } from '../../types/alerts'; // Mock axios jest.mock('axios', () => ({ @@ -14,14 +13,14 @@ describe('API Service', () => { jest.clearAllMocks(); }); - test('fetches security status from API', async () => { + test('maps snake_case security status fields to camelCase', async () => { const mockStatus = { - overallScore: 85, - activeThreats: 3, - quarantinedContainers: 1, - alertsNew: 5, - alertsAcknowledged: 2, - lastUpdated: new Date().toISOString(), + overall_score: 85, + active_threats: 3, + quarantined_containers: 1, + alerts_new: 5, + alerts_acknowledged: 2, + last_updated: new Date().toISOString(), }; (apiService.api.get as jest.Mock).mockResolvedValue({ data: mockStatus }); @@ -29,27 +28,95 @@ describe('API Service', () => { const status = await apiService.getSecurityStatus(); expect(apiService.api.get).toHaveBeenCalledWith('/security/status'); - expect(status).toEqual(mockStatus); + expect(status).toEqual({ + overallScore: 85, + activeThreats: 3, + quarantinedContainers: 1, + alertsNew: 5, + alertsAcknowledged: 2, + lastUpdated: mockStatus.last_updated, + }); }); - test('fetches alerts from API', async () => { + test('maps snake_case alerts and alert stats from the API', async () => { const mockAlerts = [ { id: 'alert-1', - alertType: 'ThreatDetected', + alert_type: 'ThreatDetected', severity: 'High', message: 'Test alert', status: 'New', timestamp: new Date().toISOString(), + metadata: { source: 'api' }, }, ]; + const mockAlertStats = { + total_count: 8, + new_count: 5, + acknowledged_count: 2, + resolved_count: 1, + }; - (apiService.api.get as jest.Mock).mockResolvedValue({ data: mockAlerts }); + (apiService.api.get as jest.Mock) + .mockResolvedValueOnce({ data: mockAlerts }) + .mockResolvedValueOnce({ data: mockAlertStats }); const alerts = await apiService.getAlerts(); + const stats = await apiService.getAlertStats(); expect(apiService.api.get).toHaveBeenCalledWith('/alerts', expect.anything()); - expect(alerts).toEqual(mockAlerts); + expect(apiService.api.get).toHaveBeenCalledWith('/alerts/stats'); + expect(alerts).toEqual([ + { + id: 'alert-1', + alertType: 'ThreatDetected', + severity: 'High', + message: 'Test alert', + status: 'New', + timestamp: mockAlerts[0].timestamp, + metadata: { source: 'api' }, + }, + ]); + expect(stats).toEqual({ + totalCount: 8, + newCount: 5, + acknowledgedCount: 2, + resolvedCount: 1, + falsePositiveCount: 0, + }); + }); + + test('maps snake_case threat statistics from the API', async () => { + const mockThreatStats = { + total_threats: 3, + by_severity: { + Critical: 1, + High: 2, + }, + by_type: { + ThreatDetected: 2, + ThresholdExceeded: 1, + }, + trend: 'increasing', + }; + + (apiService.api.get as jest.Mock).mockResolvedValue({ data: mockThreatStats }); + + const stats = await apiService.getThreatStatistics(); + + expect(apiService.api.get).toHaveBeenCalledWith('/threats/statistics'); + expect(stats).toEqual({ + totalThreats: 3, + bySeverity: { + Critical: 1, + High: 2, + }, + byType: { + ThreatDetected: 2, + ThresholdExceeded: 1, + }, + trend: 'increasing', + }); }); test('acknowledges alert via API', async () => { @@ -86,7 +153,32 @@ describe('API Service', () => { const containers = await apiService.getContainers(); expect(apiService.api.get).toHaveBeenCalledWith('/containers'); - expect(containers).toEqual(mockContainers); + expect(containers).toEqual([ + { + id: 'container-1', + name: 'test-container', + image: 'unknown', + status: 'Running', + securityStatus: { + state: 'Secure', + threats: 0, + vulnerabilities: null, + lastScan: null, + }, + riskScore: 10, + networkActivity: { + inboundConnections: null, + outboundConnections: null, + blockedConnections: null, + receivedBytes: null, + transmittedBytes: null, + receivedPackets: null, + transmittedPackets: null, + suspiciousActivity: false, + }, + createdAt: expect.any(String), + }, + ]); }); test('quarantines container via API', async () => { diff --git a/web/src/services/__tests__/websocket.test.ts b/web/src/services/__tests__/websocket.test.ts index 272a8a9..b711977 100644 --- a/web/src/services/__tests__/websocket.test.ts +++ b/web/src/services/__tests__/websocket.test.ts @@ -1,118 +1,110 @@ -import { WebSocketService, webSocketService } from '../websocket'; +import { WebSocketService } from '../websocket'; describe('WebSocket Service', () => { let ws: WebSocketService; + const originalWebSocket = global.WebSocket; + + const createMockSocket = (readyState: number = WebSocket.CONNECTING) => ({ + onopen: null as (() => void) | null, + onmessage: null as ((event: MessageEvent) => void) | null, + onclose: null as (() => void) | null, + onerror: null as ((event: Event) => void) | null, + readyState, + send: jest.fn(), + close: jest.fn(), + }); + + const installWebSocketMock = (...sockets: ReturnType[]) => { + let index = 0; + const mockConstructor = jest.fn().mockImplementation(() => { + const socket = sockets[Math.min(index, sockets.length - 1)]; + index += 1; + return socket as any; + }); + Object.assign(mockConstructor, { + CONNECTING: 0, + OPEN: 1, + CLOSING: 2, + CLOSED: 3, + }); + global.WebSocket = mockConstructor as unknown as typeof WebSocket; + return mockConstructor; + }; beforeEach(() => { ws = new WebSocketService('ws://test-server'); jest.clearAllMocks(); }); + afterEach(() => { + jest.useRealTimers(); + global.WebSocket = originalWebSocket; + }); + test('connects to WebSocket server', async () => { - const mockWs = { - onopen: null as (() => void) | null, - onmessage: null as ((event: any) => void) | null, - onclose: null as (() => void) | null, - onerror: null as ((event: any) => void) | null, - readyState: WebSocket.OPEN, - send: jest.fn(), - close: jest.fn(), - }; - - jest.spyOn(global, 'WebSocket').mockImplementation(() => mockWs as any); + const mockWs = createMockSocket(WebSocket.OPEN); + const webSocketCtor = installWebSocketMock(mockWs); const connectPromise = ws.connect(); - // Simulate connection open mockWs.onopen!(); await connectPromise; - expect(global.WebSocket).toHaveBeenCalledWith('ws://test-server'); + expect(webSocketCtor).toHaveBeenCalledWith('ws://test-server'); }); test('receives real-time updates', async () => { - const mockWs = { - onopen: null as (() => void) | null, - onmessage: null as ((event: any) => void) | null, - onclose: null as (() => void) | null, - onerror: null as ((event: any) => void) | null, - readyState: WebSocket.OPEN, - send: jest.fn(), - close: jest.fn(), - }; - - jest.spyOn(global, 'WebSocket').mockImplementation(() => mockWs as any); + const mockWs = createMockSocket(WebSocket.OPEN); + installWebSocketMock(mockWs); const handler = jest.fn(); ws.subscribe('alert:created', handler); - await ws.connect(); + const connectPromise = ws.connect(); + mockWs.onopen!(); + await connectPromise; - // Simulate message received mockWs.onmessage!({ data: JSON.stringify({ type: 'alert:created', payload: { id: 'alert-1', message: 'Test' }, }), - }); + } as MessageEvent); expect(handler).toHaveBeenCalledWith({ id: 'alert-1', message: 'Test' }); }); test('handles connection errors', async () => { - const mockWs = { - onopen: null as (() => void) | null, - onmessage: null as ((event: any) => void) | null, - onclose: null as (() => void) | null, - onerror: null as ((event: any) => void) | null, - readyState: WebSocket.CLOSED, - send: jest.fn(), - close: jest.fn(), - }; - - jest.spyOn(global, 'WebSocket').mockImplementation(() => mockWs as any); - - const errorHandler = jest.fn(); - - try { - await ws.connect(); - } catch (error) { - errorHandler(error); - } - - // Simulate error - mockWs.onerror!({ message: 'Connection failed' }); - - expect(errorHandler).toHaveBeenCalled(); - }); + const mockWs = createMockSocket(WebSocket.CLOSED); + const webSocketCtor = installWebSocketMock(mockWs); - test('reconnects on disconnect', async () => { - jest.useFakeTimers(); - - const mockWs = { - onopen: null as (() => void) | null, - onmessage: null as ((event: any) => void) | null, - onclose: null as (() => void) | null, - onerror: null as ((event: any) => void) | null, - readyState: WebSocket.OPEN, - send: jest.fn(), - close: jest.fn(), - }; + const connectPromise = ws.connect(); + mockWs.onerror!(new Event('error')); + await connectPromise; - jest.spyOn(global, 'WebSocket').mockImplementation(() => mockWs as any); + expect(ws.isConnected()).toBe(false); await ws.connect(); - // Simulate disconnect - mockWs.onclose!(); + expect(webSocketCtor).toHaveBeenCalledTimes(1); + }); - // Fast-forward time - jest.advanceTimersByTime(2000); + test('reconnects on disconnect', async () => { + jest.useFakeTimers(); + const firstSocket = createMockSocket(WebSocket.OPEN); + const secondSocket = createMockSocket(WebSocket.OPEN); - expect(global.WebSocket).toHaveBeenCalledTimes(2); + const webSocketCtor = installWebSocketMock(firstSocket, secondSocket); - jest.useRealTimers(); + const connectPromise = ws.connect(); + firstSocket.onopen!(); + await connectPromise; + + firstSocket.onclose!(); + jest.advanceTimersByTime(1000); + + expect(webSocketCtor).toHaveBeenCalledTimes(2); }); test('subscribes to events', () => { @@ -133,19 +125,12 @@ describe('WebSocket Service', () => { }); test('sends messages', async () => { - const mockWs = { - onopen: null as (() => void) | null, - onmessage: null as ((event: any) => void) | null, - onclose: null as (() => void) | null, - onerror: null as ((event: any) => void) | null, - readyState: WebSocket.OPEN, - send: jest.fn(), - close: jest.fn(), - }; - - jest.spyOn(global, 'WebSocket').mockImplementation(() => mockWs as any); + const mockWs = createMockSocket(WebSocket.OPEN); + installWebSocketMock(mockWs); - await ws.connect(); + const connectPromise = ws.connect(); + mockWs.onopen!(); + await connectPromise; ws.send('alert:created', { id: 'alert-1' }); @@ -155,21 +140,14 @@ describe('WebSocket Service', () => { }); test('checks connection status', async () => { - const mockWs = { - onopen: null as (() => void) | null, - onmessage: null as ((event: any) => void) | null, - onclose: null as (() => void) | null, - onerror: null as ((event: any) => void) | null, - readyState: WebSocket.OPEN, - send: jest.fn(), - close: jest.fn(), - }; - - jest.spyOn(global, 'WebSocket').mockImplementation(() => mockWs as any); + const mockWs = createMockSocket(WebSocket.OPEN); + installWebSocketMock(mockWs); expect(ws.isConnected()).toBe(false); - await ws.connect(); + const connectPromise = ws.connect(); + mockWs.onopen!(); + await connectPromise; expect(ws.isConnected()).toBe(true); }); diff --git a/web/src/services/api.ts b/web/src/services/api.ts index cb5d64c..b570314 100644 --- a/web/src/services/api.ts +++ b/web/src/services/api.ts @@ -38,10 +38,52 @@ class ApiService { ); } + private normalizeSecurityStatus(payload: Record): SecurityStatus { + return { + overallScore: (payload.overallScore ?? payload.overall_score ?? 0) as number, + activeThreats: (payload.activeThreats ?? payload.active_threats ?? 0) as number, + quarantinedContainers: (payload.quarantinedContainers ?? payload.quarantined_containers ?? 0) as number, + alertsNew: (payload.alertsNew ?? payload.alerts_new ?? 0) as number, + alertsAcknowledged: (payload.alertsAcknowledged ?? payload.alerts_acknowledged ?? 0) as number, + lastUpdated: (payload.lastUpdated ?? payload.last_updated ?? new Date().toISOString()) as string, + }; + } + + private normalizeThreatStatistics(payload: Record): ThreatStatistics { + return { + totalThreats: (payload.totalThreats ?? payload.total_threats ?? 0) as number, + bySeverity: (payload.bySeverity ?? payload.by_severity ?? {}) as ThreatStatistics['bySeverity'], + byType: (payload.byType ?? payload.by_type ?? {}) as Record, + trend: (payload.trend ?? 'stable') as ThreatStatistics['trend'], + }; + } + + private normalizeAlert(payload: Record): Alert { + return { + id: (payload.id ?? '') as string, + alertType: (payload.alertType ?? payload.alert_type ?? 'SystemEvent') as Alert['alertType'], + severity: (payload.severity ?? 'Info') as Alert['severity'], + message: (payload.message ?? '') as string, + status: (payload.status ?? 'New') as Alert['status'], + timestamp: (payload.timestamp ?? new Date().toISOString()) as string, + metadata: payload.metadata as Record | undefined, + }; + } + + private normalizeAlertStats(payload: Record): AlertStats { + return { + totalCount: (payload.totalCount ?? payload.total_count ?? 0) as number, + newCount: (payload.newCount ?? payload.new_count ?? 0) as number, + acknowledgedCount: (payload.acknowledgedCount ?? payload.acknowledged_count ?? 0) as number, + resolvedCount: (payload.resolvedCount ?? payload.resolved_count ?? 0) as number, + falsePositiveCount: (payload.falsePositiveCount ?? payload.false_positive_count ?? 0) as number, + }; + } + // Security Status async getSecurityStatus(): Promise { const response = await this.api.get('/security/status'); - return response.data; + return this.normalizeSecurityStatus(response.data as Record); } async getThreats(): Promise { @@ -51,7 +93,7 @@ class ApiService { async getThreatStatistics(): Promise { const response = await this.api.get('/threats/statistics'); - return response.data; + return this.normalizeThreatStatistics(response.data as Record); } // Alerts @@ -64,12 +106,12 @@ class ApiService { filter.status.forEach(s => params.append('status', s)); } const response = await this.api.get('/alerts', { params }); - return response.data; + return (response.data as Array>).map((alert) => this.normalizeAlert(alert)); } async getAlertStats(): Promise { const response = await this.api.get('/alerts/stats'); - return response.data; + return this.normalizeAlertStats(response.data as Record); } async acknowledgeAlert(alertId: string): Promise { diff --git a/web/src/setupTests.ts b/web/src/setupTests.ts index 68cddd9..5b7c924 100644 --- a/web/src/setupTests.ts +++ b/web/src/setupTests.ts @@ -21,5 +21,13 @@ class MockWebSocket { global.WebSocket = MockWebSocket as unknown as typeof WebSocket; +class MockResizeObserver { + observe = jest.fn(); + unobserve = jest.fn(); + disconnect = jest.fn(); +} + +global.ResizeObserver = MockResizeObserver as unknown as typeof ResizeObserver; + // Mock fetch global.fetch = jest.fn(); From 1c30685a21fb3640227031fdf5e077f12b4b2a4f Mon Sep 17 00:00:00 2001 From: vsilent Date: Sat, 4 Apr 2026 21:10:44 +0300 Subject: [PATCH 05/10] logs, containers in ui, ports, ws --- Cargo.toml | 2 + docker-compose.yml | 2 +- src/api/containers.rs | 26 +++- src/api/logs.rs | 62 +++++++- tests/api/websocket_test.rs | 141 ++++++++++++++++-- web/src/components/ContainerList.tsx | 17 ++- .../__tests__/ContainerList.test.tsx | 20 +++ web/src/services/__tests__/ports.test.ts | 16 ++ web/src/services/api.ts | 3 +- web/src/services/ports.ts | 10 ++ web/src/services/websocket.ts | 4 +- 11 files changed, 283 insertions(+), 20 deletions(-) create mode 100644 web/src/services/__tests__/ports.test.ts create mode 100644 web/src/services/ports.ts diff --git a/Cargo.toml b/Cargo.toml index 27ad71d..e2e6c0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,8 @@ ebpf = [] # Testing tokio-test = "0.4" tempfile = "3" +actix-test = "0.1" +awc = "3" # Benchmarking criterion = { version = "0.5", features = ["html_reports"] } diff --git a/docker-compose.yml b/docker-compose.yml index 289a2fe..34a821b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,7 +19,7 @@ services: echo "Starting Stackdog..." cargo run --bin stackdog ports: - - "${APP_PORT:-8080}:${APP_PORT:-8080}" + - "${APP_PORT:-5000}:${APP_PORT:-5000}" env_file: - .env environment: diff --git a/src/api/containers.rs b/src/api/containers.rs index 9e7ad77..6864c88 100644 --- a/src/api/containers.rs +++ b/src/api/containers.rs @@ -83,11 +83,17 @@ fn to_container_response( security: &crate::docker::containers::ContainerSecurityStatus, stats: Option<&ContainerStats>, ) -> ContainerResponse { + let effective_status = if security.security_state == "Quarantined" { + "Quarantined".to_string() + } else { + container.status.clone() + }; + ContainerResponse { id: container.id.clone(), name: container.name.clone(), image: container.image.clone(), - status: container.status.clone(), + status: effective_status, security_status: ApiContainerSecurityStatus { state: security.security_state.clone(), threats: security.threats, @@ -261,6 +267,24 @@ mod tests { assert_eq!(response.network_activity.blocked_connections, None); } + #[actix_rt::test] + async fn test_to_container_response_marks_quarantined_status_from_security_state() { + let response = to_container_response( + &sample_container(), + &crate::docker::containers::ContainerSecurityStatus { + container_id: "container-1".into(), + risk_score: 88, + threats: 3, + security_state: "Quarantined".into(), + }, + None, + ); + + assert_eq!(response.status, "Quarantined"); + assert_eq!(response.security_status.state, "Quarantined"); + assert!(response.network_activity.suspicious_activity); + } + #[actix_rt::test] async fn test_get_containers() { let pool = create_pool(":memory:").unwrap(); diff --git a/src/api/logs.rs b/src/api/logs.rs index 5468fa7..47d465a 100644 --- a/src/api/logs.rs +++ b/src/api/logs.rs @@ -38,7 +38,7 @@ pub async fn list_sources(pool: web::Data) -> impl Responder { /// /// GET /api/logs/sources/{path} pub async fn get_source(pool: web::Data, path: web::Path) -> impl Responder { - match log_sources::get_log_source_by_path(&pool, &path) { + match log_sources::get_log_source_by_path(&pool, &path.into_inner()) { Ok(Some(source)) => HttpResponse::Ok().json(source), Ok(None) => HttpResponse::NotFound().json(serde_json::json!({ "error": "Log source not found" @@ -77,7 +77,7 @@ pub async fn add_source( /// /// DELETE /api/logs/sources/{path} pub async fn delete_source(pool: web::Data, path: web::Path) -> impl Responder { - match log_sources::delete_log_source(&pool, &path) { + match log_sources::delete_log_source(&pool, &path.into_inner()) { Ok(_) => HttpResponse::NoContent().finish(), Err(e) => { log::error!("Failed to delete log source: {}", e); @@ -136,8 +136,8 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { web::scope("/api/logs") .route("/sources", web::get().to(list_sources)) .route("/sources", web::post().to(add_source)) - .route("/sources/{path}", web::get().to(get_source)) - .route("/sources/{path}", web::delete().to(delete_source)) + .route("/sources/{path:.*}", web::get().to(get_source)) + .route("/sources/{path:.*}", web::delete().to(delete_source)) .route("/summaries", web::get().to(list_summaries)), ); } @@ -236,6 +236,33 @@ mod tests { assert_eq!(resp.status(), 404); } + #[actix_rt::test] + async fn test_get_source_with_full_filesystem_path() { + let pool = setup_pool(); + let source = LogSource::new( + LogSourceType::CustomFile, + "/var/log/app.log".into(), + "App Log".into(), + ); + log_sources::upsert_log_source(&pool, &source).unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool)) + .configure(configure_routes), + ) + .await; + + let req = test::TestRequest::get() + .uri("/api/logs/sources//var/log/app.log") + .to_request(); + let resp = test::call_service(&app, req).await; + assert_eq!(resp.status(), 200); + + let body: serde_json::Value = test::read_body_json(resp).await; + assert_eq!(body["path_or_id"], "/var/log/app.log"); + } + #[actix_rt::test] async fn test_delete_source() { let pool = setup_pool(); @@ -262,6 +289,33 @@ mod tests { assert_eq!(resp.status(), 204); } + #[actix_rt::test] + async fn test_delete_source_with_full_filesystem_path() { + let pool = setup_pool(); + let source = LogSource::new( + LogSourceType::CustomFile, + "/var/log/app.log".into(), + "App Log".into(), + ); + log_sources::upsert_log_source(&pool, &source).unwrap(); + + let app = test::init_service( + App::new() + .app_data(web::Data::new(pool.clone())) + .configure(configure_routes), + ) + .await; + + let req = test::TestRequest::delete() + .uri("/api/logs/sources//var/log/app.log") + .to_request(); + let resp = test::call_service(&app, req).await; + assert_eq!(resp.status(), 204); + + let stored = log_sources::get_log_source_by_path(&pool, "/var/log/app.log").unwrap(); + assert!(stored.is_none()); + } + #[actix_rt::test] async fn test_list_summaries_empty() { let pool = setup_pool(); diff --git a/tests/api/websocket_test.rs b/tests/api/websocket_test.rs index 10bbbca..7a6d827 100644 --- a/tests/api/websocket_test.rs +++ b/tests/api/websocket_test.rs @@ -1,22 +1,145 @@ //! WebSocket API tests +use actix::Actor; +use actix_test::start; +use actix_web::{web, App}; +use awc::ws::Frame; +use chrono::Utc; +use futures_util::StreamExt; +use serde_json::Value; +use stackdog::alerting::alert::{AlertSeverity, AlertStatus, AlertType}; +use stackdog::api::websocket::{self, WebSocketHub}; +use stackdog::database::models::Alert; +use stackdog::database::{create_alert, create_pool, init_database}; + +async fn read_text_frame(framed: &mut S) -> Value +where + S: futures_util::Stream> + Unpin, +{ + loop { + match framed + .next() + .await + .expect("expected websocket frame") + .expect("valid websocket frame") + { + Frame::Text(bytes) => { + return serde_json::from_slice(&bytes).expect("valid websocket json"); + } + Frame::Ping(_) | Frame::Pong(_) => continue, + other => panic!("unexpected websocket frame: {other:?}"), + } + } +} + +fn sample_alert(id: &str) -> Alert { + Alert { + id: id.to_string(), + alert_type: AlertType::ThreatDetected, + severity: AlertSeverity::High, + message: format!("alert-{id}"), + status: AlertStatus::New, + timestamp: Utc::now().to_rfc3339(), + metadata: None, + } +} + #[cfg(test)] mod tests { + use super::*; + #[actix_rt::test] - async fn test_websocket_connection() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + async fn test_websocket_connection_receives_initial_stats_snapshot() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + create_alert(&pool, sample_alert("a1")).await.unwrap(); + + let hub = WebSocketHub::new().start(); + let pool_for_app = pool.clone(); + let hub_for_app = hub.clone(); + let server = start(move || { + App::new() + .app_data(web::Data::new(pool_for_app.clone())) + .app_data(web::Data::new(hub_for_app.clone())) + .configure(websocket::configure_routes) + }); + + let (_response, mut framed) = awc::Client::new() + .ws(server.url("/ws")) + .connect() + .await + .unwrap(); + + let message = read_text_frame(&mut framed).await; + assert_eq!(message["type"], "stats:updated"); + assert_eq!(message["payload"]["alerts_new"], 1); } #[actix_rt::test] - async fn test_websocket_subscribe() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + async fn test_websocket_receives_broadcast_events() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + let hub = WebSocketHub::new().start(); + let pool_for_app = pool.clone(); + let hub_for_app = hub.clone(); + let server = start(move || { + App::new() + .app_data(web::Data::new(pool_for_app.clone())) + .app_data(web::Data::new(hub_for_app.clone())) + .configure(websocket::configure_routes) + }); + + let (_response, mut framed) = awc::Client::new() + .ws(server.url("/ws")) + .connect() + .await + .unwrap(); + + let _initial = read_text_frame(&mut framed).await; + + websocket::broadcast_event( + &hub, + "alert:created", + serde_json::json!({ "id": "alert-1" }), + ) + .await; + + let message = read_text_frame(&mut framed).await; + assert_eq!(message["type"], "alert:created"); + assert_eq!(message["payload"]["id"], "alert-1"); } #[actix_rt::test] - async fn test_websocket_receive_events() { - // TODO: Implement when API is ready - assert!(true, "Test placeholder"); + async fn test_websocket_receives_broadcast_stats_updates() { + let pool = create_pool(":memory:").unwrap(); + init_database(&pool).unwrap(); + + let hub = WebSocketHub::new().start(); + let pool_for_app = pool.clone(); + let hub_for_app = hub.clone(); + let server = start(move || { + App::new() + .app_data(web::Data::new(pool_for_app.clone())) + .app_data(web::Data::new(hub_for_app.clone())) + .configure(websocket::configure_routes) + }); + + let (_response, mut framed) = awc::Client::new() + .ws(server.url("/ws")) + .connect() + .await + .unwrap(); + + let initial = read_text_frame(&mut framed).await; + assert_eq!(initial["type"], "stats:updated"); + assert_eq!(initial["payload"]["alerts_new"], 0); + + create_alert(&pool, sample_alert("a2")).await.unwrap(); + websocket::broadcast_stats(&hub, &pool).await.unwrap(); + + let updated = read_text_frame(&mut framed).await; + assert_eq!(updated["type"], "stats:updated"); + assert_eq!(updated["payload"]["alerts_new"], 1); } } diff --git a/web/src/components/ContainerList.tsx b/web/src/components/ContainerList.tsx index 39defae..42ff229 100644 --- a/web/src/components/ContainerList.tsx +++ b/web/src/components/ContainerList.tsx @@ -119,9 +119,17 @@ const ContainerList: React.FC = () => {
{containers.map((container) => (
+ {(() => { + const isQuarantined = + container.status === 'Quarantined' || container.securityStatus.state === 'Quarantined'; + + return ( + <>
{container.name}
- {container.status} + + {isQuarantined ? 'Quarantined' : container.status} +

Image: {container.image}

@@ -159,7 +167,7 @@ const ContainerList: React.FC = () => { > Details - {container.status === 'Running' && ( + {!isQuarantined && container.status === 'Running' && ( )} - {container.status === 'Quarantined' && ( + {isQuarantined && ( )}
+ + ); + })()}
))}
diff --git a/web/src/components/__tests__/ContainerList.test.tsx b/web/src/components/__tests__/ContainerList.test.tsx index 58caaf0..cebcd94 100644 --- a/web/src/components/__tests__/ContainerList.test.tsx +++ b/web/src/components/__tests__/ContainerList.test.tsx @@ -140,6 +140,26 @@ describe('ContainerList Component', () => { }); }); + test('shows release action when security state is quarantined', async () => { + const quarantinedBySecurityState = { + ...mockContainers[0], + status: 'Running' as const, + securityStatus: { + ...mockContainers[0].securityStatus, + state: 'Quarantined' as const, + }, + }; + + (apiService.getContainers as jest.Mock).mockResolvedValue([quarantinedBySecurityState]); + + render(); + + expect(await screen.findByText('web-server')).toBeInTheDocument(); + expect(screen.getAllByText('Quarantined').length).toBeGreaterThanOrEqual(2); + expect(screen.getByText('Release')).toBeInTheDocument(); + expect(screen.queryByText('Quarantine')).not.toBeInTheDocument(); + }); + test('filters by status', async () => { render(); diff --git a/web/src/services/__tests__/ports.test.ts b/web/src/services/__tests__/ports.test.ts new file mode 100644 index 0000000..8d60f74 --- /dev/null +++ b/web/src/services/__tests__/ports.test.ts @@ -0,0 +1,16 @@ +import { DEFAULT_API_PORT, resolveApiPort } from '../ports'; + +describe('port configuration', () => { + test('uses the backend default port when no frontend override is set', () => { + expect(DEFAULT_API_PORT).toBe('5000'); + expect(resolveApiPort({})).toBe('5000'); + }); + + test('prefers explicit frontend port overrides', () => { + expect(resolveApiPort({ REACT_APP_API_PORT: '7000', APP_PORT: '5000' })).toBe('7000'); + }); + + test('falls back to APP_PORT when frontend override is absent', () => { + expect(resolveApiPort({ APP_PORT: '6000' })).toBe('6000'); + }); +}); diff --git a/web/src/services/api.ts b/web/src/services/api.ts index b570314..c6b4dc7 100644 --- a/web/src/services/api.ts +++ b/web/src/services/api.ts @@ -2,6 +2,7 @@ import axios, { AxiosInstance } from 'axios'; import { SecurityStatus, Threat, ThreatStatistics } from '../types/security'; import { Alert, AlertStats, AlertFilter } from '../types/alerts'; import { Container, QuarantineRequest } from '../types/containers'; +import { resolveApiPort } from './ports'; type EnvLike = { REACT_APP_API_URL?: string; @@ -11,7 +12,7 @@ type EnvLike = { const env = ((globalThis as unknown as { __STACKDOG_ENV__?: EnvLike }).__STACKDOG_ENV__ ?? {}) as EnvLike; -const apiPort = env.REACT_APP_API_PORT || env.APP_PORT || '5555'; +const apiPort = resolveApiPort(env); const API_BASE_URL = env.REACT_APP_API_URL || `http://localhost:${apiPort}/api`; class ApiService { diff --git a/web/src/services/ports.ts b/web/src/services/ports.ts new file mode 100644 index 0000000..e36b378 --- /dev/null +++ b/web/src/services/ports.ts @@ -0,0 +1,10 @@ +export type PortEnvLike = { + APP_PORT?: string; + REACT_APP_API_PORT?: string; +}; + +export const DEFAULT_API_PORT = '5000'; + +export function resolveApiPort(env: PortEnvLike): string { + return env.REACT_APP_API_PORT || env.APP_PORT || DEFAULT_API_PORT; +} diff --git a/web/src/services/websocket.ts b/web/src/services/websocket.ts index 7513591..18c75fc 100644 --- a/web/src/services/websocket.ts +++ b/web/src/services/websocket.ts @@ -1,3 +1,5 @@ +import { resolveApiPort } from './ports'; + type WebSocketEvent = | 'threat:detected' | 'alert:created' @@ -31,7 +33,7 @@ export class WebSocketService { constructor(url?: string) { const env = ((globalThis as { __STACKDOG_ENV__?: EnvLike }).__STACKDOG_ENV__ ?? {}) as EnvLike; - const apiPort = env.REACT_APP_API_PORT || env.APP_PORT || '5555'; + const apiPort = resolveApiPort(env); this.url = url || env.REACT_APP_WS_URL || `ws://localhost:${apiPort}/ws`; } From 400a99ff9914df28179d69cda57dc7aa6db29cea Mon Sep 17 00:00:00 2001 From: vsilent Date: Sat, 4 Apr 2026 22:06:30 +0300 Subject: [PATCH 06/10] context.to_string() --- src/firewall/iptables.rs | 2 +- src/firewall/nftables.rs | 5 ++++- src/ip_ban/engine.rs | 3 +++ 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/firewall/iptables.rs b/src/firewall/iptables.rs index 544b202..7df60ed 100644 --- a/src/firewall/iptables.rs +++ b/src/firewall/iptables.rs @@ -49,7 +49,7 @@ impl IptablesBackend { let output = Command::new("iptables") .args(args) .output() - .context(context)?; + .context(context.to_string())?; if !output.status.success() { anyhow::bail!("{}", String::from_utf8_lossy(&output.stderr).trim()); diff --git a/src/firewall/nftables.rs b/src/firewall/nftables.rs index 7c7703d..495404a 100644 --- a/src/firewall/nftables.rs +++ b/src/firewall/nftables.rs @@ -70,7 +70,10 @@ pub struct NfTablesBackend { impl NfTablesBackend { fn run_nft(&self, args: &[&str], context: &str) -> Result<()> { - let output = Command::new("nft").args(args).output().context(context)?; + let output = Command::new("nft") + .args(args) + .output() + .context(context.to_string())?; if !output.status.success() { anyhow::bail!("{}", String::from_utf8_lossy(&output.stderr).trim()); diff --git a/src/ip_ban/engine.rs b/src/ip_ban/engine.rs index 0225ea8..4635d6b 100644 --- a/src/ip_ban/engine.rs +++ b/src/ip_ban/engine.rs @@ -10,6 +10,9 @@ use anyhow::Result; use chrono::{Duration, Utc}; use uuid::Uuid; +#[cfg(target_os = "linux")] +use crate::firewall::backend::FirewallBackend; + #[derive(Debug, Clone)] pub struct OffenseInput { pub ip_address: String, From 5ee12674cc9df3e5df9bd4fe6792af871b2ae126 Mon Sep 17 00:00:00 2001 From: vsilent Date: Sat, 4 Apr 2026 22:09:21 +0300 Subject: [PATCH 07/10] clippy fix --- src/database/events.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/database/events.rs b/src/database/events.rs index 91ed614..260865e 100644 --- a/src/database/events.rs +++ b/src/database/events.rs @@ -51,6 +51,10 @@ impl EventsDb { pub fn len(&self) -> usize { self.events.read().unwrap().len() } + + pub fn is_empty(&self) -> bool { + self.events.read().unwrap().is_empty() + } } impl Default for EventsDb { @@ -102,6 +106,8 @@ mod tests { #[test] fn test_events_db_filters_events_by_pid() { let db = EventsDb::new().unwrap(); + assert!(db.is_empty()); + db.insert(SecurityEvent::Syscall(SyscallEvent::new( 42, 1000, @@ -128,5 +134,6 @@ mod tests { assert_eq!(pid_events.len(), 1); assert_eq!(pid_events[0].pid(), Some(42)); assert_eq!(db.len(), 3); + assert!(!db.is_empty()); } } From 49d4488020dc0f993127558e6aea5120fd9f8c45 Mon Sep 17 00:00:00 2001 From: vsilent Date: Sat, 4 Apr 2026 22:23:33 +0300 Subject: [PATCH 08/10] clippy fix --- src/cli.rs | 127 ++++++++++++++++++++++++++-------------------------- src/main.rs | 48 +++++++------------- 2 files changed, 80 insertions(+), 95 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index c06de7c..db9c18a 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -3,7 +3,7 @@ //! Defines the command-line interface using clap derive macros. //! Supports `serve` (HTTP server) and `sniff` (log analysis) subcommands. -use clap::{Parser, Subcommand}; +use clap::{Args, Parser, Subcommand}; /// Stackdog Security — Docker & Linux server security platform #[derive(Parser, Debug)] @@ -20,67 +20,70 @@ pub enum Command { Serve, /// Sniff and analyze logs from Docker containers and system sources - Sniff { - /// Run a single scan/analysis pass, then exit - #[arg(long)] - once: bool, + Sniff(Box), +} + +#[derive(Args, Debug, Clone)] +pub struct SniffCommand { + /// Run a single scan/analysis pass, then exit + #[arg(long)] + pub once: bool, - /// Consume logs: archive to zstd, then purge originals to free disk - #[arg(long)] - consume: bool, + /// Consume logs: archive to zstd, then purge originals to free disk + #[arg(long)] + pub consume: bool, - /// Output directory for consumed logs - #[arg(long, default_value = "./stackdog-logs/")] - output: String, + /// Output directory for consumed logs + #[arg(long, default_value = "./stackdog-logs/")] + pub output: String, - /// Additional log file paths to watch (comma-separated) - #[arg(long)] - sources: Option, + /// Additional log file paths to watch (comma-separated) + #[arg(long)] + pub sources: Option, - /// Poll interval in seconds - #[arg(long, default_value = "30")] - interval: u64, + /// Poll interval in seconds + #[arg(long, default_value = "30")] + pub interval: u64, - /// AI provider: "openai", "ollama", or "candle" - #[arg(long)] - ai_provider: Option, + /// AI provider: "openai", "ollama", or "candle" + #[arg(long)] + pub ai_provider: Option, - /// AI model name (e.g. "gpt-4o-mini", "qwen2.5-coder:latest", "llama3") - #[arg(long)] - ai_model: Option, + /// AI model name (e.g. "gpt-4o-mini", "qwen2.5-coder:latest", "llama3") + #[arg(long)] + pub ai_model: Option, - /// AI API URL (e.g. "http://localhost:11434/v1" for Ollama) - #[arg(long)] - ai_api_url: Option, + /// AI API URL (e.g. "http://localhost:11434/v1" for Ollama) + #[arg(long)] + pub ai_api_url: Option, - /// Slack webhook URL for alert notifications - #[arg(long)] - slack_webhook: Option, + /// Slack webhook URL for alert notifications + #[arg(long)] + pub slack_webhook: Option, - /// Generic webhook URL for alert notifications - #[arg(long)] - webhook_url: Option, + /// Generic webhook URL for alert notifications + #[arg(long)] + pub webhook_url: Option, - /// SMTP host for email alert notifications - #[arg(long)] - smtp_host: Option, + /// SMTP host for email alert notifications + #[arg(long)] + pub smtp_host: Option, - /// SMTP port for email alert notifications - #[arg(long)] - smtp_port: Option, + /// SMTP port for email alert notifications + #[arg(long)] + pub smtp_port: Option, - /// SMTP username / sender address for email alert notifications - #[arg(long)] - smtp_user: Option, + /// SMTP username / sender address for email alert notifications + #[arg(long)] + pub smtp_user: Option, - /// SMTP password for email alert notifications - #[arg(long)] - smtp_password: Option, + /// SMTP password for email alert notifications + #[arg(long)] + pub smtp_password: Option, - /// Comma-separated email recipients for alert notifications - #[arg(long)] - email_recipients: Option, - }, + /// Comma-separated email recipients for alert notifications + #[arg(long)] + pub email_recipients: Option, } #[cfg(test)] @@ -107,7 +110,8 @@ mod tests { fn test_sniff_subcommand_defaults() { let cli = Cli::parse_from(["stackdog", "sniff"]); match cli.command { - Some(Command::Sniff { + Some(Command::Sniff(sniff)) => { + let SniffCommand { once, consume, output, @@ -123,7 +127,7 @@ mod tests { smtp_user, smtp_password, email_recipients, - }) => { + } = *sniff; assert!(!once); assert!(!consume); assert_eq!(output, "./stackdog-logs/"); @@ -148,7 +152,7 @@ mod tests { fn test_sniff_with_once_flag() { let cli = Cli::parse_from(["stackdog", "sniff", "--once"]); match cli.command { - Some(Command::Sniff { once, .. }) => assert!(once), + Some(Command::Sniff(sniff)) => assert!(sniff.once), _ => panic!("Expected Sniff command"), } } @@ -157,7 +161,7 @@ mod tests { fn test_sniff_with_consume_flag() { let cli = Cli::parse_from(["stackdog", "sniff", "--consume"]); match cli.command { - Some(Command::Sniff { consume, .. }) => assert!(consume), + Some(Command::Sniff(sniff)) => assert!(sniff.consume), _ => panic!("Expected Sniff command"), } } @@ -197,7 +201,8 @@ mod tests { "soc@example.com,oncall@example.com", ]); match cli.command { - Some(Command::Sniff { + Some(Command::Sniff(sniff)) => { + let SniffCommand { once, consume, output, @@ -213,7 +218,7 @@ mod tests { smtp_user, smtp_password, email_recipients, - }) => { + } = *sniff; assert!(once); assert!(consume); assert_eq!(output, "/tmp/logs/"); @@ -244,8 +249,8 @@ mod tests { fn test_sniff_with_candle_provider() { let cli = Cli::parse_from(["stackdog", "sniff", "--ai-provider", "candle"]); match cli.command { - Some(Command::Sniff { ai_provider, .. }) => { - assert_eq!(ai_provider.unwrap(), "candle"); + Some(Command::Sniff(sniff)) => { + assert_eq!(sniff.ai_provider.as_deref(), Some("candle")); } _ => panic!("Expected Sniff command"), } @@ -263,13 +268,9 @@ mod tests { "qwen2.5-coder:latest", ]); match cli.command { - Some(Command::Sniff { - ai_provider, - ai_model, - .. - }) => { - assert_eq!(ai_provider.unwrap(), "ollama"); - assert_eq!(ai_model.unwrap(), "qwen2.5-coder:latest"); + Some(Command::Sniff(sniff)) => { + assert_eq!(sniff.ai_provider.as_deref(), Some("ollama")); + assert_eq!(sniff.ai_model.as_deref(), Some("qwen2.5-coder:latest")); } _ => panic!("Expected Sniff command"), } diff --git a/src/main.rs b/src/main.rs index a9083c6..041c13d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -71,39 +71,23 @@ async fn main() -> io::Result<()> { info!("Architecture: {}", std::env::consts::ARCH); match cli.command { - Some(Command::Sniff { - once, - consume, - output, - sources, - interval, - ai_provider, - ai_model, - ai_api_url, - slack_webhook, - webhook_url, - smtp_host, - smtp_port, - smtp_user, - smtp_password, - email_recipients, - }) => { + Some(Command::Sniff(sniff)) => { let config = sniff::config::SniffConfig::from_env_and_args(sniff::config::SniffArgs { - once, - consume, - output: &output, - sources: sources.as_deref(), - interval, - ai_provider: ai_provider.as_deref(), - ai_model: ai_model.as_deref(), - ai_api_url: ai_api_url.as_deref(), - slack_webhook: slack_webhook.as_deref(), - webhook_url: webhook_url.as_deref(), - smtp_host: smtp_host.as_deref(), - smtp_port, - smtp_user: smtp_user.as_deref(), - smtp_password: smtp_password.as_deref(), - email_recipients: email_recipients.as_deref(), + once: sniff.once, + consume: sniff.consume, + output: &sniff.output, + sources: sniff.sources.as_deref(), + interval: sniff.interval, + ai_provider: sniff.ai_provider.as_deref(), + ai_model: sniff.ai_model.as_deref(), + ai_api_url: sniff.ai_api_url.as_deref(), + slack_webhook: sniff.slack_webhook.as_deref(), + webhook_url: sniff.webhook_url.as_deref(), + smtp_host: sniff.smtp_host.as_deref(), + smtp_port: sniff.smtp_port, + smtp_user: sniff.smtp_user.as_deref(), + smtp_password: sniff.smtp_password.as_deref(), + email_recipients: sniff.email_recipients.as_deref(), }); run_sniff(config).await } From d790554a2f1a955c53cdf1b9920e823fdf76524b Mon Sep 17 00:00:00 2001 From: vsilent Date: Sat, 4 Apr 2026 22:40:04 +0300 Subject: [PATCH 09/10] clippy fix --- src/cli.rs | 64 +++++++++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index db9c18a..7b6e4fa 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -112,22 +112,22 @@ mod tests { match cli.command { Some(Command::Sniff(sniff)) => { let SniffCommand { - once, - consume, - output, - sources, - interval, - ai_provider, - ai_model, - ai_api_url, - slack_webhook, - webhook_url, - smtp_host, - smtp_port, - smtp_user, - smtp_password, - email_recipients, - } = *sniff; + once, + consume, + output, + sources, + interval, + ai_provider, + ai_model, + ai_api_url, + slack_webhook, + webhook_url, + smtp_host, + smtp_port, + smtp_user, + smtp_password, + email_recipients, + } = *sniff; assert!(!once); assert!(!consume); assert_eq!(output, "./stackdog-logs/"); @@ -203,22 +203,22 @@ mod tests { match cli.command { Some(Command::Sniff(sniff)) => { let SniffCommand { - once, - consume, - output, - sources, - interval, - ai_provider, - ai_model, - ai_api_url, - slack_webhook, - webhook_url, - smtp_host, - smtp_port, - smtp_user, - smtp_password, - email_recipients, - } = *sniff; + once, + consume, + output, + sources, + interval, + ai_provider, + ai_model, + ai_api_url, + slack_webhook, + webhook_url, + smtp_host, + smtp_port, + smtp_user, + smtp_password, + email_recipients, + } = *sniff; assert!(once); assert!(consume); assert_eq!(output, "/tmp/logs/"); From c316f94a61b2b7c95c65fb50040e1661847fe107 Mon Sep 17 00:00:00 2001 From: vsilent Date: Sun, 5 Apr 2026 15:37:45 +0300 Subject: [PATCH 10/10] ip_ban::engine::tests --- ebpf/.cargo/config.toml | 7 +++++++ src/firewall/response.rs | 12 +++++++---- src/ip_ban/engine.rs | 44 ++++++++++++++++++++++++++++++++++++---- src/sniff/mod.rs | 28 ++++++++++++++++++++++++- 4 files changed, 82 insertions(+), 9 deletions(-) create mode 100644 ebpf/.cargo/config.toml diff --git a/ebpf/.cargo/config.toml b/ebpf/.cargo/config.toml new file mode 100644 index 0000000..7f0e2a7 --- /dev/null +++ b/ebpf/.cargo/config.toml @@ -0,0 +1,7 @@ +[build] +target = ["bpfel-unknown-none"] + +[target.bpfel-unknown-none] + +[unstable] +build-std = ["core"] diff --git a/src/firewall/response.rs b/src/firewall/response.rs index 2626eef..6a32d75 100644 --- a/src/firewall/response.rs +++ b/src/firewall/response.rs @@ -33,6 +33,13 @@ pub struct ResponseAction { } impl ResponseAction { + fn quarantine_container_error(container_id: &str) -> anyhow::Error { + anyhow::anyhow!( + "Docker-based container quarantine flow is required for {} because firewall backends do not implement container-specific quarantine. Use the Docker/API quarantine path instead.", + container_id + ) + } + fn preferred_backend() -> Result> { if let Ok(mut backend) = NfTablesBackend::new() { backend.initialize()?; @@ -105,10 +112,7 @@ impl ResponseAction { let backend = Self::preferred_backend()?; backend.block_port(*port) } - ResponseType::QuarantineContainer(id) => { - let backend = Self::preferred_backend()?; - backend.block_container(id) - } + ResponseType::QuarantineContainer(id) => Err(Self::quarantine_container_error(id)), ResponseType::KillProcess(pid) => { let output = Command::new("kill") .args(["-TERM", &pid.to_string()]) diff --git a/src/ip_ban/engine.rs b/src/ip_ban/engine.rs index 4635d6b..60dff26 100644 --- a/src/ip_ban/engine.rs +++ b/src/ip_ban/engine.rs @@ -181,6 +181,19 @@ mod tests { use crate::database::repositories::offenses::OffenseStatus; use crate::database::{create_pool, init_database, list_alerts, AlertFilter}; use chrono::Utc; + #[cfg(target_os = "linux")] + use std::process::Command; + + #[cfg(target_os = "linux")] + fn running_as_root() -> bool { + Command::new("id") + .arg("-u") + .output() + .ok() + .and_then(|output| String::from_utf8(output.stdout).ok()) + .map(|stdout| stdout.trim() == "0") + .unwrap_or(false) + } #[actix_rt::test] async fn test_extract_ip_candidates() { @@ -227,10 +240,21 @@ mod tests { source_path: Some("/var/log/auth.log".into()), sample_line: Some("Failed password from 192.0.2.44".into()), }) - .await - .unwrap(); + .await; assert!(!first); + #[cfg(target_os = "linux")] + if !running_as_root() { + let error = second.unwrap_err().to_string(); + assert!( + error.contains("Operation not permitted") + || error.contains("Permission denied") + || error.contains("you must be root") + ); + return; + } + + let second = second.unwrap(); assert!(second); assert!(active_block_for_ip(&pool, "192.0.2.44").unwrap().is_some()); } @@ -260,8 +284,20 @@ mod tests { source_path: Some("/var/log/auth.log".into()), sample_line: Some("Failed password from 192.0.2.55".into()), }) - .await - .unwrap(); + .await; + + #[cfg(target_os = "linux")] + if !running_as_root() { + let error = blocked.unwrap_err().to_string(); + assert!( + error.contains("Operation not permitted") + || error.contains("Permission denied") + || error.contains("you must be root") + ); + return; + } + + let blocked = blocked.unwrap(); assert!(blocked); let released = engine.unban_expired().await.unwrap(); diff --git a/src/sniff/mod.rs b/src/sniff/mod.rs index 4b5b1ea..8a3d07d 100644 --- a/src/sniff/mod.rs +++ b/src/sniff/mod.rs @@ -300,6 +300,19 @@ mod tests { use crate::ip_ban::{IpBanConfig, IpBanEngine}; use crate::sniff::analyzer::{AnomalySeverity, LogAnomaly, LogSummary}; use chrono::Utc; + #[cfg(target_os = "linux")] + use std::process::Command; + + #[cfg(target_os = "linux")] + fn running_as_root() -> bool { + Command::new("id") + .arg("-u") + .output() + .ok() + .and_then(|output| String::from_utf8(output.stdout).ok()) + .map(|stdout| stdout.trim() == "0") + .unwrap_or(false) + } fn memory_sniff_config() -> SniffConfig { let mut config = SniffConfig::from_env_and_args(config::SniffArgs { @@ -473,7 +486,20 @@ mod tests { ); orchestrator.apply_ip_ban(&summary, &engine).await.unwrap(); - orchestrator.apply_ip_ban(&summary, &engine).await.unwrap(); + let second_attempt = orchestrator.apply_ip_ban(&summary, &engine).await; + + #[cfg(target_os = "linux")] + if !running_as_root() { + let error = second_attempt.unwrap_err().to_string(); + assert!( + error.contains("Operation not permitted") + || error.contains("Permission denied") + || error.contains("you must be root") + ); + return; + } + + second_attempt.unwrap(); assert!(active_block_for_ip(&orchestrator.pool, "192.0.2.81") .unwrap()