diff --git a/build.sbt b/build.sbt index b285d4e192..abea213307 100644 --- a/build.sbt +++ b/build.sbt @@ -347,6 +347,7 @@ lazy val io = crossProject(JVMPlatform, JSPlatform, NativePlatform) .dependsOn(core % "compile->compile;test->test") .jsSettings( mimaBinaryIssueFilters ++= Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]("fs2.io.net.tls.TLSSocket.forAsync"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("fs2.io.package.stdinUtf8"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("fs2.io.package.stdoutLines"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("fs2.io.package.stdout"), diff --git a/io/js/src/main/scala/fs2/io/internal/facade/events.scala b/io/js/src/main/scala/fs2/io/internal/facade/events.scala index 8cdb1041db..c05fa8d3f9 100644 --- a/io/js/src/main/scala/fs2/io/internal/facade/events.scala +++ b/io/js/src/main/scala/fs2/io/internal/facade/events.scala @@ -34,6 +34,8 @@ import scala.scalajs.js @nowarn212("cat=unused") private[io] trait EventEmitter extends js.Object { + protected[io] def on(eventName: String, listener: js.Function0[Unit]): this.type = js.native + protected[io] def on[E](eventName: String, listener: js.Function1[E, Unit]): this.type = js.native protected[io] def on[E, F](eventName: String, listener: js.Function2[E, F, Unit]): this.type = diff --git a/io/js/src/main/scala/fs2/io/ioplatform.scala b/io/js/src/main/scala/fs2/io/ioplatform.scala index 011f770764..3989185608 100644 --- a/io/js/src/main/scala/fs2/io/ioplatform.scala +++ b/io/js/src/main/scala/fs2/io/ioplatform.scala @@ -31,7 +31,6 @@ import cats.effect.std.Queue import cats.effect.syntax.all._ import cats.syntax.all._ import fs2.concurrent.Channel -import fs2.io.internal.MicrotaskExecutor import fs2.io.internal.facade import java.nio.charset.Charset @@ -58,53 +57,106 @@ private[fs2] trait ioplatform { def suspendReadableAndRead[F[_], R <: Readable]( destroyIfNotEnded: Boolean = true, destroyIfCanceled: Boolean = true - )(thunk: => R)(implicit F: Async[F]): Resource[F, (R, Stream[F, Byte])] = - (for { - dispatcher <- Dispatcher.sequential[F] - channel <- Channel.unbounded[F, Unit].toResource - error <- F.deferred[Throwable].toResource - readableResource = for { - readable <- Resource.makeCase(F.delay(thunk)) { - case (readable, Resource.ExitCase.Succeeded) => + )(thunk: => R)(implicit F: Async[F]): Resource[F, (R, Stream[F, Byte])] = { + + final class Listener { + private[this] var readableCounter = 0 + private[this] var error: Either[Throwable, Boolean] = null + private[this] var ended = false + private[this] var callback: Either[Throwable, Boolean] => Unit = null + + def handleReadable(): Unit = + if (callback eq null) { + readableCounter += 1 + } else { + callback(Right(true)) + callback = null + } + + def handleEnd(): Unit = { + ended = true + if (readableCounter == 0 && (callback ne null)) { + callback(Right(false)) + } + } + + def handleError(e: js.Error): Unit = { + error = Left(js.JavaScriptException(e)) + if (callback ne null) { + callback(error) + } + } + + private[this] def next: F[Boolean] = F.async { cb => + F.delay { + if (error ne null) { + cb(error) + None + } else if (readableCounter > 0) { + cb(Right(true)) + readableCounter -= 1 + None + } else if (ended) { + cb(Right(false)) + None + } else { + callback = cb + Some(F.delay { callback = null }) + } + } + } + + def readableEvents: Stream[F, Unit] = { + def go: Pull[F, Unit, Unit] = + Pull.eval(next).flatMap { continue => + if (continue) + Pull.outUnit >> go + else + Pull.done + } + + go.streamNoScope + } + + } + + Resource + .eval(F.delay(new Listener)) + .flatMap { listener => + Resource + .makeCase { F.delay { - if (!readable.readableEnded & destroyIfNotEnded) - readable.destroy() + val readable = thunk + readable.on("readable", () => listener.handleReadable()) + readable.once("error", listener.handleError(_)) + readable.once("end", () => listener.handleEnd()) + readable } - case (readable, Resource.ExitCase.Errored(_)) => - // tempting, but don't propagate the error! - // that would trigger a unhandled Node.js error that circumvents FS2/CE error channels - F.delay(readable.destroy()) - case (readable, Resource.ExitCase.Canceled) => - if (destroyIfCanceled) + } { + case (readable, Resource.ExitCase.Succeeded) => + F.delay { + if (!readable.readableEnded & destroyIfNotEnded) + readable.destroy() + } + case (readable, Resource.ExitCase.Errored(_)) => + // tempting, but don't propagate the error! + // that would trigger a unhandled Node.js error that circumvents FS2/CE error channels F.delay(readable.destroy()) - else - F.unit - } - _ <- readable.registerListener[F, Any]("readable", dispatcher)(_ => channel.send(()).void) - _ <- readable.registerListener[F, Any]("end", dispatcher)(_ => channel.close.void) - _ <- readable.registerListener[F, Any]("close", dispatcher)(_ => channel.close.void) - _ <- readable.registerListener[F, js.Error]("error", dispatcher) { e => - error.complete(js.JavaScriptException(e)).void - } - } yield readable - // Implementation note: why run on the MicrotaskExecutor? - // In many cases creating a `Readable` starts async side-effects (e.g. negotiating TLS handshake or opening a file handle). - // Furthermore, these side-effects will invoke the listeners we register to the `Readable`. - // Therefore, it is critical that the listeners are registered to the `Readable` _before_ these async side-effects occur: - // in other words, before we next yield (cede) to the event loop. Because an arbitrary effect `F` (particularly `IO`) may cede at any time, - // our only recourse is to run the entire creation/listener registration process on the microtask executor. - readable <- readableResource.evalOn(MicrotaskExecutor) - stream = - (channel.stream - .concurrently(Stream.eval(error.get.flatMap(F.raiseError[Unit]))) >> - Stream - .evalUnChunk( - F.delay( - Option(readable.read()) - .fold(Chunk.empty[Byte])(Chunk.uint8Array) + case (readable, Resource.ExitCase.Canceled) => + if (destroyIfCanceled) + F.delay(readable.destroy()) + else + F.unit + } + .fproduct { readable => + listener.readableEvents.adaptError { case IOException(ex) => ex } >> + Stream.evalUnChunk( + F.delay(Option(readable.read()).fold(Chunk.empty[Byte])(Chunk.uint8Array(_))) ) - )).adaptError { case IOException(ex) => ex } - } yield (readable, stream)).adaptError { case IOException(ex) => ex } + } + } + .adaptError { case IOException(ex) => ex } + } /** `Pipe` that converts a stream of bytes to a stream that will emit a single `Readable`, * that ends whenever the resulting stream terminates. diff --git a/io/js/src/main/scala/fs2/io/net/SocketPlatform.scala b/io/js/src/main/scala/fs2/io/net/SocketPlatform.scala index 4f7ecab9ea..e39734ab15 100644 --- a/io/js/src/main/scala/fs2/io/net/SocketPlatform.scala +++ b/io/js/src/main/scala/fs2/io/net/SocketPlatform.scala @@ -53,7 +53,7 @@ private[net] trait SocketCompanionPlatform { } } - private[net] class AsyncSocket[F[_]]( + private[net] case class AsyncSocket[F[_]]( sock: facade.net.Socket, readStream: SuspendedStream[F, Byte] )(implicit F: Async[F]) diff --git a/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala b/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala index a506e02f58..0ef4bfe933 100644 --- a/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala +++ b/io/js/src/main/scala/fs2/io/net/tls/TLSContextPlatform.scala @@ -62,13 +62,40 @@ private[tls] trait TLSContextCompanionPlatform { self: TLSContext.type => clientMode: Boolean, params: TLSParameters, logger: TLSLogger[F] - ): Resource[F, TLSSocket[F]] = (Dispatcher.sequential[F], Dispatcher.parallel[F]) - .flatMapN { (seqDispatcher, parDispatcher) => - if (clientMode) { - Resource.eval(F.deferred[Either[Throwable, Unit]]).flatMap { handshake => + ): Resource[F, TLSSocket[F]] = { + + final class Listener { + private[this] var value: Either[Throwable, Unit] = null + private[this] var callback: Either[Throwable, Unit] => Unit = null + + def complete(value: Either[Throwable, Unit]): Unit = + if (callback ne null) { + callback(value) + callback = null + } else { + this.value = value + } + + def get: F[Unit] = F.async { cb => + F.delay { + if (value ne null) { + cb(value) + None + } else { + callback = cb + Some(F.delay { callback = null }) + } + } + } + } + + (Dispatcher.parallel[F], Resource.eval(F.delay(new Listener))) + .flatMapN { (parDispatcher, listener) => + if (clientMode) { TLSSocket .forAsync( socket, + clientMode, sock => { val options = params.toTLSConnectOptions(parDispatcher) options.secureContext = context @@ -79,25 +106,21 @@ private[tls] trait TLSContextCompanionPlatform { self: TLSContext.type => val tlsSock = facade.tls.connect(options) tlsSock.once( "secureConnect", - () => seqDispatcher.unsafeRunAndForget(handshake.complete(Either.unit)) + () => listener.complete(Either.unit) ) tlsSock.once[js.Error]( "error", - e => - seqDispatcher.unsafeRunAndForget( - handshake.complete(Left(new js.JavaScriptException(e))) - ) + e => listener.complete(Left(new js.JavaScriptException(e))) ) tlsSock } ) - .evalTap(_ => handshake.get.rethrow) - } - } else { - Resource.eval(F.deferred[Either[Throwable, Unit]]).flatMap { verifyError => + .evalTap(_ => listener.get) + } else { TLSSocket .forAsync( socket, + clientMode, sock => { val options = params.toTLSSocketOptions(parDispatcher) options.secureContext = context @@ -117,24 +140,21 @@ private[tls] trait TLSContextCompanionPlatform { self: TLSContext.type => .map(e => new JavaScriptSSLException(js.JavaScriptException(e))) .toLeft(()) else Either.unit - seqDispatcher.unsafeRunAndForget(verifyError.complete(result)) + listener.complete(result) } ) tlsSock.once[js.Error]( "error", - e => - seqDispatcher.unsafeRunAndForget( - verifyError.complete(Left(new js.JavaScriptException(e))) - ) + e => listener.complete(Left(new js.JavaScriptException(e))) ) tlsSock } ) - .evalTap(_ => verifyError.get.rethrow) + .evalTap(_ => listener.get) } } - } - .adaptError { case IOException(ex) => ex } + .adaptError { case IOException(ex) => ex } + } } def fromSecureContext(context: SecureContext): TLSContext[F] = diff --git a/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala b/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala index 050f9d8353..d54624f3c4 100644 --- a/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala +++ b/io/js/src/main/scala/fs2/io/net/tls/TLSSocketPlatform.scala @@ -38,16 +38,29 @@ private[tls] trait TLSSocketCompanionPlatform { self: TLSSocket.type => private[tls] def forAsync[F[_]]( socket: Socket[F], + clientMode: Boolean, upgrade: fs2.io.Duplex => facade.tls.TLSSocket )(implicit F: Async[F]): Resource[F, TLSSocket[F]] = for { - duplexOut <- mkDuplex(socket.reads) - (duplex, out) = duplexOut - _ <- out.through(socket.writes).compile.drain.background - tlsSockReadable <- suspendReadableAndRead( - destroyIfNotEnded = false, - destroyIfCanceled = false - )(upgrade(duplex)) + tlsSockReadable <- socket match { + case Socket.AsyncSocket(sock, _) if clientMode => + for { + tlsSockReadable <- suspendReadableAndRead( + destroyIfNotEnded = false, + destroyIfCanceled = false + )(upgrade(sock)) + } yield tlsSockReadable + case _ => + for { + duplexOut <- mkDuplex(socket.reads) + (duplex, out) = duplexOut + _ <- out.through(socket.writes).compile.drain.background + tlsSockReadable <- suspendReadableAndRead( + destroyIfNotEnded = false, + destroyIfCanceled = false + )(upgrade(duplex)) + } yield tlsSockReadable + } (tlsSock, readable) = tlsSockReadable readStream <- SuspendedStream(readable) } yield new AsyncTLSSocket(