diff --git a/lib/pg_query/param_refs.rb b/lib/pg_query/param_refs.rb index ab393f5d..0a9bd74a 100644 --- a/lib/pg_query/param_refs.rb +++ b/lib/pg_query/param_refs.rb @@ -3,7 +3,7 @@ class ParserResult def param_refs # rubocop:disable Metrics/CyclomaticComplexity results = [] - treewalker! @tree do |_, _, node, location| + treewalker_with_location! @tree do |_, _, node, location| case node when PgQuery::ParamRef # Ignore param refs inside type casts, as these are already handled diff --git a/lib/pg_query/treewalker.rb b/lib/pg_query/treewalker.rb index e62b5858..c59a60a5 100644 --- a/lib/pg_query/treewalker.rb +++ b/lib/pg_query/treewalker.rb @@ -5,15 +5,17 @@ class ParserResult # If you pass a block with 1 argument, you will get each node. # If you pass a block with 4 arguments, you will get each parent_node, parent_field, node and location. # + # If sufficent for the use case, the 1 argument block approach is recommended, since its faster. + # # Location uniquely identifies a given node within the parse tree. This is a stable identifier across # multiple parser runs, assuming the same pg_query release and no modifications to the parse tree. def walk!(&block) if block.arity == 1 - treewalker!(@tree) do |_, _, node, _| + treewalker!(@tree) do |node| yield(node) end else - treewalker!(@tree) do |parent_node, parent_field, node, location| + treewalker_with_location!(@tree) do |parent_node, parent_field, node, location| yield(parent_node, parent_field, node, location) end end @@ -22,6 +24,34 @@ def walk!(&block) private def treewalker!(tree) # rubocop:disable Metrics/CyclomaticComplexity + nodes = [tree.dup] + + loop do + parent_node = nodes.shift + + case parent_node + when Google::Protobuf::MessageExts + parent_node.class.descriptor.each do |field_descriptor| + node = field_descriptor.get(parent_node) + next if node.nil? + yield(node) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField) + + nodes << node unless node.nil? + end + when Google::Protobuf::RepeatedField + parent_node.each do |node| + next if node.nil? + yield(node) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField) + + nodes << node unless node.nil? + end + end + + break if nodes.empty? + end + end + + def treewalker_with_location!(tree) # rubocop:disable Metrics/CyclomaticComplexity nodes = [[tree.dup, []]] loop do @@ -29,11 +59,12 @@ def treewalker!(tree) # rubocop:disable Metrics/CyclomaticComplexity case parent_node when Google::Protobuf::MessageExts - parent_node.to_h.keys.each do |parent_field| - node = parent_node[parent_field.to_s] + parent_node.class.descriptor.each do |field_descriptor| + parent_field = field_descriptor.name + node = parent_node[parent_field] next if node.nil? - location = parent_location + [parent_field] - yield(parent_node, parent_field, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField) + location = parent_location + [parent_field.to_sym] + yield(parent_node, parent_field.to_sym, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField) nodes << [node, location] unless node.nil? end @@ -52,7 +83,7 @@ def treewalker!(tree) # rubocop:disable Metrics/CyclomaticComplexity end def find_tree_location(tree, searched_location) - treewalker! tree do |parent_node, parent_field, node, location| + treewalker_with_location! tree do |parent_node, parent_field, node, location| next unless location == searched_location yield(parent_node, parent_field, node) end diff --git a/lib/pg_query/truncate.rb b/lib/pg_query/truncate.rb index 37cc1968..c7383ba8 100644 --- a/lib/pg_query/truncate.rb +++ b/lib/pg_query/truncate.rb @@ -60,7 +60,7 @@ def truncate(max_length) # rubocop:disable Metrics/CyclomaticComplexity def find_possible_truncations # rubocop:disable Metrics/CyclomaticComplexity truncations = [] - treewalker! @tree do |node, k, v, location| + treewalker_with_location! @tree do |node, k, v, location| case k when :target_list next unless node.is_a?(PgQuery::SelectStmt) || node.is_a?(PgQuery::UpdateStmt) || node.is_a?(PgQuery::OnConflictClause) diff --git a/spec/lib/treewalker_spec.rb b/spec/lib/treewalker_spec.rb index 3b7870c3..9784231c 100644 --- a/spec/lib/treewalker_spec.rb +++ b/spec/lib/treewalker_spec.rb @@ -1,6 +1,6 @@ require 'spec_helper' -describe PgQuery, '.treewalker' do +describe PgQuery, '#walk!' do it 'walks nodes contained in repeated fields' do locations = [] described_class.parse("SELECT to_timestamp($1)").walk! do |_, _, _, location|