diff --git a/crates/stackforge-core/src/flow/icmp_state.rs b/crates/stackforge-core/src/flow/icmp_state.rs new file mode 100644 index 0000000..e682e27 --- /dev/null +++ b/crates/stackforge-core/src/flow/icmp_state.rs @@ -0,0 +1,144 @@ +use std::time::Duration; + +use crate::Packet; + +use super::config::FlowConfig; +use super::state::ConversationStatus; + +/// ICMP/ICMPv6 conversation state. +/// +/// Tracks ICMP-specific metadata for echo request/reply pairs and other ICMP types. +/// Echo requests and replies are correlated using the ICMP identifier field. +#[derive(Debug, Clone)] +pub struct IcmpFlowState { + /// ICMP type (e.g., 8 for Echo Request, 0 for Echo Reply). + pub icmp_type: u8, + /// ICMP code. + pub icmp_code: u8, + /// ICMP identifier (for echo, timestamp, and other types that use it). + pub identifier: Option, + /// Number of echo requests (type 8 for ICMP, 128 for ICMPv6). + pub request_count: u64, + /// Number of echo replies (type 0 for ICMP, 129 for ICMPv6). + pub reply_count: u64, + /// Last sequence number seen in an echo packet. + pub last_seq: Option, + /// Conversation status. + pub status: ConversationStatus, +} + +impl IcmpFlowState { + #[must_use] + pub fn new(icmp_type: u8, icmp_code: u8) -> Self { + Self { + icmp_type, + icmp_code, + identifier: None, + request_count: 0, + reply_count: 0, + last_seq: None, + status: ConversationStatus::Active, + } + } + + /// Update state when a new ICMP packet is received. + /// + /// Increments request or reply count based on ICMP type, and updates + /// the identifier and sequence number fields if present. + pub fn process_packet(&mut self, packet: &Packet, buf: &[u8], icmp_type: u8, icmp_code: u8) { + // Update type/code on every packet (they should be consistent) + self.icmp_type = icmp_type; + self.icmp_code = icmp_code; + + // Get ICMP layer bounds to extract fields + if let Some(icmp_layer) = crate::layer::LayerKind::Icmp + .try_into() + .ok() + .and_then(|kind| packet.get_layer(kind)) + { + let icmp_start = icmp_layer.start; + + // Extract identifier (bytes 4-5) if present + if buf.len() >= icmp_start + 6 { + self.identifier = Some(u16::from_be_bytes([ + buf[icmp_start + 4], + buf[icmp_start + 5], + ])); + } + + // Extract sequence number (bytes 6-7) if present + if buf.len() >= icmp_start + 8 { + self.last_seq = Some(u16::from_be_bytes([ + buf[icmp_start + 6], + buf[icmp_start + 7], + ])); + } + + // Count requests and replies based on ICMP type + match icmp_type { + 8 => { + // ICMP Echo Request + self.request_count += 1; + }, + 0 => { + // ICMP Echo Reply + self.reply_count += 1; + }, + 128 => { + // ICMPv6 Echo Request + self.request_count += 1; + }, + 129 => { + // ICMPv6 Echo Reply + self.reply_count += 1; + }, + _ => { + // Other ICMP types: no counting + }, + } + } + + self.status = ConversationStatus::Active; + } + + /// Check whether this flow has timed out. + #[must_use] + pub fn check_timeout(&self, last_seen: Duration, now: Duration, config: &FlowConfig) -> bool { + // ICMP uses UDP timeout + now.saturating_sub(last_seen) > config.udp_timeout + } +} + +impl Default for IcmpFlowState { + fn default() -> Self { + Self::new(0, 0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_icmp_state_new() { + let state = IcmpFlowState::new(8, 0); + assert_eq!(state.icmp_type, 8); + assert_eq!(state.icmp_code, 0); + assert_eq!(state.request_count, 0); + assert_eq!(state.reply_count, 0); + assert_eq!(state.identifier, None); + assert_eq!(state.last_seq, None); + } + + #[test] + fn test_icmp_timeout() { + let config = FlowConfig::default(); // 120s UDP timeout + let state = IcmpFlowState::new(8, 0); + + // Not timed out + assert!(!state.check_timeout(Duration::from_secs(100), Duration::from_secs(200), &config)); + + // Timed out + assert!(state.check_timeout(Duration::from_secs(100), Duration::from_secs(300), &config)); + } +} diff --git a/crates/stackforge-core/src/flow/key.rs b/crates/stackforge-core/src/flow/key.rs index fda292a..c8ce30a 100644 --- a/crates/stackforge-core/src/flow/key.rs +++ b/crates/stackforge-core/src/flow/key.rs @@ -291,7 +291,59 @@ pub fn extract_key(packet: &Packet) -> Result<(CanonicalKey, FlowDirection), Flo .map_err(|e| FlowError::PacketError(e.into()))?; (sport, dport) }, - // ICMP and other protocols have no ports + TransportProtocol::Icmp => { + // For ICMP, use identifier (for echo/timestamp types) for both ports + // (symmetric), or type+code as port substitute for other types. + // Using identifier symmetrically ensures request and reply have + // the same canonical key regardless of direction. + if let Some(icmp_layer) = packet.get_layer(LayerKind::Icmp) { + if buf.len() >= icmp_layer.start + 8 { + let icmp_type = buf[icmp_layer.start]; + let is_echo = icmp_type == 0 || icmp_type == 8; + if is_echo { + let id = u16::from_be_bytes([ + buf[icmp_layer.start + 4], + buf[icmp_layer.start + 5], + ]); + (id, id) // Use identifier symmetrically for both ports + } else { + let code = buf[icmp_layer.start + 1]; + (icmp_type as u16, code as u16) + } + } else { + (0u16, 0u16) + } + } else { + (0u16, 0u16) + } + }, + TransportProtocol::Icmpv6 => { + // For ICMPv6, use identifier (for echo/timestamp types) for both ports + // (symmetric), or type+code as port substitute for other types. + // Using identifier symmetrically ensures request and reply have + // the same canonical key regardless of direction. + if let Some(icmpv6_layer) = packet.get_layer(LayerKind::Icmpv6) { + if buf.len() >= icmpv6_layer.start + 8 { + let icmpv6_type = buf[icmpv6_layer.start]; + let is_echo = icmpv6_type == 128 || icmpv6_type == 129; + if is_echo { + let id = u16::from_be_bytes([ + buf[icmpv6_layer.start + 4], + buf[icmpv6_layer.start + 5], + ]); + (id, id) // Use identifier symmetrically for both ports + } else { + let code = buf[icmpv6_layer.start + 1]; + (icmpv6_type as u16, code as u16) + } + } else { + (0u16, 0u16) + } + } else { + (0u16, 0u16) + } + }, + // Other protocols have no ports _ => (0u16, 0u16), }; diff --git a/crates/stackforge-core/src/flow/mod.rs b/crates/stackforge-core/src/flow/mod.rs index 642d79b..9892dcc 100644 --- a/crates/stackforge-core/src/flow/mod.rs +++ b/crates/stackforge-core/src/flow/mod.rs @@ -27,6 +27,7 @@ pub mod config; pub mod error; +pub mod icmp_state; pub mod key; pub mod state; pub mod table; @@ -37,6 +38,7 @@ pub mod udp_state; // Re-exports pub use config::FlowConfig; pub use error::FlowError; +pub use icmp_state::IcmpFlowState; pub use key::{ CanonicalKey, FlowDirection, TransportProtocol, ZWaveKey, extract_key, extract_zwave_key, }; diff --git a/crates/stackforge-core/src/flow/state.rs b/crates/stackforge-core/src/flow/state.rs index b57f08b..de8dffb 100644 --- a/crates/stackforge-core/src/flow/state.rs +++ b/crates/stackforge-core/src/flow/state.rs @@ -1,6 +1,7 @@ use std::time::Duration; use super::config::FlowConfig; +use super::icmp_state::IcmpFlowState; use super::key::{CanonicalKey, FlowDirection, TransportProtocol}; use super::tcp_state::TcpConversationState; use super::udp_state::UdpFlowState; @@ -75,9 +76,13 @@ pub enum ProtocolState { Tcp(TcpConversationState), /// UDP pseudo-conversation with timeout tracking. Udp(UdpFlowState), + /// ICMP conversation with echo request/reply tracking. + Icmp(IcmpFlowState), + /// ICMPv6 conversation with echo request/reply tracking. + Icmpv6(IcmpFlowState), /// Z-Wave wireless conversation with home ID and node tracking. ZWave(ZWaveFlowState), - /// Other protocols (ICMP, etc.) — no specific state tracking. + /// Other protocols — no specific state tracking. Other, } @@ -124,6 +129,8 @@ impl ConversationState { let protocol_state = match key.protocol { TransportProtocol::Tcp => ProtocolState::Tcp(TcpConversationState::new()), TransportProtocol::Udp => ProtocolState::Udp(UdpFlowState::new()), + TransportProtocol::Icmp => ProtocolState::Icmp(IcmpFlowState::new(0, 0)), + TransportProtocol::Icmpv6 => ProtocolState::Icmpv6(IcmpFlowState::new(0, 0)), _ => ProtocolState::Other, }; @@ -230,6 +237,12 @@ impl ConversationState { ProtocolState::Udp(udp) => { self.status = udp.status; }, + ProtocolState::Icmp(icmp) => { + self.status = icmp.status; + }, + ProtocolState::Icmpv6(icmpv6) => { + self.status = icmpv6.status; + }, ProtocolState::ZWave(_) => {}, ProtocolState::Other => {}, } @@ -250,6 +263,7 @@ impl ConversationState { } }, ProtocolState::Udp(_) => elapsed > config.udp_timeout, + ProtocolState::Icmp(_) | ProtocolState::Icmpv6(_) => elapsed > config.udp_timeout, ProtocolState::ZWave(_) => elapsed > config.udp_timeout, ProtocolState::Other => elapsed > config.udp_timeout, } diff --git a/crates/stackforge-core/src/flow/table.rs b/crates/stackforge-core/src/flow/table.rs index 7a4a0b8..8d595dc 100644 --- a/crates/stackforge-core/src/flow/table.rs +++ b/crates/stackforge-core/src/flow/table.rs @@ -85,6 +85,26 @@ impl ConversationTable { ProtocolState::Udp(udp_state) => { udp_state.process_packet(); }, + ProtocolState::Icmp(icmp_state) => { + // Get ICMP type and code from buffer + if let Some(icmp_layer) = packet.get_layer(crate::layer::LayerKind::Icmp) { + if buf.len() >= icmp_layer.start + 2 { + let icmp_type = buf[icmp_layer.start]; + let icmp_code = buf[icmp_layer.start + 1]; + icmp_state.process_packet(packet, buf, icmp_type, icmp_code); + } + } + }, + ProtocolState::Icmpv6(icmpv6_state) => { + // Get ICMPv6 type and code from buffer + if let Some(icmpv6_layer) = packet.get_layer(crate::layer::LayerKind::Icmpv6) { + if buf.len() >= icmpv6_layer.start + 2 { + let icmpv6_type = buf[icmpv6_layer.start]; + let icmpv6_code = buf[icmpv6_layer.start + 1]; + icmpv6_state.process_packet(packet, buf, icmpv6_type, icmpv6_code); + } + } + }, ProtocolState::ZWave(_) => {}, ProtocolState::Other => {}, } diff --git a/src/lib.rs b/src/lib.rs index 77948fc..0f6d12f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4179,6 +4179,66 @@ impl PyConversation { } } + /// ICMP type number, or None for non-ICMP flows. + #[getter] + fn icmp_type(&self) -> Option { + match &self.inner.protocol_state { + stackforge_core::ProtocolState::Icmp(icmp) + | stackforge_core::ProtocolState::Icmpv6(icmp) => Some(icmp.icmp_type), + _ => None, + } + } + + /// ICMP code, or None for non-ICMP flows. + #[getter] + fn icmp_code(&self) -> Option { + match &self.inner.protocol_state { + stackforge_core::ProtocolState::Icmp(icmp) + | stackforge_core::ProtocolState::Icmpv6(icmp) => Some(icmp.icmp_code), + _ => None, + } + } + + /// ICMP identifier (for echo sessions), or None for non-ICMP flows. + #[getter] + fn icmp_identifier(&self) -> Option { + match &self.inner.protocol_state { + stackforge_core::ProtocolState::Icmp(icmp) + | stackforge_core::ProtocolState::Icmpv6(icmp) => icmp.identifier, + _ => None, + } + } + + /// ICMP echo request count, or None for non-ICMP flows. + #[getter] + fn icmp_request_count(&self) -> Option { + match &self.inner.protocol_state { + stackforge_core::ProtocolState::Icmp(icmp) + | stackforge_core::ProtocolState::Icmpv6(icmp) => Some(icmp.request_count), + _ => None, + } + } + + /// ICMP echo reply count, or None for non-ICMP flows. + #[getter] + fn icmp_reply_count(&self) -> Option { + match &self.inner.protocol_state { + stackforge_core::ProtocolState::Icmp(icmp) + | stackforge_core::ProtocolState::Icmpv6(icmp) => Some(icmp.reply_count), + _ => None, + } + } + + /// ICMP last sequence number seen, or None for non-ICMP flows. + #[getter] + fn icmp_last_seq(&self) -> Option { + match &self.inner.protocol_state { + stackforge_core::ProtocolState::Icmp(icmp) + | stackforge_core::ProtocolState::Icmpv6(icmp) => icmp.last_seq, + _ => None, + } + } + /// Reassembled forward TCP stream data, or None. #[getter] fn reassembled_forward<'py>( diff --git a/tests/python/test_flow.py b/tests/python/test_flow.py index e023954..6c513c8 100644 --- a/tests/python/test_flow.py +++ b/tests/python/test_flow.py @@ -5,6 +5,7 @@ import pytest from stackforge import ( + ICMP, IP, TCP, UDP, @@ -368,3 +369,137 @@ def test_http_pcap_conversations(self): assert conv.dst_port >= 0 assert conv.forward_bytes + conv.reverse_bytes == conv.total_bytes assert conv.forward_packets + conv.reverse_packets == conv.total_packets + + +# ============================================================================ +# Test: ICMP flow extraction +# ============================================================================ + + +def _build_icmp_echo_packet(src_ip, dst_ip, icmp_id, seq, is_reply=False): + """Build an ICMP Echo Request or Reply packet.""" + if is_reply: + icmp_layer = ICMP.echo_reply(id=icmp_id, seq=seq) + else: + icmp_layer = ICMP.echo_request(id=icmp_id, seq=seq) + pkt = Ether() / IP(src=src_ip, dst=dst_ip) / icmp_layer + built = pkt.build() + built.parse() + return built + + +class TestICMPFlows: + def test_icmp_echo_pair_correlation(self): + """Test that ICMP echo request and reply are correlated as a single flow.""" + # Create echo request and reply + req = _build_icmp_echo_packet("192.168.1.1", "192.168.1.2", icmp_id=1234, seq=1) + reply = _build_icmp_echo_packet( + "192.168.1.2", "192.168.1.1", icmp_id=1234, seq=1, is_reply=True + ) + + conversations = extract_flows_from_packets([req, reply]) + + # Should have exactly 1 ICMP conversation (both directions same flow) + icmp_flows = [c for c in conversations if c.protocol == "ICMP"] + assert len(icmp_flows) == 1, f"Expected 1 ICMP flow, got {len(icmp_flows)}" + + conv = icmp_flows[0] + + # Verify protocol + assert conv.protocol == "ICMP" + + # Verify ICMP-specific fields + assert conv.icmp_type is not None + assert conv.icmp_code is not None + assert conv.icmp_identifier == 1234 + assert conv.icmp_request_count == 1 + assert conv.icmp_reply_count == 1 + assert conv.icmp_last_seq == 1 + + # Verify packet counts (1 forward, 1 reverse) + assert conv.total_packets == 2 + assert conv.forward_packets == 1 + assert conv.reverse_packets == 1 + + def test_icmp_multiple_sequences(self): + """Test ICMP flow with multiple sequence numbers.""" + packets = [] + for seq in range(1, 4): + req = _build_icmp_echo_packet("10.0.0.1", "10.0.0.2", icmp_id=5678, seq=seq) + reply = _build_icmp_echo_packet( + "10.0.0.2", "10.0.0.1", icmp_id=5678, seq=seq, is_reply=True + ) + packets.extend([req, reply]) + + conversations = extract_flows_from_packets(packets) + icmp_flows = [c for c in conversations if c.protocol == "ICMP"] + + assert len(icmp_flows) == 1 + conv = icmp_flows[0] + + assert conv.icmp_identifier == 5678 + assert conv.icmp_request_count == 3 + assert conv.icmp_reply_count == 3 + assert conv.icmp_last_seq == 3 + assert conv.total_packets == 6 + + def test_icmp_different_identifiers_separate_flows(self): + """Test that ICMP flows with different identifiers are tracked separately.""" + # Two separate echo sessions + packets = [] + for icmp_id in [1111, 2222]: + req = _build_icmp_echo_packet("10.0.0.1", "10.0.0.2", icmp_id=icmp_id, seq=1) + reply = _build_icmp_echo_packet( + "10.0.0.2", "10.0.0.1", icmp_id=icmp_id, seq=1, is_reply=True + ) + packets.extend([req, reply]) + + conversations = extract_flows_from_packets(packets) + icmp_flows = [c for c in conversations if c.protocol == "ICMP"] + + # Should have 2 separate conversations (different identifiers) + assert len(icmp_flows) == 2, f"Expected 2 ICMP flows, got {len(icmp_flows)}" + + # Verify each has its own identifier + identifiers = sorted([c.icmp_identifier for c in icmp_flows]) + assert identifiers == [1111, 2222] + + def test_icmp_non_echo_fields(self): + """Test that non-echo ICMP types are tracked.""" + # Create a dest unreachable (type 3, code 1 = host unreachable) + pkt = Ether() / IP(src="10.0.0.1", dst="10.0.0.2") / ICMP(type=3, code=1) + built = pkt.build() + built.parse() + + conversations = extract_flows_from_packets([built]) + icmp_flows = [c for c in conversations if c.protocol == "ICMP"] + + assert len(icmp_flows) == 1 + conv = icmp_flows[0] + + assert conv.icmp_type == 3 + assert conv.icmp_code == 1 + # For non-echo, request_count and reply_count should be 0 + assert conv.icmp_request_count == 0 + assert conv.icmp_reply_count == 0 + + def test_icmp_python_getters_none_for_non_icmp(self): + """Test that ICMP getters return None for non-ICMP conversations.""" + # Create a UDP packet + pkt = Ether() / IP(src="10.0.0.1", dst="10.0.0.2") / UDP(sport=53, dport=12345) + built = pkt.build() + built.parse() + + conversations = extract_flows_from_packets([built]) + udp_flows = [c for c in conversations if c.protocol == "UDP"] + + assert len(udp_flows) == 1 + conv = udp_flows[0] + + # All ICMP properties should be None for UDP + assert conv.icmp_type is None + assert conv.icmp_code is None + assert conv.icmp_identifier is None + assert conv.icmp_request_count is None + assert conv.icmp_reply_count is None + assert conv.icmp_last_seq is None