Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Route with literal path segment makes route with variable path segment unavailable #3040

Merged
merged 12 commits into from
Aug 27, 2024
80 changes: 76 additions & 4 deletions zio-http/jvm/src/test/scala/zio/http/RoutePatternSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@ package zio.http

import java.util.UUID

import scala.collection.Seq

import zio.Chunk
import zio.test._

import zio.http.internal.HttpGen
import zio.http.{int => _, uuid => _, _}
import zio.http.{int => _, uuid => _}

object RoutePatternSpec extends ZIOHttpSpec {
import zio.http.Method
Expand Down Expand Up @@ -143,6 +140,81 @@ object RoutePatternSpec extends ZIOHttpSpec {
tree.get(Method.GET, p2).contains(2),
)
},
suite("collisions properly resolved")(
test("simple collision between literal and text segment i3036") {
val routes: Chunk[RoutePattern[_]] =
Chunk(Method.GET / "users" / "param1" / "fixed", Method.GET / "users" / string("param") / "dynamic")

var tree: Tree[Int] = RoutePattern.Tree.empty
routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) =>
tree = tree.add(routePattern, idx)
}

assertTrue(
tree.get(Method.GET, Path("/users/param1/fixed")).contains(1),
tree.get(Method.GET, Path("/users/param1/dynamic")).contains(2),
)
},
test("two collisions between literal and text segment") {
val routes: Chunk[RoutePattern[_]] = Chunk(
Method.GET / "users" / "param1" / "literal1" / "p1" / "tail1",
Method.GET / "users" / "param1" / "literal1" / string("p2") / "tail2",
Method.GET / "users" / string("param") / "literal1" / "p1" / "tail3",
Method.GET / "users" / string("param") / "literal1" / string("p2") / "tail4",
)

var tree: Tree[Int] = RoutePattern.Tree.empty
routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) =>
tree = tree.add(routePattern, idx)
}

assertTrue(
tree.get(Method.GET, Path("/users/param1/literal1/p1/tail1")).contains(1),
tree.get(Method.GET, Path("/users/param1/literal1/p1/tail2")).contains(2),
tree.get(Method.GET, Path("/users/param1/literal1/p1/tail3")).contains(3),
tree.get(Method.GET, Path("/users/param1/literal1/p1/tail4")).contains(4),
)
},
test("collision where distinguish is by literal and int segment") {
val routes: Chunk[RoutePattern[_]] = Chunk(
Method.GET / "users" / "param1" / int("id"),
Method.GET / "users" / string("param") / "dynamic",
)

var tree: Tree[Int] = RoutePattern.Tree.empty
routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) =>
tree = tree.add(routePattern, idx)
}

val r1 = tree.get(Method.GET, Path("/users/param1/155"))
val r2 = tree.get(Method.GET, Path("/users/param1/dynamic"))

assertTrue(
r1.contains(1),
r2.contains(2),
)
},
test("collision where distinguish is by two not literal segments") {
val uuid1 = new UUID(10, 10)
val routes: Chunk[RoutePattern[_]] = Chunk(
Method.GET / "users" / "param1" / int("id"),
Method.GET / "users" / string("param") / uuid("dynamic"),
)

var tree: Tree[Int] = RoutePattern.Tree.empty
routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) =>
tree = tree.add(routePattern, idx)
}

val r2 = tree.get(Method.GET, Path(s"/users/param1/$uuid1"))
val r1 = tree.get(Method.GET, Path("/users/param1/155"))

assertTrue(
r1.contains(1),
r2.contains(2),
)
},
),
test("on conflict, first one wins") {
var tree: Tree[Int] = RoutePattern.Tree.empty

Expand Down
81 changes: 61 additions & 20 deletions zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -762,34 +762,53 @@ object PathCodec {
private[http] final case class SegmentSubtree[+A](
literals: ListMap[String, SegmentSubtree[A]],
others: ListMap[SegmentCodec[_], SegmentSubtree[A]],
literalsWithCollisions: Set[String],
value: Chunk[A],
) {
self =>
def ++[A1 >: A](that: SegmentSubtree[A1]): SegmentSubtree[A1] =
def ++[A1 >: A](that: SegmentSubtree[A1]): SegmentSubtree[A1] = {
val newLiterals = mergeMaps(self.literals, that.literals)(_ ++ _)
val newOthers = mergeMaps(self.others, that.others)(_ ++ _)
val newLiteralCollisions = mergeLiteralCollisions(
self.literalsWithCollisions ++ that.literalsWithCollisions,
newLiterals.keySet,
newOthers.keys,
)
SegmentSubtree(
mergeMaps(self.literals, that.literals)(_ ++ _),
mergeMaps(self.others, that.others)(_ ++ _),
newLiterals,
newOthers,
newLiteralCollisions,
self.value ++ that.value,
)
}

def add[A1 >: A](segments: Iterable[SegmentCodec[_]], value: A1): SegmentSubtree[A1] =
self ++ SegmentSubtree.single(segments, value)

def get(path: Path): Chunk[A] =
get(path, 0)

private def get(path: Path, from: Int): Chunk[A] = {
private def get(path: Path, from: Int, skipLiteralsFor: Set[Int] = Set.empty): Chunk[A] = {
val segments = path.segments
val nSegments = segments.length
var subtree = self
var result = subtree.value
var i = from

var trySkipLiteralIdx: List[Int] = Nil

while (i < nSegments) {
val segment = segments(i)

if (subtree.literals.contains(segment)) {
// Fast path, jump down the tree:
// Fast path, jump down the tree:
if (!skipLiteralsFor.contains(i) && subtree.literals.contains(segment)) {

// this subtree segment have conflict with others
// will try others if result was empty
if (subtree.literalsWithCollisions.contains(segment)) {
trySkipLiteralIdx = i +: trySkipLiteralIdx
}

subtree = subtree.literals(segment)

result = subtree.value
Expand Down Expand Up @@ -863,13 +882,22 @@ object PathCodec {
}
}

result
if (trySkipLiteralIdx.nonEmpty && result.isEmpty) {
trySkipLiteralIdx = trySkipLiteralIdx.reverse
while (trySkipLiteralIdx.nonEmpty && result.isEmpty) {
val skipIdx = trySkipLiteralIdx.head
trySkipLiteralIdx = trySkipLiteralIdx.tail
result = get(path, from, skipLiteralsFor + skipIdx)
}
result
} else result
}

def map[B](f: A => B): SegmentSubtree[B] =
SegmentSubtree(
literals.map { case (k, v) => k -> v.map(f) },
ListMap(others.toSeq.map { case (k, v) => k -> v.map(f) }: _*),
literalsWithCollisions,
value.map(f),
)

Expand All @@ -883,24 +911,25 @@ object PathCodec {
object SegmentSubtree {
def single[A](segments: Iterable[SegmentCodec[_]], value: A): SegmentSubtree[A] =
segments.collect { case x if x.nonEmpty => x }
.foldRight[SegmentSubtree[A]](SegmentSubtree(ListMap(), ListMap(), Chunk(value))) { case (segment, subtree) =>
val literals =
segment match {
case SegmentCodec.Literal(value) => ListMap(value -> subtree)
case _ => ListMap.empty[String, SegmentSubtree[A]]
}
.foldRight[SegmentSubtree[A]](SegmentSubtree(ListMap(), ListMap(), Set.empty, Chunk(value))) {
case (segment, subtree) =>
val literals =
segment match {
case SegmentCodec.Literal(value) => ListMap(value -> subtree)
case _ => ListMap.empty[String, SegmentSubtree[A]]
}

val others =
ListMap[SegmentCodec[_], SegmentSubtree[A]]((segment match {
case SegmentCodec.Literal(_) => Chunk.empty
case _ => Chunk((segment, subtree))
}): _*)
val others =
ListMap[SegmentCodec[_], SegmentSubtree[A]]((segment match {
case SegmentCodec.Literal(_) => Chunk.empty
case _ => Chunk((segment, subtree))
}): _*)

SegmentSubtree(literals, others, Chunk.empty)
SegmentSubtree(literals, others, Set.empty, Chunk.empty)
}

val empty: SegmentSubtree[Nothing] =
SegmentSubtree(ListMap(), ListMap(), Chunk.empty)
SegmentSubtree(ListMap(), ListMap(), Set.empty, Chunk.empty)
}

private def mergeMaps[A, B](left: ListMap[A, B], right: ListMap[A, B])(f: (B, B) => B): ListMap[A, B] =
Expand All @@ -910,4 +939,16 @@ object PathCodec {
case Some(v0) => acc.updated(k, f(v0, v))
}
}

private def mergeLiteralCollisions(
currentCollisions: Set[String],
literals: Set[String],
others: Iterable[SegmentCodec[_]],
): Set[String] = {
currentCollisions ++ literals.filter { literal =>
!currentCollisions.contains(literal) && others.exists { o =>
o.inSegmentUntil(literal, 0) != -1
}
}
}
}
Loading