From 42cd3e8fd841322d6f92f45d0315d537dd53ed80 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 25 Apr 2024 09:48:53 -0700 Subject: [PATCH 01/27] add s2a java client. --- MODULE.bazel | 2 + repositories.bzl | 3 + s2a/BUILD.bazel | 194 +++++++++ s2a/build.gradle | 153 +++++++ .../grpc/s2a/handshaker/S2AServiceGrpc.java | 285 +++++++++++++ .../grpc/s2a/MtlsToS2AChannelCredentials.java | 96 +++++ .../io/grpc/s2a/S2AChannelCredentials.java | 132 ++++++ .../io/grpc/s2a/channel/S2AChannelPool.java | 43 ++ .../grpc/s2a/channel/S2AGrpcChannelPool.java | 112 +++++ .../channel/S2AHandshakerServiceChannel.java | 195 +++++++++ .../ConnectionIsClosedException.java | 27 ++ .../GetAuthenticationMechanisms.java | 60 +++ .../io/grpc/s2a/handshaker/ProtoUtil.java | 72 ++++ .../handshaker/S2AConnectionException.java | 25 ++ .../io/grpc/s2a/handshaker/S2AIdentity.java | 62 +++ .../s2a/handshaker/S2APrivateKeyMethod.java | 143 +++++++ .../S2AProtocolNegotiatorFactory.java | 194 +++++++++ .../java/io/grpc/s2a/handshaker/S2AStub.java | 225 ++++++++++ .../grpc/s2a/handshaker/S2ATrustManager.java | 152 +++++++ .../s2a/handshaker/SslContextFactory.java | 179 ++++++++ .../tokenmanager/AccessTokenManager.java | 61 +++ .../tokenmanager/SingleTokenFetcher.java | 64 +++ .../handshaker/tokenmanager/TokenFetcher.java | 28 ++ s2a/src/main/proto/grpc/gcp/common.proto | 79 ++++ s2a/src/main/proto/grpc/gcp/s2a.proto | 369 +++++++++++++++++ s2a/src/main/proto/grpc/gcp/s2a_context.proto | 61 +++ .../s2a/MtlsToS2AChannelCredentialsTest.java | 135 ++++++ .../grpc/s2a/S2AChannelCredentialsTest.java | 112 +++++ .../s2a/channel/S2AGrpcChannelPoolTest.java | 125 ++++++ .../S2AHandshakerServiceChannelTest.java | 390 ++++++++++++++++++ .../io/grpc/s2a/handshaker/FakeS2AServer.java | 55 +++ .../s2a/handshaker/FakeS2AServerTest.java | 265 ++++++++++++ .../io/grpc/s2a/handshaker/FakeWriter.java | 347 ++++++++++++++++ .../GetAuthenticationMechanismsTest.java | 64 +++ .../grpc/s2a/handshaker/IntegrationTest.java | 322 +++++++++++++++ .../io/grpc/s2a/handshaker/ProtoUtilTest.java | 95 +++++ .../handshaker/S2APrivateKeyMethodTest.java | 308 ++++++++++++++ .../S2AProtocolNegotiatorFactoryTest.java | 267 ++++++++++++ .../io/grpc/s2a/handshaker/S2AStubTest.java | 260 ++++++++++++ .../s2a/handshaker/S2ATrustManagerTest.java | 262 ++++++++++++ .../s2a/handshaker/SslContextFactoryTest.java | 173 ++++++++ .../SingleTokenAccessTokenManagerTest.java | 74 ++++ s2a/src/test/resources/README.md | 31 ++ s2a/src/test/resources/client.csr | 16 + s2a/src/test/resources/client_cert.pem | 18 + s2a/src/test/resources/client_key.pem | 28 ++ s2a/src/test/resources/config.cnf | 17 + s2a/src/test/resources/root_cert.pem | 22 + s2a/src/test/resources/root_key.pem | 30 ++ s2a/src/test/resources/server.csr | 16 + s2a/src/test/resources/server_cert.pem | 20 + s2a/src/test/resources/server_key.pem | 28 ++ settings.gradle | 2 + 53 files changed, 6498 insertions(+) create mode 100644 s2a/BUILD.bazel create mode 100644 s2a/build.gradle create mode 100644 s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java create mode 100644 s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java create mode 100644 s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java create mode 100644 s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java create mode 100644 s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java create mode 100644 s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java create mode 100644 s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java create mode 100644 s2a/src/main/proto/grpc/gcp/common.proto create mode 100644 s2a/src/main/proto/grpc/gcp/s2a.proto create mode 100644 s2a/src/main/proto/grpc/gcp/s2a_context.proto create mode 100644 s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java create mode 100644 s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java create mode 100644 s2a/src/test/resources/README.md create mode 100644 s2a/src/test/resources/client.csr create mode 100644 s2a/src/test/resources/client_cert.pem create mode 100644 s2a/src/test/resources/client_key.pem create mode 100644 s2a/src/test/resources/config.cnf create mode 100644 s2a/src/test/resources/root_cert.pem create mode 100644 s2a/src/test/resources/root_key.pem create mode 100644 s2a/src/test/resources/server.csr create mode 100644 s2a/src/test/resources/server_cert.pem create mode 100644 s2a/src/test/resources/server_key.pem diff --git a/MODULE.bazel b/MODULE.bazel index 1d79e362e11..78e6ccb70f2 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -41,7 +41,9 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "io.perfmark:perfmark-api:0.26.0", "junit:junit:4.13.2", "org.apache.tomcat:annotations-api:6.0.53", + "org.checkerframework:checker-qual:3.12.0", "org.codehaus.mojo:animal-sniffer-annotations:1.23", + "org.jcommander:jcommander:1.83", ] # GRPC_DEPS_END diff --git a/repositories.bzl b/repositories.bzl index 1f422d3380f..7ed5141fec3 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -45,7 +45,9 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "io.perfmark:perfmark-api:0.26.0", "junit:junit:4.13.2", "org.apache.tomcat:annotations-api:6.0.53", + "org.checkerframework:checker-qual:3.12.0", "org.codehaus.mojo:animal-sniffer-annotations:1.23", + "org.jcommander:jcommander:1.83", ] # GRPC_DEPS_END @@ -80,6 +82,7 @@ IO_GRPC_GRPC_JAVA_OVERRIDE_TARGETS = { "io.grpc:grpc-rls": "@io_grpc_grpc_java//rls", "io.grpc:grpc-services": "@io_grpc_grpc_java//services:services_maven", "io.grpc:grpc-stub": "@io_grpc_grpc_java//stub", + "io.grpc:grpc-s2a": "@io_grpc_grpc_java//s2a", "io.grpc:grpc-testing": "@io_grpc_grpc_java//testing", "io.grpc:grpc-xds": "@io_grpc_grpc_java//xds:xds_maven", "io.grpc:grpc-util": "@io_grpc_grpc_java//util", diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel new file mode 100644 index 00000000000..0041ad52be6 --- /dev/null +++ b/s2a/BUILD.bazel @@ -0,0 +1,194 @@ +load("@rules_proto//proto:defs.bzl", "proto_library") +load("//:java_grpc_library.bzl", "java_grpc_library") +load("@rules_jvm_external//:defs.bzl", "artifact") + +java_library( + name = "s2a_channel_pool", + srcs = glob([ + "src/main/java/io/grpc/s2a/channel/*.java", + ]), + deps = [ + "//api", + "//core", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + artifact("io.netty:netty-common"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2a_identity", + srcs = ["src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java"], + deps = [ + ":common_java_proto", + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + ], +) + +java_library( + name = "token_fetcher", + srcs = ["src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java"], + deps = [ + ":s2a_identity", + ], +) + +java_library( + name = "access_token_manager", + srcs = [ + "src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java", + ], + deps = [ + ":s2a_identity", + ":token_fetcher", + artifact("com.google.code.findbugs:jsr305"), + ], +) + +java_library( + name = "single_token_fetcher", + srcs = [ + "src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java", + ], + deps = [ + ":s2a_identity", + ":token_fetcher", + artifact("org.jcommander:jcommander"), + ], +) + +java_library( + name = "s2a_handshaker", + srcs = [ + "src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java", + "src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java", + "src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java", + "src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java", + "src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java", + "src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java", + "src/main/java/io/grpc/s2a/handshaker/S2AStub.java", + "src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java", + "src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java", + ], + deps = [ + ":access_token_manager", + ":common_java_proto", + ":s2a_channel_pool", + ":s2a_identity", + ":s2a_java_proto", + ":s2a_java_grpc_proto", + ":single_token_fetcher", + "//api", + "//core:internal", + "//netty", + "//stub", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + "@com_google_protobuf//:protobuf_java", + artifact("io.netty:netty-common"), + artifact("io.netty:netty-handler"), + artifact("io.netty:netty-transport"), + ], +) + +java_library( + name = "s2av2_credentials", + srcs = ["src/main/java/io/grpc/s2a/S2AChannelCredentials.java"], + visibility = ["//visibility:public"], + deps = [ + ":s2a_channel_pool", + ":s2a_handshaker", + ":s2a_identity", + "//api", + "//core:internal", + "//netty", + artifact("com.google.code.findbugs:jsr305"), + artifact("com.google.errorprone:error_prone_annotations"), + artifact("com.google.guava:guava"), + artifact("org.checkerframework:checker-qual"), + ], +) + +java_library( + name = "mtls_to_s2av2_credentials", + srcs = ["src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java"], + visibility = ["//visibility:public"], + deps = [ + ":s2a_channel_pool", + ":s2av2_credentials", + "//api", + "//util", + artifact("com.google.guava:guava"), + ], +) + +# bazel only accepts proto import with absolute path. +genrule( + name = "protobuf_imports", + srcs = glob(["src/main/proto/grpc/gcp/*.proto"]), + outs = [ + "protobuf_out/grpc/gcp/s2a.proto", + "protobuf_out/grpc/gcp/s2a_context.proto", + "protobuf_out/grpc/gcp/common.proto", + ], + cmd = "for fname in $(SRCS); do " + + "sed 's,import \",import \"s2a/protobuf_out/,g' $$fname > " + + "$(@D)/protobuf_out/grpc/gcp/$$(basename $$fname); done", +) + +proto_library( + name = "common_proto", + srcs = [ + "protobuf_out/grpc/gcp/common.proto", + ], +) + +proto_library( + name = "s2a_context_proto", + srcs = [ + "protobuf_out/grpc/gcp/s2a_context.proto", + ], + deps = [ + ":common_proto", + ], +) + +proto_library( + name = "s2a_proto", + srcs = [ + "protobuf_out/grpc/gcp/s2a.proto", + ], + deps = [ + ":common_proto", + ":s2a_context_proto", + ], +) + +java_proto_library( + name = "s2a_java_proto", + deps = [":s2a_proto"], +) + +java_proto_library( + name = "s2a_context_java_proto", + deps = [":s2a_context_proto"], +) + +java_proto_library( + name = "common_java_proto", + deps = [":common_proto"], +) + +java_grpc_library( + name = "s2a_java_grpc_proto", + srcs = [":s2a_proto"], + deps = [":s2a_java_proto"], +) diff --git a/s2a/build.gradle b/s2a/build.gradle new file mode 100644 index 00000000000..054039571d8 --- /dev/null +++ b/s2a/build.gradle @@ -0,0 +1,153 @@ +buildscript { + dependencies { + classpath 'com.google.gradle:osdetector-gradle-plugin:1.4.0' + } +} + +plugins { + id "java-library" + id "maven-publish" + + id "com.github.johnrengelman.shadow" + id "com.google.protobuf" + id "ru.vyarus.animalsniffer" +} + +description = "gRPC: S2A" + +apply plugin: "com.google.osdetector" + +dependencies { + + api project(':grpc-api') + implementation project(':grpc-stub'), + project(':grpc-protobuf'), + project(':grpc-core'), + libraries.protobuf.java, + libraries.conscrypt, + libraries.guava.jre // JRE required by protobuf-java-util from grpclb + compileOnly 'org.jcommander:jcommander:1.83' + def nettyDependency = implementation project(':grpc-netty') + compileOnly libraries.javax.annotation + + shadow configurations.implementation.getDependencies().minus(nettyDependency) + shadow project(path: ':grpc-netty-shaded', configuration: 'shadow') + + testImplementation project(':grpc-benchmarks'), + project(':grpc-testing'), + project(':grpc-testing-proto'), + testFixtures(project(':grpc-core')), + libraries.guava, + libraries.junit, + libraries.mockito.core, + libraries.truth, + libraries.conscrypt, + libraries.netty.transport.epoll + + testImplementation 'org.jcommander:jcommander:1.83' + testImplementation 'com.google.truth:truth:1.4.2' + testImplementation 'com.google.truth.extensions:truth-proto-extension:1.4.2' + testImplementation libraries.guava.testlib + + testRuntimeOnly libraries.netty.tcnative, + libraries.netty.tcnative.classes + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "linux-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-x86_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "osx-aarch_64" + } + } + testRuntimeOnly (libraries.netty.tcnative) { + artifact { + classifier = "windows-x86_64" + } + } + testRuntimeOnly (libraries.netty.transport.epoll) { + artifact { + classifier = "linux-x86_64" + } + } + + signature libraries.signature.java +} + +tasks.named("compileJava") { + dependsOn(tasks.named("generateProto")) + //dependsOn(tasks.named("syncGeneratedSourcesmain")) +} + + +tasks.named("sourcesJar") { + dependsOn(tasks.named("generateProto")) + //dependsOn(tasks.named("syncGeneratedSourcesmain")) +} + +sourceSets { + main { + //java.srcDirs += "src/generated/main/java" + //java.srcDirs += "src/generated/main/grpc" + } +} +//println sourceSets.main.java.srcDirs +//println sourceSets.test.resources.srcDirs + +configureProtoCompilation() + +tasks.named("javadoc").configure { + exclude 'io/grpc/s2a/**' +} + +tasks.named("jar").configure { + // Must use a different archiveClassifier to avoid conflicting with shadowJar + archiveClassifier = 'original' + manifest { + attributes('Automatic-Module-Name': 'io.grpc.s2a') + } +} + +// We want to use grpc-netty-shaded instead of grpc-netty. But we also want our +// source to work with Bazel, so we rewrite the code as part of the build. +tasks.named("shadowJar").configure { + archiveClassifier = null + dependencies { + exclude(dependency {true}) + } + relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty' + relocate 'io.netty', 'io.grpc.netty.shaded.io.netty' +} + +publishing { + publications { + maven(MavenPublication) { + // We want this to throw an exception if it isn't working + def originalJar = artifacts.find { dep -> dep.classifier == 'original'} + artifacts.remove(originalJar) + + pom.withXml { + def dependenciesNode = new Node(null, 'dependencies') + project.configurations.shadow.allDependencies.each { dep -> + def dependencyNode = dependenciesNode.appendNode('dependency') + dependencyNode.appendNode('groupId', dep.group) + dependencyNode.appendNode('artifactId', dep.name) + dependencyNode.appendNode('version', dep.version) + dependencyNode.appendNode('scope', 'compile') + } + asNode().dependencies[0].replaceNode(dependenciesNode) + } + } + } +} diff --git a/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java b/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java new file mode 100644 index 00000000000..fd6b991c039 --- /dev/null +++ b/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java @@ -0,0 +1,285 @@ +package io.grpc.s2a.handshaker; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + */ +@javax.annotation.Generated( + value = "by gRPC proto compiler", + comments = "Source: grpc/gcp/s2a.proto") +@io.grpc.stub.annotations.GrpcGenerated +public final class S2AServiceGrpc { + + private S2AServiceGrpc() {} + + public static final java.lang.String SERVICE_NAME = "grpc.gcp.S2AService"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getSetUpSessionMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "SetUpSession", + requestType = io.grpc.s2a.handshaker.SessionReq.class, + responseType = io.grpc.s2a.handshaker.SessionResp.class, + methodType = io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + public static io.grpc.MethodDescriptor getSetUpSessionMethod() { + io.grpc.MethodDescriptor getSetUpSessionMethod; + if ((getSetUpSessionMethod = S2AServiceGrpc.getSetUpSessionMethod) == null) { + synchronized (S2AServiceGrpc.class) { + if ((getSetUpSessionMethod = S2AServiceGrpc.getSetUpSessionMethod) == null) { + S2AServiceGrpc.getSetUpSessionMethod = getSetUpSessionMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "SetUpSession")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.s2a.handshaker.SessionReq.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.s2a.handshaker.SessionResp.getDefaultInstance())) + .setSchemaDescriptor(new S2AServiceMethodDescriptorSupplier("SetUpSession")) + .build(); + } + } + } + return getSetUpSessionMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static S2AServiceStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public S2AServiceStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceStub(channel, callOptions); + } + }; + return S2AServiceStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static S2AServiceBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public S2AServiceBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceBlockingStub(channel, callOptions); + } + }; + return S2AServiceBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static S2AServiceFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public S2AServiceFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceFutureStub(channel, callOptions); + } + }; + return S2AServiceFutureStub.newStub(factory, channel); + } + + /** + */ + public interface AsyncService { + + /** + *
+     * SetUpSession is a bidirectional stream used by applications to offload
+     * operations from the TLS handshake.
+     * 
+ */ + default io.grpc.stub.StreamObserver setUpSession( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(getSetUpSessionMethod(), responseObserver); + } + } + + /** + * Base class for the server implementation of the service S2AService. + */ + public static abstract class S2AServiceImplBase + implements io.grpc.BindableService, AsyncService { + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return S2AServiceGrpc.bindService(this); + } + } + + /** + * A stub to allow clients to do asynchronous rpc calls to service S2AService. + */ + public static final class S2AServiceStub + extends io.grpc.stub.AbstractAsyncStub { + private S2AServiceStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected S2AServiceStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceStub(channel, callOptions); + } + + /** + *
+     * SetUpSession is a bidirectional stream used by applications to offload
+     * operations from the TLS handshake.
+     * 
+ */ + public io.grpc.stub.StreamObserver setUpSession( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ClientCalls.asyncBidiStreamingCall( + getChannel().newCall(getSetUpSessionMethod(), getCallOptions()), responseObserver); + } + } + + /** + * A stub to allow clients to do synchronous rpc calls to service S2AService. + */ + public static final class S2AServiceBlockingStub + extends io.grpc.stub.AbstractBlockingStub { + private S2AServiceBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected S2AServiceBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceBlockingStub(channel, callOptions); + } + } + + /** + * A stub to allow clients to do ListenableFuture-style rpc calls to service S2AService. + */ + public static final class S2AServiceFutureStub + extends io.grpc.stub.AbstractFutureStub { + private S2AServiceFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected S2AServiceFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new S2AServiceFutureStub(channel, callOptions); + } + } + + private static final int METHODID_SET_UP_SESSION = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final AsyncService serviceImpl; + private final int methodId; + + MethodHandlers(AsyncService serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_SET_UP_SESSION: + return (io.grpc.stub.StreamObserver) serviceImpl.setUpSession( + (io.grpc.stub.StreamObserver) responseObserver); + default: + throw new AssertionError(); + } + } + } + + public static final io.grpc.ServerServiceDefinition bindService(AsyncService service) { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getSetUpSessionMethod(), + io.grpc.stub.ServerCalls.asyncBidiStreamingCall( + new MethodHandlers< + io.grpc.s2a.handshaker.SessionReq, + io.grpc.s2a.handshaker.SessionResp>( + service, METHODID_SET_UP_SESSION))) + .build(); + } + + private static abstract class S2AServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + S2AServiceBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.grpc.s2a.handshaker.S2AProto.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("S2AService"); + } + } + + private static final class S2AServiceFileDescriptorSupplier + extends S2AServiceBaseDescriptorSupplier { + S2AServiceFileDescriptorSupplier() {} + } + + private static final class S2AServiceMethodDescriptorSupplier + extends S2AServiceBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final java.lang.String methodName; + + S2AServiceMethodDescriptorSupplier(java.lang.String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (S2AServiceGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new S2AServiceFileDescriptorSupplier()) + .addMethod(getSetUpSessionMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java new file mode 100644 index 00000000000..b2aee6db49e --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; + +import io.grpc.ChannelCredentials; +import io.grpc.TlsChannelCredentials; +import io.grpc.util.AdvancedTlsX509KeyManager; +import io.grpc.util.AdvancedTlsX509TrustManager; +import java.io.File; +import java.io.IOException; +import java.security.GeneralSecurityException; + +/** + * Configures an {@code S2AChannelCredentials.Builder} instance with credentials used to establish a + * connection with the S2A to support talking to the S2A over mTLS. + */ +public final class MtlsToS2AChannelCredentials { + /** + * Creates a {@code S2AChannelCredentials.Builder} builder, that talks to the S2A over mTLS. + * + * @param s2aAddress the address of the S2A server used to secure the connection. + * @param privateKeyPath the path to the private key PEM to use for authenticating to the S2A. + * @param certChainPath the path to the cert chain PEM to use for authenticating to the S2A. + * @param trustBundlePath the path to the trust bundle PEM. + * @return a {@code MtlsToS2AChannelCredentials.Builder} instance. + */ + public static Builder createBuilder( + String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) { + checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + checkArgument(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty."); + checkArgument(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty."); + checkArgument(!isNullOrEmpty(trustBundlePath), "trustBundlePath must not be null or empty."); + return new Builder(s2aAddress, privateKeyPath, certChainPath, trustBundlePath); + } + + /** Builds an {@code MtlsToS2AChannelCredentials} instance. */ + public static final class Builder { + private final String s2aAddress; + private final String privateKeyPath; + private final String certChainPath; + private final String trustBundlePath; + + Builder( + String s2aAddress, String privateKeyPath, String certChainPath, String trustBundlePath) { + this.s2aAddress = s2aAddress; + this.privateKeyPath = privateKeyPath; + this.certChainPath = certChainPath; + this.trustBundlePath = trustBundlePath; + } + + public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IOException { + checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + checkState(!isNullOrEmpty(privateKeyPath), "privateKeyPath must not be null or empty."); + checkState(!isNullOrEmpty(certChainPath), "certChainPath must not be null or empty."); + checkState(!isNullOrEmpty(trustBundlePath), "trustBundlePath must not be null or empty."); + File privateKeyFile = new File(privateKeyPath); + File certChainFile = new File(certChainPath); + File trustBundleFile = new File(trustBundlePath); + + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + keyManager.updateIdentityCredentialsFromFile(privateKeyFile, certChainFile); + + AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); + trustManager.updateTrustCredentialsFromFile(trustBundleFile); + + ChannelCredentials channelToS2ACredentials = + TlsChannelCredentials.newBuilder() + .keyManager(keyManager) + .trustManager(trustManager) + .build(); + + return S2AChannelCredentials.createBuilder(s2aAddress) + .setS2AChannelCredentials(channelToS2ACredentials); + } + } + + private MtlsToS2AChannelCredentials() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java new file mode 100644 index 00000000000..4ad05b4541a --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.netty.InternalNettyChannelCredentials; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Configures gRPC to use S2A for transport security when establishing a secure channel. Only for + * use on the client side of a gRPC connection. + */ +@NotThreadSafe +public final class S2AChannelCredentials { + /** + * Creates a channel credentials builder for establishing an S2A-secured connection. + * + * @param s2aAddress the address of the S2A server used to secure the connection. + * @return a {@code S2AChannelCredentials.Builder} instance. + */ + public static Builder createBuilder(String s2aAddress) { + checkArgument(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + return new Builder(s2aAddress); + } + + /** Builds an {@code S2AChannelCredentials} instance. */ + public static final class Builder { + private final String s2aAddress; + private ObjectPool s2aChannelPool; + private Optional s2aChannelCredentials; + private @Nullable S2AIdentity localIdentity = null; + + Builder(String s2aAddress) { + this.s2aAddress = s2aAddress; + this.s2aChannelPool = null; + this.s2aChannelCredentials = Optional.empty(); + } + + /** + * Sets the local identity of the client in the form of a SPIFFE ID. The client may set at most + * 1 local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalSpiffeId(String localSpiffeId) { + checkNotNull(localSpiffeId); + localIdentity = S2AIdentity.fromSpiffeId(localSpiffeId); + return this; + } + + /** + * Sets the local identity of the client in the form of a hostname. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalHostname(String localHostname) { + checkNotNull(localHostname); + localIdentity = S2AIdentity.fromHostname(localHostname); + return this; + } + + /** + * Sets the local identity of the client in the form of a UID. The client may set at most 1 + * local identity. If no local identity is specified, then the S2A chooses a default local + * identity, if one exists. + */ + @CanIgnoreReturnValue + public Builder setLocalUid(String localUid) { + checkNotNull(localUid); + localIdentity = S2AIdentity.fromUid(localUid); + return this; + } + + /** Sets the credentials to be used when connecting to the S2A. */ + @CanIgnoreReturnValue + public Builder setS2AChannelCredentials(ChannelCredentials s2aChannelCredentials) { + this.s2aChannelCredentials = Optional.of(s2aChannelCredentials); + return this; + } + + public ChannelCredentials build() { + checkState(!isNullOrEmpty(s2aAddress), "S2A address must not be null or empty."); + ObjectPool s2aChannelPool = + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials)); + checkNotNull(s2aChannelPool, "s2aChannelPool"); + this.s2aChannelPool = s2aChannelPool; + return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory()); + } + + InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() { + if (localIdentity == null) { + return S2AProtocolNegotiatorFactory.createClientFactory(Optional.empty(), s2aChannelPool); + } else { + return S2AProtocolNegotiatorFactory.createClientFactory( + Optional.of(localIdentity), s2aChannelPool); + } + } + } + + private S2AChannelCredentials() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java new file mode 100644 index 00000000000..e0501e91c66 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java @@ -0,0 +1,43 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.grpc.Channel; +import javax.annotation.concurrent.ThreadSafe; + +/** Manages a channel pool to be used for communication with the S2A. */ +@ThreadSafe +public interface S2AChannelPool extends AutoCloseable { + /** + * Retrieves an open channel to the S2A from the channel pool. + * + *

If no channel is available, blocks until a channel can be retrieved from the channel pool. + */ + @CanIgnoreReturnValue + Channel getChannel(); + + /** Returns a channel to the channel pool. */ + void returnChannel(Channel channel); + + /** + * Returns all channels to the channel pool and closes the pool so that no new channels can be + * retrieved from the pool. + */ + @Override + void close(); +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java new file mode 100644 index 00000000000..1d1de28e64e --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java @@ -0,0 +1,112 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.Channel; +import io.grpc.internal.ObjectPool; +import javax.annotation.concurrent.ThreadSafe; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Manages a gRPC channel pool and a cached gRPC channel to be used for communication with the S2A. + */ +@ThreadSafe +public final class S2AGrpcChannelPool implements S2AChannelPool { + private static final int MAX_NUMBER_USERS_OF_CACHED_CHANNEL = 100000; + private final ObjectPool channelPool; + + @GuardedBy("this") + private @Nullable Channel cachedChannel; + + @GuardedBy("this") + private int numberOfUsersOfCachedChannel = 0; + + private enum State { + OPEN, + CLOSED, + } + + ; + + @GuardedBy("this") + private State state = State.OPEN; + + public static S2AChannelPool create(ObjectPool channelPool) { + checkNotNull(channelPool, "Channel pool should not be null."); + return new S2AGrpcChannelPool(channelPool); + } + + private S2AGrpcChannelPool(ObjectPool channelPool) { + this.channelPool = channelPool; + } + + /** + * Retrieves a channel from {@code channelPool} if {@code channel} is null, and returns {@code + * channel} otherwise. + * + * @return a {@link Channel} obtained from the channel pool. + */ + @Override + public synchronized Channel getChannel() { + checkState(state.equals(State.OPEN), "Channel pool is not open."); + checkState( + numberOfUsersOfCachedChannel >= 0, + "Number of users of cached channel must be non-negative."); + checkState( + numberOfUsersOfCachedChannel < MAX_NUMBER_USERS_OF_CACHED_CHANNEL, + "Max number of channels have been retrieved from the channel pool."); + if (cachedChannel == null) { + cachedChannel = channelPool.getObject(); + } + numberOfUsersOfCachedChannel += 1; + return cachedChannel; + } + + /** + * Returns {@code channel} to {@code channelPool}. + * + *

The caller must ensure that {@code channel} was retrieved from this channel pool. + */ + @Override + public synchronized void returnChannel(Channel channel) { + checkState(state.equals(State.OPEN), "Channel pool is not open."); + checkArgument( + cachedChannel != null && numberOfUsersOfCachedChannel > 0 && cachedChannel.equals(channel), + "Cannot return the channel to channel pool because the channel was not obtained from" + + " channel pool."); + numberOfUsersOfCachedChannel -= 1; + if (numberOfUsersOfCachedChannel == 0) { + channelPool.returnObject(channel); + cachedChannel = null; + } + } + + @Override + public synchronized void close() { + state = State.CLOSED; + numberOfUsersOfCachedChannel = 0; + if (cachedChannel != null) { + channelPool.returnObject(cachedChannel); + cachedChannel = null; + } + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java new file mode 100644 index 00000000000..75ec7347bb5 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AHandshakerServiceChannel.java @@ -0,0 +1,195 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Maps; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.concurrent.DefaultThreadFactory; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.ConcurrentMap; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Provides APIs for managing gRPC channels to S2A servers. Each channel is local and plaintext. If + * credentials are provided, they are used to secure the channel. + * + *

This is done as follows: for each S2A server, provides an implementation of gRPC's {@link + * SharedResourceHolder.Resource} interface called a {@code Resource}. A {@code + * Resource} is a factory for creating gRPC channels to the S2A server at a given address, + * and a channel must be returned to the {@code Resource} when it is no longer needed. + * + *

Typical usage pattern is below: + * + *

{@code
+ * Resource resource = S2AHandshakerServiceChannel.getChannelResource("localhost:1234",
+ * creds);
+ * Channel channel = resource.create();
+ * // Send an RPC over the channel to the S2A server running at localhost:1234.
+ * resource.close(channel);
+ * }
+ */ +@ThreadSafe +public final class S2AHandshakerServiceChannel { + private static final ConcurrentMap> SHARED_RESOURCE_CHANNELS = + Maps.newConcurrentMap(); + private static final Duration DELEGATE_TERMINATION_TIMEOUT = Duration.ofSeconds(2); + private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); + + /** + * Returns a {@link SharedResourceHolder.Resource} instance for managing channels to an S2A server + * running at {@code s2aAddress}. + * + * @param s2aAddress the address of the S2A, typically in the format {@code host:port}. + * @param s2aChannelCredentials the credentials to use when establishing a connection to the S2A. + * @return a {@link ChannelResource} instance that manages a {@link Channel} to the S2A server + * running at {@code s2aAddress}. + */ + public static Resource getChannelResource( + String s2aAddress, Optional s2aChannelCredentials) { + checkNotNull(s2aAddress); + return SHARED_RESOURCE_CHANNELS.computeIfAbsent( + s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials)); + } + + /** + * Defines how to create and destroy a {@link Channel} instance that uses shared resources. A + * channel created by {@code ChannelResource} is a plaintext, local channel to the service running + * at {@code targetAddress}. + */ + private static class ChannelResource implements Resource { + private final String targetAddress; + private final Optional channelCredentials; + + public ChannelResource(String targetAddress, Optional channelCredentials) { + this.targetAddress = targetAddress; + this.channelCredentials = channelCredentials; + } + + /** + * Creates a {@code EventLoopHoldingChannel} instance to the service running at {@code + * targetAddress}. This channel uses a dedicated thread pool for its {@code EventLoopGroup} + * instance to avoid blocking. + */ + @Override + public Channel create() { + EventLoopGroup eventLoopGroup = + new NioEventLoopGroup(1, new DefaultThreadFactory("S2A channel pool", true)); + ManagedChannel channel = null; + if (channelCredentials.isPresent()) { + // Create a secure channel. + channel = + NettyChannelBuilder.forTarget(targetAddress, channelCredentials.get()) + .channelType(NioSocketChannel.class) + .directExecutor() + .eventLoopGroup(eventLoopGroup) + .build(); + } else { + // Create a plaintext channel. + channel = + NettyChannelBuilder.forTarget(targetAddress) + .channelType(NioSocketChannel.class) + .directExecutor() + .eventLoopGroup(eventLoopGroup) + .usePlaintext() + .build(); + } + return EventLoopHoldingChannel.create(channel, eventLoopGroup); + } + + /** Destroys a {@code EventLoopHoldingChannel} instance. */ + @Override + public void close(Channel instanceChannel) { + checkNotNull(instanceChannel); + EventLoopHoldingChannel channel = (EventLoopHoldingChannel) instanceChannel; + channel.close(); + } + + @Override + public String toString() { + return "grpc-s2a-channel"; + } + } + + /** + * Manages a channel using a {@link ManagedChannel} instance that belong to the {@code + * EventLoopGroup} thread pool. + */ + @VisibleForTesting + static class EventLoopHoldingChannel extends Channel { + private final ManagedChannel delegate; + private final EventLoopGroup eventLoopGroup; + + static EventLoopHoldingChannel create(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + checkNotNull(delegate); + checkNotNull(eventLoopGroup); + return new EventLoopHoldingChannel(delegate, eventLoopGroup); + } + + private EventLoopHoldingChannel(ManagedChannel delegate, EventLoopGroup eventLoopGroup) { + this.delegate = delegate; + this.eventLoopGroup = eventLoopGroup; + } + + /** + * Returns the address of the service to which the {@code delegate} channel connects, which is + * typically of the form {@code host:port}. + */ + @Override + public String authority() { + return delegate.authority(); + } + + /** Creates a {@link ClientCall} that invokes the operations in {@link MethodDescriptor}. */ + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions options) { + return delegate.newCall(methodDescriptor, options); + } + + @SuppressWarnings("FutureReturnValueIgnored") + public void close() { + delegate.shutdownNow(); + boolean isDelegateTerminated; + try { + isDelegateTerminated = + delegate.awaitTermination(DELEGATE_TERMINATION_TIMEOUT.getSeconds(), SECONDS); + } catch (InterruptedException e) { + isDelegateTerminated = false; + } + long quietPeriodSeconds = isDelegateTerminated ? 0 : 1; + eventLoopGroup.shutdownGracefully( + quietPeriodSeconds, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); + } + } + + private S2AHandshakerServiceChannel() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java new file mode 100644 index 00000000000..1f9b2d5a23a --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import java.io.IOException; + +/** Indicates that a connection has been closed. */ +@SuppressWarnings("serial") // This class is never serialized. +final class ConnectionIsClosedException extends IOException { + public ConnectionIsClosedException(String errorMessage) { + super(errorMessage); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java new file mode 100644 index 00000000000..3b17a5ed322 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java @@ -0,0 +1,60 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import com.google.errorprone.annotations.Immutable; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.tokenmanager.AccessTokenManager; +import java.util.Optional; + +/** Retrieves the authentication mechanism for a given local identity. */ +@Immutable +final class GetAuthenticationMechanisms { + private static final Optional TOKEN_MANAGER = AccessTokenManager.create(); + + /** + * Retrieves the authentication mechanism for a given local identity. + * + * @param localIdentity the identity for which to fetch a token. + * @return an {@link AuthenticationMechanism} for the given local identity. + */ + static Optional getAuthMechanism(Optional localIdentity) { + Optional authMechanism = Optional.empty(); + if (!TOKEN_MANAGER.isPresent()) { + return Optional.empty(); + } + AccessTokenManager manager = TOKEN_MANAGER.get(); + // If no identity is provided, fetch the default access token and DO NOT attach an identity + // to the request. + if (!localIdentity.isPresent()) { + authMechanism = + Optional.of( + AuthenticationMechanism.newBuilder().setToken(manager.getDefaultToken()).build()); + } else { + // Fetch an access token for the provided identity. + authMechanism = + Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(localIdentity.get().identity()) + .setToken(manager.getToken(localIdentity.get())) + .build()); + } + return authMechanism; + } + + private GetAuthenticationMechanisms() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java new file mode 100644 index 00000000000..34cc4bbe737 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java @@ -0,0 +1,72 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +/** Converts proto messages to Netty strings. */ +final class ProtoUtil { + /** + * Converts {@link Ciphersuite} to its {@link String} representation. + * + * @param ciphersuite the {@link Ciphersuite} to be converted. + * @return a {@link String} representing the ciphersuite. + * @throws AssertionError if the {@link Ciphersuite} is not one of the supported ciphersuites. + */ + static String convertCiphersuite(Ciphersuite ciphersuite) { + switch (ciphersuite) { + case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: + return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"; + case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: + return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"; + case CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256: + return "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"; + case CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256: + return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; + case CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384: + return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + case CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: + return "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"; + default: + throw new AssertionError( + String.format("Ciphersuite %d is not supported.", ciphersuite.getNumber())); + } + } + + /** + * Converts a {@link TLSVersion} object to its {@link String} representation. + * + * @param tlsVersion the {@link TLSVersion} object to be converted. + * @return a {@link String} representation of the TLS version. + * @throws AssertionError if the {@code tlsVersion} is not one of the supported TLS versions. + */ + static String convertTlsProtocolVersion(TLSVersion tlsVersion) { + switch (tlsVersion) { + case TLS_VERSION_1_3: + return "TLSv1.3"; + case TLS_VERSION_1_2: + return "TLSv1.2"; + case TLS_VERSION_1_1: + return "TLSv1.1"; + case TLS_VERSION_1_0: + return "TLSv1"; + default: + throw new AssertionError( + String.format("TLS version %d is not supported.", tlsVersion.getNumber())); + } + } + + private ProtoUtil() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java new file mode 100644 index 00000000000..d976308ad22 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +/** Exception that denotes a runtime error that was encountered when talking to the S2A server. */ +@SuppressWarnings("serial") // This class is never serialized. +public class S2AConnectionException extends RuntimeException { + S2AConnectionException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java new file mode 100644 index 00000000000..30957acd521 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.errorprone.annotations.ThreadSafe; + +/** + * Stores an identity in such a way that it can be sent to the S2A handshaker service. The identity + * may be formatted as a SPIFFE ID or as a hostname. + */ +@ThreadSafe +public final class S2AIdentity { + private final Identity identity; + + /** Returns an {@link S2AIdentity} instance with SPIFFE ID set to {@code spiffeId}. */ + public static S2AIdentity fromSpiffeId(String spiffeId) { + checkNotNull(spiffeId); + return new S2AIdentity(Identity.newBuilder().setSpiffeId(spiffeId).build()); + } + + /** Returns an {@link S2AIdentity} instance with hostname set to {@code hostname}. */ + public static S2AIdentity fromHostname(String hostname) { + checkNotNull(hostname); + return new S2AIdentity(Identity.newBuilder().setHostname(hostname).build()); + } + + /** Returns an {@link S2AIdentity} instance with UID set to {@code uid}. */ + public static S2AIdentity fromUid(String uid) { + checkNotNull(uid); + return new S2AIdentity(Identity.newBuilder().setUid(uid).build()); + } + + /** Returns an {@link S2AIdentity} instance with {@code identity} set. */ + public static S2AIdentity fromIdentity(Identity identity) { + return new S2AIdentity(identity == null ? Identity.getDefaultInstance() : identity); + } + + private S2AIdentity(Identity identity) { + this.identity = identity; + } + + /** Returns the proto {@link Identity} representation of this identity instance. */ + public Identity identity() { + return identity; + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java new file mode 100644 index 00000000000..fb4908d99fc --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java @@ -0,0 +1,143 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.ByteString; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import java.io.IOException; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.SSLEngine; + +/** + * Handles requests on signing bytes with a private key designated by {@code stub}. + * + *

This is done by sending the to-be-signed bytes to an S2A server (designated by {@code stub}) + * and read the signature from the server. + * + *

OpenSSL libraries must be appropriately initialized before using this class. One possible way + * to initialize OpenSSL library is to call {@code + * GrpcSslContexts.configure(SslContextBuilder.forClient());}. + */ +@NotThreadSafe +final class S2APrivateKeyMethod implements OpenSslPrivateKeyMethod { + private final S2AStub stub; + private final Optional localIdentity; + private static final ImmutableMap + OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP = + ImmutableMap.of( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512, + SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + + public static S2APrivateKeyMethod create(S2AStub stub, Optional localIdentity) { + checkNotNull(stub); + return new S2APrivateKeyMethod(stub, localIdentity); + } + + private S2APrivateKeyMethod(S2AStub stub, Optional localIdentity) { + this.stub = stub; + this.localIdentity = localIdentity; + } + + /** + * Converts the signature algorithm to an enum understood by S2A. + * + * @param signatureAlgorithm the int representation of the signature algorithm define by {@code + * OpenSslPrivateKeyMethod}. + * @return the signature algorithm enum defined by S2A proto. + * @throws UnsupportedOperationException if the algorithm is not supported by S2A. + */ + @VisibleForTesting + static SignatureAlgorithm convertOpenSslSignAlgToS2ASignAlg(int signatureAlgorithm) { + SignatureAlgorithm sig = OPENSSL_TO_S2A_SIGNATURE_ALGORITHM_MAP.get(signatureAlgorithm); + if (sig == null) { + throw new UnsupportedOperationException( + String.format("Signature Algorithm %d is not supported.", signatureAlgorithm)); + } + return sig; + } + + /** + * Signs the input bytes by sending the request to the S2A srever. + * + * @param engine not used. + * @param signatureAlgorithm the {@link OpenSslPrivateKeyMethod}'s signature algorithm + * representation + * @param input the bytes to be signed. + * @return the signature of the {@code input}. + * @throws IOException if the connection to the S2A server is corrupted. + * @throws InterruptedException if the connection to the S2A server is interrupted. + * @throws S2AConnectionException if the response from the S2A server does not contain valid data. + */ + @Override + public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) + throws IOException, InterruptedException { + checkArgument(input.length > 0, "No bytes to sign."); + SignatureAlgorithm s2aSignatureAlgorithm = + convertOpenSslSignAlgToS2ASignAlg(signatureAlgorithm); + SessionReq.Builder reqBuilder = + SessionReq.newBuilder() + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(s2aSignatureAlgorithm) + .setRawBytes(ByteString.copyFrom(input))); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().identity()); + } + + SessionResp resp = stub.send(reqBuilder.build()); + + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: \"%s\".", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.hasOffloadPrivateKeyOperationResp()) { + throw new S2AConnectionException("No valid response received from S2A."); + } + return resp.getOffloadPrivateKeyOperationResp().getOutBytes().toByteArray(); + } + + @Override + public byte[] decrypt(SSLEngine engine, byte[] input) { + throw new UnsupportedOperationException("decrypt is not supported."); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java new file mode 100644 index 00000000000..7f00e198fae --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -0,0 +1,194 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.net.HostAndPort; +import com.google.errorprone.annotations.ThreadSafe; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.internal.ObjectPool; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler; +import io.grpc.s2a.channel.S2AChannelPool; +import io.grpc.s2a.channel.S2AGrpcChannelPool; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.ssl.SslContext; +import io.netty.util.AsciiString; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.util.Optional; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Factory for performing negotiation of a secure channel using the S2A. */ +@ThreadSafe +public final class S2AProtocolNegotiatorFactory { + @VisibleForTesting static final int DEFAULT_PORT = 443; + private static final AsciiString SCHEME = AsciiString.of("https"); + + /** + * Creates a {@code S2AProtocolNegotiatorFactory} configured for a client to establish secure + * connections using the S2A. + * + * @param localIdentity the identity of the client; if none is provided, the S2A will use the + * client's default identity. + * @param s2aChannelPool a pool of shared channels that can be used to connect to the S2A. + * @return a factory for creating a client-side protocol negotiator. + */ + public static InternalProtocolNegotiator.ClientFactory createClientFactory( + Optional localIdentity, ObjectPool s2aChannelPool) { + checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); + checkNotNull(localIdentity, "Local identity should not be null on the client side."); + S2AChannelPool channelPool = S2AGrpcChannelPool.create(s2aChannelPool); + return new S2AClientProtocolNegotiatorFactory(localIdentity, channelPool); + } + + static final class S2AClientProtocolNegotiatorFactory + implements InternalProtocolNegotiator.ClientFactory { + private final Optional localIdentity; + private final S2AChannelPool channelPool; + + S2AClientProtocolNegotiatorFactory( + Optional localIdentity, S2AChannelPool channelPool) { + this.localIdentity = localIdentity; + this.channelPool = channelPool; + } + + @Override + public ProtocolNegotiator newNegotiator() { + return S2AProtocolNegotiator.createForClient(channelPool, localIdentity); + } + + @Override + public int getDefaultPort() { + return DEFAULT_PORT; + } + } + + /** Negotiates the TLS handshake using S2A. */ + @VisibleForTesting + static final class S2AProtocolNegotiator implements ProtocolNegotiator { + + private final S2AChannelPool channelPool; + private final Optional localIdentity; + + static S2AProtocolNegotiator createForClient( + S2AChannelPool channelPool, Optional localIdentity) { + checkNotNull(channelPool, "Channel pool should not be null."); + checkNotNull(localIdentity, "Local identity should not be null on the client side."); + return new S2AProtocolNegotiator(channelPool, localIdentity); + } + + @VisibleForTesting + static @Nullable String getHostNameFromAuthority(@Nullable String authority) { + if (authority == null) { + return null; + } + return HostAndPort.fromString(authority).getHost(); + } + + private S2AProtocolNegotiator(S2AChannelPool channelPool, Optional localIdentity) { + this.channelPool = channelPool; + this.localIdentity = localIdentity; + } + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + checkNotNull(grpcHandler, "grpcHandler should not be null."); + String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); + checkNotNull(hostname, "hostname should not be null."); + return new S2AProtocolNegotiationHandler( + InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler), + grpcHandler.getNegotiationLogger(), + channelPool, + localIdentity, + hostname, + grpcHandler); + } + + @Override + public void close() { + channelPool.close(); + } + } + + private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { + private final S2AChannelPool channelPool; + private final Optional localIdentity; + private final String hostname; + private InternalProtocolNegotiator.ProtocolNegotiator negotiator; + private final GrpcHttp2ConnectionHandler grpcHandler; + + private S2AProtocolNegotiationHandler( + ChannelHandler next, + ChannelLogger negotiationLogger, + S2AChannelPool channelPool, + Optional localIdentity, + String hostname, + GrpcHttp2ConnectionHandler grpcHandler) { + super(next, negotiationLogger); + this.channelPool = channelPool; + this.localIdentity = localIdentity; + this.hostname = hostname; + this.grpcHandler = grpcHandler; + } + + @Override + protected void handlerAdded0(ChannelHandlerContext ctx) throws GeneralSecurityException { + SslContext sslContext; + try { + // Establish a stream to S2A server. + Channel ch = channelPool.getChannel(); + S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch); + S2AStub s2aStub = S2AStub.newInstance(stub); + sslContext = SslContextFactory.createForClient(s2aStub, hostname, localIdentity); + } catch (InterruptedException + | IOException + | IllegalArgumentException + | UnrecoverableKeyException + | CertificateException + | NoSuchAlgorithmException + | KeyStoreException e) { + // GeneralSecurityException is intentionally not caught, and rather propagated. This is done + // because throwing a GeneralSecurityException in this context indicates that we encountered + // a retryable error. + throw new IllegalArgumentException( + "Something went wrong during the initialization of SslContext.", e); + } + negotiator = InternalProtocolNegotiators.tls(sslContext); + ctx.pipeline().addBefore(ctx.name(), /* name= */ null, negotiator.newHandler(grpcHandler)); + } + } + + private S2AProtocolNegotiatorFactory() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java new file mode 100644 index 00000000000..aa2502cd4fa --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java @@ -0,0 +1,225 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Verify.verify; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.concurrent.NotThreadSafe; + +/** Reads and writes messages to and from the S2A. */ +@NotThreadSafe +class S2AStub implements AutoCloseable { + private static final Logger logger = Logger.getLogger(S2AStub.class.getName()); + private static final long HANDSHAKE_RPC_DEADLINE_SECS = 20; + private final StreamObserver reader = new Reader(); + private final BlockingQueue responses = new ArrayBlockingQueue<>(10); + private S2AServiceGrpc.S2AServiceStub serviceStub; + private StreamObserver writer; + private boolean doneReading = false; + private boolean doneWriting = false; + + static S2AStub newInstance(S2AServiceGrpc.S2AServiceStub serviceStub) { + checkNotNull(serviceStub); + return new S2AStub(serviceStub); + } + + @VisibleForTesting + static S2AStub newInstanceForTesting(StreamObserver writer) { + checkNotNull(writer); + return new S2AStub(writer); + } + + private S2AStub(S2AServiceGrpc.S2AServiceStub serviceStub) { + this.serviceStub = serviceStub; + } + + private S2AStub(StreamObserver writer) { + this.writer = writer; + } + + @VisibleForTesting + StreamObserver getReader() { + return reader; + } + + @VisibleForTesting + BlockingQueue getResponses() { + return responses; + } + + /** + * Sends a request and returns the response. Caller must wait until this method executes prior to + * calling it again. If this method throws {@code ConnectionIsClosedException}, then it should not + * be called again, and both {@code reader} and {@code writer} are closed. + * + * @param req the {@code SessionReq} message to be sent to the S2A server. + * @return the {@code SessionResp} message received from the S2A server. + * @throws ConnectionIsClosedException if {@code reader} or {@code writer} calls their {@code + * onCompleted} method. + * @throws IOException if an unexpected response is received, or if the {@code reader} or {@code + * writer} calls their {@code onError} method. + */ + public SessionResp send(SessionReq req) throws IOException, InterruptedException { + if (doneWriting && doneReading) { + logger.log(Level.INFO, "Stream to the S2A is closed."); + throw new ConnectionIsClosedException("Stream to the S2A is closed."); + } + createWriterIfNull(); + if (!responses.isEmpty()) { + IOException exception = null; + SessionResp resp = null; + try { + resp = responses.take().getResultOrThrow(); + } catch (IOException e) { + exception = e; + } + responses.clear(); + if (exception != null) { + logger.log( + Level.WARNING, + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable. " + + exception.getMessage()); + throw new IOException( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable." + + exception.getMessage()); + } + return resp; + } + try { + writer.onNext(req); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Error occurred while writing to the S2A.", e); + writer.onError(e); + responses.offer(Result.createWithThrowable(e)); + } + try { + return responses.take().getResultOrThrow(); + } catch (ConnectionIsClosedException e) { + // A ConnectionIsClosedException is thrown by getResultOrThrow when reader calls its + // onCompleted method. The close method is called to also close the writer, and then the + // ConnectionIsClosedException is re-thrown in order to indicate to the caller that send + // should not be called again. + close(); + throw e; + } + } + + @Override + public void close() { + if (doneWriting && doneReading) { + return; + } + verify(!doneWriting); + doneReading = true; + doneWriting = true; + if (writer != null) { + writer.onCompleted(); + } + } + + /** Create a new writer if the writer is null. */ + private void createWriterIfNull() { + if (writer == null) { + writer = + serviceStub.withDeadlineAfter(HANDSHAKE_RPC_DEADLINE_SECS, SECONDS).setUpSession(reader); + } + } + + private class Reader implements StreamObserver { + /** + * Places a {@code SessionResp} message in the {@code responses} queue, or an {@code + * IOException} if reading is complete. + * + * @param resp the {@code SessionResp} message received from the S2A handshaker module. + */ + @Override + public void onNext(SessionResp resp) { + verify(!doneReading); + responses.offer(Result.createWithResponse(resp)); + } + + /** + * Places a {@code Throwable} in the {@code responses} queue. + * + * @param t the {@code Throwable} caught when reading the stream to the S2A handshaker module. + */ + @Override + public void onError(Throwable t) { + logger.log(Level.WARNING, "Error occurred while reading from the S2A.", t); + responses.offer(Result.createWithThrowable(t)); + } + + /** + * Sets {@code doneReading} to true, and places a {@code ConnectionIsClosedException} in the + * {@code responses} queue. + */ + @Override + public void onCompleted() { + logger.log(Level.INFO, "Reading from the S2A is complete."); + doneReading = true; + responses.offer( + Result.createWithThrowable( + new ConnectionIsClosedException("Reading from the S2A is complete."))); + } + } + + private static final class Result { + private final Optional response; + private final Optional throwable; + + static Result createWithResponse(SessionResp response) { + return new Result(Optional.of(response), Optional.empty()); + } + + static Result createWithThrowable(Throwable throwable) { + return new Result(Optional.empty(), Optional.of(throwable)); + } + + private Result(Optional response, Optional throwable) { + checkArgument(response.isPresent() != throwable.isPresent()); + this.response = response; + this.throwable = throwable; + } + + /** Throws {@code throwable} if present, and returns {@code response} otherwise. */ + SessionResp getResultOrThrow() throws IOException { + if (throwable.isPresent()) { + if (throwable.get() instanceof ConnectionIsClosedException) { + ConnectionIsClosedException exception = (ConnectionIsClosedException) throwable.get(); + throw exception; + } else { + throw new IOException(throwable.get()); + } + } + verify(response.isPresent()); + return response.get(); + } + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java new file mode 100644 index 00000000000..014fcf4c4f8 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java @@ -0,0 +1,152 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.ValidatePeerCertificateChainReq.VerificationMode; +import java.io.IOException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.annotation.concurrent.NotThreadSafe; +import javax.net.ssl.X509TrustManager; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** Offloads verification of the peer certificate chain to S2A. */ +@NotThreadSafe +final class S2ATrustManager implements X509TrustManager { + private final Optional localIdentity; + private final S2AStub stub; + private final String hostname; + + static S2ATrustManager createForClient( + S2AStub stub, String hostname, Optional localIdentity) { + checkNotNull(stub); + checkNotNull(hostname); + return new S2ATrustManager(stub, hostname, localIdentity); + } + + private S2ATrustManager(S2AStub stub, String hostname, Optional localIdentity) { + this.stub = stub; + this.hostname = hostname; + this.localIdentity = localIdentity; + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ true); + } + + /** + * Validates the given certificate chain provided by the peer. + * + * @param chain the peer certificate chain + * @param authType the authentication type based on the client certificate + * @throws IllegalArgumentException if null or zero-length chain is passed in for the chain + * parameter. + * @throws CertificateException if the certificate chain is not trusted by this TrustManager. + */ + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + checkPeerTrusted(chain, /* isCheckingClientCertificateChain= */ false); + } + + /** + * Returns null because the accepted issuers are held in S2A and this class receives decision made + * from S2A on the fly about which to use to verify a given chain. + * + * @return null. + */ + @Override + public X509Certificate @Nullable [] getAcceptedIssuers() { + return null; + } + + private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientCertificateChain) + throws CertificateException { + checkNotNull(chain); + checkArgument(chain.length > 0, "Certificate chain has zero certificates."); + + ValidatePeerCertificateChainReq.Builder validatePeerCertificateChainReq = + ValidatePeerCertificateChainReq.newBuilder().setMode(VerificationMode.UNSPECIFIED); + if (isCheckingClientCertificateChain) { + validatePeerCertificateChainReq.setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain))); + } else { + validatePeerCertificateChainReq.setServerPeer( + ValidatePeerCertificateChainReq.ServerPeer.newBuilder() + .addAllCertificateChain(certificateChainToDerChain(chain)) + .setServerHostname(hostname)); + } + + SessionReq.Builder reqBuilder = + SessionReq.newBuilder().setValidatePeerCertificateChainReq(validatePeerCertificateChainReq); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().identity()); + } + + SessionResp resp; + try { + resp = stub.send(reqBuilder.build()); + } catch (IOException | InterruptedException e) { + throw new CertificateException("Failed to send request to S2A.", e); + } + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new CertificateException( + String.format( + "Error occurred in response from S2A, error code: %d, error message: %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + + if (!resp.hasValidatePeerCertificateChainResp()) { + throw new CertificateException("No valid response received from S2A."); + } + + ValidatePeerCertificateChainResp validationResult = resp.getValidatePeerCertificateChainResp(); + if (validationResult.getValidationResult() + != ValidatePeerCertificateChainResp.ValidationResult.SUCCESS) { + throw new CertificateException(validationResult.getValidationDetails()); + } + } + + private static ImmutableList certificateChainToDerChain(X509Certificate[] chain) + throws CertificateEncodingException { + ImmutableList.Builder derChain = ImmutableList.builder(); + for (X509Certificate certificate : chain) { + derChain.add(ByteString.copyFrom(certificate.getEncoded())); + } + return derChain.build(); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java new file mode 100644 index 00000000000..bfa45146625 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java @@ -0,0 +1,179 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.base.Preconditions.checkNotNull; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableList; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslContextOption; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.OpenSslX509KeyManagerFactory; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLSessionContext; + +/** Creates {@link SslContext} objects with TLS configurations from S2A server. */ +final class SslContextFactory { + + /** + * Creates {@link SslContext} objects for client with TLS configurations from S2A server. + * + * @param stub the {@link S2AStub} to talk to the S2A server. + * @param targetName the {@link String} of the server that this client makes connection to. + * @param localIdentity the {@link S2AIdentity} that should be used when talking to S2A server. + * Will use default identity if empty. + * @return a {@link SslContext} object. + * @throws NullPointerException if either {@code stub} or {@code targetName} is null. + * @throws IOException if an unexpected response from S2A server is received. + * @throws InterruptedException if {@code stub} is closed. + */ + static SslContext createForClient( + S2AStub stub, String targetName, Optional localIdentity) + throws IOException, + InterruptedException, + CertificateException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException, + GeneralSecurityException { + checkNotNull(stub, "stub should not be null."); + checkNotNull(targetName, "targetName should not be null on client side."); + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration; + try { + clientTlsConfiguration = getClientTlsConfigurationFromS2A(stub, localIdentity); + } catch (IOException | InterruptedException e) { + throw new GeneralSecurityException("Failed to get client TLS configuration from S2A.", e); + } + + // Use the default value for timeout. + // Use the smallest possible value for cache size. + // The Provider is by default OPENSSL. No need to manually set it. + SslContextBuilder sslContextBuilder = + GrpcSslContexts.configure(SslContextBuilder.forClient()) + .sessionCacheSize(1) + .sessionTimeout(0); + + configureSslContextWithClientTlsConfiguration(clientTlsConfiguration, sslContextBuilder); + sslContextBuilder.trustManager( + S2ATrustManager.createForClient(stub, targetName, localIdentity)); + sslContextBuilder.option( + OpenSslContextOption.PRIVATE_KEY_METHOD, S2APrivateKeyMethod.create(stub, localIdentity)); + + SslContext sslContext = sslContextBuilder.build(); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + openSslSessionContext.setSessionCacheEnabled(false); + } + + return sslContext; + } + + private static GetTlsConfigurationResp.ClientTlsConfiguration getClientTlsConfigurationFromS2A( + S2AStub stub, Optional localIdentity) throws IOException, InterruptedException { + checkNotNull(stub, "stub should not be null."); + SessionReq.Builder reqBuilder = SessionReq.newBuilder(); + if (localIdentity.isPresent()) { + reqBuilder.setLocalIdentity(localIdentity.get().identity()); + } + Optional authMechanism = + GetAuthenticationMechanisms.getAuthMechanism(localIdentity); + if (authMechanism.isPresent()) { + reqBuilder.addAuthenticationMechanisms(authMechanism.get()); + } + SessionResp resp = + stub.send( + reqBuilder + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + if (resp.hasStatus() && resp.getStatus().getCode() != 0) { + throw new S2AConnectionException( + String.format( + "response from S2A server has ean error %d with error message %s.", + resp.getStatus().getCode(), resp.getStatus().getDetails())); + } + if (!resp.getGetTlsConfigurationResp().hasClientTlsConfiguration()) { + throw new S2AConnectionException( + "Response from S2A server does NOT contain ClientTlsConfiguration."); + } + return resp.getGetTlsConfigurationResp().getClientTlsConfiguration(); + } + + private static void configureSslContextWithClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration, + SslContextBuilder sslContextBuilder) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + sslContextBuilder.keyManager(createKeylessManager(clientTlsConfiguration)); + sslContextBuilder.protocols( + ProtoUtil.convertTlsProtocolVersion(clientTlsConfiguration.getMinTlsVersion()), + ProtoUtil.convertTlsProtocolVersion(clientTlsConfiguration.getMaxTlsVersion())); + ImmutableList.Builder ciphersuites = ImmutableList.builder(); + for (int i = 0; i < clientTlsConfiguration.getCiphersuitesCount(); ++i) { + ciphersuites.add(ProtoUtil.convertCiphersuite(clientTlsConfiguration.getCiphersuites(i))); + } + sslContextBuilder.ciphers(ciphersuites.build()); + } + + private static KeyManager createKeylessManager( + GetTlsConfigurationResp.ClientTlsConfiguration clientTlsConfiguration) + throws CertificateException, + IOException, + KeyStoreException, + NoSuchAlgorithmException, + UnrecoverableKeyException { + X509Certificate[] certificates = + new X509Certificate[clientTlsConfiguration.getCertificateChainCount()]; + for (int i = 0; i < clientTlsConfiguration.getCertificateChainCount(); ++i) { + certificates[i] = convertStringToX509Cert(clientTlsConfiguration.getCertificateChain(i)); + } + KeyManager[] keyManagers = + OpenSslX509KeyManagerFactory.newKeyless(certificates).getKeyManagers(); + if (keyManagers == null || keyManagers.length == 0) { + throw new IllegalStateException("No key managers created."); + } + return keyManagers[0]; + } + + private static X509Certificate convertStringToX509Cert(String certificate) + throws CertificateException { + return (X509Certificate) + CertificateFactory.getInstance("X509") + .generateCertificate(new ByteArrayInputStream(certificate.getBytes(UTF_8))); + } + + private SslContextFactory() {} +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java new file mode 100644 index 00000000000..94549d11c87 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/AccessTokenManager.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import io.grpc.s2a.handshaker.S2AIdentity; +import java.lang.reflect.Method; +import java.util.Optional; +import javax.annotation.concurrent.ThreadSafe; + +/** Manages access tokens for authenticating to the S2A. */ +@ThreadSafe +public final class AccessTokenManager { + private final TokenFetcher tokenFetcher; + + /** Creates an {@code AccessTokenManager} based on the environment where the application runs. */ + @SuppressWarnings("RethrowReflectiveOperationExceptionAsLinkageError") + public static Optional create() { + Optional tokenFetcher; + try { + Class singleTokenFetcherClass = + Class.forName("io.grpc.s2a.handshaker.tokenmanager.SingleTokenFetcher"); + Method createTokenFetcher = singleTokenFetcherClass.getMethod("create"); + tokenFetcher = (Optional) createTokenFetcher.invoke(null); + } catch (ClassNotFoundException e) { + tokenFetcher = Optional.empty(); + } catch (ReflectiveOperationException e) { + throw new AssertionError(e); + } + return tokenFetcher.isPresent() + ? Optional.of(new AccessTokenManager((TokenFetcher) tokenFetcher.get())) + : Optional.empty(); + } + + private AccessTokenManager(TokenFetcher tokenFetcher) { + this.tokenFetcher = tokenFetcher; + } + + /** Returns an access token when no identity is specified. */ + public String getDefaultToken() { + return tokenFetcher.getDefaultToken(); + } + + /** Returns an access token for the given identity. */ + public String getToken(S2AIdentity identity) { + return tokenFetcher.getToken(identity); + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java new file mode 100644 index 00000000000..3b2bd051e84 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; +import io.grpc.s2a.handshaker.S2AIdentity; +import java.util.Optional; + +/** Fetches a single access token via an environment variable. */ +public final class SingleTokenFetcher implements TokenFetcher { + private static final String ENVIRONMENT_VARIABLE = "S2A_ACCESS_TOKEN"; + + /** Set an access token via a flag. */ + @Parameters(separators = "=") + public static class Flags { + @Parameter( + names = "--s2a_access_token", + description = "The access token used to authenticate to S2A.") + private static String accessToken = System.getenv(ENVIRONMENT_VARIABLE); + + public synchronized void reset() { + accessToken = null; + } + } + + private final String token; + + /** + * Creates a {@code SingleTokenFetcher} from {@code ENVIRONMENT_VARIABLE}, and returns an empty + * {@code Optional} instance if the token could not be fetched. + */ + public static Optional create() { + return Optional.ofNullable(Flags.accessToken).map(SingleTokenFetcher::new); + } + + private SingleTokenFetcher(String token) { + this.token = token; + } + + @Override + public String getDefaultToken() { + return token; + } + + @Override + public String getToken(S2AIdentity identity) { + return token; + } +} \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java new file mode 100644 index 00000000000..9eeddaad844 --- /dev/null +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/TokenFetcher.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import io.grpc.s2a.handshaker.S2AIdentity; + +/** Fetches tokens used to authenticate to S2A. */ +interface TokenFetcher { + /** Returns an access token when no identity is specified. */ + String getDefaultToken(); + + /** Returns an access token for the given identity. */ + String getToken(S2AIdentity identity); +} \ No newline at end of file diff --git a/s2a/src/main/proto/grpc/gcp/common.proto b/s2a/src/main/proto/grpc/gcp/common.proto new file mode 100644 index 00000000000..7c105c2ce05 --- /dev/null +++ b/s2a/src/main/proto/grpc/gcp/common.proto @@ -0,0 +1,79 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package grpc.gcp; + +option java_multiple_files = true; +option java_outer_classname = "CommonProto"; +option java_package = "io.grpc.s2a.handshaker"; + +// The TLS 1.0-1.2 ciphersuites that the application can negotiate when using +// S2A. +enum Ciphersuite { + CIPHERSUITE_UNSPECIFIED = 0; + CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 1; + CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 2; + CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 3; + CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 4; + CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 5; + CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 6; +} + +// The TLS versions supported by S2A's handshaker module. +enum TLSVersion { + TLS_VERSION_UNSPECIFIED = 0; + TLS_VERSION_1_0 = 1; + TLS_VERSION_1_1 = 2; + TLS_VERSION_1_2 = 3; + TLS_VERSION_1_3 = 4; +} + +// The side in the TLS connection. +enum ConnectionSide { + CONNECTION_SIDE_UNSPECIFIED = 0; + CONNECTION_SIDE_CLIENT = 1; + CONNECTION_SIDE_SERVER = 2; +} + +// The ALPN protocols that the application can negotiate during a TLS handshake. +enum AlpnProtocol { + ALPN_PROTOCOL_UNSPECIFIED = 0; + ALPN_PROTOCOL_GRPC = 1; + ALPN_PROTOCOL_HTTP2 = 2; + ALPN_PROTOCOL_HTTP1_1 = 3; +} + +message Identity { + oneof identity_oneof { + // The SPIFFE ID of a connection endpoint. + string spiffe_id = 1; + + // The hostname of a connection endpoint. + string hostname = 2; + + // The UID of a connection endpoint. + string uid = 4; + + // The username of a connection endpoint. + string username = 5; + + // The GCP ID of a connection endpoint. + string gcp_id = 6; + } + + // Additional identity-specific attributes. + map attributes = 3; +} diff --git a/s2a/src/main/proto/grpc/gcp/s2a.proto b/s2a/src/main/proto/grpc/gcp/s2a.proto new file mode 100644 index 00000000000..1a05b546ebb --- /dev/null +++ b/s2a/src/main/proto/grpc/gcp/s2a.proto @@ -0,0 +1,369 @@ +// Copyright 2024 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/gcp/s2a/s2a.proto + +syntax = "proto3"; + +package grpc.gcp; + +import "grpc/gcp/common.proto"; +import "grpc/gcp/s2a_context.proto"; + +option java_multiple_files = true; +option java_outer_classname = "S2AProto"; +option java_package = "io.grpc.s2a.handshaker"; + +enum SignatureAlgorithm { + S2A_SSL_SIGN_UNSPECIFIED = 0; + // RSA Public-Key Cryptography Standards #1. + S2A_SSL_SIGN_RSA_PKCS1_SHA256 = 1; + S2A_SSL_SIGN_RSA_PKCS1_SHA384 = 2; + S2A_SSL_SIGN_RSA_PKCS1_SHA512 = 3; + // ECDSA. + S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256 = 4; + S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384 = 5; + S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512 = 6; + // RSA Probabilistic Signature Scheme. + S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256 = 7; + S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384 = 8; + S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512 = 9; + // ED25519. + S2A_SSL_SIGN_ED25519 = 10; +} + +message AlpnPolicy { + // If true, the application MUST perform ALPN negotiation. + bool enable_alpn_negotiation = 1; + + // The ordered list of ALPN protocols that specify how the application SHOULD + // negotiate ALPN during the TLS handshake. + // + // The application MAY ignore any ALPN protocols in this list that are not + // supported by the application. + repeated AlpnProtocol alpn_protocols = 2; +} + +message AuthenticationMechanism { + // Applications may specify an identity associated to an authentication + // mechanism. Otherwise, S2A assumes that the authentication mechanism is + // associated with the default identity. If the default identity cannot be + // determined, the request is rejected. + Identity identity = 1; + + oneof mechanism_oneof { + // A token that the application uses to authenticate itself to S2A. + string token = 2; + } +} + +message Status { + // The status code that is specific to the application and the implementation + // of S2A, e.g., gRPC status code. + uint32 code = 1; + + // The status details. + string details = 2; +} + +message GetTlsConfigurationReq { + // The role of the application in the TLS connection. + ConnectionSide connection_side = 1; + + // The server name indication (SNI) extension, which MAY be populated when a + // server is offloading to S2A. The SNI is used to determine the server + // identity if the local identity in the request is empty. + string sni = 2; +} + +message GetTlsConfigurationResp { + // Next ID: 8 + message ClientTlsConfiguration { + reserved 4, 5; + + // The certificate chain that the client MUST use for the TLS handshake. + // It's a list of PEM-encoded certificates, ordered from leaf to root, + // excluding the root. + repeated string certificate_chain = 1; + + // The minimum TLS version number that the client MUST use for the TLS + // handshake. If this field is not provided, the client MUST use the default + // minimum version of the client's TLS library. + TLSVersion min_tls_version = 2; + + // The maximum TLS version number that the client MUST use for the TLS + // handshake. If this field is not provided, the client MUST use the default + // maximum version of the client's TLS library. + TLSVersion max_tls_version = 3; + + // The ordered list of TLS 1.0-1.2 ciphersuites that the client MAY offer to + // negotiate in the TLS handshake. + repeated Ciphersuite ciphersuites = 6; + + // The policy that dictates how the client negotiates ALPN during the TLS + // handshake. + AlpnPolicy alpn_policy = 7; + } + + // Next ID: 12 + message ServerTlsConfiguration { + reserved 4, 5; + + enum RequestClientCertificate { + UNSPECIFIED = 0; + DONT_REQUEST_CLIENT_CERTIFICATE = 1; + REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY = 2; + REQUEST_CLIENT_CERTIFICATE_AND_VERIFY = 3; + REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY = 4; + REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY = 5; + } + + // The certificate chain that the server MUST use for the TLS handshake. + // It's a list of PEM-encoded certificates, ordered from leaf to root, + // excluding the root. + repeated string certificate_chain = 1; + + // The minimum TLS version number that the server MUST use for the TLS + // handshake. If this field is not provided, the server MUST use the default + // minimum version of the server's TLS library. + TLSVersion min_tls_version = 2; + + // The maximum TLS version number that the server MUST use for the TLS + // handshake. If this field is not provided, the server MUST use the default + // maximum version of the server's TLS library. + TLSVersion max_tls_version = 3; + + // The ordered list of TLS 1.0-1.2 ciphersuites that the server MAY offer to + // negotiate in the TLS handshake. + repeated Ciphersuite ciphersuites = 10; + + // Whether to enable TLS resumption. + bool tls_resumption_enabled = 6; + + // Whether the server MUST request a client certificate (i.e. to negotiate + // TLS vs. mTLS). + RequestClientCertificate request_client_certificate = 7; + + // Returns the maximum number of extra bytes that + // |OffloadResumptionKeyOperation| can add to the number of unencrypted + // bytes to form the encrypted bytes. + uint32 max_overhead_of_ticket_aead = 9; + + // The policy that dictates how the server negotiates ALPN during the TLS + // handshake. + AlpnPolicy alpn_policy = 11; + } + + oneof tls_configuration { + ClientTlsConfiguration client_tls_configuration = 1; + ServerTlsConfiguration server_tls_configuration = 2; + } +} + +message OffloadPrivateKeyOperationReq { + enum PrivateKeyOperation { + UNSPECIFIED = 0; + // When performing a TLS 1.2 or 1.3 handshake, the (partial) transcript of + // the TLS handshake must be signed to prove possession of the private key. + // + // See https://www.rfc-editor.org/rfc/rfc8446.html#section-4.4.3. + SIGN = 1; + // When performing a TLS 1.2 handshake using an RSA algorithm, the key + // exchange algorithm involves the client generating a premaster secret, + // encrypting it using the server's public key, and sending this encrypted + // blob to the server in a ClientKeyExchange message. + // + // See https://www.rfc-editor.org/rfc/rfc4346#section-7.4.7.1. + DECRYPT = 2; + } + + // The operation the private key is used for. + PrivateKeyOperation operation = 1; + + // The signature algorithm to be used for signing operations. + SignatureAlgorithm signature_algorithm = 2; + + // The input bytes to be signed or decrypted. + oneof in_bytes { + // Raw bytes to be hashed and signed, or decrypted. + bytes raw_bytes = 4; + // A SHA256 hash to be signed. Must be 32 bytes. + bytes sha256_digest = 5; + // A SHA384 hash to be signed. Must be 48 bytes. + bytes sha384_digest = 6; + // A SHA512 hash to be signed. Must be 64 bytes. + bytes sha512_digest = 7; + } +} + +message OffloadPrivateKeyOperationResp { + // The signed or decrypted output bytes. + bytes out_bytes = 1; +} + +message OffloadResumptionKeyOperationReq { + enum ResumptionKeyOperation { + UNSPECIFIED = 0; + ENCRYPT = 1; + DECRYPT = 2; + } + + // The operation the resumption key is used for. + ResumptionKeyOperation operation = 1; + + // The bytes to be encrypted or decrypted. + bytes in_bytes = 2; +} + +message OffloadResumptionKeyOperationResp { + // The encrypted or decrypted bytes. + bytes out_bytes = 1; +} + +message ValidatePeerCertificateChainReq { + enum VerificationMode { + // The default verification mode supported by S2A. + UNSPECIFIED = 0; + // The SPIFFE verification mode selects the set of trusted certificates to + // use for path building based on the SPIFFE trust domain in the peer's leaf + // certificate. + SPIFFE = 1; + // The connect-to-Google verification mode uses the trust bundle for + // connecting to Google, e.g. *.mtls.googleapis.com endpoints. + CONNECT_TO_GOOGLE = 2; + } + + message ClientPeer { + // The certificate chain to be verified. The chain MUST be a list of + // DER-encoded certificates, ordered from leaf to root, excluding the root. + repeated bytes certificate_chain = 1; + } + + message ServerPeer { + // The certificate chain to be verified. The chain MUST be a list of + // DER-encoded certificates, ordered from leaf to root, excluding the root. + repeated bytes certificate_chain = 1; + + // The expected hostname of the server. + string server_hostname = 2; + + // The UnrestrictedClientPolicy specified by the user. + bytes serialized_unrestricted_client_policy = 3; + } + + // The verification mode that S2A MUST use to validate the peer certificate + // chain. + VerificationMode mode = 1; + + oneof peer_oneof { + ClientPeer client_peer = 2; + ServerPeer server_peer = 3; + } +} + +message ValidatePeerCertificateChainResp { + enum ValidationResult { + UNSPECIFIED = 0; + SUCCESS = 1; + FAILURE = 2; + } + + // The result of validating the peer certificate chain. + ValidationResult validation_result = 1; + + // The validation details. This field is only populated when the validation + // result is NOT SUCCESS. + string validation_details = 2; + + // The S2A context contains information from the peer certificate chain. + // + // The S2A context MAY be populated even if validation of the peer certificate + // chain fails. + S2AContext context = 3; +} + +message SessionReq { + // The identity corresponding to the TLS configurations that MUST be used for + // the TLS handshake. + // + // If a managed identity already exists, the local identity and authentication + // mechanisms are ignored. If a managed identity doesn't exist and the local + // identity is not populated, S2A will try to deduce the managed identity to + // use from the SNI extension. If that also fails, S2A uses the default + // identity (if one exists). + Identity local_identity = 1; + + // The authentication mechanisms that the application wishes to use to + // authenticate to S2A, ordered by preference. S2A will always use the first + // authentication mechanism that matches the managed identity. + repeated AuthenticationMechanism authentication_mechanisms = 2; + + oneof req_oneof { + // Requests the certificate chain and TLS configuration corresponding to the + // local identity, which the application MUST use to negotiate the TLS + // handshake. + GetTlsConfigurationReq get_tls_configuration_req = 3; + + // Signs or decrypts the input bytes using a private key corresponding to + // the local identity in the request. + // + // WARNING: More than one OffloadPrivateKeyOperationReq may be sent to the + // S2Av2 by a server during a TLS 1.2 handshake. + OffloadPrivateKeyOperationReq offload_private_key_operation_req = 4; + + // Encrypts or decrypts the input bytes using a resumption key corresponding + // to the local identity in the request. + OffloadResumptionKeyOperationReq offload_resumption_key_operation_req = 5; + + // Verifies the peer's certificate chain using + // (a) trust bundles corresponding to the local identity in the request, and + // (b) the verification mode in the request. + ValidatePeerCertificateChainReq validate_peer_certificate_chain_req = 6; + } +} + +message SessionResp { + // Status of the session response. + // + // The status field is populated so that if an error occurs when making an + // individual request, then communication with the S2A may continue. If an + // error is returned directly (e.g. at the gRPC layer), then it may result + // that the bidirectional stream being closed. + Status status = 1; + + oneof resp_oneof { + // Contains the certificate chain and TLS configurations corresponding to + // the local identity. + GetTlsConfigurationResp get_tls_configuration_resp = 2; + + // Contains the signed or encrypted output bytes using the private key + // corresponding to the local identity. + OffloadPrivateKeyOperationResp offload_private_key_operation_resp = 3; + + // Contains the encrypted or decrypted output bytes using the resumption key + // corresponding to the local identity. + OffloadResumptionKeyOperationResp offload_resumption_key_operation_resp = 4; + + // Contains the validation result, peer identity and fingerprints of peer + // certificates. + ValidatePeerCertificateChainResp validate_peer_certificate_chain_resp = 5; + } +} + +service S2AService { + // SetUpSession is a bidirectional stream used by applications to offload + // operations from the TLS handshake. + rpc SetUpSession(stream SessionReq) returns (stream SessionResp) {} +} diff --git a/s2a/src/main/proto/grpc/gcp/s2a_context.proto b/s2a/src/main/proto/grpc/gcp/s2a_context.proto new file mode 100644 index 00000000000..5ad264bf875 --- /dev/null +++ b/s2a/src/main/proto/grpc/gcp/s2a_context.proto @@ -0,0 +1,61 @@ +// Copyright 2024 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/gcp/s2a/s2a_context.proto +syntax = "proto3"; + +package grpc.gcp; + +import "grpc/gcp/common.proto"; + +option java_multiple_files = true; +option java_outer_classname = "S2AContextProto"; +option java_package = "io.grpc.s2a.handshaker"; + +message S2AContext { + // The SPIFFE ID from the peer leaf certificate, if present. + // + // This field is only populated if the leaf certificate is a valid SPIFFE + // SVID; in particular, there is a unique URI SAN and this URI SAN is a valid + // SPIFFE ID. + string leaf_cert_spiffe_id = 1; + + // The URIs that are present in the SubjectAltName extension of the peer leaf + // certificate. + // + // Note that the extracted URIs are not validated and may not be properly + // formatted. + repeated string leaf_cert_uris = 2; + + // The DNSNames that are present in the SubjectAltName extension of the peer + // leaf certificate. + repeated string leaf_cert_dnsnames = 3; + + // The (ordered) list of fingerprints in the certificate chain used to verify + // the given leaf certificate. The order MUST be from leaf certificate + // fingerprint to root certificate fingerprint. + // + // A fingerprint is the base-64 encoding of the SHA256 hash of the + // DER-encoding of a certificate. The list MAY be populated even if the peer + // certificate chain was NOT validated successfully. + repeated string peer_certificate_chain_fingerprints = 4; + + // The local identity used during session setup. + Identity local_identity = 5; + + // The SHA256 hash of the DER-encoding of the local leaf certificate used in + // the handshake. + bytes local_leaf_cert_fingerprint = 6; +} diff --git a/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java new file mode 100644 index 00000000000..5ccc522292e --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/MtlsToS2AChannelCredentialsTest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class MtlsToS2AChannelCredentialsTest { + @Test + public void createBuilder_nullAddress_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ null, + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_nullPrivateKeyPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ null, + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_nullCertChainPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ null, + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_nullTrustBundlePath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ null)); + } + + @Test + public void createBuilder_emptyAddress_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_emptyPrivateKeyPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_emptyCertChainPath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "", + /* trustBundlePath= */ "src/test/resources/root_cert.pem")); + } + + @Test + public void createBuilder_emptyTrustBundlePath_throwsException() throws Exception { + assertThrows( + IllegalArgumentException.class, + () -> + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "")); + } + + @Test + public void build_s2AChannelCredentials_success() throws Exception { + assertThat( + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ "s2a_address", + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem") + .build()) + .isInstanceOf(S2AChannelCredentials.Builder.class); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java new file mode 100644 index 00000000000..a6133ed0af8 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/S2AChannelCredentialsTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.ChannelCredentials; +import io.grpc.TlsChannelCredentials; +import java.io.File; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@code S2AChannelCredentials}. */ +@RunWith(JUnit4.class) +public final class S2AChannelCredentialsTest { + @Test + public void createBuilder_nullArgument_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder(null)); + } + + @Test + public void createBuilder_emptyAddress_throwsException() throws Exception { + assertThrows(IllegalArgumentException.class, () -> S2AChannelCredentials.createBuilder("")); + } + + @Test + public void setLocalSpiffeId_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalSpiffeId(null)); + } + + @Test + public void setLocalHostname_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalHostname(null)); + } + + @Test + public void setLocalUid_nullArgument_throwsException() throws Exception { + assertThrows( + NullPointerException.class, + () -> S2AChannelCredentials.createBuilder("s2a_address").setLocalUid(null)); + } + + @Test + public void build_withLocalSpiffeId_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.createBuilder("s2a_address") + .setLocalSpiffeId("spiffe://test") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalHostname_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.createBuilder("s2a_address") + .setLocalHostname("local_hostname") + .build()) + .isNotNull(); + } + + @Test + public void build_withLocalUid_succeeds() throws Exception { + assertThat(S2AChannelCredentials.createBuilder("s2a_address").setLocalUid("local_uid").build()) + .isNotNull(); + } + + @Test + public void build_withNoLocalIdentity_succeeds() throws Exception { + assertThat(S2AChannelCredentials.createBuilder("s2a_address").build()) + .isNotNull(); + } + + @Test + public void build_withTlsChannelCredentials_succeeds() throws Exception { + assertThat( + S2AChannelCredentials.createBuilder("s2a_address") + .setLocalSpiffeId("spiffe://test") + .setS2AChannelCredentials(getTlsChannelCredentials()) + .build()) + .isNotNull(); + } + + private static ChannelCredentials getTlsChannelCredentials() throws Exception { + File clientCert = new File("src/test/resources/client_cert.pem"); + File clientKey = new File("src/test/resources/client_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + return TlsChannelCredentials.newBuilder() + .keyManager(clientCert, clientKey) + .trustManager(rootCert) + .build(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java new file mode 100644 index 00000000000..13eccac682d --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java @@ -0,0 +1,125 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; + +import io.grpc.Channel; +import io.grpc.internal.ObjectPool; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AGrpcChannelPool}. */ +@RunWith(JUnit4.class) +public final class S2AGrpcChannelPoolTest { + @Test + public void getChannel_success() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); + + Channel channel = s2aChannelPool.getChannel(); + + assertThat(channel).isNotNull(); + assertThat(fakeChannelPool.isChannelCached()).isTrue(); + assertThat(s2aChannelPool.getChannel()).isEqualTo(channel); + } + + @Test + public void returnChannel_success() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); + + s2aChannelPool.returnChannel(s2aChannelPool.getChannel()); + + assertThat(fakeChannelPool.isChannelCached()).isFalse(); + } + + @Test + public void returnChannel_channelStillCachedBecauseMultipleChannelsRetrieved() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); + + s2aChannelPool.getChannel(); + s2aChannelPool.returnChannel(s2aChannelPool.getChannel()); + + assertThat(fakeChannelPool.isChannelCached()).isTrue(); + } + + @Test + public void returnChannel_failureBecauseChannelWasNotFromPool() throws Exception { + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(new FakeChannelPool()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> s2aChannelPool.returnChannel(mock(Channel.class))); + assertThat(expected) + .hasMessageThat() + .isEqualTo( + "Cannot return the channel to channel pool because the channel was not obtained from" + + " channel pool."); + } + + @Test + public void close_success() throws Exception { + FakeChannelPool fakeChannelPool = new FakeChannelPool(); + try (S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool)) { + s2aChannelPool.getChannel(); + } + + assertThat(fakeChannelPool.isChannelCached()).isFalse(); + } + + @Test + public void close_poolIsUnusable() throws Exception { + S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(new FakeChannelPool()); + s2aChannelPool.close(); + + IllegalStateException expected = + assertThrows(IllegalStateException.class, s2aChannelPool::getChannel); + + assertThat(expected).hasMessageThat().isEqualTo("Channel pool is not open."); + } + + private static class FakeChannelPool implements ObjectPool { + private final Channel mockChannel = mock(Channel.class); + private @Nullable Channel cachedChannel = null; + + @Override + public Channel getObject() { + if (cachedChannel == null) { + cachedChannel = mockChannel; + } + return cachedChannel; + } + + @Override + public Channel returnObject(Object object) { + assertThat(object).isSameInstanceAs(mockChannel); + cachedChannel = null; + return null; + } + + public boolean isChannelCached() { + return (cachedChannel != null); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java new file mode 100644 index 00000000000..57288be1b6f --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AHandshakerServiceChannelTest.java @@ -0,0 +1,390 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.channel; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ChannelCredentials; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel.EventLoopHoldingChannel; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.channel.EventLoopGroup; +import java.io.File; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AHandshakerServiceChannel}. */ +@RunWith(JUnit4.class) +public final class S2AHandshakerServiceChannelTest { + @ClassRule public static final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10); + private final EventLoopGroup mockEventLoopGroup = mock(EventLoopGroup.class); + private Server mtlsServer; + private Server plaintextServer; + + @Before + public void setUp() throws Exception { + mtlsServer = createMtlsServer(); + plaintextServer = createPlaintextServer(); + mtlsServer.start(); + plaintextServer.start(); + } + + /** + * Creates a {@code Resource} and verifies that it produces a {@code ChannelResource} + * instance by using its {@code toString()} method. + */ + @Test + public void getChannelResource_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** Same as getChannelResource_success, but use mTLS. */ + @Test + public void getChannelResource_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource.toString()).isEqualTo("grpc-s2a-channel"); + } + + /** + * Creates two {@code Resoure}s for the same target address and verifies that they are + * equal. + */ + @Test + public void getChannelResource_twoEqualChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + assertThat(resource).isEqualTo(resourceTwo); + } + + /** Same as getChannelResource_twoEqualChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoEqualChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + assertThat(resource).isEqualTo(resourceTwo); + } + + /** + * Creates two {@code Resoure}s for different target addresses and verifies that they are + * distinct. + */ + @Test + public void getChannelResource_twoDistinctChannels() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + Utils.pickUnusedPort(), /* s2aChannelCredentials= */ Optional.empty()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** Same as getChannelResource_twoDistinctChannels, but use mTLS. */ + @Test + public void getChannelResource_mtlsTwoDistinctChannels() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Resource resourceTwo = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + Utils.pickUnusedPort(), getTlsChannelCredentials()); + assertThat(resourceTwo).isNotEqualTo(resource); + } + + /** + * Uses a {@code Resource} to create a channel, closes the channel, and verifies that the + * channel is closed by attempting to make a simple RPC. + */ + @Test + public void close_success() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** Same as close_success, but use mTLS. */ + @Test + public void close_mtlsSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channel = resource.create(); + resource.close(channel); + StatusRuntimeException expected = + assertThrows( + StatusRuntimeException.class, + () -> + SimpleServiceGrpc.newBlockingStub(channel) + .unaryRpc(SimpleRequest.getDefaultInstance())); + assertThat(expected).hasMessageThat().isEqualTo("UNAVAILABLE: Channel shutdown invoked"); + } + + /** + * Verifies that an {@code EventLoopHoldingChannel}'s {@code newCall} method can be used to + * perform a simple RPC. + */ + @Test + public void newCall_performSimpleRpcSuccess() { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Channel channel = resource.create(); + assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + } + + /** Same as newCall_performSimpleRpcSuccess, but use mTLS. */ + @Test + public void newCall_mtlsPerformSimpleRpcSuccess() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channel = resource.create(); + assertThat(channel).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + } + + /** Creates a {@code EventLoopHoldingChannel} instance and verifies its authority. */ + @Test + public void authority_success() throws Exception { + ManagedChannel channel = new FakeManagedChannel(true); + EventLoopHoldingChannel eventLoopHoldingChannel = + EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + assertThat(eventLoopHoldingChannel.authority()).isEqualTo("FakeManagedChannel"); + } + + /** + * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} terminates + * successfully. + */ + @Test + public void close_withDelegateTerminatedSuccess() throws Exception { + ManagedChannel channel = new FakeManagedChannel(true); + EventLoopHoldingChannel eventLoopHoldingChannel = + EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + eventLoopHoldingChannel.close(); + assertThat(channel.isShutdown()).isTrue(); + verify(mockEventLoopGroup, times(1)) + .shutdownGracefully(0, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); + } + + /** + * Creates and closes a {@code EventLoopHoldingChannel} when its {@code ManagedChannel} does not + * terminate successfully. + */ + @Test + public void close_withDelegateTerminatedFailure() throws Exception { + ManagedChannel channel = new FakeManagedChannel(false); + EventLoopHoldingChannel eventLoopHoldingChannel = + EventLoopHoldingChannel.create(channel, mockEventLoopGroup); + eventLoopHoldingChannel.close(); + assertThat(channel.isShutdown()).isTrue(); + verify(mockEventLoopGroup, times(1)) + .shutdownGracefully(1, CHANNEL_SHUTDOWN_TIMEOUT.getSeconds(), SECONDS); + } + + /** + * Creates and closes a {@code EventLoopHoldingChannel}, creates a new channel from the same + * resource, and verifies that this second channel is useable. + */ + @Test + public void create_succeedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + plaintextServer.getPort(), + /* s2aChannelCredentials= */ Optional.empty()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + /** Same as create_succeedsAfterCloseIsCalledOnce, but use mTLS. */ + @Test + public void create_mtlsSucceedsAfterCloseIsCalledOnce() throws Exception { + Resource resource = + S2AHandshakerServiceChannel.getChannelResource( + "localhost:" + mtlsServer.getPort(), getTlsChannelCredentials()); + Channel channelOne = resource.create(); + resource.close(channelOne); + + Channel channelTwo = resource.create(); + assertThat(channelTwo).isInstanceOf(EventLoopHoldingChannel.class); + assertThat( + SimpleServiceGrpc.newBlockingStub(channelTwo) + .unaryRpc(SimpleRequest.getDefaultInstance())) + .isEqualToDefaultInstance(); + resource.close(channelTwo); + } + + private static Server createMtlsServer() throws Exception { + SimpleServiceImpl service = new SimpleServiceImpl(); + File serverCert = new File("src/test/resources/server_cert.pem"); + File serverKey = new File("src/test/resources/server_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + ServerCredentials creds = + TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + return grpcCleanup.register( + NettyServerBuilder.forPort(Utils.pickUnusedPort(), creds).addService(service).build()); + } + + private static Server createPlaintextServer() { + SimpleServiceImpl service = new SimpleServiceImpl(); + return grpcCleanup.register( + ServerBuilder.forPort(Utils.pickUnusedPort()).addService(service).build()); + } + + private static Optional getTlsChannelCredentials() throws Exception { + File clientCert = new File("src/test/resources/client_cert.pem"); + File clientKey = new File("src/test/resources/client_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + return Optional.of( + TlsChannelCredentials.newBuilder() + .keyManager(clientCert, clientKey) + .trustManager(rootCert) + .build()); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver streamObserver) { + streamObserver.onNext(SimpleResponse.getDefaultInstance()); + streamObserver.onCompleted(); + } + } + + private static class FakeManagedChannel extends ManagedChannel { + private final boolean isDelegateTerminatedSuccess; + private boolean isShutdown = false; + + FakeManagedChannel(boolean isDelegateTerminatedSuccess) { + this.isDelegateTerminatedSuccess = isDelegateTerminatedSuccess; + } + + @Override + public String authority() { + return "FakeManagedChannel"; + } + + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions options) { + throw new UnsupportedOperationException("This method should not be called."); + } + + @Override + public ManagedChannel shutdown() { + throw new UnsupportedOperationException("This method should not be called."); + } + + @Override + public boolean isShutdown() { + return isShutdown; + } + + @Override + public boolean isTerminated() { + throw new UnsupportedOperationException("This method should not be called."); + } + + @Override + public ManagedChannel shutdownNow() { + isShutdown = true; + return null; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + if (isDelegateTerminatedSuccess) { + return true; + } + throw new InterruptedException("Await termination was interrupted."); + } + } +} diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java new file mode 100644 index 00000000000..66f636ada22 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import io.grpc.stub.StreamObserver; +import java.security.NoSuchAlgorithmException; +import java.security.spec.InvalidKeySpecException; +import java.util.logging.Logger; + +/** A fake S2Av2 server that should be used for testing only. */ +public final class FakeS2AServer extends S2AServiceGrpc.S2AServiceImplBase { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + + private final FakeWriter writer; + + public FakeS2AServer() throws InvalidKeySpecException, NoSuchAlgorithmException { + this.writer = new FakeWriter(); + this.writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS).initializePrivateKey(); + } + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + logger.info("Received a request from client."); + responseObserver.onNext(writer.handleResponse(req)); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java new file mode 100644 index 00000000000..e200d119867 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java @@ -0,0 +1,265 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.benchmarks.Utils; +import io.grpc.s2a.handshaker.ValidatePeerCertificateChainReq.VerificationMode; +import io.grpc.stub.StreamObserver; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link FakeS2AServer}. */ +@RunWith(JUnit4.class) +public final class FakeS2AServerTest { + private static final Logger logger = Logger.getLogger(FakeS2AServerTest.class.getName()); + + private static final ImmutableList FAKE_CERT_DER_CHAIN = + ImmutableList.of( + ByteString.copyFrom( + new byte[] {'f', 'a', 'k', 'e', '-', 'd', 'e', 'r', '-', 'c', 'h', 'a', 'i', 'n'})); + private int port; + private String serverAddress; + private SessionResp response = null; + private Server fakeS2AServer; + + @Before + public void setUp() throws Exception { + port = Utils.pickUnusedPort(); + fakeS2AServer = ServerBuilder.forPort(port).addService(new FakeS2AServer()).build(); + fakeS2AServer.start(); + serverAddress = String.format("localhost:%d", port); + } + + @After + public void tearDown() { + fakeS2AServer.shutdown(); + } + + @Test + public void callS2AServerOnce_getTlsConfiguration_returnsValidResult() + throws InterruptedException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + @Override + public void onNext(SessionResp resp) { + response = resp; + } + + @Override + public void onError(Throwable t) { + throw new RuntimeException(t); + } + + @Override + public void onCompleted() {} + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(FakeWriter.LEAF_CERT) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult() + throws InterruptedException { + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + @Override + public void onNext(SessionResp resp) { + response = resp; + } + + @Override + public void onError(Throwable t) { + throw new RuntimeException(t); + } + + @Override + public void onCompleted() {} + }); + try { + requestObserver.onNext( + SessionReq.newBuilder() + .setValidatePeerCertificateChainReq( + ValidatePeerCertificateChainReq.newBuilder() + .setMode(VerificationMode.UNSPECIFIED) + .setClientPeer( + ValidatePeerCertificateChainReq.ClientPeer.newBuilder() + .addAllCertificateChain(FAKE_CERT_DER_CHAIN))) + .build()); + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + + SessionResp expected = + SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void callS2AServerRepeatedly_returnsValidResult() throws InterruptedException { + final int numberOfRequests = 10; + ExecutorService executor = Executors.newSingleThreadExecutor(); + logger.info("Client connecting to: " + serverAddress); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + + try { + S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel); + CountDownLatch finishLatch = new CountDownLatch(1); + StreamObserver requestObserver = + asyncStub.setUpSession( + new StreamObserver() { + private int expectedNumberOfReplies = numberOfRequests; + + @Override + public void onNext(SessionResp reply) { + System.out.println("Received a message from the S2AService service."); + expectedNumberOfReplies -= 1; + } + + @Override + public void onError(Throwable t) { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(t); + } + } + + @Override + public void onCompleted() { + finishLatch.countDown(); + if (expectedNumberOfReplies != 0) { + throw new RuntimeException(); + } + } + }); + try { + for (int i = 0; i < numberOfRequests; i++) { + requestObserver.onNext(SessionReq.getDefaultInstance()); + } + } catch (RuntimeException e) { + // Cancel the RPC. + requestObserver.onError(e); + throw e; + } + // Mark the end of requests. + requestObserver.onCompleted(); + // Wait for receiving to happen. + if (!finishLatch.await(10, SECONDS)) { + throw new RuntimeException(); + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + } + +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java new file mode 100644 index 00000000000..505a0cf4a3a --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java @@ -0,0 +1,347 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static io.grpc.s2a.handshaker.TLSVersion.TLS_VERSION_1_3; + +import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.protobuf.ByteString; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.security.KeyFactory; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Signature; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.Base64; + +/** A fake Writer Class to mock the behavior of S2A server. */ +final class FakeWriter implements StreamObserver { + /** Fake behavior of S2A service. */ + enum Behavior { + OK_STATUS, + EMPTY_RESPONSE, + ERROR_STATUS, + ERROR_RESPONSE, + COMPLETE_STATUS + } + + enum VerificationResult { + UNSPECIFIED, + SUCCESS, + FAILURE + } + + public static final String LEAF_CERT = + "-----BEGIN CERTIFICATE-----\n" + + "MIICkDCCAjagAwIBAgIUSAtcrPhNNs1zxv51lIfGOVtkw6QwCgYIKoZIzj0EAwIw\n" + + "QTEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxFDAS\n" + + "BgorBgEEAdZ5AggBDAQyMDIyMCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIz\n" + + "NjA0WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC\n" + + "AAQGFlJpLxJMh4HuUm0DKjnUF7larH3tJvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jz\n" + + "U98eDRXG5f4VjnX98DDHE4Ido4IBODCCATQwDgYDVR0PAQH/BAQDAgeAMCAGA1Ud\n" + + "JQEB/wQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMIGxBgNV\n" + + "HREBAf8EgaYwgaOGSnNwaWZmZTovL3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJp\n" + + "dHktcmVhbG0ucHJvZC5nb29nbGUuY29tL3JvbGUvbGVhZi1yb2xlgjNzaWduZXIt\n" + + "cm9sZS5jb250ZXh0LnNlY3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2eCIGZx\n" + + "ZG4tb2YtdGhlLW5vZGUucHJvZC5nb29nbGUuY29tMB0GA1UdDgQWBBSWSd5Fw6dI\n" + + "TGpt0m1Uxwf0iKqebzAfBgNVHSMEGDAWgBRm5agVVdpWfRZKM7u6OMuzHhqPcDAK\n" + + "BggqhkjOPQQDAgNIADBFAiB0sjRPSYy2eFq8Y0vQ8QN4AZ2NMajskvxnlifu7O4U\n" + + "RwIhANTh5Fkyx2nMYFfyl+W45dY8ODTw3HnlZ4b51hTAdkWl\n" + + "-----END CERTIFICATE-----"; + public static final String INTERMEDIATE_CERT_2 = + "-----BEGIN CERTIFICATE-----\n" + + "MIICQjCCAeigAwIBAgIUKxXRDlnWXefNV5lj5CwhDuXEq7MwCgYIKoZIzj0EAwIw\n" + + "OzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxDjAM\n" + + "BgNVBAMMBTEyMzQ1MCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIzNjA0WjBB\n" + + "MRcwFQYDVQQKDA5zZWN1cml0eS1yZWFsbTEQMA4GA1UECwwHY29udGV4dDEUMBIG\n" + + "CisGAQQB1nkCCAEMBDIwMjIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAT/Zu7x\n" + + "UYVyg+T/vg2H+y4I6t36Kc4qxD0eqqZjRLYBVKkUQHxBqc14t0DpoROMYQCNd4DF\n" + + "pcxv/9m6DaJbRk6Ao4HBMIG+MA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAG\n" + + "AQH/AgEBMFgGA1UdHgEB/wROMEygSjA1gjNzaWduZXItcm9sZS5jb250ZXh0LnNl\n" + + "Y3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2cwEYIPcHJvZC5nb29nbGUuY29t\n" + + "MB0GA1UdDgQWBBRm5agVVdpWfRZKM7u6OMuzHhqPcDAfBgNVHSMEGDAWgBQcjNAh\n" + + "SCHTj+BW8KrzSSLo2ASEgjAKBggqhkjOPQQDAgNIADBFAiEA6KyGd9VxXDZceMZG\n" + + "IsbC40rtunFjLYI0mjZw9RcRWx8CIHCIiIHxafnDaCi+VB99NZfzAdu37g6pJptB\n" + + "gjIY71MO\n" + + "-----END CERTIFICATE-----"; + public static final String INTERMEDIATE_CERT_1 = + "-----BEGIN CERTIFICATE-----\n" + + "MIICODCCAd6gAwIBAgIUXtZECORWRSKnS9rRTJYkiALUXswwCgYIKoZIzj0EAwIw\n" + + "NzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xDTALBgNVBAsMBHJvb3QxDTALBgNV\n" + + "BAMMBDEyMzQwIBcNMjMwNzE0MjIzNjA0WhgPMjA1MDExMjkyMjM2MDRaMDsxFzAV\n" + + "BgNVBAoMDnNlY3VyaXR5LXJlYWxtMRAwDgYDVQQLDAdjb250ZXh0MQ4wDAYDVQQD\n" + + "DAUxMjM0NTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABAycVTZrjockbpD59f1a\n" + + "4l1SNL7nSyXz66Guz4eDveQqLmaMBg7vpACfO4CtiAGnolHEffuRtSkdM434m5En\n" + + "bXCjgcEwgb4wDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQIwWAYD\n" + + "VR0eAQH/BE4wTKBKMDWCM3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJpdHktcmVh\n" + + "bG0ucHJvZC5zcGlmZmUuZ29vZzARgg9wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYE\n" + + "FByM0CFIIdOP4FbwqvNJIujYBISCMB8GA1UdIwQYMBaAFMX+vebuj/lYfYEC23IA\n" + + "8HoIW0HsMAoGCCqGSM49BAMCA0gAMEUCIQCfxeXEBd7UPmeImT16SseCRu/6cHxl\n" + + "kTDsq9sKZ+eXBAIgA+oViAVOUhUQO1/6Mjlczg8NmMy2vNtG4V/7g9dMMVU=\n" + + "-----END CERTIFICATE-----"; + + private static final String PRIVATE_KEY = + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgqA2U0ld1OOHLMXWf" + + "uyN4GSaqhhudEIaKkll3rdIq0M+hRANCAAQGFlJpLxJMh4HuUm0DKjnUF7larH3t" + + "JvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jzU98eDRXG5f4VjnX98DDHE4Id"; + private static final ImmutableMap + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER = + ImmutableMap.of( + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256, + "SHA256withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384, + "SHA384withECDSA", + SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512, + "SHA512withECDSA"); + + private boolean fakeWriterClosed = false; + private Behavior behavior = Behavior.OK_STATUS; + private StreamObserver reader; + private VerificationResult verificationResult = VerificationResult.UNSPECIFIED; + private String failureReason; + private PrivateKey privateKey; + + @CanIgnoreReturnValue + FakeWriter setReader(StreamObserver reader) { + this.reader = reader; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setBehavior(Behavior behavior) { + this.behavior = behavior; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setVerificationResult(VerificationResult verificationResult) { + this.verificationResult = verificationResult; + return this; + } + + @CanIgnoreReturnValue + FakeWriter setFailureReason(String failureReason) { + this.failureReason = failureReason; + return this; + } + + @CanIgnoreReturnValue + FakeWriter initializePrivateKey() throws InvalidKeySpecException, NoSuchAlgorithmException { + privateKey = + KeyFactory.getInstance("EC") + .generatePrivate(new PKCS8EncodedKeySpec(Base64.getDecoder().decode(PRIVATE_KEY))); + return this; + } + + @CanIgnoreReturnValue + FakeWriter resetPrivateKey() { + privateKey = null; + return this; + } + + void sendUnexpectedResponse() { + reader.onNext(SessionResp.getDefaultInstance()); + } + + void sendIoError() { + reader.onError(new IOException("Intended ERROR from FakeWriter.")); + } + + void sendGetTlsConfigResp() { + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(LEAF_CERT) + .addCertificateChain(INTERMEDIATE_CERT_2) + .addCertificateChain(INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build()); + } + + boolean isFakeWriterClosed() { + return fakeWriterClosed; + } + + @Override + public void onNext(SessionReq sessionReq) { + switch (behavior) { + case OK_STATUS: + reader.onNext(handleResponse(sessionReq)); + break; + case EMPTY_RESPONSE: + reader.onNext(SessionResp.getDefaultInstance()); + break; + case ERROR_STATUS: + reader.onNext( + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(1) + .setDetails("Intended ERROR Status from FakeWriter.")) + .build()); + break; + case ERROR_RESPONSE: + reader.onError(new S2AConnectionException("Intended ERROR from FakeWriter.")); + break; + case COMPLETE_STATUS: + reader.onCompleted(); + break; + default: + reader.onNext(handleResponse(sessionReq)); + } + } + + SessionResp handleResponse(SessionReq sessionReq) { + if (sessionReq.hasGetTlsConfigurationReq()) { + return handleGetTlsConfigurationReq(sessionReq.getGetTlsConfigurationReq()); + } + + if (sessionReq.hasValidatePeerCertificateChainReq()) { + return handleValidatePeerCertificateChainReq(sessionReq.getValidatePeerCertificateChainReq()); + } + + if (sessionReq.hasOffloadPrivateKeyOperationReq()) { + return handleOffloadPrivateKeyOperationReq(sessionReq.getOffloadPrivateKeyOperationReq()); + } + + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(255).setDetails("No supported operation designated.")) + .build(); + } + + private SessionResp handleGetTlsConfigurationReq(GetTlsConfigurationReq req) { + if (!req.getConnectionSide().equals(ConnectionSide.CONNECTION_SIDE_CLIENT)) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + } + return SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(LEAF_CERT) + .addCertificateChain(INTERMEDIATE_CERT_2) + .addCertificateChain(INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + } + + private SessionResp handleValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (verifyValidatePeerCertificateChainReq(req) + && verificationResult == VerificationResult.SUCCESS) { + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS)) + .build(); + } + return SessionResp.newBuilder() + .setValidatePeerCertificateChainResp( + ValidatePeerCertificateChainResp.newBuilder() + .setValidationResult( + verificationResult == VerificationResult.FAILURE + ? ValidatePeerCertificateChainResp.ValidationResult.FAILURE + : ValidatePeerCertificateChainResp.ValidationResult.UNSPECIFIED) + .setValidationDetails(failureReason)) + .build(); + } + + private boolean verifyValidatePeerCertificateChainReq(ValidatePeerCertificateChainReq req) { + if (req.getMode() != ValidatePeerCertificateChainReq.VerificationMode.UNSPECIFIED) { + return false; + } + if (req.getClientPeer().getCertificateChainCount() > 0) { + return true; + } + if (req.getServerPeer().getCertificateChainCount() > 0 + && !req.getServerPeer().getServerHostname().isEmpty()) { + return true; + } + return false; + } + + private SessionResp handleOffloadPrivateKeyOperationReq(OffloadPrivateKeyOperationReq req) { + if (privateKey == null) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails("No Private Key available.")) + .build(); + } + String signatureIdentifier = + ALGORITHM_TO_SIGNATURE_INSTANCE_IDENTIFIER.get(req.getSignatureAlgorithm()); + if (signatureIdentifier == null) { + return SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("Only ECDSA key algorithms are supported.")) + .build(); + } + + byte[] signature; + try { + Signature sig = Signature.getInstance(signatureIdentifier); + sig.initSign(privateKey); + sig.update(req.getRawBytes().toByteArray()); + signature = sig.sign(); + } catch (Exception e) { + return SessionResp.newBuilder() + .setStatus(Status.newBuilder().setCode(255).setDetails(e.getMessage())) + .build(); + } + + return SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder().setOutBytes(ByteString.copyFrom(signature))) + .build(); + } + + @Override + public void onError(Throwable t) { + throw new UnsupportedOperationException("onError is not supported by FakeWriter."); + } + + @Override + public void onCompleted() { + fakeWriterClosed = true; + reader.onCompleted(); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java new file mode 100644 index 00000000000..aea279ed8c5 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import com.beust.jcommander.JCommander; +import com.google.common.truth.Expect; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.tokenmanager.SingleTokenFetcher; +import java.util.Optional; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link GetAuthenticationMechanisms}. */ +@RunWith(JUnit4.class) +public final class GetAuthenticationMechanismsTest { + @Rule public final Expect expect = Expect.create(); + private static final String TOKEN = "access_token"; + private static final String[] SET_TOKEN = {"--s2a_access_token", TOKEN}; + private static final SingleTokenFetcher.Flags FLAGS = new SingleTokenFetcher.Flags(); + + @BeforeClass + public static void setUpClass() { + // Set the token that the client will use to authenticate to the S2A. + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + } + + @Test + public void getAuthMechanisms_emptyIdentity_success() { + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.empty())) + .isEqualTo( + Optional.of(AuthenticationMechanism.newBuilder().setToken("access_token").build())); + } + + @Test + public void getAuthMechanisms_nonEmptyIdentity_success() { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + expect + .that(GetAuthenticationMechanisms.getAuthMechanism(Optional.of(fakeIdentity))) + .isEqualTo( + Optional.of( + AuthenticationMechanism.newBuilder() + .setIdentity(fakeIdentity.identity()) + .setToken("access_token") + .build())); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java new file mode 100644 index 00000000000..f9de9765527 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java @@ -0,0 +1,322 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.SECONDS; + +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.benchmarks.Utils; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.s2a.MtlsToS2AChannelCredentials; +import io.grpc.s2a.S2AChannelCredentials; +import io.grpc.s2a.handshaker.FakeS2AServer; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import java.io.ByteArrayInputStream; +import java.io.File; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class IntegrationTest { + private static final Logger logger = Logger.getLogger(FakeS2AServer.class.getName()); + + private static final String CERT_CHAIN = + "-----BEGIN CERTIFICATE-----\n" + + "MIICkDCCAjagAwIBAgIUSAtcrPhNNs1zxv51lIfGOVtkw6QwCgYIKoZIzj0EAwIw\n" + + "QTEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxFDAS\n" + + "BgorBgEEAdZ5AggBDAQyMDIyMCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIz\n" + + "NjA0WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC\n" + + "AAQGFlJpLxJMh4HuUm0DKjnUF7larH3tJvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jz\n" + + "U98eDRXG5f4VjnX98DDHE4Ido4IBODCCATQwDgYDVR0PAQH/BAQDAgeAMCAGA1Ud\n" + + "JQEB/wQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMIGxBgNV\n" + + "HREBAf8EgaYwgaOGSnNwaWZmZTovL3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJp\n" + + "dHktcmVhbG0ucHJvZC5nb29nbGUuY29tL3JvbGUvbGVhZi1yb2xlgjNzaWduZXIt\n" + + "cm9sZS5jb250ZXh0LnNlY3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2eCIGZx\n" + + "ZG4tb2YtdGhlLW5vZGUucHJvZC5nb29nbGUuY29tMB0GA1UdDgQWBBSWSd5Fw6dI\n" + + "TGpt0m1Uxwf0iKqebzAfBgNVHSMEGDAWgBRm5agVVdpWfRZKM7u6OMuzHhqPcDAK\n" + + "BggqhkjOPQQDAgNIADBFAiB0sjRPSYy2eFq8Y0vQ8QN4AZ2NMajskvxnlifu7O4U\n" + + "RwIhANTh5Fkyx2nMYFfyl+W45dY8ODTw3HnlZ4b51hTAdkWl\n" + + "-----END CERTIFICATE-----\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIICQjCCAeigAwIBAgIUKxXRDlnWXefNV5lj5CwhDuXEq7MwCgYIKoZIzj0EAwIw\n" + + "OzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xEDAOBgNVBAsMB2NvbnRleHQxDjAM\n" + + "BgNVBAMMBTEyMzQ1MCAXDTIzMDcxNDIyMzYwNFoYDzIwNTAxMTI5MjIzNjA0WjBB\n" + + "MRcwFQYDVQQKDA5zZWN1cml0eS1yZWFsbTEQMA4GA1UECwwHY29udGV4dDEUMBIG\n" + + "CisGAQQB1nkCCAEMBDIwMjIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAT/Zu7x\n" + + "UYVyg+T/vg2H+y4I6t36Kc4qxD0eqqZjRLYBVKkUQHxBqc14t0DpoROMYQCNd4DF\n" + + "pcxv/9m6DaJbRk6Ao4HBMIG+MA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAG\n" + + "AQH/AgEBMFgGA1UdHgEB/wROMEygSjA1gjNzaWduZXItcm9sZS5jb250ZXh0LnNl\n" + + "Y3VyaXR5LXJlYWxtLnByb2Quc3BpZmZlLmdvb2cwEYIPcHJvZC5nb29nbGUuY29t\n" + + "MB0GA1UdDgQWBBRm5agVVdpWfRZKM7u6OMuzHhqPcDAfBgNVHSMEGDAWgBQcjNAh\n" + + "SCHTj+BW8KrzSSLo2ASEgjAKBggqhkjOPQQDAgNIADBFAiEA6KyGd9VxXDZceMZG\n" + + "IsbC40rtunFjLYI0mjZw9RcRWx8CIHCIiIHxafnDaCi+VB99NZfzAdu37g6pJptB\n" + + "gjIY71MO\n" + + "-----END CERTIFICATE-----\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIICODCCAd6gAwIBAgIUXtZECORWRSKnS9rRTJYkiALUXswwCgYIKoZIzj0EAwIw\n" + + "NzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xDTALBgNVBAsMBHJvb3QxDTALBgNV\n" + + "BAMMBDEyMzQwIBcNMjMwNzE0MjIzNjA0WhgPMjA1MDExMjkyMjM2MDRaMDsxFzAV\n" + + "BgNVBAoMDnNlY3VyaXR5LXJlYWxtMRAwDgYDVQQLDAdjb250ZXh0MQ4wDAYDVQQD\n" + + "DAUxMjM0NTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABAycVTZrjockbpD59f1a\n" + + "4l1SNL7nSyXz66Guz4eDveQqLmaMBg7vpACfO4CtiAGnolHEffuRtSkdM434m5En\n" + + "bXCjgcEwgb4wDgYDVR0PAQH/BAQDAgEGMBIGA1UdEwEB/wQIMAYBAf8CAQIwWAYD\n" + + "VR0eAQH/BE4wTKBKMDWCM3NpZ25lci1yb2xlLmNvbnRleHQuc2VjdXJpdHktcmVh\n" + + "bG0ucHJvZC5zcGlmZmUuZ29vZzARgg9wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYE\n" + + "FByM0CFIIdOP4FbwqvNJIujYBISCMB8GA1UdIwQYMBaAFMX+vebuj/lYfYEC23IA\n" + + "8HoIW0HsMAoGCCqGSM49BAMCA0gAMEUCIQCfxeXEBd7UPmeImT16SseCRu/6cHxl\n" + + "kTDsq9sKZ+eXBAIgA+oViAVOUhUQO1/6Mjlczg8NmMy2vNtG4V/7g9dMMVU=\n" + + "-----END CERTIFICATE-----"; + private static final String ROOT_PEM = + "-----BEGIN CERTIFICATE-----\n" + + "MIIBtTCCAVqgAwIBAgIUbAe+8OocndQXRBCElLBxBSdfdV8wCgYIKoZIzj0EAwIw\n" + + "NzEXMBUGA1UECgwOc2VjdXJpdHktcmVhbG0xDTALBgNVBAsMBHJvb3QxDTALBgNV\n" + + "BAMMBDEyMzQwIBcNMjMwNzE0MjIzNjA0WhgPMjA1MDExMjkyMjM2MDRaMDcxFzAV\n" + + "BgNVBAoMDnNlY3VyaXR5LXJlYWxtMQ0wCwYDVQQLDARyb290MQ0wCwYDVQQDDAQx\n" + + "MjM0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEaMY2tBW5r1t0+vhayz0ZoGMF\n" + + "boX/ZmmCmIh0iTWg4madvwNOh74CMVVvDUlXZcuVqZ3vVIX/a7PTFVqUwQlKW6NC\n" + + "MEAwDgYDVR0PAQH/BAQDAgGGMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFMX+\n" + + "vebuj/lYfYEC23IA8HoIW0HsMAoGCCqGSM49BAMCA0kAMEYCIQDETd27nsUTXKWY\n" + + "CiOno78O09gK95NoTkPU5e2chJYMqAIhALYFAyh7PU5xgFQsN9hiqgsHUc5/pmBG\n" + + "BGjJ1iz8rWGJ\n" + + "-----END CERTIFICATE-----"; + private static final String PRIVATE_KEY = + "-----BEGIN PRIVATE KEY-----\n" + + "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgqA2U0ld1OOHLMXWf\n" + + "uyN4GSaqhhudEIaKkll3rdIq0M+hRANCAAQGFlJpLxJMh4HuUm0DKjnUF7larH3t\n" + + "JvroQ12xpk+pPKQepn4ILoq9lZ8Xd3jzU98eDRXG5f4VjnX98DDHE4Id\n" + + "-----END PRIVATE KEY-----"; + + private String s2aAddress; + private int s2aPort; + private Server s2aServer; + private String s2aDelayAddress; + private int s2aDelayPort; + private Server s2aDelayServer; + private String mtlsS2AAddress; + private int mtlsS2APort; + private Server mtlsS2AServer; + private int serverPort; + private String serverAddress; + private Server server; + + @BeforeClass + public static void setUpClass() { + System.setProperty("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", "false"); + } + + @Before + public void setUp() throws Exception { + s2aPort = Utils.pickUnusedPort(); + s2aAddress = "localhost:" + s2aPort; + s2aServer = ServerBuilder.forPort(s2aPort).addService(new FakeS2AServer()).build(); + logger.info("S2A service listening on localhost:" + s2aPort); + s2aServer.start(); + + mtlsS2APort = Utils.pickUnusedPort(); + mtlsS2AAddress = "localhost:" + mtlsS2APort; + File s2aCert = new File("src/test/resources/server_cert.pem"); + File s2aKey = new File("src/test/resources/server_key.pem"); + File rootCert = new File("src/test/resources/root_cert.pem"); + ServerCredentials s2aCreds = + TlsServerCredentials.newBuilder() + .keyManager(s2aCert, s2aKey) + .trustManager(rootCert) + .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) + .build(); + mtlsS2AServer = + NettyServerBuilder.forPort(mtlsS2APort, s2aCreds).addService(new FakeS2AServer()).build(); + logger.info("mTLS S2A service listening on localhost:" + mtlsS2APort); + mtlsS2AServer.start(); + + s2aDelayPort = Utils.pickUnusedPort(); + s2aDelayAddress = "localhost:" + s2aDelayPort; + s2aDelayServer = ServerBuilder.forPort(s2aDelayPort).addService(new FakeS2AServer()).build(); + + serverPort = Utils.pickUnusedPort(); + serverAddress = "localhost:" + serverPort; + server = + NettyServerBuilder.forPort(serverPort) + .addService(new SimpleServiceImpl()) + .sslContext(buildSslContext()) + .build(); + logger.info("Simple Service listening on localhost:" + serverPort); + server.start(); + } + + @After + public void tearDown() throws Exception { + server.shutdown(); + s2aServer.shutdown(); + s2aDelayServer.shutdown(); + mtlsS2AServer.shutdown(); + } + + @Test + public void clientCommunicateUsingS2ACredentials_succeeds() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = + S2AChannelCredentials.createBuilder(s2aAddress).setLocalSpiffeId("test-spiffe-id").build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + + assertThat(doUnaryRpc(executor, channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentialsNoLocalIdentity_succeeds() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aAddress).build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + + assertThat(doUnaryRpc(executor, channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingMtlsToS2ACredentials_succeeds() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = + MtlsToS2AChannelCredentials.createBuilder( + /* s2aAddress= */ mtlsS2AAddress, + /* privateKeyPath= */ "src/test/resources/client_key.pem", + /* certChainPath= */ "src/test/resources/client_cert.pem", + /* trustBundlePath= */ "src/test/resources/root_cert.pem") + .build() + .setLocalSpiffeId("test-spiffe-id") + .build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + + assertThat(doUnaryRpc(executor, channel)).isTrue(); + } + + @Test + public void clientCommunicateUsingS2ACredentials_s2AdelayStart_succeeds() throws Exception { + DoUnaryRpc doUnaryRpc = new DoUnaryRpc(); + doUnaryRpc.start(); + Thread.sleep(2000); + s2aDelayServer.start(); + doUnaryRpc.join(); + } + + private class DoUnaryRpc extends Thread { + @Override + public void run() { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ChannelCredentials credentials = S2AChannelCredentials.createBuilder(s2aDelayAddress).build(); + ManagedChannel channel = + Grpc.newChannelBuilder(serverAddress, credentials).executor(executor).build(); + boolean result = false; + try { + result = doUnaryRpc(executor, channel); + } catch (InterruptedException e) { + logger.log(Level.SEVERE, "Failed to do unary rpc", e); + result = false; + } + assertThat(result).isTrue(); + } + } + + public static boolean doUnaryRpc(ExecutorService executor, ManagedChannel channel) + throws InterruptedException { + try { + SimpleServiceGrpc.SimpleServiceBlockingStub stub = + SimpleServiceGrpc.newBlockingStub(channel).withWaitForReady(); + SimpleResponse resp = stub.unaryRpc(SimpleRequest.newBuilder() + .setRequestMessage("S2A team") + .build()); + if (!resp.getResponseMessage().equals("Hello, S2A team!")) { + logger.info( + "Received unexpected message from the Simple Service: " + resp.getResponseMessage()); + throw new RuntimeException(); + } else { + System.out.println( + "We received this message from the Simple Service: " + resp.getResponseMessage()); + return true; + } + } finally { + channel.shutdown(); + channel.awaitTermination(1, SECONDS); + executor.shutdown(); + executor.awaitTermination(1, SECONDS); + } + } + + private static SslContext buildSslContext() throws SSLException { + SslContextBuilder sslServerContextBuilder = + SslContextBuilder.forServer( + new ByteArrayInputStream(CERT_CHAIN.getBytes(UTF_8)), + new ByteArrayInputStream(PRIVATE_KEY.getBytes(UTF_8))); + SslContext sslServerContext = + GrpcSslContexts.configure(sslServerContextBuilder, SslProvider.OPENSSL) + .protocols("TLSv1.3", "TLSv1.2") + .trustManager(new ByteArrayInputStream(ROOT_PEM.getBytes(UTF_8))) + .clientAuth(ClientAuth.REQUIRE) + .build(); + + // Enable TLS resumption. This requires using the OpenSSL provider, since the JDK provider does + // not allow a server to send session tickets. + SSLSessionContext sslSessionContext = sslServerContext.sessionContext(); + if (!(sslSessionContext instanceof OpenSslSessionContext)) { + throw new SSLException("sslSessionContext does not use OpenSSL."); + } + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + // Calling {@code setTicketKeys} without specifying any keys means that the SSL libraries will + // handle the generation of the resumption master secret. + openSslSessionContext.setTicketKeys(); + + return sslServerContext; + } + + public static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver observer) { + observer.onNext( + SimpleResponse.newBuilder() + .setResponseMessage("Hello, " + request.getRequestMessage() + "!") + .build()); + observer.onCompleted(); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java new file mode 100644 index 00000000000..0191398b6b7 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ProtoUtil}. */ +@RunWith(JUnit4.class) +public final class ProtoUtilTest { + @Rule public final Expect expect = Expect.create(); + + @Test + public void convertCiphersuite_success() { + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)) + .isEqualTo("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"); + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384)) + .isEqualTo("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"); + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)) + .isEqualTo("TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); + expect + .that( + ProtoUtil.convertCiphersuite(Ciphersuite.CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256)) + .isEqualTo("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + expect + .that( + ProtoUtil.convertCiphersuite(Ciphersuite.CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384)) + .isEqualTo("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + expect + .that( + ProtoUtil.convertCiphersuite( + Ciphersuite.CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)) + .isEqualTo("TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"); + } + + @Test + public void convertCiphersuite_withUnspecifiedCiphersuite_fails() { + AssertionError expected = + assertThrows( + AssertionError.class, + () -> ProtoUtil.convertCiphersuite(Ciphersuite.CIPHERSUITE_UNSPECIFIED)); + expect.that(expected).hasMessageThat().isEqualTo("Ciphersuite 0 is not supported."); + } + + @Test + public void convertTlsProtocolVersion_success() { + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_3)) + .isEqualTo("TLSv1.3"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_2)) + .isEqualTo("TLSv1.2"); + expect + .that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_1)) + .isEqualTo("TLSv1.1"); + expect.that(ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_1_0)).isEqualTo("TLSv1"); + } + + @Test + public void convertTlsProtocolVersion_withUnknownTlsVersion_fails() { + AssertionError expected = + assertThrows( + AssertionError.class, + () -> ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_UNSPECIFIED)); + expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java new file mode 100644 index 00000000000..4024e8a6e36 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.truth.Expect; +import com.google.protobuf.ByteString; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslPrivateKeyMethod; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.ByteArrayInputStream; +import java.security.PublicKey; +import java.security.Signature; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2APrivateKeyMethodTest { + @Rule public final Expect expect = Expect.create(); + private static final byte[] DATA_TO_SIGN = "random bytes for signing.".getBytes(UTF_8); + + private S2AStub stub; + private FakeWriter writer; + private S2APrivateKeyMethod keyMethod; + + private static PublicKey extractPublicKeyFromPem(String pem) throws Exception { + X509Certificate cert = + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(pem.getBytes(UTF_8))); + return cert.getPublicKey(); + } + + private static boolean verifySignature( + byte[] dataToSign, byte[] signature, String signatureAlgorithm) throws Exception { + Signature sig = Signature.getInstance(signatureAlgorithm); + sig.initVerify(extractPublicKeyFromPem(FakeWriter.LEAF_CERT)); + sig.update(dataToSign); + return sig.verify(signature); + } + + @Before + public void setUp() { + // This is line is to ensure that JNI correctly links the necessary objects. Without this, we + // get `java.lang.UnsatisfiedLinkError` on + // `io.netty.internal.tcnative.NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha1()` + GrpcSslContexts.configure(SslContextBuilder.forClient()); + + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + keyMethod = S2APrivateKeyMethod.create(stub, /* localIdentity= */ Optional.empty()); + } + + @Test + public void signatureAlgorithmConversion_success() { + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PKCS1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP384R1_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP521R1_SHA512); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA256); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA384); + expect + .that( + S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg( + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512)) + .isEqualTo(SignatureAlgorithm.S2A_SSL_SIGN_RSA_PSS_RSAE_SHA512); + } + + @Test + public void signatureAlgorithmConversion_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> S2APrivateKeyMethod.convertOpenSslSignAlgToS2ASignAlg(-1)); + + assertThat(e).hasMessageThat().contains("Signature Algorithm -1 is not supported."); + } + + @Test + public void createOnNullStub_returnsNullPointerException() { + assertThrows( + NullPointerException.class, + () -> S2APrivateKeyMethod.create(/* stub= */ null, /* localIdentity= */ Optional.empty())); + } + + @Test + public void decrypt_unsupportedOperation() { + UnsupportedOperationException e = + assertThrows( + UnsupportedOperationException.class, + () -> keyMethod.decrypt(/* engine= */ null, DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("decrypt is not supported."); + } + + @Test + public void fakelocalIdentity_signWithSha256_success() throws Exception { + S2AIdentity fakeIdentity = S2AIdentity.fromSpiffeId("fake-spiffe-id"); + S2AStub mockStub = mock(S2AStub.class); + OpenSslPrivateKeyMethod keyMethodWithFakeIdentity = + S2APrivateKeyMethod.create(mockStub, Optional.of(fakeIdentity)); + SessionReq req = + SessionReq.newBuilder() + .setLocalIdentity(fakeIdentity.identity()) + .setOffloadPrivateKeyOperationReq( + OffloadPrivateKeyOperationReq.newBuilder() + .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) + .setSignatureAlgorithm(SignatureAlgorithm.S2A_SSL_SIGN_ECDSA_SECP256R1_SHA256) + .setRawBytes(ByteString.copyFrom(DATA_TO_SIGN))) + .build(); + byte[] expectedOutbytes = "fake out bytes".getBytes(UTF_8); + when(mockStub.send(req)) + .thenReturn( + SessionResp.newBuilder() + .setOffloadPrivateKeyOperationResp( + OffloadPrivateKeyOperationResp.newBuilder() + .setOutBytes(ByteString.copyFrom(expectedOutbytes))) + .build()); + + byte[] signature = + keyMethodWithFakeIdentity.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + verify(mockStub).send(req); + assertThat(signature).isEqualTo(expectedOutbytes); + } + + @Test + public void signWithSha256_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA256withECDSA")).isTrue(); + } + + @Test + public void signWithSha384_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA384withECDSA")).isTrue(); + } + + @Test + public void signWithSha512_success() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + byte[] signature = + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512, + DATA_TO_SIGN); + + assertThat(signature).isNotEmpty(); + assertThat(verifySignature(DATA_TO_SIGN, signature, "SHA512withECDSA")).isTrue(); + } + + @Test + public void sign_noKeyAvailable() throws Exception { + writer.resetPrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"No Private Key" + + " available.\"."); + } + + @Test + public void sign_algorithmNotSupported() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.OK_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 255, error message: \"Only ECDSA key" + + " algorithms are supported.\"."); + } + + @Test + public void sign_getsErrorResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e) + .hasMessageThat() + .contains( + "Error occurred in response from S2A, error code: 1, error message: \"Intended ERROR" + + " Status from FakeWriter.\"."); + } + + @Test + public void sign_getsEmptyResponse() throws Exception { + writer.initializePrivateKey().setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException e = + assertThrows( + S2AConnectionException.class, + () -> + keyMethod.sign( + /* engine= */ null, + OpenSslPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256, + DATA_TO_SIGN)); + + assertThat(e).hasMessageThat().contains("No valid response received from S2A."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java new file mode 100644 index 00000000000..82db6d4a144 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import com.google.common.testing.NullPointerTester; +import com.google.common.testing.NullPointerTester.Visibility; +import io.grpc.Channel; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.benchmarks.Utils; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.s2a.channel.S2AChannelPool; +import io.grpc.s2a.channel.S2AGrpcChannelPool; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.grpc.s2a.handshaker.S2AProtocolNegotiatorFactory.S2AProtocolNegotiator; +import io.grpc.stub.StreamObserver; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import io.netty.util.AsciiString; +import java.util.Optional; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AProtocolNegotiatorFactory}. */ +@RunWith(JUnit4.class) +public class S2AProtocolNegotiatorFactoryTest { + private static final S2AIdentity LOCAL_IDENTITY = S2AIdentity.fromSpiffeId("local identity"); + private final ChannelHandlerContext mockChannelHandlerContext = mock(ChannelHandlerContext.class); + private GrpcHttp2ConnectionHandler fakeConnectionHandler; + private String authority; + private int port; + private Server fakeS2AServer; + private ObjectPool channelPool; + + @Before + public void setUp() throws Exception { + port = Utils.pickUnusedPort(); + fakeS2AServer = ServerBuilder.forPort(port).addService(new S2AServiceImpl()).build(); + fakeS2AServer.start(); + channelPool = new FakeChannelPool(); + authority = "localhost:" + port; + fakeConnectionHandler = FakeConnectionHandler.create(authority); + } + + @After + public void tearDown() { + fakeS2AServer.shutdown(); + } + + @Test + public void createProtocolNegotiatorFactory_nullArgument() throws Exception { + NullPointerTester tester = new NullPointerTester().setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiatorFactory.class, Visibility.PUBLIC); + } + + @Test + public void createProtocolNegotiator_nullArgument() throws Exception { + S2AChannelPool pool = + S2AGrpcChannelPool.create( + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + "localhost:8080", /* s2aChannelCredentials= */ Optional.empty()))); + + NullPointerTester tester = + new NullPointerTester() + .setDefault(S2AChannelPool.class, pool) + .setDefault(Optional.class, Optional.empty()); + + tester.testStaticMethods(S2AProtocolNegotiator.class, Visibility.PACKAGE); + } + + @Test + public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + + assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnNull_returnsNull() throws Exception { + assertThat(S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority(null)) + .isNull(); + } + + @Test + public void s2aProtocolNegotiator_getHostNameOnValidAuthority_returnsValidHostname() + throws Exception { + assertThat( + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.getHostNameFromAuthority( + "hostname:80")) + .isEqualTo("hostname"); + } + + @Test + public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + assertThat(clientNegotiator).isInstanceOf(S2AProtocolNegotiator.class); + assertThat(clientNegotiator.scheme()).isEqualTo(AsciiString.of("https")); + } + + @Test + public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() + throws Exception { + InternalProtocolNegotiator.ClientFactory clientFactory = + S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); + + clientNegotiator.close(); + + assertThat(((FakeChannelPool) channelPool).isChannelCached()).isFalse(); + } + + @Test + public void createChannelHandler_addHandlerToMockContext() throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + ManagedChannel channel = + Grpc.newChannelBuilder(authority, InsecureChannelCredentials.create()) + .executor(executor) + .build(); + FakeS2AChannelPool fakeChannelPool = new FakeS2AChannelPool(channel); + ProtocolNegotiator clientNegotiator = + S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( + fakeChannelPool, Optional.of(LOCAL_IDENTITY)); + + ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); + + ((ChannelDuplexHandler) channelHandler).userEventTriggered(mockChannelHandlerContext, "event"); + verify(mockChannelHandlerContext).fireUserEventTriggered("event"); + } + + /** A {@link S2AChannelPool} that returns the given channel. */ + private static class FakeS2AChannelPool implements S2AChannelPool { + private final Channel channel; + + FakeS2AChannelPool(Channel channel) { + this.channel = channel; + } + + @Override + public Channel getChannel() { + return channel; + } + + @Override + public void returnChannel(Channel channel) {} + + @Override + public void close() {} + } + + /** A {@code GrpcHttp2ConnectionHandler} that does nothing. */ + private static class FakeConnectionHandler extends GrpcHttp2ConnectionHandler { + private static final Http2ConnectionDecoder DECODER = mock(Http2ConnectionDecoder.class); + private static final Http2ConnectionEncoder ENCODER = mock(Http2ConnectionEncoder.class); + private static final Http2Settings SETTINGS = new Http2Settings(); + private final String authority; + + static FakeConnectionHandler create(String authority) { + return new FakeConnectionHandler(null, DECODER, ENCODER, SETTINGS, authority); + } + + private FakeConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings, + String authority) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); + this.authority = authority; + } + + @Override + public String getAuthority() { + return authority; + } + } + + /** An S2A server that handles GetTlsConfiguration request. */ + private static class S2AServiceImpl extends S2AServiceGrpc.S2AServiceImplBase { + static final FakeWriter writer = new FakeWriter(); + + @Override + public StreamObserver setUpSession(StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SessionReq req) { + responseObserver.onNext(writer.handleResponse(req)); + } + + @Override + public void onError(Throwable t) {} + + @Override + public void onCompleted() {} + }; + } + } + + private static class FakeChannelPool implements ObjectPool { + private final Channel mockChannel = mock(Channel.class); + private @Nullable Channel cachedChannel = null; + + @Override + public Channel getObject() { + if (cachedChannel == null) { + cachedChannel = mockChannel; + } + return cachedChannel; + } + + @Override + public Channel returnObject(Object object) { + assertThat(object).isSameInstanceAs(mockChannel); + cachedChannel = null; + return null; + } + + public boolean isChannelCached() { + return (cachedChannel != null); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java new file mode 100644 index 00000000000..a2b0e673313 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java @@ -0,0 +1,260 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.extensions.proto.ProtoTruth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import io.grpc.internal.SharedResourcePool; +import io.grpc.s2a.channel.S2AChannelPool; +import io.grpc.s2a.channel.S2AGrpcChannelPool; +import io.grpc.s2a.channel.S2AHandshakerServiceChannel; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.Optional; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link S2AStub}. */ +@RunWith(JUnit4.class) +public class S2AStubTest { + @Rule public final Expect expect = Expect.create(); + private static final String S2A_ADDRESS = "localhost:8080"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void send_receiveOkStatus() throws Exception { + S2AChannelPool channelPool = + S2AGrpcChannelPool.create( + SharedResourcePool.forResource( + S2AHandshakerServiceChannel.getChannelResource( + S2A_ADDRESS, /* s2aChannelCredentials= */ Optional.empty()))); + S2AServiceGrpc.S2AServiceStub serviceStub = S2AServiceGrpc.newStub(channelPool.getChannel()); + S2AStub newStub = S2AStub.newInstance(serviceStub); + + IOException expected = + assertThrows(IOException.class, () -> newStub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("UNAVAILABLE"); + } + + @Test + public void send_clientTlsConfiguration_receiveOkStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_CLIENT)) + .build(); + + SessionResp resp = stub.send(req); + + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(FakeWriter.LEAF_CERT) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(resp).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void send_serverTlsConfiguration_receiveErrorStatus() throws Exception { + SessionReq req = + SessionReq.newBuilder() + .setGetTlsConfigurationReq( + GetTlsConfigurationReq.newBuilder() + .setConnectionSide(ConnectionSide.CONNECTION_SIDE_SERVER)) + .build(); + + SessionResp resp = stub.send(req); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder() + .setCode(255) + .setDetails("No TLS configuration for the server side.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + SessionResp resp = stub.send(SessionReq.getDefaultInstance()); + + SessionResp expected = + SessionResp.newBuilder() + .setStatus( + Status.newBuilder().setCode(1).setDetails("Intended ERROR Status from FakeWriter.")) + .build(); + assertThat(resp).isEqualTo(expected); + } + + @Test + public void send_receiveErrorResponse() throws InterruptedException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(RuntimeException.class); + expect.that(expected).hasMessageThat().contains("Intended ERROR from FakeWriter."); + } + + @Test + public void send_receiveCompleteStatus() throws Exception { + writer.setBehavior(FakeWriter.Behavior.COMPLETE_STATUS); + + ConnectionIsClosedException expected = + assertThrows( + ConnectionIsClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Reading from the S2A is complete."); + } + + @Test + public void send_receiveUnexpectedResponse() throws Exception { + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + } + + @Test + public void send_receiveManyUnexpectedResponse_expectResponsesEmpty() throws Exception { + writer.sendIoError(); + writer.sendIoError(); + writer.sendIoError(); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected) + .hasMessageThat() + .contains( + "Received an unexpected response from a host at the S2A's address. The S2A might be" + + " unavailable."); + + assertThat(stub.getResponses()).isEmpty(); + } + + @Test + public void send_receiveDelayedResponse() throws Exception { + writer.sendGetTlsConfigResp(); + SessionResp resp = stub.send(SessionReq.getDefaultInstance()); + SessionResp expected = + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(FakeWriter.LEAF_CERT) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2) + .addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLSVersion.TLS_VERSION_1_3) + .setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384) + .addCiphersuites( + Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256))) + .build(); + assertThat(resp).ignoringRepeatedFieldOrder().isEqualTo(expected); + } + + @Test + public void send_afterEarlyClose_receivesClosedException() throws InterruptedException { + stub.close(); + expect.that(writer.isFakeWriterClosed()).isTrue(); + + ConnectionIsClosedException expected = + assertThrows( + ConnectionIsClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + assertThat(expected).hasMessageThat().contains("Stream to the S2A is closed."); + } + + @Test + public void send_failToWrite() throws Exception { + FailWriter failWriter = new FailWriter(); + stub = S2AStub.newInstanceForTesting(failWriter); + + IOException expected = + assertThrows(IOException.class, () -> stub.send(SessionReq.getDefaultInstance())); + + expect.that(expected).hasCauseThat().isInstanceOf(S2AConnectionException.class); + expect + .that(expected) + .hasCauseThat() + .hasMessageThat() + .isEqualTo("Could not send request to S2A."); + } + + /** Fails whenever a write is attempted. */ + private static class FailWriter implements StreamObserver { + @Override + public void onNext(SessionReq req) { + assertThat(req).isNotNull(); + throw new S2AConnectionException("Could not send request to S2A."); + } + + @Override + public void onError(Throwable t) { + assertThat(t).isInstanceOf(S2AConnectionException.class); + } + + @Override + public void onCompleted() { + throw new UnsupportedOperationException(); + } + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java new file mode 100644 index 00000000000..384e1aba5cc --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2ATrustManagerTest.java @@ -0,0 +1,262 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import io.grpc.s2a.handshaker.S2AIdentity; +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Base64; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class S2ATrustManagerTest { + private S2AStub stub; + private FakeWriter writer; + private static final String FAKE_HOSTNAME = "Fake-Hostname"; + private static final String CLIENT_CERT_PEM = + "MIICKjCCAc+gAwIBAgIUC2GShcVO+5Zkml+7VO3OQ+B2c7EwCgYIKoZIzj0EAwIw" + + "HzEdMBsGA1UEAwwUcm9vdGNlcnQuZXhhbXBsZS5jb20wIBcNMjMwMTI2MTk0OTUx" + + "WhgPMjA1MDA2MTMxOTQ5NTFaMB8xHTAbBgNVBAMMFGxlYWZjZXJ0LmV4YW1wbGUu" + + "Y29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEeciYZgFAZjxyzTrklCRIWpad" + + "8wkyCZQzJSf0IfNn9NKtfzL2V/blteULO0o9Da8e2Avaj+XCKfFTc7salMo/waOB" + + "5jCB4zAOBgNVHQ8BAf8EBAMCB4AwIAYDVR0lAQH/BBYwFAYIKwYBBQUHAwIGCCsG" + + "AQUFBwMBMAwGA1UdEwEB/wQCMAAwYQYDVR0RBFowWIYic3BpZmZlOi8vZm9vLnBy" + + "b2QuZ29vZ2xlLmNvbS9wMS9wMoIUZm9vLnByb2Quc3BpZmZlLmdvb2eCHG1hY2hp" + + "bmUtbmFtZS5wcm9kLmdvb2dsZS5jb20wHQYDVR0OBBYEFETY6Cu/aW924nfvUrOs" + + "yXCC1hrpMB8GA1UdIwQYMBaAFJLkXGlTYKISiGd+K/Ijh4IOEpHBMAoGCCqGSM49" + + "BAMCA0kAMEYCIQCZDW472c1/4jEOHES/88X7NTqsYnLtIpTjp5PZ62z3sAIhAN1J" + + "vxvbxt9ySdFO+cW7oLBEkCwUicBhxJi5VfQeQypT"; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_withNullStub_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + /* stub= */ null, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void createForClient_withNullHostname_throwsError() { + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + S2ATrustManager.createForClient( + stub, /* hostname= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isNull(); + } + + @Test + public void getAcceptedIssuers_returnsExpectedNullResult() { + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + assertThat(trustManager.getAcceptedIssuers()).isNull(); + } + + @Test + public void checkClientTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkClientTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkServerTrusted_withEmptyCertificateChain_throwsException() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + IllegalArgumentException expected = + assertThrows( + IllegalArgumentException.class, + () -> trustManager.checkServerTrusted(new X509Certificate[] {}, /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Certificate chain has zero certificates."); + } + + @Test + public void checkClientTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkClientTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_getsSuccessResponse() throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkServerTrusted_withLocalIdentity_getsSuccessResponse() + throws CertificateException { + writer.setVerificationResult(FakeWriter.VerificationResult.SUCCESS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient( + stub, FAKE_HOSTNAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + // Expect no exception. + trustManager.checkServerTrusted(getCerts(), /* authType= */ ""); + } + + @Test + public void checkClientTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkClientTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkClientTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkClientTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureResponse() throws CertificateException { + writer + .setVerificationResult(FakeWriter.VerificationResult.FAILURE) + .setFailureReason("Intended failure."); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Intended failure."); + } + + @Test + public void checkServerTrusted_getsIntendedFailureStatusInResponse() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().contains("Error occurred in response from S2A"); + } + + @Test + public void checkServerTrusted_getsIntendedFailureFromServer() throws CertificateException { + writer.setBehavior(FakeWriter.Behavior.ERROR_RESPONSE); + S2ATrustManager trustManager = + S2ATrustManager.createForClient(stub, FAKE_HOSTNAME, /* localIdentity= */ Optional.empty()); + + CertificateException expected = + assertThrows( + CertificateException.class, + () -> trustManager.checkServerTrusted(getCerts(), /* authType= */ "")); + + assertThat(expected).hasMessageThat().isEqualTo("Failed to send request to S2A."); + } + + private X509Certificate[] getCerts() throws CertificateException { + byte[] decoded = Base64.getDecoder().decode(CLIENT_CERT_PEM); + return new X509Certificate[] { + (X509Certificate) + CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(decoded)) + }; + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java new file mode 100644 index 00000000000..c33fd820e4c --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.truth.Expect; +import io.grpc.s2a.handshaker.S2AIdentity; +import io.netty.handler.ssl.OpenSslSessionContext; +import io.netty.handler.ssl.SslContext; +import java.security.GeneralSecurityException; +import java.util.Optional; +import javax.net.ssl.SSLSessionContext; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link SslContextFactory}. */ +@RunWith(JUnit4.class) +public final class SslContextFactoryTest { + @Rule public final Expect expect = Expect.create(); + private static final String FAKE_TARGET_NAME = "fake_target_name"; + private S2AStub stub; + private FakeWriter writer; + + @Before + public void setUp() { + writer = new FakeWriter(); + stub = S2AStub.newInstanceForTesting(writer); + writer.setReader(stub.getReader()); + } + + @Test + public void createForClient_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty()); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + expect + .that(sslContext.cipherSuites()) + .containsExactly( + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_withLocalIdentity_returnsValidSslContext() throws Exception { + SslContext sslContext = + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, Optional.of(S2AIdentity.fromSpiffeId("fake-spiffe-id"))); + + expect.that(sslContext).isNotNull(); + expect.that(sslContext.sessionCacheSize()).isEqualTo(1); + expect.that(sslContext.sessionTimeout()).isEqualTo(300); + expect.that(sslContext.isClient()).isTrue(); + expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); + expect + .that(sslContext.cipherSuites()) + .containsExactly( + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); + SSLSessionContext sslSessionContext = sslContext.sessionContext(); + if (sslSessionContext instanceof OpenSslSessionContext) { + OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; + expect.that(openSslSessionContext.isSessionCacheEnabled()).isFalse(); + } + } + + @Test + public void createForClient_returnsEmptyResponse_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.EMPTY_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Response from S2A server does NOT contain ClientTlsConfiguration."); + } + + @Test + public void createForClient_returnsErrorStatus_error() throws Exception { + writer.setBehavior(FakeWriter.Behavior.ERROR_STATUS); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().contains("Intended ERROR Status from FakeWriter."); + } + + @Test + public void createForClient_getsErrorFromServer_throwsError() throws Exception { + writer.sendIoError(); + + GeneralSecurityException expected = + assertThrows( + GeneralSecurityException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Failed to get client TLS configuration from S2A."); + } + + @Test + public void createForClient_nullStub_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + /* stub= */ null, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected).hasMessageThat().isEqualTo("stub should not be null."); + } + + @Test + public void createForClient_nullTargetName_throwsError() throws Exception { + writer.sendUnexpectedResponse(); + + NullPointerException expected = + assertThrows( + NullPointerException.class, + () -> + SslContextFactory.createForClient( + stub, /* targetName= */ null, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .isEqualTo("targetName should not be null on client side."); + } +} \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java new file mode 100644 index 00000000000..806e412b784 --- /dev/null +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.s2a.handshaker.tokenmanager; + +import static com.google.common.truth.Truth.assertThat; + +import com.beust.jcommander.JCommander; +import io.grpc.s2a.handshaker.S2AIdentity; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class SingleTokenAccessTokenManagerTest { + private static final S2AIdentity IDENTITY = S2AIdentity.fromSpiffeId("spiffe_id"); + private static final String TOKEN = "token"; + private static final String[] SET_TOKEN = {"--s2a_access_token", TOKEN}; + private static final SingleTokenFetcher.Flags FLAGS = new SingleTokenFetcher.Flags(); + + @Before + public void setUp() { + FLAGS.reset(); + } + + @Test + public void getDefaultToken_success() throws Exception { + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getDefaultToken()).isEqualTo(TOKEN); + } + + @Test + public void getToken_success() throws Exception { + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void getToken_noEnvironmentVariable() throws Exception { + assertThat(SingleTokenFetcher.create()).isEmpty(); + } + + @Test + public void create_success() throws Exception { + JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + Optional manager = AccessTokenManager.create(); + assertThat(manager).isPresent(); + assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); + } + + @Test + public void create_noEnvironmentVariable() throws Exception { + assertThat(AccessTokenManager.create()).isEmpty(); + } +} \ No newline at end of file diff --git a/s2a/src/test/resources/README.md b/s2a/src/test/resources/README.md new file mode 100644 index 00000000000..00901015444 --- /dev/null +++ b/s2a/src/test/resources/README.md @@ -0,0 +1,31 @@ +# Generating certificates and keys for testing mTLS-S2A + +Content from: https://github.com/google/s2a-go/blob/main/testdata/README.md + +Create root CA + +``` +openssl req -x509 -sha256 -days 7305 -newkey rsa:2048 -keyout root_key.pem -out +root_cert.pem +``` + +Generate private keys for server and client + +``` +openssl genrsa -out server_key.pem 2048 +openssl genrsa -out client_key.pem 2048 +``` + +Generate CSRs for server and client + +``` +openssl req -key server_key.pem -new -out server.csr -config config.cnf +openssl req -key client_key.pem -new -out client.csr -config config.cnf +``` + +Sign CSRs for server and client + +``` +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in server.csr -out server_cert.pem -days 7305 -extfile config.cnf -extensions req_ext +openssl x509 -req -CA root_cert.pem -CAkey root_key.pem -in client.csr -out client_cert.pem -days 7305 +``` \ No newline at end of file diff --git a/s2a/src/test/resources/client.csr b/s2a/src/test/resources/client.csr new file mode 100644 index 00000000000..664f5a4cf86 --- /dev/null +++ b/s2a/src/test/resources/client.csr @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIChzCCAW8CAQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAoSS3KtFgiXX4vAUNscFGIB/r2OOMgiZMKHz72dN0 +5kSxwdpQxpMIhwEoe0lhHNfOiuE7/r6VbGG9RGGIcQcoSonc3InPRfpnzfj9KohJ +i8pYkLL9EwElAEl9sWnvVKTza8jTApDP2Z/fntBEsWAMsLPpuRZT6tgN1sXe4vNG +4wufJSxuImyCVAx1fkZjRkYEKOtm1osnEDng4R0WXZ6S+q5lYzYPk1wxgbjdZu2U +fWxP6V63SphV0NFXTx0E401j2h258cIqTVj8lRX6dfl9gO0d43Rd+hSU7R4iXGEw +arixuH9g5H745AFf9H52twHPcNP9cEKBljBpSV5z3MvTkQIDAQABoC4wLAYJKoZI +hvcNAQkOMR8wHTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAAAMA0GCSqGSIb3 +DQEBCwUAA4IBAQCQHim3aIpGJs5u6JhEA07Rwm8YKyVALDEklhsHILlFhdNr2uV7 +S+3bHV79mDGjxNWvFcgK5h5ENkT60tXbhbie1gYmFT0RMCYHDsL09NGTh8G9Bbdl +UKeA9DMhRSYzE7Ks3Lo1dJvX7OAEI0qV77dGpQknufYpmHiBXuqtB9I0SpYi1c4O +9IUn/NY0yiYFPsIEsVRz/1dK97wazusLnijaMwNNhUc9bJwTyujhlr+b8ioPyADG +e+GDF97d0nQ8806DOJF4GTRKwaXD+R5zN5t4ULhZ7ERqLNeE9EnWRe4CvSGvBoNA +hIVeYaLd761Z9ZKvOnsgCr8qvMDilDFY6OfB +-----END CERTIFICATE REQUEST----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_cert.pem b/s2a/src/test/resources/client_cert.pem new file mode 100644 index 00000000000..b72f6991c91 --- /dev/null +++ b/s2a/src/test/resources/client_cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC9DCCAdwCFB+cDXee2sIHjdlBhdNpTo+G2XAjMA0GCSqGSIb3DQEBCwUAMFkx +CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzEw +MTcyMzA5MDNaFw00MzEwMTcyMzA5MDNaMBQxEjAQBgNVBAMMCWxvY2FsaG9zdDCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKEktyrRYIl1+LwFDbHBRiAf +69jjjIImTCh8+9nTdOZEscHaUMaTCIcBKHtJYRzXzorhO/6+lWxhvURhiHEHKEqJ +3NyJz0X6Z834/SqISYvKWJCy/RMBJQBJfbFp71Sk82vI0wKQz9mf357QRLFgDLCz +6bkWU+rYDdbF3uLzRuMLnyUsbiJsglQMdX5GY0ZGBCjrZtaLJxA54OEdFl2ekvqu +ZWM2D5NcMYG43WbtlH1sT+let0qYVdDRV08dBONNY9odufHCKk1Y/JUV+nX5fYDt +HeN0XfoUlO0eIlxhMGq4sbh/YOR++OQBX/R+drcBz3DT/XBCgZYwaUlec9zL05EC +AwEAATANBgkqhkiG9w0BAQsFAAOCAQEARorc1t2OJnwm1lxhf2KpTpNvNOI9FJak +iSHz/MxhMdu4BG/dQHkKkWoVC6W2Kaimx4OImBwRlGEmGf4P0bXOLSTOumk2k1np +ZUbw7Z2cJzvBmT2BLoHRXcBvbFIBW5DJUSHR37eXEKP57BeD+Og4/3XhNzehSpTX +DRd2Ix/D39JjYA462nqPHQP8HDMf6+0BFmvf9ZRYmFucccYQRCUCKDqb8+wGf9W6 +tKNRE6qPG2jpAQ9qkgO7XuucbLvpywt5xj+yDRbOIq43l40mHaz4lRp697oaxjP8 +HSVcMydW3cluoW3AVInNIaqbM1dr6931MllK62DKipFtmCycq/56XA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/client_key.pem b/s2a/src/test/resources/client_key.pem new file mode 100644 index 00000000000..dd3e2ff78f1 --- /dev/null +++ b/s2a/src/test/resources/client_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQChJLcq0WCJdfi8 +BQ2xwUYgH+vY44yCJkwofPvZ03TmRLHB2lDGkwiHASh7SWEc186K4Tv+vpVsYb1E +YYhxByhKidzcic9F+mfN+P0qiEmLyliQsv0TASUASX2xae9UpPNryNMCkM/Zn9+e +0ESxYAyws+m5FlPq2A3Wxd7i80bjC58lLG4ibIJUDHV+RmNGRgQo62bWiycQOeDh +HRZdnpL6rmVjNg+TXDGBuN1m7ZR9bE/pXrdKmFXQ0VdPHQTjTWPaHbnxwipNWPyV +Ffp1+X2A7R3jdF36FJTtHiJcYTBquLG4f2DkfvjkAV/0fna3Ac9w0/1wQoGWMGlJ +XnPcy9ORAgMBAAECggEALAUqoGDIHWUDyOEch5WDwZzWwc4PgTJTFbBm4G96fLkB +UjKAZG6gIrk3RM6b39Q4UQoMaJ/Jk+zzVi3Kpw3MfOhCVGC1JamtF8BP8IGAjdZ9 +8TFkHv/uCrEIzCFjRt00vhoDQq0qiom4/dppGYdikBbl3zDxRbM1vJkbNSY+FCGW +dA0uJ5XdMLR6lPeB5odqjUggnfUgPCOLdV/F+HkSM9NP1bzmHLiKznzwFsfat139 +7LdzJwNN5IX4Io6cxsxNlrX/NNvPkKdGv07Z6FYxWROyKCunjh48xFcQg0ltoRuq +R9P8/LwS8GYrcc1uC/uBc0e6VgM9D9fsvh+8SQtf3QKBgQDXX+z2GnsFoEs7xv9U +qN0HEX4jOkihZvFu43layUmeCeE8wlEctJ0TsM5Bd7FMoUG6e5/btwhsAIYW89Xn +l/R8OzxR6Kh952Dce4DAULuIeopiw7ASJwTZtO9lWhxw0hjM1hxXTG+xxOqQvsRX +c+d+vtvdIqyJ4ELfzg9kUtkdpwKBgQC/ig3cmej7dQdRAMn0YAwgwhuLkCqVFh4y +WIlqyPPejKf8RXubqgtaSYx/T7apP87SMMSfSLaUdrYAGjST6k+tG5cmwutPIbw/ +osL7U3hcIhjX3hfHgI69Ojcpplbd5yqTxZHpxIs6iAQCEqNuasLXIDMouqNhGF1D +YssD6qxcBwKBgQCdZqWvVrsB6ZwSG+UO4jpmqAofhMD/9FQOToCqMOF0dpP966WL +7RO/CEA06FzTPCblOuQhlyq4g8l7jMiPcSZkhIYY9oftO+Q2Pqxh4J6tp6DrfUh4 +e7u3v9wVnj2a1nD5gqFDy8D1kow7LLAhmbtdje7xNh4SxasaFWZ6U3IJkQKBgGS1 +F5i3q9IatCAZBBZjMb0/kfANevYsTPA3sPjec6q91c1EUzuDarisFx0RMn9Gt124 +mokNWEIzMHpZTO/AsOfZq92LeuF+YVYsI8y1FIGMw/csJOCWbXZ812gkt2OxGafc +p118I6BAx6q3VgrGQ2+M1JlDmIeCofa+SPPkPX+dAoGBAJrOgEJ+oyEaX/YR1g+f +33pWoPQbRCG7T4+Y0oetCCWIcMg1/IUvGUCGmRDxj5dMqB+a0vJtviQN9rjpSuNS +0EVw79AJkIjHhi6KDOfAuyBvzGhxpqxGufnQ2GU0QL65NxQfd290xkxikN0ZGtuB +SDgZoJxMOGYwf8EX5i9h27Db +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/config.cnf b/s2a/src/test/resources/config.cnf new file mode 100644 index 00000000000..38d9a9ccdb0 --- /dev/null +++ b/s2a/src/test/resources/config.cnf @@ -0,0 +1,17 @@ +[req] +distinguished_name = req_distinguished_name +req_extensions = req_ext + +[req_distinguished_name] +countryName = Country Name (2 letter code) +stateOrProvinceName = State or Province Name (full name) +localityName = Locality Name (eg, city) +organizationalUnitName = Organizational Unit Name (eg, section) +commonName = Common Name (eg, your name or your server\'s hostname) +emailAddress = Email Address + +[req_ext] +subjectAltName = @alt_names + +[alt_names] +IP.1 = :: \ No newline at end of file diff --git a/s2a/src/test/resources/root_cert.pem b/s2a/src/test/resources/root_cert.pem new file mode 100644 index 00000000000..737e601691c --- /dev/null +++ b/s2a/src/test/resources/root_cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUb7RsINwsFgKf0Q0RuzfOgp48j6UwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIzMTAxNzIzMDczOFoXDTQzMTAxNzIzMDczOFowWTELMAkGA1UEBhMCQVUxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAkIFnQLuhzYnm3rvmi/U7zMgEP2Tqgb3VC00frSXEV6olZcLgyC9g +0DAGdt9l9lP90DQTG5KCOtoW2BTqM/aaVpR0OaDFOCy90FIj6YyZLZ9w2PQxQcxS +GQHyEvWszTkNxeDyG1mPTj+Go8JLKqdvLg/9GUgPg6stxyAZwYhyUTGuEM4bv0sn +b3vmHRmIGJ/w6aLtd7nK8LkNHa3WVrbvRGHrzdMHfpzF/M/5fAk8GfRYugo39knf +VLKGyQCXNI8Y1iHGEmPqQZIFPTjBL6caIlbEV0VHlxoSOGB6JVxcllxAEvd6abqX +RJVJPQzzGfEnMNYp9SiZQ9bvDRUsUkWyYwIDAQABo1MwUTAdBgNVHQ4EFgQUAZMN +F9JAGHbA3jGOeu6bWFvSdWkwHwYDVR0jBBgwFoAUAZMNF9JAGHbA3jGOeu6bWFvS +dWkwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAicBli36ISJFu +lrJqOHVqTeNP6go0I35VGnP44nEEP5cBvRD3XntBFEk5D3mSNNOGt+2ncxom8VR9 +FsLuTfHAipXePJI6MSxFuBPea8V/YPBs3npk5f1FRvJ5vEgtzFvBjsKmp1dS9hH0 +KUWtWcsAkO2Anc/LVc0xxSidL8NjzYoEFqiki0TNNwCJjmd9XwnBLHW38sEb/pgy +KTyRpOyG3Zg2UDjBHiXPBrmIvVFLB6+LrPNvfr1k4HjIgVY539ZXUvVMDKytMrDY +h63EMDn4kkPpxXlufgWGybjN5D51OylyWBZLe+L1DQyWEg0Vd7GwPzb6p7bmI7MP +pooqbgbDpQ== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/root_key.pem b/s2a/src/test/resources/root_key.pem new file mode 100644 index 00000000000..aae992426d7 --- /dev/null +++ b/s2a/src/test/resources/root_key.pem @@ -0,0 +1,30 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFHDBOBgkqhkiG9w0BBQ0wQTApBgkqhkiG9w0BBQwwHAQInmQVkXP3TFcCAggA +MAwGCCqGSIb3DQIJBQAwFAYIKoZIhvcNAwcECGeCAVH1pefxBIIEyD3Nj1Dy19oy +fogU+z8YBLXuSCx8s3zncYPF9nYlegGSSo0ace/WxfPu8AEPus1P2MxlxfcCQ1A+ +5+vMihtEpgpTg9R4RlLAWs45jz4AduGiwqW05+W5zgDn6g7p7HIL0+M5FxKRkAW0 +KEH4Jy8Vc1XQxkhOm1Q4NLI8PT94rcBDE9Od03sdrW/hQgaOFz5AWOlT5jF1uUOz +glF1RQQxfJygTB6qlPTC3BAaiAnWij3NOg5L5vvUhjLa7iOZOhRQBRkf4YtHsM+2 +rFy8Z7MeHOvrqFf8LXosNy3JreQW036rLGR0Xh5myATkNrEwA8df37AgLUmwqyfz +hjZefPW77LgMAXlaN8s345AGikOX8yQKEFzPV/Nag32p6t4oiRRcUUfdB4wzKi6T +mzZ6lKcGR3qqL4V6lJSV3I2fmgkYZnUwymolyu+1+CVYDLuE53TBi5dRXwgOghi7 +npw7PqqQCian8yxHF9c1rYukD0ov0/y8ratjOu9XoJG2/wWQJNvDkAyc3mSJf+3y +6Wtu1qhLszU8pZOGW0fK6bGyHSp+wkoah/vRzB0+yFjvuMIG6py2ZDQeqhqS3ZV2 +nZHHjj0tZ45Wbdf4k17ujEK34pFXluPH//zADnd6ym2W0t6x+jtqR5tYu3poORQg +jFgpudkn2RUSq8N/gIiHDwblYBxU2dmyzEVudv1zNgVSHyetGLxsFoNB7Prn89rJ +u24a/xtuCyC2pshWo3KiL74hkkCsC8rLbEAAbADheb35b+Ca3JnMwgyUHbHL6Hqf +EiVIgm14lB/1uz651X58Boo6tDFkgrxEtGDUIZm8yk2n0tGflp7BtYbMCw+7gqhb +XN4hlhFDcCJm8peXcyCtGajOnBuNO9JJDNYor6QjptaIpQBFb7/0rc7kyO12BIUv +F9mrCHF18Hd/9AtUO93+tyDAnL64Jqq9tUv8dOVtIfbcHXZSYHf24l0XAiKByb8y +9NQLUZkIuF4aUZVHV8ZBDdHNqjzqVglKQlGHdw1XBexSal5pC9HvknOmWBgl0aza +flzeTRPX7TPrMJDE5lgSy58czGpvZzhFYwOp6cwpfjNsiqdzD78Zs0xsRbNg519s +d+cLmbiU3plWCoYCuDb68eZRRzT+o41+QJG2PoMCpzPw5wMLl6HuW7HXMRFpZKJc +tPKpeTIzb8hjhA+TwVIVpTPHvvQehtTUQD2mRujdvNM6PF8tnuC3F3sB3PTjeeJg +uzfEfs3BynRTIj/gX6y87gzwsrwWIEN6U0cCbQ6J1EcgdQCiH8vbhIgfd4DkLgLN +Kkif+fI/HgBOqaiwSw3sHmWgB6PllVQOKH6qAiejTHR/UUvJTPvgKJFLunmBiF12 +N1bRge1sSXE1eLKVdi+dP1j0o6RxhaRrbX7ie3y/wYHwCJnb8h08DEprgCqoswFs +SuNKmvlibBHAsnOdhyCTOd9I5n8XzAUUp6mT+C5WDfl7qfYvh6IHFlSrhZ9aS9b6 +RY873cnphKbqU5d7Cr8Ufx4b4SgS+hEnuP8y5IToLQ3BONGQH2lu7nmd89wjW0uo +IMRXybwf/5FnKhEy8Aw+pD6AxiXC3DZVTKl3SHmjkYBDvNElsJVgygVTKgbOa1Z+ +ovIK/D7QV7Nv3uVortH8XA== +-----END ENCRYPTED PRIVATE KEY----- \ No newline at end of file diff --git a/s2a/src/test/resources/server.csr b/s2a/src/test/resources/server.csr new file mode 100644 index 00000000000..1657b191133 --- /dev/null +++ b/s2a/src/test/resources/server.csr @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE REQUEST----- +MIIChzCCAW8CAQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAlPThqu8tfJ4hQKRiUw/vNPfo2L2LQU8NlrRL7rvV +71E345LGK1h/hM3MHp5VgEvaaIibb0hSNv/TYz3HVCQyNuPlcmkHZTJ9mB0icilU +rYWdM0LPIg46iThmIQVhMiNfpMKQLDLQ7o3Jktjm32OxnQdtYSV+7NFnw8/0pB4j +iaiBYfZIMeGzEJIOFG8GSNJG0pfCI71DyLRonIcb2XzfeDPHeWSF7lbIoMGAuKIE +2mXpwHmAjTMJzIShSgLqCvmbz7wR3ZeVMknXcgcqMmagGphy8SjizIWC5KRbrnRq +F22Ouxdat6scIevRXGp5nYawFYdpK9qo+82gEouVX3dtSQIDAQABoC4wLAYJKoZI +hvcNAQkOMR8wHTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAAAMA0GCSqGSIb3 +DQEBCwUAA4IBAQB2qU354OlNVunhZhiOFNwabovxLcgKoQz+GtJ2EzsMEza+NPvV +dttPxXzqL/U+gDghvGzSYGuh2yMfTTPO+XtZKpvMUmIWonN5jItbFwSTaWcoE8Qs +zFZokRuFJ9dy017u642mpdf6neUzjbfCjWs8+3jyFzWlkrMF3RlSTxPuksWjhXsX +dxxLNu8YWcsYRB3fODHqrlBNuDn+9kb9z8to+yq76MA0HtdDkjd/dfgghiTDJhqm +IcwhBXufwQUrOP4YiuiwM0mo7Xlhw65gnSmRcwR9ha98SV2zG5kiRYE+m+94bDbd +kGBRfhpQSzh1w09cVzmLgzkfxRShEB+bb9Ss +-----END CERTIFICATE REQUEST----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_cert.pem b/s2a/src/test/resources/server_cert.pem new file mode 100644 index 00000000000..10a98cf5c21 --- /dev/null +++ b/s2a/src/test/resources/server_cert.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDWjCCAkKgAwIBAgIUMZkgD5gtoa39H9jdI/ijVkyxC/swDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIzMTAxNzIzMDg1M1oXDTQzMTAxNzIzMDg1M1owFDESMBAGA1UEAwwJbG9jYWxo +b3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAlPThqu8tfJ4hQKRi +Uw/vNPfo2L2LQU8NlrRL7rvV71E345LGK1h/hM3MHp5VgEvaaIibb0hSNv/TYz3H +VCQyNuPlcmkHZTJ9mB0icilUrYWdM0LPIg46iThmIQVhMiNfpMKQLDLQ7o3Jktjm +32OxnQdtYSV+7NFnw8/0pB4jiaiBYfZIMeGzEJIOFG8GSNJG0pfCI71DyLRonIcb +2XzfeDPHeWSF7lbIoMGAuKIE2mXpwHmAjTMJzIShSgLqCvmbz7wR3ZeVMknXcgcq +MmagGphy8SjizIWC5KRbrnRqF22Ouxdat6scIevRXGp5nYawFYdpK9qo+82gEouV +X3dtSQIDAQABo18wXTAbBgNVHREEFDAShxAAAAAAAAAAAAAAAAAAAAAAMB0GA1Ud +DgQWBBTKJU+NK7Q6ZPccSigRCMBCBgjkaDAfBgNVHSMEGDAWgBQBkw0X0kAYdsDe +MY567ptYW9J1aTANBgkqhkiG9w0BAQsFAAOCAQEAXuCs6MGVoND8TaJ6qaDmqtpy +wKEW2hsGclI9yv5cMS0XCVTkmKYnIoijtqv6Pdh8PfhIx5oJqJC8Ml16w4Iou4+6 +kKF0DdzdQyiM0OlNCgLYPiR4rh0ZCAFFCvOsDum1g+b9JTFZGooK4TMd9thwms4D +SqpP5v1NWf/ZLH5TYnp2CkPzBxDlnMJZphuWtPHL+78TbgQuQaKu2nMLBGBJqtFi +HDOGxckgZuwBsy0c+aC/ZwaV7FdMP42kxUZduCEx8+BDSGwPoEpz6pwVIkjiyYAm +3O8FUeEPzYzwpkANIbbEIDWV6FVH9IahKRRkE+bL3BqoQkv8SMciEA5zWsPrbA== +-----END CERTIFICATE----- \ No newline at end of file diff --git a/s2a/src/test/resources/server_key.pem b/s2a/src/test/resources/server_key.pem new file mode 100644 index 00000000000..44f087dee94 --- /dev/null +++ b/s2a/src/test/resources/server_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCU9OGq7y18niFA +pGJTD+809+jYvYtBTw2WtEvuu9XvUTfjksYrWH+EzcwenlWAS9poiJtvSFI2/9Nj +PcdUJDI24+VyaQdlMn2YHSJyKVSthZ0zQs8iDjqJOGYhBWEyI1+kwpAsMtDujcmS +2ObfY7GdB21hJX7s0WfDz/SkHiOJqIFh9kgx4bMQkg4UbwZI0kbSl8IjvUPItGic +hxvZfN94M8d5ZIXuVsigwYC4ogTaZenAeYCNMwnMhKFKAuoK+ZvPvBHdl5UySddy +ByoyZqAamHLxKOLMhYLkpFuudGoXbY67F1q3qxwh69FcanmdhrAVh2kr2qj7zaAS +i5Vfd21JAgMBAAECggEACTBuN4hXywdKT92UP0GNZTwh/jT7QUUqNnDa+lhWI1Rk +WUK1vPjRrRSxEfZ8mdSUHbzHsf7JK6FungGyqUsuWdqHTh6SmTibLOYnONm54paK +kx38/0HXdJ2pF0Jos5ohDV3/XOqpnv3aQJfm7kMNMv3BTqvsf5mPiDHtCq7dTGGj +rGiLc0zirKZq79C6YSB1UMB01BsDl2ScflK8b3osT18uYx/BOdjLT4yZWQsU/nbB +OeF+ziWTTUAVjodGeTf+NYG7cFN/9N9PdSnAwuw8Nche3xZKbHTh2I578Zd4bsDX +H+hoMN862nzOXEvD6KyLB8xDdnEZ+p+njeDROJVmgQKBgQDQhzQEl/co1LYc5IDO +mynhCOtKJeRWBLhYEPIuaSY3qF+lrOWzqyOUNppWDx+HeKOq70X1Q+ETeSXtbaL1 +qHBkNcApQ2lStcpkR9whcVbr9NIWC8y8UQxyerEK3x3l0bZ99dfJ/z6lbxdS7prc +Hhxy6pUj8Q8AgpTZA8HfQUF1EQKBgQC23ek24kTVvWeWX2C/82H1Yfia6ITL7WHz +3aEJaZaO5JD3KmOSZgY88Ob3pkDTRYjFZND5zSB7PnM68gpo/OEDla6ZYtfwBWCX +q4QhFtv2obehobmDk+URVfvlOcBikoEP1i8oy7WdZ5CgC4gNKkkD15l68W+g5IIG +2ZOA97yUuQKBgDAzoI2TRxmUGciR9UhMy6Bt/F12ZtKPYsFQoXqi6aeh7wIP9kTS +wXWoLYLJGiOpekOv7X7lQujKbz7zweCBIAG5/wJKx9TLms4VYkgEt+/w9oMMFTZO +kc8Al14I9xNBp6p0In5Z1vRMupp79yX8e90AZpsZRLt8c8W6PZ1Kq0PRAoGBAKmD +7LzD46t/eJccs0M9CoG94Ac5pGCmHTdDLBTdnIO5vehhkwwTJ5U2e+T2aQFwY+kY +G+B1FrconQj3dk78nFoGV2Q5DJOjaHcwt7s0xZNLNj7O/HnMj3wSiP9lGcJGrP1R +P0ZCEIlph9fU2LnbiPPW2J/vT9uF+EMBTosvG9GBAoGAEVaDLLXOHj+oh1i6YY7s +0qokN2CdeKY4gG7iKjuDFb0r/l6R9uFvpUwJMhLEkF5SPQMyrzKFdnTpw3n/jnRa +AWG6GoV+D7LES+lHP5TXKKijbnHJdFjW8PtfDXHCJ6uGG91vH0TMMp1LqhcvGfTv +lcNGXkk6gUNSecxBC1uJfKE= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index ae6e395e7a1..a550789aaca 100644 --- a/settings.gradle +++ b/settings.gradle @@ -64,6 +64,7 @@ include ":grpc-benchmarks" include ":grpc-services" include ":grpc-servlet" include ":grpc-servlet-jakarta" +include ":grpc-s2a" include ":grpc-xds" include ":grpc-bom" include ":grpc-rls" @@ -98,6 +99,7 @@ project(':grpc-benchmarks').projectDir = "$rootDir/benchmarks" as File project(':grpc-services').projectDir = "$rootDir/services" as File project(':grpc-servlet').projectDir = "$rootDir/servlet" as File project(':grpc-servlet-jakarta').projectDir = "$rootDir/servlet/jakarta" as File +project(':grpc-s2a').projectDir = "$rootDir/s2a" as File project(':grpc-xds').projectDir = "$rootDir/xds" as File project(':grpc-bom').projectDir = "$rootDir/bom" as File project(':grpc-rls').projectDir = "$rootDir/rls" as File From 739ee23e03350104b0d60a4d1dbb14a6026b67af Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 13:20:33 -0700 Subject: [PATCH 02/27] update to use gRPC Authors with copyright. --- s2a/src/main/proto/grpc/gcp/common.proto | 7 +++++-- s2a/src/main/proto/grpc/gcp/s2a_context.proto | 1 + 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/s2a/src/main/proto/grpc/gcp/common.proto b/s2a/src/main/proto/grpc/gcp/common.proto index 7c105c2ce05..1fffc5e1aeb 100644 --- a/s2a/src/main/proto/grpc/gcp/common.proto +++ b/s2a/src/main/proto/grpc/gcp/common.proto @@ -1,10 +1,10 @@ -// Copyright 2022 Google LLC +// Copyright 2024 The gRPC Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// https://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +// The canonical version of this proto can be found at +// https://github.com/grpc/grpc-proto/blob/master/grpc/gcp/s2a/common.proto + syntax = "proto3"; package grpc.gcp; diff --git a/s2a/src/main/proto/grpc/gcp/s2a_context.proto b/s2a/src/main/proto/grpc/gcp/s2a_context.proto index 5ad264bf875..8d49471c0d3 100644 --- a/s2a/src/main/proto/grpc/gcp/s2a_context.proto +++ b/s2a/src/main/proto/grpc/gcp/s2a_context.proto @@ -14,6 +14,7 @@ // The canonical version of this proto can be found at // https://github.com/grpc/grpc-proto/blob/master/grpc/gcp/s2a/s2a_context.proto + syntax = "proto3"; package grpc.gcp; From 72630d8d23ef6d1e8c384449977f85d97e26cd6e Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 13:29:25 -0700 Subject: [PATCH 03/27] S2AChannelPool returnChannel --> returnToPool name change. --- .../java/io/grpc/s2a/channel/S2AChannelPool.java | 2 +- .../java/io/grpc/s2a/channel/S2AGrpcChannelPool.java | 2 +- .../io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java | 12 ++++++------ .../handshaker/S2AProtocolNegotiatorFactoryTest.java | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java index e0501e91c66..9c849442dca 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java @@ -32,7 +32,7 @@ public interface S2AChannelPool extends AutoCloseable { Channel getChannel(); /** Returns a channel to the channel pool. */ - void returnChannel(Channel channel); + void returnToPool(Channel channel); /** * Returns all channels to the channel pool and closes the pool so that no new channels can be diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java index 1d1de28e64e..63d2d6fd736 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java @@ -87,7 +87,7 @@ public synchronized Channel getChannel() { *

The caller must ensure that {@code channel} was retrieved from this channel pool. */ @Override - public synchronized void returnChannel(Channel channel) { + public synchronized void returnToPool(Channel channel) { checkState(state.equals(State.OPEN), "Channel pool is not open."); checkArgument( cachedChannel != null && numberOfUsersOfCachedChannel > 0 && cachedChannel.equals(channel), diff --git a/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java b/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java index 13eccac682d..260129f8f56 100644 --- a/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java +++ b/s2a/src/test/java/io/grpc/s2a/channel/S2AGrpcChannelPoolTest.java @@ -43,34 +43,34 @@ public void getChannel_success() throws Exception { } @Test - public void returnChannel_success() throws Exception { + public void returnToPool_success() throws Exception { FakeChannelPool fakeChannelPool = new FakeChannelPool(); S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); - s2aChannelPool.returnChannel(s2aChannelPool.getChannel()); + s2aChannelPool.returnToPool(s2aChannelPool.getChannel()); assertThat(fakeChannelPool.isChannelCached()).isFalse(); } @Test - public void returnChannel_channelStillCachedBecauseMultipleChannelsRetrieved() throws Exception { + public void returnToPool_channelStillCachedBecauseMultipleChannelsRetrieved() throws Exception { FakeChannelPool fakeChannelPool = new FakeChannelPool(); S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(fakeChannelPool); s2aChannelPool.getChannel(); - s2aChannelPool.returnChannel(s2aChannelPool.getChannel()); + s2aChannelPool.returnToPool(s2aChannelPool.getChannel()); assertThat(fakeChannelPool.isChannelCached()).isTrue(); } @Test - public void returnChannel_failureBecauseChannelWasNotFromPool() throws Exception { + public void returnToPool_failureBecauseChannelWasNotFromPool() throws Exception { S2AChannelPool s2aChannelPool = S2AGrpcChannelPool.create(new FakeChannelPool()); IllegalArgumentException expected = assertThrows( IllegalArgumentException.class, - () -> s2aChannelPool.returnChannel(mock(Channel.class))); + () -> s2aChannelPool.returnToPool(mock(Channel.class))); assertThat(expected) .hasMessageThat() .isEqualTo( diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java index 82db6d4a144..ed7c71bb539 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -187,7 +187,7 @@ public Channel getChannel() { } @Override - public void returnChannel(Channel channel) {} + public void returnToPool(Channel channel) {} @Override public void close() {} From 46691df78e14e059b579af454b4f965243791283 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 13:31:57 -0700 Subject: [PATCH 04/27] add s2a to sync-protos script. --- buildscripts/sync-protos.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buildscripts/sync-protos.sh b/buildscripts/sync-protos.sh index 5f01be2e5c9..628b1688d4c 100755 --- a/buildscripts/sync-protos.sh +++ b/buildscripts/sync-protos.sh @@ -8,7 +8,7 @@ curl -Ls https://github.com/grpc/grpc-proto/archive/master.tar.gz | tar xz -C "$ base="$tmpdir/grpc-proto-master" # Copy protos in 'src/main/proto' from grpc-proto for these projects -for project in alts grpclb services rls interop-testing; do +for project in alts grpclb services s2a rls interop-testing; do while read -r proto; do [ -f "$base/$proto" ] && cp "$base/$proto" "$project/src/main/proto/$proto" echo "$proto" From 44fe552660ba1aa921c073e8605d99dc2361785c Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 14:00:28 -0700 Subject: [PATCH 05/27] S2AGrpcChannelPool remove unnecessary state check. --- s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java index 63d2d6fd736..4794cd9ee49 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AGrpcChannelPool.java @@ -68,9 +68,6 @@ private S2AGrpcChannelPool(ObjectPool channelPool) { @Override public synchronized Channel getChannel() { checkState(state.equals(State.OPEN), "Channel pool is not open."); - checkState( - numberOfUsersOfCachedChannel >= 0, - "Number of users of cached channel must be non-negative."); checkState( numberOfUsersOfCachedChannel < MAX_NUMBER_USERS_OF_CACHED_CHANNEL, "Max number of channels have been retrieved from the channel pool."); From 2dd1c7e7befd4d1eb35e29ac5a8307d30a9aed6e Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 14:25:31 -0700 Subject: [PATCH 06/27] update proto package to grpc.gcp.s2a. --- s2a/BUILD.bazel | 16 ++++++++-------- .../io/grpc/s2a/handshaker/S2AServiceGrpc.java | 4 ++-- .../main/proto/grpc/gcp/{ => s2a}/common.proto | 2 +- s2a/src/main/proto/grpc/gcp/{ => s2a}/s2a.proto | 6 +++--- .../proto/grpc/gcp/{ => s2a}/s2a_context.proto | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) rename s2a/src/main/proto/grpc/gcp/{ => s2a}/common.proto (99%) rename s2a/src/main/proto/grpc/gcp/{ => s2a}/s2a.proto (99%) rename s2a/src/main/proto/grpc/gcp/{ => s2a}/s2a_context.proto (97%) diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel index 0041ad52be6..1e519549727 100644 --- a/s2a/BUILD.bazel +++ b/s2a/BUILD.bazel @@ -133,28 +133,28 @@ java_library( # bazel only accepts proto import with absolute path. genrule( name = "protobuf_imports", - srcs = glob(["src/main/proto/grpc/gcp/*.proto"]), + srcs = glob(["src/main/proto/grpc/gcp/s2a/*.proto"]), outs = [ - "protobuf_out/grpc/gcp/s2a.proto", - "protobuf_out/grpc/gcp/s2a_context.proto", - "protobuf_out/grpc/gcp/common.proto", + "protobuf_out/grpc/gcp/s2a/s2a.proto", + "protobuf_out/grpc/gcp/s2a/s2a_context.proto", + "protobuf_out/grpc/gcp/s2a/common.proto", ], cmd = "for fname in $(SRCS); do " + "sed 's,import \",import \"s2a/protobuf_out/,g' $$fname > " + - "$(@D)/protobuf_out/grpc/gcp/$$(basename $$fname); done", + "$(@D)/protobuf_out/grpc/gcp/s2a/$$(basename $$fname); done", ) proto_library( name = "common_proto", srcs = [ - "protobuf_out/grpc/gcp/common.proto", + "protobuf_out/grpc/gcp/s2a/common.proto", ], ) proto_library( name = "s2a_context_proto", srcs = [ - "protobuf_out/grpc/gcp/s2a_context.proto", + "protobuf_out/grpc/gcp/s2a/s2a_context.proto", ], deps = [ ":common_proto", @@ -164,7 +164,7 @@ proto_library( proto_library( name = "s2a_proto", srcs = [ - "protobuf_out/grpc/gcp/s2a.proto", + "protobuf_out/grpc/gcp/s2a/s2a.proto", ], deps = [ ":common_proto", diff --git a/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java b/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java index fd6b991c039..b365954b189 100644 --- a/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java +++ b/s2a/src/generated/main/grpc/io/grpc/s2a/handshaker/S2AServiceGrpc.java @@ -6,13 +6,13 @@ */ @javax.annotation.Generated( value = "by gRPC proto compiler", - comments = "Source: grpc/gcp/s2a.proto") + comments = "Source: grpc/gcp/s2a/s2a.proto") @io.grpc.stub.annotations.GrpcGenerated public final class S2AServiceGrpc { private S2AServiceGrpc() {} - public static final java.lang.String SERVICE_NAME = "grpc.gcp.S2AService"; + public static final java.lang.String SERVICE_NAME = "grpc.gcp.s2a.S2AService"; // Static method descriptors that strictly reflect the proto. private static volatile io.grpc.MethodDescriptor Date: Mon, 29 Apr 2024 14:55:23 -0700 Subject: [PATCH 07/27] ConnectionIsClosedException --> ConnectionClosedException. --- s2a/BUILD.bazel | 2 +- ...on.java => ConnectionClosedException.java} | 4 ++-- .../java/io/grpc/s2a/handshaker/S2AStub.java | 20 +++++++++---------- .../io/grpc/s2a/handshaker/S2AStubTest.java | 8 ++++---- 4 files changed, 17 insertions(+), 17 deletions(-) rename s2a/src/main/java/io/grpc/s2a/handshaker/{ConnectionIsClosedException.java => ConnectionClosedException.java} (86%) diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel index 1e519549727..3e987e40228 100644 --- a/s2a/BUILD.bazel +++ b/s2a/BUILD.bazel @@ -66,7 +66,7 @@ java_library( java_library( name = "s2a_handshaker", srcs = [ - "src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java", + "src/main/java/io/grpc/s2a/handshaker/ConnectionClosedException.java", "src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java", "src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java", "src/main/java/io/grpc/s2a/handshaker/S2AConnectionException.java", diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionClosedException.java similarity index 86% rename from s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java rename to s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionClosedException.java index 1f9b2d5a23a..1a7f86bda91 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionIsClosedException.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ConnectionClosedException.java @@ -20,8 +20,8 @@ /** Indicates that a connection has been closed. */ @SuppressWarnings("serial") // This class is never serialized. -final class ConnectionIsClosedException extends IOException { - public ConnectionIsClosedException(String errorMessage) { +final class ConnectionClosedException extends IOException { + public ConnectionClosedException(String errorMessage) { super(errorMessage); } } \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java index aa2502cd4fa..0c6be56971d 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java @@ -74,12 +74,12 @@ BlockingQueue getResponses() { /** * Sends a request and returns the response. Caller must wait until this method executes prior to - * calling it again. If this method throws {@code ConnectionIsClosedException}, then it should not + * calling it again. If this method throws {@code ConnectionClosedException}, then it should not * be called again, and both {@code reader} and {@code writer} are closed. * * @param req the {@code SessionReq} message to be sent to the S2A server. * @return the {@code SessionResp} message received from the S2A server. - * @throws ConnectionIsClosedException if {@code reader} or {@code writer} calls their {@code + * @throws ConnectionClosedException if {@code reader} or {@code writer} calls their {@code * onCompleted} method. * @throws IOException if an unexpected response is received, or if the {@code reader} or {@code * writer} calls their {@code onError} method. @@ -87,7 +87,7 @@ BlockingQueue getResponses() { public SessionResp send(SessionReq req) throws IOException, InterruptedException { if (doneWriting && doneReading) { logger.log(Level.INFO, "Stream to the S2A is closed."); - throw new ConnectionIsClosedException("Stream to the S2A is closed."); + throw new ConnectionClosedException("Stream to the S2A is closed."); } createWriterIfNull(); if (!responses.isEmpty()) { @@ -121,10 +121,10 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException } try { return responses.take().getResultOrThrow(); - } catch (ConnectionIsClosedException e) { - // A ConnectionIsClosedException is thrown by getResultOrThrow when reader calls its + } catch (ConnectionClosedException e) { + // A ConnectionClosedException is thrown by getResultOrThrow when reader calls its // onCompleted method. The close method is called to also close the writer, and then the - // ConnectionIsClosedException is re-thrown in order to indicate to the caller that send + // ConnectionClosedException is re-thrown in order to indicate to the caller that send // should not be called again. close(); throw e; @@ -177,7 +177,7 @@ public void onError(Throwable t) { } /** - * Sets {@code doneReading} to true, and places a {@code ConnectionIsClosedException} in the + * Sets {@code doneReading} to true, and places a {@code ConnectionClosedException} in the * {@code responses} queue. */ @Override @@ -186,7 +186,7 @@ public void onCompleted() { doneReading = true; responses.offer( Result.createWithThrowable( - new ConnectionIsClosedException("Reading from the S2A is complete."))); + new ConnectionClosedException("Reading from the S2A is complete."))); } } @@ -211,8 +211,8 @@ private Result(Optional response, Optional throwable) { /** Throws {@code throwable} if present, and returns {@code response} otherwise. */ SessionResp getResultOrThrow() throws IOException { if (throwable.isPresent()) { - if (throwable.get() instanceof ConnectionIsClosedException) { - ConnectionIsClosedException exception = (ConnectionIsClosedException) throwable.get(); + if (throwable.get() instanceof ConnectionClosedException) { + ConnectionClosedException exception = (ConnectionClosedException) throwable.get(); throw exception; } else { throw new IOException(throwable.get()); diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java index a2b0e673313..6413874a99c 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java @@ -147,9 +147,9 @@ public void send_receiveErrorResponse() throws InterruptedException { public void send_receiveCompleteStatus() throws Exception { writer.setBehavior(FakeWriter.Behavior.COMPLETE_STATUS); - ConnectionIsClosedException expected = + ConnectionClosedException expected = assertThrows( - ConnectionIsClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + ConnectionClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); assertThat(expected).hasMessageThat().contains("Reading from the S2A is complete."); } @@ -216,9 +216,9 @@ public void send_afterEarlyClose_receivesClosedException() throws InterruptedExc stub.close(); expect.that(writer.isFakeWriterClosed()).isTrue(); - ConnectionIsClosedException expected = + ConnectionClosedException expected = assertThrows( - ConnectionIsClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); + ConnectionClosedException.class, () -> stub.send(SessionReq.getDefaultInstance())); assertThat(expected).hasMessageThat().contains("Stream to the S2A is closed."); } From f94cc100cbaff7ee0151ef923851018a715da987 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 15:30:40 -0700 Subject: [PATCH 08/27] identity() --> getIdentity(). --- .../io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java | 2 +- s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java | 2 +- .../main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java | 2 +- s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java | 2 +- s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java | 2 +- .../io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java | 2 +- .../java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java index 3b17a5ed322..b5fdbe76f36 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java @@ -49,7 +49,7 @@ static Optional getAuthMechanism(Optional authMechanism = Optional.of( AuthenticationMechanism.newBuilder() - .setIdentity(localIdentity.get().identity()) + .setIdentity(localIdentity.get().getIdentity()) .setToken(manager.getToken(localIdentity.get())) .build()); } diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java index 30957acd521..c4fed7377ac 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AIdentity.java @@ -56,7 +56,7 @@ private S2AIdentity(Identity identity) { } /** Returns the proto {@link Identity} representation of this identity instance. */ - public Identity identity() { + public Identity getIdentity() { return identity; } } \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java index fb4908d99fc..fb6d5761355 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2APrivateKeyMethod.java @@ -119,7 +119,7 @@ public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) .setSignatureAlgorithm(s2aSignatureAlgorithm) .setRawBytes(ByteString.copyFrom(input))); if (localIdentity.isPresent()) { - reqBuilder.setLocalIdentity(localIdentity.get().identity()); + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); } SessionResp resp = stub.send(reqBuilder.build()); diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java index 014fcf4c4f8..fb113bb29cc 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2ATrustManager.java @@ -114,7 +114,7 @@ private void checkPeerTrusted(X509Certificate[] chain, boolean isCheckingClientC SessionReq.Builder reqBuilder = SessionReq.newBuilder().setValidatePeerCertificateChainReq(validatePeerCertificateChainReq); if (localIdentity.isPresent()) { - reqBuilder.setLocalIdentity(localIdentity.get().identity()); + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); } SessionResp resp; diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java index bfa45146625..c8c7cdd3e04 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java @@ -102,7 +102,7 @@ private static GetTlsConfigurationResp.ClientTlsConfiguration getClientTlsConfig checkNotNull(stub, "stub should not be null."); SessionReq.Builder reqBuilder = SessionReq.newBuilder(); if (localIdentity.isPresent()) { - reqBuilder.setLocalIdentity(localIdentity.get().identity()); + reqBuilder.setLocalIdentity(localIdentity.get().getIdentity()); } Optional authMechanism = GetAuthenticationMechanisms.getAuthMechanism(localIdentity); diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java index aea279ed8c5..8070976a844 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java @@ -57,7 +57,7 @@ public void getAuthMechanisms_nonEmptyIdentity_success() { .isEqualTo( Optional.of( AuthenticationMechanism.newBuilder() - .setIdentity(fakeIdentity.identity()) + .setIdentity(fakeIdentity.getIdentity()) .setToken("access_token") .build())); } diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java index 4024e8a6e36..8252aa245d7 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2APrivateKeyMethodTest.java @@ -163,7 +163,7 @@ public void fakelocalIdentity_signWithSha256_success() throws Exception { S2APrivateKeyMethod.create(mockStub, Optional.of(fakeIdentity)); SessionReq req = SessionReq.newBuilder() - .setLocalIdentity(fakeIdentity.identity()) + .setLocalIdentity(fakeIdentity.getIdentity()) .setOffloadPrivateKeyOperationReq( OffloadPrivateKeyOperationReq.newBuilder() .setOperation(OffloadPrivateKeyOperationReq.PrivateKeyOperation.SIGN) From 217a3e417fef5c2a3261b6fcd82efec7fa5b6ffa Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 29 Apr 2024 16:49:13 -0700 Subject: [PATCH 09/27] annotate S2AChannelCredentials.Builder with NotThreadSafe. --- s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java index 4ad05b4541a..ca06cdf4689 100644 --- a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -39,7 +39,6 @@ * Configures gRPC to use S2A for transport security when establishing a secure channel. Only for * use on the client side of a gRPC connection. */ -@NotThreadSafe public final class S2AChannelCredentials { /** * Creates a channel credentials builder for establishing an S2A-secured connection. @@ -53,6 +52,7 @@ public static Builder createBuilder(String s2aAddress) { } /** Builds an {@code S2AChannelCredentials} instance. */ + @NotThreadSafe public static final class Builder { private final String s2aAddress; private ObjectPool s2aChannelPool; From b35f145be2cb13c6ef58159c3deefc55816efb64 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Tue, 30 Apr 2024 09:29:33 -0700 Subject: [PATCH 10/27] add values entered when generating certs. --- s2a/src/test/resources/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/s2a/src/test/resources/README.md b/s2a/src/test/resources/README.md index 00901015444..726b921a615 100644 --- a/s2a/src/test/resources/README.md +++ b/s2a/src/test/resources/README.md @@ -16,7 +16,8 @@ openssl genrsa -out server_key.pem 2048 openssl genrsa -out client_key.pem 2048 ``` -Generate CSRs for server and client +Generate CSRs for server and client (set Common Name to localhost, leave all +other fields blank) ``` openssl req -key server_key.pem -new -out server.csr -config config.cnf From f47c560980124babdd5af4fa365762559bfbc64c Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Tue, 30 Apr 2024 10:28:50 -0700 Subject: [PATCH 11/27] remove JCommander dependency. --- MODULE.bazel | 1 - repositories.bzl | 1 - s2a/BUILD.bazel | 2 +- s2a/build.gradle | 2 -- .../tokenmanager/SingleTokenFetcher.java | 25 +++++++------------ .../GetAuthenticationMechanismsTest.java | 5 +--- .../SingleTokenAccessTokenManagerTest.java | 11 +++----- 7 files changed, 15 insertions(+), 32 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 78e6ccb70f2..fc64139a460 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -43,7 +43,6 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "org.apache.tomcat:annotations-api:6.0.53", "org.checkerframework:checker-qual:3.12.0", "org.codehaus.mojo:animal-sniffer-annotations:1.23", - "org.jcommander:jcommander:1.83", ] # GRPC_DEPS_END diff --git a/repositories.bzl b/repositories.bzl index 7ed5141fec3..2abb19977a5 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -47,7 +47,6 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "org.apache.tomcat:annotations-api:6.0.53", "org.checkerframework:checker-qual:3.12.0", "org.codehaus.mojo:animal-sniffer-annotations:1.23", - "org.jcommander:jcommander:1.83", ] # GRPC_DEPS_END diff --git a/s2a/BUILD.bazel b/s2a/BUILD.bazel index 3e987e40228..5aeaedbe358 100644 --- a/s2a/BUILD.bazel +++ b/s2a/BUILD.bazel @@ -59,7 +59,7 @@ java_library( deps = [ ":s2a_identity", ":token_fetcher", - artifact("org.jcommander:jcommander"), + artifact("com.google.guava:guava"), ], ) diff --git a/s2a/build.gradle b/s2a/build.gradle index 054039571d8..403ac93552f 100644 --- a/s2a/build.gradle +++ b/s2a/build.gradle @@ -26,7 +26,6 @@ dependencies { libraries.protobuf.java, libraries.conscrypt, libraries.guava.jre // JRE required by protobuf-java-util from grpclb - compileOnly 'org.jcommander:jcommander:1.83' def nettyDependency = implementation project(':grpc-netty') compileOnly libraries.javax.annotation @@ -44,7 +43,6 @@ dependencies { libraries.conscrypt, libraries.netty.transport.epoll - testImplementation 'org.jcommander:jcommander:1.83' testImplementation 'com.google.truth:truth:1.4.2' testImplementation 'com.google.truth.extensions:truth-proto-extension:1.4.2' testImplementation libraries.guava.testlib diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java index 3b2bd051e84..c3dffd2b715 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenFetcher.java @@ -16,27 +16,15 @@ package io.grpc.s2a.handshaker.tokenmanager; -import com.beust.jcommander.Parameter; -import com.beust.jcommander.Parameters; +import com.google.common.annotations.VisibleForTesting; import io.grpc.s2a.handshaker.S2AIdentity; import java.util.Optional; /** Fetches a single access token via an environment variable. */ +@SuppressWarnings("NonFinalStaticField") public final class SingleTokenFetcher implements TokenFetcher { private static final String ENVIRONMENT_VARIABLE = "S2A_ACCESS_TOKEN"; - - /** Set an access token via a flag. */ - @Parameters(separators = "=") - public static class Flags { - @Parameter( - names = "--s2a_access_token", - description = "The access token used to authenticate to S2A.") - private static String accessToken = System.getenv(ENVIRONMENT_VARIABLE); - - public synchronized void reset() { - accessToken = null; - } - } + private static String accessToken = System.getenv(ENVIRONMENT_VARIABLE); private final String token; @@ -45,7 +33,12 @@ public synchronized void reset() { * {@code Optional} instance if the token could not be fetched. */ public static Optional create() { - return Optional.ofNullable(Flags.accessToken).map(SingleTokenFetcher::new); + return Optional.ofNullable(accessToken).map(SingleTokenFetcher::new); + } + + @VisibleForTesting + public static void setAccessToken(String token) { + accessToken = token; } private SingleTokenFetcher(String token) { diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java index 8070976a844..884e1ec88eb 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/GetAuthenticationMechanismsTest.java @@ -16,7 +16,6 @@ package io.grpc.s2a.handshaker; -import com.beust.jcommander.JCommander; import com.google.common.truth.Expect; import io.grpc.s2a.handshaker.S2AIdentity; import io.grpc.s2a.handshaker.tokenmanager.SingleTokenFetcher; @@ -32,13 +31,11 @@ public final class GetAuthenticationMechanismsTest { @Rule public final Expect expect = Expect.create(); private static final String TOKEN = "access_token"; - private static final String[] SET_TOKEN = {"--s2a_access_token", TOKEN}; - private static final SingleTokenFetcher.Flags FLAGS = new SingleTokenFetcher.Flags(); @BeforeClass public static void setUpClass() { // Set the token that the client will use to authenticate to the S2A. - JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + SingleTokenFetcher.setAccessToken(TOKEN); } @Test diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java index 806e412b784..80adba07f20 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/tokenmanager/SingleTokenAccessTokenManagerTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; -import com.beust.jcommander.JCommander; import io.grpc.s2a.handshaker.S2AIdentity; import java.util.Optional; import org.junit.Before; @@ -30,17 +29,15 @@ public final class SingleTokenAccessTokenManagerTest { private static final S2AIdentity IDENTITY = S2AIdentity.fromSpiffeId("spiffe_id"); private static final String TOKEN = "token"; - private static final String[] SET_TOKEN = {"--s2a_access_token", TOKEN}; - private static final SingleTokenFetcher.Flags FLAGS = new SingleTokenFetcher.Flags(); @Before public void setUp() { - FLAGS.reset(); + SingleTokenFetcher.setAccessToken(null); } @Test public void getDefaultToken_success() throws Exception { - JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + SingleTokenFetcher.setAccessToken(TOKEN); Optional manager = AccessTokenManager.create(); assertThat(manager).isPresent(); assertThat(manager.get().getDefaultToken()).isEqualTo(TOKEN); @@ -48,7 +45,7 @@ public void getDefaultToken_success() throws Exception { @Test public void getToken_success() throws Exception { - JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + SingleTokenFetcher.setAccessToken(TOKEN); Optional manager = AccessTokenManager.create(); assertThat(manager).isPresent(); assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); @@ -61,7 +58,7 @@ public void getToken_noEnvironmentVariable() throws Exception { @Test public void create_success() throws Exception { - JCommander.newBuilder().addObject(FLAGS).build().parse(SET_TOKEN); + SingleTokenFetcher.setAccessToken(TOKEN); Optional manager = AccessTokenManager.create(); assertThat(manager).isPresent(); assertThat(manager.get().getToken(IDENTITY)).isEqualTo(TOKEN); From 1d41d1037b47e36abc335d5e2019afd1679d165b Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 9 Sep 2024 10:59:48 -0700 Subject: [PATCH 12/27] Migrate away from deprecated functions. --- .../main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java index b2aee6db49e..56f612502bf 100644 --- a/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/MtlsToS2AChannelCredentials.java @@ -76,10 +76,10 @@ public S2AChannelCredentials.Builder build() throws GeneralSecurityException, IO File trustBundleFile = new File(trustBundlePath); AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); - keyManager.updateIdentityCredentialsFromFile(privateKeyFile, certChainFile); + keyManager.updateIdentityCredentials(certChainFile, privateKeyFile); AdvancedTlsX509TrustManager trustManager = AdvancedTlsX509TrustManager.newBuilder().build(); - trustManager.updateTrustCredentialsFromFile(trustBundleFile); + trustManager.updateTrustCredentials(trustBundleFile); ChannelCredentials channelToS2ACredentials = TlsChannelCredentials.newBuilder() From 35084afb3ab3615f949470c44af857e6496276fc Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 9 Sep 2024 11:01:04 -0700 Subject: [PATCH 13/27] Remove logging before errors thrown in S2AStub. --- s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java index 0c6be56971d..79f42dd590b 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java @@ -100,11 +100,6 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException } responses.clear(); if (exception != null) { - logger.log( - Level.WARNING, - "Received an unexpected response from a host at the S2A's address. The S2A might be" - + " unavailable. " - + exception.getMessage()); throw new IOException( "Received an unexpected response from a host at the S2A's address. The S2A might be" + " unavailable." @@ -115,7 +110,6 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException try { writer.onNext(req); } catch (RuntimeException e) { - logger.log(Level.WARNING, "Error occurred while writing to the S2A.", e); writer.onError(e); responses.offer(Result.createWithThrowable(e)); } @@ -172,7 +166,6 @@ public void onNext(SessionResp resp) { */ @Override public void onError(Throwable t) { - logger.log(Level.WARNING, "Error occurred while reading from the S2A.", t); responses.offer(Result.createWithThrowable(t)); } From da330cd1f7c495ddff53becdba8450b38ad83750 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 9 Sep 2024 13:01:12 -0700 Subject: [PATCH 14/27] Build set of TLS versions from S2Av2's GetTlsConfigResp. --- .../io/grpc/s2a/handshaker/ProtoUtil.java | 24 +++++++++++++ .../s2a/handshaker/SslContextFactory.java | 15 ++++---- .../io/grpc/s2a/handshaker/FakeWriter.java | 18 +++++++++- .../io/grpc/s2a/handshaker/ProtoUtilTest.java | 36 +++++++++++++++++++ .../s2a/handshaker/SslContextFactoryTest.java | 28 ++++++++------- 5 files changed, 100 insertions(+), 21 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java index 34cc4bbe737..ccf3138d953 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java @@ -16,6 +16,8 @@ package io.grpc.s2a.handshaker; +import com.google.common.collect.ImmutableSet; + /** Converts proto messages to Netty strings. */ final class ProtoUtil { /** @@ -68,5 +70,27 @@ static String convertTlsProtocolVersion(TLSVersion tlsVersion) { } } + /** + * Builds a set of strings representing all {@link TLSVersion}s between {@code minTlsVersion} and + * {@code maxTlsVersion}. + */ + static ImmutableSet buildTlsProtocolVersionSet( + TLSVersion minTlsVersion, TLSVersion maxTlsVersion) { + ImmutableSet.Builder tlsVersions = ImmutableSet.builder(); + for (TLSVersion tlsVersion : TLSVersion.values()) { + int versionNumber; + try { + versionNumber = tlsVersion.getNumber(); + } catch (IllegalArgumentException e) { + continue; + } + if (versionNumber < minTlsVersion.getNumber() || versionNumber > maxTlsVersion.getNumber()) { + continue; + } + tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); + } + return tlsVersions.build(); + } + private ProtoUtil() {} } \ No newline at end of file diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java index c8c7cdd3e04..1ac5887ebc4 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/SslContextFactory.java @@ -19,7 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static java.nio.charset.StandardCharsets.UTF_8; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.grpc.netty.GrpcSslContexts; import io.grpc.s2a.handshaker.S2AIdentity; import io.netty.handler.ssl.OpenSslContextOption; @@ -138,14 +138,13 @@ private static void configureSslContextWithClientTlsConfiguration( NoSuchAlgorithmException, UnrecoverableKeyException { sslContextBuilder.keyManager(createKeylessManager(clientTlsConfiguration)); - sslContextBuilder.protocols( - ProtoUtil.convertTlsProtocolVersion(clientTlsConfiguration.getMinTlsVersion()), - ProtoUtil.convertTlsProtocolVersion(clientTlsConfiguration.getMaxTlsVersion())); - ImmutableList.Builder ciphersuites = ImmutableList.builder(); - for (int i = 0; i < clientTlsConfiguration.getCiphersuitesCount(); ++i) { - ciphersuites.add(ProtoUtil.convertCiphersuite(clientTlsConfiguration.getCiphersuites(i))); + ImmutableSet tlsVersions = + ProtoUtil.buildTlsProtocolVersionSet( + clientTlsConfiguration.getMinTlsVersion(), clientTlsConfiguration.getMaxTlsVersion()); + if (tlsVersions.isEmpty()) { + throw new S2AConnectionException("Set of TLS versions received from S2A server is empty."); } - sslContextBuilder.ciphers(ciphersuites.build()); + sslContextBuilder.protocols(tlsVersions); } private static KeyManager createKeylessManager( diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java index 505a0cf4a3a..45961b81b7b 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/FakeWriter.java @@ -16,6 +16,7 @@ package io.grpc.s2a.handshaker; +import static io.grpc.s2a.handshaker.TLSVersion.TLS_VERSION_1_2; import static io.grpc.s2a.handshaker.TLSVersion.TLS_VERSION_1_3; import com.google.common.collect.ImmutableMap; @@ -39,7 +40,8 @@ enum Behavior { EMPTY_RESPONSE, ERROR_STATUS, ERROR_RESPONSE, - COMPLETE_STATUS + COMPLETE_STATUS, + BAD_TLS_VERSION_RESPONSE, } enum VerificationResult { @@ -213,6 +215,20 @@ public void onNext(SessionReq sessionReq) { case COMPLETE_STATUS: reader.onCompleted(); break; + case BAD_TLS_VERSION_RESPONSE: + reader.onNext( + SessionResp.newBuilder() + .setGetTlsConfigurationResp( + GetTlsConfigurationResp.newBuilder() + .setClientTlsConfiguration( + GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder() + .addCertificateChain(LEAF_CERT) + .addCertificateChain(INTERMEDIATE_CERT_2) + .addCertificateChain(INTERMEDIATE_CERT_1) + .setMinTlsVersion(TLS_VERSION_1_3) + .setMaxTlsVersion(TLS_VERSION_1_2))) + .build()); + break; default: reader.onNext(handleResponse(sessionReq)); } diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java index 0191398b6b7..6d134b43f7a 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/ProtoUtilTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableSet; import com.google.common.truth.Expect; import org.junit.Rule; import org.junit.Test; @@ -92,4 +93,39 @@ public void convertTlsProtocolVersion_withUnknownTlsVersion_fails() { () -> ProtoUtil.convertTlsProtocolVersion(TLSVersion.TLS_VERSION_UNSPECIFIED)); expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); } + + @Test + public void buildTlsProtocolVersionSet_success() { + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_0, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1", "TLSv1.1", "TLSv1.2", "TLSv1.3")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_2, TLSVersion.TLS_VERSION_1_2)) + .isEqualTo(ImmutableSet.of("TLSv1.2")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_3, TLSVersion.TLS_VERSION_1_3)) + .isEqualTo(ImmutableSet.of("TLSv1.3")); + expect + .that( + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_1_3, TLSVersion.TLS_VERSION_1_2)) + .isEmpty(); + } + + @Test + public void buildTlsProtocolVersionSet_failure() { + AssertionError expected = + assertThrows( + AssertionError.class, + () -> + ProtoUtil.buildTlsProtocolVersionSet( + TLSVersion.TLS_VERSION_UNSPECIFIED, TLSVersion.TLS_VERSION_1_3)); + expect.that(expected).hasMessageThat().isEqualTo("TLS version 0 is not supported."); + } } \ No newline at end of file diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java index c33fd820e4c..a2a66a7b563 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/SslContextFactoryTest.java @@ -58,12 +58,6 @@ public void createForClient_returnsValidSslContext() throws Exception { expect.that(sslContext.sessionTimeout()).isEqualTo(300); expect.that(sslContext.isClient()).isTrue(); expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); - expect - .that(sslContext.cipherSuites()) - .containsExactly( - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); SSLSessionContext sslSessionContext = sslContext.sessionContext(); if (sslSessionContext instanceof OpenSslSessionContext) { OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; @@ -82,12 +76,6 @@ public void createForClient_withLocalIdentity_returnsValidSslContext() throws Ex expect.that(sslContext.sessionTimeout()).isEqualTo(300); expect.that(sslContext.isClient()).isTrue(); expect.that(sslContext.applicationProtocolNegotiator().protocols()).containsExactly("h2"); - expect - .that(sslContext.cipherSuites()) - .containsExactly( - "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", - "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", - "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"); SSLSessionContext sslSessionContext = sslContext.sessionContext(); if (sslSessionContext instanceof OpenSslSessionContext) { OpenSslSessionContext openSslSessionContext = (OpenSslSessionContext) sslSessionContext; @@ -141,6 +129,22 @@ public void createForClient_getsErrorFromServer_throwsError() throws Exception { .contains("Failed to get client TLS configuration from S2A."); } + @Test + public void createForClient_getsBadTlsVersionsFromServer_throwsError() throws Exception { + writer.setBehavior(FakeWriter.Behavior.BAD_TLS_VERSION_RESPONSE); + + S2AConnectionException expected = + assertThrows( + S2AConnectionException.class, + () -> + SslContextFactory.createForClient( + stub, FAKE_TARGET_NAME, /* localIdentity= */ Optional.empty())); + + assertThat(expected) + .hasMessageThat() + .contains("Set of TLS versions received from S2A server is empty."); + } + @Test public void createForClient_nullStub_throwsError() throws Exception { writer.sendUnexpectedResponse(); From 655f0bd893629ac46484209e8164733f8fc58adf Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 9 Sep 2024 13:02:29 -0700 Subject: [PATCH 15/27] S2AStub uses withWaitForReady. --- s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java | 5 ++++- s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java index 79f42dd590b..8249ca59d09 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java @@ -142,7 +142,10 @@ public void close() { private void createWriterIfNull() { if (writer == null) { writer = - serviceStub.withDeadlineAfter(HANDSHAKE_RPC_DEADLINE_SECS, SECONDS).setUpSession(reader); + serviceStub + .withWaitForReady() + .withDeadlineAfter(HANDSHAKE_RPC_DEADLINE_SECS, SECONDS) + .setUpSession(reader); } } diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java index 6413874a99c..bb90be12b6a 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AStubTest.java @@ -62,7 +62,7 @@ public void send_receiveOkStatus() throws Exception { IOException expected = assertThrows(IOException.class, () -> newStub.send(SessionReq.getDefaultInstance())); - assertThat(expected).hasMessageThat().contains("UNAVAILABLE"); + assertThat(expected).hasMessageThat().contains("DEADLINE_EXCEEDED"); } @Test From 3184cdcfddda29ebcfe45b1011c882df1489f5b2 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 9 Sep 2024 13:12:04 -0700 Subject: [PATCH 16/27] Don't block on SslContext creation in Java S2A client. --- .../S2AProtocolNegotiatorFactory.java | 135 ++++++++++++------ .../S2AProtocolNegotiatorFactoryTest.java | 17 +++ 2 files changed, 110 insertions(+), 42 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index 7f00e198fae..f453b4903b2 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -20,9 +20,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HostAndPort; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.ThreadSafe; import io.grpc.Channel; -import io.grpc.ChannelLogger; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiator; @@ -33,16 +37,15 @@ import io.grpc.s2a.channel.S2AGrpcChannelPool; import io.grpc.s2a.handshaker.S2AIdentity; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; -import java.io.IOException; -import java.security.GeneralSecurityException; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.UnrecoverableKeyException; -import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; +import java.util.concurrent.Executors; import org.checkerframework.checker.nullness.qual.Nullable; /** Factory for performing negotiation of a secure channel using the S2A. */ @@ -96,6 +99,8 @@ static final class S2AProtocolNegotiator implements ProtocolNegotiator { private final S2AChannelPool channelPool; private final Optional localIdentity; + private final ListeningExecutorService service = + MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); static S2AProtocolNegotiator createForClient( S2AChannelPool channelPool, Optional localIdentity) { @@ -128,65 +133,111 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); checkNotNull(hostname, "hostname should not be null."); return new S2AProtocolNegotiationHandler( - InternalProtocolNegotiators.grpcNegotiationHandler(grpcHandler), - grpcHandler.getNegotiationLogger(), - channelPool, - localIdentity, - hostname, - grpcHandler); + grpcHandler, channelPool, localIdentity, hostname, service); } @Override public void close() { + service.shutdown(); channelPool.close(); } } + @VisibleForTesting + static class BufferReadsHandler extends ChannelInboundHandlerAdapter { + private final List reads = new ArrayList<>(); + private boolean readComplete; + + public List getReads() { + return reads; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + reads.add(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + readComplete = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + for (Object msg : reads) { + super.channelRead(ctx, msg); + } + if (readComplete) { + super.channelReadComplete(ctx); + } + } + } + private static final class S2AProtocolNegotiationHandler extends ProtocolNegotiationHandler { private final S2AChannelPool channelPool; private final Optional localIdentity; private final String hostname; - private InternalProtocolNegotiator.ProtocolNegotiator negotiator; private final GrpcHttp2ConnectionHandler grpcHandler; + private final ListeningExecutorService service; private S2AProtocolNegotiationHandler( - ChannelHandler next, - ChannelLogger negotiationLogger, + GrpcHttp2ConnectionHandler grpcHandler, S2AChannelPool channelPool, Optional localIdentity, String hostname, - GrpcHttp2ConnectionHandler grpcHandler) { - super(next, negotiationLogger); + ListeningExecutorService service) { + super( + // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' + // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior + // here and then manually add 'next' when we call fireProtocolNegotiationEvent() + new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ctx.pipeline().remove(this); + } + }, + grpcHandler.getNegotiationLogger()); + this.grpcHandler = grpcHandler; this.channelPool = channelPool; this.localIdentity = localIdentity; this.hostname = hostname; - this.grpcHandler = grpcHandler; + checkNotNull(service, "service should not be null."); + this.service = service; } @Override - protected void handlerAdded0(ChannelHandlerContext ctx) throws GeneralSecurityException { - SslContext sslContext; - try { - // Establish a stream to S2A server. - Channel ch = channelPool.getChannel(); - S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch); - S2AStub s2aStub = S2AStub.newInstance(stub); - sslContext = SslContextFactory.createForClient(s2aStub, hostname, localIdentity); - } catch (InterruptedException - | IOException - | IllegalArgumentException - | UnrecoverableKeyException - | CertificateException - | NoSuchAlgorithmException - | KeyStoreException e) { - // GeneralSecurityException is intentionally not caught, and rather propagated. This is done - // because throwing a GeneralSecurityException in this context indicates that we encountered - // a retryable error. - throw new IllegalArgumentException( - "Something went wrong during the initialization of SslContext.", e); - } - negotiator = InternalProtocolNegotiators.tls(sslContext); - ctx.pipeline().addBefore(ctx.name(), /* name= */ null, negotiator.newHandler(grpcHandler)); + protected void handlerAdded0(ChannelHandlerContext ctx) { + // Buffer all reads until the TLS Handler is added. + BufferReadsHandler bufferReads = new BufferReadsHandler(); + ctx.pipeline().addBefore(ctx.name(), /* name= */ null, bufferReads); + + Channel ch = channelPool.getChannel(); + S2AServiceGrpc.S2AServiceStub stub = S2AServiceGrpc.newStub(ch); + S2AStub s2aStub = S2AStub.newInstance(stub); + + ListenableFuture sslContextFuture = + service.submit(() -> SslContextFactory.createForClient(s2aStub, hostname, localIdentity)); + Futures.addCallback( + sslContextFuture, + new FutureCallback() { + @Override + public void onSuccess(SslContext sslContext) { + ChannelHandler handler = + InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + + // Remove the bufferReads handler and delegate the rest of the handshake to the TLS + // handler. + ctx.pipeline().addAfter(ctx.name(), /* name= */ null, handler); + fireProtocolNegotiationEvent(ctx); + ctx.pipeline().remove(bufferReads); + } + + @Override + public void onFailure(Throwable t) { + ctx.fireExceptionCaught(t); + } + }, + service); } } diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java index ed7c71bb539..7328840735c 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -45,6 +45,7 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http2.Http2ConnectionDecoder; import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2Settings; @@ -85,6 +86,22 @@ public void tearDown() { fakeS2AServer.shutdown(); } + @Test + public void handlerRemoved_success() throws Exception { + S2AProtocolNegotiatorFactory.BufferReadsHandler handler1 = + new S2AProtocolNegotiatorFactory.BufferReadsHandler(); + S2AProtocolNegotiatorFactory.BufferReadsHandler handler2 = + new S2AProtocolNegotiatorFactory.BufferReadsHandler(); + EmbeddedChannel channel = new EmbeddedChannel(handler1, handler2); + channel.writeInbound("message1"); + channel.writeInbound("message2"); + channel.writeInbound("message3"); + assertThat(handler1.getReads()).hasSize(3); + assertThat(handler2.getReads()).isEmpty(); + channel.pipeline().remove(handler1); + assertThat(handler2.getReads()).hasSize(3); + } + @Test public void createProtocolNegotiatorFactory_nullArgument() throws Exception { NullPointerTester tester = new NullPointerTester().setDefault(Optional.class, Optional.empty()); From f96d3956c34af47f5cf214a79930726471d43ed7 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Mon, 9 Sep 2024 14:17:38 -0700 Subject: [PATCH 17/27] use javax.annotation.Nullable in S2AProtocolNegotiatorFactory. --- .../io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index f453b4903b2..e88e7c4c9f9 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -46,7 +46,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.Executors; -import org.checkerframework.checker.nullness.qual.Nullable; +import javax.annotation.Nullable; /** Factory for performing negotiation of a secure channel using the S2A. */ @ThreadSafe From 38b0a3a2c254785f375d4c4e9407f6405bf7df08 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 09:23:43 -0700 Subject: [PATCH 18/27] getChannel() doesn't block. --- s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java index 9c849442dca..e5caf5e69bd 100644 --- a/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java +++ b/s2a/src/main/java/io/grpc/s2a/channel/S2AChannelPool.java @@ -26,7 +26,7 @@ public interface S2AChannelPool extends AutoCloseable { /** * Retrieves an open channel to the S2A from the channel pool. * - *

If no channel is available, blocks until a channel can be retrieved from the channel pool. + * @throws IllegalStateException if no channel is available. */ @CanIgnoreReturnValue Channel getChannel(); From 12586b11508fd5fca7058e8e414428b93d7e45f2 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 09:33:04 -0700 Subject: [PATCH 19/27] Remove unnecessary local variable in getAuthMechanism. --- .../grpc/s2a/handshaker/GetAuthenticationMechanisms.java | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java index b5fdbe76f36..56d74a9b766 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/GetAuthenticationMechanisms.java @@ -33,7 +33,6 @@ final class GetAuthenticationMechanisms { * @return an {@link AuthenticationMechanism} for the given local identity. */ static Optional getAuthMechanism(Optional localIdentity) { - Optional authMechanism = Optional.empty(); if (!TOKEN_MANAGER.isPresent()) { return Optional.empty(); } @@ -41,19 +40,16 @@ static Optional getAuthMechanism(Optional // If no identity is provided, fetch the default access token and DO NOT attach an identity // to the request. if (!localIdentity.isPresent()) { - authMechanism = - Optional.of( + return Optional.of( AuthenticationMechanism.newBuilder().setToken(manager.getDefaultToken()).build()); } else { // Fetch an access token for the provided identity. - authMechanism = - Optional.of( + return Optional.of( AuthenticationMechanism.newBuilder() .setIdentity(localIdentity.get().getIdentity()) .setToken(manager.getToken(localIdentity.get())) .build()); } - return authMechanism; } private GetAuthenticationMechanisms() {} From 08f83421b3ece4a7e44a84637971451f141f5f9e Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 09:45:50 -0700 Subject: [PATCH 20/27] Invert if statement in ProtoUtil to improve readability. --- s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java index ccf3138d953..59e3931d9e6 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java @@ -84,10 +84,10 @@ static ImmutableSet buildTlsProtocolVersionSet( } catch (IllegalArgumentException e) { continue; } - if (versionNumber < minTlsVersion.getNumber() || versionNumber > maxTlsVersion.getNumber()) { - continue; + if (versionNumber >= minTlsVersion.getNumber() + && versionNumber <= maxTlsVersion.getNumber()) { + tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); } - tlsVersions.add(convertTlsProtocolVersion(tlsVersion)); } return tlsVersions.build(); } From 7f267121aa662379a1c137de27bf0f73c8b76cda Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 09:49:26 -0700 Subject: [PATCH 21/27] Check localIdentity is null before setting it. --- s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java index ca06cdf4689..73072a83f91 100644 --- a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -73,6 +73,7 @@ public static final class Builder { @CanIgnoreReturnValue public Builder setLocalSpiffeId(String localSpiffeId) { checkNotNull(localSpiffeId); + checkArgument(localIdentity == null, "localIdentity is already set."); localIdentity = S2AIdentity.fromSpiffeId(localSpiffeId); return this; } @@ -85,6 +86,7 @@ public Builder setLocalSpiffeId(String localSpiffeId) { @CanIgnoreReturnValue public Builder setLocalHostname(String localHostname) { checkNotNull(localHostname); + checkArgument(localIdentity == null, "localIdentity is already set."); localIdentity = S2AIdentity.fromHostname(localHostname); return this; } @@ -97,6 +99,7 @@ public Builder setLocalHostname(String localHostname) { @CanIgnoreReturnValue public Builder setLocalUid(String localUid) { checkNotNull(localUid); + checkArgument(localIdentity == null, "localIdentity is already set."); localIdentity = S2AIdentity.fromUid(localUid); return this; } From 19583a7b136989d6c1d4b88be814cbe2e3e15cc4 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 10:32:29 -0700 Subject: [PATCH 22/27] Check hostname not null or empty. --- .../io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index e88e7c4c9f9..7bab6729320 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -16,7 +16,9 @@ package io.grpc.s2a.handshaker; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Strings.isNullOrEmpty; import com.google.common.annotations.VisibleForTesting; import com.google.common.net.HostAndPort; @@ -131,7 +133,7 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { checkNotNull(grpcHandler, "grpcHandler should not be null."); String hostname = getHostNameFromAuthority(grpcHandler.getAuthority()); - checkNotNull(hostname, "hostname should not be null."); + checkArgument(!isNullOrEmpty(hostname), "hostname should not be null or empty."); return new S2AProtocolNegotiationHandler( grpcHandler, channelPool, localIdentity, hostname, service); } From 3198eec7ca420d098d0c37d793774bf84e72d2d0 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 10:36:00 -0700 Subject: [PATCH 23/27] Change channelRead argument ctx to unused. --- .../io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index 7bab6729320..e156cbed0a8 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -155,7 +155,7 @@ public List getReads() { } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { + public void channelRead(ChannelHandlerContext unused, Object msg) { reads.add(msg); } From 0e059e4d1a3a4477abe53f7d9fd9f411a9395990 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 10:39:19 -0700 Subject: [PATCH 24/27] Remove unnecessary waitForReady() in IntegrationTest. --- s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java index f9de9765527..04dd9472bf9 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java @@ -262,7 +262,7 @@ public static boolean doUnaryRpc(ExecutorService executor, ManagedChannel channe throws InterruptedException { try { SimpleServiceGrpc.SimpleServiceBlockingStub stub = - SimpleServiceGrpc.newBlockingStub(channel).withWaitForReady(); + SimpleServiceGrpc.newBlockingStub(channel); SimpleResponse resp = stub.unaryRpc(SimpleRequest.newBuilder() .setRequestMessage("S2A team") .build()); From b021a2128ec2dc92fa403c190d4aeaa60c259a78 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Thu, 12 Sep 2024 15:56:21 -0700 Subject: [PATCH 25/27] Push down the creation of Optional until S2AProtocolNegotiator. --- .../java/io/grpc/s2a/S2AChannelCredentials.java | 7 +------ .../handshaker/S2AProtocolNegotiatorFactory.java | 16 +++++++++------- .../S2AProtocolNegotiatorFactoryTest.java | 8 ++++---- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java index 73072a83f91..8a5f1f51350 100644 --- a/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java +++ b/s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java @@ -122,12 +122,7 @@ public ChannelCredentials build() { } InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() { - if (localIdentity == null) { - return S2AProtocolNegotiatorFactory.createClientFactory(Optional.empty(), s2aChannelPool); - } else { - return S2AProtocolNegotiatorFactory.createClientFactory( - Optional.of(localIdentity), s2aChannelPool); - } + return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool); } } diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index e156cbed0a8..b8ada3e0340 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -66,20 +66,19 @@ public final class S2AProtocolNegotiatorFactory { * @return a factory for creating a client-side protocol negotiator. */ public static InternalProtocolNegotiator.ClientFactory createClientFactory( - Optional localIdentity, ObjectPool s2aChannelPool) { + @Nullable S2AIdentity localIdentity, ObjectPool s2aChannelPool) { checkNotNull(s2aChannelPool, "S2A channel pool should not be null."); - checkNotNull(localIdentity, "Local identity should not be null on the client side."); S2AChannelPool channelPool = S2AGrpcChannelPool.create(s2aChannelPool); return new S2AClientProtocolNegotiatorFactory(localIdentity, channelPool); } static final class S2AClientProtocolNegotiatorFactory implements InternalProtocolNegotiator.ClientFactory { - private final Optional localIdentity; + private final @Nullable S2AIdentity localIdentity; private final S2AChannelPool channelPool; S2AClientProtocolNegotiatorFactory( - Optional localIdentity, S2AChannelPool channelPool) { + @Nullable S2AIdentity localIdentity, S2AChannelPool channelPool) { this.localIdentity = localIdentity; this.channelPool = channelPool; } @@ -105,10 +104,13 @@ static final class S2AProtocolNegotiator implements ProtocolNegotiator { MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1)); static S2AProtocolNegotiator createForClient( - S2AChannelPool channelPool, Optional localIdentity) { + S2AChannelPool channelPool, @Nullable S2AIdentity localIdentity) { checkNotNull(channelPool, "Channel pool should not be null."); - checkNotNull(localIdentity, "Local identity should not be null on the client side."); - return new S2AProtocolNegotiator(channelPool, localIdentity); + if (localIdentity == null) { + return new S2AProtocolNegotiator(channelPool, Optional.empty()); + } else { + return new S2AProtocolNegotiator(channelPool, Optional.of(localIdentity)); + } } @VisibleForTesting diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java index 7328840735c..f130e52aac7 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactoryTest.java @@ -128,7 +128,7 @@ public void createProtocolNegotiator_nullArgument() throws Exception { @Test public void createProtocolNegotiatorFactory_getsDefaultPort_succeeds() throws Exception { InternalProtocolNegotiator.ClientFactory clientFactory = - S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); assertThat(clientFactory.getDefaultPort()).isEqualTo(S2AProtocolNegotiatorFactory.DEFAULT_PORT); } @@ -152,7 +152,7 @@ public void s2aProtocolNegotiator_getHostNameOnValidAuthority_returnsValidHostna public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClientSide_succeeds() throws Exception { InternalProtocolNegotiator.ClientFactory clientFactory = - S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); @@ -164,7 +164,7 @@ public void createProtocolNegotiatorFactory_buildsAnS2AProtocolNegotiatorOnClien public void closeProtocolNegotiator_verifyProtocolNegotiatorIsClosedOnClientSide() throws Exception { InternalProtocolNegotiator.ClientFactory clientFactory = - S2AProtocolNegotiatorFactory.createClientFactory(Optional.of(LOCAL_IDENTITY), channelPool); + S2AProtocolNegotiatorFactory.createClientFactory(LOCAL_IDENTITY, channelPool); ProtocolNegotiator clientNegotiator = clientFactory.newNegotiator(); clientNegotiator.close(); @@ -182,7 +182,7 @@ public void createChannelHandler_addHandlerToMockContext() throws Exception { FakeS2AChannelPool fakeChannelPool = new FakeS2AChannelPool(channel); ProtocolNegotiator clientNegotiator = S2AProtocolNegotiatorFactory.S2AProtocolNegotiator.createForClient( - fakeChannelPool, Optional.of(LOCAL_IDENTITY)); + fakeChannelPool, LOCAL_IDENTITY); ChannelHandler channelHandler = clientNegotiator.newHandler(fakeConnectionHandler); From 752627a8e5b383f9bb06c5296bd1691a5987f1da Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Fri, 13 Sep 2024 08:31:51 -0700 Subject: [PATCH 26/27] Wait for servers to be terminated in tearDown in IntegrationTest.java. --- .../java/io/grpc/s2a/handshaker/IntegrationTest.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java index 04dd9472bf9..859771a4afa 100644 --- a/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java +++ b/s2a/src/test/java/io/grpc/s2a/handshaker/IntegrationTest.java @@ -52,7 +52,6 @@ import javax.net.ssl.SSLSessionContext; import org.junit.After; import org.junit.Before; -import org.junit.BeforeClass; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -140,11 +139,6 @@ public final class IntegrationTest { private String serverAddress; private Server server; - @BeforeClass - public static void setUpClass() { - System.setProperty("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", "false"); - } - @Before public void setUp() throws Exception { s2aPort = Utils.pickUnusedPort(); @@ -186,9 +180,13 @@ public void setUp() throws Exception { @After public void tearDown() throws Exception { + server.awaitTermination(10, SECONDS); server.shutdown(); + s2aServer.awaitTermination(10, SECONDS); s2aServer.shutdown(); + s2aDelayServer.awaitTermination(10, SECONDS); s2aDelayServer.shutdown(); + mtlsS2AServer.awaitTermination(10, SECONDS); mtlsS2AServer.shutdown(); } From a8cacb0f14d1235adb47be598b8b49b5096ba778 Mon Sep 17 00:00:00 2001 From: Riya Mehta Date: Fri, 13 Sep 2024 10:06:55 -0700 Subject: [PATCH 27/27] mark unused ctx in channelReadComplete. --- .../io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java index b8ada3e0340..25d1e325ea8 100644 --- a/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/handshaker/S2AProtocolNegotiatorFactory.java @@ -162,7 +162,7 @@ public void channelRead(ChannelHandlerContext unused, Object msg) { } @Override - public void channelReadComplete(ChannelHandlerContext ctx) { + public void channelReadComplete(ChannelHandlerContext unused) { readComplete = true; }