diff --git a/plugin-tester-scala/src/test/scala/example/myapp/helloworld/AkkaHttpClientCancelSpec.scala b/plugin-tester-scala/src/test/scala/example/myapp/helloworld/AkkaHttpClientCancelSpec.scala new file mode 100644 index 000000000..556306e5f --- /dev/null +++ b/plugin-tester-scala/src/test/scala/example/myapp/helloworld/AkkaHttpClientCancelSpec.scala @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2009-2023 Lightbend Inc. + */ + +package example.myapp.helloworld + +import akka.Done +import akka.NotUsed +import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit +import akka.grpc.GrpcClientSettings +import akka.http.scaladsl.Http +import akka.stream.scaladsl.Sink +import akka.stream.scaladsl.Source +import example.myapp.helloworld.grpc.GreeterService +import example.myapp.helloworld.grpc.GreeterServiceClient +import example.myapp.helloworld.grpc.GreeterServiceHandler +import example.myapp.helloworld.grpc.HelloReply +import example.myapp.helloworld.grpc.HelloRequest +import org.scalatest.concurrent.ScalaFutures +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpecLike + +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.DurationInt +import scala.util.Success + +class AkkaHttpClientCancelSpec + extends ScalaTestWithActorTestKit(""" + akka.http.server.enable-http2 = true + akka.grpc.client."*" { + backend = "akka-http" + use-tls = false + } + """) + with AnyWordSpecLike + with Matchers + with ScalaFutures { + + "The Akka HTTP client" should { + + "correctly cancel a server streaming call" in { + val probe = createTestProbe[Any]() + implicit val ec: ExecutionContext = system.executionContext + + val handler = GreeterServiceHandler(new GreeterService { + override def sayHello(in: HelloRequest): Future[HelloReply] = ??? + override def itKeepsTalking(in: Source[HelloRequest, NotUsed]): Future[HelloReply] = ??? + override def streamHellos(in: Source[HelloRequest, NotUsed]): Source[HelloReply, NotUsed] = ??? + + override def itKeepsReplying(in: HelloRequest): Source[HelloReply, NotUsed] = { + Source + .single(HelloReply.defaultInstance) + // keep the stream alive indefinitely + .concat(Source.maybe[HelloReply]) + // tell probe when we start and when we complete + .watchTermination() { (_, termination) => + probe.ref ! "started" + termination.onComplete { t => + probe.ref ! t + } + NotUsed + } + } + + }) + + val binding = + Http().newServerAt("127.0.0.1", 0).bind(handler).futureValue + + val client = + GreeterServiceClient(GrpcClientSettings.connectToServiceAt("127.0.0.1", binding.localAddress.getPort)) + client.itKeepsReplying(HelloRequest.defaultInstance).runWith(Sink.head) + + probe.expectMessage("started") + probe.expectMessage(5.seconds, Success(Done)) + } + + } + +} diff --git a/runtime/src/main/scala/akka/grpc/internal/AkkaHttpClientUtils.scala b/runtime/src/main/scala/akka/grpc/internal/AkkaHttpClientUtils.scala index 073325be1..026c67a67 100644 --- a/runtime/src/main/scala/akka/grpc/internal/AkkaHttpClientUtils.scala +++ b/runtime/src/main/scala/akka/grpc/internal/AkkaHttpClientUtils.scala @@ -173,7 +173,7 @@ object AkkaHttpClientUtils { s"${scheme}://${settings.overrideAuthority.getOrElse(settings.serviceName)}/" + descriptor.getFullMethodName), GrpcEntityHelpers.metadataHeaders(headers.entries), source) - responseToSource(httpRequest.uri, singleRequest(httpRequest), deserializer) + responseToSource(httpRequest.uri, singleRequest(httpRequest), deserializer, streamingResponse) } } } @@ -182,7 +182,11 @@ object AkkaHttpClientUtils { * INTERNAL API */ @InternalApi - def responseToSource[O](requestUri: Uri, response: Future[HttpResponse], deserializer: ProtobufSerializer[O])( + def responseToSource[O]( + requestUri: Uri, + response: Future[HttpResponse], + deserializer: ProtobufSerializer[O], + streamingResponse: Boolean)( implicit ec: ExecutionContext, mat: Materializer): Source[O, Future[GrpcResponseMetadata]] = { Source.lazyFutureSource[O, Future[GrpcResponseMetadata]](() => { @@ -221,14 +225,19 @@ object AkkaHttpClientUtils { response.entity.discardBytes() throw mapToStatusException(requestUri, response, Seq.empty) } - responseData + val baseFlow = responseData // This never adds any data to the stream, but makes sure it fails with the correct error code if applicable .concat( Source .maybe[ByteString] .mapMaterializedValue(promise => promise.completeWith(completionFuture.map(_ => None)))) + val flow = if (streamingResponse) { + baseFlow + } else { // Make sure we continue reading to get the trailing header even if we're no longer interested in the rest of the body - .via(new CancellationBarrierGraphStage) + baseFlow.via(new CancellationBarrierGraphStage) + } + flow .via(reader.dataFrameDecoder) .map(deserializer.deserialize) .mapMaterializedValue(_ => diff --git a/runtime/src/test/scala/akka/grpc/internal/AkkaHttpClientUtilsSpec.scala b/runtime/src/test/scala/akka/grpc/internal/AkkaHttpClientUtilsSpec.scala index 9f9abe5bc..7c0fe3b41 100644 --- a/runtime/src/test/scala/akka/grpc/internal/AkkaHttpClientUtilsSpec.scala +++ b/runtime/src/test/scala/akka/grpc/internal/AkkaHttpClientUtilsSpec.scala @@ -32,7 +32,7 @@ class AkkaHttpClientUtilsSpec extends TestKit(ActorSystem()) with AnyWordSpecLik val requestUri = Uri("https://example.com/GuestExeSample/GrpcHello") val response = Future.successful(HttpResponse(NotFound, entity = Strict(GrpcProtocolNative.contentType, ByteString.empty))) - val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null) + val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null, false) val failure = source.run().failed.futureValue // https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md @@ -43,7 +43,7 @@ class AkkaHttpClientUtilsSpec extends TestKit(ActorSystem()) with AnyWordSpecLik val requestUri = Uri("https://example.com/GuestExeSample/GrpcHello") val response = Future.successful( HttpResponse(OK, List(RawHeader("grpc-status", "9")), Strict(GrpcProtocolNative.contentType, ByteString.empty))) - val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null) + val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null, false) val failure = source.run().failed.futureValue failure.asInstanceOf[StatusRuntimeException].getStatus.getCode should be(Status.Code.FAILED_PRECONDITION)