diff --git a/src/models/concrete_syntax.rs b/src/models/concrete_syntax.rs index d85e937a7..81323e32c 100644 --- a/src/models/concrete_syntax.rs +++ b/src/models/concrete_syntax.rs @@ -25,6 +25,7 @@ use crate::models::matches::Match; // Precompile the regex outside the function lazy_static! { static ref RE_VAR: Regex = Regex::new(r"^:\[(?P\w+)\]").unwrap(); + static ref RE_VAR_PLUS: Regex = Regex::new(r"^:\[(?P\w+)\+\]").unwrap(); } // Struct to avoid dealing with lifetimes @@ -34,15 +35,21 @@ pub struct CapturedNode { text: String, } +#[derive(Clone, PartialEq, Eq)] +struct MatchResult { + mapping: HashMap, + range: Range, +} + pub(crate) fn get_all_matches_for_concrete_syntax( - node: &Node, code_str: &[u8], meta: &ConcreteSyntax, recursive: bool, - replace_node: Option, + node: &Node, code_str: &[u8], cs: &ConcreteSyntax, recursive: bool, replace_node: Option, ) -> (Vec, bool) { let mut matches: Vec = Vec::new(); - if let (mut match_map, true) = get_matches_for_node(&mut node.walk(), code_str, meta) { + if let Some(match_result) = match_sequential_siblings(&mut node.walk(), code_str, cs) { let replace_node_key = replace_node.clone().unwrap_or("*".to_string()); - + let mut match_map = match_result.mapping; + let range = match_result.range; let replace_node_match = if replace_node_key != "*" { match_map .get(&replace_node_key) @@ -52,8 +59,8 @@ pub(crate) fn get_all_matches_for_concrete_syntax( }) } else { CapturedNode { - range: Range::from(node.range()), - text: node.utf8_text(code_str).unwrap().to_string(), + range, + text: get_code_from_range(range.start_byte, range.end_byte, code_str), } }; @@ -71,7 +78,7 @@ pub(crate) fn get_all_matches_for_concrete_syntax( let mut cursor = node.walk(); for child in node.children(&mut cursor) { if let (mut inner_matches, true) = - get_all_matches_for_concrete_syntax(&child, code_str, meta, recursive, replace_node.clone()) + get_all_matches_for_concrete_syntax(&child, code_str, cs, recursive, replace_node.clone()) { matches.append(&mut inner_matches); } @@ -82,134 +89,277 @@ pub(crate) fn get_all_matches_for_concrete_syntax( (matches, !is_empty) } -/// `find_next_sibling` navigates the cursor through the tree to the next sibling. -/// If no sibling exists, it ascends the tree until it can move laterally or until it reaches the root. +fn find_next_sibling_or_ancestor_sibling(cursor: &mut TreeCursor) -> bool { + while !cursor.goto_next_sibling() { + if !cursor.goto_parent() { + return false; + } + } + true +} + +/// Attempts to match a given ConcreteSyntax template against a sequence of sibling nodes +/// in an Abstract Syntax Tree (AST). /// /// # Arguments /// -/// * `cursor` - A mutable reference to a `TreeCursor` used to navigate the tree. +/// * `cursor` - A mutable reference to the TreeCursor, which is used to navigate the AST. +/// * `source_code` - A slice of bytes representing the source code being analyzed. +/// * `cs` - A reference to the ConcreteSyntax template used for matching. /// -/// The function mutates the cursor position. If no further siblings exist, the cursor ends at the root. -fn find_next_sibling(cursor: &mut TreeCursor) { - while !cursor.goto_next_sibling() { - if !cursor.goto_parent() { - break; +/// # Returns +/// +/// A tuple containing: +/// * A HashMap where keys are variable names from the template and values are CapturedNode instances representing matched AST nodes. +/// * A boolean indicating whether a match was found. +/// * An Option containing the range of matched nodes if a match was found. +/// +/// # Algorithm +/// +/// 1. Initialize cursor to the first child and iterate through siblings. +/// 2. Use `get_matches_for_subsequence_of_nodes` to attempt matching the template against a sequence of subtree starting at each sibling. +/// 3. If a match is found, determine the range of matched nodes subtrees (i.e., [2nd,..., 4th], and return the match mapping, and range. +/// 4. If no match is found, return an empty mapping, and None for range. +fn match_sequential_siblings( + cursor: &mut TreeCursor, source_code: &[u8], cs: &ConcreteSyntax, +) -> Option { + let parent_node = cursor.node(); + let mut child_seq_match_start = 0; + + if cursor.goto_first_child() { + // Iterate through siblings to find a match + loop { + // Clone the cursor in order to attempt matching the sequence starting at cursor.node + // Cloning here is necessary other we won't be able to advance to the next sibling if the matching fails + let mut tmp_cursor = cursor.clone(); + let (mapping, indx) = + get_matches_for_subsequence_of_nodes(&mut tmp_cursor, source_code, cs, true, &parent_node); + + // If we got the index of the last matched sibling, that means the matching was successful. + if let Some(last_node_index) = indx { + // Determine the last matched node. Remember, we are matching subsequences of children [n ... k] + let last_node = parent_node.child(last_node_index); + let range = Range::span_ranges(cursor.node().range(), last_node.unwrap().range()); + if last_node_index != child_seq_match_start || parent_node.child_count() == 1 { + return Some(MatchResult { mapping, range }); + } + return None; + } + + child_seq_match_start += 1; + if !cursor.goto_next_sibling() { + break; + } } - } + } // Not currently handing matching of leaf nodes. Current semantics would never match it anyway. + None } /// This function performs the actual matching of the ConcreteSyntax pattern against a syntax tree /// node. The matching is done in the following way: /// +/// /// - If the ConcreteSyntax is empty and all the nodes have been visited, then we found a match! +/// Otherwise, if we ran out of nodes to match, and the template is not empty, then we failed /// /// - If the ConcreteSyntax starts with `:[variable]`, the function tries to match the variable -/// against all possible AST nodes starting at the current's cursor position (i.e., the node itself, -/// its first child, the child of the first child, and so on.) -/// If it succeeds, it advances the ConcreteSyntax by the length of the matched -/// AST node and calls itself recursively to try to match the rest of the ConcreteSyntax. +/// against all possible AST nodes starting at the current's cursor position (i.e., the node itself +/// and all of its siblings, its first child and respective siblings, the child of the first child, and so on.) +/// If it succeeds, it advances the ConcreteSyntax by the length of the matched sequence of +/// AST nodes, and calls itself recursively to try to match the rest of the ConcreteSyntax. /// -/// - If the ConcreteSyntax doesn't start with `:[variable]`, the function checks if the node is a leaf -/// (i.e., has no children). If it is, and its text starts with the metasyyntax, we match the text, -/// and advance to the next immediate node (i.e., it's sibling or it's parent's sibling). If does not -/// match we cannot match the meta syntax template. +/// - If the ConcreteSyntax doesn't start with `:[variable]`, the function checks if the node **is a leaf** +/// (i.e., has no children). If it is, and the leaf node matches the concrete syntax, we match it against +/// the concrete syntax and advance to the next immediate node. If the leaf does not match the concrete syntax, +/// then our matching has failed. /// -/// - If the ConcreteSyntax doesn't start with `:[variable]` and the node is not a leaf, the function +/// - If the ConcreteSyntax doesn't start with `:[variable]` and the node **is not a leaf**, the function /// moves the cursor to the first child of the node and calls itself recursively to try to match /// the ConcreteSyntax. -pub(crate) fn get_matches_for_node( - cursor: &mut TreeCursor, source_code: &[u8], meta: &ConcreteSyntax, -) -> (HashMap, bool) { - let match_template = meta.0.as_str(); +pub(crate) fn get_matches_for_subsequence_of_nodes( + cursor: &mut TreeCursor, source_code: &[u8], cs: &ConcreteSyntax, nodes_left_to_match: bool, + top_node: &Node, +) -> (HashMap, Option) { + let match_template = cs.0.as_str(); if match_template.is_empty() { - return ( - HashMap::new(), - !cursor.goto_next_sibling() && !cursor.goto_parent(), - ); + if !nodes_left_to_match { + return (HashMap::new(), Some(top_node.child_count() - 1)); + } + let index = find_last_matched_node(cursor, top_node); + return (HashMap::new(), index); + } else if !nodes_left_to_match { + return (HashMap::new(), None); } let mut node = cursor.node(); - // Skip comment nodes always while node.kind().contains("comment") && cursor.goto_next_sibling() { node = cursor.node(); } - // In case the template starts with :[var_name], we try match - if let Some(caps) = RE_VAR.captures(match_template) { - let var_name = &caps["var_name"]; - let meta_adv_len = caps[0].len(); - let meta_advanced = ConcreteSyntax( - match_template[meta_adv_len..] - .to_string() - .trim_start() - .to_string(), - ); + if let Some(caps) = RE_VAR_PLUS.captures(match_template) { + // If template starts with a template variable + handle_template_variable_matching(cursor, source_code, top_node, caps, match_template, true) + } else if let Some(caps) = RE_VAR.captures(match_template) { + // If template starts with a template variable + handle_template_variable_matching(cursor, source_code, top_node, caps, match_template, false) + } else if node.child_count() == 0 { + // If the current node if a leaf + return handle_leaf_node(cursor, source_code, match_template, top_node); + } else { + // If the current node is an intermediate node + cursor.goto_first_child(); + return get_matches_for_subsequence_of_nodes(cursor, source_code, cs, true, top_node); + } +} + +/// This function does the template variable matching against entire tree nodes.function +/// Keep in my mind that it will only attempt to match the template variables against nodes +/// at either the current level of the traversal, or it's children. It can also operate on +/// single node templates [args], and multiple nodes templates :[args+]. + +/// For successful matches, it returns the assignment of each template varaible against a +/// particular range. The Option indicates whether a match was succesfull, and keeps +/// track of the last sibling node that was matched (wrt to the match_sequential_siblings function) +fn handle_template_variable_matching( + cursor: &mut TreeCursor, source_code: &[u8], top_node: &Node, caps: regex::Captures, + match_template: &str, one_plus: bool, +) -> (HashMap, Option) { + let var_name = &caps["var_name"]; + let cs_adv_len = caps[0].len(); + let cs_advanced = ConcreteSyntax( + match_template[cs_adv_len..] + .to_string() + .trim_start() + .to_string(), + ); - // If we need to match a variable `:[var]`, we can match it against the next node or any of it's - // first children. We need to try all possibilities. + // Matching :[var] against a sequence of nodes [first_node, ... last_node] + loop { + let first_node = cursor.node(); + let mut last_node = first_node; + + // Determine whether a next node exists: + let mut next_node_cursor = cursor.clone(); + let mut should_match = find_next_sibling_or_ancestor_sibling(&mut next_node_cursor); + // At this point next_node_cursor either points to the first sibling of the first node, + // or the first node itself, if such sibling no longer exists + + // Intentionally setting is_final_sibling to false regardless of should_match, due to the logic of handling the last iteration + let mut is_final_sibling = false; loop { - let mut tmp_cursor = cursor.clone(); - let current_node = cursor.node(); - let current_node_code = current_node.utf8_text(source_code).unwrap(); - find_next_sibling(&mut tmp_cursor); - - // Support for trailing commas - // This skips trailing commas as we are parsing through the match template - // Skips the comma node if the template doesn't contain it. - let next_node = tmp_cursor.node(); - let next_node_text = next_node.utf8_text(source_code).unwrap(); - if next_node_text == "," && !meta_advanced.0.starts_with(',') { - find_next_sibling(&mut tmp_cursor); // Skip comma - } + let mut tmp_cursor = next_node_cursor.clone(); - if let (mut recursive_matches, true) = - get_matches_for_node(&mut tmp_cursor, source_code, &meta_advanced) + if let (mut recursive_matches, Some(last_matched_node_idx)) = + get_matches_for_subsequence_of_nodes( + &mut tmp_cursor, + source_code, + &cs_advanced, + should_match, + top_node, + ) { - // If we already matched this variable, we need to make sure that the match is the same. Otherwise, we were unsuccessful. - // No other way of unrolling exists. + // Continuous code range that :[var] is matching from [first, ..., last] + let matched_code = get_code_from_range( + first_node.range().start_byte, + last_node.range().end_byte, + source_code, + ); + + // Check if :[var] was already matched against some code range + // If it did, and it is not the same, we return unsuccessful if recursive_matches.contains_key(var_name) - && recursive_matches[var_name].text.trim() != current_node_code.trim() + && recursive_matches[var_name].text.trim() != matched_code.trim() { - return (HashMap::new(), false); + return (HashMap::new(), None); } + + // Otherwise insert it recursive_matches.insert( var_name.to_string(), CapturedNode { - range: Range::from(current_node.range()), - text: current_node_code.to_string(), + range: Range::span_ranges(first_node.range(), last_node.range()), + text: matched_code, }, ); - return (recursive_matches, true); + return (recursive_matches, Some(last_matched_node_idx)); } - if !cursor.goto_first_child() { + // Append an extra node to match with :[var]. Remember we had advanced next_node_cursor before, + // therefore we cannot advance it again, otherwise we would skip nodes. + // We only attempt to append an extra code if we are in one_plus matching mode. + last_node = next_node_cursor.node(); + if is_final_sibling { break; } - } - } else if node.child_count() == 0 { - let code = node.utf8_text(source_code).unwrap().trim(); - if match_template.starts_with(code) && !code.is_empty() { - let advance_by = code.len(); - // Can only advance if there is still enough chars to consume - if advance_by > match_template.len() { - return (HashMap::new(), false); + + // This is used for the final iteration. We need to determine if there are any other nodes + // left to match, to inform our next recursive call. We do this by calling find_next_sibling_or_ancestor_sibling + // to move the cursor to the parent and find the next sibling at another level, + // since at this level we already matched everything + is_final_sibling = !next_node_cursor.goto_next_sibling(); + if is_final_sibling { + should_match = find_next_sibling_or_ancestor_sibling(&mut next_node_cursor); + } + + if !one_plus { + break; } - let meta_substring = ConcreteSyntax( - match_template[advance_by..] - .to_string() - .trim_start() - .to_owned(), - ); - find_next_sibling(cursor); - return get_matches_for_node(cursor, source_code, &meta_substring); } - } else { - cursor.goto_first_child(); - return get_matches_for_node(cursor, source_code, meta); + + // Move one level down, to attempt to match the template variable :[var] against smaller nodes. + if !cursor.goto_first_child() { + break; + } + } + (HashMap::new(), None) +} + +fn handle_leaf_node( + cursor: &mut TreeCursor, source_code: &[u8], match_template: &str, top_node: &Node, +) -> (HashMap, Option) { + let code = cursor.node().utf8_text(source_code).unwrap().trim(); + if match_template.starts_with(code) && !code.is_empty() { + let advance_by = code.len(); + // Can only advance if there is still enough chars to consume + if advance_by > match_template.len() { + return (HashMap::new(), None); + } + let cs_substring = ConcreteSyntax( + match_template[advance_by..] + .to_string() + .trim_start() + .to_owned(), + ); + let should_match = find_next_sibling_or_ancestor_sibling(cursor); + return get_matches_for_subsequence_of_nodes( + cursor, + source_code, + &cs_substring, + should_match, + top_node, + ); } - (HashMap::new(), false) + (HashMap::new(), None) +} + +/// Finds the index of the last matched node relative to the `match_sequential_siblings` function. +/// +/// This function checks if the matching concluded on a child of the node where `match_sequential_siblings` +/// was invoked. If so, it returns the index of that child. +fn find_last_matched_node(cursor: &mut TreeCursor, parent_node: &Node) -> Option { + parent_node + .children(&mut parent_node.walk()) + .enumerate() + .filter(|&(_i, child)| child == cursor.node()) + .map(|(i, _child)| i - 1) + .next() +} + +fn get_code_from_range(start_byte: usize, end_byte: usize, source_code: &[u8]) -> String { + let text_slice = &source_code[start_byte..end_byte]; + String::from_utf8_lossy(text_slice).to_string() } #[cfg(test)] diff --git a/src/models/matches.rs b/src/models/matches.rs index 60e96878b..47fbdeb96 100644 --- a/src/models/matches.rs +++ b/src/models/matches.rs @@ -321,6 +321,21 @@ impl Range { end_point: position_for_offset(source_code.as_bytes(), mtch.end()), } } + + /// Creates a new range that spans from the beginning of the start range + /// to the end of the end range. + /// + /// This function is useful for creating a range that covers the span + /// from the start of one range to the end of another, regardless of whether + /// the ranges are contiguous. + pub(crate) fn span_ranges(left: tree_sitter::Range, right: tree_sitter::Range) -> Self { + Self { + start_byte: left.start_byte, + end_byte: right.end_byte, + start_point: left.start_point.into(), + end_point: right.end_point.into(), + } + } } // Finds the position (col and row number) for a given offset. diff --git a/src/models/unit_tests/concrete_syntax_test.rs b/src/models/unit_tests/concrete_syntax_test.rs index dd7e39af5..2d607f5fc 100644 --- a/src/models/unit_tests/concrete_syntax_test.rs +++ b/src/models/unit_tests/concrete_syntax_test.rs @@ -25,13 +25,8 @@ fn run_test( let tree = parser.parse(code.as_bytes(), None).unwrap(); let meta = ConcreteSyntax(String::from(pattern)); - let (matches, _is_match_found) = get_all_matches_for_concrete_syntax( - &tree.root_node().child(0).unwrap(), - code.as_bytes(), - &meta, - true, - None, - ); + let (matches, _is_match_found) = + get_all_matches_for_concrete_syntax(&tree.root_node(), code.as_bytes(), &meta, true, None); assert_eq!(matches.len(), expected_matches); @@ -84,11 +79,88 @@ fn test_no_match() { fn test_trailing_comma() { run_test( "a.foo(x, // something about the first argument - y, // something about the second argumet + y, // something about the second argument );", - ":[var].foo(:[arg1], :[arg2])", + ":[var].foo(:[arg1], :[arg2+])", 2, - vec![vec![("var", "a"), ("arg1", "x"), ("arg2", "y")]], + vec![vec![("var", "a"), ("arg1", "x"), ("arg2", "y,")]], GO, ); } + +#[test] +fn test_sequential_siblings_matching() { + run_test( + "a.foo(x, y, z);", + ":[var].foo(:[arg1+], z)", + 2, + vec![vec![("var", "a"), ("arg1", "x, y")]], + GO, + ); +} + +#[test] +fn test_sequential_siblings_stmts() { + // Find all usages of foo, whose last element is z. + run_test( + "{ int x = 2; x = x + 1; while(x > 0) { x = x - 1} } ", + "int :[stmt1] = 2; \ + :[stmt2] = :[stmt2] + 1;", + 1, + vec![vec![("stmt1", "x"), ("stmt2", "x")]], + JAVA, + ); +} + +#[test] +fn test_sequential_siblings_stmts2() { + // Find all usages of foo, whose last element is z. + run_test( + "x.foo(1,2,3,4);", + ":[var].foo(:[args+]);", + 2, + vec![vec![("var", "x"), ("args", "1,2,3,4")]], + JAVA, + ); +} + +#[test] +fn test_complex_template() { + // Test matching the given code against the template + run_test( + "void main() { + // Some comment + int some = 0; + while(some < 100) { + float length = 3.14; + float area = length * length; + some++; + }}", + "int :[var] = 0; + while(:[var] < 100) { + :[body+] + :[var] ++; + }", + 1, + vec![vec![ + ("var", "some"), + ( + "body", + "float length = 3.14;\n float area = length * length;", + ), + ]], + JAVA, + ); +} + +#[test] +fn test_match_anything() { + // Test matching the given code against the template + run_test( + "public static void main(String args) { }", + ":[x]", + 1, + vec![vec![("x", "public static void main(String args) { }")]], + JAVA, + ); +} diff --git a/src/tests/test_piranha_java.rs b/src/tests/test_piranha_java.rs index bd299597c..e27714b9f 100644 --- a/src/tests/test_piranha_java.rs +++ b/src/tests/test_piranha_java.rs @@ -426,8 +426,8 @@ fn test_dyn_rule() { let rule = piranha_rule! { name = "match_class", query = "cs println(:[xs])", - replace_node = "xs", - replace = "@xs, 2" + replace_node = "*", + replace = "println2(@xs, 2)" }; let piranha_arguments = PiranhaArgumentsBuilder::default()