diff --git a/zio-http/jvm/src/test/scala/zio/http/RoutePatternSpec.scala b/zio-http/jvm/src/test/scala/zio/http/RoutePatternSpec.scala index fa9f4cacd2..c59f853344 100644 --- a/zio-http/jvm/src/test/scala/zio/http/RoutePatternSpec.scala +++ b/zio-http/jvm/src/test/scala/zio/http/RoutePatternSpec.scala @@ -18,13 +18,10 @@ package zio.http import java.util.UUID -import scala.collection.Seq - import zio.Chunk import zio.test._ -import zio.http.internal.HttpGen -import zio.http.{int => _, uuid => _, _} +import zio.http.{int => _, uuid => _} object RoutePatternSpec extends ZIOHttpSpec { import zio.http.Method @@ -143,6 +140,81 @@ object RoutePatternSpec extends ZIOHttpSpec { tree.get(Method.GET, p2).contains(2), ) }, + suite("collisions properly resolved")( + test("simple collision between literal and text segment i3036") { + val routes: Chunk[RoutePattern[_]] = + Chunk(Method.GET / "users" / "param1" / "fixed", Method.GET / "users" / string("param") / "dynamic") + + var tree: Tree[Int] = RoutePattern.Tree.empty + routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) => + tree = tree.add(routePattern, idx) + } + + assertTrue( + tree.get(Method.GET, Path("/users/param1/fixed")).contains(1), + tree.get(Method.GET, Path("/users/param1/dynamic")).contains(2), + ) + }, + test("two collisions between literal and text segment") { + val routes: Chunk[RoutePattern[_]] = Chunk( + Method.GET / "users" / "param1" / "literal1" / "p1" / "tail1", + Method.GET / "users" / "param1" / "literal1" / string("p2") / "tail2", + Method.GET / "users" / string("param") / "literal1" / "p1" / "tail3", + Method.GET / "users" / string("param") / "literal1" / string("p2") / "tail4", + ) + + var tree: Tree[Int] = RoutePattern.Tree.empty + routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) => + tree = tree.add(routePattern, idx) + } + + assertTrue( + tree.get(Method.GET, Path("/users/param1/literal1/p1/tail1")).contains(1), + tree.get(Method.GET, Path("/users/param1/literal1/p1/tail2")).contains(2), + tree.get(Method.GET, Path("/users/param1/literal1/p1/tail3")).contains(3), + tree.get(Method.GET, Path("/users/param1/literal1/p1/tail4")).contains(4), + ) + }, + test("collision where distinguish is by literal and int segment") { + val routes: Chunk[RoutePattern[_]] = Chunk( + Method.GET / "users" / "param1" / int("id"), + Method.GET / "users" / string("param") / "dynamic", + ) + + var tree: Tree[Int] = RoutePattern.Tree.empty + routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) => + tree = tree.add(routePattern, idx) + } + + val r1 = tree.get(Method.GET, Path("/users/param1/155")) + val r2 = tree.get(Method.GET, Path("/users/param1/dynamic")) + + assertTrue( + r1.contains(1), + r2.contains(2), + ) + }, + test("collision where distinguish is by two not literal segments") { + val uuid1 = new UUID(10, 10) + val routes: Chunk[RoutePattern[_]] = Chunk( + Method.GET / "users" / "param1" / int("id"), + Method.GET / "users" / string("param") / uuid("dynamic"), + ) + + var tree: Tree[Int] = RoutePattern.Tree.empty + routes.zipWithIndexFrom(1).foreach { case (routePattern, idx) => + tree = tree.add(routePattern, idx) + } + + val r2 = tree.get(Method.GET, Path(s"/users/param1/$uuid1")) + val r1 = tree.get(Method.GET, Path("/users/param1/155")) + + assertTrue( + r1.contains(1), + r2.contains(2), + ) + }, + ), test("on conflict, first one wins") { var tree: Tree[Int] = RoutePattern.Tree.empty diff --git a/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala b/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala index d16f8f1d0e..ef4fe6d436 100644 --- a/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala +++ b/zio-http/shared/src/main/scala/zio/http/codec/PathCodec.scala @@ -762,15 +762,25 @@ object PathCodec { private[http] final case class SegmentSubtree[+A]( literals: ListMap[String, SegmentSubtree[A]], others: ListMap[SegmentCodec[_], SegmentSubtree[A]], + literalsWithCollisions: Set[String], value: Chunk[A], ) { self => - def ++[A1 >: A](that: SegmentSubtree[A1]): SegmentSubtree[A1] = + def ++[A1 >: A](that: SegmentSubtree[A1]): SegmentSubtree[A1] = { + val newLiterals = mergeMaps(self.literals, that.literals)(_ ++ _) + val newOthers = mergeMaps(self.others, that.others)(_ ++ _) + val newLiteralCollisions = mergeLiteralCollisions( + self.literalsWithCollisions ++ that.literalsWithCollisions, + newLiterals.keySet, + newOthers.keys, + ) SegmentSubtree( - mergeMaps(self.literals, that.literals)(_ ++ _), - mergeMaps(self.others, that.others)(_ ++ _), + newLiterals, + newOthers, + newLiteralCollisions, self.value ++ that.value, ) + } def add[A1 >: A](segments: Iterable[SegmentCodec[_]], value: A1): SegmentSubtree[A1] = self ++ SegmentSubtree.single(segments, value) @@ -778,18 +788,27 @@ object PathCodec { def get(path: Path): Chunk[A] = get(path, 0) - private def get(path: Path, from: Int): Chunk[A] = { + private def get(path: Path, from: Int, skipLiteralsFor: Set[Int] = Set.empty): Chunk[A] = { val segments = path.segments val nSegments = segments.length var subtree = self var result = subtree.value var i = from + var trySkipLiteralIdx: List[Int] = Nil + while (i < nSegments) { val segment = segments(i) - if (subtree.literals.contains(segment)) { - // Fast path, jump down the tree: + // Fast path, jump down the tree: + if (!skipLiteralsFor.contains(i) && subtree.literals.contains(segment)) { + + // this subtree segment have conflict with others + // will try others if result was empty + if (subtree.literalsWithCollisions.contains(segment)) { + trySkipLiteralIdx = i +: trySkipLiteralIdx + } + subtree = subtree.literals(segment) result = subtree.value @@ -863,13 +882,22 @@ object PathCodec { } } - result + if (trySkipLiteralIdx.nonEmpty && result.isEmpty) { + trySkipLiteralIdx = trySkipLiteralIdx.reverse + while (trySkipLiteralIdx.nonEmpty && result.isEmpty) { + val skipIdx = trySkipLiteralIdx.head + trySkipLiteralIdx = trySkipLiteralIdx.tail + result = get(path, from, skipLiteralsFor + skipIdx) + } + result + } else result } def map[B](f: A => B): SegmentSubtree[B] = SegmentSubtree( literals.map { case (k, v) => k -> v.map(f) }, ListMap(others.toSeq.map { case (k, v) => k -> v.map(f) }: _*), + literalsWithCollisions, value.map(f), ) @@ -883,24 +911,25 @@ object PathCodec { object SegmentSubtree { def single[A](segments: Iterable[SegmentCodec[_]], value: A): SegmentSubtree[A] = segments.collect { case x if x.nonEmpty => x } - .foldRight[SegmentSubtree[A]](SegmentSubtree(ListMap(), ListMap(), Chunk(value))) { case (segment, subtree) => - val literals = - segment match { - case SegmentCodec.Literal(value) => ListMap(value -> subtree) - case _ => ListMap.empty[String, SegmentSubtree[A]] - } + .foldRight[SegmentSubtree[A]](SegmentSubtree(ListMap(), ListMap(), Set.empty, Chunk(value))) { + case (segment, subtree) => + val literals = + segment match { + case SegmentCodec.Literal(value) => ListMap(value -> subtree) + case _ => ListMap.empty[String, SegmentSubtree[A]] + } - val others = - ListMap[SegmentCodec[_], SegmentSubtree[A]]((segment match { - case SegmentCodec.Literal(_) => Chunk.empty - case _ => Chunk((segment, subtree)) - }): _*) + val others = + ListMap[SegmentCodec[_], SegmentSubtree[A]]((segment match { + case SegmentCodec.Literal(_) => Chunk.empty + case _ => Chunk((segment, subtree)) + }): _*) - SegmentSubtree(literals, others, Chunk.empty) + SegmentSubtree(literals, others, Set.empty, Chunk.empty) } val empty: SegmentSubtree[Nothing] = - SegmentSubtree(ListMap(), ListMap(), Chunk.empty) + SegmentSubtree(ListMap(), ListMap(), Set.empty, Chunk.empty) } private def mergeMaps[A, B](left: ListMap[A, B], right: ListMap[A, B])(f: (B, B) => B): ListMap[A, B] = @@ -910,4 +939,16 @@ object PathCodec { case Some(v0) => acc.updated(k, f(v0, v)) } } + + private def mergeLiteralCollisions( + currentCollisions: Set[String], + literals: Set[String], + others: Iterable[SegmentCodec[_]], + ): Set[String] = { + currentCollisions ++ literals.filter { literal => + !currentCollisions.contains(literal) && others.exists { o => + o.inSegmentUntil(literal, 0) != -1 + } + } + } }