Skip to content

Commit

Permalink
S2AHandshakerServiceChannel doesn't use custom event loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehta19 committed Sep 19, 2024
1 parent 9b0c19e commit 03ced73
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 61 deletions.
32 changes: 32 additions & 0 deletions netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
package io.grpc.netty;

import io.grpc.ChannelLogger;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
import io.grpc.netty.ProtocolNegotiators.GrpcNegotiationHandler;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.concurrent.Executor;

/**
* Internal accessor for {@link ProtocolNegotiators}.
Expand All @@ -31,6 +33,36 @@ public final class InternalProtocolNegotiators {

private InternalProtocolNegotiators() {}

/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
* may happen immediately, even before the TLS Handshake is complete.
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {

@Override
public AsciiString scheme() {
return negotiator.scheme();
}

@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
return negotiator.newHandler(grpcHandler);
}

@Override
public void close() {
negotiator.close();
}
}

return new TlsNegotiator();
}

/**
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
import io.grpc.MethodDescriptor;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultThreadFactory;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -61,7 +57,6 @@
public final class S2AHandshakerServiceChannel {
private static final ConcurrentMap<String, Resource<Channel>> SHARED_RESOURCE_CHANNELS =
Maps.newConcurrentMap();
private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2);
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);

/**
Expand Down Expand Up @@ -95,41 +90,35 @@ public ChannelResource(String targetAddress, Optional<ChannelCredentials> channe
}

/**
* Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code
* Creates a {@code HandshakerServiceChannel} instance to the service running at {@code
* targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup}
* instance to avoid blocking.
*/
@Override
public Channel create() {
EventLoopGroup eventLoopGroup =
new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true));
ManagedChannel channel = null;
if (channelCredentials.isPresent()) {
// Create a secure channel.
channel =
NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get())
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.build();
} else {
// Create a plaintext channel.
channel =
NettyChannelBuilder.forTarget(targetAddress)
.channelType(NioSocketChannel.class)
.directExecutor()
.eventLoopGroup(eventLoopGroup)
.usePlaintext()
.build();
}
return EventLoopHoldingChannel.create(channel, eventLoopGroup);
return HandshakerServiceChannel.create(channel);
}

/** Destroys a {@code EventLoopHoldingChannel} instance. */
/** Destroys a {@code HandshakerServiceChannel} instance. */
@Override
public void close(Channel instanceChannel) {
checkNotNull(instanceChannel);
EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel;
HandshakerServiceChannel channel = (HandshakerServiceChannel) instanceChannel;
channel.close();
}

Expand All @@ -140,23 +129,19 @@ public String toString() {
}

/**
* Manages a channel using a {@link ManagedChannel} instance that belong to the {@code
* EventLoopGroup} thread pool.
* Manages a channel using a {@link ManagedChannel} instance.
*/
@VisibleForTesting
static class EventLoopHoldingChannel extends Channel {
static class HandshakerServiceChannel extends Channel {
private final ManagedChannel delegate;
private final EventLoopGroup eventLoopGroup;

static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
static HandshakerServiceChannel create(ManagedChannel delegate) {
checkNotNull(delegate);
checkNotNull(eventLoopGroup);
return new EventLoopHoldingChannel(delegate, eventLoopGroup);
return new HandshakerServiceChannel(delegate);
}

private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) {
private HandshakerServiceChannel(ManagedChannel delegate) {
this.delegate = delegate;
this.eventLoopGroup = eventLoopGroup;
}

/**
Expand All @@ -178,16 +163,11 @@ public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
@SuppressWarnings("FutureReturnValueIgnored")
public void close() {
delegate.shutdownNow();
boolean isDelegateTerminated;
try {
isDelegateTerminated =
delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS);
delegate.awaitTermination(CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
} catch (InterruptedException e) {
isDelegateTerminated = false;
Thread.currentThread().interrupt();
}
long quietPeriodSeconds = isDelegateTerminated ? 0 : 1;
eventLoopGroup.shutdownGracefully(
quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.common.util.concurrent.MoreExecutors;
import com.google.errorprone.annotations.ThreadSafe;
import io.grpc.Channel;
import io.grpc.internal.FixedObjectPool;
import io.grpc.internal.ObjectPool;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalProtocolNegotiator;
Expand Down Expand Up @@ -227,7 +228,9 @@ protected void handlerAdded0(ChannelHandlerContext ctx) {
@Override
public void onSuccess(SslContext sslContext) {
ChannelHandler handler =
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
InternalProtocolNegotiators.tls(
sslContext, new FixedObjectPool<>(Executors.newFixedThreadPool(1)))
.newHandler(grpcHandler);

// Remove the bufferReads handler and delegate the rest of the handshake to the TLS
// handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -39,15 +35,13 @@
import io.grpc.benchmarks.Utils;
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel;
import io.grpc.s2a.channel.S2AHandshakerServiceChannel.HandshakerServiceChannel;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.netty.channel.EventLoopGroup;
import java.io.File;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
Expand All @@ -60,8 +54,6 @@
@RunWith(JUnit4.class)
public final class S2AHandshakerServiceChannelTest {
@ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);
private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class);
private Server mtlsServer;
private Server plaintextServer;

Expand Down Expand Up @@ -191,7 +183,7 @@ public void close_mtlsSuccess() throws Exception {
}

/**
* Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to
* Verifies that an {@code HandshakerServiceChannel}'s {@code newCall} method can be used to
* perform a simple RPC.
*/
@Test
Expand All @@ -201,7 +193,7 @@ public void newCall_performSimpleRpcSuccess() {
"localhost:" + plaintextServer.getPort(),
/* s2aChannelCredentials= */ Optional.empty());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
Expand All @@ -214,53 +206,49 @@ public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception {
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
Channel channel = resource.create();
assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channel).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()))
.isEqualToDefaultInstance();
}

/** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */
/** Creates a {@code HandshakerServiceChannel} instance and verifies its authority. */
@Test
public void authority_success() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel");
}

/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates
* successfully.
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel}
* terminates successfully.
*/
@Test
public void close_withDelegateTerminatedSuccess() throws Exception {
ManagedChannel channel = new FakeManagedChannel(true);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}

/**
* Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not
* Creates and closes a {@code HandshakerServiceChannel} when its {@code ManagedChannel} does not
* terminate successfully.
*/
@Test
public void close_withDelegateTerminatedFailure() throws Exception {
ManagedChannel channel = new FakeManagedChannel(false);
EventLoopHoldingChannel eventLoopHoldingChannel =
EventLoopHoldingChannel.create(channel, mockEventLoopGroup);
HandshakerServiceChannel eventLoopHoldingChannel =
HandshakerServiceChannel.create(channel);
eventLoopHoldingChannel.close();
assertThat(channel.isShutdown()).isTrue();
verify(mockEventLoopGroup, times(1))
.shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS);
}

/**
* Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same
* Creates and closes a {@code HandshakerServiceChannel}, creates a new channel from the same
* resource, and verifies that this second channel is useable.
*/
@Test
Expand All @@ -273,7 +261,7 @@ public void create_succeedsAfterCloseIsCalledOnce() throws Exception {
resource.close(channelOne);

Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
Expand All @@ -291,7 +279,7 @@ public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception {
resource.close(channelOne);

Channel channelTwo = resource.create();
assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class);
assertThat(channelTwo).isInstanceOf(HandshakerServiceChannel.class);
assertThat(
SimpleServiceGrpc.newBlockingStub(channelTwo)
.unaryRpc(SimpleRequest.getDefaultInstance()))
Expand Down

0 comments on commit 03ced73

Please sign in to comment.