diff --git a/src/asttools.rs b/src/asttools.rs new file mode 100644 index 000000000..868a8bdf5 --- /dev/null +++ b/src/asttools.rs @@ -0,0 +1,117 @@ +//! AST traversal and analysis utilities. +//! +//! This module provides helper functions and macros for traversing +//! and analyzing the AST tree structure. + +use crate::node::Node; + +/// Gets an ancestor at a specific level above the current node. +/// +/// # Arguments +/// * `node` - The starting node +/// * `level` - How many levels up to traverse (0 returns the node itself) +/// +/// # Returns +/// The ancestor node at the specified level, or None if the tree isn't deep enough. +/// +/// # Example +/// ```ignore +/// // Get the grandparent (2 levels up) +/// if let Some(grandparent) = get_parent(&node, 2) { +/// println!("Grandparent kind: {}", grandparent.kind()); +/// } +/// ``` +pub fn get_parent<'a>(node: &Node<'a>, level: usize) -> Option> { + let mut level = level; + let mut current = *node; + while level != 0 { + current = current.parent()?; + level -= 1; + } + Some(current) +} + +/// Checks if a node has specific ancestors in sequence. +/// +/// This macro checks if the node's ancestors match a specific pattern, +/// where the first pattern(s) are immediate ancestors and the last pattern +/// is the final ancestor to match. +/// +/// # Example +/// ```ignore +/// // Check if node is inside a function inside a class +/// let is_method = has_ancestors!(node, Class | Struct, Function); +/// ``` +#[macro_export] +macro_rules! has_ancestors { + ($node:expr, $( $typs:pat_param )|*, $( $typ:pat_param ),+) => {{ + let mut res = false; + loop { + let mut node = *$node; + $( + if let Some(parent) = node.parent() { + match parent.kind_id().into() { + $typ => { + node = parent; + }, + _ => { + break; + } + } + } else { + break; + } + )* + if let Some(parent) = node.parent() { + match parent.kind_id().into() { + $( $typs )|+ => { + res = true; + }, + _ => {} + } + } + break; + } + res + }}; +} + +/// Counts specific ancestors matching a pattern until a stop condition. +/// +/// This macro traverses up the tree counting ancestors that match the given +/// patterns, stopping when it encounters an ancestor matching the stop pattern. +/// +/// # Example +/// ```ignore +/// // Count nested if statements until we hit a function boundary +/// let nesting = count_specific_ancestors!(node, If | ElseIf, Function | Method); +/// ``` +#[macro_export] +macro_rules! count_specific_ancestors { + ($node:expr, $checker:ty, $( $typs:pat_param )|*, $( $stops:pat_param )|*) => {{ + let mut count = 0; + let mut node = *$node; + while let Some(parent) = node.parent() { + match parent.kind_id().into() { + $( $typs )|* => { + if !<$checker>::is_else_if(&parent) { + count += 1; + } + }, + $( $stops )|* => break, + _ => {} + } + node = parent; + } + count + }}; +} + +#[cfg(test)] +mod tests { + #[test] + fn test_get_parent_level_zero() { + // Level 0 should return the same node + // (actual test would need a real node) + } +} diff --git a/src/getter.rs b/src/getter.rs index c4cc4fc86..c30d0ae40 100644 --- a/src/getter.rs +++ b/src/getter.rs @@ -438,31 +438,30 @@ impl Getter for CppCode { return std::str::from_utf8(code).ok(); } // we're in a function_definition so need to get the declarator - if let Some(declarator) = node.child_by_field_name("declarator") { - let declarator_node = declarator; - if let Some(fd) = declarator_node.first_occurrence(|id| { + if let Some(declarator) = node.child_by_field_name("declarator") + && let Some(fd) = declarator.first_occurrence(|id| { Cpp::FunctionDeclarator == id || Cpp::FunctionDeclarator2 == id || Cpp::FunctionDeclarator3 == id - }) && let Some(first) = fd.child(0) - { - match first.kind_id().into() { - Cpp::TypeIdentifier - | Cpp::Identifier - | Cpp::FieldIdentifier - | Cpp::DestructorName - | Cpp::OperatorName - | Cpp::QualifiedIdentifier - | Cpp::QualifiedIdentifier2 - | Cpp::QualifiedIdentifier3 - | Cpp::QualifiedIdentifier4 - | Cpp::TemplateFunction - | Cpp::TemplateMethod => { - let code = &code[first.start_byte()..first.end_byte()]; - return std::str::from_utf8(code).ok(); - } - _ => {} + }) + && let Some(first) = fd.child(0) + { + match first.kind_id().into() { + Cpp::TypeIdentifier + | Cpp::Identifier + | Cpp::FieldIdentifier + | Cpp::DestructorName + | Cpp::OperatorName + | Cpp::QualifiedIdentifier + | Cpp::QualifiedIdentifier2 + | Cpp::QualifiedIdentifier3 + | Cpp::QualifiedIdentifier4 + | Cpp::TemplateFunction + | Cpp::TemplateMethod => { + let code = &code[first.start_byte()..first.end_byte()]; + return std::str::from_utf8(code).ok(); } + _ => {} } } } diff --git a/src/lib.rs b/src/lib.rs index 923210d28..cdd7ae037 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -84,6 +84,10 @@ pub use crate::function::*; mod ast; pub use crate::ast::*; +/// AST traversal and analysis utilities. +pub mod asttools; +pub use crate::asttools::get_parent; + mod count; pub use crate::count::*; diff --git a/src/node.rs b/src/node.rs index 49b8df104..97777e47c 100644 --- a/src/node.rs +++ b/src/node.rs @@ -37,49 +37,60 @@ impl<'a> Node<'a> { self.0.has_error() } - pub(crate) fn id(&self) -> usize { + /// Returns a numeric id for this node that is unique within its tree. + pub fn id(&self) -> usize { self.0.id() } - pub(crate) fn kind(&self) -> &'static str { + /// Returns the node's type as a string. + pub fn kind(&self) -> &'static str { self.0.kind() } - pub(crate) fn kind_id(&self) -> u16 { + /// Returns the node's type as a numeric id. + pub fn kind_id(&self) -> u16 { self.0.kind_id() } - pub(crate) fn utf8_text(&self, data: &'a [u8]) -> Option<&'a str> { + /// Returns the node's text as a UTF-8 string, if valid. + pub fn utf8_text(&self, data: &'a [u8]) -> Option<&'a str> { self.0.utf8_text(data).ok() } - pub(crate) fn start_byte(&self) -> usize { + /// Returns the byte offset where this node starts. + pub fn start_byte(&self) -> usize { self.0.start_byte() } - pub(crate) fn end_byte(&self) -> usize { + /// Returns the byte offset where this node ends. + pub fn end_byte(&self) -> usize { self.0.end_byte() } - pub(crate) fn start_position(&self) -> (usize, usize) { + /// Returns the (row, column) position where this node starts. + pub fn start_position(&self) -> (usize, usize) { let temp = self.0.start_position(); (temp.row, temp.column) } - pub(crate) fn end_position(&self) -> (usize, usize) { + /// Returns the (row, column) position where this node ends. + pub fn end_position(&self) -> (usize, usize) { let temp = self.0.end_position(); (temp.row, temp.column) } - pub(crate) fn start_row(&self) -> usize { + /// Returns the row number where this node starts. + pub fn start_row(&self) -> usize { self.0.start_position().row } - pub(crate) fn end_row(&self) -> usize { + /// Returns the row number where this node ends. + pub fn end_row(&self) -> usize { self.0.end_position().row } - pub(crate) fn parent(&self) -> Option> { + /// Returns this node's parent, if any. + pub fn parent(&self) -> Option> { self.0.parent().map(Node) } @@ -183,6 +194,21 @@ impl<'a> Node<'a> { } res } + + /// Checks if this node has any ancestor that meets the given predicate. + /// + /// Traverses up the tree from this node's parent to the root, + /// returning true if any ancestor satisfies the predicate. + pub fn has_ancestor bool>(&self, pred: F) -> bool { + let mut node = *self; + while let Some(parent) = node.parent() { + if pred(&parent) { + return true; + } + node = parent; + } + false + } } /// An `AST` cursor. @@ -236,6 +262,35 @@ impl<'a> Search<'a> for Node<'a> { None } + fn all_occurrences(&self, pred: fn(u16) -> bool) -> Vec> { + let mut cursor = self.cursor(); + let mut stack = Vec::new(); + let mut children = Vec::new(); + let mut results = Vec::new(); + + stack.push(*self); + + while let Some(node) = stack.pop() { + if pred(node.kind_id()) { + results.push(node); + } + cursor.reset(&node); + if cursor.goto_first_child() { + loop { + children.push(cursor.node()); + if !cursor.goto_next_sibling() { + break; + } + } + for child in children.drain(..).rev() { + stack.push(child); + } + } + } + + results + } + fn act_on_node(&self, action: &mut dyn FnMut(&Node<'a>)) { let mut cursor = self.cursor(); let mut stack = Vec::new(); diff --git a/src/traits.rs b/src/traits.rs index 16d4ed9cb..bf115fdb1 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -66,9 +66,20 @@ pub trait ParserTrait { fn get_filters(&self, filters: &[String]) -> Filter; } -pub(crate) trait Search<'a> { +/// Search trait for AST node traversal. +pub trait Search<'a> { + /// Starting from this node, gets the first occurrence that meets the predicate. fn first_occurrence(&self, pred: fn(u16) -> bool) -> Option>; + + /// Starting from this node, gets all nodes that meet the given predicate. + fn all_occurrences(&self, pred: fn(u16) -> bool) -> Vec>; + + /// Apply the given predicate on this node and all descendants. fn act_on_node(&self, pred: &mut dyn FnMut(&Node<'a>)); + + /// Starting from this node, gets the first child that meets the predicate. fn first_child(&self, pred: fn(u16) -> bool) -> Option>; + + /// Apply the given action on node's immediate children. fn act_on_child(&self, action: &mut dyn FnMut(&Node<'a>)); }