diff --git a/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/AbstractThriftMessageClassFinder.java b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/AbstractThriftMessageClassFinder.java new file mode 100644 index 00000000000..0499fd279be --- /dev/null +++ b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/AbstractThriftMessageClassFinder.java @@ -0,0 +1,110 @@ +/* + * Copyright 2019 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= +package com.linecorp.armeria.common.thrift.text; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; +import java.util.function.Supplier; + +import javax.annotation.Nullable; + +import org.apache.thrift.TApplicationException; +import org.apache.thrift.TBase; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.linecorp.armeria.internal.thrift.TApplicationExceptions; + +abstract class AbstractThriftMessageClassFinder implements Supplier> { + private static final Logger logger = LoggerFactory.getLogger(AbstractThriftMessageClassFinder.class); + + @Nullable + static Class getClassByName(String className) { + try { + return Class.forName(className); + } catch (ClassNotFoundException ex) { + logger.warn("Can't find a class: {}", className, ex); + } + return null; + } + + @Nullable + static Class getMatchedClass(@Nullable Class clazz) { + if (clazz == null) { + return null; + } + // Note, we need to check + // if the class is abstract, because abstract class does not have metaDataMap + // if the class has no-arg constructor, because FieldMetaData.getStructMetaDataMap + // calls clazz.newInstance + if (isTBase(clazz) && !isAbstract(clazz) && hasNoArgConstructor(clazz)) { + return clazz; + } + + if (isTApplicationException(clazz)) { + return clazz; + } + + if (isTApplicationExceptions(clazz)) { + return TApplicationException.class; + } + + return null; + } + + static boolean isTBase(Class clazz) { + return TBase.class.isAssignableFrom(clazz); + } + + private static boolean isTApplicationExceptions(Class clazz) { + return clazz == TApplicationExceptions.class; + } + + private static boolean isTApplicationException(Class clazz) { + return TApplicationException.class.isAssignableFrom(clazz); + } + + private static boolean isAbstract(Class clazz) { + return Modifier.isAbstract(clazz.getModifiers()); + } + + private static boolean hasNoArgConstructor(Class clazz) { + final Constructor[] allConstructors = clazz.getConstructors(); + for (Constructor ctor : allConstructors) { + final Class[] pType = ctor.getParameterTypes(); + if (pType.length == 0) { + return true; + } + } + + return false; + } +} diff --git a/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/DefaultThriftMessageClassFinder.java b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/DefaultThriftMessageClassFinder.java new file mode 100644 index 00000000000..9e663d02e05 --- /dev/null +++ b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/DefaultThriftMessageClassFinder.java @@ -0,0 +1,55 @@ +/* + * Copyright 2019 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= +package com.linecorp.armeria.common.thrift.text; + +import javax.annotation.Nullable; + +final class DefaultThriftMessageClassFinder extends AbstractThriftMessageClassFinder { + + @Nullable + @Override + public Class get() { + final StackTraceElement[] frames = + Thread.currentThread().getStackTrace(); + + for (StackTraceElement f : frames) { + final String className = f.getClassName(); + final Class matchedClazz = getMatchedClass(getClassByName(className)); + + if (matchedClazz != null) { + return matchedClazz; + } + } + + return null; + } +} + diff --git a/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StackWalkingThriftMessageClassFinder.java b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StackWalkingThriftMessageClassFinder.java new file mode 100644 index 00000000000..d79db546b9d --- /dev/null +++ b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StackWalkingThriftMessageClassFinder.java @@ -0,0 +1,142 @@ +/* + * Copyright 2019 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// ================================================================================================= +// Copyright 2011 Twitter, Inc. +// ------------------------------------------------------------------------------------------------- +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this work except in compliance with the License. +// You may obtain a copy of the License in the LICENSE file, or at: +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ================================================================================================= +package com.linecorp.armeria.common.thrift.text; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodHandles.Lookup; +import java.lang.invoke.MethodType; +import java.util.Arrays; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Stream; + +import javax.annotation.Nullable; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class StackWalkingThriftMessageClassFinder extends AbstractThriftMessageClassFinder { + + private static final Logger logger = + LoggerFactory.getLogger(StackWalkingThriftMessageClassFinder.class); + + private static final String INVOKING_FAIL_MSG = + "Failed to invoke StackWalker.StackFrame.getDeclaringClass():"; + + @Nullable + private final Function, Class> walkHandler; + private final MethodHandle walkMH; + private final Object stackWalker; + + StackWalkingThriftMessageClassFinder() throws Throwable { + final ClassLoader classLoader = ClassLoader.getSystemClassLoader(); + final Lookup lookup = MethodHandles.lookup(); + + final Class stackWalkerClass = classLoader.loadClass("java.lang.StackWalker"); + walkMH = lookup.findVirtual(stackWalkerClass, + "walk", + MethodType.methodType(Object.class, Function.class)); + + final Class stackFrameClass = classLoader.loadClass("java.lang.StackWalker$StackFrame"); + Function> getClassByStackFrameTemp; + Object instance; + + try { + // StackWalker instantiate with RETAIN_CLASS_REFERENCE option + final Class Option = classLoader.loadClass("java.lang.StackWalker$Option"); + final MethodHandle getInstanceMH = + lookup.findStatic(stackWalkerClass, + "getInstance", + MethodType.methodType(stackWalkerClass, Option)); + final Enum RETAIN_CLASS_REFERENCE = + Arrays.stream((Enum[]) Option.getEnumConstants()) + .filter(op -> "RETAIN_CLASS_REFERENCE".equals(op.name())) + .findFirst().orElseGet(null); + + if (RETAIN_CLASS_REFERENCE == null) { + throw new IllegalStateException("Failed to get RETAIN_CLASS_REFERENCE option"); + } + instance = getInstanceMH.invoke(RETAIN_CLASS_REFERENCE); + final MethodHandle getDeclaringClassMH = lookup.findVirtual(stackFrameClass, + "getDeclaringClass", + MethodType.methodType(Class.class)); + + getClassByStackFrameTemp = stackFrame -> { + try { + return getMatchedClass((Class) getDeclaringClassMH.invoke(stackFrame)); + } catch (Throwable t) { + logger.warn(INVOKING_FAIL_MSG, t); + } + return null; + }; + } catch (Throwable throwable) { + // StackWalker instantiate without option + logger.warn("Falling back to StackWalker without option:", throwable); + final MethodHandle getInstanceMH = + lookup.findStatic(stackWalkerClass, + "getInstance", + MethodType.methodType(stackWalkerClass)); + final MethodHandle getClassNameMH = + lookup.findVirtual(stackFrameClass, "getClassName", MethodType.methodType(String.class)); + + instance = getInstanceMH.invoke(); + getClassByStackFrameTemp = stackFrame -> { + try { + return getMatchedClass(getClassByName(getClassNameMH.invoke(stackFrame).toString())); + } catch (Throwable t) { + logger.warn(INVOKING_FAIL_MSG, t); + } + return null; + }; + } + + stackWalker = instance; + + final Function> getClassByStackFrame = getClassByStackFrameTemp; + walkHandler = stackFrameStream -> + stackFrameStream.map(getClassByStackFrame) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + } + + @Nullable + @Override + public Class get() { + try { + return (Class) walkMH.invoke(stackWalker, walkHandler); + } catch (Throwable t) { + throw new IllegalStateException("Failed to invoke StackWalker.walk():", t); + } + } +} + diff --git a/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StructContext.java b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StructContext.java index 51819573db9..9a6d98b4eb2 100644 --- a/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StructContext.java +++ b/thrift/src/main/java/com/linecorp/armeria/common/thrift/text/StructContext.java @@ -30,15 +30,15 @@ // ================================================================================================= package com.linecorp.armeria.common.thrift.text; -import java.lang.reflect.Constructor; -import java.lang.reflect.Modifier; +import static com.linecorp.armeria.common.thrift.text.AbstractThriftMessageClassFinder.isTBase; + import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; +import java.util.function.Supplier; import javax.annotation.Nullable; -import org.apache.thrift.TApplicationException; import org.apache.thrift.TBase; import org.apache.thrift.TException; import org.apache.thrift.TFieldIdEnum; @@ -56,7 +56,7 @@ import com.fasterxml.jackson.databind.JsonNode; -import com.linecorp.armeria.internal.thrift.TApplicationExceptions; +import com.linecorp.armeria.common.util.SystemInfo; /** * A struct parsing context. Builds a map from field name to TField. @@ -65,6 +65,21 @@ */ class StructContext extends PairContext { private static final Logger log = LoggerFactory.getLogger(StructContext.class); + private static final Supplier> thriftMessageClassFinder; + + static { + Supplier> supplier = null; + if (SystemInfo.javaVersion() >= 9) { + try { + supplier = new StackWalkingThriftMessageClassFinder(); + } catch (Throwable t) { + log.warn("Failed to initialize StackWalkingThriftMessageClassFinder. " + + "Falling back to DefaultThriftMessageClassFinder:", t); + } + } + + thriftMessageClassFinder = supplier != null ? supplier : new DefaultThriftMessageClassFinder(); + } // When processing a given thrift struct, we need certain information // for every field in that struct. We store that here, in a map @@ -132,63 +147,13 @@ protected Class getClassByFieldName(String fieldName) { * TProtocol.writeStructBegin(), rather than relying on the stack trace. */ private static Class getCurrentThriftMessageClass() { - final StackTraceElement[] frames = - Thread.currentThread().getStackTrace(); + final Class clazz = thriftMessageClassFinder.get(); - for (StackTraceElement f : frames) { - final String className = f.getClassName(); - - try { - final Class clazz = Class.forName(className); - - // Note, we need to check - // if the class is abstract, because abstract class does not have metaDataMap - // if the class has no-arg constructor, because FieldMetaData.getStructMetaDataMap - // calls clazz.newInstance - if (isTBase(clazz) && !isAbstract(clazz) && hasNoArgConstructor(clazz)) { - return clazz; - } - - if (isTApplicationException(clazz)) { - return clazz; - } - - if (isTApplicationExceptions(clazz)) { - return TApplicationException.class; - } - } catch (ClassNotFoundException ex) { - log.warn("Can't find class: " + className, ex); - } - } - throw new RuntimeException("Must call (indirectly) from a TBase/TApplicationException object."); - } - - private static boolean isTBase(Class clazz) { - return TBase.class.isAssignableFrom(clazz); - } - - private static boolean isTApplicationException(Class clazz) { - return TApplicationException.class.isAssignableFrom(clazz); - } - - private static boolean isTApplicationExceptions(Class clazz) { - return clazz == TApplicationExceptions.class; - } - - private static boolean isAbstract(Class clazz) { - return Modifier.isAbstract(clazz.getModifiers()); - } - - private static boolean hasNoArgConstructor(Class clazz) { - final Constructor[] allConstructors = clazz.getConstructors(); - for (Constructor ctor : allConstructors) { - final Class[] pType = ctor.getParameterTypes(); - if (pType.length == 0) { - return true; - } + if (clazz == null) { + throw new RuntimeException("Must call (indirectly) from a TBase/TApplicationException object."); } - return false; + return clazz; } /** diff --git a/thrift/src/test/java/com/linecorp/armeria/common/thrift/text/ThriftMessageClassFinderTest.java b/thrift/src/test/java/com/linecorp/armeria/common/thrift/text/ThriftMessageClassFinderTest.java new file mode 100644 index 00000000000..dffe21844d4 --- /dev/null +++ b/thrift/src/test/java/com/linecorp/armeria/common/thrift/text/ThriftMessageClassFinderTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2019 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package com.linecorp.armeria.common.thrift.text; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.security.Permission; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import com.linecorp.armeria.common.thrift.text.RpcDebugService.doDebug_args; +import com.linecorp.armeria.common.util.SystemInfo; + +class ThriftMessageClassFinderTest { + @ParameterizedTest(name = "testThriftMessageClassFinder {index}: finder={0}") + @MethodSource("testThriftMessageClassParameters") + void testThriftMessageClassFinder(Supplier> thriftMessageClassFinder) { + assertThat(thriftMessageClassFinder.get()).isNull(); + assertThat(new MockArgs().proxy(thriftMessageClassFinder)).isNotNull(); + } + + private static Stream testThriftMessageClassParameters() throws Throwable { + Stream parameters = Stream.of(Arguments.of(new DefaultThriftMessageClassFinder())); + + if (SystemInfo.javaVersion() >= 9) { + parameters = Stream.concat( + parameters, + Stream.of( + Arguments.of(new StackWalkingThriftMessageClassFinder()), + Arguments.of(getNoOptionStackWalkerInstance()))); + } + return parameters; + } + + private static Supplier> getNoOptionStackWalkerInstance() throws Throwable { + final SecurityManager SM = System.getSecurityManager(); + System.setSecurityManager(new SecurityManager() { + @Override + public void checkPermission(Permission perm) { + // `getStackWalkerWithClassReference` is called by RETAIN_CLASS_REFERENCE option. + if (perm.getName().equals("getStackWalkerWithClassReference")) { + throw new SecurityException("Failing SecurityManage.checkPermission() for unit testing"); + } + } + + @Override + public void checkPermission(Permission perm, Object context) { + } + + @Override + public void checkExit(int status) { + } + }); + + final StackWalkingThriftMessageClassFinder noOptionStackWalkingThriftMessageClassFinder = + new StackWalkingThriftMessageClassFinder(); + System.setSecurityManager(SM); + + return noOptionStackWalkingThriftMessageClassFinder; + } + + public static class MockArgs extends doDebug_args { + public Class proxy(Supplier> currentThriftMessageClassFinder) { + return currentThriftMessageClassFinder.get(); + } + } +}