Skip to content

Commit

Permalink
Update ServerInboundHandler.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
Saturn225 authored Oct 2, 2024
1 parent b609cd9 commit e3eefc8
Showing 1 changed file with 30 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,29 @@ private[zio] final case class ServerInboundHandler(
)
releaseRequest()
} else {
val req = makeZioRequest(ctx, jReq)
if (!validateHostHeader(req)) {
attemptFastWrite(ctx, req.method, Response.status(Status.BadRequest))
releaseRequest()
} else {

val exit = handler(req)
if (attemptImmediateWrite(ctx, req.method, exit)) {
try {
val req = makeZioRequest(ctx, jReq)
if (!validateHostHeader(req)) {
attemptFastWrite(ctx, req.method, Response.status(Status.BadRequest))
releaseRequest()
} else {
writeResponse(ctx, runtime, exit, req)(releaseRequest)

val exit = handler(req)
if (attemptImmediateWrite(ctx, req.method, exit)) {
releaseRequest()
} else {
writeResponse(ctx, runtime, exit, req)(releaseRequest)

}
}
} catch {
case _: IllegalArgumentException =>
attemptFastWrite(
ctx,
Conversions.methodFromNetty(jReq.method()),
Response.status(Status.BadRequest),
)
releaseRequest()
}
}
} finally {
Expand All @@ -116,16 +126,16 @@ private[zio] final case class ServerInboundHandler(
}

private def validateHostHeader(req: Request): Boolean = {
req.headers.get("Host") match {
case Some(host) =>
val parts = host.split(":")
val hostname = parts(0)
val isValidHost = validateHostname(hostname)
val isValidPort = parts.length == 1 || (parts.length == 2 && parts(1).forall(_.isDigit))
val isValid = isValidHost && isValidPort
isValid
case None =>
false
val host = req.headers.get("Host").getOrElse(null)
if (host != null) {
val parts = host.split(":")
val hostname = parts(0)
val isValidHost = validateHostname(hostname)
val isValidPort = parts.length == 1 || (parts.length == 2 && parts(1).forall(_.isDigit))
val isValid = isValidHost && isValidPort
isValid
} else {
false
}
}

Expand All @@ -143,6 +153,7 @@ private[zio] final case class ServerInboundHandler(

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit =
cause match {

case ioe: IOException if {
val msg = ioe.getMessage
(msg ne null) && msg.contains("Connection reset")
Expand Down

0 comments on commit e3eefc8

Please sign in to comment.