Skip to content

Commit

Permalink
fix: stream leak in akka-http client backend for server stream (#1832)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Johan Andrén <[email protected]>
Co-authored-by: Patrik Nordwall <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2023
1 parent bee8429 commit 77b8e49
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (C) 2009-2023 Lightbend Inc. <https://www.lightbend.com>
*/

package example.myapp.helloworld

import akka.Done
import akka.NotUsed
import akka.actor.testkit.typed.scaladsl.ScalaTestWithActorTestKit
import akka.grpc.GrpcClientSettings
import akka.http.scaladsl.Http
import akka.stream.scaladsl.Sink
import akka.stream.scaladsl.Source
import example.myapp.helloworld.grpc.GreeterService
import example.myapp.helloworld.grpc.GreeterServiceClient
import example.myapp.helloworld.grpc.GreeterServiceHandler
import example.myapp.helloworld.grpc.HelloReply
import example.myapp.helloworld.grpc.HelloRequest
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration.DurationInt
import scala.util.Success

class AkkaHttpClientCancelSpec
extends ScalaTestWithActorTestKit("""
akka.http.server.enable-http2 = true
akka.grpc.client."*" {
backend = "akka-http"
use-tls = false
}
""")
with AnyWordSpecLike
with Matchers
with ScalaFutures {

"The Akka HTTP client" should {

"correctly cancel a server streaming call" in {
val probe = createTestProbe[Any]()
implicit val ec: ExecutionContext = system.executionContext

val handler = GreeterServiceHandler(new GreeterService {
override def sayHello(in: HelloRequest): Future[HelloReply] = ???
override def itKeepsTalking(in: Source[HelloRequest, NotUsed]): Future[HelloReply] = ???
override def streamHellos(in: Source[HelloRequest, NotUsed]): Source[HelloReply, NotUsed] = ???

override def itKeepsReplying(in: HelloRequest): Source[HelloReply, NotUsed] = {
Source
.single(HelloReply.defaultInstance)
// keep the stream alive indefinitely
.concat(Source.maybe[HelloReply])
// tell probe when we start and when we complete
.watchTermination() { (_, termination) =>
probe.ref ! "started"
termination.onComplete { t =>
probe.ref ! t
}
NotUsed
}
}

})

val binding =
Http().newServerAt("127.0.0.1", 0).bind(handler).futureValue

val client =
GreeterServiceClient(GrpcClientSettings.connectToServiceAt("127.0.0.1", binding.localAddress.getPort))
client.itKeepsReplying(HelloRequest.defaultInstance).runWith(Sink.head)

probe.expectMessage("started")
probe.expectMessage(5.seconds, Success(Done))
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ object AkkaHttpClientUtils {
s"${scheme}://${settings.overrideAuthority.getOrElse(settings.serviceName)}/" + descriptor.getFullMethodName),
GrpcEntityHelpers.metadataHeaders(headers.entries),
source)
responseToSource(httpRequest.uri, singleRequest(httpRequest), deserializer)
responseToSource(httpRequest.uri, singleRequest(httpRequest), deserializer, streamingResponse)
}
}
}
Expand All @@ -182,7 +182,11 @@ object AkkaHttpClientUtils {
* INTERNAL API
*/
@InternalApi
def responseToSource[O](requestUri: Uri, response: Future[HttpResponse], deserializer: ProtobufSerializer[O])(
def responseToSource[O](
requestUri: Uri,
response: Future[HttpResponse],
deserializer: ProtobufSerializer[O],
streamingResponse: Boolean)(
implicit ec: ExecutionContext,
mat: Materializer): Source[O, Future[GrpcResponseMetadata]] = {
Source.lazyFutureSource[O, Future[GrpcResponseMetadata]](() => {
Expand Down Expand Up @@ -221,14 +225,19 @@ object AkkaHttpClientUtils {
response.entity.discardBytes()
throw mapToStatusException(requestUri, response, Seq.empty)
}
responseData
val baseFlow = responseData
// This never adds any data to the stream, but makes sure it fails with the correct error code if applicable
.concat(
Source
.maybe[ByteString]
.mapMaterializedValue(promise => promise.completeWith(completionFuture.map(_ => None))))
val flow = if (streamingResponse) {
baseFlow
} else {
// Make sure we continue reading to get the trailing header even if we're no longer interested in the rest of the body
.via(new CancellationBarrierGraphStage)
baseFlow.via(new CancellationBarrierGraphStage)
}
flow
.via(reader.dataFrameDecoder)
.map(deserializer.deserialize)
.mapMaterializedValue(_ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class AkkaHttpClientUtilsSpec extends TestKit(ActorSystem()) with AnyWordSpecLik
val requestUri = Uri("https://example.com/GuestExeSample/GrpcHello")
val response =
Future.successful(HttpResponse(NotFound, entity = Strict(GrpcProtocolNative.contentType, ByteString.empty)))
val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null)
val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null, false)

val failure = source.run().failed.futureValue
// https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md
Expand All @@ -43,7 +43,7 @@ class AkkaHttpClientUtilsSpec extends TestKit(ActorSystem()) with AnyWordSpecLik
val requestUri = Uri("https://example.com/GuestExeSample/GrpcHello")
val response = Future.successful(
HttpResponse(OK, List(RawHeader("grpc-status", "9")), Strict(GrpcProtocolNative.contentType, ByteString.empty)))
val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null)
val source = AkkaHttpClientUtils.responseToSource(requestUri, response, null, false)

val failure = source.run().failed.futureValue
failure.asInstanceOf[StatusRuntimeException].getStatus.getCode should be(Status.Code.FAILED_PRECONDITION)
Expand Down

0 comments on commit 77b8e49

Please sign in to comment.