Skip to content

Commit

Permalink
Merge pull request #13 from go-labx/refactor_to_trie_router
Browse files Browse the repository at this point in the history
refactor: refactor router to prefix tree
  • Loading branch information
kk0829 authored May 13, 2023
2 parents ebb5763 + 2052034 commit 086fc55
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 310 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## [0.4.1] - May 13, 2023

### Changed

- refactor router to prefix tree

## [0.4.0] - Apr 20, 2023

### Added
Expand Down
56 changes: 24 additions & 32 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ func TestGroup_AddRoute(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.AddRoute(http.MethodGet, "/path", handlers)
root := app.router.Roots[http.MethodGet]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodGet, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand Down Expand Up @@ -131,11 +130,10 @@ func TestGroup_Get(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Get("/path", handlers...)
root := app.router.Roots[http.MethodGet]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodGet, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand All @@ -144,11 +142,10 @@ func TestGroup_Post(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Post("/path", handlers...)
root := app.router.Roots[http.MethodPost]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodPost, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand All @@ -157,11 +154,10 @@ func TestGroup_Put(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Put("/path", handlers...)
root := app.router.Roots[http.MethodPut]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodPut, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand All @@ -170,11 +166,10 @@ func TestGroup_Delete(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Delete("/path", handlers...)
root := app.router.Roots[http.MethodDelete]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodDelete, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand All @@ -183,11 +178,10 @@ func TestGroup_Head(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Head("/path", handlers...)
root := app.router.Roots[http.MethodHead]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodHead, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand All @@ -196,11 +190,10 @@ func TestGroup_Patch(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Patch("/path", handlers...)
root := app.router.Roots[http.MethodPatch]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodPatch, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}

Expand All @@ -209,10 +202,9 @@ func TestGroup_Options(t *testing.T) {
group := app.Group("/prefix")
handlers := []HandlerFunc{func(c *Context) {}}
group.Options("/path", handlers...)
root := app.router.Roots[http.MethodOptions]

route := root.Children["prefix"].Children["path"]
if reflect.ValueOf(route.handlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", route.handlers[0], handlers[0])
searchHandlers, _ := app.router.findRoute(http.MethodOptions, "/prefix/path")
if reflect.ValueOf(searchHandlers[0]) != reflect.ValueOf(handlers[0]) {
t.Errorf("Expected handlers to be '%v', but got '%v'", searchHandlers[0], handlers[0])
}
}
5 changes: 3 additions & 2 deletions lightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func NewApp(c ...*Config) *Application {
}

if app.Config.EnableDebug {
app.Get("/__debug__/router-map", func(ctx *Context) {
app.Get("/__debug__/router_map", func(ctx *Context) {
ctx.JSON(200, app.router.Roots)
})
}
Expand Down Expand Up @@ -166,7 +166,8 @@ func (app *Application) ServeHTTP(w http.ResponseWriter, req *http.Request) {
defer ctx.flush()

// Find the matching route and set the handlers and paramsMap in the context
handlers, params := app.router.findRoute(req.Method, req.URL.Path)
handlers, params := app.router.findRoute(ctx.Method, ctx.Path)

// This check is necessary because if no matching route is found and the handlers slice is left empty,
// the middleware chain will not be executed and the client will receive an empty response.
// By appending the 404 handler function to the handlers slice,
Expand Down
191 changes: 100 additions & 91 deletions router.go
Original file line number Diff line number Diff line change
@@ -1,121 +1,130 @@
package lightning

import (
"fmt"
"strings"
)

// trieNode represents a node in the trie data structure used by the router.
type trieNode struct {
Children map[string]*trieNode `json:"children"` // A map of child nodes keyed by their string values
IsEnd bool `json:"isEnd"` // boolean flag indicating whether the node marks the end of a route
handlers []HandlerFunc // `HandlerFunc` functions that handles requests for the node's route
Params map[string]int `json:"params"` // a map of parameter names and their corresponding indices in the route pattern
Wildcard string `json:"wildcard"` // a string representing the name of the wildcard parameter in the route pattern (if any)
}
import "strings"

// router represents the HTTP router.
type router struct {
Roots map[string]*trieNode `json:"roots"`
// node is a struct that represents a node in the trie
type node struct {
// Pattern is the pattern of the node
Pattern string `json:"pattern"`
// Part is the part of the node
Part string `json:"part"`
// IsWild is a boolean that indicates whether the node is a wildcard
IsWild bool `json:"isWild"`
// Children is a slice of pointers to the children of the node
Children []*node `json:"children,omitempty"`
// handlers is a slice of HandlerFuncs that are associated with the node
handlers []HandlerFunc
}

// newTrieNode creates a new instance of the `trieNode` struct with default values.
func newTrieNode() *trieNode {
return &trieNode{
Children: make(map[string]*trieNode),
IsEnd: false,
handlers: make([]HandlerFunc, 0),
Params: make(map[string]int),
Wildcard: "",
// matchChild returns the child node that matches the given part
func (n *node) matchChild(part string) *node {
for _, child := range n.Children {
if child.Part == part || child.IsWild {
return child
}
}
return nil
}

// newRouter creates a new instance of the `router` struct with an empty `roots` map.
func newRouter() *router {
return &router{
Roots: make(map[string]*trieNode),
// matchChildren returns a slice of child nodes that match the given part
func (n *node) matchChildren(part string) []*node {
nodes := make([]*node, 0)
for _, child := range n.Children {
if child.Part == part || child.IsWild {
nodes = append(nodes, child)
}
}
return nodes
}

// addRoute adds a new route to the router.
func (r *router) addRoute(method string, pattern string, handlers []HandlerFunc) {
if !isValidHTTPMethod(method) {
panic(fmt.Sprintf("method `%s` is not a standard HTTP method", method))
// insert inserts a new node into the trie
func (n *node) insert(pattern string, parts []string, height int, handlers []HandlerFunc) {
if len(parts) == height {
n.Pattern = pattern
n.handlers = handlers
return
}
root, ok := r.Roots[method]
if !ok {
root = newTrieNode()
r.Roots[method] = root

part := parts[height]
child := n.matchChild(part)
if child == nil {
child = &node{Part: part, IsWild: part[0] == ':' || part[0] == '*'}
n.Children = append(n.Children, child)
}
child.insert(pattern, parts, height+1, handlers)
}

// search searches the trie for a node that matches the given parts
func (n *node) search(parts []string, height int) *node {
if len(parts) == height || strings.HasPrefix(n.Part, "*") {
if n.Pattern == "" {
return nil
}
return n
}

params := make(map[string]int)
parts := parsePattern(pattern)
for i, part := range parts {
if part[0] == ':' {
// parameter
name := part[1:]
params[name] = i
if root.Children[":"] == nil {
root.Children[":"] = newTrieNode()
}
root = root.Children[":"]
} else if part[0] == '*' {
// wildcard
name := part[1:]
if root.Children["*"] == nil {
root.Children["*"] = newTrieNode()
}
root = root.Children["*"]
root.Wildcard = name
break
} else {
// static
if root.Children[part] == nil {
root.Children[part] = newTrieNode()
}
root = root.Children[part]
part := parts[height]
children := n.matchChildren(part)

for _, child := range children {
result := child.search(parts, height+1)
if result != nil {
return result
}
}

root.IsEnd = true // mark the end of the route
root.handlers = handlers // set the handlers for the route
root.Params = params // set the parameters for the route
return nil
}

// findRoute is used to find the appropriate handler function for a given HTTP request method and URL pattern.
func (r *router) findRoute(method string, pattern string) ([]HandlerFunc, map[string]string) {
root, ok := r.Roots[method]
if !ok {
return nil, nil
// router is a struct that represents a router
type router struct {
// Roots is a map of HTTP methods to the root nodes of the trie
Roots map[string]*node `json:"roots"`
}

// newRouter creates a new router
func newRouter() *router {
return &router{
Roots: make(map[string]*node, 0),
}
params := make(map[string]string)
values := make(map[int]string)
}

// addRoute adds a new route to the router
func (r *router) addRoute(method string, pattern string, handlers []HandlerFunc) {
parts := parsePattern(pattern)
for i, part := range parts {
if root.Children[part] != nil {
root = root.Children[part]
} else if root.Children[":"] != nil {
root = root.Children[":"]
values[i] = part
} else if root.Children["*"] != nil {
root = root.Children["*"]
if root.Wildcard != "" {
params[root.Wildcard] = strings.Join(parts[i:], "/")
}
break
} else {
return nil, nil
}

_, ok := r.Roots[method]
if !ok {
r.Roots[method] = &node{}
}
r.Roots[method].insert(pattern, parts, 0, handlers)
}

// findRoute finds the route that matches the given method and path
func (r *router) findRoute(method string, path string) ([]HandlerFunc, map[string]string) {
searchParts := parsePattern(path)
params := make(map[string]string)
root, ok := r.Roots[method]

if !root.IsEnd {
if !ok {
return nil, nil
}

for name, index := range root.Params {
params[name] = values[index]
n := root.search(searchParts, 0)

if n != nil {
parts := parsePattern(n.Pattern)
for index, part := range parts {
if part[0] == ':' {
params[part[1:]] = searchParts[index]
}
if part[0] == '*' && len(part) > 1 {
params[part[1:]] = strings.Join(searchParts[index:], "/")
break
}
}
return n.handlers, params
}

return root.handlers, params
return nil, nil
}
Loading

0 comments on commit 086fc55

Please sign in to comment.