diff --git a/lib/src/main/java/thriftlabs/thriftfmt/FormatterUtil.java b/lib/src/main/java/thriftlabs/thriftfmt/FormatterUtil.java index 4252101..94ea342 100644 --- a/lib/src/main/java/thriftlabs/thriftfmt/FormatterUtil.java +++ b/lib/src/main/java/thriftlabs/thriftfmt/FormatterUtil.java @@ -1,8 +1,11 @@ package thriftlabs.thriftfmt; import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.function.BiPredicate; +import java.util.function.Consumer; +import java.util.function.Function; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.tree.ParseTree; @@ -12,7 +15,7 @@ import thriftlabs.thriftparser.ThriftParser.FieldContext; public final class FormatterUtil { - private static final int FAKE_NODE_LINE_NO = -1; + public static final int FAKE_NODE_LINE_NO = -1; public static boolean isToken(ParseTree node, String text) { return node instanceof TerminalNode && ((TerminalNode) node).getSymbol().getText().equals(text); @@ -217,4 +220,28 @@ private static ParseTree[] getSubArray(ParseTree[] nodes, int startIndex) { System.arraycopy(nodes, startIndex, subArray, 0, nodes.length - startIndex); return subArray; } + + // 遍历节点 + public static void walkNode(ParseTree root, Consumer callback) { + LinkedList stack = new LinkedList<>(); + stack.add(root); + + while (!stack.isEmpty()) { + ParseTree node = stack.removeFirst(); // 移除并获取第一个元素 + if (node == null) { + break; + } + + callback.accept(root); // 调用回调函数 + List children = getNodeChildren(node); // 获取子节点 + for (ParseTree child : children) { + stack.add(child); // 添加子节点到栈中 + } + } + } + + public static boolean isFunctionOrThrowsListNode(ParseTree node) { + return node instanceof ThriftParser.Function_Context || + node instanceof ThriftParser.Throws_listContext; + } } diff --git a/lib/src/main/java/thriftlabs/thriftfmt/ThriftFormatter.java b/lib/src/main/java/thriftlabs/thriftfmt/ThriftFormatter.java index 5122460..59f3cfc 100644 --- a/lib/src/main/java/thriftlabs/thriftfmt/ThriftFormatter.java +++ b/lib/src/main/java/thriftlabs/thriftfmt/ThriftFormatter.java @@ -1,6 +1,79 @@ package thriftlabs.thriftfmt; +import java.util.HashMap; +import java.util.Map; + +import org.antlr.v4.runtime.CommonToken; +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.RuleContext; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.antlr.v4.runtime.tree.TerminalNodeImpl; + +import thriftlabs.thriftparser.Thrift; +import thriftlabs.thriftparser.ThriftParser; + public class ThriftFormatter extends PureThriftFormatter { - public ThriftFormatter() { + + private Thrift.ParserResult data; + private ThriftParser.DocumentContext document; + private int lastTokenIndex = -1; + private int fieldCommentPadding = 0; + private int fieldAlignAssignPadding = 0; + private Map fieldAlignPaddingMap; + + public ThriftFormatter(Thrift.ParserResult data) { + this.data = data; + this.document = data.document; + this.fieldAlignPaddingMap = new HashMap<>(); + } + + public String format() { + patch(); + return formatNode(document); + } + + private void patch() { + FormatterUtil.walkNode(this.document, node -> this.patchFieldRequired(node)); + } + + private void patchFieldRequired(ParseTree node) { + // 检查是否是 FieldContext 的实例 + if (!(node instanceof ThriftParser.FieldContext)) { + return; + } + ThriftParser.FieldContext field = (ThriftParser.FieldContext) node; + + // 检查父节点是否为 undefined 或者是 FunctionOrThrowsListNode 的实例 + if (field.getParent() == null || FormatterUtil.isFunctionOrThrowsListNode(field.getParent())) { + return; + } + + int i; + // 遍历子节点 + for (i = 0; i < field.getChildCount(); i++) { + ParseTree child = field.getChild(i); + if (child instanceof ThriftParser.Field_reqContext) { + return; // 如果已经有 Field_reqContext,返回 + } + if (child instanceof ThriftParser.Field_typeContext) { + break; // 找到 Field_typeContext,停止循环 + } + } + + // 创建伪节点 + CommonToken fakeToken = new CommonToken(ThriftParser.T__20, "required"); + fakeToken.setLine(FormatterUtil.FAKE_NODE_LINE_NO); + fakeToken.setCharPositionInLine(FormatterUtil.FAKE_NODE_LINE_NO); + fakeToken.setTokenIndex(-1); + TerminalNode fakeNode = new TerminalNodeImpl(fakeToken); + ThriftParser.Field_reqContext fakeReq = new ThriftParser.Field_reqContext(field, 0); + + fakeNode.setParent(fakeReq); + fakeReq.addChild(fakeNode); + fakeReq.setParent(field); + + // 在子节点的指定位置插入 fakeReq + field.children.add(i, fakeReq); } } diff --git a/lib/src/test/java/thriftlabs/thriftfmt/ThriftFormatterTest.java b/lib/src/test/java/thriftlabs/thriftfmt/ThriftFormatterTest.java index 9783b3a..d6499f2 100644 --- a/lib/src/test/java/thriftlabs/thriftfmt/ThriftFormatterTest.java +++ b/lib/src/test/java/thriftlabs/thriftfmt/ThriftFormatterTest.java @@ -6,7 +6,7 @@ public class ThriftFormatterTest { @Test public void someLibraryMethodReturnsTrue() { - var classUnderTest = new ThriftFormatter(); + // var classUnderTest = new ThriftFormatter(); // assertTrue("someLibraryMethod should return 'true'", // classUnderTest.someLibraryMethod()); }