From 62e78a36b88153a26a337eb8c7f610a189dc2f22 Mon Sep 17 00:00:00 2001 From: Saturn225 <101260782+Saturn225@users.noreply.github.com> Date: Wed, 2 Oct 2024 08:40:36 +0530 Subject: [PATCH] feat(conformance): add review comments --- .../netty/server/ServerInboundHandler.scala | 21 +++++---- .../test/scala/zio/http/ConformanceSpec.scala | 2 +- .../src/main/scala/zio/http/Routes.scala | 46 ++++++++++--------- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala index d87261e25b..37a63707ae 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/server/ServerInboundHandler.scala @@ -116,16 +116,16 @@ private[zio] final case class ServerInboundHandler( } private def validateHostHeader(req: Request): Boolean = { - req.headers.get("Host") match { - case Some(host) => - val parts = host.split(":") - val hostname = parts(0) - val isValidHost = validateHostname(hostname) - val isValidPort = parts.length == 1 || (parts.length == 2 && parts(1).forall(_.isDigit)) - val isValid = isValidHost && isValidPort - isValid - case None => - false + val host = req.headers.get("Host").getOrElse(null) + if (host != null) { + val parts = host.split(":") + val hostname = parts(0) + val isValidHost = validateHostname(hostname) + val isValidPort = parts.length == 1 || (parts.length == 2 && parts(1).forall(_.isDigit)) + val isValid = isValidHost && isValidPort + isValid + } else { + false } } @@ -143,6 +143,7 @@ private[zio] final case class ServerInboundHandler( override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = cause match { + case ioe: IOException if { val msg = ioe.getMessage (msg ne null) && msg.contains("Connection reset") diff --git a/zio-http/jvm/src/test/scala/zio/http/ConformanceSpec.scala b/zio-http/jvm/src/test/scala/zio/http/ConformanceSpec.scala index fa80d17dd9..f2ae2abd30 100644 --- a/zio-http/jvm/src/test/scala/zio/http/ConformanceSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/ConformanceSpec.scala @@ -1,7 +1,7 @@ package zio.http -import java.time.format.DateTimeFormatter import java.time.ZonedDateTime +import java.time.format.DateTimeFormatter import zio._ import zio.test.Assertion._ diff --git a/zio-http/shared/src/main/scala/zio/http/Routes.scala b/zio-http/shared/src/main/scala/zio/http/Routes.scala index 642044ca0e..8695ebd2ea 100644 --- a/zio-http/shared/src/main/scala/zio/http/Routes.scala +++ b/zio-http/shared/src/main/scala/zio/http/Routes.scala @@ -249,32 +249,36 @@ final case class Routes[-Env, +Err](routes: Chunk[zio.http.Route[Env, Err]]) { s Handler .fromFunctionHandler[Request] { req => val chunk = tree.get(req.method, req.path) - val allowedMethods = tree.getAllMethods(req.path) + def allowedMethods = tree.getAllMethods(req.path) req.method match { case Method.CUSTOM(_) => Handler.notImplemented case _ => - chunk.length match { - case 0 => - if (allowedMethods.nonEmpty) { - val allowHeader = Header.Allow(NonEmptyChunk.fromIterableOption(allowedMethods).get) - Handler.methodNotAllowed.addHeader(allowHeader) - } else { - Handler.notFound - } - case 1 => chunk(0) - case n => // TODO: Support precomputed fallback among all chunk elements - var acc = chunk(0) - var i = 1 - while (i < n) { - val h = chunk(i) - acc = acc.catchAll { response => - if (response.status == Status.NotFound) h - else Handler.fail(response) + if (chunk.isEmpty) { + if (allowedMethods.isEmpty) { + // If no methods are allowed for the path, return 404 Not Found + Handler.notFound + } else { + // If there are allowed methods for the path but none match the request method, return 405 Method Not Allowed + val allowHeader = Header.Allow(NonEmptyChunk.fromIterableOption(allowedMethods).get) + Handler.methodNotAllowed.addHeader(allowHeader) + } + } else { + chunk.length match { + case 1 => chunk(0) + case n => // TODO: Support precomputed fallback among all chunk elements + var acc = chunk(0) + var i = 1 + while (i < n) { + val h = chunk(i) + acc = acc.catchAll { response => + if (response.status == Status.NotFound) h + else Handler.fail(response) + } + i += 1 } - i += 1 - } - acc + acc + } } } }