Skip to content

Commit

Permalink
feat: scoped requests
Browse files Browse the repository at this point in the history
  • Loading branch information
jgranstrom committed Dec 4, 2024
1 parent 1f8ef1f commit 87329b7
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 16 deletions.
13 changes: 13 additions & 0 deletions zio-http-testkit/src/main/scala/zio/http/TestServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ final case class TestServer(driver: Driver, bindPort: Int) extends Server {
),
)

override def installScoped[R](routes: Routes[R with Scope, Response])(implicit
trace: zio.Trace,
tag: EnvironmentTag[R],
): URIO[R, Unit] =
ZIO
.environment[R]
.flatMap(
driver.addAppScoped(
routes,
_,
),
)

override def port: UIO[Int] = ZIO.succeed(bindPort)
}

Expand Down
37 changes: 27 additions & 10 deletions zio-http/jvm/src/main/scala/zio/http/netty/server/NettyDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,38 @@ private[zio] final case class NettyDriver(
} yield StartResult(port, serverInboundHandler.inFlightRequests)

def addApp[R](newApp: Routes[R, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit] =
addAppImpl(asScoped = false, newApp, env)

def addAppScoped[R](newApp: Routes[R with Scope, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit] =
addAppImpl(asScoped = true, newApp, env)

override def createClientDriver()(implicit trace: Trace): ZIO[Scope, Throwable, ClientDriver] =
for {
channelFactory <- ChannelFactories.Client.live.build
.provideSomeEnvironment[Scope](_ ++ ZEnvironment[ChannelType.Config](nettyConfig))
nettyRuntime <- NettyRuntime.live.build
} yield NettyClientDriver(channelFactory.get, eventLoopGroups.worker, nettyRuntime.get)

override def toString: String = s"NettyDriver($serverConfig)"

private def addAppImpl[E, R <: E](asScoped: Boolean, newApp: Routes[R, Response], env: ZEnvironment[E])(implicit
trace: Trace,
): UIO[Unit] =
ZIO.fiberId.map { fiberId =>
var loop = true
while (loop) {
val oldAppAndRt = appRef.get()
val (oldApp, oldRt) = oldAppAndRt
val updatedApp = (oldApp ++ newApp).asInstanceOf[Routes[Any, Response]]
val updatedApp = oldApp.fold(
oldUnscoped => {
if (asScoped) {
Right((oldUnscoped ++ newApp).asInstanceOf[Routes[Scope, Response]])
} else {
Left((oldUnscoped ++ newApp).asInstanceOf[Routes[Any, Response]])
}
},
oldScoped => Right((oldScoped ++ newApp).asInstanceOf[Routes[Scope, Response]]),
)
val updatedEnv = oldRt.environment.unionAll(env)
// Update the fiberRefs with the new environment to avoid doing this every time we run / fork a fiber
val updatedFibRefs = oldRt.fiberRefs.updatedAs(fiberId)(FiberRef.currentEnvironment, updatedEnv)
Expand All @@ -78,15 +104,6 @@ private[zio] final case class NettyDriver(
}
serverInboundHandler.refreshApp()
}

override def createClientDriver()(implicit trace: Trace): ZIO[Scope, Throwable, ClientDriver] =
for {
channelFactory <- ChannelFactories.Client.live.build
.provideSomeEnvironment[Scope](_ ++ ZEnvironment[ChannelType.Config](nettyConfig))
nettyRuntime <- NettyRuntime.live.build
} yield NettyClientDriver(channelFactory.get, eventLoopGroups.worker, nettyRuntime.get)

override def toString: String = s"NettyDriver($serverConfig)"
}

object NettyDriver {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ private[zio] final case class ServerInboundHandler(

implicit private val unsafe: Unsafe = Unsafe.unsafe

private var handler: Handler[Any, Nothing, Request, Response] = _
private var runtime: NettyRuntime = _
private var handle: Request => ZIO[Any, Nothing, Response] = _
private var runtime: NettyRuntime = _

val inFlightRequests: LongAdder = new LongAdder()
private val readClientCert = config.sslConfig.exists(_.includeClientCert)
Expand All @@ -58,7 +58,15 @@ private[zio] final case class ServerInboundHandler(
def refreshApp(): Unit = {
val pair = appRef.get()

this.handler = pair._1.toHandler
this.handle = pair._1 match {
case Left(unscopedHandler) =>
val handler = unscopedHandler.toHandler
handler.apply
case Right(scopedHandler) =>
val handler = scopedHandler.toHandler
(req: Request) => ZIO.scoped(handler(req))
}

this.runtime = new NettyRuntime(pair._2)
}

Expand Down Expand Up @@ -88,7 +96,7 @@ private[zio] final case class ServerInboundHandler(
releaseRequest()
} else {
val req = makeZioRequest(ctx, jReq)
val exit = handler(req)
val exit = handle(req)
if (attemptImmediateWrite(ctx, req.method, exit)) {
releaseRequest()
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ import java.util.concurrent.atomic.AtomicReference // scalafix:ok;
import zio.stacktracer.TracingImplicits.disableAutoTrace

package object server {
private[server] type RoutesRef = AtomicReference[(Routes[Any, Response], Runtime[Any])]
private[server] type RoutesRef =
AtomicReference[(Either[Routes[Any, Response], Routes[Scope, Response]], Runtime[Any])]

private[server] object AppRef {
val empty: UIO[RoutesRef] = {
implicit val trace: Trace = Trace.empty
// Environment will be populated when we `install` the app
ZIO.runtime[Any].map(rt => new AtomicReference((Routes.empty, rt.mapEnvironment(_ => ZEnvironment.empty))))
ZIO.runtime[Any].map(rt => new AtomicReference((Left(Routes.empty), rt.mapEnvironment(_ => ZEnvironment.empty))))
}
}

Expand Down
1 change: 1 addition & 0 deletions zio-http/shared/src/main/scala/zio/http/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ trait Driver {
def start(implicit trace: Trace): RIO[Scope, StartResult]

def addApp[R](newRoutes: Routes[R, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit]
def addAppScoped[R](newRoutes: Routes[R with Scope, Response], env: ZEnvironment[R])(implicit trace: Trace): UIO[Unit]

def createClientDriver()(implicit trace: Trace): ZIO[Scope, Throwable, ClientDriver]
}
Expand Down
40 changes: 40 additions & 0 deletions zio-http/shared/src/main/scala/zio/http/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ trait Server {
*/
def install[R](routes: Routes[R, Response])(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R, Unit]

/**
* Installs the given HTTP application into the server, providing a Scope for
* each request.
*/
def installScoped[R](
routes: Routes[R with Scope, Response],
)(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R, Unit]

/**
* The port on which the server is listening.
*
Expand Down Expand Up @@ -443,19 +451,41 @@ object Server extends ServerPlatformSpecific {
ZIO.never
}

def serveScoped[R](
routes: Routes[R with Scope, Response],
)(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Nothing] = {
ZIO.logInfo("Starting the server...") *>
ZIO.serviceWithZIO[Server](_.installScoped[R](routes)) *>
ZIO.logInfo("Server started") *>
ZIO.never
}

def serve[R](
route: Route[R, Response],
routes: Route[R, Response]*,
)(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Nothing] = {
serve(Routes(route, routes: _*))
}

def serveScoped[R](
route: Route[R with Scope, Response],
routes: Route[R with Scope, Response]*,
)(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Nothing] = {
serveScoped[R](Routes(route, routes: _*))
}

def install[R](
routes: Routes[R, Response],
)(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Int] = {
ZIO.serviceWithZIO[Server](_.install[R](routes)) *> ZIO.serviceWithZIO[Server](_.port)
}

def installScoped[R](
routes: Routes[R with Scope, Response],
)(implicit trace: Trace, tag: EnvironmentTag[R]): URIO[R with Server, Int] = {
ZIO.serviceWithZIO[Server](_.installScoped[R](routes)) *> ZIO.serviceWithZIO[Server](_.port)
}

private[http] val base: ZLayer[Driver & Config, Throwable, Server] = {
implicit val trace: Trace = Trace.empty
ZLayer.scoped {
Expand Down Expand Up @@ -533,6 +563,16 @@ object Server extends ServerPlatformSpecific {
_ <- ZIO.environment[R].flatMap(env => driver.addApp(routes, env.prune[R]))
} yield ()

override def installScoped[R](routes: Routes[R with Scope, Response])(implicit
trace: Trace,
tag: EnvironmentTag[R],
): URIO[R, Unit] =
for {
_ <- initialInstall.succeed(())
_ <- serverStarted.await.orDie
_ <- ZIO.environment[R].flatMap(env => driver.addAppScoped(routes, env.prune[R]))
} yield ()

override def port: UIO[Int] = serverStarted.await.orDie

}
Expand Down

0 comments on commit 87329b7

Please sign in to comment.