Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

s2a: Address minor comments on PR#11113 #11540

Merged
merged 11 commits into from
Sep 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static java.util.concurrent.TimeUnit.SECONDS;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ChannelCredentials;
Expand All @@ -35,14 +34,13 @@
import io.netty.util.concurrent.DefaultThreadFactory;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.ConcurrentMap;
import javax.annotation.concurrent.ThreadSafe;

/**
* Provides APIs for managing gRPC channels to S2A servers. Each channel is local and plaintext. If
* credentials are provided, they are used to secure the channel.
* Provides APIs for managing gRPC channels to an S2A server. Each channel is local and plaintext.
* If credentials are provided, they are used to secure the channel.
*
* <p>This is done as follows: for each S2A server, provides an implementation of gRPC's {@link
* <p>This is done as follows: for an S2A server, provides an implementation of gRPC's {@link
* SharedResourceHolder.Resource} interface called a {@code Resource<Channel>}. A {@code
* Resource<Channel>} is a factory for creating gRPC channels to the S2A server at a given address,
* and a channel must be returned to the {@code Resource<Channel>} when it is no longer needed.
Expand All @@ -59,8 +57,6 @@
*/
@ThreadSafe
public final class S2AHandshakerServiceChannel {
private static final ConcurrentMap<String, Resource<Channel>> SHARED_RESOURCE_CHANNELS =
Maps.newConcurrentMap();
private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2);
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);

Expand All @@ -76,8 +72,7 @@ public final class S2AHandshakerServiceChannel {
public static Resource<Channel> getChannelResource(
String s2aAddress, Optional<ChannelCredentials> s2aChannelCredentials) {
checkNotNull(s2aAddress);
return SHARED_RESOURCE_CHANNELS.computeIfAbsent(
s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials));
return new ChannelResource(s2aAddress, s2aChannelCredentials);
}

/**
Expand Down
28 changes: 2 additions & 26 deletions s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,11 @@

package io.grpc.s2a.handshaker;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;

/** Converts proto messages to Netty strings. */
final class ProtoUtil {
/**
* Converts {@link Ciphersuite} to its {@link String} representation.
*
* @param ciphersuite the {@link Ciphersuite} to be converted.
* @return a {@link String} representing the ciphersuite.
* @throws AssertionError if the {@link Ciphersuite} is not one of the supported ciphersuites.
*/
static String convertCiphersuite(Ciphersuite ciphersuite) {
switch (ciphersuite) {
case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256";
case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384";
case CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
return "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256";
case CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
case CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384";
case CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
return "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256";
default:
throw new AssertionError(
String.format("Ciphersuite %d is not supported.", ciphersuite.getNumber()));
}
}

/**
* Converts a {@link TLSVersion} object to its {@link String} representation.
Expand All @@ -54,6 +29,7 @@ static String convertCiphersuite(Ciphersuite ciphersuite) {
* @return a {@link String} representation of the TLS version.
* @throws AssertionError if the {@code tlsVersion} is not one of the supported TLS versions.
*/
@VisibleForTesting
static String convertTlsProtocolVersion(TLSVersion tlsVersion) {
switch (tlsVersion) {
case TLS_VERSION_1_3:
Expand Down
15 changes: 8 additions & 7 deletions s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ BlockingQueue<Result> getResponses() {
* @throws IOException if an unexpected response is received, or if the {@code reader} or {@code
* writer} calls their {@code onError} method.
*/
@SuppressWarnings("CheckReturnValue")
public SessionResp send(SessionReq req) throws IOException, InterruptedException {
if (doneWriting && doneReading) {
logger.log(Level.INFO, "Stream to the S2A is closed.");
Expand All @@ -92,9 +93,8 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException
createWriterIfNull();
if (!responses.isEmpty()) {
IOException exception = null;
SessionResp resp = null;
try {
resp = responses.take().getResultOrThrow();
responses.take().getResultOrThrow();
} catch (IOException e) {
exception = e;
}
Expand All @@ -104,14 +104,15 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException
"Received an unexpected response from a host at the S2A's address. The S2A might be"
+ " unavailable."
+ exception.getMessage());
} else {
throw new IOException("Received an unexpected response from a host at the S2A's address.");
}
return resp;
}
try {
writer.onNext(req);
} catch (RuntimeException e) {
writer.onError(e);
responses.offer(Result.createWithThrowable(e));
responses.add(Result.createWithThrowable(e));
}
try {
return responses.take().getResultOrThrow();
Expand Down Expand Up @@ -159,7 +160,7 @@ private class Reader implements StreamObserver<SessionResp> {
@Override
public void onNext(SessionResp resp) {
verify(!doneReading);
responses.offer(Result.createWithResponse(resp));
responses.add(Result.createWithResponse(resp));
}

/**
Expand All @@ -169,7 +170,7 @@ public void onNext(SessionResp resp) {
*/
@Override
public void onError(Throwable t) {
responses.offer(Result.createWithThrowable(t));
responses.add(Result.createWithThrowable(t));
}

/**
Expand All @@ -180,7 +181,7 @@ public void onError(Throwable t) {
public void onCompleted() {
logger.log(Level.INFO, "Reading from the S2A is complete.");
doneReading = true;
responses.offer(
responses.add(
Result.createWithThrowable(
new ConnectionClosedException("Reading from the S2A is complete.")));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public final class AccessTokenManager {
private final TokenFetcher tokenFetcher;

/** Creates an {@code AccessTokenManager} based on the environment where the application runs. */
@SuppressWarnings("RethrowReflectiveOperationExceptionAsLinkageError")
public static Optional<AccessTokenManager> create() {
Optional<?> tokenFetcher;
try {
Expand All @@ -38,7 +37,7 @@ public static Optional<AccessTokenManager> create() {
} catch (ClassNotFoundException e) {
tokenFetcher = Optional.empty();
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
throw new LinkageError(e.getMessage(), e);
}
return tokenFetcher.isPresent()
? Optional.of(new AccessTokenManager((TokenFetcher) tokenFetcher.get()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ public void getChannelResource_mtlsSuccess() throws Exception {

/**
* Creates two {@code Resoure<Channel>}s for the same target address and verifies that they are
* equal.
* distinct.
*/
@Test
public void getChannelResource_twoEqualChannels() {
public void getChannelResource_twoUnEqualChannels() {
Resource<Channel> resource =
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + plaintextServer.getPort(),
Expand All @@ -109,19 +109,19 @@ public void getChannelResource_twoEqualChannels() {
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + plaintextServer.getPort(),
/* s2aChannelCredentials= */ Optional.empty());
assertThat(resource).isEqualTo(resourceTwo);
assertThat(resource).isNotEqualTo(resourceTwo);
}

/** Same as getChannelResource_twoEqualChannels, but use mTLS. */
/** Same as getChannelResource_twoUnEqualChannels, but use mTLS. */
@Test
public void getChannelResource_mtlsTwoEqualChannels() throws Exception {
public void getChannelResource_mtlsTwoUnEqualChannels() throws Exception {
Resource<Channel> resource =
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
Resource<Channel> resourceTwo =
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
assertThat(resource).isEqualTo(resourceTwo);
assertThat(resource).isNotEqualTo(resourceTwo);
}

/**
Expand Down
7 changes: 6 additions & 1 deletion s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.grpc.s2a.handshaker;

import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.logging.Logger;
Expand All @@ -38,7 +39,11 @@ public StreamObserver<SessionReq> setUpSession(StreamObserver<SessionResp> respo
@Override
public void onNext(SessionReq req) {
logger.info("Received a request from client.");
responseObserver.onNext(writer.handleResponse(req));
try {
responseObserver.onNext(writer.handleResponse(req));
} catch (IOException e) {
responseObserver.onError(e);
}
}

@Override
Expand Down
18 changes: 11 additions & 7 deletions s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import io.grpc.benchmarks.Utils;
import io.grpc.s2a.handshaker.ValidatePeerCertificateChainReq.VerificationMode;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -45,9 +48,7 @@ public final class FakeS2AServerTest {
private static final Logger logger = Logger.getLogger(FakeS2AServerTest.class.getName());

private static final ImmutableList<ByteString> FAKE_CERT_DER_CHAIN =
ImmutableList.of(
ByteString.copyFrom(
new byte[] {'f', 'a', 'k', 'e', '-', 'd', 'e', 'r', '-', 'c', 'h', 'a', 'i', 'n'}));
ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII)));
private int port;
private String serverAddress;
private SessionResp response = null;
Expand All @@ -68,7 +69,7 @@ public void tearDown() {

@Test
public void callS2AServerOnce_getTlsConfiguration_returnsValidResult()
throws InterruptedException {
throws InterruptedException, IOException {
ExecutorService executor = Executors.newSingleThreadExecutor();
logger.info("Client connecting to: " + serverAddress);
ManagedChannel channel =
Expand Down Expand Up @@ -122,9 +123,12 @@ public void onCompleted() {}
GetTlsConfigurationResp.newBuilder()
.setClientTlsConfiguration(
GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder()
.addCertificateChain(FakeWriter.LEAF_CERT)
.addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2)
.addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1)
.addCertificateChain(new String(Files.readAllBytes(
FakeWriter.leafCertFile.toPath()), StandardCharsets.UTF_8))
.addCertificateChain(new String(Files.readAllBytes(
FakeWriter.cert1File.toPath()), StandardCharsets.UTF_8))
.addCertificateChain(new String(Files.readAllBytes(
FakeWriter.cert2File.toPath()), StandardCharsets.UTF_8))
.setMinTlsVersion(TLSVersion.TLS_VERSION_1_3)
.setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3)
.addCiphersuites(
Expand Down
Loading