diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 13efaaac..244fe175 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -136,11 +136,11 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - run: mkdir -p examples/target target .js/target site/target servlet/target .jvm/target .native/target project/target + run: mkdir -p benchmarks/target examples/target target .js/target site/target servlet/target .jvm/target .native/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - run: tar cf targets.tar examples/target target .js/target site/target servlet/target .jvm/target .native/target project/target + run: tar cf targets.tar benchmarks/target examples/target target .js/target site/target servlet/target .jvm/target .native/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') diff --git a/benchmarks/src/main/scala/org/http4s/servlet/ServletIoBenchmarks.scala b/benchmarks/src/main/scala/org/http4s/servlet/ServletIoBenchmarks.scala new file mode 100644 index 00000000..bdb9152d --- /dev/null +++ b/benchmarks/src/main/scala/org/http4s/servlet/ServletIoBenchmarks.scala @@ -0,0 +1,183 @@ +package org.http4s.servlet + +import cats.effect.IO +import cats.effect.std.Dispatcher +import cats.effect.unsafe.implicits.global + +import org.openjdk.jmh.annotations._ +import org.http4s.servlet.NonBlockingServletIo + +import java.io.ByteArrayInputStream +import java.util.concurrent.TimeUnit +import javax.servlet.{ServletInputStream, ReadListener} +import javax.servlet.http.HttpServletRequest +import scala.util.Random + +/** To do comparative benchmarks between versions: + * + * benchmarks/run-benchmark AsyncBenchmark + * + * This will generate results in `benchmarks/results`. + * + * Or to run the benchmark from within sbt: + * + * Jmh / run -i 10 -wi 10 -f 2 -t 1 cats.effect.benchmarks.AsyncBenchmark + * + * Which means "10 iterations", "10 warm-up iterations", "2 forks", "1 thread". Please note that + * benchmarks should be usually executed at least in 10 iterations (as a rule of thumb), but + * more is better. + */ +@State(Scope.Thread) +@BenchmarkMode(Array(Mode.Throughput)) +@OutputTimeUnit(TimeUnit.SECONDS) +class ServletIoBenchmarks { + + @Param(Array("100000")) + var size: Int = _ + + @Param(Array("1000")) + var iters: Int = _ + + def servletRequest: HttpServletRequest = new HttpServletRequestStub( + new TestServletInputStream(Random.nextBytes(size)) + ) + + @Benchmark + def reader() = { + val req = servletRequest + val servletIo = NonBlockingServletIo[IO](4096) + + def loop(i: Int): IO[Unit] = + if (i == iters) IO.unit else servletIo.reader(req).compile.drain >> loop(i + 1) + + loop(0).unsafeRunSync() + } + + @Benchmark + def requestBody() = { + val req = servletRequest + val servletIo = NonBlockingServletIo[IO](4096) + + def loop(i: Int): IO[Unit] = + if (i == iters) IO.unit + else Dispatcher.sequential[IO].use { dispatcher => + servletIo.requestBody(req, dispatcher).compile.drain + } >> loop(i + 1) + + loop(0).unsafeRunSync() + } + + class TestServletInputStream(body: Array[Byte]) extends ServletInputStream { + private var readListener: ReadListener = null + private val in = new ByteArrayInputStream(body) + + override def isReady: Boolean = true + + override def isFinished: Boolean = in.available() == 0 + + override def setReadListener(readListener: ReadListener): Unit = { + this.readListener = readListener + readListener.onDataAvailable() + } + + override def read(): Int = { + val result = in.read() + if (in.available() == 0) + readListener.onAllDataRead() + result + } + + override def read(buf: Array[Byte]) = { + val result = in.read(buf) + if (in.available() == 0) + readListener.onAllDataRead() + result + } + + override def read(buf: Array[Byte], off: Int, len: Int) = { + val result = in.read(buf, off, len) + if (in.available() == 0) + readListener.onAllDataRead() + result + } + } + + case class HttpServletRequestStub( + inputStream: ServletInputStream + ) extends HttpServletRequest { + def getInputStream(): ServletInputStream = inputStream + + def authenticate(x$1: javax.servlet.http.HttpServletResponse): Boolean = ??? + def changeSessionId(): String = ??? + def getAuthType(): String = ??? + def getContextPath(): String = ??? + def getCookies(): Array[javax.servlet.http.Cookie] = ??? + def getDateHeader(x$1: String): Long = ??? + def getHeader(x$1: String): String = ??? + def getHeaderNames(): java.util.Enumeration[String] = ??? + def getHeaders(x$1: String): java.util.Enumeration[String] = ??? + def getIntHeader(x$1: String): Int = ??? + def getMethod(): String = ??? + def getPart(x$1: String): javax.servlet.http.Part = ??? + def getParts(): java.util.Collection[javax.servlet.http.Part] = ??? + def getPathInfo(): String = ??? + def getPathTranslated(): String = ??? + def getQueryString(): String = ??? + def getRemoteUser(): String = ??? + def getRequestURI(): String = ??? + def getRequestURL(): StringBuffer = ??? + def getRequestedSessionId(): String = ??? + def getServletPath(): String = ??? + def getSession(): javax.servlet.http.HttpSession = ??? + def getSession(x$1: Boolean): javax.servlet.http.HttpSession = ??? + def getUserPrincipal(): java.security.Principal = ??? + def isRequestedSessionIdFromCookie(): Boolean = ??? + def isRequestedSessionIdFromURL(): Boolean = ??? + def isRequestedSessionIdFromUrl(): Boolean = ??? + def isRequestedSessionIdValid(): Boolean = ??? + def isUserInRole(x$1: String): Boolean = ??? + def login(x$1: String, x$2: String): Unit = ??? + def logout(): Unit = ??? + def upgrade[T <: javax.servlet.http.HttpUpgradeHandler](x$1: Class[T]): T = ??? + def getAsyncContext(): javax.servlet.AsyncContext = ??? + def getAttribute(x$1: String): Object = ??? + def getAttributeNames(): java.util.Enumeration[String] = ??? + def getCharacterEncoding(): String = ??? + def getContentLength(): Int = ??? + def getContentLengthLong(): Long = ??? + def getContentType(): String = ??? + def getDispatcherType(): javax.servlet.DispatcherType = ??? + def getLocalAddr(): String = ??? + def getLocalName(): String = ??? + def getLocalPort(): Int = ??? + def getLocale(): java.util.Locale = ??? + def getLocales(): java.util.Enumeration[java.util.Locale] = ??? + def getParameter(x$1: String): String = ??? + def getParameterMap(): java.util.Map[String, Array[String]] = ??? + def getParameterNames(): java.util.Enumeration[String] = ??? + def getParameterValues(x$1: String): Array[String] = ??? + def getProtocol(): String = ??? + def getReader(): java.io.BufferedReader = ??? + def getRealPath(x$1: String): String = ??? + def getRemoteAddr(): String = ??? + def getRemoteHost(): String = ??? + def getRemotePort(): Int = ??? + def getRequestDispatcher(x$1: String): javax.servlet.RequestDispatcher = ??? + def getScheme(): String = ??? + def getServerName(): String = ??? + def getServerPort(): Int = ??? + def getServletContext(): javax.servlet.ServletContext = ??? + def isAsyncStarted(): Boolean = ??? + def isAsyncSupported(): Boolean = ??? + def isSecure(): Boolean = ??? + def removeAttribute(x$1: String): Unit = ??? + def setAttribute(x$1: String, x$2: Object): Unit = ??? + def setCharacterEncoding(x$1: String): Unit = ??? + def startAsync( + x$1: javax.servlet.ServletRequest, + x$2: javax.servlet.ServletResponse, + ): javax.servlet.AsyncContext = ??? + def startAsync(): javax.servlet.AsyncContext = ??? + } + +} diff --git a/build.sbt b/build.sbt index 76410da4..80d31df9 100644 --- a/build.sbt +++ b/build.sbt @@ -43,6 +43,7 @@ lazy val servlet = project "org.eclipse.jetty" % "jetty-servlet" % jettyVersion % Test, "org.http4s" %% "http4s-dsl" % http4sVersion % Test, "org.http4s" %% "http4s-server" % http4sVersion, + "org.typelevel" %% "cats-effect" % "3.4.5", "org.typelevel" %% "munit-cats-effect-3" % munitCatsEffectVersion % Test, ), ) @@ -64,3 +65,16 @@ lazy val examples = project .dependsOn(servlet) lazy val docs = project.in(file("site")).enablePlugins(TypelevelSitePlugin) + +lazy val benchmarks = project + .in(file("benchmarks")) + .dependsOn(servlet) + .settings( + name := "servlet-benchmarks", + libraryDependencies ++= Seq( + "javax.servlet" % "javax.servlet-api" % servletApiVersion, + ), + javaOptions ++= Seq( + "-Dcats.effect.tracing.mode=none", + "-Dcats.effect.tracing.exceptions.enhanced=false")) + .enablePlugins(NoPublishPlugin, JmhPlugin) diff --git a/project/plugins.sbt b/project/plugins.sbt index ab1a29f0..d8df1722 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,2 +1,3 @@ addSbtPlugin("com.earldouglas" % "xsbt-web-plugin" % "4.2.4") addSbtPlugin("org.http4s" % "sbt-http4s-org" % "0.14.9") +addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.3") diff --git a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala index 6d179e35..19375c5d 100644 --- a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala +++ b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala @@ -213,14 +213,11 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl servletRequest: HttpServletRequest, dispatcher: Dispatcher[F], ): Stream[F, Byte] = { - sealed trait Read - final case class Bytes(chunk: Chunk[Byte]) extends Read - case object End extends Read - final case class Error(t: Throwable) extends Read + case object End Stream.eval(F.delay(servletRequest.getInputStream)).flatMap { in => - Stream.eval(Queue.bounded[F, Read](4)).flatMap { q => - val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener { + Stream.eval(Queue.bounded[F, Any](4)).flatMap { q => + val readBody = Stream.eval(F.delay(in.setReadListener(new ReadListener { var buf: Array[Byte] = _ unsafeReplaceBuffer() @@ -238,10 +235,11 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl F.delay(in.read(buf)).flatMap { case len if len == chunkSize => // We used the whole buffer. Replace it new before next read. - q.offer(Bytes(Chunk.array(buf))) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady - case len if len >= 0 => + q.offer(Chunk.array(buf)) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady + case len if len > 0 => // Got a partial chunk. Copy it, and reuse the current buffer. - q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady + q.offer(Chunk.array(Arrays.copyOf(buf, len))) >> loopIfReady + case len if len == 0 => loopIfReady case _ => F.unit } @@ -253,7 +251,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl unsafeRunAndForget(q.offer(End)) def onError(t: Throwable): Unit = - unsafeRunAndForget(q.offer(Error(t))) + unsafeRunAndForget(q.offer(t)) def unsafeRunAndForget[A](fa: F[A]): Unit = dispatcher.unsafeRunAndForget( @@ -263,12 +261,12 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl def pullBody: Pull[F, Byte, Unit] = Pull.eval(q.take).flatMap { - case Bytes(chunk) => Pull.output(chunk) >> pullBody + case chunk: Chunk[Byte] @unchecked => Pull.output(chunk) >> pullBody case End => Pull.done - case Error(t) => Pull.raiseError[F](t) + case t: Throwable => Pull.raiseError[F](t) } - pullBody.stream.concurrently(readBody) + readBody.flatMap(_ => pullBody.stream) } } }