diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 0d309828c6d..b9c6a77982a 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -17,12 +17,14 @@ package io.grpc.netty; import io.grpc.ChannelLogger; +import io.grpc.internal.ObjectPool; import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler; import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.concurrent.Executor; /** * Internal accessor for {@link ProtocolNegotiators}. @@ -35,9 +37,12 @@ private InternalProtocolNegotiators() {} * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, + ObjectPool executorPool) { + final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, + executorPool); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -58,6 +63,15 @@ public void close() { return new TlsNegotiator(); } + + /** + * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will + * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} + * may happen immediately, even before the TLS Handshake is complete. + */ + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { + return tls(sslContext, null); + } /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java index 75ec7347bb5..90956907bfe 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java @@ -29,13 +29,11 @@ import io.grpc.MethodDescriptor; import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyChannelBuilder; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.nio.NioEventLoopGroup; -import io.netty.channel.socket.nio.NioSocketChannel; -import io.netty.util.concurrent.DefaultThreadFactory; import java.time.Duration; import java.util.Optional; import java.util.concurrent.ConcurrentMap; +import java.util.logging.Level; +import java.util.logging.Logger; import javax.annotation.concurrent.ThreadSafe; /** @@ -61,7 +59,6 @@ public final class S2AHandshakerServiceChannel { private static final ConcurrentMap> SHARED_RESOURCE_CHANNELS = Maps.newConcurrentMap(); - private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2); private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); /** @@ -95,41 +92,34 @@ public ChannelResource(String targetAddress, Optional channe } /** - * Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code - * targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup} - * instance to avoid blocking. + * Creates a {@code HandshakerServiceChannel} instance to the service running at {@code + * targetAddress}. */ @Override public Channel create() { - EventLoopGroup eventLoopGroup = - new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true)); ManagedChannel channel = null; if (channelCredentials.isPresent()) { // Create a secure channel. channel = NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get()) - .channelType(NioSocketChannel.class) .directExecutor() - .eventLoopGroup(eventLoopGroup) .build(); } else { // Create a plaintext channel. channel = NettyChannelBuilder.forTarget(targetAddress) - .channelType(NioSocketChannel.class) .directExecutor() - .eventLoopGroup(eventLoopGroup) .usePlaintext() .build(); } - return EventLoopHoldingChannel.create(channel, eventLoopGroup); + return HandshakerServiceChannel.create(channel); } - /** Destroys a {@code EventLoopHoldingChannel} instance. */ + /** Destroys a {@code HandshakerServiceChannel} instance. */ @Override public void close(Channel instanceChannel) { checkNotNull(instanceChannel); - EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel; + HandshakerServiceChannel channel = (HandshakerServiceChannel) instanceChannel; channel.close(); } @@ -140,23 +130,21 @@ public String toString() { } /** - * Manages a channel using a {@link ManagedChannel} instance that belong to the {@code - * EventLoopGroup} thread pool. + * Manages a channel using a {@link ManagedChannel} instance. */ @VisibleForTesting - static class EventLoopHoldingChannel extends Channel { + static class HandshakerServiceChannel extends Channel { + private static final Logger logger = + Logger.getLogger(S2AHandshakerServiceChannel.class.getName()); private final ManagedChannel delegate; - private final EventLoopGroup eventLoopGroup; - static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + static HandshakerServiceChannel create(ManagedChannel delegate) { checkNotNull(delegate); - checkNotNull(eventLoopGroup); - return new EventLoopHoldingChannel(delegate, eventLoopGroup); + return new HandshakerServiceChannel(delegate); } - private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + private HandshakerServiceChannel(ManagedChannel delegate) { this.delegate = delegate; - this.eventLoopGroup = eventLoopGroup; } /** @@ -178,16 +166,12 @@ public ClientCall newCall( @SuppressWarnings("FutureReturnValueIgnored") public void close() { delegate.shutdownNow(); - boolean isDelegateTerminated; try { - isDelegateTerminated = - delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS); + delegate.awaitTermination(CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } catch (InterruptedException e) { - isDelegateTerminated = false; + Thread.currentThread().interrupt(); + logger.log(Level.WARNING, "Channel to S2A was not shutdown."); } - long quietPeriodSeconds = isDelegateTerminated ? 0 : 1; - eventLoopGroup.shutdownGracefully( - quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } } diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index 25d1e325ea8..14bdc05238d 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -29,7 +29,9 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.ThreadSafe; import io.grpc.Channel; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; @@ -227,7 +229,10 @@ protected void handlerAdded0(ChannelHandlerContext ctx) { @Override public void onSuccess(SslContext sslContext) { ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls( + sslContext, + SharedResourcePool.forResource(GrpcUtil.SHARED_CHANNEL_EXECUTOR)) + .newHandler(grpcHandler); // Remove the bufferReads handler and delegate the rest of the handshake to the TLS // handler. diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java index 57288be1b6f..dc5909442bf 100644 --- a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java @@ -18,11 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; -import static java.util.concurrent.TimeUnit.SECONDS; import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import io.grpc.CallOptions; import io.grpc.Channel; @@ -39,15 +35,13 @@ import io.grpc.benchmarks.Utils; import io.grpc.internal.SharedResourceHolder.Resource; import io.grpc.netty.NettyServerBuilder; -import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel.HandshakerServiceChannel; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; -import io.netty.channel.EventLoopGroup; import java.io.File; -import java.time.Duration; import java.util.Optional; import java.util.concurrent.TimeUnit; import org.junit.Before; @@ -60,8 +54,6 @@ @RunWith(JUnit4.class) public final class S2AHandshakerServiceChannelTest { @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); - private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); - private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); private Server mtlsServer; private Server plaintextServer; @@ -191,7 +183,7 @@ public void close_mtlsSuccess() throws Exception { } /** - * Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to + * Verifies that an {@code HandshakerServiceChannel}'s {@code newCall} method can be used to * perform a simple RPC. */ @Test @@ -201,7 +193,7 @@ public void newCall_performSimpleRpcSuccess() { "localhost:" + plaintextServer.getPort(), /* s2aChannelCredentials= */ Optional.empty()); Channel channel = resource.create(); - assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channel).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) .isEqualToDefaultInstance(); @@ -214,53 +206,49 @@ public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception { S2AHandshakerServiceChannel.getChannelResource( "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); Channel channel = resource.create(); - assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channel).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) .isEqualToDefaultInstance(); } - /** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */ + /** Creates a {@code HandshakerServiceChannel} instance and verifies its authority. */ @Test public void authority_success() throws Exception { ManagedChannel channel = new FakeManagedChannel(true); - EventLoopHoldingChannel eventLoopHoldingChannel = - EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + HandshakerServiceChannel eventLoopHoldingChannel = + HandshakerServiceChannel.create(channel); assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel"); } /** - * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates - * successfully. + * Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} + * terminates successfully. */ @Test public void close_withDelegateTerminatedSuccess() throws Exception { ManagedChannel channel = new FakeManagedChannel(true); - EventLoopHoldingChannel eventLoopHoldingChannel = - EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + HandshakerServiceChannel eventLoopHoldingChannel = + HandshakerServiceChannel.create(channel); eventLoopHoldingChannel.close(); assertThat(channel.isShutdown()).isTrue(); - verify(mockEventLoopGroup, times(1)) - .shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } /** - * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not + * Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} does not * terminate successfully. */ @Test public void close_withDelegateTerminatedFailure() throws Exception { ManagedChannel channel = new FakeManagedChannel(false); - EventLoopHoldingChannel eventLoopHoldingChannel = - EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + HandshakerServiceChannel eventLoopHoldingChannel = + HandshakerServiceChannel.create(channel); eventLoopHoldingChannel.close(); assertThat(channel.isShutdown()).isTrue(); - verify(mockEventLoopGroup, times(1)) - .shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); } /** - * Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same + * Creates and closes a {@code HandshakerServiceChannel}, creates a new channel from the same * resource, and verifies that this second channel is useable. */ @Test @@ -273,7 +261,7 @@ public void create_succeedsAfterCloseIsCalledOnce() throws Exception { resource.close(channelOne); Channel channelTwo = resource.create(); - assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channelTwo) .unaryRpc(SimpleRequest.getDefaultInstance())) @@ -291,7 +279,7 @@ public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception { resource.close(channelOne); Channel channelTwo = resource.create(); - assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class); assertThat( SimpleServiceGrpc.newBlockingStub(channelTwo) .unaryRpc(SimpleRequest.getDefaultInstance()))