diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..4f5d32c --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,18 @@ +#!/bin/sh +set -e + +echo "šŸ• Stackdog pre-commit: running cargo fmt..." +cargo fmt --all -- --check || { + echo "āŒ cargo fmt failed. Run 'cargo fmt --all' to fix." + exit 1 +} + +echo "šŸ• Stackdog pre-commit: running cargo clippy..." +cargo clippy 2>&1 +CLIPPY_EXIT=$? +if [ $CLIPPY_EXIT -ne 0 ]; then + echo "āŒ cargo clippy failed to compile. Fix errors before committing." + exit 1 +fi + +echo "āœ… Pre-commit checks passed." diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 52cc06a..8e2246f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -2,154 +2,84 @@ name: Docker CICD on: push: - branches: - - master - - testing + branches: [main, dev] pull_request: - branches: - - master + branches: [main, dev] jobs: - cicd-linux-docker: - name: Cargo and npm build - #runs-on: ubuntu-latest - runs-on: [self-hosted, linux] + build: + name: Build & Test + runs-on: ubuntu-latest steps: - - name: Checkout sources - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Install stable toolchain - uses: actions-rs/toolchain@v1 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable with: - toolchain: stable - profile: minimal - override: true components: rustfmt, clippy - - name: Cache cargo registry - uses: actions/cache@v4 - with: - path: ~/.cargo/registry - key: docker-registry-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - docker-registry- - docker- - - - name: Cache cargo index - uses: actions/cache@v4 - with: - path: ~/.cargo/git - key: docker-index-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - docker-index- - docker- + - name: Cache Rust dependencies + uses: Swatinem/rust-cache@v2 - name: Generate Secret Key - run: | - head -c16 /dev/urandom > src/secret.key + run: head -c16 /dev/urandom > src/secret.key - - name: Cache cargo build - uses: actions/cache@v4 - with: - path: target - key: docker-build-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - docker-build- - docker- - - - name: Cargo check - uses: actions-rs/cargo@v1 - with: - command: check + - name: Check + run: cargo check - - name: Cargo test - if: ${{ always() }} - uses: actions-rs/cargo@v1 - with: - command: test + - name: Format check + run: cargo fmt --all -- --check - - name: Rustfmt - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true - components: rustfmt - command: fmt - args: --all -- --check - - - name: Rustfmt - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - profile: minimal - override: true - components: clippy - command: clippy - args: -- -D warnings - - - name: Run cargo build - uses: actions-rs/cargo@v1 - with: - command: build - args: --release + - name: Clippy + run: cargo clippy -- -D warnings + + - name: Test + run: cargo test - - name: npm install, build, and test + - name: Build release + run: cargo build --release + + - name: Build frontend working-directory: ./web run: | - npm install + if [ -f package-lock.json ]; then + npm ci + else + npm install + fi npm run build - # npm test - - name: Archive production artifacts - uses: actions/upload-artifact@v4 - with: - name: dist-without-markdown - path: | - web/dist - !web/dist/**/*.md - -# - name: Archive code coverage results -# uses: actions/upload-artifact@v4 -# with: -# name: code-coverage-report -# path: output/test/code-coverage.html - - name: Display structure of downloaded files - run: ls -R web/dist - - - name: Copy app files and zip + - name: Package app run: | mkdir -p app/stackdog/dist - cp target/release/stackdog app/stackdog - cp -a web/dist/. app/stackdog + cp target/release/stackdog app/stackdog/ + cp -a web/dist/. app/stackdog/ cp docker/prod/Dockerfile app/Dockerfile - cd app - touch .env - tar -czvf ../app.tar.gz . - cd .. + touch app/.env + tar -czf app.tar.gz -C app . - - name: Upload app archive for Docker job + - name: Upload build artifact uses: actions/upload-artifact@v4 with: - name: artifact-linux-docker + name: app-archive path: app.tar.gz + retention-days: 1 - cicd-docker: - name: CICD Docker - #runs-on: ubuntu-latest - runs-on: [self-hosted, linux] - needs: cicd-linux-docker + docker: + name: Docker Build & Push + runs-on: ubuntu-latest + needs: build steps: - - name: Download app archive + - name: Download build artifact uses: actions/download-artifact@v4 with: - name: artifact-linux-docker + name: app-archive - - name: Extract app archive - run: tar -zxvf app.tar.gz + - name: Extract archive + run: tar -xzf app.tar.gz - - name: Display structure of downloaded files - run: ls -R + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@v3 @@ -157,8 +87,9 @@ jobs: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - - name: Docker build and publish + - name: Build and push uses: docker/build-push-action@v6 with: + context: . push: true tags: trydirect/stackdog:latest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f15bf4c..65e9855 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -18,9 +18,9 @@ jobs: strategy: matrix: include: - - target: x86_64-unknown-linux-gnu + - target: x86_64-unknown-linux-musl artifact: stackdog-linux-x86_64 - - target: aarch64-unknown-linux-gnu + - target: aarch64-unknown-linux-musl artifact: stackdog-linux-aarch64 steps: diff --git a/Cargo.toml b/Cargo.toml index cf82f97..7f450b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,7 @@ r2d2 = "0.8" bollard = "0.16" # HTTP client (for LLM API) -reqwest = { version = "0.12", features = ["json", "blocking"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "blocking", "rustls-tls"] } # Compression zstd = "0.13" diff --git a/VERSION.md b/VERSION.md index 8a9ecc2..341cf11 100644 --- a/VERSION.md +++ b/VERSION.md @@ -1 +1 @@ -0.0.1 \ No newline at end of file +0.2.0 \ No newline at end of file diff --git a/docker/prod/Dockerfile b/docker/prod/Dockerfile index 2d43826..9276155 100644 --- a/docker/prod/Dockerfile +++ b/docker/prod/Dockerfile @@ -1,20 +1,17 @@ # base image -FROM debian:buster-slim +FROM debian:bookworm-slim -# create app directory -RUN mkdir app WORKDIR /app -# install libpq -RUN apt-get update; \ - apt-get install --no-install-recommends -y libpq-dev; \ +# install ca-certificates for HTTPS requests +RUN apt-get update && \ + apt-get install --no-install-recommends -y ca-certificates && \ rm -rf /var/lib/apt/lists/* # copy binary and configuration files COPY ./stackdog . COPY ./.env . -# expose port + EXPOSE 5000 -# run the binary ENTRYPOINT ["/app/stackdog"] diff --git a/ebpf/src/lib.rs b/ebpf/src/lib.rs index c391873..dd1321f 100644 --- a/ebpf/src/lib.rs +++ b/ebpf/src/lib.rs @@ -4,5 +4,5 @@ #![no_std] -pub mod syscalls; pub mod maps; +pub mod syscalls; diff --git a/examples/usage_examples.rs b/examples/usage_examples.rs index 297acdb..53689c1 100644 --- a/examples/usage_examples.rs +++ b/examples/usage_examples.rs @@ -3,18 +3,24 @@ //! This file demonstrates how to use Stackdog Security in your Rust applications. use stackdog::{ - // Events - SyscallEvent, SyscallType, SecurityEvent, - + // Alerting + AlertManager, + AlertType, + PatternMatch, // Rules & Detection RuleEngine, - SignatureDatabase, ThreatCategory, - SignatureMatcher, PatternMatch, - ThreatScorer, ScoringConfig, + ScoringConfig, + SecurityEvent, + + SignatureDatabase, + SignatureMatcher, StatsTracker, - - // Alerting - AlertManager, AlertType, + + // Events + SyscallEvent, + SyscallType, + ThreatCategory, + ThreatScorer, }; use stackdog::alerting::{AlertDeduplicator, DedupConfig}; @@ -23,25 +29,25 @@ use chrono::Utc; fn main() { println!("šŸ• Stackdog Security - Usage Examples\n"); - + // Example 1: Create and validate events example_events(); - + // Example 2: Rule engine example_rule_engine(); - + // Example 3: Signature detection example_signature_detection(); - + // Example 4: Threat scoring example_threat_scoring(); - + // Example 5: Alert management example_alerting(); - + // Example 6: Pattern matching example_pattern_matching(); - + println!("\nāœ… All examples completed!"); } @@ -49,18 +55,20 @@ fn main() { fn example_events() { println!("šŸ“‹ Example 1: Creating Security Events"); println!("----------------------------------------"); - + // Create a syscall event let execve_event = SyscallEvent::new( - 1234, // PID - 1000, // UID + 1234, // PID + 1000, // UID SyscallType::Execve, Utc::now(), ); - - println!(" Created execve event: PID={}, UID={}", - execve_event.pid, execve_event.uid); - + + println!( + " Created execve event: PID={}, UID={}", + execve_event.pid, execve_event.uid + ); + // Create event with builder pattern let connect_event = SyscallEvent::builder() .pid(5678) @@ -69,14 +77,16 @@ fn example_events() { .container_id(Some("abc123".to_string())) .comm(Some("curl".to_string())) .build(); - - println!(" Created connect event: PID={}, Command={:?}", - connect_event.pid, connect_event.comm); - + + println!( + " Created connect event: PID={}, Command={:?}", + connect_event.pid, connect_event.comm + ); + // Convert to SecurityEvent - let security_event: SecurityEvent = execve_event.into(); + let _security_event: SecurityEvent = execve_event.into(); println!(" Converted to SecurityEvent variant"); - + println!(" āœ“ Events created successfully\n"); } @@ -84,39 +94,43 @@ fn example_events() { fn example_rule_engine() { println!("šŸ“‹ Example 2: Rule Engine"); println!("----------------------------------------"); - + // Create rule engine let mut engine = RuleEngine::new(); - + // Add built-in rules use stackdog::rules::builtin::{ - SyscallBlocklistRule, ProcessExecutionRule, NetworkConnectionRule, + NetworkConnectionRule, ProcessExecutionRule, SyscallBlocklistRule, }; - + // Block dangerous syscalls - engine.register_rule(Box::new(SyscallBlocklistRule::new( - vec![SyscallType::Ptrace, SyscallType::Setuid] - ))); - + engine.register_rule(Box::new(SyscallBlocklistRule::new(vec![ + SyscallType::Ptrace, + SyscallType::Setuid, + ]))); + // Monitor process execution engine.register_rule(Box::new(ProcessExecutionRule::new())); - + // Monitor network connections engine.register_rule(Box::new(NetworkConnectionRule::new())); - + println!(" Registered {} rules", engine.rule_count()); - + // Create test event let event = SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Ptrace, Utc::now(), + 1234, + 1000, + SyscallType::Ptrace, + Utc::now(), )); - + // Evaluate rules let results = engine.evaluate(&event); let matches = results.iter().filter(|r| r.is_match()).count(); - + println!(" Evaluated event: {} rules matched", matches); - + // Get detailed results let detailed = engine.evaluate_detailed(&event); for result in detailed { @@ -124,7 +138,7 @@ fn example_rule_engine() { println!(" āœ“ Rule matched: {}", result.rule_name()); } } - + println!(" āœ“ Rule engine working\n"); } @@ -132,31 +146,38 @@ fn example_rule_engine() { fn example_signature_detection() { println!("šŸ“‹ Example 3: Signature Detection"); println!("----------------------------------------"); - + // Create signature database let db = SignatureDatabase::new(); println!(" Loaded {} built-in signatures", db.signature_count()); - + // Get signatures by category let crypto_sigs = db.get_signatures_by_category(&ThreatCategory::CryptoMiner); println!(" Crypto miner signatures: {}", crypto_sigs.len()); - + let escape_sigs = db.get_signatures_by_category(&ThreatCategory::ContainerEscape); println!(" Container escape signatures: {}", escape_sigs.len()); - + // Detect threats in event let event = SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Ptrace, Utc::now(), + 1234, + 1000, + SyscallType::Ptrace, + Utc::now(), )); - + let matches = db.detect(&event); println!(" Detected {} matching signatures", matches.len()); - + for sig in matches { - println!(" āš ļø {} (Severity: {}, Category: {})", - sig.name(), sig.severity(), sig.category()); + println!( + " āš ļø {} (Severity: {}, Category: {})", + sig.name(), + sig.severity(), + sig.category() + ); } - + println!(" āœ“ Signature detection working\n"); } @@ -164,44 +185,60 @@ fn example_signature_detection() { fn example_threat_scoring() { println!("šŸ“‹ Example 4: Threat Scoring"); println!("----------------------------------------"); - + // Create scorer with custom config let config = ScoringConfig::default() .with_base_score(50) .with_multiplier(1.2); - + let scorer = ThreatScorer::with_config(config); - + // Create test events let events = vec![ SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Execve, Utc::now(), + 1234, + 1000, + SyscallType::Execve, + Utc::now(), )), SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Ptrace, Utc::now(), + 1234, + 1000, + SyscallType::Ptrace, + Utc::now(), )), SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Mount, Utc::now(), + 1234, + 1000, + SyscallType::Mount, + Utc::now(), )), ]; - + // Calculate scores println!(" Calculating threat scores:"); for (i, event) in events.iter().enumerate() { let score = scorer.calculate_score(event); - println!(" Event {}: Score={} (Severity={})", - i + 1, score.value(), score.severity()); - + println!( + " Event {}: Score={} (Severity={})", + i + 1, + score.value(), + score.severity() + ); + if score.is_high_or_higher() { println!(" āš ļø High threat detected!"); } } - + // Cumulative scoring let cumulative = scorer.calculate_cumulative_score(&events); - println!(" Cumulative score: {} (Severity={})", - cumulative.value(), cumulative.severity()); - + println!( + " Cumulative score: {} (Severity={})", + cumulative.value(), + cumulative.severity() + ); + println!(" āœ“ Threat scoring working\n"); } @@ -209,44 +246,51 @@ fn example_threat_scoring() { fn example_alerting() { println!("šŸ“‹ Example 5: Alert Management"); println!("----------------------------------------"); - + // Create alert manager let mut alert_manager = AlertManager::new().expect("Failed to create manager"); - + // Generate alerts - let alert = alert_manager.generate_alert( - AlertType::ThreatDetected, - stackdog::rules::result::Severity::High, - "Suspicious ptrace activity detected".to_string(), - None, - ).expect("Failed to generate alert"); - + let alert = alert_manager + .generate_alert( + AlertType::ThreatDetected, + stackdog::rules::result::Severity::High, + "Suspicious ptrace activity detected".to_string(), + None, + ) + .expect("Failed to generate alert"); + println!(" Generated alert: ID={}", alert.id()); println!(" Alert count: {}", alert_manager.alert_count()); - + // Acknowledge alert let alert_id = alert.id().to_string(); - alert_manager.acknowledge_alert(&alert_id).expect("Failed to acknowledge"); + alert_manager + .acknowledge_alert(&alert_id) + .expect("Failed to acknowledge"); println!(" Alert acknowledged"); - + // Get statistics let stats = alert_manager.get_stats(); - println!(" Statistics: Total={}, New={}, Acknowledged={}, Resolved={}", - stats.total_count, stats.new_count, - stats.acknowledged_count, stats.resolved_count); - + println!( + " Statistics: Total={}, New={}, Acknowledged={}, Resolved={}", + stats.total_count, stats.new_count, stats.acknowledged_count, stats.resolved_count + ); + // Create deduplicator let config = DedupConfig::default() .with_window_seconds(300) .with_aggregation(true); - + let mut dedup = AlertDeduplicator::new(config); - + // Check for duplicates let result = dedup.check(&alert); - println!(" Deduplication: is_duplicate={}, count={}", - result.is_duplicate, result.count); - + println!( + " Deduplication: is_duplicate={}, count={}", + result.is_duplicate, result.count + ); + println!(" āœ“ Alert management working\n"); } @@ -254,56 +298,70 @@ fn example_alerting() { fn example_pattern_matching() { println!("šŸ“‹ Example 6: Pattern Matching"); println!("----------------------------------------"); - + // Create signature matcher let mut matcher = SignatureMatcher::new(); - + // Add pattern: execve followed by ptrace (suspicious) let pattern = PatternMatch::new() .with_syscall(SyscallType::Execve) .then_syscall(SyscallType::Ptrace) .within_seconds(60) .with_description("Suspicious process debugging pattern"); - + matcher.add_pattern(pattern); println!(" Added pattern: execve -> ptrace (within 60s)"); - + // Create event sequence let events = vec![ SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Execve, Utc::now(), + 1234, + 1000, + SyscallType::Execve, + Utc::now(), )), SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Ptrace, Utc::now(), + 1234, + 1000, + SyscallType::Ptrace, + Utc::now(), )), ]; - + // Match pattern let result = matcher.match_sequence(&events); - println!(" Pattern match: {} (confidence: {:.2})", - if result.is_match() { "MATCH" } else { "NO MATCH" }, - result.confidence()); - + println!( + " Pattern match: {} (confidence: {:.2})", + if result.is_match() { + "MATCH" + } else { + "NO MATCH" + }, + result.confidence() + ); + if result.is_match() { println!(" āš ļø Suspicious pattern detected!"); for sig in result.matches() { println!(" Matched: {}", sig); } } - + // Detection statistics let mut stats_tracker = StatsTracker::new().expect("Failed to create tracker"); - + for event in &events { let match_result = matcher.match_single(event); stats_tracker.record_event(event, match_result.is_match()); } - + let stats = stats_tracker.stats(); - println!(" Detection stats: Events={}, Matches={}, Rate={:.1}%", - stats.events_processed(), - stats.signatures_matched(), - stats.detection_rate() * 100.0); - + println!( + " Detection stats: Events={}, Matches={}, Rate={:.1}%", + stats.events_processed(), + stats.signatures_matched(), + stats.detection_rate() * 100.0 + ); + println!(" āœ“ Pattern matching working\n"); } diff --git a/src/alerting/alert.rs b/src/alerting/alert.rs index 61033eb..76ef10e 100644 --- a/src/alerting/alert.rs +++ b/src/alerting/alert.rs @@ -91,11 +91,7 @@ pub struct Alert { impl Alert { /// Create a new alert - pub fn new( - alert_type: AlertType, - severity: AlertSeverity, - message: String, - ) -> Self { + pub fn new(alert_type: AlertType, severity: AlertSeverity, message: String) -> Self { Self { id: Uuid::new_v4().to_string(), alert_type, @@ -109,64 +105,64 @@ impl Alert { resolution_note: None, } } - + /// Get alert ID pub fn id(&self) -> &str { &self.id } - + /// Get alert type pub fn alert_type(&self) -> AlertType { self.alert_type.clone() } - + /// Get severity pub fn severity(&self) -> AlertSeverity { self.severity } - + /// Get message pub fn message(&self) -> &str { &self.message } - + /// Get status pub fn status(&self) -> AlertStatus { self.status } - + /// Get timestamp pub fn timestamp(&self) -> DateTime { self.timestamp } - + /// Get source event pub fn source_event(&self) -> Option<&SecurityEvent> { self.source_event.as_ref() } - + /// Set source event pub fn set_source_event(&mut self, event: SecurityEvent) { self.source_event = Some(event); } - + /// Get metadata pub fn metadata(&self) -> &std::collections::HashMap { &self.metadata } - + /// Add metadata pub fn add_metadata(&mut self, key: String, value: String) { self.metadata.insert(key, value); } - + /// Acknowledge the alert pub fn acknowledge(&mut self) { if self.status == AlertStatus::New { self.status = AlertStatus::Acknowledged; } } - + /// Resolve the alert pub fn resolve(&mut self) { if self.status == AlertStatus::Acknowledged || self.status == AlertStatus::New { @@ -174,22 +170,22 @@ impl Alert { self.resolved_at = Some(Utc::now()); } } - + /// Set resolution note pub fn set_resolution_note(&mut self, note: String) { self.resolution_note = Some(note); } - + /// Calculate fingerprint for deduplication pub fn fingerprint(&self) -> String { use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; - + let mut hasher = DefaultHasher::new(); self.alert_type.hash(&mut hasher); self.severity.hash(&mut hasher); self.message.hash(&mut hasher); - + format!("{:x}", hasher.finish()) } } @@ -199,10 +195,7 @@ impl std::fmt::Display for Alert { write!( f, "[{}] {} - {} ({})", - self.severity, - self.alert_type, - self.message, - self.status + self.severity, self.alert_type, self.message, self.status ) } } @@ -210,17 +203,17 @@ impl std::fmt::Display for Alert { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_alert_type_display() { assert_eq!(format!("{}", AlertType::ThreatDetected), "ThreatDetected"); } - + #[test] fn test_alert_severity_display() { assert_eq!(format!("{}", AlertSeverity::High), "High"); } - + #[test] fn test_alert_status_display() { assert_eq!(format!("{}", AlertStatus::New), "New"); diff --git a/src/alerting/dedup.rs b/src/alerting/dedup.rs index 532edf4..9724f4d 100644 --- a/src/alerting/dedup.rs +++ b/src/alerting/dedup.rs @@ -16,43 +16,43 @@ pub struct DedupConfig { } impl DedupConfig { - /// Create default config - pub fn default() -> Self { + /// Create a new config with given values + pub fn new(enabled: bool, window_seconds: u64, aggregation: bool) -> Self { Self { - enabled: true, - window_seconds: 300, // 5 minutes - aggregation: true, + enabled, + window_seconds, + aggregation, } } - + /// Set enabled pub fn with_enabled(mut self, enabled: bool) -> Self { self.enabled = enabled; self } - + /// Set window seconds pub fn with_window_seconds(mut self, seconds: u64) -> Self { self.window_seconds = seconds; self } - + /// Set aggregation pub fn with_aggregation(mut self, aggregation: bool) -> Self { self.aggregation = aggregation; self } - + /// Check if enabled pub fn enabled(&self) -> bool { self.enabled } - + /// Get window seconds pub fn window_seconds(&self) -> u64 { self.window_seconds } - + /// Check if aggregation enabled pub fn aggregation_enabled(&self) -> bool { self.aggregation @@ -61,7 +61,7 @@ impl DedupConfig { impl Default for DedupConfig { fn default() -> Self { - Self::default() + Self::new(true, 300, true) } } @@ -74,7 +74,7 @@ impl Fingerprint { pub fn new(value: String) -> Self { Self(value) } - + /// Get value pub fn value(&self) -> &str { &self.0 @@ -124,21 +124,21 @@ impl AlertDeduplicator { stats: DedupStats::default(), } } - + /// Calculate fingerprint for alert pub fn calculate_fingerprint(&self, alert: &Alert) -> Fingerprint { Fingerprint::new(alert.fingerprint()) } - + /// Check if alert is duplicate pub fn is_duplicate(&mut self, alert: &Alert) -> bool { if !self.config.enabled { return false; } - + let fingerprint = self.calculate_fingerprint(alert); let now = Utc::now(); - + if let Some(entry) = self.fingerprints.get(&fingerprint) { // Check if within window let elapsed = now - entry.last_seen; @@ -146,7 +146,7 @@ impl AlertDeduplicator { return true; } } - + // Not a duplicate or window expired self.fingerprints.insert( fingerprint, @@ -156,14 +156,14 @@ impl AlertDeduplicator { count: 1, }, ); - + false } - + /// Check alert and return result with count pub fn check(&mut self, alert: &Alert) -> DedupResult { self.stats.total_checked += 1; - + if !self.config.enabled { return DedupResult { is_duplicate: false, @@ -171,19 +171,19 @@ impl AlertDeduplicator { first_seen: Utc::now(), }; } - + let fingerprint = self.calculate_fingerprint(alert); let now = Utc::now(); - + if let Some(entry) = self.fingerprints.get_mut(&fingerprint) { let elapsed = now - entry.last_seen; - + if elapsed.num_seconds() as u64 <= self.config.window_seconds { // Duplicate within window entry.count += 1; entry.last_seen = now; self.stats.duplicates_found += 1; - + return DedupResult { is_duplicate: true, count: entry.count, @@ -208,14 +208,14 @@ impl AlertDeduplicator { }, ); } - + DedupResult { is_duplicate: false, count: 1, first_seen: now, } } - + /// Get statistics pub fn get_stats(&self) -> DedupStatsPublic { DedupStatsPublic { @@ -223,12 +223,12 @@ impl AlertDeduplicator { duplicates_found: self.stats.duplicates_found, } } - + /// Clear old fingerprints pub fn clear_expired(&mut self) { let now = Utc::now(); let window = self.config.window_seconds; - + self.fingerprints.retain(|_, entry| { let elapsed = now - entry.last_seen; elapsed.num_seconds() as u64 <= window @@ -246,14 +246,14 @@ pub struct DedupStatsPublic { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_dedup_config_default() { let config = DedupConfig::default(); assert!(config.enabled()); assert_eq!(config.window_seconds(), 300); } - + #[test] fn test_fingerprint_display() { let fp = Fingerprint::new("test".to_string()); diff --git a/src/alerting/manager.rs b/src/alerting/manager.rs index c51e2d0..6b2ea53 100644 --- a/src/alerting/manager.rs +++ b/src/alerting/manager.rs @@ -3,7 +3,6 @@ //! Manages alert generation, storage, and lifecycle use anyhow::Result; -use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -34,7 +33,7 @@ impl AlertManager { stats: Arc::new(RwLock::new(AlertStats::default())), }) } - + /// Generate an alert pub fn generate_alert( &mut self, @@ -43,41 +42,37 @@ impl AlertManager { message: String, source_event: Option, ) -> Result { - let mut alert = Alert::new( - alert_type, - severity_to_alert_severity(severity), - message, - ); - + let mut alert = Alert::new(alert_type, severity_to_alert_severity(severity), message); + if let Some(event) = source_event { alert.set_source_event(event); } - + // Store alert let alert_id = alert.id().to_string(); { let mut alerts = self.alerts.write().unwrap(); alerts.insert(alert_id.clone(), alert.clone()); } - + // Update stats self.update_stats_new(); - + Ok(alert) } - + /// Get alert by ID pub fn get_alert(&self, alert_id: &str) -> Option { let alerts = self.alerts.read().unwrap(); alerts.get(alert_id).cloned() } - + /// Get all alerts pub fn get_all_alerts(&self) -> Vec { let alerts = self.alerts.read().unwrap(); alerts.values().cloned().collect() } - + /// Get alerts by severity pub fn get_alerts_by_severity(&self, severity: AlertSeverity) -> Vec { let alerts = self.alerts.read().unwrap(); @@ -87,7 +82,7 @@ impl AlertManager { .cloned() .collect() } - + /// Get alerts by status pub fn get_alerts_by_status(&self, status: AlertStatus) -> Vec { let alerts = self.alerts.read().unwrap(); @@ -97,11 +92,11 @@ impl AlertManager { .cloned() .collect() } - + /// Acknowledge an alert pub fn acknowledge_alert(&mut self, alert_id: &str) -> Result<()> { let mut alerts = self.alerts.write().unwrap(); - + if let Some(alert) = alerts.get_mut(alert_id) { alert.acknowledge(); self.update_stats_ack(); @@ -110,11 +105,11 @@ impl AlertManager { anyhow::bail!("Alert not found: {}", alert_id) } } - + /// Resolve an alert pub fn resolve_alert(&mut self, alert_id: &str, note: String) -> Result<()> { let mut alerts = self.alerts.write().unwrap(); - + if let Some(alert) = alerts.get_mut(alert_id) { alert.resolve(); alert.set_resolution_note(note); @@ -124,24 +119,24 @@ impl AlertManager { anyhow::bail!("Alert not found: {}", alert_id) } } - + /// Get alert count pub fn alert_count(&self) -> usize { let alerts = self.alerts.read().unwrap(); alerts.len() } - + /// Get statistics pub fn get_stats(&self) -> AlertStats { - let stats = self.stats.read().unwrap(); - + let _stats = self.stats.read().unwrap(); + // Calculate current counts from alerts let alerts = self.alerts.read().unwrap(); let mut new_count = 0; let mut ack_count = 0; let mut resolved_count = 0; let mut fp_count = 0; - + for alert in alerts.values() { match alert.status() { AlertStatus::New => new_count += 1, @@ -150,7 +145,7 @@ impl AlertManager { AlertStatus::FalsePositive => fp_count += 1, } } - + AlertStats { total_count: alerts.len() as u64, new_count, @@ -159,24 +154,24 @@ impl AlertManager { false_positive_count: fp_count, } } - + /// Clear resolved alerts pub fn clear_resolved_alerts(&mut self) -> usize { let mut alerts = self.alerts.write().unwrap(); let initial_count = alerts.len(); - + alerts.retain(|_, alert| alert.status() != AlertStatus::Resolved); - + initial_count - alerts.len() } - + /// Update stats for new alert fn update_stats_new(&self) { let mut stats = self.stats.write().unwrap(); stats.total_count += 1; stats.new_count += 1; } - + /// Update stats for acknowledgment fn update_stats_ack(&self) { let mut stats = self.stats.write().unwrap(); @@ -185,7 +180,7 @@ impl AlertManager { stats.acknowledged_count += 1; } } - + /// Update stats for resolution fn update_stats_resolve(&self) { let mut stats = self.stats.write().unwrap(); @@ -219,24 +214,24 @@ fn severity_to_alert_severity(severity: Severity) -> AlertSeverity { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_manager_creation() { let manager = AlertManager::new(); assert!(manager.is_ok()); } - + #[test] fn test_alert_generation() { let mut manager = AlertManager::new().expect("Failed to create manager"); - + let alert = manager.generate_alert( AlertType::ThreatDetected, Severity::High, "Test".to_string(), None, ); - + assert!(alert.is_ok()); assert_eq!(manager.alert_count(), 1); } diff --git a/src/alerting/mod.rs b/src/alerting/mod.rs index 594eb7e..32231f2 100644 --- a/src/alerting/mod.rs +++ b/src/alerting/mod.rs @@ -3,8 +3,8 @@ //! Alert generation, management, and notifications pub mod alert; -pub mod manager; pub mod dedup; +pub mod manager; pub mod notifications; /// Marker struct for module tests @@ -12,6 +12,6 @@ pub struct AlertingMarker; // Re-export commonly used types pub use alert::{Alert, AlertSeverity, AlertStatus, AlertType}; +pub use dedup::{AlertDeduplicator, DedupConfig, DedupResult, Fingerprint}; pub use manager::{AlertManager, AlertStats}; -pub use dedup::{AlertDeduplicator, DedupConfig, Fingerprint, DedupResult}; pub use notifications::{NotificationChannel, NotificationConfig, NotificationResult}; diff --git a/src/alerting/notifications.rs b/src/alerting/notifications.rs index d35d7e0..d4ba3e5 100644 --- a/src/alerting/notifications.rs +++ b/src/alerting/notifications.rs @@ -3,7 +3,6 @@ //! Notification channels for alert delivery use anyhow::Result; -use chrono::{DateTime, Utc}; use crate::alerting::alert::{Alert, AlertSeverity}; @@ -13,10 +12,10 @@ pub struct NotificationConfig { slack_webhook: Option, smtp_host: Option, smtp_port: Option, - smtp_user: Option, - smtp_password: Option, + _smtp_user: Option, + _smtp_password: Option, webhook_url: Option, - email_recipients: Vec, + _email_recipients: Vec, } impl NotificationConfig { @@ -26,52 +25,52 @@ impl NotificationConfig { slack_webhook: None, smtp_host: None, smtp_port: None, - smtp_user: None, - smtp_password: None, + _smtp_user: None, + _smtp_password: None, webhook_url: None, - email_recipients: Vec::new(), + _email_recipients: Vec::new(), } } - + /// Set Slack webhook pub fn with_slack_webhook(mut self, url: String) -> Self { self.slack_webhook = Some(url); self } - + /// Set SMTP host pub fn with_smtp_host(mut self, host: String) -> Self { self.smtp_host = Some(host); self } - + /// Set SMTP port pub fn with_smtp_port(mut self, port: u16) -> Self { self.smtp_port = Some(port); self } - + /// Set webhook URL pub fn with_webhook_url(mut self, url: String) -> Self { self.webhook_url = Some(url); self } - + /// Get Slack webhook pub fn slack_webhook(&self) -> Option<&str> { self.slack_webhook.as_deref() } - + /// Get SMTP host pub fn smtp_host(&self) -> Option<&str> { self.smtp_host.as_deref() } - + /// Get SMTP port pub fn smtp_port(&self) -> Option { self.smtp_port } - + /// Get webhook URL pub fn webhook_url(&self) -> Option<&str> { self.webhook_url.as_deref() @@ -97,7 +96,7 @@ impl NotificationChannel { NotificationChannel::Webhook => self.send_webhook(alert, _config), } } - + /// Send to console fn send_console(&self, alert: &Alert) -> Result { println!( @@ -107,10 +106,10 @@ impl NotificationChannel { alert.alert_type(), alert.message() ); - + Ok(NotificationResult::Success("sent to console".to_string())) } - + /// Send to Slack via incoming webhook fn send_slack(&self, alert: &Alert, config: &NotificationConfig) -> Result { if let Some(webhook_url) = config.slack_webhook() { @@ -134,20 +133,28 @@ impl NotificationChannel { let status = resp.status(); let body = resp.text().unwrap_or_default(); log::warn!("Slack API returned {}: {}", status, body); - Ok(NotificationResult::Failure(format!("Slack returned {}: {}", status, body))) + Ok(NotificationResult::Failure(format!( + "Slack returned {}: {}", + status, body + ))) } } Err(e) => { log::warn!("Failed to send Slack notification: {}", e); - Ok(NotificationResult::Failure(format!("Slack request failed: {}", e))) + Ok(NotificationResult::Failure(format!( + "Slack request failed: {}", + e + ))) } } } else { log::debug!("Slack webhook not configured, skipping"); - Ok(NotificationResult::Failure("Slack webhook not configured".to_string())) + Ok(NotificationResult::Failure( + "Slack webhook not configured".to_string(), + )) } } - + /// Send via email fn send_email(&self, alert: &Alert, config: &NotificationConfig) -> Result { // In production, this would send SMTP email @@ -156,19 +163,27 @@ impl NotificationChannel { log::info!("Would send email: {}", alert.message()); Ok(NotificationResult::Success("sent via email".to_string())) } else { - Ok(NotificationResult::Failure("SMTP not configured".to_string())) + Ok(NotificationResult::Failure( + "SMTP not configured".to_string(), + )) } } - + /// Send to webhook - fn send_webhook(&self, alert: &Alert, config: &NotificationConfig) -> Result { + fn send_webhook( + &self, + alert: &Alert, + config: &NotificationConfig, + ) -> Result { // In production, this would make HTTP POST // For now, just log if config.webhook_url().is_some() { log::info!("Would send to webhook: {}", alert.message()); Ok(NotificationResult::Success("sent to webhook".to_string())) } else { - Ok(NotificationResult::Failure("Webhook URL not configured".to_string())) + Ok(NotificationResult::Failure( + "Webhook URL not configured".to_string(), + )) } } } @@ -209,10 +224,7 @@ pub fn route_by_severity(severity: AlertSeverity) -> Vec { ] } AlertSeverity::Medium => { - vec![ - NotificationChannel::Console, - NotificationChannel::Slack, - ] + vec![NotificationChannel::Console, NotificationChannel::Slack] } AlertSeverity::Low => { vec![NotificationChannel::Console] @@ -248,7 +260,8 @@ pub fn build_slack_message(alert: &Alert) -> String { {"title": "Time", "value": alert.timestamp().to_rfc3339(), "short": true} ] }] - }).to_string() + }) + .to_string() } /// Build webhook payload @@ -272,7 +285,7 @@ pub fn build_webhook_payload(alert: &Alert) -> String { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_console_notification() { let channel = NotificationChannel::Console; @@ -281,22 +294,22 @@ mod tests { AlertSeverity::High, "Test".to_string(), ); - + let result = channel.send(&alert, &NotificationConfig::default()); assert!(result.is_ok()); } - + #[test] fn test_severity_to_slack_color() { assert_eq!(severity_to_slack_color(AlertSeverity::Critical), "#FF0000"); assert_eq!(severity_to_slack_color(AlertSeverity::High), "#FF8C00"); } - + #[test] fn test_route_by_severity() { let critical_routes = route_by_severity(AlertSeverity::Critical); assert!(critical_routes.len() >= 3); - + let info_routes = route_by_severity(AlertSeverity::Info); assert_eq!(info_routes.len(), 1); } diff --git a/src/api/alerts.rs b/src/api/alerts.rs index 44227ca..5a80e37 100644 --- a/src/api/alerts.rs +++ b/src/api/alerts.rs @@ -1,17 +1,11 @@ //! Alerts API endpoints -use actix_web::{web, HttpResponse, Responder}; -use serde::Deserialize; use crate::database::{ - DbPool, - list_alerts as db_list_alerts, - get_alert_stats as db_get_alert_stats, - update_alert_status, - create_sample_alert, - AlertFilter, + create_sample_alert, get_alert_stats as db_get_alert_stats, list_alerts as db_list_alerts, + update_alert_status, AlertFilter, DbPool, }; -use uuid::Uuid; -use chrono::Utc; +use actix_web::{web, HttpResponse, Responder}; +use serde::Deserialize; /// Query parameters for alert filtering #[derive(Debug, Deserialize)] @@ -21,17 +15,14 @@ pub struct AlertQuery { } /// Get all alerts -/// +/// /// GET /api/alerts -pub async fn get_alerts( - pool: web::Data, - query: web::Query, -) -> impl Responder { +pub async fn get_alerts(pool: web::Data, query: web::Query) -> impl Responder { let filter = AlertFilter { severity: query.severity.clone(), status: query.status.clone(), }; - + match db_list_alerts(&pool, filter).await { Ok(alerts) => HttpResponse::Ok().json(alerts), Err(e) => { @@ -44,7 +35,7 @@ pub async fn get_alerts( } /// Get alert statistics -/// +/// /// GET /api/alerts/stats pub async fn get_alert_stats(pool: web::Data) -> impl Responder { match db_get_alert_stats(&pool).await { @@ -68,14 +59,11 @@ pub async fn get_alert_stats(pool: web::Data) -> impl Responder { } /// Acknowledge an alert -/// +/// /// POST /api/alerts/:id/acknowledge -pub async fn acknowledge_alert( - pool: web::Data, - path: web::Path, -) -> impl Responder { +pub async fn acknowledge_alert(pool: web::Data, path: web::Path) -> impl Responder { let alert_id = path.into_inner(); - + match update_alert_status(&pool, &alert_id, "Acknowledged").await { Ok(()) => { log::info!("Acknowledged alert: {}", alert_id); @@ -94,7 +82,7 @@ pub async fn acknowledge_alert( } /// Resolve an alert -/// +/// /// POST /api/alerts/:id/resolve #[derive(Debug, Deserialize)] pub struct ResolveRequest { @@ -108,7 +96,7 @@ pub async fn resolve_alert( ) -> impl Responder { let alert_id = path.into_inner(); let _note = body.note.clone().unwrap_or_default(); - + match update_alert_status(&pool, &alert_id, "Resolved").await { Ok(()) => { log::info!("Resolved alert {}: {}", alert_id, _note); @@ -129,16 +117,16 @@ pub async fn resolve_alert( /// Seed database with sample alerts (for testing) pub async fn seed_sample_alerts(pool: web::Data) -> impl Responder { use crate::database::create_alert; - + let mut created = Vec::new(); - + for i in 0..5 { let alert = create_sample_alert(); if create_alert(&pool, alert).await.is_ok() { created.push(i); } } - + HttpResponse::Ok().json(serde_json::json!({ "created": created.len(), "message": "Sample alerts created" @@ -153,26 +141,24 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { .route("/stats", web::get().to(get_alert_stats)) .route("/{id}/acknowledge", web::post().to(acknowledge_alert)) .route("/{id}/resolve", web::post().to(resolve_alert)) - .route("/seed", web::post().to(seed_sample_alerts)) // For testing + .route("/seed", web::post().to(seed_sample_alerts)), // For testing ); } #[cfg(test)] mod tests { use super::*; - use actix_web::{test, App}; use crate::database::create_pool; + use actix_web::{test, App}; #[actix_rt::test] async fn test_get_alerts_empty() { let pool = create_pool(":memory:").unwrap(); + crate::database::init_database(&pool).unwrap(); let pool_data = web::Data::new(pool); - - let app = test::init_service( - App::new() - .app_data(pool_data) - .configure(configure_routes) - ).await; + + let app = + test::init_service(App::new().app_data(pool_data).configure(configure_routes)).await; let req = test::TestRequest::get().uri("/api/alerts").to_request(); let resp = test::call_service(&app, req).await; diff --git a/src/api/containers.rs b/src/api/containers.rs index 886821e..85d76c2 100644 --- a/src/api/containers.rs +++ b/src/api/containers.rs @@ -1,11 +1,10 @@ //! Containers API endpoints -use actix_web::{web, HttpResponse, Responder}; -use serde::Deserialize; use crate::database::DbPool; -use crate::docker::containers::ContainerManager; use crate::docker::client::ContainerInfo; -use crate::database::models::ContainerCache; +use crate::docker::containers::ContainerManager; +use actix_web::{web, HttpResponse, Responder}; +use serde::Deserialize; /// Quarantine request #[derive(Debug, Deserialize)] @@ -14,7 +13,7 @@ pub struct QuarantineRequest { } /// Get all containers -/// +/// /// GET /api/containers pub async fn get_containers(pool: web::Data) -> impl Responder { // Create container manager @@ -23,53 +22,54 @@ pub async fn get_containers(pool: web::Data) -> impl Responder { Err(e) => { log::error!("Failed to create container manager: {}", e); // Return mock data if Docker not available - return HttpResponse::Ok().json(vec![ - serde_json::json!({ - "id": "mock-container-1", - "name": "web-server", - "image": "nginx:latest", - "status": "Running", - "security_status": { - "state": "Secure", - "threats": 0, - "vulnerabilities": 0 - }, - "risk_score": 10, - "network_activity": { - "inbound_connections": 5, - "outbound_connections": 3, - "blocked_connections": 0, - "suspicious_activity": false - } - }) - ]); + return HttpResponse::Ok().json(vec![serde_json::json!({ + "id": "mock-container-1", + "name": "web-server", + "image": "nginx:latest", + "status": "Running", + "security_status": { + "state": "Secure", + "threats": 0, + "vulnerabilities": 0 + }, + "risk_score": 10, + "network_activity": { + "inbound_connections": 5, + "outbound_connections": 3, + "blocked_connections": 0, + "suspicious_activity": false + } + })]); } }; - + match manager.list_containers().await { Ok(containers) => { // Convert to API response format - let response: Vec = containers.iter().map(|c: &ContainerInfo| { - serde_json::json!({ - "id": c.id, - "name": c.name, - "image": c.image, - "status": c.status, - "security_status": { - "state": "Secure", - "threats": 0, - "vulnerabilities": 0 - }, - "risk_score": 0, - "network_activity": { - "inbound_connections": 0, - "outbound_connections": 0, - "blocked_connections": 0, - "suspicious_activity": false - } + let response: Vec = containers + .iter() + .map(|c: &ContainerInfo| { + serde_json::json!({ + "id": c.id, + "name": c.name, + "image": c.image, + "status": c.status, + "security_status": { + "state": "Secure", + "threats": 0, + "vulnerabilities": 0 + }, + "risk_score": 0, + "network_activity": { + "inbound_connections": 0, + "outbound_connections": 0, + "blocked_connections": 0, + "suspicious_activity": false + } + }) }) - }).collect(); - + .collect(); + HttpResponse::Ok().json(response) } Err(e) => { @@ -82,7 +82,7 @@ pub async fn get_containers(pool: web::Data) -> impl Responder { } /// Quarantine a container -/// +/// /// POST /api/containers/:id/quarantine pub async fn quarantine_container( pool: web::Data, @@ -91,7 +91,7 @@ pub async fn quarantine_container( ) -> impl Responder { let container_id = path.into_inner(); let reason = body.into_inner().reason; - + let manager = match ContainerManager::new(pool.get_ref().clone()).await { Ok(m) => m, Err(e) => { @@ -101,7 +101,7 @@ pub async fn quarantine_container( })); } }; - + match manager.quarantine_container(&container_id, &reason).await { Ok(()) => HttpResponse::Ok().json(serde_json::json!({ "success": true, @@ -117,14 +117,11 @@ pub async fn quarantine_container( } /// Release a container from quarantine -/// +/// /// POST /api/containers/:id/release -pub async fn release_container( - pool: web::Data, - path: web::Path, -) -> impl Responder { +pub async fn release_container(pool: web::Data, path: web::Path) -> impl Responder { let container_id = path.into_inner(); - + let manager = match ContainerManager::new(pool.get_ref().clone()).await { Ok(m) => m, Err(e) => { @@ -134,7 +131,7 @@ pub async fn release_container( })); } }; - + match manager.release_container(&container_id).await { Ok(()) => HttpResponse::Ok().json(serde_json::json!({ "success": true, @@ -155,27 +152,24 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { web::scope("/api/containers") .route("", web::get().to(get_containers)) .route("/{id}/quarantine", web::post().to(quarantine_container)) - .route("/{id}/release", web::post().to(release_container)) + .route("/{id}/release", web::post().to(release_container)), ); } #[cfg(test)] mod tests { use super::*; - use actix_web::{test, App}; use crate::database::{create_pool, init_database}; + use actix_web::{test, App}; #[actix_rt::test] async fn test_get_containers() { let pool = create_pool(":memory:").unwrap(); init_database(&pool).unwrap(); let pool_data = web::Data::new(pool); - - let app = test::init_service( - App::new() - .app_data(pool_data) - .configure(configure_routes) - ).await; + + let app = + test::init_service(App::new().app_data(pool_data).configure(configure_routes)).await; let req = test::TestRequest::get().uri("/api/containers").to_request(); let resp = test::call_service(&app, req).await; diff --git a/src/api/logs.rs b/src/api/logs.rs index 9963c33..5468fa7 100644 --- a/src/api/logs.rs +++ b/src/api/logs.rs @@ -1,10 +1,10 @@ //! Log sources and summaries API endpoints -use actix_web::{web, HttpResponse, Responder}; -use serde::Deserialize; use crate::database::connection::DbPool; use crate::database::repositories::log_sources; use crate::sniff::discovery::{LogSource, LogSourceType}; +use actix_web::{web, HttpResponse, Responder}; +use serde::Deserialize; /// Query parameters for summary filtering #[derive(Debug, Deserialize)] @@ -102,7 +102,9 @@ pub async fn list_summaries( Ok(sources) => { let mut all_summaries = Vec::new(); for source in &sources { - if let Ok(summaries) = log_sources::list_summaries_for_source(&pool, &source.path_or_id) { + if let Ok(summaries) = + log_sources::list_summaries_for_source(&pool, &source.path_or_id) + { all_summaries.extend(summaries); } } @@ -136,15 +138,15 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { .route("/sources", web::post().to(add_source)) .route("/sources/{path}", web::get().to(get_source)) .route("/sources/{path}", web::delete().to(delete_source)) - .route("/summaries", web::get().to(list_summaries)) + .route("/summaries", web::get().to(list_summaries)), ); } #[cfg(test)] mod tests { use super::*; - use actix_web::{test, App}; use crate::database::connection::{create_pool, init_database}; + use actix_web::{test, App}; fn setup_pool() -> DbPool { let pool = create_pool(":memory:").unwrap(); @@ -158,10 +160,13 @@ mod tests { let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; - let req = test::TestRequest::get().uri("/api/logs/sources").to_request(); + let req = test::TestRequest::get() + .uri("/api/logs/sources") + .to_request(); let resp = test::call_service(&app, req).await; assert_eq!(resp.status(), 200); } @@ -172,8 +177,9 @@ mod tests { let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; let body = serde_json::json!({ "path": "/var/log/test.log", "name": "Test Log" }); let req = test::TestRequest::post() @@ -190,8 +196,9 @@ mod tests { let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; // Add a source let body = serde_json::json!({ "path": "/var/log/app.log" }); @@ -202,7 +209,9 @@ mod tests { test::call_service(&app, req).await; // List sources - let req = test::TestRequest::get().uri("/api/logs/sources").to_request(); + let req = test::TestRequest::get() + .uri("/api/logs/sources") + .to_request(); let resp = test::call_service(&app, req).await; assert_eq!(resp.status(), 200); @@ -216,10 +225,13 @@ mod tests { let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; - let req = test::TestRequest::get().uri("/api/logs/sources/nonexistent").to_request(); + let req = test::TestRequest::get() + .uri("/api/logs/sources/nonexistent") + .to_request(); let resp = test::call_service(&app, req).await; assert_eq!(resp.status(), 404); } @@ -229,14 +241,19 @@ mod tests { let pool = setup_pool(); // Add source directly via repository (avoids route path issues) - let source = LogSource::new(LogSourceType::CustomFile, "test-delete.log".into(), "Test Delete".into()); + let source = LogSource::new( + LogSourceType::CustomFile, + "test-delete.log".into(), + "Test Delete".into(), + ); log_sources::upsert_log_source(&pool, &source).unwrap(); let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; let req = test::TestRequest::delete() .uri("/api/logs/sources/test-delete.log") @@ -251,10 +268,13 @@ mod tests { let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; - let req = test::TestRequest::get().uri("/api/logs/summaries").to_request(); + let req = test::TestRequest::get() + .uri("/api/logs/summaries") + .to_request(); let resp = test::call_service(&app, req).await; assert_eq!(resp.status(), 200); } @@ -265,8 +285,9 @@ mod tests { let app = test::init_service( App::new() .app_data(web::Data::new(pool)) - .configure(configure_routes) - ).await; + .configure(configure_routes), + ) + .await; let req = test::TestRequest::get() .uri("/api/logs/summaries?source_id=test-source") diff --git a/src/api/mod.rs b/src/api/mod.rs index 6120aab..56ab962 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,23 +2,23 @@ //! //! REST API and WebSocket endpoints -pub mod security; pub mod alerts; pub mod containers; +pub mod logs; +pub mod security; pub mod threats; pub mod websocket; -pub mod logs; /// Marker struct for module tests pub struct ApiMarker; // Re-export route configurators -pub use security::configure_routes as configure_security_routes; pub use alerts::configure_routes as configure_alerts_routes; pub use containers::configure_routes as configure_containers_routes; +pub use logs::configure_routes as configure_logs_routes; +pub use security::configure_routes as configure_security_routes; pub use threats::configure_routes as configure_threats_routes; pub use websocket::configure_routes as configure_websocket_routes; -pub use logs::configure_routes as configure_logs_routes; /// Configure all API routes pub fn configure_all_routes(cfg: &mut actix_web::web::ServiceConfig) { diff --git a/src/api/security.rs b/src/api/security.rs index 7d7201e..1c97f15 100644 --- a/src/api/security.rs +++ b/src/api/security.rs @@ -1,10 +1,10 @@ //! Security API endpoints +use crate::models::api::security::SecurityStatusResponse; use actix_web::{web, HttpResponse, Responder}; -use stackdog::models::api::security::SecurityStatusResponse; /// Get overall security status -/// +/// /// GET /api/security/status pub async fn get_security_status() -> impl Responder { let status = SecurityStatusResponse::new(); @@ -13,10 +13,7 @@ pub async fn get_security_status() -> impl Responder { /// Configure security routes pub fn configure_routes(cfg: &mut web::ServiceConfig) { - cfg.service( - web::scope("/api/security") - .route("/status", web::get().to(get_security_status)) - ); + cfg.service(web::scope("/api/security").route("/status", web::get().to(get_security_status))); } #[cfg(test)] @@ -26,11 +23,11 @@ mod tests { #[actix_rt::test] async fn test_get_security_status() { - let app = test::init_service( - App::new().configure(configure_routes) - ).await; + let app = test::init_service(App::new().configure(configure_routes)).await; - let req = test::TestRequest::get().uri("/api/security/status").to_request(); + let req = test::TestRequest::get() + .uri("/api/security/status") + .to_request(); let resp = test::call_service(&app, req).await; assert!(resp.status().is_success()); diff --git a/src/api/threats.rs b/src/api/threats.rs index 6c5c36c..a9a5886 100644 --- a/src/api/threats.rs +++ b/src/api/threats.rs @@ -1,31 +1,29 @@ //! Threats API endpoints +use crate::models::api::threats::{ThreatResponse, ThreatStatisticsResponse}; use actix_web::{web, HttpResponse, Responder}; use std::collections::HashMap; -use stackdog::models::api::threats::{ThreatResponse, ThreatStatisticsResponse}; /// Get all threats -/// +/// /// GET /api/threats pub async fn get_threats() -> impl Responder { // TODO: Fetch from database when implemented - let threats = vec![ - ThreatResponse { - id: "threat-1".to_string(), - r#type: "CryptoMiner".to_string(), - severity: "High".to_string(), - score: 85, - source: "container-1".to_string(), - timestamp: chrono::Utc::now().to_rfc3339(), - status: "New".to_string(), - }, - ]; - + let threats = vec![ThreatResponse { + id: "threat-1".to_string(), + r#type: "CryptoMiner".to_string(), + severity: "High".to_string(), + score: 85, + source: "container-1".to_string(), + timestamp: chrono::Utc::now().to_rfc3339(), + status: "New".to_string(), + }]; + HttpResponse::Ok().json(threats) } /// Get threat statistics -/// +/// /// GET /api/threats/statistics pub async fn get_threat_statistics() -> impl Responder { let mut by_severity = HashMap::new(); @@ -34,19 +32,19 @@ pub async fn get_threat_statistics() -> impl Responder { by_severity.insert("Medium".to_string(), 3); by_severity.insert("High".to_string(), 3); by_severity.insert("Critical".to_string(), 1); - + let mut by_type = HashMap::new(); by_type.insert("CryptoMiner".to_string(), 3); by_type.insert("ContainerEscape".to_string(), 2); by_type.insert("NetworkScanner".to_string(), 5); - + let stats = ThreatStatisticsResponse { total_threats: 10, by_severity, by_type, trend: "stable".to_string(), }; - + HttpResponse::Ok().json(stats) } @@ -55,7 +53,7 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { cfg.service( web::scope("/api/threats") .route("", web::get().to(get_threats)) - .route("/statistics", web::get().to(get_threat_statistics)) + .route("/statistics", web::get().to(get_threat_statistics)), ); } @@ -66,9 +64,7 @@ mod tests { #[actix_rt::test] async fn test_get_threats() { - let app = test::init_service( - App::new().configure(configure_routes) - ).await; + let app = test::init_service(App::new().configure(configure_routes)).await; let req = test::TestRequest::get().uri("/api/threats").to_request(); let resp = test::call_service(&app, req).await; @@ -78,11 +74,11 @@ mod tests { #[actix_rt::test] async fn test_get_threat_statistics() { - let app = test::init_service( - App::new().configure(configure_routes) - ).await; + let app = test::init_service(App::new().configure(configure_routes)).await; - let req = test::TestRequest::get().uri("/api/threats/statistics").to_request(); + let req = test::TestRequest::get() + .uri("/api/threats/statistics") + .to_request(); let resp = test::call_service(&app, req).await; assert!(resp.status().is_success()); diff --git a/src/api/websocket.rs b/src/api/websocket.rs index dba6e92..106fe05 100644 --- a/src/api/websocket.rs +++ b/src/api/websocket.rs @@ -1,23 +1,24 @@ //! WebSocket handler for real-time updates -//! +//! //! Note: Full WebSocket implementation requires additional setup. //! This is a placeholder that returns 426 Upgrade Required. -//! +//! //! TODO: Implement proper WebSocket support with: //! - actix-web-actors with proper Actor trait implementation //! - Or use tokio-tungstenite for lower-level WebSocket handling -use actix_web::{web, Error, HttpRequest, HttpResponse, http::StatusCode}; +use actix_web::{http::StatusCode, web, Error, HttpRequest, HttpResponse}; use log::info; /// WebSocket endpoint handler (placeholder) -/// +/// /// Returns 426 Upgrade Required to indicate WebSocket is not yet fully implemented -pub async fn websocket_handler( - req: HttpRequest, -) -> Result { - info!("WebSocket connection attempt from: {:?}", req.connection_info().peer_addr()); - +pub async fn websocket_handler(req: HttpRequest) -> Result { + info!( + "WebSocket connection attempt from: {:?}", + req.connection_info().peer_addr() + ); + // Return upgrade required response // Client should retry with proper WebSocket upgrade headers Ok(HttpResponse::build(StatusCode::SWITCHING_PROTOCOLS) @@ -37,9 +38,7 @@ mod tests { #[actix_rt::test] async fn test_websocket_endpoint_exists() { - let app = test::init_service( - App::new().configure(configure_routes) - ).await; + let app = test::init_service(App::new().configure(configure_routes)).await; let req = test::TestRequest::get().uri("/ws").to_request(); let resp = test::call_service(&app, req).await; diff --git a/src/cli.rs b/src/cli.rs index ea26fcc..9ff6579 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -67,7 +67,10 @@ mod tests { #[test] fn test_no_subcommand_defaults_to_none() { let cli = Cli::parse_from(["stackdog"]); - assert!(cli.command.is_none(), "No subcommand should yield None (default to serve)"); + assert!( + cli.command.is_none(), + "No subcommand should yield None (default to serve)" + ); } #[test] @@ -80,7 +83,17 @@ mod tests { fn test_sniff_subcommand_defaults() { let cli = Cli::parse_from(["stackdog", "sniff"]); match cli.command { - Some(Command::Sniff { once, consume, output, sources, interval, ai_provider, ai_model, ai_api_url, slack_webhook }) => { + Some(Command::Sniff { + once, + consume, + output, + sources, + interval, + ai_provider, + ai_model, + ai_api_url, + slack_webhook, + }) => { assert!(!once); assert!(!consume); assert_eq!(output, "./stackdog-logs/"); @@ -116,19 +129,37 @@ mod tests { #[test] fn test_sniff_with_all_options() { let cli = Cli::parse_from([ - "stackdog", "sniff", + "stackdog", + "sniff", "--once", "--consume", - "--output", "/tmp/logs/", - "--sources", "/var/log/syslog,/var/log/auth.log", - "--interval", "60", - "--ai-provider", "openai", - "--ai-model", "gpt-4o-mini", - "--ai-api-url", "https://api.openai.com/v1", - "--slack-webhook", "https://hooks.slack.com/services/T/B/xxx", + "--output", + "/tmp/logs/", + "--sources", + "/var/log/syslog,/var/log/auth.log", + "--interval", + "60", + "--ai-provider", + "openai", + "--ai-model", + "gpt-4o-mini", + "--ai-api-url", + "https://api.openai.com/v1", + "--slack-webhook", + "https://hooks.slack.com/services/T/B/xxx", ]); match cli.command { - Some(Command::Sniff { once, consume, output, sources, interval, ai_provider, ai_model, ai_api_url, slack_webhook }) => { + Some(Command::Sniff { + once, + consume, + output, + sources, + interval, + ai_provider, + ai_model, + ai_api_url, + slack_webhook, + }) => { assert!(once); assert!(consume); assert_eq!(output, "/tmp/logs/"); @@ -137,7 +168,10 @@ mod tests { assert_eq!(ai_provider.unwrap(), "openai"); assert_eq!(ai_model.unwrap(), "gpt-4o-mini"); assert_eq!(ai_api_url.unwrap(), "https://api.openai.com/v1"); - assert_eq!(slack_webhook.unwrap(), "https://hooks.slack.com/services/T/B/xxx"); + assert_eq!( + slack_webhook.unwrap(), + "https://hooks.slack.com/services/T/B/xxx" + ); } _ => panic!("Expected Sniff command"), } @@ -157,13 +191,20 @@ mod tests { #[test] fn test_sniff_with_ollama_provider_and_model() { let cli = Cli::parse_from([ - "stackdog", "sniff", + "stackdog", + "sniff", "--once", - "--ai-provider", "ollama", - "--ai-model", "qwen2.5-coder:latest", + "--ai-provider", + "ollama", + "--ai-model", + "qwen2.5-coder:latest", ]); match cli.command { - Some(Command::Sniff { ai_provider, ai_model, .. }) => { + Some(Command::Sniff { + ai_provider, + ai_model, + .. + }) => { assert_eq!(ai_provider.unwrap(), "ollama"); assert_eq!(ai_model.unwrap(), "qwen2.5-coder:latest"); } diff --git a/src/collectors/ebpf/container.rs b/src/collectors/ebpf/container.rs index 98de118..435cc0b 100644 --- a/src/collectors/ebpf/container.rs +++ b/src/collectors/ebpf/container.rs @@ -2,7 +2,7 @@ //! //! Detects container ID from cgroup and other sources -use anyhow::{Result, Context}; +use anyhow::Result; /// Container detector pub struct ContainerDetector { @@ -19,37 +19,37 @@ impl ContainerDetector { cache: std::collections::HashMap::new(), }) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("Container detection only available on Linux"); } } - + /// Detect container ID for a process pub fn detect_container(&mut self, pid: u32) -> Option { // Check cache first if let Some(cached) = self.cache.get(&pid) { return Some(cached.clone()); } - + // Try to detect from cgroup let container_id = self.detect_from_cgroup(pid); - + // Cache result if let Some(id) = &container_id { self.cache.insert(pid, id.clone()); } - + container_id } - + /// Detect container ID from cgroup file - fn detect_from_cgroup(&self, pid: u32) -> Option { + fn detect_from_cgroup(&self, _pid: u32) -> Option { #[cfg(target_os = "linux")] { // Read /proc/[pid]/cgroup - let cgroup_path = format!("/proc/{}/cgroup", pid); + let cgroup_path = format!("/proc/{}/cgroup", _pid); if let Ok(content) = std::fs::read_to_string(&cgroup_path) { for line in content.lines() { if let Some(id) = Self::parse_container_from_cgroup(line) { @@ -58,41 +58,41 @@ impl ContainerDetector { } } } - + None } - + /// Parse container ID from cgroup line pub fn parse_container_from_cgroup(cgroup_line: &str) -> Option { // Format: hierarchy:controllers:path // Docker: 12:memory:/docker/abc123def456... // Kubernetes: 11:cpu:/kubepods/pod123/def456... - + let parts: Vec<&str> = cgroup_line.split(':').collect(); if parts.len() < 3 { return None; } - + let path = parts[2]; - + // Try Docker format if let Some(id) = Self::extract_docker_id(path) { return Some(id); } - + // Try Kubernetes format if let Some(id) = Self::extract_kubernetes_id(path) { return Some(id); } - + // Try containerd format if let Some(id) = Self::extract_containerd_id(path) { return Some(id); } - + None } - + /// Extract Docker container ID fn extract_docker_id(path: &str) -> Option { // Look for /docker/[container_id] @@ -100,30 +100,30 @@ impl ContainerDetector { let start = pos + 8; let id = &path[start..]; let id = id.split('/').next()?; - + if Self::is_valid_container_id(id) { return Some(id.to_string()); } } - + None } - + /// Extract Kubernetes container ID fn extract_kubernetes_id(path: &str) -> Option { // Look for /kubepods/.../container_id if path.contains("/kubepods/") { // Get last component - let id = path.split('/').last()?; - + let id = path.split('/').next_back()?; + if Self::is_valid_container_id(id) { return Some(id.to_string()); } } - + None } - + /// Extract containerd container ID fn extract_containerd_id(path: &str) -> Option { // Look for /containerd/[container_id] @@ -131,42 +131,42 @@ impl ContainerDetector { let start = pos + 12; let id = &path[start..]; let id = id.split('/').next()?; - + if Self::is_valid_container_id(id) { return Some(id.to_string()); } } - + None } - + /// Validate container ID format pub fn validate_container_id(&self, id: &str) -> bool { Self::is_valid_container_id(id) } - + /// Check if string is a valid container ID fn is_valid_container_id(id: &str) -> bool { // Container IDs are typically 64 hex characters (full) or 12 hex characters (short) if id.is_empty() { return false; } - + // Check length if id.len() != 12 && id.len() != 64 { return false; } - + // Check all characters are hex id.chars().all(|c| c.is_ascii_hexdigit()) } - + /// Get current process container ID pub fn current_container(&mut self) -> Option { let pid = std::process::id(); self.detect_container(pid) } - + /// Clear the cache pub fn clear_cache(&mut self) { self.cache.clear(); @@ -182,66 +182,77 @@ impl Default for ContainerDetector { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_detector_creation() { let detector = ContainerDetector::new(); - + #[cfg(target_os = "linux")] assert!(detector.is_ok()); - + #[cfg(not(target_os = "linux"))] assert!(detector.is_err()); } - + #[test] fn test_parse_docker_cgroup() { - let cgroup = "12:memory:/docker/abc123def456abc123def456abc123def456abc123def456abc123def456abcd"; + let cgroup = + "12:memory:/docker/abc123def456abc123def456abc123def456abc123def456abc123def456abcd"; let result = ContainerDetector::parse_container_from_cgroup(cgroup); - assert_eq!(result, Some("abc123def456abc123def456abc123def456abc123def456abc123def456abcd".to_string())); + assert_eq!( + result, + Some("abc123def456abc123def456abc123def456abc123def456abc123def456abcd".to_string()) + ); } #[test] fn test_parse_kubernetes_cgroup() { let cgroup = "11:cpu:/kubepods/pod123/def456abc123def456abc123def456abc123def456abc123def456abc123def4"; let result = ContainerDetector::parse_container_from_cgroup(cgroup); - assert_eq!(result, Some("def456abc123def456abc123def456abc123def456abc123def456abc123def4".to_string())); + assert_eq!( + result, + Some("def456abc123def456abc123def456abc123def456abc123def456abc123def4".to_string()) + ); } - + #[test] fn test_parse_non_container_cgroup() { let cgroup = "10:cpuacct:/"; let result = ContainerDetector::parse_container_from_cgroup(cgroup); assert_eq!(result, None); } - + #[cfg(target_os = "linux")] #[test] fn test_validate_valid_container_id() { let detector = ContainerDetector::new().unwrap(); - + // Full ID (64 chars) - assert!(detector.validate_container_id("abc123def456789012345678901234567890123456789012345678901234abcd")); - + assert!(detector.validate_container_id( + "abc123def456789012345678901234567890123456789012345678901234abcd" + )); + // Short ID (12 chars) assert!(detector.validate_container_id("abc123def456")); } - + #[cfg(target_os = "linux")] #[test] fn test_validate_invalid_container_id() { let detector = ContainerDetector::new().unwrap(); - + // Empty assert!(!detector.validate_container_id("")); - + // Too short assert!(!detector.validate_container_id("abc123")); - + // Invalid chars assert!(!detector.validate_container_id("abc123def45!")); - + // Too long - assert!(!detector.validate_container_id("abc123def4567890123456789012345678901234567890123456789012345678901234567890")); + assert!(!detector.validate_container_id( + "abc123def4567890123456789012345678901234567890123456789012345678901234567890" + )); } } diff --git a/src/collectors/ebpf/enrichment.rs b/src/collectors/ebpf/enrichment.rs index fcbde6c..141c9da 100644 --- a/src/collectors/ebpf/enrichment.rs +++ b/src/collectors/ebpf/enrichment.rs @@ -2,39 +2,38 @@ //! //! Enriches syscall events with additional context (container ID, process info, etc.) -use anyhow::Result; use crate::events::syscall::SyscallEvent; +use anyhow::Result; /// Event enricher pub struct EventEnricher { - // Cache for process information - process_cache: std::collections::HashMap, + _process_cache: std::collections::HashMap, } #[derive(Debug, Clone)] struct ProcessInfo { - pid: u32, - ppid: u32, - comm: Option, + _pid: u32, + _ppid: u32, + _comm: Option, } impl EventEnricher { /// Create a new event enricher pub fn new() -> Result { Ok(Self { - process_cache: std::collections::HashMap::new(), + _process_cache: std::collections::HashMap::new(), }) } - + /// Enrich an event with additional information pub fn enrich(&mut self, event: &mut SyscallEvent) -> Result<()> { // Add timestamp normalization (already done in event creation) // Add process information self.enrich_process_info(event); - + Ok(()) } - + /// Enrich event with process information fn enrich_process_info(&mut self, event: &mut SyscallEvent) { // Try to get process comm if not already set @@ -42,13 +41,13 @@ impl EventEnricher { event.comm = self.get_process_comm(event.pid); } } - + /// Get parent PID for a process - pub fn get_parent_pid(&self, pid: u32) -> Option { + pub fn get_parent_pid(&self, _pid: u32) -> Option { #[cfg(target_os = "linux")] { // Read from /proc/[pid]/stat - let stat_path = format!("/proc/{}/stat", pid); + let stat_path = format!("/proc/{}/stat", _pid); if let Ok(content) = std::fs::read_to_string(&stat_path) { // Parse ppid from stat file (field 4) let parts: Vec<&str> = content.split_whitespace().collect(); @@ -59,22 +58,22 @@ impl EventEnricher { } } } - + None } - + /// Get process command name - pub fn get_process_comm(&self, pid: u32) -> Option { + pub fn get_process_comm(&self, _pid: u32) -> Option { #[cfg(target_os = "linux")] { // Read from /proc/[pid]/comm - let comm_path = format!("/proc/{}/comm", pid); + let comm_path = format!("/proc/{}/comm", _pid); if let Ok(content) = std::fs::read_to_string(&comm_path) { return Some(content.trim().to_string()); } - + // Alternative: read from /proc/[pid]/cmdline - let cmdline_path = format!("/proc/{}/cmdline", pid); + let cmdline_path = format!("/proc/{}/cmdline", _pid); if let Ok(content) = std::fs::read_to_string(&cmdline_path) { if let Some(first_null) = content.find('\0') { let path = &content[..first_null]; @@ -86,35 +85,35 @@ impl EventEnricher { } } } - + None } - + /// Get process executable path - pub fn get_process_exe(&self, pid: u32) -> Option { + pub fn get_process_exe(&self, _pid: u32) -> Option { #[cfg(target_os = "linux")] { // Read symlink /proc/[pid]/exe - let exe_path = format!("/proc/{}/exe", pid); + let exe_path = format!("/proc/{}/exe", _pid); if let Ok(path) = std::fs::read_link(&exe_path) { return path.to_str().map(|s| s.to_string()); } } - + None } - + /// Get process working directory - pub fn get_process_cwd(&self, pid: u32) -> Option { + pub fn get_process_cwd(&self, _pid: u32) -> Option { #[cfg(target_os = "linux")] { // Read symlink /proc/[pid]/cwd - let cwd_path = format!("/proc/{}/cwd", pid); + let cwd_path = format!("/proc/{}/cwd", _pid); if let Ok(path) = std::fs::read_link(&cwd_path) { return path.to_str().map(|s| s.to_string()); } } - + None } } @@ -139,7 +138,7 @@ mod tests { let enricher = EventEnricher::new(); assert!(enricher.is_ok()); } - + #[test] fn test_normalize_timestamp() { let now = Utc::now(); diff --git a/src/collectors/ebpf/kernel.rs b/src/collectors/ebpf/kernel.rs index 3348569..a3db7e8 100644 --- a/src/collectors/ebpf/kernel.rs +++ b/src/collectors/ebpf/kernel.rs @@ -2,7 +2,7 @@ //! //! Provides kernel version detection and compatibility checks for eBPF -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use std::fmt; /// Kernel version information @@ -17,26 +17,23 @@ impl KernelVersion { /// Parse kernel version from string (e.g., "5.15.0" or "4.19.0-16-amd64") pub fn parse(version: &str) -> Result { // Extract the first three numeric components - let parts: Vec<&str> = version - .split('.') - .take(3) - .collect(); - + let parts: Vec<&str> = version.split('.').take(3).collect(); + if parts.len() < 2 { anyhow::bail!("Invalid kernel version format: {}", version); } - + let major = parts[0] .parse::() .with_context(|| format!("Invalid major version: {}", parts[0]))?; - + let minor = parts[1] - .split('-') // Handle versions like "15.0-16-amd64" + .split('-') // Handle versions like "15.0-16-amd64" .next() .unwrap_or("0") .parse::() .with_context(|| format!("Invalid minor version: {}", parts[1]))?; - + let patch = if parts.len() > 2 { parts[2] .split('-') @@ -47,15 +44,19 @@ impl KernelVersion { } else { 0 }; - - Ok(Self { major, minor, patch }) + + Ok(Self { + major, + minor, + patch, + }) } - + /// Check if this version meets the minimum requirement pub fn meets_minimum(&self, minimum: &KernelVersion) -> bool { self >= minimum } - + /// Check if kernel supports eBPF (4.19+) pub fn supports_ebpf(&self) -> bool { self.meets_minimum(&KernelVersion { @@ -64,7 +65,7 @@ impl KernelVersion { patch: 0, }) } - + /// Check if kernel supports BTF pub fn supports_btf(&self) -> bool { // BTF support improved significantly in 5.4+ @@ -98,25 +99,25 @@ impl KernelInfo { let version_str = get_kernel_version()?; let version = KernelVersion::parse(&version_str) .with_context(|| format!("Failed to parse kernel version: {}", version_str))?; - + Ok(Self { version, os: "linux".to_string(), arch: std::env::consts::ARCH.to_string(), }) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("Kernel info only available on Linux"); } } - + /// Check if current kernel supports eBPF pub fn supports_ebpf(&self) -> bool { self.version.supports_ebpf() } - + /// Check if current kernel supports BTF pub fn supports_btf(&self) -> bool { self.version.supports_btf() @@ -139,10 +140,10 @@ pub fn check_kernel_version() -> Result { #[cfg(target_os = "linux")] fn get_kernel_version() -> Result { use std::fs; - + let version = fs::read_to_string("/proc/sys/kernel/osrelease") .with_context(|| "Failed to read /proc/sys/kernel/osrelease")?; - + Ok(version.trim().to_string()) } @@ -154,7 +155,7 @@ pub fn is_linux() -> bool { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_kernel_version_parse_simple() { let version = KernelVersion::parse("5.15.0").unwrap(); @@ -162,7 +163,7 @@ mod tests { assert_eq!(version.minor, 15); assert_eq!(version.patch, 0); } - + #[test] fn test_kernel_version_parse_with_suffix() { let version = KernelVersion::parse("4.19.0-16-amd64").unwrap(); @@ -170,7 +171,7 @@ mod tests { assert_eq!(version.minor, 19); assert_eq!(version.patch, 0); } - + #[test] fn test_kernel_version_parse_two_components() { let version = KernelVersion::parse("5.10").unwrap(); @@ -178,52 +179,52 @@ mod tests { assert_eq!(version.minor, 10); assert_eq!(version.patch, 0); } - + #[test] fn test_kernel_version_parse_invalid() { let result = KernelVersion::parse("invalid"); assert!(result.is_err()); } - + #[test] fn test_kernel_version_comparison() { let v1 = KernelVersion::parse("5.10.0").unwrap(); let v2 = KernelVersion::parse("5.15.0").unwrap(); - + assert!(v2 > v1); assert!(v1 < v2); } - + #[test] fn test_kernel_version_equality() { let v1 = KernelVersion::parse("5.10.0").unwrap(); let v2 = KernelVersion::parse("5.10.0").unwrap(); assert_eq!(v1, v2); } - + #[test] fn test_kernel_version_display() { let version = KernelVersion::parse("5.15.0").unwrap(); assert_eq!(format!("{}", version), "5.15.0"); } - + #[test] fn test_kernel_version_supports_ebpf() { let v4_18 = KernelVersion::parse("4.18.0").unwrap(); let v4_19 = KernelVersion::parse("4.19.0").unwrap(); let v5_10 = KernelVersion::parse("5.10.0").unwrap(); - + assert!(!v4_18.supports_ebpf()); assert!(v4_19.supports_ebpf()); assert!(v5_10.supports_ebpf()); } - + #[test] fn test_kernel_version_supports_btf() { let v5_3 = KernelVersion::parse("5.3.0").unwrap(); let v5_4 = KernelVersion::parse("5.4.0").unwrap(); let v5_10 = KernelVersion::parse("5.10.0").unwrap(); - + assert!(!v5_3.supports_btf()); assert!(v5_4.supports_btf()); assert!(v5_10.supports_btf()); diff --git a/src/collectors/ebpf/loader.rs b/src/collectors/ebpf/loader.rs index 5838f1d..4ced63f 100644 --- a/src/collectors/ebpf/loader.rs +++ b/src/collectors/ebpf/loader.rs @@ -1,10 +1,10 @@ //! eBPF program loader //! //! Loads and manages eBPF programs using aya-rs -//! +//! //! Note: This module is only available on Linux with the ebpf feature enabled -use anyhow::{Result, Context, bail}; +use anyhow::Result; use std::collections::HashMap; /// eBPF loader errors @@ -12,22 +12,22 @@ use std::collections::HashMap; pub enum LoadError { #[error("Program not found: {0}")] ProgramNotFound(String), - + #[error("Failed to load program: {0}")] LoadFailed(String), - + #[error("Failed to attach program: {0}")] AttachFailed(String), - + #[error("Kernel version too low: required {required}, current {current}. eBPF requires kernel 4.19+")] KernelVersionTooLow { required: String, current: String }, - + #[error("Not running on Linux")] NotLinux, - + #[error("Permission denied: eBPF programs require root or CAP_BPF")] PermissionDenied, - + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -36,17 +36,18 @@ pub enum LoadError { /// /// Responsible for loading eBPF programs from ELF files /// and attaching them to kernel tracepoints +#[derive(Default)] pub struct EbpfLoader { #[cfg(all(target_os = "linux", feature = "ebpf"))] bpf: Option, - + loaded_programs: HashMap, kernel_version: Option, } #[derive(Debug, Clone)] struct ProgramInfo { - name: String, + _name: String, attached: bool, } @@ -57,7 +58,7 @@ impl EbpfLoader { if !cfg!(target_os = "linux") { return Err(LoadError::NotLinux); } - + // Check kernel version #[cfg(target_os = "linux")] let kernel_version = { @@ -78,10 +79,10 @@ impl EbpfLoader { } } }; - + #[cfg(not(target_os = "linux"))] let kernel_version: Option = None; - + Ok(Self { #[cfg(all(target_os = "linux", feature = "ebpf"))] bpf: None, @@ -89,7 +90,7 @@ impl EbpfLoader { kernel_version, }) } - + /// Load an eBPF program from bytes (ELF file contents) pub fn load_program_from_bytes(&mut self, _bytes: &[u8]) -> Result<(), LoadError> { #[cfg(all(target_os = "linux", feature = "ebpf"))] @@ -98,8 +99,7 @@ impl EbpfLoader { return Err(LoadError::LoadFailed("Empty program bytes".to_string())); } - let bpf = aya::Bpf::load(_bytes) - .map_err(|e| LoadError::LoadFailed(e.to_string()))?; + let bpf = aya::Bpf::load(_bytes).map_err(|e| LoadError::LoadFailed(e.to_string()))?; self.bpf = Some(bpf); log::info!("eBPF program loaded ({} bytes)", _bytes.len()); @@ -111,39 +111,39 @@ impl EbpfLoader { Err(LoadError::NotLinux) } } - + /// Load an eBPF program from ELF file pub fn load_program_from_file(&mut self, _path: &str) -> Result<(), LoadError> { #[cfg(all(target_os = "linux", feature = "ebpf"))] { use std::fs; - + let bytes = fs::read(_path) .with_context(|| format!("Failed to read eBPF program: {}", _path)) .map_err(|e| LoadError::Other(e.into()))?; - + self.load_program_from_bytes(&bytes) } - + #[cfg(not(all(target_os = "linux", feature = "ebpf")))] { Err(LoadError::NotLinux) } } - + /// Attach a loaded program to its tracepoint pub fn attach_program(&mut self, _program_name: &str) -> Result<(), LoadError> { #[cfg(all(target_os = "linux", feature = "ebpf"))] { - let (category, tp_name) = program_to_tracepoint(_program_name) - .ok_or_else(|| LoadError::ProgramNotFound( - format!("No tracepoint mapping for '{}'", _program_name) - ))?; + let (category, tp_name) = program_to_tracepoint(_program_name).ok_or_else(|| { + LoadError::ProgramNotFound(format!("No tracepoint mapping for '{}'", _program_name)) + })?; - let bpf = self.bpf.as_mut() - .ok_or_else(|| LoadError::LoadFailed( - "No eBPF program loaded; call load_program_from_bytes first".to_string() - ))?; + let bpf = self.bpf.as_mut().ok_or_else(|| { + LoadError::LoadFailed( + "No eBPF program loaded; call load_program_from_bytes first".to_string(), + ) + })?; let prog: &mut aya::programs::TracePoint = bpf .program_mut(_program_name) @@ -154,17 +154,24 @@ impl EbpfLoader { prog.load() .map_err(|e| LoadError::AttachFailed(format!("load '{}': {}", _program_name, e)))?; - prog.attach(category, tp_name) - .map_err(|e| LoadError::AttachFailed( - format!("attach '{}/{}': {}", category, tp_name, e) - ))?; + prog.attach(category, tp_name).map_err(|e| { + LoadError::AttachFailed(format!("attach '{}/{}': {}", category, tp_name, e)) + })?; self.loaded_programs.insert( _program_name.to_string(), - ProgramInfo { name: _program_name.to_string(), attached: true }, + ProgramInfo { + _name: _program_name.to_string(), + attached: true, + }, ); - log::info!("eBPF program '{}' attached to {}/{}", _program_name, category, tp_name); + log::info!( + "eBPF program '{}' attached to {}/{}", + _program_name, + category, + tp_name + ); Ok(()) } @@ -178,7 +185,12 @@ impl EbpfLoader { pub fn attach_all_programs(&mut self) -> Result<(), LoadError> { #[cfg(all(target_os = "linux", feature = "ebpf"))] { - for name in &["trace_execve", "trace_connect", "trace_openat", "trace_ptrace"] { + for name in &[ + "trace_execve", + "trace_connect", + "trace_openat", + "trace_ptrace", + ] { if let Err(e) = self.attach_program(name) { log::warn!("Failed to attach '{}': {}", name, e); } @@ -196,20 +208,19 @@ impl EbpfLoader { /// Must be called after load_program_from_bytes and before the Bpf object is dropped. #[cfg(all(target_os = "linux", feature = "ebpf"))] pub fn take_ring_buf(&mut self) -> Result, LoadError> { - let bpf = self.bpf.as_mut() - .ok_or_else(|| LoadError::LoadFailed( - "No eBPF program loaded".to_string() - ))?; + let bpf = self + .bpf + .as_mut() + .ok_or_else(|| LoadError::LoadFailed("No eBPF program loaded".to_string()))?; - let map = bpf.take_map("EVENTS") - .ok_or_else(|| LoadError::LoadFailed( - "EVENTS ring buffer map not found in eBPF program".to_string() - ))?; + let map = bpf.take_map("EVENTS").ok_or_else(|| { + LoadError::LoadFailed("EVENTS ring buffer map not found in eBPF program".to_string()) + })?; aya::maps::RingBuf::try_from(map) .map_err(|e| LoadError::LoadFailed(format!("Failed to create ring buffer: {}", e))) } - + /// Detach a program pub fn detach_program(&mut self, program_name: &str) -> Result<(), LoadError> { if let Some(info) = self.loaded_programs.get_mut(program_name) { @@ -219,7 +230,7 @@ impl EbpfLoader { Err(LoadError::ProgramNotFound(program_name.to_string())) } } - + /// Unload a program pub fn unload_program(&mut self, program_name: &str) -> Result<(), LoadError> { self.loaded_programs @@ -227,12 +238,12 @@ impl EbpfLoader { .ok_or_else(|| LoadError::ProgramNotFound(program_name.to_string()))?; Ok(()) } - + /// Check if a program is loaded pub fn is_program_loaded(&self, program_name: &str) -> bool { self.loaded_programs.contains_key(program_name) } - + /// Check if a program is attached pub fn is_program_attached(&self, program_name: &str) -> bool { self.loaded_programs @@ -240,17 +251,17 @@ impl EbpfLoader { .map(|info| info.attached) .unwrap_or(false) } - + /// Get the number of loaded programs pub fn loaded_program_count(&self) -> usize { self.loaded_programs.len() } - + /// Get the kernel version pub fn kernel_version(&self) -> Option<&crate::collectors::ebpf::kernel::KernelVersion> { self.kernel_version.as_ref() } - + /// Check if eBPF is supported on this system pub fn is_ebpf_supported(&self) -> bool { self.kernel_version @@ -260,24 +271,14 @@ impl EbpfLoader { } } -impl Default for EbpfLoader { - fn default() -> Self { - Self { - #[cfg(all(target_os = "linux", feature = "ebpf"))] - bpf: None, - loaded_programs: HashMap::new(), - kernel_version: None, - } - } -} - /// Map program name to its tracepoint (category, name) for aya attachment. +#[cfg(all(target_os = "linux", feature = "ebpf"))] fn program_to_tracepoint(name: &str) -> Option<(&'static str, &'static str)> { match name { - "trace_execve" => Some(("syscalls", "sys_enter_execve")), + "trace_execve" => Some(("syscalls", "sys_enter_execve")), "trace_connect" => Some(("syscalls", "sys_enter_connect")), - "trace_openat" => Some(("syscalls", "sys_enter_openat")), - "trace_ptrace" => Some(("syscalls", "sys_enter_ptrace")), + "trace_openat" => Some(("syscalls", "sys_enter_openat")), + "trace_ptrace" => Some(("syscalls", "sys_enter_ptrace")), _ => None, } } @@ -299,33 +300,33 @@ impl EbpfLoader { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_ebpf_loader_creation() { let loader = EbpfLoader::new(); - - #[cfg(all(target_os = "linux", feature = "ebpf"))] + + #[cfg(target_os = "linux")] assert!(loader.is_ok()); - - #[cfg(not(all(target_os = "linux", feature = "ebpf")))] + + #[cfg(not(target_os = "linux"))] assert!(loader.is_err()); } - + #[test] fn test_is_linux() { #[cfg(target_os = "linux")] assert!(is_linux()); - + #[cfg(not(target_os = "linux"))] assert!(!is_linux()); } - + #[test] fn test_load_error_display() { let error = LoadError::ProgramNotFound("test".to_string()); let msg = format!("{}", error); assert!(msg.contains("test")); - + let error = LoadError::NotLinux; let msg = format!("{}", error); assert!(msg.contains("Linux")); diff --git a/src/collectors/ebpf/mod.rs b/src/collectors/ebpf/mod.rs index ca59ad5..7da67d0 100644 --- a/src/collectors/ebpf/mod.rs +++ b/src/collectors/ebpf/mod.rs @@ -1,21 +1,21 @@ //! eBPF collectors module //! //! Provides eBPF-based syscall monitoring using aya-rs -//! +//! //! Note: This module is only available on Linux with the ebpf feature enabled -pub mod loader; +pub mod container; +pub mod enrichment; pub mod kernel; -pub mod syscall_monitor; +pub mod loader; pub mod programs; pub mod ring_buffer; -pub mod enrichment; -pub mod container; +pub mod syscall_monitor; pub mod types; // Re-export main types +pub use container::ContainerDetector; +pub use enrichment::EventEnricher; pub use loader::EbpfLoader; pub use syscall_monitor::SyscallMonitor; -pub use enrichment::EventEnricher; -pub use container::ContainerDetector; -pub use types::{EbpfSyscallEvent, EbpfEventData, to_syscall_event}; +pub use types::{to_syscall_event, EbpfEventData, EbpfSyscallEvent}; diff --git a/src/collectors/ebpf/programs.rs b/src/collectors/ebpf/programs.rs index 92b7256..7767929 100644 --- a/src/collectors/ebpf/programs.rs +++ b/src/collectors/ebpf/programs.rs @@ -1,7 +1,7 @@ //! eBPF programs module //! //! Contains eBPF program definitions -//! +//! //! Note: Actual eBPF programs will be implemented in TASK-004 /// Program types supported by Stackdog @@ -21,13 +21,13 @@ pub struct ProgramMetadata { pub name: &'static str, pub program_type: ProgramType, pub description: &'static str, - pub required_kernel: (u32, u32), // (major, minor) + pub required_kernel: (u32, u32), // (major, minor) } /// Built-in eBPF programs pub mod builtin { use super::*; - + /// Execve syscall tracepoint program pub const EXECVE_PROGRAM: ProgramMetadata = ProgramMetadata { name: "trace_execve", @@ -35,7 +35,7 @@ pub mod builtin { description: "Monitors execve syscalls for process execution tracking", required_kernel: (4, 19), }; - + /// Connect syscall tracepoint program pub const CONNECT_PROGRAM: ProgramMetadata = ProgramMetadata { name: "trace_connect", @@ -43,7 +43,7 @@ pub mod builtin { description: "Monitors connect syscalls for network connection tracking", required_kernel: (4, 19), }; - + /// Openat syscall tracepoint program pub const OPENAT_PROGRAM: ProgramMetadata = ProgramMetadata { name: "trace_openat", @@ -51,7 +51,7 @@ pub mod builtin { description: "Monitors openat syscalls for file access tracking", required_kernel: (4, 19), }; - + /// Ptrace syscall tracepoint program pub const PTRACE_PROGRAM: ProgramMetadata = ProgramMetadata { name: "trace_ptrace", @@ -64,14 +64,14 @@ pub mod builtin { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_program_type_variants() { let _syscall = ProgramType::SyscallTracepoint; let _network = ProgramType::NetworkMonitor; let _container = ProgramType::ContainerMonitor; } - + #[test] fn test_builtin_programs() { assert_eq!(builtin::EXECVE_PROGRAM.name, "trace_execve"); @@ -79,7 +79,7 @@ mod tests { assert_eq!(builtin::OPENAT_PROGRAM.name, "trace_openat"); assert_eq!(builtin::PTRACE_PROGRAM.name, "trace_ptrace"); } - + #[test] fn test_program_metadata() { let program = builtin::EXECVE_PROGRAM; diff --git a/src/collectors/ebpf/ring_buffer.rs b/src/collectors/ebpf/ring_buffer.rs index 9c25b01..6acac60 100644 --- a/src/collectors/ebpf/ring_buffer.rs +++ b/src/collectors/ebpf/ring_buffer.rs @@ -2,7 +2,6 @@ //! //! Provides efficient event buffering from eBPF to userspace -use anyhow::Result; use crate::events::syscall::SyscallEvent; /// Ring buffer for eBPF events @@ -18,10 +17,10 @@ impl EventRingBuffer { pub fn new() -> Self { Self { buffer: Vec::new(), - capacity: 4096, // Default capacity + capacity: 4096, // Default capacity } } - + /// Create a ring buffer with specific capacity pub fn with_capacity(capacity: usize) -> Self { Self { @@ -29,7 +28,7 @@ impl EventRingBuffer { capacity, } } - + /// Add an event to the buffer pub fn push(&mut self, event: SyscallEvent) { // If buffer is full, remove oldest events @@ -38,27 +37,27 @@ impl EventRingBuffer { } self.buffer.push(event); } - + /// Get all events and clear the buffer pub fn drain(&mut self) -> Vec { std::mem::take(&mut self.buffer) } - + /// Get the number of events in the buffer pub fn len(&self) -> usize { self.buffer.len() } - + /// Check if buffer is empty pub fn is_empty(&self) -> bool { self.buffer.is_empty() } - + /// Get the capacity of the buffer pub fn capacity(&self) -> usize { self.capacity } - + /// View events without consuming them pub fn events(&self) -> &[SyscallEvent] { &self.buffer @@ -81,72 +80,72 @@ mod tests { use super::*; use crate::events::syscall::{SyscallEvent, SyscallType}; use chrono::Utc; - + #[test] fn test_ring_buffer_creation() { let buffer = EventRingBuffer::new(); assert_eq!(buffer.len(), 0); assert!(buffer.is_empty()); } - + #[test] fn test_ring_buffer_with_capacity() { let buffer = EventRingBuffer::with_capacity(100); assert_eq!(buffer.capacity(), 100); } - + #[test] fn test_ring_buffer_push() { let mut buffer = EventRingBuffer::new(); let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); - + buffer.push(event); assert_eq!(buffer.len(), 1); } - + #[test] fn test_ring_buffer_drain() { let mut buffer = EventRingBuffer::new(); - + for i in 0..5 { let event = SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()); buffer.push(event); } - + let events = buffer.drain(); assert_eq!(events.len(), 5); assert!(buffer.is_empty()); } - + #[test] fn test_ring_buffer_overflow() { let mut buffer = EventRingBuffer::with_capacity(3); - + // Push 5 events into buffer with capacity 3 for i in 0..5 { let event = SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()); buffer.push(event); } - + // Should only have 3 events (oldest removed) assert_eq!(buffer.len(), 3); - + // The first two events should be removed let events = buffer.drain(); - assert_eq!(events[0].pid, 2); // First event should be pid=2 + assert_eq!(events[0].pid, 2); // First event should be pid=2 assert_eq!(events[1].pid, 3); assert_eq!(events[2].pid, 4); } - + #[test] fn test_ring_buffer_clear() { let mut buffer = EventRingBuffer::new(); - + for i in 0..3 { let event = SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()); buffer.push(event); } - + buffer.clear(); assert!(buffer.is_empty()); } diff --git a/src/collectors/ebpf/syscall_monitor.rs b/src/collectors/ebpf/syscall_monitor.rs index df92490..46ddce2 100644 --- a/src/collectors/ebpf/syscall_monitor.rs +++ b/src/collectors/ebpf/syscall_monitor.rs @@ -2,11 +2,11 @@ //! //! Monitors syscalls using eBPF tracepoints -use anyhow::{Result, Context}; -use crate::events::syscall::{SyscallEvent, SyscallType}; -use crate::collectors::ebpf::ring_buffer::EventRingBuffer; -use crate::collectors::ebpf::enrichment::EventEnricher; use crate::collectors::ebpf::container::ContainerDetector; +use crate::collectors::ebpf::enrichment::EventEnricher; +use crate::collectors::ebpf::ring_buffer::EventRingBuffer; +use crate::events::syscall::SyscallEvent; +use anyhow::Result; /// Syscall monitor using eBPF pub struct SyscallMonitor { @@ -18,8 +18,8 @@ pub struct SyscallMonitor { running: bool, event_buffer: EventRingBuffer, - enricher: EventEnricher, - container_detector: Option, + _enricher: EventEnricher, + _container_detector: Option, } impl SyscallMonitor { @@ -27,30 +27,29 @@ impl SyscallMonitor { pub fn new() -> Result { #[cfg(all(target_os = "linux", feature = "ebpf"))] { - let loader = super::loader::EbpfLoader::new() - .context("Failed to create eBPF loader")?; - - let enricher = EventEnricher::new() - .context("Failed to create event enricher")?; - + let loader = + super::loader::EbpfLoader::new().context("Failed to create eBPF loader")?; + + let enricher = EventEnricher::new().context("Failed to create event enricher")?; + let container_detector = ContainerDetector::new().ok(); - + Ok(Self { loader: Some(loader), ring_buf: None, running: false, event_buffer: EventRingBuffer::with_capacity(8192), - enricher, - container_detector, + _enricher: enricher, + _container_detector: container_detector, }) } - + #[cfg(not(all(target_os = "linux", feature = "ebpf")))] { anyhow::bail!("SyscallMonitor is only available on Linux with eBPF feature"); } } - + /// Start monitoring syscalls pub fn start(&mut self) -> Result<()> { #[cfg(all(target_os = "linux", feature = "ebpf"))] @@ -67,8 +66,12 @@ impl SyscallMonitor { log::warn!("Some eBPF programs failed to attach: {}", e); }); match loader.take_ring_buf() { - Ok(rb) => { self.ring_buf = Some(rb); } - Err(e) => { log::warn!("Failed to get eBPF ring buffer: {}", e); } + Ok(rb) => { + self.ring_buf = Some(rb); + } + Err(e) => { + log::warn!("Failed to get eBPF ring buffer: {}", e); + } } } Err(e) => { @@ -77,7 +80,8 @@ impl SyscallMonitor { Running without kernel event collection — \ build the eBPF crate first with `cargo build --release` \ in the ebpf/ directory.", - ebpf_path, e + ebpf_path, + e ); } } @@ -93,7 +97,7 @@ impl SyscallMonitor { anyhow::bail!("SyscallMonitor is only available on Linux"); } } - + /// Stop monitoring syscalls pub fn stop(&mut self) -> Result<()> { self.running = false; @@ -105,12 +109,12 @@ impl SyscallMonitor { log::info!("Syscall monitor stopped"); Ok(()) } - + /// Check if monitor is running pub fn is_running(&self) -> bool { self.running } - + /// Poll for new events pub fn poll_events(&mut self) -> Vec { #[cfg(all(target_os = "linux", feature = "ebpf"))] @@ -139,7 +143,7 @@ impl SyscallMonitor { // Drain the staging buffer and enrich with /proc info let mut events = self.event_buffer.drain(); for event in &mut events { - let _ = self.enricher.enrich(event); + let _ = self._enricher.enrich(event); } events @@ -155,40 +159,40 @@ impl SyscallMonitor { pub fn peek_events(&self) -> &[SyscallEvent] { self.event_buffer.events() } - + /// Get the eBPF loader #[cfg(all(target_os = "linux", feature = "ebpf"))] pub fn loader(&self) -> Option<&super::loader::EbpfLoader> { self.loader.as_ref() } - + /// Get container ID for current process pub fn current_container_id(&mut self) -> Option { #[cfg(target_os = "linux")] { - if let Some(detector) = &mut self.container_detector { + if let Some(detector) = &mut self._container_detector { return detector.current_container(); } } None } - + /// Detect container for a specific PID - pub fn detect_container_for_pid(&mut self, pid: u32) -> Option { + pub fn detect_container_for_pid(&mut self, _pid: u32) -> Option { #[cfg(target_os = "linux")] { - if let Some(detector) = &mut self.container_detector { - return detector.detect_container(pid); + if let Some(detector) = &mut self._container_detector { + return detector.detect_container(_pid); } } None } - + /// Get event count pub fn event_count(&self) -> usize { self.event_buffer.len() } - + /// Clear event buffer pub fn clear_events(&mut self) { self.event_buffer.clear(); @@ -212,48 +216,48 @@ impl SyscallMonitor { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_syscall_monitor_creation() { let result = SyscallMonitor::new(); - + #[cfg(all(target_os = "linux", feature = "ebpf"))] assert!(result.is_ok()); - + #[cfg(not(all(target_os = "linux", feature = "ebpf")))] assert!(result.is_err()); } - + #[test] fn test_syscall_monitor_not_running_initially() { - let monitor = SyscallMonitor::new(); - + let _monitor = SyscallMonitor::new(); + #[cfg(all(target_os = "linux", feature = "ebpf"))] { let monitor = monitor.unwrap(); assert!(!monitor.is_running()); } } - + #[test] fn test_poll_events_empty_when_not_running() { - let mut monitor = SyscallMonitor::new(); - + let _monitor = SyscallMonitor::new(); + #[cfg(all(target_os = "linux", feature = "ebpf"))] { - let mut monitor = monitor.unwrap(); + let mut monitor = _monitor.unwrap(); let events = monitor.poll_events(); assert!(events.is_empty()); } } - + #[test] fn test_event_count() { - let mut monitor = SyscallMonitor::new(); - + let _monitor = SyscallMonitor::new(); + #[cfg(all(target_os = "linux", feature = "ebpf"))] { - let mut monitor = monitor.unwrap(); + let monitor = _monitor.unwrap(); assert_eq!(monitor.event_count(), 0); } } diff --git a/src/collectors/ebpf/types.rs b/src/collectors/ebpf/types.rs index 6e97d28..3455d4a 100644 --- a/src/collectors/ebpf/types.rs +++ b/src/collectors/ebpf/types.rs @@ -3,7 +3,7 @@ //! Shared type definitions for eBPF programs and userspace /// eBPF syscall event structure -/// +/// /// This structure is shared between eBPF programs and userspace /// It must be C-compatible for efficient transfer via ring buffer #[repr(C)] @@ -51,9 +51,7 @@ impl std::fmt::Debug for EbpfEventData { impl Default for EbpfEventData { fn default() -> Self { - Self { - raw: [0u8; 128], - } + Self { raw: [0u8; 128] } } } @@ -71,7 +69,11 @@ pub struct ExecveData { impl Default for ExecveData { fn default() -> Self { - Self { filename_len: 0, filename: [0u8; 128], argc: 0 } + Self { + filename_len: 0, + filename: [0u8; 128], + argc: 0, + } } } @@ -101,7 +103,11 @@ pub struct OpenatData { impl Default for OpenatData { fn default() -> Self { - Self { path_len: 0, path: [0u8; 256], flags: 0 } + Self { + path_len: 0, + path: [0u8; 256], + flags: 0, + } } } @@ -132,13 +138,13 @@ impl EbpfSyscallEvent { data: EbpfEventData::default(), } } - + /// Get command name as string pub fn comm_str(&self) -> String { let len = self.comm.iter().position(|&b| b == 0).unwrap_or(16); String::from_utf8_lossy(&self.comm[..len]).to_string() } - + /// Set command name pub fn set_comm(&mut self, comm: &[u8]) { let len = comm.len().min(15); @@ -151,32 +157,32 @@ impl EbpfSyscallEvent { pub fn to_syscall_event(ebpf_event: &EbpfSyscallEvent) -> crate::events::syscall::SyscallEvent { use crate::events::syscall::{SyscallEvent, SyscallType}; use chrono::Utc; - + // Convert syscall_id to SyscallType let syscall_type = match ebpf_event.syscall_id { - 59 => SyscallType::Execve, // sys_execve - 42 => SyscallType::Connect, // sys_connect - 257 => SyscallType::Openat, // sys_openat - 101 => SyscallType::Ptrace, // sys_ptrace + 59 => SyscallType::Execve, // sys_execve + 42 => SyscallType::Connect, // sys_connect + 257 => SyscallType::Openat, // sys_openat + 101 => SyscallType::Ptrace, // sys_ptrace _ => SyscallType::Unknown, }; - + let mut event = SyscallEvent::new( ebpf_event.pid, ebpf_event.uid, syscall_type, - Utc::now(), // Use current time (timestamp from eBPF may need conversion) + Utc::now(), // Use current time (timestamp from eBPF may need conversion) ); - + event.comm = Some(ebpf_event.comm_str()); - + event } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_event_creation() { let event = EbpfSyscallEvent::new(1234, 1000, 59); @@ -184,28 +190,28 @@ mod tests { assert_eq!(event.uid, 1000); assert_eq!(event.syscall_id, 59); } - + #[test] fn test_comm_str_empty() { let mut event = EbpfSyscallEvent::new(1234, 1000, 59); event.comm = [0u8; 16]; assert_eq!(event.comm_str(), ""); } - + #[test] fn test_comm_str_short() { let mut event = EbpfSyscallEvent::new(1234, 1000, 59); event.set_comm(b"bash"); assert_eq!(event.comm_str(), "bash"); } - + #[test] fn test_comm_str_exact_15() { let mut event = EbpfSyscallEvent::new(1234, 1000, 59); event.set_comm(b"longprocessname"); assert_eq!(event.comm_str(), "longprocessname"); } - + #[test] fn test_set_comm_truncates() { let mut event = EbpfSyscallEvent::new(1234, 1000, 59); diff --git a/src/collectors/mod.rs b/src/collectors/mod.rs index c63079f..50f7164 100644 --- a/src/collectors/mod.rs +++ b/src/collectors/mod.rs @@ -5,8 +5,8 @@ //! - Docker events streaming //! - Network traffic capture -pub mod ebpf; pub mod docker_events; +pub mod ebpf; pub mod network; /// Marker struct for module tests diff --git a/src/database/connection.rs b/src/database/connection.rs index d98d619..4db4a27 100644 --- a/src/database/connection.rs +++ b/src/database/connection.rs @@ -1,9 +1,8 @@ //! Database connection pool using rusqlite and r2d2 -use r2d2::{Pool, ManageConnection}; -use rusqlite::{Connection, Result as RusqliteResult}; use anyhow::Result; -use std::fmt; +use r2d2::{ManageConnection, Pool}; +use rusqlite::{Connection, Result as RusqliteResult}; /// Rusqlite connection manager #[derive(Debug)] @@ -28,7 +27,7 @@ impl ManageConnection for SqliteConnectionManager { } fn is_valid(&self, conn: &mut Self::Connection) -> RusqliteResult<()> { - conn.execute_batch("").map_err(|e| e.into()) + conn.execute_batch("") } fn has_broken(&self, _: &mut Self::Connection) -> bool { @@ -41,17 +40,15 @@ pub type DbPool = Pool; /// Create database connection pool pub fn create_pool(database_url: &str) -> Result { let manager = SqliteConnectionManager::new(database_url); - let pool = Pool::builder() - .max_size(10) - .build(manager)?; - + let pool = Pool::builder().max_size(10).build(manager)?; + Ok(pool) } /// Initialize database (create tables if not exist) pub fn init_database(pool: &DbPool) -> Result<()> { let conn = pool.get()?; - + // Create alerts table conn.execute( "CREATE TABLE IF NOT EXISTS alerts ( @@ -66,7 +63,7 @@ pub fn init_database(pool: &DbPool) -> Result<()> { )", [], )?; - + // Create threats table conn.execute( "CREATE TABLE IF NOT EXISTS threats ( @@ -82,7 +79,7 @@ pub fn init_database(pool: &DbPool) -> Result<()> { )", [], )?; - + // Create containers_cache table conn.execute( "CREATE TABLE IF NOT EXISTS containers_cache ( @@ -97,17 +94,38 @@ pub fn init_database(pool: &DbPool) -> Result<()> { )", [], )?; - + // Create indexes for performance - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_alerts_status ON alerts(status)", []); - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_alerts_severity ON alerts(severity)", []); - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_alerts_timestamp ON alerts(timestamp)", []); - - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_threats_status ON threats(status)", []); - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_threats_severity ON threats(severity)", []); - - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_containers_status ON containers_cache(status)", []); - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_containers_name ON containers_cache(name)", []); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_alerts_status ON alerts(status)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_alerts_severity ON alerts(severity)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_alerts_timestamp ON alerts(timestamp)", + [], + ); + + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_threats_status ON threats(status)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_threats_severity ON threats(severity)", + [], + ); + + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_containers_status ON containers_cache(status)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_containers_name ON containers_cache(name)", + [], + ); // Create log_sources table conn.execute( @@ -138,9 +156,15 @@ pub fn init_database(pool: &DbPool) -> Result<()> { [], )?; - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_log_sources_type ON log_sources(source_type)", []); - let _ = conn.execute("CREATE INDEX IF NOT EXISTS idx_log_summaries_source ON log_summaries(source_id)", []); - + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_log_sources_type ON log_sources(source_type)", + [], + ); + let _ = conn.execute( + "CREATE INDEX IF NOT EXISTS idx_log_summaries_source ON log_summaries(source_id)", + [], + ); + Ok(()) } @@ -153,7 +177,7 @@ mod tests { let pool = create_pool(":memory:"); assert!(pool.is_ok()); } - + #[test] fn test_init_database() { let pool = create_pool(":memory:").unwrap(); diff --git a/src/database/repositories/alerts.rs b/src/database/repositories/alerts.rs index 8001182..d6d7a88 100644 --- a/src/database/repositories/alerts.rs +++ b/src/database/repositories/alerts.rs @@ -1,11 +1,11 @@ //! Alert repository using rusqlite -use rusqlite::params; -use anyhow::Result; use crate::database::connection::DbPool; use crate::database::models::Alert; -use uuid::Uuid; +use anyhow::Result; use chrono::Utc; +use rusqlite::params; +use uuid::Uuid; /// Alert filter #[derive(Debug, Clone, Default)] @@ -38,7 +38,7 @@ fn map_alert_row(row: &rusqlite::Row) -> Result { /// Create a new alert pub async fn create_alert(pool: &DbPool, alert: Alert) -> Result { let conn = pool.get()?; - + conn.execute( "INSERT INTO alerts (id, alert_type, severity, message, status, timestamp, metadata) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", @@ -52,21 +52,21 @@ pub async fn create_alert(pool: &DbPool, alert: Alert) -> Result { alert.metadata ], )?; - + Ok(alert) } /// List alerts with filter pub async fn list_alerts(pool: &DbPool, filter: AlertFilter) -> Result> { let conn = pool.get()?; - + let mut alerts = Vec::new(); - + match (&filter.severity, &filter.status) { (Some(severity), Some(status)) => { let mut stmt = conn.prepare( "SELECT id, alert_type, severity, message, status, timestamp, metadata - FROM alerts WHERE severity = ?1 AND status = ?2 ORDER BY timestamp DESC" + FROM alerts WHERE severity = ?1 AND status = ?2 ORDER BY timestamp DESC", )?; let rows = stmt.query_map(params![severity, status], map_alert_row)?; for row in rows { @@ -76,7 +76,7 @@ pub async fn list_alerts(pool: &DbPool, filter: AlertFilter) -> Result { let mut stmt = conn.prepare( "SELECT id, alert_type, severity, message, status, timestamp, metadata - FROM alerts WHERE severity = ?1 ORDER BY timestamp DESC" + FROM alerts WHERE severity = ?1 ORDER BY timestamp DESC", )?; let rows = stmt.query_map(params![severity], map_alert_row)?; for row in rows { @@ -86,7 +86,7 @@ pub async fn list_alerts(pool: &DbPool, filter: AlertFilter) -> Result { let mut stmt = conn.prepare( "SELECT id, alert_type, severity, message, status, timestamp, metadata - FROM alerts WHERE status = ?1 ORDER BY timestamp DESC" + FROM alerts WHERE status = ?1 ORDER BY timestamp DESC", )?; let rows = stmt.query_map(params![status], map_alert_row)?; for row in rows { @@ -96,7 +96,7 @@ pub async fn list_alerts(pool: &DbPool, filter: AlertFilter) -> Result { let mut stmt = conn.prepare( "SELECT id, alert_type, severity, message, status, timestamp, metadata - FROM alerts ORDER BY timestamp DESC" + FROM alerts ORDER BY timestamp DESC", )?; let rows = stmt.query_map([], map_alert_row)?; for row in rows { @@ -104,21 +104,21 @@ pub async fn list_alerts(pool: &DbPool, filter: AlertFilter) -> Result Result> { let conn = pool.get()?; - + let mut stmt = conn.prepare( "SELECT id, alert_type, severity, message, status, timestamp, metadata - FROM alerts WHERE id = ?" + FROM alerts WHERE id = ?", )?; - + let result = stmt.query_row(params![alert_id], map_alert_row); - + match result { Ok(alert) => Ok(Some(alert)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), @@ -129,24 +129,36 @@ pub async fn get_alert(pool: &DbPool, alert_id: &str) -> Result> { /// Update alert status pub async fn update_alert_status(pool: &DbPool, alert_id: &str, status: &str) -> Result<()> { let conn = pool.get()?; - + conn.execute( "UPDATE alerts SET status = ?1 WHERE id = ?2", params![status, alert_id], )?; - + Ok(()) } /// Get alert statistics pub async fn get_alert_stats(pool: &DbPool) -> Result { let conn = pool.get()?; - + let total: i64 = conn.query_row("SELECT COUNT(*) FROM alerts", [], |row| row.get(0))?; - let new: i64 = conn.query_row("SELECT COUNT(*) FROM alerts WHERE status = 'New'", [], |row| row.get(0))?; - let ack: i64 = conn.query_row("SELECT COUNT(*) FROM alerts WHERE status = 'Acknowledged'", [], |row| row.get(0))?; - let resolved: i64 = conn.query_row("SELECT COUNT(*) FROM alerts WHERE status = 'Resolved'", [], |row| row.get(0))?; - + let new: i64 = conn.query_row( + "SELECT COUNT(*) FROM alerts WHERE status = 'New'", + [], + |row| row.get(0), + )?; + let ack: i64 = conn.query_row( + "SELECT COUNT(*) FROM alerts WHERE status = 'Acknowledged'", + [], + |row| row.get(0), + )?; + let resolved: i64 = conn.query_row( + "SELECT COUNT(*) FROM alerts WHERE status = 'Resolved'", + [], + |row| row.get(0), + )?; + Ok(AlertStats { total_count: total, new_count: new, @@ -178,39 +190,41 @@ mod tests { async fn test_create_and_list_alerts() { let pool = create_pool(":memory:").unwrap(); init_database(&pool).unwrap(); - + let alert = create_sample_alert(); let result = create_alert(&pool, alert.clone()).await; assert!(result.is_ok()); - + let alerts = list_alerts(&pool, AlertFilter::default()).await.unwrap(); assert_eq!(alerts.len(), 1); } - + #[actix_rt::test] async fn test_update_alert_status() { let pool = create_pool(":memory:").unwrap(); init_database(&pool).unwrap(); - + let alert = create_sample_alert(); create_alert(&pool, alert.clone()).await.unwrap(); - - update_alert_status(&pool, &alert.id, "Acknowledged").await.unwrap(); - + + update_alert_status(&pool, &alert.id, "Acknowledged") + .await + .unwrap(); + let updated = get_alert(&pool, &alert.id).await.unwrap().unwrap(); assert_eq!(updated.status, "Acknowledged"); } - + #[actix_rt::test] async fn test_get_alert_stats() { let pool = create_pool(":memory:").unwrap(); init_database(&pool).unwrap(); - + // Create some alerts for _ in 0..3 { create_alert(&pool, create_sample_alert()).await.unwrap(); } - + let stats = get_alert_stats(&pool).await.unwrap(); assert_eq!(stats.total_count, 3); assert_eq!(stats.new_count, 3); diff --git a/src/database/repositories/log_sources.rs b/src/database/repositories/log_sources.rs index 70e45fe..d3809e6 100644 --- a/src/database/repositories/log_sources.rs +++ b/src/database/repositories/log_sources.rs @@ -3,11 +3,11 @@ //! Persists discovered log sources and AI summaries, following //! the same pattern as the alerts repository. -use rusqlite::params; -use anyhow::Result; use crate::database::connection::DbPool; -use crate::sniff::discovery::{LogSource, LogSourceType}; +use crate::sniff::discovery::LogSource; +use anyhow::Result; use chrono::Utc; +use rusqlite::params; /// Create or update a log source (upsert by path_or_id) pub fn upsert_log_source(pool: &DbPool, source: &LogSource) -> Result<()> { @@ -35,26 +35,27 @@ pub fn list_log_sources(pool: &DbPool) -> Result> { let conn = pool.get()?; let mut stmt = conn.prepare( "SELECT id, source_type, path_or_id, name, discovered_at, last_read_position - FROM log_sources ORDER BY discovered_at DESC" + FROM log_sources ORDER BY discovered_at DESC", )?; - let sources = stmt.query_map([], |row| { - let source_type_str: String = row.get(1)?; - let discovered_str: String = row.get(4)?; - let pos: i64 = row.get(5)?; - Ok(LogSource { - id: row.get(0)?, - source_type: LogSourceType::from_str(&source_type_str), - path_or_id: row.get(2)?, - name: row.get(3)?, - discovered_at: chrono::DateTime::parse_from_rfc3339(&discovered_str) - .map(|dt| dt.with_timezone(&Utc)) - .unwrap_or_else(|_| Utc::now()), - last_read_position: pos as u64, - }) - })? - .filter_map(|r| r.ok()) - .collect(); + let sources = stmt + .query_map([], |row| { + let source_type_str: String = row.get(1)?; + let discovered_str: String = row.get(4)?; + let pos: i64 = row.get(5)?; + Ok(LogSource { + id: row.get(0)?, + source_type: source_type_str.parse().unwrap(), + path_or_id: row.get(2)?, + name: row.get(3)?, + discovered_at: chrono::DateTime::parse_from_rfc3339(&discovered_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()), + last_read_position: pos as u64, + }) + })? + .filter_map(|r| r.ok()) + .collect(); Ok(sources) } @@ -64,7 +65,7 @@ pub fn get_log_source_by_path(pool: &DbPool, path_or_id: &str) -> Result Result Result<()> { Ok(()) } +/// Parameters for creating a log summary +pub struct CreateLogSummaryParams<'a> { + pub source_id: &'a str, + pub summary_text: &'a str, + pub period_start: &'a str, + pub period_end: &'a str, + pub total_entries: i64, + pub error_count: i64, + pub warning_count: i64, +} + /// Store a log summary -pub fn create_log_summary( - pool: &DbPool, - source_id: &str, - summary_text: &str, - period_start: &str, - period_end: &str, - total_entries: i64, - error_count: i64, - warning_count: i64, -) -> Result { +pub fn create_log_summary(pool: &DbPool, params: CreateLogSummaryParams<'_>) -> Result { let conn = pool.get()?; let id = uuid::Uuid::new_v4().to_string(); let now = Utc::now().to_rfc3339(); @@ -129,8 +132,17 @@ pub fn create_log_summary( "INSERT INTO log_summaries (id, source_id, summary_text, period_start, period_end, total_entries, error_count, warning_count, created_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", - params![id, source_id, summary_text, period_start, period_end, - total_entries, error_count, warning_count, now], + rusqlite::params![ + id, + params.source_id, + params.summary_text, + params.period_start, + params.period_end, + params.total_entries, + params.error_count, + params.warning_count, + now + ], )?; Ok(id) @@ -142,24 +154,25 @@ pub fn list_summaries_for_source(pool: &DbPool, source_id: &str) -> Result DbPool { let pool = create_pool(":memory:").unwrap(); @@ -257,7 +271,9 @@ mod tests { update_read_position(&pool, "/tmp/app.log", 4096).unwrap(); - let updated = get_log_source_by_path(&pool, "/tmp/app.log").unwrap().unwrap(); + let updated = get_log_source_by_path(&pool, "/tmp/app.log") + .unwrap() + .unwrap(); assert_eq!(updated.last_read_position, 4096); } @@ -288,14 +304,17 @@ mod tests { let summary_id = create_log_summary( &pool, - &source.id, - "System running normally. 3 warnings about disk space.", - "2026-03-30T12:00:00Z", - "2026-03-30T13:00:00Z", - 500, - 0, - 3, - ).unwrap(); + CreateLogSummaryParams { + source_id: &source.id, + summary_text: "System running normally. 3 warnings about disk space.", + period_start: "2026-03-30T12:00:00Z", + period_end: "2026-03-30T13:00:00Z", + total_entries: 500, + error_count: 0, + warning_count: 3, + }, + ) + .unwrap(); assert!(!summary_id.is_empty()); diff --git a/src/docker/client.rs b/src/docker/client.rs index 751fe14..3d57091 100644 --- a/src/docker/client.rs +++ b/src/docker/client.rs @@ -1,12 +1,12 @@ //! Docker client wrapper -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use std::collections::HashMap; // Bollard imports -use bollard::Docker; -use bollard::container::{ListContainersOptions, InspectContainerOptions}; +use bollard::container::{InspectContainerOptions, ListContainersOptions}; use bollard::network::{DisconnectNetworkOptions, ListNetworksOptions}; +use bollard::Docker; /// Docker client wrapper pub struct DockerClient { @@ -16,17 +16,18 @@ pub struct DockerClient { impl DockerClient { /// Create a new Docker client pub async fn new() -> Result { - let client = Docker::connect_with_local_defaults() - .context("Failed to connect to Docker daemon")?; - + let client = + Docker::connect_with_local_defaults().context("Failed to connect to Docker daemon")?; + // Test connection - client.ping() + client + .ping() .await .context("Failed to ping Docker daemon")?; - + Ok(Self { client }) } - + /// List all containers pub async fn list_containers(&self, all: bool) -> Result> { let options: Option> = Some(ListContainersOptions { @@ -35,11 +36,12 @@ impl DockerClient { ..Default::default() }); - let containers: Vec = self.client + let containers: Vec = self + .client .list_containers(options) .await .context("Failed to list containers")?; - + let mut result = Vec::new(); for container in containers { if let Some(id) = container.id { @@ -47,23 +49,26 @@ impl DockerClient { result.push(info); } } - + Ok(result) } - + /// Get container info by ID pub async fn get_container_info(&self, container_id: &str) -> Result { - let inspect = self.client + let inspect = self + .client .inspect_container(container_id, None::) .await .context("Failed to inspect container")?; - + let config = inspect.config.unwrap_or_default(); let state = inspect.state.unwrap_or_default(); - + Ok(ContainerInfo { id: container_id.to_string(), - name: config.hostname.unwrap_or_else(|| container_id[..12].to_string()), + name: config + .hostname + .unwrap_or_else(|| container_id[..12].to_string()), image: config.image.unwrap_or_else(|| "unknown".to_string()), status: if state.running.unwrap_or(false) { "Running" @@ -71,21 +76,27 @@ impl DockerClient { "Paused" } else { "Stopped" - }.to_string(), + } + .to_string(), created: state.started_at.unwrap_or_default(), - network_settings: inspect.network_settings.map(|ns| { - ns.networks.unwrap_or_default() - .into_iter() - .map(|(name, endpoint)| (name, endpoint.ip_address.unwrap_or_default())) - .collect() - }).unwrap_or_default(), + network_settings: inspect + .network_settings + .map(|ns| { + ns.networks + .unwrap_or_default() + .into_iter() + .map(|(name, endpoint)| (name, endpoint.ip_address.unwrap_or_default())) + .collect() + }) + .unwrap_or_default(), }) } - + /// Quarantine a container (disconnect from all networks) pub async fn quarantine_container(&self, container_id: &str) -> Result<()> { // List all networks - let networks: Vec = self.client + let networks: Vec = self + .client .list_networks(None::>) .await .context("Failed to list networks")?; @@ -103,26 +114,28 @@ impl DockerClient { force: true, }; - let _ = self.client - .disconnect_network(&name, options) - .await; + let _ = self.client.disconnect_network(&name, options).await; } } - + Ok(()) } - + /// Release a container (reconnect to default network) pub async fn release_container(&self, container_id: &str, network_name: &str) -> Result<()> { // Connect to the specified network // Note: This requires additional implementation for network connection // For now, just log the action - log::info!("Would reconnect container {} to network {}", container_id, network_name); + log::info!( + "Would reconnect container {} to network {}", + container_id, + network_name + ); Ok(()) } - + /// Get container stats - pub async fn get_container_stats(&self, container_id: &str) -> Result { + pub async fn get_container_stats(&self, _container_id: &str) -> Result { // Implementation would use Docker stats API // For now, return placeholder Ok(ContainerStats { @@ -164,7 +177,7 @@ mod tests { async fn test_docker_client_creation() { // This test requires Docker daemon running let result = DockerClient::new().await; - + // Test may fail if Docker is not running if result.is_ok() { let client = result.unwrap(); diff --git a/src/docker/containers.rs b/src/docker/containers.rs index 5db967f..146f2f9 100644 --- a/src/docker/containers.rs +++ b/src/docker/containers.rs @@ -1,11 +1,11 @@ //! Container management -use anyhow::Result; -use crate::docker::client::{DockerClient, ContainerInfo}; -use crate::database::{DbPool, create_sample_alert, create_alert, update_alert_status}; use crate::database::models::Alert; -use uuid::Uuid; +use crate::database::{create_alert, DbPool}; +use crate::docker::client::{ContainerInfo, DockerClient}; +use anyhow::Result; use chrono::Utc; +use uuid::Uuid; /// Container manager pub struct ContainerManager { @@ -19,22 +19,22 @@ impl ContainerManager { let docker = DockerClient::new().await?; Ok(Self { docker, pool }) } - + /// List all containers pub async fn list_containers(&self) -> Result> { self.docker.list_containers(true).await } - + /// Get container by ID pub async fn get_container(&self, container_id: &str) -> Result { self.docker.get_container_info(container_id).await } - + /// Quarantine a container pub async fn quarantine_container(&self, container_id: &str, reason: &str) -> Result<()> { // Disconnect from networks self.docker.quarantine_container(container_id).await?; - + // Create alert let alert = Alert { id: Uuid::new_v4().to_string(), @@ -45,39 +45,44 @@ impl ContainerManager { timestamp: Utc::now().to_rfc3339(), metadata: Some(format!("container_id={}", container_id)), }; - + let _ = create_alert(&self.pool, alert).await; - + log::info!("Container {} quarantined: {}", container_id, reason); Ok(()) } - + /// Release a container from quarantine pub async fn release_container(&self, container_id: &str) -> Result<()> { // Reconnect to default network - self.docker.release_container(container_id, "bridge").await?; - + self.docker + .release_container(container_id, "bridge") + .await?; + // Update any quarantine alerts // (In production, would query for specific alerts) - + log::info!("Container {} released from quarantine", container_id); Ok(()) } - + /// Get container security status - pub async fn get_container_security_status(&self, container_id: &str) -> Result { - let info = self.docker.get_container_info(container_id).await?; - + pub async fn get_container_security_status( + &self, + container_id: &str, + ) -> Result { + let _info = self.docker.get_container_info(container_id).await?; + // Calculate risk score based on various factors - let mut risk_score = 0; - let mut threats = 0; - let mut security_state = "Secure"; - + let risk_score = 0; + let threats = 0; + let security_state = "Secure"; + // Check if running as root // Check for privileged mode // Check for exposed ports // Check for volume mounts - + Ok(ContainerSecurityStatus { container_id: container_id.to_string(), risk_score, @@ -105,10 +110,10 @@ mod tests { async fn test_container_manager_creation() { let pool = create_pool(":memory:").unwrap(); init_database(&pool).unwrap(); - + // This test requires Docker daemon let result = ContainerManager::new(pool).await; - + if result.is_ok() { let manager = result.unwrap(); let containers = manager.list_containers().await; diff --git a/src/docker/mod.rs b/src/docker/mod.rs index 0fbae60..03de6d2 100644 --- a/src/docker/mod.rs +++ b/src/docker/mod.rs @@ -3,5 +3,5 @@ pub mod client; pub mod containers; -pub use client::{DockerClient, ContainerInfo, ContainerStats}; +pub use client::{ContainerInfo, ContainerStats, DockerClient}; pub use containers::{ContainerManager, ContainerSecurityStatus}; diff --git a/src/events/mod.rs b/src/events/mod.rs index 1ec2559..3ac040c 100644 --- a/src/events/mod.rs +++ b/src/events/mod.rs @@ -2,10 +2,10 @@ //! //! Contains all security event types, conversions, validation, and streaming -pub mod syscall; pub mod security; -pub mod validation; pub mod stream; +pub mod syscall; +pub mod validation; /// Marker struct for module tests pub struct EventsMarker; diff --git a/src/events/security.rs b/src/events/security.rs index d765623..b6ccf5c 100644 --- a/src/events/security.rs +++ b/src/events/security.rs @@ -26,7 +26,7 @@ impl SecurityEvent { _ => None, } } - + /// Get the UID if this is a syscall event pub fn uid(&self) -> Option { match self { @@ -34,7 +34,7 @@ impl SecurityEvent { _ => None, } } - + /// Get the timestamp pub fn timestamp(&self) -> DateTime { match self { @@ -135,25 +135,25 @@ pub enum AlertSeverity { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_container_event_type_variants() { let _start = ContainerEventType::Start; let _stop = ContainerEventType::Stop; } - + #[test] fn test_alert_type_variants() { let _threat = AlertType::ThreatDetected; let _anomaly = AlertType::AnomalyDetected; } - + #[test] fn test_alert_severity_variants() { let _info = AlertSeverity::Info; let _critical = AlertSeverity::Critical; } - + #[test] fn test_security_event_from_syscall() { let syscall_event = SyscallEvent::new( @@ -162,11 +162,11 @@ mod tests { crate::events::syscall::SyscallType::Execve, Utc::now(), ); - + let security_event: SecurityEvent = syscall_event.into(); - + match security_event { - SecurityEvent::Syscall(_) => {}, + SecurityEvent::Syscall(_) => {} _ => panic!("Expected Syscall variant"), } } diff --git a/src/events/stream.rs b/src/events/stream.rs index a38a2c4..c64d70b 100644 --- a/src/events/stream.rs +++ b/src/events/stream.rs @@ -2,9 +2,9 @@ //! //! Provides event batch, filter, and iterator types for streaming operations -use chrono::{DateTime, Utc}; -use crate::events::syscall::SyscallType; use crate::events::security::SecurityEvent; +use crate::events::syscall::SyscallType; +use chrono::{DateTime, Utc}; /// A batch of security events for bulk operations #[derive(Debug, Clone, Default)] @@ -15,43 +15,41 @@ pub struct EventBatch { impl EventBatch { /// Create a new empty batch pub fn new() -> Self { - Self { - events: Vec::new(), - } + Self { events: Vec::new() } } - + /// Create a batch with capacity pub fn with_capacity(capacity: usize) -> Self { Self { events: Vec::with_capacity(capacity), } } - + /// Add an event to the batch pub fn add(&mut self, event: SecurityEvent) { self.events.push(event); } - + /// Get the number of events in the batch pub fn len(&self) -> usize { self.events.len() } - + /// Check if the batch is empty pub fn is_empty(&self) -> bool { self.events.is_empty() } - + /// Get events in the batch pub fn events(&self) -> &[SecurityEvent] { &self.events } - + /// Clear the batch pub fn clear(&mut self) { self.events.clear(); } - + /// Iterate over events pub fn iter(&self) -> impl Iterator { self.events.iter() @@ -67,7 +65,7 @@ impl From> for EventBatch { impl IntoIterator for EventBatch { type Item = SecurityEvent; type IntoIter = std::vec::IntoIter; - + fn into_iter(self) -> Self::IntoIter { self.events.into_iter() } @@ -88,32 +86,32 @@ impl EventFilter { pub fn new() -> Self { Self::default() } - + /// Filter by syscall type pub fn with_syscall_type(mut self, syscall_type: SyscallType) -> Self { self.syscall_type = Some(syscall_type); self } - + /// Filter by PID pub fn with_pid(mut self, pid: u32) -> Self { self.pid = Some(pid); self } - + /// Filter by UID pub fn with_uid(mut self, uid: u32) -> Self { self.uid = Some(uid); self } - + /// Filter by time range pub fn with_time_range(mut self, start: DateTime, end: DateTime) -> Self { self.start_time = Some(start); self.end_time = Some(end); self } - + /// Check if an event matches this filter pub fn matches(&self, event: &SecurityEvent) -> bool { // Check syscall type @@ -126,7 +124,7 @@ impl EventFilter { return false; } } - + // Check PID if let Some(filter_pid) = self.pid { if let Some(event_pid) = event.pid() { @@ -137,7 +135,7 @@ impl EventFilter { return false; } } - + // Check UID if let Some(filter_uid) = self.uid { if let Some(event_uid) = event.uid() { @@ -148,7 +146,7 @@ impl EventFilter { return false; } } - + // Check time range let event_time = event.timestamp(); if let Some(start) = self.start_time { @@ -161,7 +159,7 @@ impl EventFilter { return false; } } - + true } } @@ -177,7 +175,7 @@ impl EventIterator { pub fn new(events: Vec) -> Self { Self { events, index: 0 } } - + /// Filter events matching the filter pub fn filter(self, filter: &EventFilter) -> FilteredEventIterator { FilteredEventIterator { @@ -185,13 +183,9 @@ impl EventIterator { filter: filter.clone(), } } - + /// Filter events by time range - pub fn time_range( - self, - start: DateTime, - end: DateTime, - ) -> FilteredEventIterator { + pub fn time_range(self, start: DateTime, end: DateTime) -> FilteredEventIterator { let filter = EventFilter::new().with_time_range(start, end); self.filter(&filter) } @@ -199,7 +193,7 @@ impl EventIterator { impl Iterator for EventIterator { type Item = SecurityEvent; - + fn next(&mut self) -> Option { if self.index < self.events.len() { let event = self.events[self.index].clone(); @@ -219,14 +213,9 @@ pub struct FilteredEventIterator { impl Iterator for FilteredEventIterator { type Item = SecurityEvent; - + fn next(&mut self) -> Option { - while let Some(event) = self.inner.next() { - if self.filter.matches(&event) { - return Some(event); - } - } - None + self.inner.by_ref().find(|event| self.filter.matches(event)) } } @@ -234,43 +223,39 @@ impl Iterator for FilteredEventIterator { mod tests { use super::*; use crate::events::syscall::SyscallEvent; - + #[test] fn test_event_batch_new() { let batch = EventBatch::new(); assert_eq!(batch.len(), 0); assert!(batch.is_empty()); } - + #[test] fn test_event_batch_add() { let mut batch = EventBatch::new(); - let event: SecurityEvent = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - + let event: SecurityEvent = + SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()).into(); + batch.add(event); assert_eq!(batch.len(), 1); assert!(!batch.is_empty()); } - + #[test] fn test_event_filter_new() { let filter = EventFilter::new(); assert!(filter.syscall_type.is_none()); assert!(filter.pid.is_none()); } - + #[test] fn test_event_filter_chained() { let filter = EventFilter::new() .with_syscall_type(SyscallType::Execve) .with_pid(1234) .with_uid(1000); - + assert!(filter.syscall_type.is_some()); assert_eq!(filter.pid, Some(1234)); assert_eq!(filter.uid, Some(1000)); diff --git a/src/events/syscall.rs b/src/events/syscall.rs index 85f6db3..ede04bf 100644 --- a/src/events/syscall.rs +++ b/src/events/syscall.rs @@ -11,7 +11,7 @@ pub enum SyscallType { // Process execution Execve, Execveat, - + // Network Connect, Accept, @@ -19,23 +19,23 @@ pub enum SyscallType { Listen, Socket, Sendto, - + // File operations Open, Openat, Close, Read, Write, - + // Security-sensitive Ptrace, Setuid, Setgid, - + // Mount operations Mount, Umount, - + #[default] Unknown, } @@ -53,12 +53,7 @@ pub struct SyscallEvent { impl SyscallEvent { /// Create a new syscall event - pub fn new( - pid: u32, - uid: u32, - syscall_type: SyscallType, - timestamp: DateTime, - ) -> Self { + pub fn new(pid: u32, uid: u32, syscall_type: SyscallType, timestamp: DateTime) -> Self { Self { pid, uid, @@ -68,17 +63,17 @@ impl SyscallEvent { comm: None, } } - + /// Create a builder for SyscallEvent pub fn builder() -> SyscallEventBuilder { SyscallEventBuilder::new() } - + /// Get the PID if this is a syscall event pub fn pid(&self) -> Option { Some(self.pid) } - + /// Get the UID if this is a syscall event pub fn uid(&self) -> Option { Some(self.uid) @@ -106,37 +101,37 @@ impl SyscallEventBuilder { comm: None, } } - + pub fn pid(mut self, pid: u32) -> Self { self.pid = pid; self } - + pub fn uid(mut self, uid: u32) -> Self { self.uid = uid; self } - + pub fn syscall_type(mut self, syscall_type: SyscallType) -> Self { self.syscall_type = syscall_type; self } - + pub fn timestamp(mut self, timestamp: DateTime) -> Self { self.timestamp = Some(timestamp); self } - + pub fn container_id(mut self, container_id: Option) -> Self { self.container_id = container_id; self } - + pub fn comm(mut self, comm: Option) -> Self { self.comm = comm; self } - + pub fn build(self) -> SyscallEvent { SyscallEvent { pid: self.pid, @@ -158,26 +153,21 @@ impl Default for SyscallEventBuilder { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_syscall_type_default() { assert_eq!(SyscallType::default(), SyscallType::Unknown); } - + #[test] fn test_syscall_event_new() { - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); assert_eq!(event.pid, 1234); assert_eq!(event.uid, 1000); assert_eq!(event.pid(), Some(1234)); assert_eq!(event.uid(), Some(1000)); } - + #[test] fn test_syscall_event_builder() { let event = SyscallEvent::builder() diff --git a/src/events/validation.rs b/src/events/validation.rs index 311d05e..6181598 100644 --- a/src/events/validation.rs +++ b/src/events/validation.rs @@ -2,9 +2,9 @@ //! //! Provides validation for security events -use std::net::IpAddr; +use crate::events::security::{AlertEvent, NetworkEvent}; use crate::events::syscall::SyscallEvent; -use crate::events::security::{NetworkEvent, AlertEvent}; +use std::net::IpAddr; /// Result of event validation #[derive(Debug, Clone, PartialEq)] @@ -19,25 +19,28 @@ impl ValidationResult { pub fn valid() -> Self { ValidationResult::Valid } - + /// Create an invalid result with reason pub fn invalid(reason: impl Into) -> Self { ValidationResult::Invalid(reason.into()) } - + /// Create an error result with message pub fn error(message: impl Into) -> Self { ValidationResult::Error(message.into()) } - + /// Check if validation passed pub fn is_valid(&self) -> bool { matches!(self, ValidationResult::Valid) } - + /// Check if validation failed pub fn is_invalid(&self) -> bool { - matches!(self, ValidationResult::Invalid(_) | ValidationResult::Error(_)) + matches!( + self, + ValidationResult::Invalid(_) | ValidationResult::Error(_) + ) } } @@ -62,40 +65,40 @@ impl EventValidator { if event.pid == 0 { return ValidationResult::valid(); } - + // UID 0 is valid (root) // All syscalls are valid ValidationResult::valid() } - + /// Validate a network event pub fn validate_network(event: &NetworkEvent) -> ValidationResult { // Validate source IP if let Err(e) = event.src_ip.parse::() { return ValidationResult::invalid(format!("Invalid source IP: {}", e)); } - + // Validate destination IP if let Err(e) = event.dst_ip.parse::() { return ValidationResult::invalid(format!("Invalid destination IP: {}", e)); } - + // Validate port range (0-65535 is always valid for u16) // No additional validation needed for u16 - + ValidationResult::valid() } - + /// Validate an alert event pub fn validate_alert(event: &AlertEvent) -> ValidationResult { // Validate message is not empty if event.message.trim().is_empty() { return ValidationResult::invalid("Alert message cannot be empty"); } - + ValidationResult::valid() } - + /// Validate an IP address string pub fn validate_ip(ip: &str) -> ValidationResult { match ip.parse::() { @@ -103,9 +106,9 @@ impl EventValidator { Err(e) => ValidationResult::invalid(format!("Invalid IP address: {}", e)), } } - + /// Validate a port number - pub fn validate_port(port: u16) -> ValidationResult { + pub fn validate_port(_port: u16) -> ValidationResult { // All u16 values are valid ports (0-65535) ValidationResult::valid() } @@ -115,42 +118,36 @@ impl EventValidator { mod tests { use super::*; use crate::events::syscall::SyscallType; - use crate::events::security::{AlertType, AlertSeverity}; use chrono::Utc; - + #[test] fn test_validation_result_valid() { let result = ValidationResult::valid(); assert!(result.is_valid()); assert!(!result.is_invalid()); } - + #[test] fn test_validation_result_invalid() { let result = ValidationResult::invalid("test reason"); assert!(!result.is_valid()); assert!(result.is_invalid()); } - + #[test] fn test_validation_result_error() { let result = ValidationResult::error("test error"); assert!(!result.is_valid()); assert!(result.is_invalid()); } - + #[test] fn test_validate_syscall_event() { - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); let result = EventValidator::validate_syscall(&event); assert!(result.is_valid()); } - + #[test] fn test_validate_ip() { assert!(EventValidator::validate_ip("192.168.1.1").is_valid()); diff --git a/src/firewall/backend.rs b/src/firewall/backend.rs index 2875100..1e81028 100644 --- a/src/firewall/backend.rs +++ b/src/firewall/backend.rs @@ -8,28 +8,28 @@ use anyhow::Result; pub trait FirewallBackend: Send + Sync { /// Initialize the backend fn initialize(&mut self) -> Result<()>; - + /// Check if backend is available fn is_available(&self) -> bool; - + /// Block an IP address fn block_ip(&self, ip: &str) -> Result<()>; - + /// Unblock an IP address fn unblock_ip(&self, ip: &str) -> Result<()>; - + /// Block a port fn block_port(&self, port: u16) -> Result<()>; - + /// Unblock a port fn unblock_port(&self, port: u16) -> Result<()>; - + /// Block all traffic for a container fn block_container(&self, container_id: &str) -> Result<()>; - + /// Unblock all traffic for a container fn unblock_container(&self, container_id: &str) -> Result<()>; - + /// Get backend name fn name(&self) -> &str; } @@ -43,7 +43,11 @@ pub struct FirewallRule { } impl FirewallRule { - pub fn new(chain: impl Into, rule_spec: impl Into, table: impl Into) -> Self { + pub fn new( + chain: impl Into, + rule_spec: impl Into, + table: impl Into, + ) -> Self { Self { chain: chain.into(), rule_spec: rule_spec.into(), @@ -77,7 +81,11 @@ pub struct FirewallChain { } impl FirewallChain { - pub fn new(table: FirewallTable, name: impl Into, chain_type: impl Into) -> Self { + pub fn new( + table: FirewallTable, + name: impl Into, + chain_type: impl Into, + ) -> Self { Self { table, name: name.into(), diff --git a/src/firewall/iptables.rs b/src/firewall/iptables.rs index a343b8c..b45e29e 100644 --- a/src/firewall/iptables.rs +++ b/src/firewall/iptables.rs @@ -2,7 +2,7 @@ //! //! Manages iptables firewall rules (fallback when nftables unavailable) -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use std::process::Command; use crate::firewall::backend::FirewallBackend; @@ -55,116 +55,130 @@ impl IptablesBackend { .output() .map(|o| o.status.success()) .unwrap_or(false); - + if !available { anyhow::bail!("iptables command not available"); } - + Ok(Self { available: true }) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("iptables only available on Linux"); } } - + /// Create a chain pub fn create_chain(&self, chain: &IptChain) -> Result<()> { let output = Command::new("iptables") - .args(&["-t", &chain.table, "-N", &chain.name]) + .args(["-t", &chain.table, "-N", &chain.name]) .output() .context("Failed to create iptables chain")?; - + if !output.status.success() { - anyhow::bail!("Failed to create chain: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to create chain: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Delete a chain pub fn delete_chain(&self, chain: &IptChain) -> Result<()> { let output = Command::new("iptables") - .args(&["-t", &chain.table, "-X", &chain.name]) + .args(["-t", &chain.table, "-X", &chain.name]) .output() .context("Failed to delete iptables chain")?; - + if !output.status.success() { - anyhow::bail!("Failed to delete chain: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to delete chain: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Add a rule pub fn add_rule(&self, rule: &IptRule) -> Result<()> { let args: Vec<&str> = vec!["-t", &rule.chain.table, "-A", &rule.chain.name]; let rule_parts: Vec<&str> = rule.rule_spec.split_whitespace().collect(); - + let mut cmd = Command::new("iptables"); cmd.args(&args); cmd.args(&rule_parts); - - let output = cmd - .output() - .context("Failed to add iptables rule")?; - + + let output = cmd.output().context("Failed to add iptables rule")?; + if !output.status.success() { - anyhow::bail!("Failed to add rule: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to add rule: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Delete a rule pub fn delete_rule(&self, rule: &IptRule) -> Result<()> { let args: Vec<&str> = vec!["-t", &rule.chain.table, "-D", &rule.chain.name]; let rule_parts: Vec<&str> = rule.rule_spec.split_whitespace().collect(); - + let mut cmd = Command::new("iptables"); cmd.args(&args); cmd.args(&rule_parts); - - let output = cmd - .output() - .context("Failed to delete iptables rule")?; - + + let output = cmd.output().context("Failed to delete iptables rule")?; + if !output.status.success() { - anyhow::bail!("Failed to delete rule: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to delete rule: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Flush a chain pub fn flush_chain(&self, chain: &IptChain) -> Result<()> { let output = Command::new("iptables") - .args(&["-t", &chain.table, "-F", &chain.name]) + .args(["-t", &chain.table, "-F", &chain.name]) .output() .context("Failed to flush iptables chain")?; - + if !output.status.success() { - anyhow::bail!("Failed to flush chain: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to flush chain: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// List rules in a chain pub fn list_rules(&self, chain: &IptChain) -> Result> { let output = Command::new("iptables") - .args(&["-t", &chain.table, "-L", &chain.name, "-n"]) + .args(["-t", &chain.table, "-L", &chain.name, "-n"]) .output() .context("Failed to list iptables rules")?; - + if !output.status.success() { - anyhow::bail!("Failed to list rules: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to list rules: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + let stdout = String::from_utf8_lossy(&output.stdout); let rules: Vec = stdout.lines().map(|s| s.to_string()).collect(); - + Ok(rules) } } @@ -173,45 +187,45 @@ impl FirewallBackend for IptablesBackend { fn initialize(&mut self) -> Result<()> { Ok(()) } - + fn is_available(&self) -> bool { self.available } - + fn block_ip(&self, ip: &str) -> Result<()> { let chain = IptChain::new("filter", "INPUT"); let rule = IptRule::new(&chain, format!("-s {} -j DROP", ip)); self.add_rule(&rule) } - + fn unblock_ip(&self, ip: &str) -> Result<()> { let chain = IptChain::new("filter", "INPUT"); let rule = IptRule::new(&chain, format!("-s {} -j DROP", ip)); self.delete_rule(&rule) } - + fn block_port(&self, port: u16) -> Result<()> { let chain = IptChain::new("filter", "INPUT"); let rule = IptRule::new(&chain, format!("-p tcp --dport {} -j DROP", port)); self.add_rule(&rule) } - + fn unblock_port(&self, port: u16) -> Result<()> { let chain = IptChain::new("filter", "INPUT"); let rule = IptRule::new(&chain, format!("-p tcp --dport {} -j DROP", port)); self.delete_rule(&rule) } - + fn block_container(&self, container_id: &str) -> Result<()> { log::info!("Would block container via iptables: {}", container_id); Ok(()) } - + fn unblock_container(&self, container_id: &str) -> Result<()> { log::info!("Would unblock container via iptables: {}", container_id); Ok(()) } - + fn name(&self) -> &str { "iptables" } @@ -220,14 +234,14 @@ impl FirewallBackend for IptablesBackend { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_ipt_chain_creation() { let chain = IptChain::new("filter", "INPUT"); assert_eq!(chain.table, "filter"); assert_eq!(chain.name, "INPUT"); } - + #[test] fn test_ipt_rule_creation() { let chain = IptChain::new("filter", "INPUT"); diff --git a/src/firewall/mod.rs b/src/firewall/mod.rs index 58ce962..be53ec0 100644 --- a/src/firewall/mod.rs +++ b/src/firewall/mod.rs @@ -3,8 +3,8 @@ //! Manages firewall rules (nftables/iptables) and container quarantine pub mod backend; -pub mod nftables; pub mod iptables; +pub mod nftables; pub mod quarantine; pub mod response; @@ -12,8 +12,8 @@ pub mod response; pub struct FirewallMarker; // Re-export commonly used types -pub use nftables::{NfTablesBackend, NfTable, NfChain, NfRule}; -pub use iptables::{IptablesBackend, IptChain, IptRule}; -pub use quarantine::{QuarantineManager, QuarantineState, QuarantineInfo}; +pub use backend::{FirewallBackend, FirewallChain, FirewallRule, FirewallTable}; +pub use iptables::{IptChain, IptRule, IptablesBackend}; +pub use nftables::{NfChain, NfRule, NfTable, NfTablesBackend}; +pub use quarantine::{QuarantineInfo, QuarantineManager, QuarantineState}; pub use response::{ResponseAction, ResponseChain, ResponseExecutor, ResponseType}; -pub use backend::{FirewallBackend, FirewallRule, FirewallTable, FirewallChain}; diff --git a/src/firewall/nftables.rs b/src/firewall/nftables.rs index afec647..afb8b2b 100644 --- a/src/firewall/nftables.rs +++ b/src/firewall/nftables.rs @@ -2,10 +2,10 @@ //! //! Manages nftables firewall rules -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use std::process::Command; -use crate::firewall::backend::{FirewallBackend, FirewallRule, FirewallTable, FirewallChain}; +use crate::firewall::backend::FirewallBackend; /// nftables table #[derive(Debug, Clone)] @@ -21,9 +21,11 @@ impl NfTable { name: name.into(), } } - - fn to_string(&self) -> String { - format!("{} {}", self.family, self.name) +} + +impl std::fmt::Display for NfTable { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} {}", self.family, self.name) } } @@ -77,131 +79,141 @@ impl NfTablesBackend { .output() .map(|o| o.status.success()) .unwrap_or(false); - + if !available { anyhow::bail!("nft command not available"); } - + Ok(Self { available: true }) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("nftables only available on Linux"); } } - + /// Create a table pub fn create_table(&self, table: &NfTable) -> Result<()> { + let table_str = table.to_string(); let output = Command::new("nft") - .args(&["add", "table", &table.to_string()]) + .args(["add", "table", &table_str]) .output() .context("Failed to create nftables table")?; - + if !output.status.success() { - anyhow::bail!("Failed to create table: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to create table: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Delete a table pub fn delete_table(&self, table: &NfTable) -> Result<()> { + let table_str = table.to_string(); let output = Command::new("nft") - .args(&["delete", "table", &table.to_string()]) + .args(["delete", "table", &table_str]) .output() .context("Failed to delete nftables table")?; - + if !output.status.success() { - anyhow::bail!("Failed to delete table: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to delete table: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Create a chain pub fn create_chain(&self, chain: &NfChain) -> Result<()> { let cmd = format!( "add chain {} {} {{ type {} hook input priority 0; }}", - chain.table.to_string(), - chain.name, - chain.chain_type + chain.table, chain.name, chain.chain_type ); - + let output = Command::new("nft") - .args(&["-c", &cmd]) + .args(["-c", &cmd]) .output() .context("Failed to create nftables chain")?; - + if !output.status.success() { - anyhow::bail!("Failed to create chain: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to create chain: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Delete a chain pub fn delete_chain(&self, chain: &NfChain) -> Result<()> { - let cmd = format!( - "delete chain {} {}", - chain.table.to_string(), - chain.name - ); - + let cmd = format!("delete chain {} {}", chain.table, chain.name); + let output = Command::new("nft") - .args(&["-c", &cmd]) + .args(["-c", &cmd]) .output() .context("Failed to delete nftables chain")?; - + if !output.status.success() { - anyhow::bail!("Failed to delete chain: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to delete chain: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Add a rule pub fn add_rule(&self, rule: &NfRule) -> Result<()> { let cmd = format!( "add rule {} {} {}", - rule.chain.table.to_string(), - rule.chain.name, - rule.rule_spec + rule.chain.table, rule.chain.name, rule.rule_spec ); - + let output = Command::new("nft") - .args(&["-c", &cmd]) + .args(["-c", &cmd]) .output() .context("Failed to add nftables rule")?; - + if !output.status.success() { - anyhow::bail!("Failed to add rule: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to add rule: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Delete a rule pub fn delete_rule(&self, rule: &NfRule) -> Result<()> { let cmd = format!( "delete rule {} {} {}", - rule.chain.table.to_string(), - rule.chain.name, - rule.rule_spec + rule.chain.table, rule.chain.name, rule.rule_spec ); - + let output = Command::new("nft") - .args(&["-c", &cmd]) + .args(["-c", &cmd]) .output() .context("Failed to delete nftables rule")?; - + if !output.status.success() { - anyhow::bail!("Failed to delete rule: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to delete rule: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// Batch add multiple rules pub fn batch_add_rules(&self, rules: &[NfRule]) -> Result<()> { for rule in rules { @@ -209,47 +221,45 @@ impl NfTablesBackend { } Ok(()) } - + /// Flush a chain pub fn flush_chain(&self, chain: &NfChain) -> Result<()> { - let cmd = format!( - "flush chain {} {}", - chain.table.to_string(), - chain.name - ); - + let cmd = format!("flush chain {} {}", chain.table, chain.name); + let output = Command::new("nft") - .args(&["-c", &cmd]) + .args(["-c", &cmd]) .output() .context("Failed to flush nftables chain")?; - + if !output.status.success() { - anyhow::bail!("Failed to flush chain: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to flush chain: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + Ok(()) } - + /// List rules in a chain pub fn list_rules(&self, chain: &NfChain) -> Result> { - let cmd = format!( - "list chain {} {}", - chain.table.to_string(), - chain.name - ); - + let cmd = format!("list chain {} {}", chain.table, chain.name); + let output = Command::new("nft") - .args(&["-c", &cmd]) + .args(["-c", &cmd]) .output() .context("Failed to list nftables rules")?; - + if !output.status.success() { - anyhow::bail!("Failed to list rules: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!( + "Failed to list rules: {}", + String::from_utf8_lossy(&output.stderr) + ); } - + let stdout = String::from_utf8_lossy(&output.stdout); let rules: Vec = stdout.lines().map(|s| s.to_string()).collect(); - + Ok(rules) } } @@ -258,42 +268,42 @@ impl FirewallBackend for NfTablesBackend { fn initialize(&mut self) -> Result<()> { Ok(()) } - + fn is_available(&self) -> bool { self.available } - + fn block_ip(&self, ip: &str) -> Result<()> { // Implementation would add nftables rule to block IP log::info!("Would block IP: {}", ip); Ok(()) } - + fn unblock_ip(&self, ip: &str) -> Result<()> { log::info!("Would unblock IP: {}", ip); Ok(()) } - + fn block_port(&self, port: u16) -> Result<()> { log::info!("Would block port: {}", port); Ok(()) } - + fn unblock_port(&self, port: u16) -> Result<()> { log::info!("Would unblock port: {}", port); Ok(()) } - + fn block_container(&self, container_id: &str) -> Result<()> { log::info!("Would block container: {}", container_id); Ok(()) } - + fn unblock_container(&self, container_id: &str) -> Result<()> { log::info!("Would unblock container: {}", container_id); Ok(()) } - + fn name(&self) -> &str { "nftables" } @@ -302,14 +312,14 @@ impl FirewallBackend for NfTablesBackend { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_nf_table_creation() { let table = NfTable::new("inet", "stackdog_test"); assert_eq!(table.family, "inet"); assert_eq!(table.name, "stackdog_test"); } - + #[test] fn test_nf_chain_creation() { let table = NfTable::new("inet", "stackdog_test"); diff --git a/src/firewall/quarantine.rs b/src/firewall/quarantine.rs index b779903..127a789 100644 --- a/src/firewall/quarantine.rs +++ b/src/firewall/quarantine.rs @@ -2,12 +2,12 @@ //! //! Isolates compromised containers -use anyhow::{Result, Context}; +use anyhow::Result; use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; -use crate::firewall::nftables::{NfTablesBackend, NfTable, NfChain, NfRule}; +use crate::firewall::nftables::{NfChain, NfTable, NfTablesBackend}; /// Quarantine state #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -31,7 +31,7 @@ pub struct QuarantineInfo { pub struct QuarantineManager { #[cfg(target_os = "linux")] nft: Option, - + states: Arc>>, table_name: String, } @@ -42,20 +42,20 @@ impl QuarantineManager { #[cfg(target_os = "linux")] { let nft = NfTablesBackend::new().ok(); - + Ok(Self { nft, states: Arc::new(RwLock::new(HashMap::new())), table_name: "inet_stackdog_quarantine".to_string(), }) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("Quarantine only available on Linux"); } } - + /// Quarantine a container pub fn quarantine(&mut self, container_id: &str) -> Result<()> { #[cfg(target_os = "linux")] @@ -69,14 +69,14 @@ impl QuarantineManager { } } } - + // Setup nftables table if needed self.setup_quarantine_table()?; - + // Get container IP (would need Docker API integration) // For now, log the action log::info!("Quarantining container: {}", container_id); - + // Add to states let info = QuarantineInfo { container_id: container_id.to_string(), @@ -85,21 +85,21 @@ impl QuarantineManager { state: QuarantineState::Quarantined, reason: None, }; - + { let mut states = self.states.write().unwrap(); states.insert(container_id.to_string(), info); } - + Ok(()) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("Quarantine only available on Linux"); } } - + /// Release a container from quarantine pub fn release(&mut self, container_id: &str) -> Result<()> { #[cfg(target_os = "linux")] @@ -115,10 +115,10 @@ impl QuarantineManager { anyhow::bail!("Container not found in quarantine"); } } - + // Remove nftables rules (would need container IP) log::info!("Releasing container from quarantine: {}", container_id); - + // Update state { let mut states = self.states.write().unwrap(); @@ -127,27 +127,27 @@ impl QuarantineManager { info.state = QuarantineState::Released; } } - + Ok(()) } - + #[cfg(not(target_os = "linux"))] { anyhow::bail!("Quarantine only available on Linux"); } } - + /// Rollback quarantine (release and cleanup) pub fn rollback(&mut self, container_id: &str) -> Result<()> { self.release(container_id) } - + /// Get quarantine state for a container pub fn get_state(&self, container_id: &str) -> Option { let states = self.states.read().unwrap(); states.get(container_id).map(|info| info.state) } - + /// Get all quarantined containers pub fn get_quarantined_containers(&self) -> Vec { let states = self.states.read().unwrap(); @@ -157,42 +157,42 @@ impl QuarantineManager { .map(|(id, _)| id.clone()) .collect() } - + /// Get quarantine info for a container pub fn get_quarantine_info(&self, container_id: &str) -> Option { let states = self.states.read().unwrap(); states.get(container_id).cloned() } - + /// Setup quarantine nftables table #[cfg(target_os = "linux")] fn setup_quarantine_table(&mut self) -> Result<()> { if let Some(ref nft) = self.nft { let table = NfTable::new("inet", &self.table_name); - + // Try to create table (may already exist) let _ = nft.create_table(&table); - + // Create input chain let input_chain = NfChain::new(&table, "quarantine_input", "filter"); let _ = nft.create_chain(&input_chain); - + // Create output chain let output_chain = NfChain::new(&table, "quarantine_output", "filter"); let _ = nft.create_chain(&output_chain); } - + Ok(()) } - + /// Get quarantine statistics pub fn get_stats(&self) -> QuarantineStats { let states = self.states.read().unwrap(); - + let mut currently_quarantined = 0; let mut released = 0; let mut failed = 0; - + for info in states.values() { match info.state { QuarantineState::Quarantined => currently_quarantined += 1, @@ -200,7 +200,7 @@ impl QuarantineManager { QuarantineState::Failed => failed += 1, } } - + QuarantineStats { currently_quarantined, total_quarantined: states.len() as u64, @@ -228,14 +228,14 @@ pub struct QuarantineStats { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_quarantine_state_variants() { let _quarantined = QuarantineState::Quarantined; let _released = QuarantineState::Released; let _failed = QuarantineState::Failed; } - + #[test] fn test_quarantine_info_creation() { let info = QuarantineInfo { @@ -245,7 +245,7 @@ mod tests { state: QuarantineState::Quarantined, reason: Some("Test".to_string()), }; - + assert_eq!(info.container_id, "test123"); assert_eq!(info.state, QuarantineState::Quarantined); } diff --git a/src/firewall/response.rs b/src/firewall/response.rs index e850d8c..f4a6d91 100644 --- a/src/firewall/response.rs +++ b/src/firewall/response.rs @@ -39,7 +39,7 @@ impl ResponseAction { retry_delay_ms: 0, } } - + /// Create response from alert pub fn from_alert(alert: &Alert, action_type: ResponseType) -> Self { Self { @@ -49,33 +49,33 @@ impl ResponseAction { retry_delay_ms: 1000, } } - + /// Set retry configuration pub fn set_retry_config(&mut self, max_retries: u32, retry_delay_ms: u64) { self.max_retries = max_retries; self.retry_delay_ms = retry_delay_ms; } - + /// Get action type pub fn action_type(&self) -> ResponseType { self.action_type.clone() } - + /// Get description pub fn description(&self) -> &str { &self.description } - + /// Get max retries pub fn max_retries(&self) -> u32 { self.max_retries } - + /// Get retry delay pub fn retry_delay_ms(&self) -> u64 { self.retry_delay_ms } - + /// Execute the action pub fn execute(&self) -> Result<()> { match &self.action_type { @@ -109,25 +109,28 @@ impl ResponseAction { } } } - + /// Execute with retries pub fn execute_with_retry(&self) -> Result<()> { let mut last_error = None; - + for attempt in 0..=self.max_retries { match self.execute() { Ok(()) => return Ok(()), Err(e) => { last_error = Some(e); if attempt < self.max_retries { - log::warn!("Action failed (attempt {}/{}), retrying...", - attempt + 1, self.max_retries + 1); + log::warn!( + "Action failed (attempt {}/{}), retrying...", + attempt + 1, + self.max_retries + 1 + ); std::thread::sleep(std::time::Duration::from_millis(self.retry_delay_ms)); } } } } - + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Action failed"))) } } @@ -149,32 +152,37 @@ impl ResponseChain { stop_on_failure: false, } } - + /// Add an action to the chain pub fn add_action(&mut self, action: ResponseAction) { self.actions.push(action); } - + /// Set stop on failure pub fn set_stop_on_failure(&mut self, stop: bool) { self.stop_on_failure = stop; } - + /// Get chain name pub fn name(&self) -> &str { &self.name } - + /// Get action count pub fn action_count(&self) -> usize { self.actions.len() } - + /// Execute all actions in chain pub fn execute(&self) -> Result<()> { for (i, action) in self.actions.iter().enumerate() { - log::debug!("Executing action {}/{}: {}", i + 1, self.actions.len(), action.description()); - + log::debug!( + "Executing action {}/{}: {}", + i + 1, + self.actions.len(), + action.description() + ); + match action.execute() { Ok(()) => {} Err(e) => { @@ -187,7 +195,7 @@ impl ResponseChain { } } } - + Ok(()) } } @@ -204,40 +212,40 @@ impl ResponseExecutor { log: Arc::new(RwLock::new(Vec::new())), }) } - + /// Execute a response action pub fn execute(&mut self, action: &ResponseAction) -> Result<()> { - let start = Utc::now(); + let _start = Utc::now(); let result = action.execute(); - let end = Utc::now(); - + let _end = Utc::now(); + // Log the execution let log_entry = ResponseLog::new( action.description().to_string(), result.is_ok(), result.as_ref().err().map(|e| e.to_string()), ); - + { let mut log = self.log.write().unwrap(); log.push(log_entry); } - + result } - + /// Execute a response chain pub fn execute_chain(&mut self, chain: &ResponseChain) -> Result<()> { log::info!("Executing response chain: {}", chain.name()); chain.execute() } - + /// Get execution log pub fn get_log(&self) -> Vec { let log = self.log.read().unwrap(); log.clone() } - + /// Clear execution log pub fn clear_log(&mut self) { let mut log = self.log.write().unwrap(); @@ -269,15 +277,19 @@ impl ResponseLog { timestamp: Utc::now(), } } - + pub fn action_name(&self) -> &str { &self.action_name } - + pub fn success(&self) -> bool { self.success } - + + pub fn error(&self) -> Option<&str> { + self.error.as_deref() + } + pub fn timestamp(&self) -> DateTime { self.timestamp } @@ -294,15 +306,16 @@ impl ResponseAudit { history: Vec::new(), } } - + pub fn record(&mut self, action_name: String, success: bool, error: Option) { - self.history.push(ResponseLog::new(action_name, success, error)); + self.history + .push(ResponseLog::new(action_name, success, error)); } - + pub fn get_history(&self) -> &[ResponseLog] { &self.history } - + pub fn clear(&mut self) { self.history.clear(); } @@ -317,58 +330,54 @@ impl Default for ResponseAudit { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_response_action_creation() { let action = ResponseAction::new( ResponseType::LogAction("test".to_string()), "Test action".to_string(), ); - + assert_eq!(action.description(), "Test action"); } - + #[test] fn test_response_action_execution() { let action = ResponseAction::new( ResponseType::LogAction("test".to_string()), "Test".to_string(), ); - + let result = action.execute(); assert!(result.is_ok()); } - + #[test] fn test_response_chain_creation() { let chain = ResponseChain::new("test_chain"); assert_eq!(chain.name(), "test_chain"); assert_eq!(chain.action_count(), 0); } - + #[test] fn test_response_chain_execution() { let mut chain = ResponseChain::new("test"); - + let action = ResponseAction::new( ResponseType::LogAction("test".to_string()), "Test".to_string(), ); - + chain.add_action(action); - + let result = chain.execute(); assert!(result.is_ok()); } - + #[test] fn test_response_log_creation() { - let log = ResponseLog::new( - "test_action".to_string(), - true, - None, - ); - + let log = ResponseLog::new("test_action".to_string(), true, None); + assert!(log.success()); assert_eq!(log.action_name(), "test_action"); } diff --git a/src/lib.rs b/src/lib.rs index 8a64c1d..ca67009 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ //! Stackdog Security Library //! //! Security platform for Docker containers and Linux servers -//! +//! //! ## Features -//! +//! //! - **eBPF-based syscall monitoring** - Real-time event collection //! - **Event enrichment** - Container detection, process info //! - **Rule engine** - Signature-based detection @@ -15,12 +15,9 @@ #![allow(unused_must_use)] // External crates -#[macro_use] +extern crate log; extern crate serde; -#[macro_use] extern crate serde_json; -#[macro_use] -extern crate log; // Docker (Linux only) #[cfg(target_os = "linux")] @@ -37,10 +34,10 @@ extern crate candle_core; extern crate candle_nn; // Security modules - Core -pub mod events; -pub mod rules; pub mod alerting; +pub mod events; pub mod models; +pub mod rules; // Security modules - Linux-specific #[cfg(target_os = "linux")] @@ -50,22 +47,25 @@ pub mod firewall; pub mod collectors; // Optional modules -pub mod ml; -pub mod response; -pub mod correlator; pub mod baselines; +pub mod correlator; pub mod database; pub mod docker; +pub mod ml; +pub mod response; // Configuration pub mod config; +// API +pub mod api; + // Log sniffing pub mod sniff; // Re-export commonly used types +pub use events::security::{AlertEvent, ContainerEvent, NetworkEvent, SecurityEvent}; pub use events::syscall::{SyscallEvent, SyscallType}; -pub use events::security::{SecurityEvent, NetworkEvent, ContainerEvent, AlertEvent}; // Alerting pub use alerting::{Alert, AlertSeverity, AlertStatus, AlertType}; @@ -73,15 +73,15 @@ pub use alerting::{AlertManager, AlertStats}; pub use alerting::{NotificationChannel, NotificationConfig}; // Linux-specific +pub use collectors::{EbpfLoader, SyscallMonitor}; #[cfg(target_os = "linux")] pub use firewall::{QuarantineManager, QuarantineState}; #[cfg(target_os = "linux")] pub use firewall::{ResponseAction, ResponseChain, ResponseExecutor, ResponseType}; -pub use collectors::{EbpfLoader, SyscallMonitor}; // Rules -pub use rules::{RuleEngine, Rule, RuleResult}; -pub use rules::{Signature, SignatureDatabase, ThreatCategory}; -pub use rules::{SignatureMatcher, PatternMatch, MatchResult}; -pub use rules::{ThreatScorer, ThreatScore, ScoringConfig}; pub use rules::{DetectionStats, StatsTracker}; +pub use rules::{MatchResult, PatternMatch, SignatureMatcher}; +pub use rules::{Rule, RuleEngine, RuleResult}; +pub use rules::{ScoringConfig, ThreatScore, ThreatScorer}; +pub use rules::{Signature, SignatureDatabase, ThreatCategory}; diff --git a/src/main.rs b/src/main.rs index 4bb0619..a665795 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,39 +4,29 @@ #![allow(unused_must_use)] -#[macro_use] +extern crate bollard; extern crate log; -#[macro_use] extern crate serde_json; -extern crate bollard; -extern crate actix_rt; extern crate actix_cors; +extern crate actix_rt; extern crate actix_web; -extern crate env_logger; extern crate dotenv; +extern crate env_logger; extern crate tracing; extern crate tracing_subscriber; -mod config; -mod api; -mod database; -mod docker; -mod events; -mod rules; -mod alerting; -mod models; mod cli; -mod sniff; -use std::{io, env}; -use actix_web::{HttpServer, App, web}; use actix_cors::Cors; +use actix_web::{web, App, HttpServer}; use clap::Parser; -use tracing::{Level, info}; -use tracing_subscriber::FmtSubscriber; -use database::{create_pool, init_database}; use cli::{Cli, Command}; +use stackdog::database::{create_pool, init_database}; +use stackdog::sniff; +use std::{env, io}; +use tracing::{info, Level}; +use tracing_subscriber::FmtSubscriber; #[actix_rt::main] async fn main() -> io::Result<()> { @@ -52,28 +42,52 @@ async fn main() -> io::Result<()> { env::set_var("RUST_LOG", "stackdog=info,actix_web=info"); } env_logger::init(); - + // Setup tracing — respect RUST_LOG for level - let max_level = if env::var("RUST_LOG").map(|v| v.contains("debug")).unwrap_or(false) { + let max_level = if env::var("RUST_LOG") + .map(|v| v.contains("debug")) + .unwrap_or(false) + { Level::DEBUG - } else if env::var("RUST_LOG").map(|v| v.contains("trace")).unwrap_or(false) { + } else if env::var("RUST_LOG") + .map(|v| v.contains("trace")) + .unwrap_or(false) + { Level::TRACE } else { Level::INFO }; - let subscriber = FmtSubscriber::builder() - .with_max_level(max_level) - .finish(); - tracing::subscriber::set_global_default(subscriber) - .expect("setting default subscriber failed"); + let subscriber = FmtSubscriber::builder().with_max_level(max_level).finish(); + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); info!("šŸ• Stackdog Security starting..."); info!("Platform: {}", std::env::consts::OS); info!("Architecture: {}", std::env::consts::ARCH); match cli.command { - Some(Command::Sniff { once, consume, output, sources, interval, ai_provider, ai_model, ai_api_url, slack_webhook }) => { - run_sniff(once, consume, output, sources, interval, ai_provider, ai_model, ai_api_url, slack_webhook).await + Some(Command::Sniff { + once, + consume, + output, + sources, + interval, + ai_provider, + ai_model, + ai_api_url, + slack_webhook, + }) => { + let config = sniff::config::SniffConfig::from_env_and_args(sniff::config::SniffArgs { + once, + consume, + output: &output, + sources: sources.as_deref(), + interval, + ai_provider: ai_provider.as_deref(), + ai_model: ai_model.as_deref(), + ai_api_url: ai_api_url.as_deref(), + slack_webhook: slack_webhook.as_deref(), + }); + run_sniff(config).await } // Default: serve (backward compatible) Some(Command::Serve) | None => run_serve().await, @@ -84,19 +98,24 @@ async fn run_serve() -> io::Result<()> { let app_host = env::var("APP_HOST").unwrap_or_else(|_| "0.0.0.0".to_string()); let app_port = env::var("APP_PORT").unwrap_or_else(|_| "5000".to_string()); let database_url = env::var("DATABASE_URL").unwrap_or_else(|_| "./stackdog.db".to_string()); - + info!("Host: {}", app_host); info!("Port: {}", app_port); info!("Database: {}", database_url); - + let app_url = format!("{}:{}", &app_host, &app_port); - + let display_host = if app_host == "0.0.0.0" { + "127.0.0.1" + } else { + &app_host + }; + // Initialize database info!("Initializing database..."); let pool = create_pool(&database_url).expect("Failed to create database pool"); init_database(&pool).expect("Failed to initialize database"); info!("Database initialized successfully"); - + info!("šŸŽ‰ Stackdog Security ready!"); info!(""); info!("API Endpoints:"); @@ -113,51 +132,36 @@ async fn run_serve() -> io::Result<()> { info!(" GET /api/logs/summaries - List AI summaries"); info!(" WS /ws - WebSocket for real-time updates"); info!(""); - info!("Web Dashboard: http://{}:{}", app_host, app_port); + info!("API started on http://{}:{}", display_host, app_port); info!(""); - + // Start HTTP server info!("Starting HTTP server on {}...", app_url); - + let pool_data = web::Data::new(pool); - + HttpServer::new(move || { App::new() .app_data(pool_data.clone()) .wrap(Cors::permissive()) .wrap(actix_web::middleware::Logger::default()) - .configure(api::configure_all_routes) + .configure(stackdog::api::configure_all_routes) }) .bind(&app_url)? .run() .await } -async fn run_sniff( - once: bool, - consume: bool, - output: String, - sources: Option, - interval: u64, - ai_provider: Option, - ai_model: Option, - ai_api_url: Option, - slack_webhook: Option, -) -> io::Result<()> { - let config = sniff::config::SniffConfig::from_env_and_args( - once, - consume, - &output, - sources.as_deref(), - interval, - ai_provider.as_deref(), - ai_model.as_deref(), - ai_api_url.as_deref(), - slack_webhook.as_deref(), - ); - +async fn run_sniff(config: sniff::config::SniffConfig) -> io::Result<()> { info!("šŸ” Stackdog Sniff starting..."); - info!("Mode: {}", if config.once { "one-shot" } else { "continuous" }); + info!( + "Mode: {}", + if config.once { + "one-shot" + } else { + "continuous" + } + ); info!("Consume: {}", config.consume); info!("Output: {}", config.output_dir.display()); info!("Interval: {}s", config.interval_secs); @@ -168,10 +172,7 @@ async fn run_sniff( info!("Slack: configured āœ“"); } - let orchestrator = sniff::SniffOrchestrator::new(config) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + let orchestrator = sniff::SniffOrchestrator::new(config).map_err(io::Error::other)?; - orchestrator.run().await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + orchestrator.run().await.map_err(io::Error::other) } - diff --git a/src/ml/features.rs b/src/ml/features.rs index d6ccd88..8abe268 100644 --- a/src/ml/features.rs +++ b/src/ml/features.rs @@ -2,8 +2,6 @@ //! //! Extracts features from security events for anomaly detection -use anyhow::Result; - /// Security features for ML model pub struct SecurityFeatures { pub syscall_rate: f64, diff --git a/src/ml/mod.rs b/src/ml/mod.rs index fdb65f4..8a46c20 100644 --- a/src/ml/mod.rs +++ b/src/ml/mod.rs @@ -2,11 +2,11 @@ //! //! Machine learning for anomaly detection using Candle +pub mod anomaly; pub mod candle_backend; pub mod features; -pub mod anomaly; -pub mod scorer; pub mod models; +pub mod scorer; /// Marker struct for module tests pub struct MlMarker; diff --git a/src/models/api/mod.rs b/src/models/api/mod.rs index 63306b0..26e8bcd 100644 --- a/src/models/api/mod.rs +++ b/src/models/api/mod.rs @@ -1,11 +1,13 @@ //! API models -pub mod security; pub mod alerts; pub mod containers; +pub mod security; pub mod threats; -pub use security::SecurityStatusResponse; pub use alerts::{AlertResponse, AlertStatsResponse}; -pub use containers::{ContainerResponse, ContainerSecurityStatus, NetworkActivity, QuarantineRequest}; +pub use containers::{ + ContainerResponse, ContainerSecurityStatus, NetworkActivity, QuarantineRequest, +}; +pub use security::SecurityStatusResponse; pub use threats::{ThreatResponse, ThreatStatisticsResponse}; diff --git a/src/rules/builtin.rs b/src/rules/builtin.rs index c7b1bed..c5da7b1 100644 --- a/src/rules/builtin.rs +++ b/src/rules/builtin.rs @@ -2,8 +2,8 @@ //! //! Pre-defined rules for common security scenarios -use crate::events::syscall::{SyscallEvent, SyscallType}; use crate::events::security::SecurityEvent; +use crate::events::syscall::SyscallType; use crate::rules::rule::{Rule, RuleResult}; /// Syscall allowlist rule @@ -30,11 +30,11 @@ impl Rule for SyscallAllowlistRule { RuleResult::NoMatch } } - + fn name(&self) -> &str { "syscall_allowlist" } - + fn priority(&self) -> u32 { 50 } @@ -56,7 +56,7 @@ impl Rule for SyscallBlocklistRule { fn evaluate(&self, event: &SecurityEvent) -> RuleResult { if let SecurityEvent::Syscall(syscall_event) = event { if self.blocked.contains(&syscall_event.syscall_type) { - RuleResult::Match // Match means violation detected + RuleResult::Match // Match means violation detected } else { RuleResult::NoMatch } @@ -64,13 +64,13 @@ impl Rule for SyscallBlocklistRule { RuleResult::NoMatch } } - + fn name(&self) -> &str { "syscall_blocklist" } - + fn priority(&self) -> u32 { - 10 // High priority for security violations + 10 // High priority for security violations } } @@ -106,11 +106,11 @@ impl Rule for ProcessExecutionRule { RuleResult::NoMatch } } - + fn name(&self) -> &str { "process_execution" } - + fn priority(&self) -> u32 { 30 } @@ -149,11 +149,11 @@ impl Rule for NetworkConnectionRule { RuleResult::NoMatch } } - + fn name(&self) -> &str { "network_connection" } - + fn priority(&self) -> u32 { 40 } @@ -192,11 +192,11 @@ impl Rule for FileAccessRule { RuleResult::NoMatch } } - + fn name(&self) -> &str { "file_access" } - + fn priority(&self) -> u32 { 60 } @@ -205,22 +205,29 @@ impl Rule for FileAccessRule { #[cfg(test)] mod tests { use super::*; + use crate::events::syscall::SyscallEvent; use chrono::Utc; - + #[test] fn test_allowlist_rule() { let rule = SyscallAllowlistRule::new(vec![SyscallType::Execve]); let event = SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Execve, Utc::now(), + 1234, + 1000, + SyscallType::Execve, + Utc::now(), )); assert!(rule.evaluate(&event).is_match()); } - + #[test] fn test_blocklist_rule() { let rule = SyscallBlocklistRule::new(vec![SyscallType::Ptrace]); let event = SecurityEvent::Syscall(SyscallEvent::new( - 1234, 1000, SyscallType::Ptrace, Utc::now(), + 1234, + 1000, + SyscallType::Ptrace, + Utc::now(), )); assert!(rule.evaluate(&event).is_match()); } diff --git a/src/rules/engine.rs b/src/rules/engine.rs index 406f40f..99705d5 100644 --- a/src/rules/engine.rs +++ b/src/rules/engine.rs @@ -2,10 +2,9 @@ //! //! Manages and evaluates security rules -use anyhow::Result; use crate::events::security::SecurityEvent; -use crate::rules::rule::{Rule, RuleResult}; use crate::rules::result::RuleEvaluationResult; +use crate::rules::rule::{Rule, RuleResult}; /// Rule engine for evaluating security rules pub struct RuleEngine { @@ -21,7 +20,7 @@ impl RuleEngine { enabled_rules: std::collections::HashSet::new(), } } - + /// Register a rule with the engine pub fn register_rule(&mut self, rule: Box) { let name = rule.name().to_string(); @@ -30,13 +29,13 @@ impl RuleEngine { // Sort by priority after adding self.rules.sort_by_key(|r| r.priority()); } - + /// Remove a rule by name pub fn remove_rule(&mut self, name: &str) { self.rules.retain(|r| r.name() != name); self.enabled_rules.remove(name); } - + /// Evaluate all rules against an event pub fn evaluate(&self, event: &SecurityEvent) -> Vec { self.rules @@ -48,51 +47,45 @@ impl RuleEngine { .map(|rule| rule.evaluate(event)) .collect() } - + /// Evaluate with detailed results pub fn evaluate_detailed(&self, event: &SecurityEvent) -> Vec { self.rules .iter() - .filter(|rule| { - self.enabled_rules.contains(rule.name()) && rule.enabled() - }) + .filter(|rule| self.enabled_rules.contains(rule.name()) && rule.enabled()) .map(|rule| { let result = rule.evaluate(event); - RuleEvaluationResult::new( - rule.name().to_string(), - event.clone(), - result, - ) + RuleEvaluationResult::new(rule.name().to_string(), event.clone(), result) }) .collect() } - + /// Get the number of registered rules pub fn rule_count(&self) -> usize { self.rules.len() } - + /// Clear all rules pub fn clear_all_rules(&mut self) { self.rules.clear(); self.enabled_rules.clear(); } - + /// Enable a rule pub fn enable_rule(&mut self, name: &str) { self.enabled_rules.insert(name.to_string()); } - + /// Disable a rule pub fn disable_rule(&mut self, name: &str) { self.enabled_rules.remove(name); } - + /// Check if a rule is enabled pub fn is_rule_enabled(&self, name: &str) -> bool { self.enabled_rules.contains(name) } - + /// Get all rule names pub fn rule_names(&self) -> Vec<&str> { self.rules.iter().map(|r| r.name()).collect() @@ -108,31 +101,7 @@ impl Default for RuleEngine { #[cfg(test)] mod tests { use super::*; - - struct TestRule { - name: String, - priority: u32, - should_match: bool, - } - - impl Rule for TestRule { - fn evaluate(&self, _event: &SecurityEvent) -> RuleResult { - if self.should_match { - RuleResult::Match - } else { - RuleResult::NoMatch - } - } - - fn name(&self) -> &str { - &self.name - } - - fn priority(&self) -> u32 { - self.priority - } - } - + #[test] fn test_engine_creation() { let engine = RuleEngine::new(); diff --git a/src/rules/mod.rs b/src/rules/mod.rs index 3783d49..c0ad356 100644 --- a/src/rules/mod.rs +++ b/src/rules/mod.rs @@ -2,23 +2,23 @@ //! //! Contains the rule engine for security rule evaluation -pub mod engine; -pub mod rule; -pub mod signatures; pub mod builtin; +pub mod engine; pub mod result; +pub mod rule; pub mod signature_matcher; -pub mod threat_scorer; +pub mod signatures; pub mod stats; +pub mod threat_scorer; /// Marker struct for module tests pub struct RulesMarker; // Re-export commonly used types pub use engine::RuleEngine; +pub use result::{RuleEvaluationResult, Severity}; pub use rule::{Rule, RuleResult}; +pub use signature_matcher::{MatchResult, PatternMatch, SignatureMatcher}; pub use signatures::{Signature, SignatureDatabase, ThreatCategory}; -pub use result::{RuleEvaluationResult, Severity}; -pub use signature_matcher::{SignatureMatcher, PatternMatch, MatchResult}; -pub use threat_scorer::{ThreatScorer, ThreatScore, ScoringConfig}; pub use stats::{DetectionStats, StatsTracker}; +pub use threat_scorer::{ScoringConfig, ThreatScore, ThreatScorer}; diff --git a/src/rules/result.rs b/src/rules/result.rs index f1e413f..37af375 100644 --- a/src/rules/result.rs +++ b/src/rules/result.rs @@ -27,7 +27,7 @@ impl Severity { _ => Severity::Info, } } - + /// Get the numeric score for this severity pub fn score(&self) -> u8 { match self { @@ -63,11 +63,7 @@ pub struct RuleEvaluationResult { impl RuleEvaluationResult { /// Create a new evaluation result - pub fn new( - rule_name: String, - event: SecurityEvent, - result: RuleResult, - ) -> Self { + pub fn new(rule_name: String, event: SecurityEvent, result: RuleResult) -> Self { Self { rule_name, event, @@ -75,37 +71,37 @@ impl RuleEvaluationResult { timestamp: chrono::Utc::now(), } } - + /// Get the rule name pub fn rule_name(&self) -> &str { &self.rule_name } - + /// Get the event pub fn event(&self) -> &SecurityEvent { &self.event } - + /// Get the result pub fn result(&self) -> &RuleResult { &self.result } - + /// Get the timestamp pub fn timestamp(&self) -> chrono::DateTime { self.timestamp } - + /// Check if the rule matched pub fn matched(&self) -> bool { self.result.is_match() } - + /// Check if the rule did not match pub fn not_matched(&self) -> bool { self.result.is_no_match() } - + /// Check if there was an error pub fn has_error(&self) -> bool { self.result.is_error() @@ -117,27 +113,30 @@ pub fn calculate_aggregate_severity(severities: &[Severity]) -> Severity { if severities.is_empty() { return Severity::Info; } - + // Return the highest severity *severities.iter().max().unwrap_or(&Severity::Info) } /// Calculate aggregate severity from rule results -pub fn calculate_severity_from_results(results: &[RuleEvaluationResult], base_severities: &[Severity]) -> Severity { +pub fn calculate_severity_from_results( + results: &[RuleEvaluationResult], + base_severities: &[Severity], +) -> Severity { let matched_severities: Vec = results .iter() .filter(|r| r.matched()) .enumerate() .map(|(i, _)| base_severities.get(i).copied().unwrap_or(Severity::Medium)) .collect(); - + calculate_aggregate_severity(&matched_severities) } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_severity_ordering() { assert!(Severity::Info < Severity::Low); @@ -145,7 +144,7 @@ mod tests { assert!(Severity::Medium < Severity::High); assert!(Severity::High < Severity::Critical); } - + #[test] fn test_severity_from_score() { assert_eq!(Severity::from_score(0), Severity::Info); @@ -154,25 +153,25 @@ mod tests { assert_eq!(Severity::from_score(80), Severity::High); assert_eq!(Severity::from_score(95), Severity::Critical); } - + #[test] fn test_severity_display() { assert_eq!(format!("{}", Severity::High), "High"); } - + #[test] fn test_aggregate_severity_empty() { let result = calculate_aggregate_severity(&[]); assert_eq!(result, Severity::Info); } - + #[test] fn test_aggregate_severity_single() { let severities = vec![Severity::High]; let result = calculate_aggregate_severity(&severities); assert_eq!(result, Severity::High); } - + #[test] fn test_aggregate_severity_multiple() { let severities = vec![Severity::Low, Severity::Medium, Severity::High]; diff --git a/src/rules/rule.rs b/src/rules/rule.rs index 02fc571..9e46409 100644 --- a/src/rules/rule.rs +++ b/src/rules/rule.rs @@ -17,12 +17,12 @@ impl RuleResult { pub fn is_match(&self) -> bool { matches!(self, RuleResult::Match) } - + /// Check if this is no match pub fn is_no_match(&self) -> bool { matches!(self, RuleResult::NoMatch) } - + /// Check if this is an error pub fn is_error(&self) -> bool { matches!(self, RuleResult::Error(_)) @@ -43,15 +43,15 @@ impl std::fmt::Display for RuleResult { pub trait Rule: Send + Sync { /// Evaluate the rule against an event fn evaluate(&self, event: &SecurityEvent) -> RuleResult; - + /// Get the rule name fn name(&self) -> &str; - + /// Get the rule priority (lower = higher priority) fn priority(&self) -> u32 { 100 } - + /// Check if the rule is enabled fn enabled(&self) -> bool { true diff --git a/src/rules/signature_matcher.rs b/src/rules/signature_matcher.rs index 76a685a..5d35e7f 100644 --- a/src/rules/signature_matcher.rs +++ b/src/rules/signature_matcher.rs @@ -2,16 +2,16 @@ //! //! Advanced signature matching with multi-event pattern detection -use crate::events::syscall::SyscallType; use crate::events::security::SecurityEvent; -use crate::rules::signatures::{SignatureDatabase, Signature}; +use crate::events::syscall::SyscallType; +use crate::rules::signatures::SignatureDatabase; use chrono::{DateTime, Utc}; /// Pattern match definition #[derive(Debug, Clone)] pub struct PatternMatch { syscalls: Vec, - time_window: Option, // Seconds + time_window: Option, // Seconds description: String, } @@ -24,41 +24,41 @@ impl PatternMatch { description: String::new(), } } - + /// Add a syscall to the pattern pub fn with_syscall(mut self, syscall: SyscallType) -> Self { self.syscalls.push(syscall); self } - + /// Add next syscall in sequence pub fn then_syscall(mut self, syscall: SyscallType) -> Self { self.syscalls.push(syscall); self } - + /// Set time window for pattern (in seconds) pub fn within_seconds(mut self, seconds: u64) -> Self { self.time_window = Some(seconds); self } - + /// Set description pub fn with_description(mut self, desc: impl Into) -> Self { self.description = desc.into(); self } - + /// Get syscalls in pattern pub fn syscalls(&self) -> &[SyscallType] { &self.syscalls } - + /// Get time window pub fn time_window(&self) -> Option { self.time_window } - + /// Get description pub fn description(&self) -> &str { &self.description @@ -88,7 +88,7 @@ impl MatchResult { confidence, } } - + /// Create empty (no match) result pub fn no_match() -> Self { Self { @@ -97,17 +97,17 @@ impl MatchResult { confidence: 0.0, } } - + /// Get matched signatures pub fn matches(&self) -> &[String] { &self.matches } - + /// Check if matched pub fn is_match(&self) -> bool { self.is_match } - + /// Get confidence score (0.0 - 1.0) pub fn confidence(&self) -> f64 { self.confidence @@ -117,8 +117,12 @@ impl MatchResult { impl std::fmt::Display for MatchResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.is_match { - write!(f, "Match ({} signatures, confidence: {:.2})", - self.matches.len(), self.confidence) + write!( + f, + "Match ({} signatures, confidence: {:.2})", + self.matches.len(), + self.confidence + ) } else { write!(f, "NoMatch") } @@ -139,52 +143,47 @@ impl SignatureMatcher { patterns: Vec::new(), } } - + /// Add a pattern to match pub fn add_pattern(&mut self, pattern: PatternMatch) { self.patterns.push(pattern); } - + /// Match a single event against signatures pub fn match_single(&self, event: &SecurityEvent) -> MatchResult { let signatures = self.db.detect(event); - + if signatures.is_empty() { return MatchResult::no_match(); } - - let matches: Vec = signatures - .iter() - .map(|s| s.name().to_string()) - .collect(); - + + let matches: Vec = signatures.iter().map(|s| s.name().to_string()).collect(); + // Calculate confidence based on severity - let avg_severity = signatures - .iter() - .map(|s| s.severity() as f64) - .sum::() / signatures.len() as f64; - + let avg_severity = + signatures.iter().map(|s| s.severity() as f64).sum::() / signatures.len() as f64; + let confidence = avg_severity / 100.0; - + MatchResult::new(matches, true, confidence) } - + /// Match a sequence of events against patterns pub fn match_sequence(&self, events: &[SecurityEvent]) -> MatchResult { if events.is_empty() { return MatchResult::no_match(); } - + for pattern in &self.patterns { if self.matches_pattern(pattern, events) { return MatchResult::new( vec![pattern.description().to_string()], true, - 0.9, // High confidence for pattern match + 0.9, // High confidence for pattern match ); } } - + // Also check individual events let mut all_matches = Vec::new(); for event in events { @@ -193,26 +192,26 @@ impl SignatureMatcher { all_matches.extend(result.matches().iter().cloned()); } } - + if all_matches.is_empty() { MatchResult::no_match() } else { MatchResult::new(all_matches, true, 0.7) } } - + /// Check if events match a pattern fn matches_pattern(&self, pattern: &PatternMatch, events: &[SecurityEvent]) -> bool { // Need at least as many events as pattern syscalls if events.len() < pattern.syscalls().len() { return false; } - + // Check if pattern syscalls appear in order let mut event_idx = 0; let mut matched_syscalls = 0; let mut first_match_time: Option> = None; - + for required_syscall in pattern.syscalls() { while event_idx < events.len() { if let SecurityEvent::Syscall(syscall_event) = &events[event_idx] { @@ -221,7 +220,7 @@ impl SignatureMatcher { if first_match_time.is_none() { first_match_time = Some(syscall_event.timestamp); } - + matched_syscalls += 1; event_idx += 1; break; @@ -230,37 +229,37 @@ impl SignatureMatcher { event_idx += 1; } } - + // Check if all syscalls matched if matched_syscalls != pattern.syscalls().len() { return false; } - + // Check time window if specified if let Some(window) = pattern.time_window() { - if let (Some(first), Some(last)) = (first_match_time, events.last()) { - if let SecurityEvent::Syscall(last_event) = last { - let elapsed = last_event.timestamp - first; - if elapsed.num_seconds() > window as i64 { - return false; - } + if let (Some(first), Some(SecurityEvent::Syscall(last_event))) = + (first_match_time, events.last()) + { + let elapsed = last_event.timestamp - first; + if elapsed.num_seconds() > window as i64 { + return false; } } } - + true } - + /// Get signature database pub fn database(&self) -> &SignatureDatabase { &self.db } - + /// Get patterns pub fn patterns(&self) -> &[PatternMatch] { &self.patterns } - + /// Clear patterns pub fn clear_patterns(&mut self) { self.patterns.clear(); @@ -276,7 +275,7 @@ impl Default for SignatureMatcher { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_pattern_match_builder() { let pattern = PatternMatch::new() @@ -284,17 +283,17 @@ mod tests { .then_syscall(SyscallType::Connect) .within_seconds(60) .with_description("Test pattern"); - + assert_eq!(pattern.syscalls().len(), 2); assert_eq!(pattern.time_window(), Some(60)); assert_eq!(pattern.description(), "Test pattern"); } - + #[test] fn test_match_result_display() { let result = MatchResult::new(vec!["sig1".to_string()], true, 0.8); assert!(format!("{}", result).contains("Match")); - + let no_result = MatchResult::no_match(); assert!(format!("{}", no_result).contains("NoMatch")); } diff --git a/src/rules/signatures.rs b/src/rules/signatures.rs index e5f0578..a77ed87 100644 --- a/src/rules/signatures.rs +++ b/src/rules/signatures.rs @@ -2,8 +2,8 @@ //! //! Known threat patterns and signatures for detection -use crate::events::syscall::{SyscallEvent, SyscallType}; use crate::events::security::SecurityEvent; +use crate::events::syscall::SyscallType; /// Threat categories #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -57,27 +57,27 @@ impl Signature { syscall_patterns, } } - + /// Get the signature name pub fn name(&self) -> &str { &self.name } - + /// Get the description pub fn description(&self) -> &str { &self.description } - + /// Get the severity (0-100) pub fn severity(&self) -> u8 { self.severity } - + /// Get the category pub fn category(&self) -> &ThreatCategory { &self.category } - + /// Check if a syscall matches this signature pub fn matches(&self, syscall_type: &SyscallType) -> bool { self.syscall_patterns.contains(syscall_type) @@ -95,12 +95,12 @@ impl SignatureDatabase { let mut db = Self { signatures: Vec::new(), }; - + // Load built-in signatures db.load_builtin_signatures(); db } - + /// Load built-in threat signatures fn load_builtin_signatures(&mut self) { // Crypto miner detection - execve + setuid pattern @@ -111,7 +111,7 @@ impl SignatureDatabase { ThreatCategory::CryptoMiner, vec![SyscallType::Execve, SyscallType::Setuid], )); - + // Container escape - ptrace + mount pattern self.signatures.push(Signature::new( "container_escape_ptrace", @@ -120,7 +120,7 @@ impl SignatureDatabase { ThreatCategory::ContainerEscape, vec![SyscallType::Ptrace], )); - + self.signatures.push(Signature::new( "container_escape_mount", "Detects mount syscall associated with container escape attempts", @@ -128,7 +128,7 @@ impl SignatureDatabase { ThreatCategory::ContainerEscape, vec![SyscallType::Mount], )); - + // Network scanner - connect + bind pattern self.signatures.push(Signature::new( "network_scanner_connect", @@ -137,7 +137,7 @@ impl SignatureDatabase { ThreatCategory::NetworkScanner, vec![SyscallType::Connect], )); - + self.signatures.push(Signature::new( "network_scanner_bind", "Detects bind syscall commonly used by network scanners", @@ -145,7 +145,7 @@ impl SignatureDatabase { ThreatCategory::NetworkScanner, vec![SyscallType::Bind], )); - + // Privilege escalation - setuid + setgid pattern self.signatures.push(Signature::new( "privilege_escalation_setuid", @@ -154,7 +154,7 @@ impl SignatureDatabase { ThreatCategory::PrivilegeEscalation, vec![SyscallType::Setuid, SyscallType::Setgid], )); - + // Data exfiltration - connect pattern self.signatures.push(Signature::new( "data_exfiltration_network", @@ -163,7 +163,7 @@ impl SignatureDatabase { ThreatCategory::DataExfiltration, vec![SyscallType::Connect, SyscallType::Sendto], )); - + // Malware indicators self.signatures.push(Signature::new( "malware_execve_tmp", @@ -172,7 +172,7 @@ impl SignatureDatabase { ThreatCategory::Malware, vec![SyscallType::Execve], )); - + // Suspicious activity self.signatures.push(Signature::new( "suspicious_execveat", @@ -181,7 +181,7 @@ impl SignatureDatabase { ThreatCategory::Suspicious, vec![SyscallType::Execveat], )); - + self.signatures.push(Signature::new( "suspicious_openat", "Detects openat syscall for file access monitoring", @@ -190,27 +190,27 @@ impl SignatureDatabase { vec![SyscallType::Openat], )); } - + /// Get all signatures pub fn get_signatures(&self) -> &[Signature] { &self.signatures } - + /// Get signature count pub fn signature_count(&self) -> usize { self.signatures.len() } - + /// Add a custom signature pub fn add_signature(&mut self, signature: Signature) { self.signatures.push(signature); } - + /// Remove a signature by name pub fn remove_signature(&mut self, name: &str) { self.signatures.retain(|sig| sig.name() != name); } - + /// Get signatures by category pub fn get_signatures_by_category(&self, category: &ThreatCategory) -> Vec<&Signature> { self.signatures @@ -218,7 +218,7 @@ impl SignatureDatabase { .filter(|sig| sig.category() == category) .collect() } - + /// Find signatures that match a syscall pub fn find_matching(&self, syscall_type: &SyscallType) -> Vec<&Signature> { self.signatures @@ -226,7 +226,7 @@ impl SignatureDatabase { .filter(|sig| sig.matches(syscall_type)) .collect() } - + /// Detect threats in an event pub fn detect(&self, event: &SecurityEvent) -> Vec<&Signature> { match event { @@ -247,7 +247,7 @@ impl Default for SignatureDatabase { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_signature_creation() { let sig = Signature::new( @@ -260,7 +260,7 @@ mod tests { assert_eq!(sig.name(), "test_sig"); assert_eq!(sig.severity(), 50); } - + #[test] fn test_threat_category_display() { assert_eq!(format!("{}", ThreatCategory::Suspicious), "Suspicious"); diff --git a/src/rules/stats.rs b/src/rules/stats.rs index 3289e77..752efcf 100644 --- a/src/rules/stats.rs +++ b/src/rules/stats.rs @@ -29,97 +29,97 @@ impl DetectionStats { last_updated: now, } } - + /// Record an event being processed pub fn record_event(&mut self) { self.events_processed += 1; self.last_updated = Utc::now(); } - + /// Record a signature match pub fn record_match(&mut self) { self.signatures_matched += 1; self.true_positives += 1; self.last_updated = Utc::now(); } - + /// Record a false positive pub fn record_false_positive(&mut self) { self.false_positives += 1; self.last_updated = Utc::now(); } - + /// Get events processed count pub fn events_processed(&self) -> u64 { self.events_processed } - + /// Get signatures matched count pub fn signatures_matched(&self) -> u64 { self.signatures_matched } - + /// Get false positives count pub fn false_positives(&self) -> u64 { self.false_positives } - + /// Get true positives count pub fn true_positives(&self) -> u64 { self.true_positives } - + /// Get start time pub fn start_time(&self) -> DateTime { self.start_time } - + /// Get last updated time pub fn last_updated(&self) -> DateTime { self.last_updated } - + /// Calculate detection rate (matches / events) pub fn detection_rate(&self) -> f64 { if self.events_processed == 0 { return 0.0; } - + self.signatures_matched as f64 / self.events_processed as f64 } - + /// Calculate false positive rate pub fn false_positive_rate(&self) -> f64 { let total_matches = self.true_positives + self.false_positives; if total_matches == 0 { return 0.0; } - + self.false_positives as f64 / total_matches as f64 } - + /// Calculate precision (true positives / all matches) pub fn precision(&self) -> f64 { let total_matches = self.true_positives + self.false_positives; if total_matches == 0 { - return 1.0; // No matches = no false positives + return 1.0; // No matches = no false positives } - + self.true_positives as f64 / total_matches as f64 } - + /// Get uptime duration pub fn uptime(&self) -> chrono::Duration { self.last_updated - self.start_time } - + /// Get events per second pub fn events_per_second(&self) -> f64 { let uptime_secs = self.uptime().num_seconds() as f64; if uptime_secs <= 0.0 { return 0.0; } - + self.events_processed as f64 / uptime_secs } } @@ -155,7 +155,7 @@ impl StatsTracker { stats: DetectionStats::new(), }) } - + /// Record an event with match result pub fn record_event(&mut self, _event: &SecurityEvent, matched: bool) { self.stats.record_event(); @@ -163,17 +163,17 @@ impl StatsTracker { self.stats.record_match(); } } - + /// Get current stats pub fn stats(&self) -> &DetectionStats { &self.stats } - + /// Get mutable stats pub fn stats_mut(&mut self) -> &mut DetectionStats { &mut self.stats } - + /// Reset stats pub fn reset(&mut self) { self.stats = DetectionStats::new(); @@ -189,57 +189,57 @@ impl Default for StatsTracker { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_detection_stats_creation() { let stats = DetectionStats::new(); assert_eq!(stats.events_processed(), 0); assert_eq!(stats.signatures_matched(), 0); } - + #[test] fn test_detection_stats_recording() { let mut stats = DetectionStats::new(); - + stats.record_event(); stats.record_event(); stats.record_match(); - + assert_eq!(stats.events_processed(), 2); assert_eq!(stats.signatures_matched(), 1); } - + #[test] fn test_detection_rate() { let mut stats = DetectionStats::new(); - + for _ in 0..10 { stats.record_event(); } for _ in 0..3 { stats.record_match(); } - + assert!((stats.detection_rate() - 0.3).abs() < 0.01); } - + #[test] fn test_false_positive_rate() { let mut stats = DetectionStats::new(); - - stats.record_match(); // true positive - stats.record_match(); // true positive + + stats.record_match(); // true positive + stats.record_match(); // true positive stats.record_false_positive(); - + assert!((stats.false_positive_rate() - 0.333).abs() < 0.01); } - + #[test] fn test_stats_display() { let mut stats = DetectionStats::new(); stats.record_event(); stats.record_match(); - + let display = format!("{}", stats); assert!(display.contains("events")); assert!(display.contains("matches")); diff --git a/src/rules/threat_scorer.rs b/src/rules/threat_scorer.rs index c1807bd..7e7c30d 100644 --- a/src/rules/threat_scorer.rs +++ b/src/rules/threat_scorer.rs @@ -5,7 +5,6 @@ use crate::events::security::SecurityEvent; use crate::rules::result::Severity; use crate::rules::signature_matcher::SignatureMatcher; -use chrono::Utc; /// Threat score (0-100) #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -20,32 +19,32 @@ impl ThreatScore { value: value.min(100), } } - + /// Get the score value pub fn value(&self) -> u8 { self.value } - + /// Get severity from score pub fn severity(&self) -> Severity { Severity::from_score(self.value) } - + /// Check if score exceeds threshold pub fn exceeds_threshold(&self, threshold: u8) -> bool { self.value >= threshold } - + /// Check if score is high or higher (>= 70) pub fn is_high_or_higher(&self) -> bool { self.value >= 70 } - + /// Check if score is critical (>= 90) pub fn is_critical(&self) -> bool { self.value >= 90 } - + /// Add to score (capped at 100) pub fn add(&mut self, value: u8) { self.value = (self.value + value).min(100); @@ -68,50 +67,55 @@ pub struct ScoringConfig { } impl ScoringConfig { - /// Create default config - pub fn default() -> Self { + /// Create a new scoring config + pub fn new( + base_score: u8, + multiplier: f64, + time_decay_enabled: bool, + decay_half_life_seconds: u64, + ) -> Self { Self { - base_score: 50, - multiplier: 1.0, - time_decay_enabled: false, - decay_half_life_seconds: 3600, // 1 hour + base_score, + multiplier, + time_decay_enabled, + decay_half_life_seconds, } } - + /// Set base score pub fn with_base_score(mut self, score: u8) -> Self { self.base_score = score; self } - + /// Set multiplier pub fn with_multiplier(mut self, multiplier: f64) -> Self { self.multiplier = multiplier; self } - + /// Enable time decay pub fn with_time_decay(mut self, enabled: bool) -> Self { self.time_decay_enabled = enabled; self } - + /// Set decay half-life pub fn with_decay_half_life(mut self, seconds: u64) -> Self { self.decay_half_life_seconds = seconds; self } - + /// Check if time decay is enabled pub fn time_decay_enabled(&self) -> bool { self.time_decay_enabled } - + /// Get base score pub fn base_score(&self) -> u8 { self.base_score } - + /// Get multiplier pub fn multiplier(&self) -> f64 { self.multiplier @@ -120,7 +124,7 @@ impl ScoringConfig { impl Default for ScoringConfig { fn default() -> Self { - Self::default() + Self::new(50, 1.0, false, 3600) } } @@ -138,7 +142,7 @@ impl ThreatScorer { matcher: SignatureMatcher::new(), } } - + /// Create scorer with custom config pub fn with_config(config: ScoringConfig) -> Self { Self { @@ -146,7 +150,7 @@ impl ThreatScorer { matcher: SignatureMatcher::new(), } } - + /// Create scorer with custom matcher pub fn with_matcher(matcher: SignatureMatcher) -> Self { Self { @@ -154,57 +158,57 @@ impl ThreatScorer { matcher, } } - + /// Calculate threat score for an event pub fn calculate_score(&self, event: &SecurityEvent) -> ThreatScore { // Get signature matches let match_result = self.matcher.match_single(event); - + if !match_result.is_match() { return ThreatScore::new(0); } - + // Start with base score let mut score = self.config.base_score() as f64; - + // Apply multiplier based on confidence score *= match_result.confidence(); score *= self.config.multiplier(); - + // Apply time decay if enabled if self.config.time_decay_enabled { // Time decay would be applied based on event age // For now, use full score (event is "recent") } - + ThreatScore::new(score as u8) } - + /// Calculate cumulative score for multiple events pub fn calculate_cumulative_score(&self, events: &[SecurityEvent]) -> ThreatScore { let mut total_score = 0u16; - + for event in events { let score = self.calculate_score(event); total_score += score.value() as u16; } - + // Average score with bonus for multiple events if events.is_empty() { return ThreatScore::new(0); } - + let avg_score = total_score / events.len() as u16; - let bonus = (events.len() as u16).min(20); // Up to 20% bonus - + let bonus = (events.len() as u16).min(20); // Up to 20% bonus + ThreatScore::new(((avg_score as f64) * (1.0 + bonus as f64 / 100.0)) as u8) } - + /// Get the signature matcher pub fn matcher(&self) -> &SignatureMatcher { &self.matcher } - + /// Get the scoring config pub fn config(&self) -> &ScoringConfig { &self.config @@ -227,7 +231,7 @@ pub fn calculate_severity_from_scores(scores: &[ThreatScore]) -> Severity { if scores.is_empty() { return Severity::Info; } - + let max_score = scores.iter().map(|s| s.value()).max().unwrap_or(0); Severity::from_score(max_score) } @@ -235,40 +239,40 @@ pub fn calculate_severity_from_scores(scores: &[ThreatScore]) -> Severity { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_threat_score_creation() { let score = ThreatScore::new(75); assert_eq!(score.value(), 75); } - + #[test] fn test_threat_score_cap() { let score = ThreatScore::new(150); assert_eq!(score.value(), 100); } - + #[test] fn test_threat_score_add() { let mut score = ThreatScore::new(50); score.add(30); assert_eq!(score.value(), 80); } - + #[test] fn test_threat_score_add_cap() { let mut score = ThreatScore::new(90); score.add(50); assert_eq!(score.value(), 100); } - + #[test] fn test_scoring_config_builder() { let config = ScoringConfig::default() .with_base_score(60) .with_multiplier(1.5) .with_time_decay(true); - + assert_eq!(config.base_score(), 60); assert_eq!(config.multiplier(), 1.5); assert!(config.time_decay_enabled()); diff --git a/src/sniff/analyzer.rs b/src/sniff/analyzer.rs index 5eee30e..f0275f7 100644 --- a/src/sniff/analyzer.rs +++ b/src/sniff/analyzer.rs @@ -4,7 +4,7 @@ //! - OpenAI-compatible API (works with OpenAI, Ollama, vLLM, etc.) //! - Local Candle inference (requires `ml` feature) -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; @@ -173,14 +173,17 @@ fn parse_severity(s: &str) -> AnomalySeverity { /// Parse the LLM JSON response into a LogSummary fn parse_llm_response(source_id: &str, entries: &[LogEntry], raw_json: &str) -> Result { - log::debug!("Parsing LLM response ({} bytes) for source {}", raw_json.len(), source_id); + log::debug!( + "Parsing LLM response ({} bytes) for source {}", + raw_json.len(), + source_id + ); log::trace!("Raw LLM response:\n{}", raw_json); - let analysis: LlmAnalysis = serde_json::from_str(raw_json) - .context(format!( - "Failed to parse LLM response as JSON. Response starts with: {}", - &raw_json[..raw_json.len().min(200)] - ))?; + let analysis: LlmAnalysis = serde_json::from_str(raw_json).context(format!( + "Failed to parse LLM response as JSON. Response starts with: {}", + &raw_json[..raw_json.len().min(200)] + ))?; log::debug!( "LLM analysis parsed — summary: {:?}, errors: {:?}, warnings: {:?}, anomalies: {}", @@ -190,7 +193,9 @@ fn parse_llm_response(source_id: &str, entries: &[LogEntry], raw_json: &str) -> analysis.anomalies.as_ref().map(|a| a.len()).unwrap_or(0), ); - let anomalies = analysis.anomalies.unwrap_or_default() + let anomalies = analysis + .anomalies + .unwrap_or_default() .into_iter() .map(|a| LogAnomaly { description: a.description.unwrap_or_default(), @@ -206,7 +211,9 @@ fn parse_llm_response(source_id: &str, entries: &[LogEntry], raw_json: &str) -> period_start: start, period_end: end, total_entries: entries.len(), - summary_text: analysis.summary.unwrap_or_else(|| "No summary available".into()), + summary_text: analysis + .summary + .unwrap_or_else(|| "No summary available".into()), error_count: analysis.error_count.unwrap_or(0), warning_count: analysis.warning_count.unwrap_or(0), key_events: analysis.key_events.unwrap_or_default(), @@ -220,8 +227,16 @@ fn entry_time_range(entries: &[LogEntry]) -> (DateTime, DateTime) { let now = Utc::now(); return (now, now); } - let start = entries.iter().map(|e| e.timestamp).min().unwrap_or_else(Utc::now); - let end = entries.iter().map(|e| e.timestamp).max().unwrap_or_else(Utc::now); + let start = entries + .iter() + .map(|e| e.timestamp) + .min() + .unwrap_or_else(Utc::now); + let end = entries + .iter() + .map(|e| e.timestamp) + .max() + .unwrap_or_else(Utc::now); (start, end) } @@ -248,7 +263,9 @@ impl LogAnalyzer for OpenAiAnalyzer { log::debug!( "Sending {} entries to AI API (model: {}, url: {})", - entries.len(), self.model, self.api_url + entries.len(), + self.model, + self.api_url ); log::trace!("Prompt:\n{}", prompt); @@ -270,11 +287,17 @@ impl LogAnalyzer for OpenAiAnalyzer { let url = format!("{}/chat/completions", self.api_url.trim_end_matches('/')); log::debug!("POST {}", url); - let mut req = self.client.post(&url) + let mut req = self + .client + .post(&url) .header("Content-Type", "application/json"); if let Some(ref key) = self.api_key { - log::debug!("Using API key: {}...{}", &key[..key.len().min(4)], &key[key.len().saturating_sub(4)..]); + log::debug!( + "Using API key: {}...{}", + &key[..key.len().min(4)], + &key[key.len().saturating_sub(4)..] + ); req = req.header("Authorization", format!("Bearer {}", key)); } else { log::debug!("No API key configured (using keyless access)"); @@ -295,7 +318,9 @@ impl LogAnalyzer for OpenAiAnalyzer { anyhow::bail!("AI API returned status {}: {}", status, body); } - let raw_body = response.text().await + let raw_body = response + .text() + .await .context("Failed to read AI API response body")?; log::debug!("AI API response body ({} bytes)", raw_body.len()); log::trace!("AI API raw response:\n{}", raw_body); @@ -303,12 +328,17 @@ impl LogAnalyzer for OpenAiAnalyzer { let completion: ChatCompletionResponse = serde_json::from_str(&raw_body) .context("Failed to parse AI API response as ChatCompletion")?; - let content = completion.choices + let content = completion + .choices .first() .map(|c| c.message.content.clone()) .unwrap_or_default(); - log::debug!("LLM content ({} chars): {}", content.len(), &content[..content.len().min(200)]); + log::debug!( + "LLM content ({} chars): {}", + content.len(), + &content[..content.len().min(200)] + ); // Extract JSON from response — LLMs often wrap in markdown code fences let json_str = extract_json(&content); @@ -321,16 +351,25 @@ impl LogAnalyzer for OpenAiAnalyzer { /// Fallback local analyzer that uses pattern matching (no AI required) pub struct PatternAnalyzer; +impl Default for PatternAnalyzer { + fn default() -> Self { + Self::new() + } +} + impl PatternAnalyzer { pub fn new() -> Self { Self } fn count_pattern(entries: &[LogEntry], patterns: &[&str]) -> usize { - entries.iter().filter(|e| { - let lower = e.line.to_lowercase(); - patterns.iter().any(|p| lower.contains(p)) - }).count() + entries + .iter() + .filter(|e| { + let lower = e.line.to_lowercase(); + patterns.iter().any(|p| lower.contains(p)) + }) + .count() } } @@ -353,13 +392,17 @@ impl LogAnalyzer for PatternAnalyzer { } let source_id = &entries[0].source_id; - let error_count = Self::count_pattern(entries, &["error", "err", "fatal", "panic", "exception"]); + let error_count = + Self::count_pattern(entries, &["error", "err", "fatal", "panic", "exception"]); let warning_count = Self::count_pattern(entries, &["warn", "warning"]); let (start, end) = entry_time_range(entries); log::debug!( "PatternAnalyzer [{}]: {} entries, {} errors, {} warnings", - source_id, entries.len(), error_count, warning_count + source_id, + entries.len(), + error_count, + warning_count ); let mut anomalies = Vec::new(); @@ -368,11 +411,19 @@ impl LogAnalyzer for PatternAnalyzer { if error_count > entries.len() / 4 { log::debug!( "Error spike detected: {} errors / {} entries (threshold: >25%)", - error_count, entries.len() + error_count, + entries.len() ); - if let Some(sample) = entries.iter().find(|e| e.line.to_lowercase().contains("error")) { + if let Some(sample) = entries + .iter() + .find(|e| e.line.to_lowercase().contains("error")) + { anomalies.push(LogAnomaly { - description: format!("High error rate: {} errors in {} entries", error_count, entries.len()), + description: format!( + "High error rate: {} errors in {} entries", + error_count, + entries.len() + ), severity: AnomalySeverity::High, sample_line: sample.line.clone(), }); @@ -381,7 +432,9 @@ impl LogAnalyzer for PatternAnalyzer { let summary_text = format!( "{} log entries analyzed. {} errors, {} warnings detected.", - entries.len(), error_count, warning_count + entries.len(), + error_count, + warning_count ); Ok(LogSummary { @@ -404,12 +457,15 @@ mod tests { use std::collections::HashMap; fn make_entries(lines: &[&str]) -> Vec { - lines.iter().map(|line| LogEntry { - source_id: "test-source".into(), - timestamp: Utc::now(), - line: line.to_string(), - metadata: HashMap::new(), - }).collect() + lines + .iter() + .map(|line| LogEntry { + source_id: "test-source".into(), + timestamp: Utc::now(), + line: line.to_string(), + metadata: HashMap::new(), + }) + .collect() } #[test] @@ -518,7 +574,10 @@ mod tests { #[test] fn test_extract_json_with_preamble() { let input = "Here is the analysis:\n{\"summary\": \"ok\", \"error_count\": 0}"; - assert_eq!(extract_json(input), r#"{"summary": "ok", "error_count": 0}"#); + assert_eq!( + extract_json(input), + r#"{"summary": "ok", "error_count": 0}"# + ); } #[test] @@ -593,11 +652,8 @@ mod tests { #[test] fn test_openai_analyzer_new() { - let analyzer = OpenAiAnalyzer::new( - "http://localhost:11434/v1".into(), - None, - "llama3".into(), - ); + let analyzer = + OpenAiAnalyzer::new("http://localhost:11434/v1".into(), None, "llama3".into()); assert_eq!(analyzer.api_url, "http://localhost:11434/v1"); assert!(analyzer.api_key.is_none()); assert_eq!(analyzer.model, "llama3"); @@ -605,11 +661,8 @@ mod tests { #[tokio::test] async fn test_openai_analyzer_empty_entries() { - let analyzer = OpenAiAnalyzer::new( - "http://localhost:11434/v1".into(), - None, - "llama3".into(), - ); + let analyzer = + OpenAiAnalyzer::new("http://localhost:11434/v1".into(), None, "llama3".into()); let summary = analyzer.summarize(&[]).await.unwrap(); assert_eq!(summary.total_entries, 0); } diff --git a/src/sniff/config.rs b/src/sniff/config.rs index 0fa0294..6ddef85 100644 --- a/src/sniff/config.rs +++ b/src/sniff/config.rs @@ -12,14 +12,16 @@ pub enum AiProvider { Candle, } -impl AiProvider { - pub fn from_str(s: &str) -> Self { - match s.to_lowercase().as_str() { +impl std::str::FromStr for AiProvider { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(match s.to_lowercase().as_str() { "candle" => AiProvider::Candle, // "ollama" uses the same OpenAI-compatible API client "openai" | "ollama" => AiProvider::OpenAi, _ => AiProvider::OpenAi, - } + }) } } @@ -52,19 +54,22 @@ pub struct SniffConfig { pub webhook_url: Option, } +/// Arguments for building a SniffConfig +pub struct SniffArgs<'a> { + pub once: bool, + pub consume: bool, + pub output: &'a str, + pub sources: Option<&'a str>, + pub interval: u64, + pub ai_provider: Option<&'a str>, + pub ai_model: Option<&'a str>, + pub ai_api_url: Option<&'a str>, + pub slack_webhook: Option<&'a str>, +} + impl SniffConfig { /// Build config from environment variables, overridden by CLI args - pub fn from_env_and_args( - once: bool, - consume: bool, - output: &str, - sources: Option<&str>, - interval: u64, - ai_provider_arg: Option<&str>, - ai_model_arg: Option<&str>, - ai_api_url_arg: Option<&str>, - slack_webhook_arg: Option<&str>, - ) -> Self { + pub fn from_env_and_args(args: SniffArgs<'_>) -> Self { let env_sources = env::var("STACKDOG_LOG_SOURCES").unwrap_or_default(); let mut extra_sources: Vec = env_sources .split(',') @@ -72,7 +77,7 @@ impl SniffConfig { .filter(|s| !s.is_empty()) .collect(); - if let Some(cli_sources) = sources { + if let Some(cli_sources) = args.sources { for s in cli_sources.split(',') { let trimmed = s.trim().to_string(); if !trimmed.is_empty() && !extra_sources.contains(&trimmed) { @@ -81,47 +86,48 @@ impl SniffConfig { } } - let ai_provider_str = ai_provider_arg - .map(|s| s.to_string()) - .unwrap_or_else(|| env::var("STACKDOG_AI_PROVIDER").unwrap_or_else(|_| "openai".into())); + let ai_provider_str = args.ai_provider.map(|s| s.to_string()).unwrap_or_else(|| { + env::var("STACKDOG_AI_PROVIDER").unwrap_or_else(|_| "openai".into()) + }); - let output_dir = if output != "./stackdog-logs/" { - PathBuf::from(output) + let output_dir = if args.output != "./stackdog-logs/" { + PathBuf::from(args.output) } else { PathBuf::from( - env::var("STACKDOG_SNIFF_OUTPUT_DIR") - .unwrap_or_else(|_| output.to_string()), + env::var("STACKDOG_SNIFF_OUTPUT_DIR").unwrap_or_else(|_| args.output.to_string()), ) }; - let interval_secs = if interval != 30 { - interval + let interval_secs = if args.interval != 30 { + args.interval } else { env::var("STACKDOG_SNIFF_INTERVAL") .ok() .and_then(|v| v.parse().ok()) - .unwrap_or(interval) + .unwrap_or(args.interval) }; Self { - once, - consume, + once: args.once, + consume: args.consume, output_dir, extra_sources, interval_secs, - ai_provider: AiProvider::from_str(&ai_provider_str), - ai_api_url: ai_api_url_arg + ai_provider: ai_provider_str.parse().unwrap(), + ai_api_url: args + .ai_api_url .map(|s| s.to_string()) .or_else(|| env::var("STACKDOG_AI_API_URL").ok()) .unwrap_or_else(|| "http://localhost:11434/v1".into()), ai_api_key: env::var("STACKDOG_AI_API_KEY").ok(), - ai_model: ai_model_arg + ai_model: args + .ai_model .map(|s| s.to_string()) .or_else(|| env::var("STACKDOG_AI_MODEL").ok()) .unwrap_or_else(|| "llama3".into()), - database_url: env::var("DATABASE_URL") - .unwrap_or_else(|_| "./stackdog.db".into()), - slack_webhook: slack_webhook_arg + database_url: env::var("DATABASE_URL").unwrap_or_else(|_| "./stackdog.db".into()), + slack_webhook: args + .slack_webhook .map(|s| s.to_string()) .or_else(|| env::var("STACKDOG_SLACK_WEBHOOK_URL").ok()), webhook_url: env::var("STACKDOG_WEBHOOK_URL").ok(), @@ -151,11 +157,11 @@ mod tests { #[test] fn test_ai_provider_from_str() { - assert_eq!(AiProvider::from_str("openai"), AiProvider::OpenAi); - assert_eq!(AiProvider::from_str("OpenAI"), AiProvider::OpenAi); - assert_eq!(AiProvider::from_str("candle"), AiProvider::Candle); - assert_eq!(AiProvider::from_str("Candle"), AiProvider::Candle); - assert_eq!(AiProvider::from_str("unknown"), AiProvider::OpenAi); + assert_eq!("openai".parse::().unwrap(), AiProvider::OpenAi); + assert_eq!("OpenAI".parse::().unwrap(), AiProvider::OpenAi); + assert_eq!("candle".parse::().unwrap(), AiProvider::Candle); + assert_eq!("Candle".parse::().unwrap(), AiProvider::Candle); + assert_eq!("unknown".parse::().unwrap(), AiProvider::OpenAi); } #[test] @@ -163,7 +169,17 @@ mod tests { let _lock = ENV_MUTEX.lock().unwrap(); clear_sniff_env(); - let config = SniffConfig::from_env_and_args(false, false, "./stackdog-logs/", None, 30, None, None, None, None); + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); assert!(!config.once); assert!(!config.consume); assert_eq!(config.output_dir, PathBuf::from("./stackdog-logs/")); @@ -180,9 +196,17 @@ mod tests { let _lock = ENV_MUTEX.lock().unwrap(); clear_sniff_env(); - let config = SniffConfig::from_env_and_args( - true, true, "/tmp/output/", Some("/var/log/app.log"), 60, Some("candle"), None, None, None, - ); + let config = SniffConfig::from_env_and_args(SniffArgs { + once: true, + consume: true, + output: "/tmp/output/", + sources: Some("/var/log/app.log"), + interval: 60, + ai_provider: Some("candle"), + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); assert!(config.once); assert!(config.consume); @@ -198,13 +222,27 @@ mod tests { clear_sniff_env(); env::set_var("STACKDOG_LOG_SOURCES", "/var/log/syslog,/var/log/auth.log"); - let config = SniffConfig::from_env_and_args( - false, false, "./stackdog-logs/", Some("/var/log/app.log,/var/log/syslog"), 30, None, None, None, None, - ); - - assert!(config.extra_sources.contains(&"/var/log/syslog".to_string())); - assert!(config.extra_sources.contains(&"/var/log/auth.log".to_string())); - assert!(config.extra_sources.contains(&"/var/log/app.log".to_string())); + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: Some("/var/log/app.log,/var/log/syslog"), + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); + + assert!(config + .extra_sources + .contains(&"/var/log/syslog".to_string())); + assert!(config + .extra_sources + .contains(&"/var/log/auth.log".to_string())); + assert!(config + .extra_sources + .contains(&"/var/log/app.log".to_string())); assert_eq!(config.extra_sources.len(), 3); clear_sniff_env(); @@ -220,7 +258,17 @@ mod tests { env::set_var("STACKDOG_SNIFF_INTERVAL", "45"); env::set_var("STACKDOG_SNIFF_OUTPUT_DIR", "/data/logs/"); - let config = SniffConfig::from_env_and_args(false, false, "./stackdog-logs/", None, 30, None, None, None, None); + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); assert_eq!(config.ai_api_url, "https://api.openai.com/v1"); assert_eq!(config.ai_api_key, Some("sk-test123".into())); assert_eq!(config.ai_model, "gpt-4o-mini"); @@ -235,10 +283,17 @@ mod tests { let _lock = ENV_MUTEX.lock().unwrap(); clear_sniff_env(); - let config = SniffConfig::from_env_and_args( - false, false, "./stackdog-logs/", None, 30, - Some("ollama"), Some("qwen2.5-coder:latest"), None, None, - ); + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: Some("ollama"), + ai_model: Some("qwen2.5-coder:latest"), + ai_api_url: None, + slack_webhook: None, + }); // "ollama" maps to OpenAi internally (same API protocol) assert_eq!(config.ai_provider, AiProvider::OpenAi); assert_eq!(config.ai_model, "qwen2.5-coder:latest"); @@ -254,10 +309,17 @@ mod tests { env::set_var("STACKDOG_AI_MODEL", "gpt-4o-mini"); env::set_var("STACKDOG_AI_API_URL", "https://api.openai.com/v1"); - let config = SniffConfig::from_env_and_args( - false, false, "./stackdog-logs/", None, 30, - None, Some("llama3"), Some("http://localhost:11434/v1"), None, - ); + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: Some("llama3"), + ai_api_url: Some("http://localhost:11434/v1"), + slack_webhook: None, + }); // CLI args take priority over env vars assert_eq!(config.ai_model, "llama3"); assert_eq!(config.ai_api_url, "http://localhost:11434/v1"); @@ -270,11 +332,21 @@ mod tests { let _lock = ENV_MUTEX.lock().unwrap(); clear_sniff_env(); - let config = SniffConfig::from_env_and_args( - false, false, "./stackdog-logs/", None, 30, - None, None, None, Some("https://hooks.slack.com/services/T/B/xxx"), + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: Some("https://hooks.slack.com/services/T/B/xxx"), + }); + assert_eq!( + config.slack_webhook.as_deref(), + Some("https://hooks.slack.com/services/T/B/xxx") ); - assert_eq!(config.slack_webhook.as_deref(), Some("https://hooks.slack.com/services/T/B/xxx")); clear_sniff_env(); } @@ -283,13 +355,26 @@ mod tests { fn test_slack_webhook_from_env() { let _lock = ENV_MUTEX.lock().unwrap(); clear_sniff_env(); - env::set_var("STACKDOG_SLACK_WEBHOOK_URL", "https://hooks.slack.com/services/T/B/env"); + env::set_var( + "STACKDOG_SLACK_WEBHOOK_URL", + "https://hooks.slack.com/services/T/B/env", + ); - let config = SniffConfig::from_env_and_args( - false, false, "./stackdog-logs/", None, 30, - None, None, None, None, + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); + assert_eq!( + config.slack_webhook.as_deref(), + Some("https://hooks.slack.com/services/T/B/env") ); - assert_eq!(config.slack_webhook.as_deref(), Some("https://hooks.slack.com/services/T/B/env")); clear_sniff_env(); } @@ -298,13 +383,26 @@ mod tests { fn test_slack_webhook_cli_overrides_env() { let _lock = ENV_MUTEX.lock().unwrap(); clear_sniff_env(); - env::set_var("STACKDOG_SLACK_WEBHOOK_URL", "https://hooks.slack.com/services/T/B/env"); + env::set_var( + "STACKDOG_SLACK_WEBHOOK_URL", + "https://hooks.slack.com/services/T/B/env", + ); - let config = SniffConfig::from_env_and_args( - false, false, "./stackdog-logs/", None, 30, - None, None, None, Some("https://hooks.slack.com/services/T/B/cli"), + let config = SniffConfig::from_env_and_args(SniffArgs { + once: false, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: Some("https://hooks.slack.com/services/T/B/cli"), + }); + assert_eq!( + config.slack_webhook.as_deref(), + Some("https://hooks.slack.com/services/T/B/cli") ); - assert_eq!(config.slack_webhook.as_deref(), Some("https://hooks.slack.com/services/T/B/cli")); clear_sniff_env(); } diff --git a/src/sniff/consumer.rs b/src/sniff/consumer.rs index b594a63..96c7aff 100644 --- a/src/sniff/consumer.rs +++ b/src/sniff/consumer.rs @@ -3,17 +3,17 @@ //! When `--consume` is enabled, logs are archived to zstd-compressed files, //! deduplicated, and then originals are purged to free disk space. -use anyhow::{Result, Context}; +use anyhow::{Context, Result}; use chrono::Utc; -use std::collections::HashSet; use std::collections::hash_map::DefaultHasher; +use std::collections::HashSet; use std::fs::{self, File, OpenOptions}; use std::hash::{Hash, Hasher}; -use std::io::{Write, BufWriter}; +use std::io::{BufWriter, Write}; use std::path::{Path, PathBuf}; -use crate::sniff::reader::LogEntry; use crate::sniff::discovery::LogSourceType; +use crate::sniff::reader::LogEntry; /// Result of a consume operation #[derive(Debug, Clone, Default)] @@ -33,8 +33,12 @@ pub struct LogConsumer { impl LogConsumer { pub fn new(output_dir: PathBuf) -> Result { - fs::create_dir_all(&output_dir) - .with_context(|| format!("Failed to create output directory: {}", output_dir.display()))?; + fs::create_dir_all(&output_dir).with_context(|| { + format!( + "Failed to create output directory: {}", + output_dir.display() + ) + })?; Ok(Self { output_dir, @@ -58,14 +62,21 @@ impl LogConsumer { } let seen = &mut self.seen_hashes; - entries.iter().filter(|entry| { - let hash = Self::hash_line(&entry.line); - seen.insert(hash) - }).collect() + entries + .iter() + .filter(|entry| { + let hash = Self::hash_line(&entry.line); + seen.insert(hash) + }) + .collect() } /// Write entries to a zstd-compressed file - pub fn write_compressed(&self, entries: &[&LogEntry], source_name: &str) -> Result<(PathBuf, u64)> { + pub fn write_compressed( + &self, + entries: &[&LogEntry], + source_name: &str, + ) -> Result<(PathBuf, u64)> { let timestamp = Utc::now().format("%Y%m%d_%H%M%S"); let safe_name = source_name.replace(['/', '\\', ':', ' '], "_"); let filename = format!("{}_{}.log.zst", safe_name, timestamp); @@ -74,18 +85,17 @@ impl LogConsumer { let file = File::create(&path) .with_context(|| format!("Failed to create archive file: {}", path.display()))?; - let encoder = zstd::Encoder::new(file, 3) - .context("Failed to create zstd encoder")?; + let encoder = zstd::Encoder::new(file, 3).context("Failed to create zstd encoder")?; let mut writer = BufWriter::new(encoder); for entry in entries { writeln!(writer, "{}\t{}", entry.timestamp.to_rfc3339(), entry.line)?; } - let encoder = writer.into_inner() + let encoder = writer + .into_inner() .map_err(|e| anyhow::anyhow!("Buffer flush error: {}", e))?; - encoder.finish() - .context("Failed to finish zstd encoding")?; + encoder.finish().context("Failed to finish zstd encoding")?; let compressed_size = fs::metadata(&path)?.len(); Ok((path, compressed_size)) @@ -112,13 +122,19 @@ impl LogConsumer { /// Purge Docker container logs by truncating the JSON log file pub async fn purge_docker_logs(container_id: &str) -> Result { // Docker stores logs at /var/lib/docker/containers//-json.log - let log_path = format!("/var/lib/docker/containers/{}/{}-json.log", container_id, container_id); + let log_path = format!( + "/var/lib/docker/containers/{}/{}-json.log", + container_id, container_id + ); let path = Path::new(&log_path); if path.exists() { Self::purge_file(path) } else { - log::info!("Docker log file not found for container {}, skipping purge", container_id); + log::info!( + "Docker log file not found for container {}, skipping purge", + container_id + ); Ok(0) } } @@ -142,9 +158,7 @@ impl LogConsumer { let (_, compressed_size) = self.write_compressed(&unique_entries, source_name)?; let bytes_freed = match source_type { - LogSourceType::DockerContainer => { - Self::purge_docker_logs(source_path).await? - } + LogSourceType::DockerContainer => Self::purge_docker_logs(source_path).await?, LogSourceType::SystemLog | LogSourceType::CustomFile => { let path = Path::new(source_path); Self::purge_file(path)? @@ -299,12 +313,10 @@ mod tests { let entries = make_entries(&["line 1", "line 2", "line 1"]); let log_path_str = log_path.to_string_lossy().to_string(); - let result = consumer.consume( - &entries, - "app", - &LogSourceType::CustomFile, - &log_path_str, - ).await.unwrap(); + let result = consumer + .consume(&entries, "app", &LogSourceType::CustomFile, &log_path_str) + .await + .unwrap(); assert_eq!(result.entries_archived, 2); // deduplicated assert_eq!(result.duplicates_skipped, 1); @@ -321,12 +333,10 @@ mod tests { let dir = tempfile::tempdir().unwrap(); let mut consumer = LogConsumer::new(dir.path().to_path_buf()).unwrap(); - let result = consumer.consume( - &[], - "empty", - &LogSourceType::SystemLog, - "/var/log/test", - ).await.unwrap(); + let result = consumer + .consume(&[], "empty", &LogSourceType::SystemLog, "/var/log/test") + .await + .unwrap(); assert_eq!(result.entries_archived, 0); assert_eq!(result.duplicates_skipped, 0); diff --git a/src/sniff/discovery.rs b/src/sniff/discovery.rs index c8acf92..e2bc4c4 100644 --- a/src/sniff/discovery.rs +++ b/src/sniff/discovery.rs @@ -26,13 +26,15 @@ impl std::fmt::Display for LogSourceType { } } -impl LogSourceType { - pub fn from_str(s: &str) -> Self { - match s { +impl std::str::FromStr for LogSourceType { + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> std::result::Result { + Ok(match s { "DockerContainer" => LogSourceType::DockerContainer, "SystemLog" => LogSourceType::SystemLog, _ => LogSourceType::CustomFile, - } + }) } } @@ -183,17 +185,32 @@ mod tests { #[test] fn test_log_source_type_display() { - assert_eq!(LogSourceType::DockerContainer.to_string(), "DockerContainer"); + assert_eq!( + LogSourceType::DockerContainer.to_string(), + "DockerContainer" + ); assert_eq!(LogSourceType::SystemLog.to_string(), "SystemLog"); assert_eq!(LogSourceType::CustomFile.to_string(), "CustomFile"); } #[test] fn test_log_source_type_from_str() { - assert_eq!(LogSourceType::from_str("DockerContainer"), LogSourceType::DockerContainer); - assert_eq!(LogSourceType::from_str("SystemLog"), LogSourceType::SystemLog); - assert_eq!(LogSourceType::from_str("CustomFile"), LogSourceType::CustomFile); - assert_eq!(LogSourceType::from_str("anything"), LogSourceType::CustomFile); + assert_eq!( + "DockerContainer".parse::().unwrap(), + LogSourceType::DockerContainer + ); + assert_eq!( + "SystemLog".parse::().unwrap(), + LogSourceType::SystemLog + ); + assert_eq!( + "CustomFile".parse::().unwrap(), + LogSourceType::CustomFile + ); + assert_eq!( + "anything".parse::().unwrap(), + LogSourceType::CustomFile + ); } #[test] @@ -216,7 +233,7 @@ mod tests { writeln!(tmp, "test log line").unwrap(); let path = tmp.path().to_string_lossy().to_string(); - let sources = discover_custom_sources(&[path.clone()]); + let sources = discover_custom_sources(std::slice::from_ref(&path)); assert_eq!(sources.len(), 1); assert_eq!(sources[0].source_type, LogSourceType::CustomFile); assert_eq!(sources[0].path_or_id, path); @@ -234,10 +251,7 @@ mod tests { writeln!(tmp, "log").unwrap(); let existing = tmp.path().to_string_lossy().to_string(); - let sources = discover_custom_sources(&[ - existing.clone(), - "/does/not/exist.log".into(), - ]); + let sources = discover_custom_sources(&[existing.clone(), "/does/not/exist.log".into()]); assert_eq!(sources.len(), 1); assert_eq!(sources[0].path_or_id, existing); } diff --git a/src/sniff/mod.rs b/src/sniff/mod.rs index 4372bd2..e9c4299 100644 --- a/src/sniff/mod.rs +++ b/src/sniff/mod.rs @@ -3,23 +3,23 @@ //! Discovers, reads, analyzes, and optionally consumes logs from //! Docker containers, system log files, and custom sources. +pub mod analyzer; pub mod config; +pub mod consumer; pub mod discovery; pub mod reader; -pub mod analyzer; -pub mod consumer; pub mod reporter; -use anyhow::Result; -use crate::database::connection::{create_pool, init_database, DbPool}; use crate::alerting::notifications::NotificationConfig; -use crate::sniff::config::SniffConfig; -use crate::sniff::discovery::LogSourceType; -use crate::sniff::reader::{LogReader, FileLogReader, DockerLogReader}; +use crate::database::connection::{create_pool, init_database, DbPool}; +use crate::database::repositories::log_sources as log_sources_repo; use crate::sniff::analyzer::{LogAnalyzer, PatternAnalyzer}; +use crate::sniff::config::SniffConfig; use crate::sniff::consumer::LogConsumer; +use crate::sniff::discovery::LogSourceType; +use crate::sniff::reader::{DockerLogReader, FileLogReader, LogReader}; use crate::sniff::reporter::Reporter; -use crate::database::repositories::log_sources as log_sources_repo; +use anyhow::Result; /// Main orchestrator for the sniff command pub struct SniffOrchestrator { @@ -42,7 +42,11 @@ impl SniffOrchestrator { } let reporter = Reporter::new(notification_config); - Ok(Self { config, pool, reporter }) + Ok(Self { + config, + pool, + reporter, + }) } /// Create the appropriate AI analyzer based on config @@ -51,7 +55,8 @@ impl SniffOrchestrator { config::AiProvider::OpenAi => { log::debug!( "Creating OpenAI-compatible analyzer (model: {}, url: {})", - self.config.ai_model, self.config.ai_api_url + self.config.ai_model, + self.config.ai_api_url ); Box::new(analyzer::OpenAiAnalyzer::new( self.config.ai_api_url.clone(), @@ -68,28 +73,27 @@ impl SniffOrchestrator { /// Build readers for discovered sources, restoring saved positions from DB fn build_readers(&self, sources: &[discovery::LogSource]) -> Vec> { - sources.iter().filter_map(|source| { - let saved = log_sources_repo::get_log_source_by_path(&self.pool, &source.path_or_id) - .ok() - .flatten(); - let offset = saved.map(|s| s.last_read_position).unwrap_or(0); - - match source.source_type { - LogSourceType::SystemLog | LogSourceType::CustomFile => { - Some(Box::new(FileLogReader::new( - source.id.clone(), - source.path_or_id.clone(), - offset, - )) as Box) - } - LogSourceType::DockerContainer => { - Some(Box::new(DockerLogReader::new( + sources + .iter() + .map(|source| { + let saved = + log_sources_repo::get_log_source_by_path(&self.pool, &source.path_or_id) + .ok() + .flatten(); + let offset = saved.map(|s| s.last_read_position).unwrap_or(0); + + match source.source_type { + LogSourceType::SystemLog | LogSourceType::CustomFile => Box::new( + FileLogReader::new(source.id.clone(), source.path_or_id.clone(), offset), + ) + as Box, + LogSourceType::DockerContainer => Box::new(DockerLogReader::new( source.id.clone(), source.path_or_id.clone(), - )) as Box) + )) as Box, } - } - }).collect() + }) + .collect() } /// Run a single sniff pass: discover → read → analyze → report → consume @@ -112,7 +116,10 @@ impl SniffOrchestrator { let mut readers = self.build_readers(&sources); let analyzer = self.create_analyzer(); let mut consumer = if self.config.consume { - log::debug!("Consume mode enabled, output: {}", self.config.output_dir.display()); + log::debug!( + "Consume mode enabled, output: {}", + self.config.output_dir.display() + ); Some(LogConsumer::new(self.config.output_dir.clone())?) } else { None @@ -121,7 +128,12 @@ impl SniffOrchestrator { // 3. Process each source let reader_count = readers.len(); for (i, reader) in readers.iter_mut().enumerate() { - log::debug!("Step 3: reading source {}/{} ({})", i + 1, reader_count, reader.source_id()); + log::debug!( + "Step 3: reading source {}/{} ({})", + i + 1, + reader_count, + reader.source_id() + ); let entries = reader.read_new_entries().await?; if entries.is_empty() { log::debug!(" No new entries, skipping"); @@ -136,7 +148,9 @@ impl SniffOrchestrator { let summary = analyzer.summarize(&entries).await?; log::debug!( " Analysis complete: {} errors, {} warnings, {} anomalies", - summary.error_count, summary.warning_count, summary.anomalies.len() + summary.error_count, + summary.warning_count, + summary.anomalies.len() ); // 5. Report @@ -149,16 +163,21 @@ impl SniffOrchestrator { if i < sources.len() { log::debug!("Step 6: consuming entries..."); let source = &sources[i]; - let consume_result = cons.consume( - &entries, - &source.name, - &source.source_type, - &source.path_or_id, - ).await?; + let consume_result = cons + .consume( + &entries, + &source.name, + &source.source_type, + &source.path_or_id, + ) + .await?; result.bytes_freed += consume_result.bytes_freed; result.entries_archived += consume_result.entries_archived; - log::debug!(" Consumed: {} archived, {} bytes freed", - consume_result.entries_archived, consume_result.bytes_freed); + log::debug!( + " Consumed: {} archived, {} bytes freed", + consume_result.entries_archived, + consume_result.bytes_freed + ); } } @@ -231,9 +250,17 @@ mod tests { #[test] fn test_orchestrator_creates_with_memory_db() { - let mut config = SniffConfig::from_env_and_args( - true, false, "./stackdog-logs/", None, 30, None, None, None, None, - ); + let mut config = SniffConfig::from_env_and_args(config::SniffArgs { + once: true, + consume: false, + output: "./stackdog-logs/", + sources: None, + interval: 30, + ai_provider: None, + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); config.database_url = ":memory:".into(); let orchestrator = SniffOrchestrator::new(config); @@ -252,11 +279,17 @@ mod tests { writeln!(f, "WARN: retry in 5s").unwrap(); } - let mut config = SniffConfig::from_env_and_args( - true, false, "./stackdog-logs/", - Some(&log_path.to_string_lossy()), - 30, Some("candle"), None, None, None, - ); + let mut config = SniffConfig::from_env_and_args(config::SniffArgs { + once: true, + consume: false, + output: "./stackdog-logs/", + sources: Some(&log_path.to_string_lossy()), + interval: 30, + ai_provider: Some("candle"), + ai_model: None, + ai_api_url: None, + slack_webhook: None, + }); config.database_url = ":memory:".into(); let orchestrator = SniffOrchestrator::new(config).unwrap(); diff --git a/src/sniff/reader.rs b/src/sniff/reader.rs index f97cabf..fa3e450 100644 --- a/src/sniff/reader.rs +++ b/src/sniff/reader.rs @@ -7,8 +7,8 @@ use anyhow::Result; use async_trait::async_trait; use chrono::{DateTime, Utc}; use std::collections::HashMap; -use std::io::{BufRead, BufReader, Seek, SeekFrom}; use std::fs::File; +use std::io::{BufRead, BufReader, Seek, SeekFrom}; use std::path::Path; /// A single log entry from any source @@ -56,11 +56,19 @@ impl FileLogReader { let file = File::open(path)?; let file_len = file.metadata()?.len(); - log::debug!("Reading {} (size: {} bytes, offset: {})", self.path, file_len, self.offset); + log::debug!( + "Reading {} (size: {} bytes, offset: {})", + self.path, + file_len, + self.offset + ); // Handle file truncation (log rotation) if self.offset > file_len { - log::debug!("File truncated (rotation?), resetting offset from {} to 0", self.offset); + log::debug!( + "File truncated (rotation?), resetting offset from {} to 0", + self.offset + ); self.offset = 0; } @@ -77,16 +85,19 @@ impl FileLogReader { source_id: self.source_id.clone(), timestamp: Utc::now(), line: trimmed, - metadata: HashMap::from([ - ("source_path".into(), self.path.clone()), - ]), + metadata: HashMap::from([("source_path".into(), self.path.clone())]), }); } line.clear(); } self.offset = reader.stream_position()?; - log::debug!("Read {} entries from {}, new offset: {}", entries.len(), self.path, self.offset); + log::debug!( + "Read {} entries from {}, new offset: {}", + entries.len(), + self.path, + self.offset + ); Ok(entries) } } @@ -126,8 +137,8 @@ impl DockerLogReader { #[async_trait] impl LogReader for DockerLogReader { async fn read_new_entries(&mut self) -> Result> { - use bollard::Docker; use bollard::container::LogsOptions; + use bollard::Docker; use futures_util::stream::StreamExt; let docker = match Docker::connect_with_local_defaults() { @@ -143,7 +154,11 @@ impl LogReader for DockerLogReader { stderr: true, since: self.last_timestamp.unwrap_or(0), timestamps: true, - tail: if self.last_timestamp.is_none() { "100".to_string() } else { "all".to_string() }, + tail: if self.last_timestamp.is_none() { + "100".to_string() + } else { + "all".to_string() + }, ..Default::default() }; @@ -160,9 +175,10 @@ impl LogReader for DockerLogReader { source_id: self.source_id.clone(), timestamp: Utc::now(), line: trimmed, - metadata: HashMap::from([ - ("container_id".into(), self.container_id.clone()), - ]), + metadata: HashMap::from([( + "container_id".into(), + self.container_id.clone(), + )]), }); } } @@ -211,8 +227,10 @@ impl LogReader for JournaldReader { let mut cmd = Command::new("journalctl"); cmd.arg("--no-pager") - .arg("-o").arg("short-iso") - .arg("-n").arg("200"); + .arg("-o") + .arg("short-iso") + .arg("-n") + .arg("200"); if let Some(ref cursor) = self.cursor { cmd.arg("--after-cursor").arg(cursor); @@ -235,9 +253,7 @@ impl LogReader for JournaldReader { source_id: self.source_id.clone(), timestamp: Utc::now(), line: trimmed, - metadata: HashMap::from([ - ("source".into(), "journald".into()), - ]), + metadata: HashMap::from([("source".into(), "journald".into())]), }); } } @@ -290,11 +306,7 @@ mod tests { writeln!(f, "line 3").unwrap(); } - let mut reader = FileLogReader::new( - "test".into(), - path.to_string_lossy().to_string(), - 0, - ); + let mut reader = FileLogReader::new("test".into(), path.to_string_lossy().to_string(), 0); let entries = reader.read_new_entries().await.unwrap(); assert_eq!(entries.len(), 3); assert_eq!(entries[0].line, "line 1"); @@ -325,7 +337,10 @@ mod tests { // Append new lines { - let mut f = std::fs::OpenOptions::new().append(true).open(&path).unwrap(); + let mut f = std::fs::OpenOptions::new() + .append(true) + .open(&path) + .unwrap(); writeln!(f, "line C").unwrap(); } @@ -382,11 +397,7 @@ mod tests { writeln!(f, "line 3").unwrap(); } - let mut reader = FileLogReader::new( - "empty".into(), - path.to_string_lossy().to_string(), - 0, - ); + let mut reader = FileLogReader::new("empty".into(), path.to_string_lossy().to_string(), 0); let entries = reader.read_new_entries().await.unwrap(); assert_eq!(entries.len(), 2); assert_eq!(entries[0].line, "line 1"); diff --git a/src/sniff/reporter.rs b/src/sniff/reporter.rs index bfc3b55..ab8c73a 100644 --- a/src/sniff/reporter.rs +++ b/src/sniff/reporter.rs @@ -3,12 +3,12 @@ //! Converts log summaries and anomalies into alerts, then dispatches //! them via the existing notification channels. -use anyhow::Result; use crate::alerting::alert::{Alert, AlertSeverity, AlertType}; -use crate::alerting::notifications::{NotificationChannel, NotificationConfig, route_by_severity}; -use crate::sniff::analyzer::{LogSummary, LogAnomaly, AnomalySeverity}; +use crate::alerting::notifications::{route_by_severity, NotificationConfig}; use crate::database::connection::DbPool; use crate::database::repositories::log_sources; +use crate::sniff::analyzer::{AnomalySeverity, LogSummary}; +use anyhow::Result; /// Reports log analysis results to alert channels and persists summaries pub struct Reporter { @@ -17,7 +17,9 @@ pub struct Reporter { impl Reporter { pub fn new(notification_config: NotificationConfig) -> Self { - Self { notification_config } + Self { + notification_config, + } } /// Map anomaly severity to alert severity @@ -36,16 +38,21 @@ impl Reporter { // Persist summary to database if let Some(pool) = pool { - log::debug!("Persisting summary for source {} to database", summary.source_id); + log::debug!( + "Persisting summary for source {} to database", + summary.source_id + ); let _ = log_sources::create_log_summary( pool, - &summary.source_id, - &summary.summary_text, - &summary.period_start.to_rfc3339(), - &summary.period_end.to_rfc3339(), - summary.total_entries as i64, - summary.error_count as i64, - summary.warning_count as i64, + log_sources::CreateLogSummaryParams { + source_id: &summary.source_id, + summary_text: &summary.summary_text, + period_start: &summary.period_start.to_rfc3339(), + period_end: &summary.period_end.to_rfc3339(), + total_entries: summary.total_entries as i64, + error_count: summary.error_count as i64, + warning_count: summary.warning_count as i64, + }, ); } @@ -55,7 +62,8 @@ impl Reporter { log::debug!( "Generating alert: severity={}, description={}", - anomaly.severity, anomaly.description + anomaly.severity, + anomaly.description ); let alert = Alert::new( @@ -107,8 +115,9 @@ pub struct ReportResult { #[cfg(test)] mod tests { use super::*; - use chrono::Utc; use crate::database::connection::{create_pool, init_database}; + use crate::sniff::analyzer::LogAnomaly; + use chrono::Utc; fn make_summary(anomalies: Vec) -> LogSummary { LogSummary { @@ -126,10 +135,22 @@ mod tests { #[test] fn test_map_severity() { - assert_eq!(Reporter::map_severity(&AnomalySeverity::Low), AlertSeverity::Low); - assert_eq!(Reporter::map_severity(&AnomalySeverity::Medium), AlertSeverity::Medium); - assert_eq!(Reporter::map_severity(&AnomalySeverity::High), AlertSeverity::High); - assert_eq!(Reporter::map_severity(&AnomalySeverity::Critical), AlertSeverity::Critical); + assert_eq!( + Reporter::map_severity(&AnomalySeverity::Low), + AlertSeverity::Low + ); + assert_eq!( + Reporter::map_severity(&AnomalySeverity::Medium), + AlertSeverity::Medium + ); + assert_eq!( + Reporter::map_severity(&AnomalySeverity::High), + AlertSeverity::High + ); + assert_eq!( + Reporter::map_severity(&AnomalySeverity::Critical), + AlertSeverity::Critical + ); } #[test] @@ -145,13 +166,11 @@ mod tests { #[test] fn test_report_with_anomalies_sends_alerts() { let reporter = Reporter::new(NotificationConfig::default()); - let summary = make_summary(vec![ - LogAnomaly { - description: "High error rate".into(), - severity: AnomalySeverity::High, - sample_line: "ERROR: connection failed".into(), - }, - ]); + let summary = make_summary(vec![LogAnomaly { + description: "High error rate".into(), + severity: AnomalySeverity::High, + sample_line: "ERROR: connection failed".into(), + }]); let result = reporter.report(&summary, None).unwrap(); assert_eq!(result.anomalies_reported, 1); diff --git a/tests/collectors/connect_capture_test.rs b/tests/collectors/connect_capture_test.rs index 6d39bda..319bcc3 100644 --- a/tests/collectors/connect_capture_test.rs +++ b/tests/collectors/connect_capture_test.rs @@ -6,56 +6,57 @@ mod linux_tests { use stackdog::collectors::ebpf::syscall_monitor::SyscallMonitor; use stackdog::events::syscall::SyscallType; - use std::time::Duration; use std::net::TcpStream; + use std::time::Duration; #[test] #[ignore = "requires root and eBPF support"] fn test_connect_event_captured_on_tcp_connection() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Try to connect to a local port (will fail, but syscall is still made) let _ = TcpStream::connect("127.0.0.1:12345"); - + // Give eBPF time to process std::thread::sleep(Duration::from_millis(100)); - + // Poll for events let events = monitor.poll_events(); - + // Should have captured connect events let connect_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Connect) .collect(); - + // We expect at least one connect event - assert!(!connect_events.is_empty(), "Should capture at least one connect event"); + assert!( + !connect_events.is_empty(), + "Should capture at least one connect event" + ); } #[test] #[ignore = "requires root and eBPF support"] fn test_connect_event_contains_destination_ip() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Connect to localhost let _ = TcpStream::connect("127.0.0.1:12345"); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let connect_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Connect) .collect(); - + // Just verify we got events (detailed IP capture tested in integration) assert!(!connect_events.is_empty()); } @@ -63,24 +64,23 @@ mod linux_tests { #[test] #[ignore = "requires root and eBPF support"] fn test_connect_event_contains_destination_port() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Connect to specific port let test_port = 12346; let _ = TcpStream::connect(format!("127.0.0.1:{}", test_port)); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let connect_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Connect) .collect(); - + // Verify events captured assert!(!connect_events.is_empty()); } @@ -88,27 +88,29 @@ mod linux_tests { #[test] #[ignore = "requires root and eBPF support"] fn test_connect_event_multiple_connections() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Make multiple connections for port in 12350..12355 { let _ = TcpStream::connect(format!("127.0.0.1:{}", port)); } - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let connect_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Connect) .collect(); - + // Should have multiple connect events - assert!(connect_events.len() >= 5, "Should capture multiple connect events"); + assert!( + connect_events.len() >= 5, + "Should capture multiple connect events" + ); } } diff --git a/tests/collectors/ebpf_kernel_test.rs b/tests/collectors/ebpf_kernel_test.rs index bbdf717..b067afd 100644 --- a/tests/collectors/ebpf_kernel_test.rs +++ b/tests/collectors/ebpf_kernel_test.rs @@ -1,6 +1,6 @@ //! eBPF kernel compatibility tests -use stackdog::collectors::ebpf::kernel::{KernelInfo, KernelVersion, check_kernel_version}; +use stackdog::collectors::ebpf::kernel::{check_kernel_version, KernelInfo, KernelVersion}; #[test] fn test_kernel_version_parse() { @@ -33,7 +33,7 @@ fn test_kernel_version_comparison() { let v1 = KernelVersion::parse("5.10.0").unwrap(); let v2 = KernelVersion::parse("5.15.0").unwrap(); let v3 = KernelVersion::parse("4.19.0").unwrap(); - + assert!(v2 > v1); assert!(v1 > v3); assert!(v2 > v3); @@ -44,7 +44,7 @@ fn test_kernel_version_meets_minimum() { let current = KernelVersion::parse("5.10.0").unwrap(); let min_4_19 = KernelVersion::parse("4.19.0").unwrap(); let min_5_15 = KernelVersion::parse("5.15.0").unwrap(); - + assert!(current.meets_minimum(&min_4_19)); assert!(!current.meets_minimum(&min_5_15)); } @@ -52,10 +52,10 @@ fn test_kernel_version_meets_minimum() { #[test] fn test_kernel_info_creation() { let info = KernelInfo::new(); - + #[cfg(target_os = "linux")] assert!(info.is_ok()); - + #[cfg(not(target_os = "linux"))] assert!(info.is_err()); } @@ -63,13 +63,13 @@ fn test_kernel_info_creation() { #[test] fn test_kernel_version_check_function() { let result = check_kernel_version(); - + #[cfg(target_os = "linux")] { // On Linux, should return some version info assert!(result.is_ok()); } - + #[cfg(not(target_os = "linux"))] { // On non-Linux, should indicate unsupported @@ -89,7 +89,7 @@ fn test_kernel_version_equality() { let v1 = KernelVersion::parse("5.10.0").unwrap(); let v2 = KernelVersion::parse("5.10.0").unwrap(); let v3 = KernelVersion::parse("5.10.1").unwrap(); - + assert_eq!(v1, v2); assert_ne!(v1, v3); } diff --git a/tests/collectors/ebpf_loader_test.rs b/tests/collectors/ebpf_loader_test.rs index 26d1155..ea0acb0 100644 --- a/tests/collectors/ebpf_loader_test.rs +++ b/tests/collectors/ebpf_loader_test.rs @@ -5,7 +5,6 @@ #[cfg(target_os = "linux")] mod linux_tests { use stackdog::collectors::ebpf::loader::{EbpfLoader, LoadError}; - use anyhow::Result; #[test] fn test_ebpf_loader_creation() { @@ -15,8 +14,7 @@ mod linux_tests { #[test] fn test_ebpf_loader_default() { - let loader = EbpfLoader::default(); - assert!(loader.is_ok(), "EbpfLoader::default() should succeed"); + let _loader = EbpfLoader::default(); } #[test] @@ -30,10 +28,10 @@ mod linux_tests { #[ignore = "requires root and eBPF support"] fn test_ebpf_program_load_success() { let mut loader = EbpfLoader::new().expect("Failed to create loader"); - + // Try to load a program (this requires the eBPF ELF file) let result = loader.load_program_from_bytes(&[]); - + // Should fail with empty bytes, but not panic assert!(result.is_err()); } @@ -43,8 +41,11 @@ mod linux_tests { let error = LoadError::ProgramNotFound("test_program".to_string()); let msg = format!("{}", error); assert!(msg.contains("test_program")); - - let error = LoadError::KernelVersionTooLow { required: 4, current: 3 }; + + let error = LoadError::KernelVersionTooLow { + required: "4".to_string(), + current: "3".to_string(), + }; let msg = format!("{}", error); assert!(msg.contains("4.19")); } @@ -58,10 +59,10 @@ mod cross_platform_tests { fn test_ebpf_loader_creation_cross_platform() { // This test should work on all platforms let result = EbpfLoader::new(); - + #[cfg(target_os = "linux")] assert!(result.is_ok()); - + #[cfg(not(target_os = "linux"))] assert!(result.is_err()); // Should error on non-Linux } @@ -69,10 +70,10 @@ mod cross_platform_tests { #[test] fn test_ebpf_is_linux_check() { use stackdog::collectors::ebpf::loader::is_linux; - + #[cfg(target_os = "linux")] assert!(is_linux()); - + #[cfg(not(target_os = "linux"))] assert!(!is_linux()); } diff --git a/tests/collectors/ebpf_syscall_test.rs b/tests/collectors/ebpf_syscall_test.rs index 9ae6617..7432f8b 100644 --- a/tests/collectors/ebpf_syscall_test.rs +++ b/tests/collectors/ebpf_syscall_test.rs @@ -5,36 +5,39 @@ #[cfg(target_os = "linux")] mod linux_tests { use stackdog::collectors::ebpf::syscall_monitor::SyscallMonitor; - use stackdog::events::syscall::{SyscallEvent, SyscallType}; + use stackdog::events::syscall::SyscallType; use std::time::Duration; #[test] #[ignore = "requires root and eBPF support"] fn test_syscall_monitor_creation() { let monitor = SyscallMonitor::new(); - assert!(monitor.is_ok(), "SyscallMonitor::new() should succeed on Linux with eBPF"); + assert!( + monitor.is_ok(), + "SyscallMonitor::new() should succeed on Linux with eBPF" + ); } #[test] #[ignore = "requires root and eBPF support"] fn test_execve_event_capture() { let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); - + // Start monitoring monitor.start().expect("Failed to start monitor"); - + // Trigger an execve by running a simple command std::process::Command::new("echo").arg("test").output().ok(); - + // Give eBPF time to process std::thread::sleep(Duration::from_millis(100)); - + // Poll for events let events = monitor.poll_events(); - + // Should have captured some events assert!(events.len() > 0, "Should capture at least one execve event"); - + // Check that we have execve events let has_execve = events.iter().any(|e| e.syscall_type == SyscallType::Execve); assert!(has_execve, "Should capture execve events"); @@ -45,15 +48,17 @@ mod linux_tests { fn test_connect_event_capture() { let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); monitor.start().expect("Failed to start monitor"); - + // Trigger a connect syscall let _ = std::net::TcpStream::connect("127.0.0.1:12345"); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - let has_connect = events.iter().any(|e| e.syscall_type == SyscallType::Connect); - + let _has_connect = events + .iter() + .any(|e| e.syscall_type == SyscallType::Connect); + // May or may not capture depending on timing // Just verify no panic assert!(true); @@ -64,14 +69,14 @@ mod linux_tests { fn test_openat_event_capture() { let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); monitor.start().expect("Failed to start monitor"); - + // Trigger openat syscalls let _ = std::fs::File::open("/etc/hostname"); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + // Should have captured some events assert!(events.len() > 0); } @@ -81,11 +86,11 @@ mod linux_tests { fn test_ptrace_event_capture() { let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); monitor.start().expect("Failed to start monitor"); - + // Note: Actually calling ptrace requires special setup // This test verifies the monitor doesn't crash - - let events = monitor.poll_events(); + + let _events = monitor.poll_events(); assert!(true); // Just verify no panic } @@ -94,11 +99,11 @@ mod linux_tests { fn test_event_ring_buffer_poll() { let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); monitor.start().expect("Failed to start monitor"); - + // Multiple polls should work let events1 = monitor.poll_events(); let events2 = monitor.poll_events(); - + // Both should succeed (may be empty) assert!(events1.len() >= 0); assert!(events2.len() >= 0); @@ -109,11 +114,11 @@ mod linux_tests { fn test_syscall_monitor_stop() { let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); monitor.start().expect("Failed to start monitor"); - + // Stop should work let result = monitor.stop(); assert!(result.is_ok()); - + // Poll after stop should return empty let events = monitor.poll_events(); assert!(events.is_empty()); diff --git a/tests/collectors/event_enrichment_test.rs b/tests/collectors/event_enrichment_test.rs index 315db08..98f83e0 100644 --- a/tests/collectors/event_enrichment_test.rs +++ b/tests/collectors/event_enrichment_test.rs @@ -2,10 +2,10 @@ //! //! Tests for event enrichment (container ID, timestamps, process tree) -use stackdog::collectors::ebpf::enrichment::EventEnricher; +use chrono::Utc; use stackdog::collectors::ebpf::container::ContainerDetector; +use stackdog::collectors::ebpf::enrichment::EventEnricher; use stackdog::events::syscall::{SyscallEvent, SyscallType}; -use chrono::Utc; #[test] fn test_event_enricher_creation() { @@ -17,9 +17,9 @@ fn test_event_enricher_creation() { fn test_enrich_adds_timestamp() { let mut enricher = EventEnricher::new().expect("Failed to create enricher"); let mut event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); - + enricher.enrich(&mut event).expect("Failed to enrich"); - + // Event should have timestamp assert!(event.timestamp <= Utc::now()); } @@ -29,9 +29,9 @@ fn test_enrich_preserves_existing_timestamp() { let mut enricher = EventEnricher::new().expect("Failed to create enricher"); let original_timestamp = Utc::now(); let mut event = SyscallEvent::new(1234, 1000, SyscallType::Execve, original_timestamp); - + enricher.enrich(&mut event).expect("Failed to enrich"); - + // Timestamp should be preserved or updated (both acceptable) assert!(event.timestamp >= original_timestamp); } @@ -42,27 +42,29 @@ fn test_container_detector_creation() { // Should work on Linux, may fail on other platforms #[cfg(target_os = "linux")] assert!(detector.is_ok()); + #[cfg(not(target_os = "linux"))] + assert!(detector.is_err()); } #[test] fn test_container_id_detection_format() { let detector = ContainerDetector::new(); - + #[cfg(target_os = "linux")] { let detector = detector.expect("Failed to create detector"); // Test with a known container ID format let valid_ids = vec![ "abc123def456", - "abc123def456789012345678901234567890", + "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890", ]; - + for id in valid_ids { let result = detector.validate_container_id(id); assert!(result, "Should validate container ID: {}", id); } } - + #[cfg(not(target_os = "linux"))] { assert!(detector.is_err()); @@ -72,7 +74,7 @@ fn test_container_id_detection_format() { #[test] fn test_container_id_invalid_formats() { let detector = ContainerDetector::new(); - + #[cfg(target_os = "linux")] { let detector = detector.expect("Failed to create detector"); @@ -82,32 +84,28 @@ fn test_container_id_invalid_formats() { "invalid@chars!", "this_is_way_too_long_for_a_container_id_and_should_fail_validation", ]; - + for id in invalid_ids { let result = detector.validate_container_id(id); assert!(!result, "Should reject invalid container ID: {}", id); } } + + #[cfg(not(target_os = "linux"))] + { + assert!(detector.is_err()); + } } #[test] fn test_cgroup_parsing() { // Test cgroup path parsing for container detection let test_cases = vec![ - ( - "12:memory:/docker/abc123def456", - Some("abc123def456"), - ), - ( - "11:cpu:/kubepods/pod123/def456abc789", - Some("def456abc789"), - ), - ( - "10:cpuacct:/", - None, - ), + ("12:memory:/docker/abc123def456", Some("abc123def456")), + ("11:cpu:/kubepods/pod123/def456abc789", Some("def456abc789")), + ("10:cpuacct:/", None), ]; - + for (cgroup_path, expected_id) in test_cases { let result = ContainerDetector::parse_container_from_cgroup(cgroup_path); assert_eq!(result, expected_id.map(|s| s.to_string())); @@ -116,37 +114,41 @@ fn test_cgroup_parsing() { #[test] fn test_process_tree_enrichment() { - let mut enricher = EventEnricher::new().expect("Failed to create enricher"); - + let enricher = EventEnricher::new().expect("Failed to create enricher"); + // Test that we can get parent PID let ppid = enricher.get_parent_pid(1); // init process - + // PID 1 should exist on Linux #[cfg(target_os = "linux")] assert!(ppid.is_some()); + #[cfg(not(target_os = "linux"))] + let _ = ppid; } #[test] fn test_process_comm_enrichment() { - let mut enricher = EventEnricher::new().expect("Failed to create enricher"); - + let enricher = EventEnricher::new().expect("Failed to create enricher"); + // Test that we can get process name let comm = enricher.get_process_comm(std::process::id()); - + // Should get some process name #[cfg(target_os = "linux")] assert!(comm.is_some()); + #[cfg(not(target_os = "linux"))] + let _ = comm; } #[test] fn test_timestamp_normalization() { use stackdog::collectors::ebpf::enrichment::normalize_timestamp; - + // Test with current time let now = Utc::now(); let normalized = normalize_timestamp(now); assert!(normalized >= now); - + // Test with epoch let epoch = chrono::DateTime::from_timestamp(0, 0).unwrap(); let normalized = normalize_timestamp(epoch); @@ -157,10 +159,10 @@ fn test_timestamp_normalization() { fn test_enrichment_pipeline() { let mut enricher = EventEnricher::new().expect("Failed to create enricher"); let mut event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); - + // Run full enrichment pipeline enricher.enrich(&mut event).expect("Failed to enrich"); - + // Event should be enriched assert!(event.timestamp <= Utc::now()); } diff --git a/tests/collectors/execve_capture_test.rs b/tests/collectors/execve_capture_test.rs index 1289258..d5914bc 100644 --- a/tests/collectors/execve_capture_test.rs +++ b/tests/collectors/execve_capture_test.rs @@ -6,83 +6,83 @@ mod linux_tests { use stackdog::collectors::ebpf::syscall_monitor::SyscallMonitor; use stackdog::events::syscall::SyscallType; - use std::time::Duration; use std::process::Command; + use std::time::Duration; #[test] #[ignore = "requires root and eBPF support"] fn test_execve_event_captured_on_process_spawn() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Spawn a process to trigger execve let _ = Command::new("echo").arg("test").output(); - + // Give eBPF time to process std::thread::sleep(Duration::from_millis(100)); - + // Poll for events let events = monitor.poll_events(); - + // Should have captured execve events let execve_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Execve) .collect(); - - assert!(!execve_events.is_empty(), "Should capture at least one execve event"); + + assert!( + !execve_events.is_empty(), + "Should capture at least one execve event" + ); } #[test] #[ignore = "requires root and eBPF support"] fn test_execve_event_contains_filename() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Spawn a specific process let _ = Command::new("/bin/ls").arg("-la").output(); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + // Find execve events let execve_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Execve) .collect(); - + // At least one should have comm set - let has_comm = execve_events.iter().any(|e| { - e.comm.as_ref().map(|c| !c.is_empty()).unwrap_or(false) - }); - + let has_comm = execve_events + .iter() + .any(|e| e.comm.as_ref().map(|c| !c.is_empty()).unwrap_or(false)); + assert!(has_comm, "Should capture command name"); } #[test] #[ignore = "requires root and eBPF support"] fn test_execve_event_contains_pid() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + let _ = Command::new("echo").arg("test").output(); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let execve_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Execve) .collect(); - + // All events should have valid PID for event in execve_events { assert!(event.pid > 0, "PID should be positive"); @@ -92,52 +92,51 @@ mod linux_tests { #[test] #[ignore = "requires root and eBPF support"] fn test_execve_event_contains_uid() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + let _ = Command::new("echo").arg("test").output(); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let execve_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Execve) .collect(); - - // All events should have valid UID - for event in execve_events { - assert!(event.uid >= 0, "UID should be non-negative"); - } + + // UID is u32, so only verify iterating events is safe and stable. + for _event in execve_events {} } #[test] #[ignore = "requires root and eBPF support"] fn test_execve_event_timestamp() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + let before = chrono::Utc::now(); - + let _ = Command::new("echo").arg("test").output(); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let execve_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Execve) .collect(); - + // Timestamps should be reasonable for event in execve_events { - assert!(event.timestamp >= before, "Event timestamp should be after test start"); + assert!( + event.timestamp >= before, + "Event timestamp should be after test start" + ); } } } diff --git a/tests/collectors/mod.rs b/tests/collectors/mod.rs index 813140b..496f326 100644 --- a/tests/collectors/mod.rs +++ b/tests/collectors/mod.rs @@ -1,10 +1,10 @@ //! Collectors module tests +mod connect_capture_test; +mod ebpf_kernel_test; mod ebpf_loader_test; mod ebpf_syscall_test; -mod ebpf_kernel_test; +mod event_enrichment_test; mod execve_capture_test; -mod connect_capture_test; mod openat_capture_test; mod ptrace_capture_test; -mod event_enrichment_test; diff --git a/tests/collectors/openat_capture_test.rs b/tests/collectors/openat_capture_test.rs index 3de56d2..20fb0fe 100644 --- a/tests/collectors/openat_capture_test.rs +++ b/tests/collectors/openat_capture_test.rs @@ -6,55 +6,56 @@ mod linux_tests { use stackdog::collectors::ebpf::syscall_monitor::SyscallMonitor; use stackdog::events::syscall::SyscallType; - use std::time::Duration; use std::fs::File; + use std::time::Duration; #[test] #[ignore = "requires root and eBPF support"] fn test_openat_event_captured_on_file_open() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Open a file to trigger openat let _ = File::open("/etc/hostname"); - + // Give eBPF time to process std::thread::sleep(Duration::from_millis(100)); - + // Poll for events let events = monitor.poll_events(); - + // Should have captured openat events let openat_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Openat) .collect(); - - assert!(!openat_events.is_empty(), "Should capture at least one openat event"); + + assert!( + !openat_events.is_empty(), + "Should capture at least one openat event" + ); } #[test] #[ignore = "requires root and eBPF support"] fn test_openat_event_contains_file_path() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Open specific file let _ = File::open("/etc/hostname"); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let openat_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Openat) .collect(); - + // Just verify events captured (detailed path capture in integration tests) assert!(!openat_events.is_empty()); } @@ -62,62 +63,59 @@ mod linux_tests { #[test] #[ignore = "requires root and eBPF support"] fn test_openat_event_multiple_files() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Open multiple files - let files = vec![ - "/etc/hostname", - "/etc/hosts", - "/etc/resolv.conf", - ]; - + let files = vec!["/etc/hostname", "/etc/hosts", "/etc/resolv.conf"]; + for path in files { let _ = File::open(path); } - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let openat_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Openat) .collect(); - + // Should have multiple openat events - assert!(openat_events.len() >= 3, "Should capture multiple openat events"); + assert!( + openat_events.len() >= 3, + "Should capture multiple openat events" + ); } #[test] #[ignore = "requires root and eBPF support"] fn test_openat_event_read_and_write() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Open file for reading let _ = File::open("/etc/hostname"); - + // Open file for writing (creates temp file) let temp_path = "/tmp/stackdog_test.tmp"; let _ = File::create(temp_path); - + // Cleanup let _ = std::fs::remove_file(temp_path); - + std::thread::sleep(Duration::from_millis(100)); - + let events = monitor.poll_events(); - + let openat_events: Vec<_> = events .iter() .filter(|e| e.syscall_type == SyscallType::Openat) .collect(); - + // Should have captured both read and write opens assert!(openat_events.len() >= 2); } diff --git a/tests/collectors/ptrace_capture_test.rs b/tests/collectors/ptrace_capture_test.rs index cde16f0..533896e 100644 --- a/tests/collectors/ptrace_capture_test.rs +++ b/tests/collectors/ptrace_capture_test.rs @@ -5,25 +5,23 @@ #[cfg(target_os = "linux")] mod linux_tests { use stackdog::collectors::ebpf::syscall_monitor::SyscallMonitor; - use stackdog::events::syscall::SyscallType; use std::time::Duration; #[test] #[ignore = "requires root and eBPF support"] fn test_ptrace_event_captured_on_trace_attempt() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Note: Actually calling ptrace requires special setup // For now, we just verify the monitor doesn't crash // and can detect ptrace syscalls if they occur - + std::thread::sleep(Duration::from_millis(100)); - - let events = monitor.poll_events(); - + + let _events = monitor.poll_events(); + // Just verify monitor works without crashing assert!(true, "Monitor should handle ptrace detection gracefully"); } @@ -31,15 +29,14 @@ mod linux_tests { #[test] #[ignore = "requires root and eBPF support"] fn test_ptrace_event_contains_target_pid() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + std::thread::sleep(Duration::from_millis(100)); - - let events = monitor.poll_events(); - + + let _events = monitor.poll_events(); + // Verify structure ready for ptrace events assert!(true); } @@ -47,18 +44,17 @@ mod linux_tests { #[test] #[ignore = "requires root and eBPF support"] fn test_ptrace_event_security_alert() { - let mut monitor = SyscallMonitor::new() - .expect("Failed to create monitor"); - + let mut monitor = SyscallMonitor::new().expect("Failed to create monitor"); + monitor.start().expect("Failed to start monitor"); - + // Ptrace is often used by debuggers and malware // Verify we can detect it - + std::thread::sleep(Duration::from_millis(100)); - - let events = monitor.poll_events(); - + + let _events = monitor.poll_events(); + // Just verify monitor is working assert!(true); } diff --git a/tests/events/event_conversion_test.rs b/tests/events/event_conversion_test.rs index d692afb..1a91bf9 100644 --- a/tests/events/event_conversion_test.rs +++ b/tests/events/event_conversion_test.rs @@ -2,25 +2,20 @@ //! //! Tests for From/Into trait implementations between event types -use stackdog::events::syscall::{SyscallEvent, SyscallType}; +use chrono::Utc; use stackdog::events::security::{ - SecurityEvent, NetworkEvent, ContainerEvent, ContainerEventType, - AlertEvent, AlertType, AlertSeverity, + AlertEvent, AlertSeverity, AlertType, ContainerEvent, ContainerEventType, NetworkEvent, + SecurityEvent, }; -use chrono::Utc; +use stackdog::events::syscall::{SyscallEvent, SyscallType}; #[test] fn test_syscall_event_to_security_event() { - let syscall_event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); - + let syscall_event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); + // Test From trait let security_event: SecurityEvent = syscall_event.clone().into(); - + match security_event { SecurityEvent::Syscall(e) => { assert_eq!(e.pid, syscall_event.pid); @@ -42,9 +37,9 @@ fn test_network_event_to_security_event() { timestamp: Utc::now(), container_id: Some("abc123".to_string()), }; - + let security_event: SecurityEvent = network_event.clone().into(); - + match security_event { SecurityEvent::Network(e) => { assert_eq!(e.src_ip, network_event.src_ip); @@ -62,9 +57,9 @@ fn test_container_event_to_security_event() { timestamp: Utc::now(), details: Some("Container started".to_string()), }; - + let security_event: SecurityEvent = container_event.clone().into(); - + match security_event { SecurityEvent::Container(e) => { assert_eq!(e.container_id, container_event.container_id); @@ -83,9 +78,9 @@ fn test_alert_event_to_security_event() { timestamp: Utc::now(), source_event_id: Some("evt_123".to_string()), }; - + let security_event: SecurityEvent = alert_event.clone().into(); - + match security_event { SecurityEvent::Alert(e) => { assert_eq!(e.alert_type, alert_event.alert_type); @@ -97,15 +92,10 @@ fn test_alert_event_to_security_event() { #[test] fn test_security_event_into_syscall() { - let syscall_event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Connect, - Utc::now(), - ); - + let syscall_event = SyscallEvent::new(1234, 1000, SyscallType::Connect, Utc::now()); + let security_event = SecurityEvent::Syscall(syscall_event.clone()); - + // Test conversion back to SyscallEvent let result = syscall_event_from_security(security_event); assert!(result.is_some()); @@ -125,9 +115,9 @@ fn test_security_event_wrong_variant() { timestamp: Utc::now(), container_id: None, }; - + let security_event = SecurityEvent::Network(network_event); - + // Try to extract as SyscallEvent (should fail) let result = syscall_event_from_security(security_event); assert!(result.is_none()); diff --git a/tests/events/event_serialization_test.rs b/tests/events/event_serialization_test.rs index d18c76a..a4b6741 100644 --- a/tests/events/event_serialization_test.rs +++ b/tests/events/event_serialization_test.rs @@ -2,22 +2,16 @@ //! //! Tests for JSON and binary serialization of events -use stackdog::events::syscall::{SyscallEvent, SyscallType}; -use stackdog::events::security::SecurityEvent; use chrono::Utc; -use serde_json; +use stackdog::events::security::SecurityEvent; +use stackdog::events::syscall::{SyscallEvent, SyscallType}; #[test] fn test_syscall_event_json_serialize() { - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); - + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); + let json = serde_json::to_string(&event).expect("Failed to serialize"); - + assert!(json.contains("\"pid\":1234")); assert!(json.contains("\"uid\":1000")); assert!(json.contains("\"syscall_type\":\"Execve\"")); @@ -33,9 +27,9 @@ fn test_syscall_event_json_deserialize() { "container_id": null, "comm": null }"#; - + let event: SyscallEvent = serde_json::from_str(json).expect("Failed to deserialize"); - + assert_eq!(event.pid, 5678); assert_eq!(event.uid, 2000); assert_eq!(event.syscall_type, SyscallType::Connect); @@ -43,16 +37,11 @@ fn test_syscall_event_json_deserialize() { #[test] fn test_syscall_event_json_roundtrip() { - let original = SyscallEvent::new( - 1234, - 1000, - SyscallType::Ptrace, - Utc::now(), - ); - + let original = SyscallEvent::new(1234, 1000, SyscallType::Ptrace, Utc::now()); + let json = serde_json::to_string(&original).expect("Failed to serialize"); let deserialized: SyscallEvent = serde_json::from_str(&json).expect("Failed to deserialize"); - + assert_eq!(original.pid, deserialized.pid); assert_eq!(original.uid, deserialized.uid); assert_eq!(original.syscall_type, deserialized.syscall_type); @@ -60,33 +49,23 @@ fn test_syscall_event_json_roundtrip() { #[test] fn test_security_event_json_serialize() { - let syscall_event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Mount, - Utc::now(), - ); + let syscall_event = SyscallEvent::new(1234, 1000, SyscallType::Mount, Utc::now()); let security_event = SecurityEvent::Syscall(syscall_event); - + let json = serde_json::to_string(&security_event).expect("Failed to serialize"); - + assert!(json.contains("Syscall")); assert!(json.contains("\"pid\":1234")); } #[test] fn test_security_event_json_roundtrip() { - let syscall_event = SyscallEvent::new( - 9999, - 0, - SyscallType::Setuid, - Utc::now(), - ); + let syscall_event = SyscallEvent::new(9999, 0, SyscallType::Setuid, Utc::now()); let original = SecurityEvent::Syscall(syscall_event); - + let json = serde_json::to_string(&original).expect("Failed to serialize"); let deserialized: SecurityEvent = serde_json::from_str(&json).expect("Failed to deserialize"); - + match deserialized { SecurityEvent::Syscall(e) => { assert_eq!(e.pid, 9999); @@ -106,7 +85,7 @@ fn test_syscall_type_serialization() { SyscallType::Ptrace, SyscallType::Mount, ]; - + for syscall_type in syscall_types { let json = serde_json::to_string(&syscall_type).expect("Failed to serialize"); let deserialized: SyscallType = serde_json::from_str(&json).expect("Failed to deserialize"); @@ -116,21 +95,19 @@ fn test_syscall_type_serialization() { #[test] fn test_syscall_event_with_container_serialization() { - let mut event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); + let mut event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); event.container_id = Some("container_abc123".to_string()); event.comm = Some("/bin/bash".to_string()); - + let json = serde_json::to_string(&event).expect("Failed to serialize"); - + assert!(json.contains("container_abc123")); assert!(json.contains("/bin/bash")); - + let deserialized: SyscallEvent = serde_json::from_str(&json).expect("Failed to deserialize"); - assert_eq!(deserialized.container_id, Some("container_abc123".to_string())); + assert_eq!( + deserialized.container_id, + Some("container_abc123".to_string()) + ); assert_eq!(deserialized.comm, Some("/bin/bash".to_string())); } diff --git a/tests/events/event_stream_test.rs b/tests/events/event_stream_test.rs index 4acbabc..f826844 100644 --- a/tests/events/event_stream_test.rs +++ b/tests/events/event_stream_test.rs @@ -2,10 +2,10 @@ //! //! Tests for event batch, filter, and iterator types -use stackdog::events::syscall::{SyscallEvent, SyscallType}; +use chrono::{Duration, Utc}; use stackdog::events::security::SecurityEvent; use stackdog::events::stream::{EventBatch, EventFilter, EventIterator}; -use chrono::{Utc, Duration}; +use stackdog::events::syscall::{SyscallEvent, SyscallType}; #[test] fn test_event_batch_creation() { @@ -17,14 +17,9 @@ fn test_event_batch_creation() { #[test] fn test_event_batch_add() { let mut batch = EventBatch::new(); - - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - + + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()).into(); + batch.add(event); assert_eq!(batch.len(), 1); assert!(!batch.is_empty()); @@ -33,28 +28,21 @@ fn test_event_batch_add() { #[test] fn test_event_batch_add_multiple() { let mut batch = EventBatch::new(); - + for i in 0..10 { - let event = SyscallEvent::new( - i, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); + let event = SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into(); batch.add(event); } - + assert_eq!(batch.len(), 10); } #[test] fn test_event_batch_from_vec() { let events: Vec = (0..5) - .map(|i| { - SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into() - }) + .map(|i| SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into()) .collect(); - + let batch = EventBatch::from(events.clone()); assert_eq!(batch.len(), 5); } @@ -62,12 +50,12 @@ fn test_event_batch_from_vec() { #[test] fn test_event_batch_clear() { let mut batch = EventBatch::new(); - + for i in 0..3 { let event = SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into(); batch.add(event); } - + assert_eq!(batch.len(), 3); batch.clear(); assert_eq!(batch.len(), 0); @@ -76,15 +64,10 @@ fn test_event_batch_clear() { #[test] fn test_event_filter_default() { let filter = EventFilter::default(); - + // Default filter should match everything - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()).into(); + assert!(filter.matches(&event)); } @@ -92,21 +75,13 @@ fn test_event_filter_default() { fn test_event_filter_by_syscall_type() { let mut filter = EventFilter::new(); filter = filter.with_syscall_type(SyscallType::Execve); - - let execve_event: SecurityEvent = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - - let connect_event: SecurityEvent = SyscallEvent::new( - 1234, - 1000, - SyscallType::Connect, - Utc::now(), - ).into(); - + + let execve_event: SecurityEvent = + SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()).into(); + + let connect_event: SecurityEvent = + SyscallEvent::new(1234, 1000, SyscallType::Connect, Utc::now()).into(); + assert!(filter.matches(&execve_event)); assert!(!filter.matches(&connect_event)); } @@ -115,21 +90,13 @@ fn test_event_filter_by_syscall_type() { fn test_event_filter_by_pid() { let mut filter = EventFilter::new(); filter = filter.with_pid(1234); - - let matching_event: SecurityEvent = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - - let non_matching_event: SecurityEvent = SyscallEvent::new( - 5678, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - + + let matching_event: SecurityEvent = + SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()).into(); + + let non_matching_event: SecurityEvent = + SyscallEvent::new(5678, 1000, SyscallType::Execve, Utc::now()).into(); + assert!(filter.matches(&matching_event)); assert!(!filter.matches(&non_matching_event)); } @@ -141,21 +108,13 @@ fn test_event_filter_chained() { .with_syscall_type(SyscallType::Execve) .with_pid(1234) .with_uid(1000); - - let matching_event: SecurityEvent = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - - let wrong_pid_event: SecurityEvent = SyscallEvent::new( - 5678, - 1000, - SyscallType::Execve, - Utc::now(), - ).into(); - + + let matching_event: SecurityEvent = + SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()).into(); + + let wrong_pid_event: SecurityEvent = + SyscallEvent::new(5678, 1000, SyscallType::Execve, Utc::now()).into(); + assert!(filter.matches(&matching_event)); assert!(!filter.matches(&wrong_pid_event)); } @@ -163,11 +122,9 @@ fn test_event_filter_chained() { #[test] fn test_event_iterator_creation() { let events: Vec = (0..5) - .map(|i| { - SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into() - }) + .map(|i| SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into()) .collect(); - + let iterator = EventIterator::new(events); assert_eq!(iterator.count(), 5); } @@ -175,14 +132,12 @@ fn test_event_iterator_creation() { #[test] fn test_event_iterator_filter() { let events: Vec = (0..10) - .map(|i| { - SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into() - }) + .map(|i| SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into()) .collect(); - + let iterator = EventIterator::new(events); let filter = EventFilter::new().with_pid(5); - + let filtered: Vec<_> = iterator.filter(&filter).collect(); assert_eq!(filtered.len(), 1); assert_eq!(filtered[0].pid().unwrap_or(0), 5); @@ -196,24 +151,22 @@ fn test_event_iterator_time_range() { SyscallEvent::new(2, 1000, SyscallType::Execve, now - Duration::seconds(5)).into(), SyscallEvent::new(3, 1000, SyscallType::Execve, now).into(), ]; - + let iterator = EventIterator::new(events); let start = now - Duration::seconds(6); let filtered: Vec<_> = iterator.time_range(start, now).collect(); - + assert_eq!(filtered.len(), 2); } #[test] fn test_event_iterator_collect() { let events: Vec = (0..5) - .map(|i| { - SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into() - }) + .map(|i| SyscallEvent::new(i, 1000, SyscallType::Execve, Utc::now()).into()) .collect(); - + let iterator = EventIterator::new(events); let collected: Vec<_> = iterator.collect(); - + assert_eq!(collected.len(), 5); } diff --git a/tests/events/event_validation_test.rs b/tests/events/event_validation_test.rs index a2aa6d0..06344d0 100644 --- a/tests/events/event_validation_test.rs +++ b/tests/events/event_validation_test.rs @@ -2,22 +2,15 @@ //! //! Tests for event validation logic +use chrono::Utc; +use stackdog::events::security::{AlertEvent, AlertSeverity, AlertType, NetworkEvent}; use stackdog::events::syscall::{SyscallEvent, SyscallType}; -use stackdog::events::security::{ - NetworkEvent, AlertEvent, AlertType, AlertSeverity, -}; use stackdog::events::validation::{EventValidator, ValidationResult}; -use chrono::Utc; #[test] fn test_valid_syscall_event() { - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); - + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); + let result = EventValidator::validate_syscall(&event); assert!(result.is_valid()); assert_eq!(result, ValidationResult::Valid); @@ -26,12 +19,12 @@ fn test_valid_syscall_event() { #[test] fn test_syscall_event_zero_pid() { let event = SyscallEvent::new( - 0, // kernel thread + 0, // kernel thread 0, SyscallType::Execve, Utc::now(), ); - + let result = EventValidator::validate_syscall(&event); // PID 0 is valid (kernel threads) assert!(result.is_valid()); @@ -48,7 +41,7 @@ fn test_invalid_ip_address() { timestamp: Utc::now(), container_id: None, }; - + let result = EventValidator::validate_network(&event); assert!(!result.is_valid()); assert!(matches!(result, ValidationResult::Invalid(_))); @@ -65,7 +58,7 @@ fn test_valid_ip_addresses() { "::1", "2001:db8::1", ]; - + for ip in valid_ips { let event = NetworkEvent { src_ip: ip.to_string(), @@ -76,32 +69,24 @@ fn test_valid_ip_addresses() { timestamp: Utc::now(), container_id: None, }; - + let result = EventValidator::validate_network(&event); assert!(result.is_valid(), "IP {} should be valid", ip); } } #[test] -fn test_invalid_port() { - let event = NetworkEvent { - src_ip: "192.168.1.1".to_string(), - dst_ip: "10.0.0.1".to_string(), - src_port: 70000, // Invalid port (> 65535) - dst_port: 80, - protocol: "TCP".to_string(), - timestamp: Utc::now(), - container_id: None, - }; - - let result = EventValidator::validate_network(&event); - assert!(!result.is_valid()); +fn test_invalid_port_not_representable_for_u16() { + // NetworkEvent ports are u16, so values > 65535 cannot be constructed. + // This test asserts type-level safety explicitly. + let max = u16::MAX; + assert_eq!(max, 65535); } #[test] fn test_valid_port_range() { let valid_ports = vec![0, 80, 443, 8080, 65535]; - + for port in valid_ports { let event = NetworkEvent { src_ip: "192.168.1.1".to_string(), @@ -112,7 +97,7 @@ fn test_valid_port_range() { timestamp: Utc::now(), container_id: None, }; - + let result = EventValidator::validate_network(&event); assert!(result.is_valid(), "Port {} should be valid", port); } @@ -127,7 +112,7 @@ fn test_alert_event_validation() { timestamp: Utc::now(), source_event_id: None, }; - + let result = EventValidator::validate_alert(&event); assert!(result.is_valid()); } @@ -141,7 +126,7 @@ fn test_alert_empty_message() { timestamp: Utc::now(), source_event_id: None, }; - + let result = EventValidator::validate_alert(&event); assert!(!result.is_valid()); } @@ -157,10 +142,10 @@ fn test_validation_result_error() { fn test_validation_result_display() { let valid = ValidationResult::Valid; assert_eq!(format!("{}", valid), "Valid"); - + let invalid = ValidationResult::Invalid("reason".to_string()); assert!(format!("{}", invalid).contains("Invalid")); - + let error = ValidationResult::Error("error".to_string()); assert!(format!("{}", error).contains("error")); } diff --git a/tests/events/mod.rs b/tests/events/mod.rs index a1d6053..f49bfc2 100644 --- a/tests/events/mod.rs +++ b/tests/events/mod.rs @@ -1,8 +1,8 @@ //! Events module tests -mod syscall_event_test; -mod security_event_test; mod event_conversion_test; mod event_serialization_test; -mod event_validation_test; mod event_stream_test; +mod event_validation_test; +mod security_event_test; +mod syscall_event_test; diff --git a/tests/events/security_event_test.rs b/tests/events/security_event_test.rs index 421d208..f565502 100644 --- a/tests/events/security_event_test.rs +++ b/tests/events/security_event_test.rs @@ -4,22 +4,17 @@ use chrono::Utc; use stackdog::events::security::{ - SecurityEvent, NetworkEvent, ContainerEvent, ContainerEventType, - AlertEvent, AlertType, AlertSeverity, + AlertEvent, AlertSeverity, AlertType, ContainerEvent, ContainerEventType, NetworkEvent, + SecurityEvent, }; use stackdog::events::syscall::{SyscallEvent, SyscallType}; #[test] fn test_security_event_syscall_variant() { - let syscall_event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); - + let syscall_event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); + let security_event = SecurityEvent::Syscall(syscall_event); - + // Test that we can match on the variant match security_event { SecurityEvent::Syscall(e) => { @@ -41,9 +36,9 @@ fn test_security_event_network_variant() { timestamp: Utc::now(), container_id: Some("abc123".to_string()), }; - + let security_event = SecurityEvent::Network(network_event); - + match security_event { SecurityEvent::Network(e) => { assert_eq!(e.src_ip, "192.168.1.1"); @@ -61,9 +56,9 @@ fn test_security_event_container_variant() { timestamp: Utc::now(), details: Some("Container started".to_string()), }; - + let security_event = SecurityEvent::Container(container_event); - + match security_event { SecurityEvent::Container(e) => { assert_eq!(e.container_id, "abc123"); @@ -82,9 +77,9 @@ fn test_security_event_alert_variant() { timestamp: Utc::now(), source_event_id: Some("evt_123".to_string()), }; - + let security_event = SecurityEvent::Alert(alert_event); - + match security_event { SecurityEvent::Alert(e) => { assert_eq!(e.alert_type, AlertType::ThreatDetected); @@ -132,7 +127,7 @@ fn test_network_event_clone() { timestamp: Utc::now(), container_id: Some("abc123".to_string()), }; - + let cloned = event.clone(); assert_eq!(event.src_ip, cloned.src_ip); assert_eq!(event.dst_port, cloned.dst_port); @@ -146,7 +141,7 @@ fn test_container_event_clone() { timestamp: Utc::now(), details: None, }; - + let cloned = event.clone(); assert_eq!(event.container_id, cloned.container_id); assert_eq!(event.event_type, cloned.event_type); @@ -161,7 +156,7 @@ fn test_alert_event_debug() { timestamp: Utc::now(), source_event_id: None, }; - + let debug_str = format!("{:?}", event); assert!(debug_str.contains("AlertEvent")); assert!(debug_str.contains("ThreatDetected")); diff --git a/tests/events/syscall_event_test.rs b/tests/events/syscall_event_test.rs index dc8a554..40cfb1f 100644 --- a/tests/events/syscall_event_test.rs +++ b/tests/events/syscall_event_test.rs @@ -3,7 +3,7 @@ //! Tests for syscall event types, creation, and builder pattern. use chrono::Utc; -use stackdog::events::syscall::{SyscallEvent, SyscallType, SyscallEventBuilder}; +use stackdog::events::syscall::{SyscallEvent, SyscallEventBuilder, SyscallType}; #[test] fn test_syscall_type_variants() { @@ -27,12 +27,12 @@ fn test_syscall_type_variants() { fn test_syscall_event_creation() { let timestamp = Utc::now(); let event = SyscallEvent::new( - 1234, // pid - 1000, // uid + 1234, // pid + 1000, // uid SyscallType::Execve, timestamp, ); - + assert_eq!(event.pid, 1234); assert_eq!(event.uid, 1000); assert_eq!(event.syscall_type, SyscallType::Execve); @@ -44,14 +44,9 @@ fn test_syscall_event_creation() { #[test] fn test_syscall_event_with_container_id() { let timestamp = Utc::now(); - let mut event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - timestamp, - ); + let mut event = SyscallEvent::new(1234, 1000, SyscallType::Execve, timestamp); event.container_id = Some("abc123def456".to_string()); - + assert_eq!(event.container_id, Some("abc123def456".to_string())); } @@ -66,7 +61,7 @@ fn test_syscall_event_builder() { .container_id(Some("abc123".to_string())) .comm(Some("bash".to_string())) .build(); - + assert_eq!(event.pid, 1234); assert_eq!(event.uid, 1000); assert_eq!(event.syscall_type, SyscallType::Execve); @@ -82,7 +77,7 @@ fn test_syscall_event_builder_minimal() { .uid(1000) .syscall_type(SyscallType::Connect) .build(); - + assert_eq!(event.pid, 1234); assert_eq!(event.uid, 1000); assert_eq!(event.syscall_type, SyscallType::Connect); @@ -99,7 +94,7 @@ fn test_syscall_event_builder_default() { .uid(2000) .syscall_type(SyscallType::Open) .build(); - + assert_eq!(event.pid, 5678); assert_eq!(event.uid, 2000); assert_eq!(event.syscall_type, SyscallType::Open); @@ -107,15 +102,10 @@ fn test_syscall_event_builder_default() { #[test] fn test_syscall_event_clone() { - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); - + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); + let cloned = event.clone(); - + assert_eq!(event.pid, cloned.pid); assert_eq!(event.uid, cloned.uid); assert_eq!(event.syscall_type, cloned.syscall_type); @@ -123,13 +113,8 @@ fn test_syscall_event_clone() { #[test] fn test_syscall_event_debug() { - let event = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - Utc::now(), - ); - + let event = SyscallEvent::new(1234, 1000, SyscallType::Execve, Utc::now()); + // Test that Debug trait is implemented let debug_str = format!("{:?}", event); assert!(debug_str.contains("SyscallEvent")); @@ -139,25 +124,10 @@ fn test_syscall_event_debug() { #[test] fn test_syscall_event_partial_eq() { let timestamp = Utc::now(); - let event1 = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - timestamp, - ); - let event2 = SyscallEvent::new( - 1234, - 1000, - SyscallType::Execve, - timestamp, - ); - let event3 = SyscallEvent::new( - 5678, - 1000, - SyscallType::Execve, - timestamp, - ); - + let event1 = SyscallEvent::new(1234, 1000, SyscallType::Execve, timestamp); + let event2 = SyscallEvent::new(1234, 1000, SyscallType::Execve, timestamp); + let event3 = SyscallEvent::new(5678, 1000, SyscallType::Execve, timestamp); + assert_eq!(event1, event2); assert_ne!(event1, event3); } diff --git a/tests/integration.rs b/tests/integration.rs index 53417c7..3c1529d 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -2,6 +2,6 @@ //! //! These tests verify that multiple components work together correctly. +mod collectors; mod events; mod structure; -mod collectors; diff --git a/web/package.json b/web/package.json index 7ba8562..cb65949 100644 --- a/web/package.json +++ b/web/package.json @@ -58,7 +58,9 @@ "@typescript-eslint/parser": "^6.14.0", "@typescript-eslint/eslint-plugin": "^6.14.0", "eslint-plugin-react": "^7.33.2", - "eslint-plugin-react-hooks": "^4.6.0" + "eslint-plugin-react-hooks": "^4.6.0", + "style-loader": "^4.0.0", + "css-loader": "^7.1.2" }, "browserslist": { "production": [ diff --git a/web/src/App.css b/web/src/App.css new file mode 100644 index 0000000..d28a34c --- /dev/null +++ b/web/src/App.css @@ -0,0 +1,8 @@ +.app-layout { + display: flex; + align-items: flex-start; +} + +.app-layout .dashboard { + flex: 1; +} diff --git a/web/src/App.tsx b/web/src/App.tsx index 6acacd6..77163f4 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -1,10 +1,13 @@ import React from 'react'; import Dashboard from './components/Dashboard'; +import Sidebar from './components/Sidebar'; import 'bootstrap/dist/css/bootstrap.min.css'; +import './App.css'; const App: React.FC = () => { return ( -
+
+
); diff --git a/web/src/components/AlertPanel.tsx b/web/src/components/AlertPanel.tsx index 20ccd0c..a17e2d7 100644 --- a/web/src/components/AlertPanel.tsx +++ b/web/src/components/AlertPanel.tsx @@ -1,8 +1,8 @@ import React, { useEffect, useState } from 'react'; import { Card, Button, Form, Table, Badge, Modal, Spinner, Alert as BootstrapAlert, Pagination } from 'react-bootstrap'; -import apiService from '../../services/api'; -import webSocketService from '../../services/websocket'; -import { Alert, AlertSeverity, AlertStatus, AlertFilter, AlertStats } from '../../types/alerts'; +import apiService from '../services/api'; +import webSocketService from '../services/websocket'; +import { Alert, AlertSeverity, AlertStatus, AlertFilter, AlertStats } from '../types/alerts'; import './AlertPanel.css'; const ITEMS_PER_PAGE = 10; @@ -121,7 +121,7 @@ const AlertPanel: React.FC = () => { }; const getSeverityBadge = (severity: AlertSeverity) => { - const variants = { + const variants: Record = { Info: 'info', Low: 'success', Medium: 'warning', @@ -132,7 +132,7 @@ const AlertPanel: React.FC = () => { }; const getStatusBadge = (status: AlertStatus) => { - const variants = { + const variants: Record = { New: 'primary', Acknowledged: 'warning', Resolved: 'success', diff --git a/web/src/components/ContainerList.tsx b/web/src/components/ContainerList.tsx index c2f8e69..04e5874 100644 --- a/web/src/components/ContainerList.tsx +++ b/web/src/components/ContainerList.tsx @@ -1,7 +1,7 @@ import React, { useEffect, useState } from 'react'; -import { Card, Button, Form, Badge, Modal, Spinner, BootstrapAlert } from 'react-bootstrap'; -import apiService from '../../services/api'; -import { Container, ContainerStatus } from '../../types/containers'; +import { Card, Button, Form, Badge, Modal, Spinner, Alert as BootstrapAlert } from 'react-bootstrap'; +import apiService from '../services/api'; +import { Container, ContainerStatus } from '../types/containers'; import './ContainerList.css'; const ContainerList: React.FC = () => { @@ -21,7 +21,7 @@ const ContainerList: React.FC = () => { try { setLoading(true); const data = await apiService.getContainers(); - setContainers(filterStatus ? data.filter(c => c.status === filterStatus) : data); + setContainers(filterStatus ? data.filter((c: Container) => c.status === filterStatus) : data); } catch (err) { console.error('Error loading containers:', err); } finally { @@ -53,7 +53,7 @@ const ContainerList: React.FC = () => { }; const getStatusBadge = (status: ContainerStatus) => { - const variants = { + const variants: Record = { Running: 'success', Stopped: 'secondary', Paused: 'warning', diff --git a/web/src/components/Dashboard.css b/web/src/components/Dashboard.css index 6804cf7..a6cc495 100644 --- a/web/src/components/Dashboard.css +++ b/web/src/components/Dashboard.css @@ -12,11 +12,30 @@ min-height: 400px; } -.dashboard-title { - font-size: 2rem; - font-weight: 700; - color: #2c3e50; - margin-bottom: 0.5rem; +.dashboard-topbar { + display: flex; + align-items: center; + justify-content: space-between; + margin-bottom: 0.75rem; +} + +.dashboard-topbar-spacer { + flex: 1; +} + +.dashboard-actions-btn { + border: 1px solid #d1d5db; + background: #fff; + color: #374151; + border-radius: 8px; + padding: 4px 10px; + font-size: 1.25rem; + line-height: 1; + cursor: pointer; +} + +.dashboard-actions-btn:hover { + background: #f9fafb; } .dashboard-subtitle { @@ -61,10 +80,6 @@ padding: 10px; } - .dashboard-title { - font-size: 1.5rem; - } - .stat-value { font-size: 2rem; } diff --git a/web/src/components/Dashboard.tsx b/web/src/components/Dashboard.tsx index 040649c..42b27f2 100644 --- a/web/src/components/Dashboard.tsx +++ b/web/src/components/Dashboard.tsx @@ -1,8 +1,8 @@ import React, { useEffect, useState } from 'react'; import { Container, Row, Col, Card, Spinner, Alert as BootstrapAlert } from 'react-bootstrap'; -import apiService from '../../services/api'; -import webSocketService from '../../services/websocket'; -import { SecurityStatus } from '../../types/security'; +import apiService from '../services/api'; +import webSocketService from '../services/websocket'; +import { SecurityStatus } from '../types/security'; import SecurityScore from './SecurityScore'; import AlertPanel from './AlertPanel'; import ContainerList from './ContainerList'; @@ -42,7 +42,7 @@ const Dashboard: React.FC = () => { await webSocketService.connect(); // Subscribe to real-time updates - webSocketService.subscribe('stats:updated', (data) => { + webSocketService.subscribe('stats:updated', (data: Partial) => { setSecurityStatus(prev => prev ? { ...prev, ...data } : null); }); @@ -79,7 +79,10 @@ const Dashboard: React.FC = () => { -

šŸ• Stackdog Security Dashboard

+
+
+ +

Real-time security monitoring for containers and Linux servers

@@ -87,7 +90,7 @@ const Dashboard: React.FC = () => { {/* Security Score Card */} - + @@ -124,7 +127,7 @@ const Dashboard: React.FC = () => { {/* Threat Map */} - + @@ -132,10 +135,10 @@ const Dashboard: React.FC = () => { {/* Alerts and Containers */} - + - + diff --git a/web/src/components/Sidebar.css b/web/src/components/Sidebar.css new file mode 100644 index 0000000..0bd26a3 --- /dev/null +++ b/web/src/components/Sidebar.css @@ -0,0 +1,48 @@ +.sidebar { + width: 220px; + min-height: 100vh; + background: #1f2937; + color: #f9fafb; + padding: 20px 16px; + position: sticky; + top: 0; +} + +.sidebar-brand { + display: flex; + align-items: center; + gap: 10px; + font-size: 1.1rem; + font-weight: 700; + margin-bottom: 20px; +} + +.sidebar-logo { + width: 39px; + height: 39px; + object-fit: contain; +} + +.sidebar-nav { + display: flex; + flex-direction: column; + gap: 10px; +} + +.sidebar-nav a { + color: #d1d5db; + text-decoration: none; + padding: 8px 10px; + border-radius: 6px; +} + +.sidebar-nav a:hover { + background: #374151; + color: #fff; +} + +@media (max-width: 992px) { + .sidebar { + display: none; + } +} diff --git a/web/src/components/Sidebar.tsx b/web/src/components/Sidebar.tsx new file mode 100644 index 0000000..c1be24b --- /dev/null +++ b/web/src/components/Sidebar.tsx @@ -0,0 +1,29 @@ +import React from 'react'; +import './Sidebar.css'; + +const DASHBOARD_LOGO_URL = 'https://github.com/user-attachments/assets/0c8a9216-8315-4ef7-9b73-d96c40521ed1'; + +const Sidebar: React.FC = () => { + return ( + + ); +}; + +export default Sidebar; diff --git a/web/src/components/ThreatMap.tsx b/web/src/components/ThreatMap.tsx index 623c83e..83177c7 100644 --- a/web/src/components/ThreatMap.tsx +++ b/web/src/components/ThreatMap.tsx @@ -1,8 +1,8 @@ import React, { useEffect, useState } from 'react'; import { Card, Form, Spinner } from 'react-bootstrap'; import { BarChart, Bar, PieChart, Pie, LineChart, Line, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer, Cell } from 'recharts'; -import apiService from '../../services/api'; -import { Threat, ThreatStatistics } from '../../types/security'; +import apiService from '../services/api'; +import { Threat, ThreatStatistics } from '../types/security'; import './ThreatMap.css'; const COLORS = ['#e74c3c', '#e67e22', '#f39c12', '#3498db', '#27ae60']; @@ -36,7 +36,8 @@ const ThreatMap: React.FC = () => { const getTypeData = () => { if (!statistics) return []; - return Object.entries(statistics.byType).map(([name, value]) => ({ + const byType = statistics.byType || {}; + return Object.entries(byType).map(([name, value]) => ({ name, value, })); @@ -44,7 +45,8 @@ const ThreatMap: React.FC = () => { const getSeverityData = () => { if (!statistics) return []; - return Object.entries(statistics.bySeverity).map(([name, value]) => ({ + const bySeverity = statistics.bySeverity || {}; + return Object.entries(bySeverity).map(([name, value]) => ({ name, value, })); diff --git a/web/src/components/__tests__/ThreatMap.test.tsx b/web/src/components/__tests__/ThreatMap.test.tsx index 95b2c8e..8ee0290 100644 --- a/web/src/components/__tests__/ThreatMap.test.tsx +++ b/web/src/components/__tests__/ThreatMap.test.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { render, screen, waitFor } from '@testing-library/react'; +import { fireEvent, render, screen, waitFor } from '@testing-library/react'; import ThreatMap from '../ThreatMap'; import apiService from '../../services/api'; diff --git a/web/src/services/api.ts b/web/src/services/api.ts index d43ddc2..53c40b0 100644 --- a/web/src/services/api.ts +++ b/web/src/services/api.ts @@ -3,10 +3,19 @@ import { SecurityStatus, Threat, ThreatStatistics } from '../types/security'; import { Alert, AlertStats, AlertFilter } from '../types/alerts'; import { Container, QuarantineRequest } from '../types/containers'; -const API_BASE_URL = process.env.REACT_APP_API_URL || 'http://localhost:5000/api'; +type EnvLike = { + REACT_APP_API_URL?: string; + APP_PORT?: string; + REACT_APP_API_PORT?: string; +}; + +const env = ((globalThis as unknown as { __STACKDOG_ENV__?: EnvLike }).__STACKDOG_ENV__ ?? + {}) as EnvLike; +const apiPort = env.REACT_APP_API_PORT || env.APP_PORT || '5555'; +const API_BASE_URL = env.REACT_APP_API_URL || `http://localhost:${apiPort}/api`; class ApiService { - private api: AxiosInstance; + public api: AxiosInstance; constructor() { this.api = axios.create({ @@ -30,7 +39,7 @@ class ApiService { } async getThreatStatistics(): Promise { - const response = await this.api.get('/statistics'); + const response = await this.api.get('/threats/statistics'); return response.data; } @@ -63,7 +72,32 @@ class ApiService { // Containers async getContainers(): Promise { const response = await this.api.get('/containers'); - return response.data; + const raw = response.data as Array>; + return raw.map((item) => { + const securityStatus = item.securityStatus ?? item.security_status ?? {}; + const networkActivity = item.networkActivity ?? item.network_activity ?? {}; + + return { + id: item.id ?? '', + name: item.name ?? item.id ?? 'unknown', + image: item.image ?? 'unknown', + status: item.status ?? 'Running', + securityStatus: { + state: securityStatus.state ?? 'Secure', + threats: securityStatus.threats ?? 0, + vulnerabilities: securityStatus.vulnerabilities ?? 0, + lastScan: securityStatus.lastScan ?? new Date().toISOString(), + }, + riskScore: item.riskScore ?? item.risk_score ?? 0, + networkActivity: { + inboundConnections: networkActivity.inboundConnections ?? networkActivity.inbound_connections ?? 0, + outboundConnections: networkActivity.outboundConnections ?? networkActivity.outbound_connections ?? 0, + blockedConnections: networkActivity.blockedConnections ?? networkActivity.blocked_connections ?? 0, + suspiciousActivity: networkActivity.suspiciousActivity ?? networkActivity.suspicious_activity ?? false, + }, + createdAt: item.createdAt ?? item.created_at ?? new Date().toISOString(), + } as Container; + }); } async quarantineContainer(request: QuarantineRequest): Promise { diff --git a/web/src/services/websocket.ts b/web/src/services/websocket.ts index 56d6bb0..7513591 100644 --- a/web/src/services/websocket.ts +++ b/web/src/services/websocket.ts @@ -6,6 +6,17 @@ type WebSocketEvent = | 'stats:updated'; type EventHandler = (data: any) => void; +type EnvLike = { + REACT_APP_WS_URL?: string; + APP_PORT?: string; + REACT_APP_API_PORT?: string; +}; + +declare global { + interface Window { + __STACKDOG_ENV__?: EnvLike; + } +} export class WebSocketService { private ws: WebSocket | null = null; @@ -15,14 +26,22 @@ export class WebSocketService { private reconnectDelay = 1000; private eventHandlers: Map> = new Map(); private shouldReconnect = true; + private failedInitialConnect = false; constructor(url?: string) { - this.url = url || process.env.REACT_APP_WS_URL || 'ws://localhost:5000/ws'; + const env = ((globalThis as { __STACKDOG_ENV__?: EnvLike }).__STACKDOG_ENV__ ?? + {}) as EnvLike; + const apiPort = env.REACT_APP_API_PORT || env.APP_PORT || '5555'; + this.url = url || env.REACT_APP_WS_URL || `ws://localhost:${apiPort}/ws`; } connect(): Promise { return new Promise((resolve, reject) => { try { + if (this.failedInitialConnect) { + resolve(); + return; + } this.ws = new WebSocket(this.url); this.ws.onopen = () => { @@ -42,17 +61,23 @@ export class WebSocketService { this.ws.onclose = () => { console.log('WebSocket disconnected'); - if (this.shouldReconnect && this.reconnectAttempts < this.maxReconnectAttempts) { + if (!this.failedInitialConnect && this.shouldReconnect && this.reconnectAttempts < this.maxReconnectAttempts) { this.scheduleReconnect(); } }; this.ws.onerror = (error) => { - console.error('WebSocket error:', error); - reject(error); + // WebSocket endpoint may be intentionally unavailable in some environments. + // Fall back to REST-only mode after the first failed connect. + this.failedInitialConnect = true; + this.shouldReconnect = false; + console.warn('WebSocket unavailable, running in polling mode'); + resolve(); }; } catch (error) { - reject(error); + this.failedInitialConnect = true; + this.shouldReconnect = false; + resolve(); } }); } @@ -96,6 +121,7 @@ export class WebSocketService { disconnect(): void { this.shouldReconnect = false; + this.failedInitialConnect = false; if (this.ws) { this.ws.close(); this.ws = null; diff --git a/web/src/setupTests.ts b/web/src/setupTests.ts index ebb3e62..68cddd9 100644 --- a/web/src/setupTests.ts +++ b/web/src/setupTests.ts @@ -1,15 +1,25 @@ import '@testing-library/jest-dom'; // Mock WebSocket -global.WebSocket = class MockWebSocket { - constructor(url: string) { - this.url = url; - } +class MockWebSocket { + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + url: string; + readyState = MockWebSocket.OPEN; send = jest.fn(); close = jest.fn(); addEventListener = jest.fn(); removeEventListener = jest.fn(); -}; + + constructor(url: string) { + this.url = url; + } +} + +global.WebSocket = MockWebSocket as unknown as typeof WebSocket; // Mock fetch global.fetch = jest.fn(); diff --git a/web/webpack.config.js b/web/webpack.config.js new file mode 100644 index 0000000..b0d56ac --- /dev/null +++ b/web/webpack.config.js @@ -0,0 +1,49 @@ +const path = require('path'); +const HtmlWebpackPlugin = require('html-webpack-plugin'); +const { CleanWebpackPlugin } = require('clean-webpack-plugin'); +const webpack = require('webpack'); + +module.exports = { + entry: './src/index.tsx', + output: { + path: path.resolve(__dirname, 'dist'), + filename: 'bundle.[contenthash].js', + publicPath: '/', + }, + resolve: { + extensions: ['.tsx', '.ts', '.js'], + }, + module: { + rules: [ + { + test: /\.tsx?$/, + use: 'ts-loader', + exclude: /node_modules/, + }, + { + test: /\.css$/, + use: ['style-loader', 'css-loader'], + }, + ], + }, + plugins: [ + new CleanWebpackPlugin(), + new webpack.DefinePlugin({ + __STACKDOG_ENV__: JSON.stringify({ + REACT_APP_API_URL: process.env.REACT_APP_API_URL || '', + REACT_APP_WS_URL: process.env.REACT_APP_WS_URL || '', + APP_PORT: process.env.APP_PORT || '', + REACT_APP_API_PORT: process.env.REACT_APP_API_PORT || '', + }), + }), + new HtmlWebpackPlugin({ + templateContent: + 'Stackdog
', + }), + ], + devServer: { + static: path.resolve(__dirname, 'dist'), + historyApiFallback: true, + port: 3000, + }, +};