From c9be1ea4f85610236a9899956a296061c7ccf3c6 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 21 Dec 2024 09:56:01 +0800 Subject: [PATCH] implement multi-categories filter for all APIs (#902) --- README.md | 1 - go.mod | 2 +- go.sum | 6 ++-- master/rest.go | 85 +++++++++++++++------------------------------ master/rest_test.go | 8 ++--- server/rest.go | 65 +++++++++++++++++++++++----------- 6 files changed, 80 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index c59abf640..d8494385c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,6 @@ [![build](https://github.com/zhenghaoz/gorse/workflows/build/badge.svg)](https://github.com/zhenghaoz/gorse/actions?query=workflow%3Abuild) [![codecov](https://codecov.io/gh/gorse-io/gorse/branch/master/graph/badge.svg)](https://codecov.io/gh/gorse-io/gorse) [![Go Report Card](https://goreportcard.com/badge/github.com/zhenghaoz/gorse)](https://goreportcard.com/report/github.com/zhenghaoz/gorse) -[![GoDoc](https://godoc.org/github.com/zhenghaoz/gorse?status.svg)](https://godoc.org/github.com/zhenghaoz/gorse) [![Discord](https://img.shields.io/discord/830635934210588743)](https://discord.gg/x6gAtNNkAE) [![Twitter Follow](https://img.shields.io/twitter/follow/gorse_io?label=Follow&style=social)](https://twitter.com/gorse_io) [![Gurubase](https://img.shields.io/badge/Gurubase-Ask%20Gorse%20Guru-006BFF)](https://gurubase.io/g/gorse) diff --git a/go.mod b/go.mod index 365988233..8c43d8b3c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77 + github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606 github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 github.com/jellydator/ttlcache/v3 v3.3.0 diff --git a/go.sum b/go.sum index 8f535388a..fe0039020 100644 --- a/go.sum +++ b/go.sum @@ -309,10 +309,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= -github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4 h1:FOUvD2HvTY/8j1/I4j/FlX3LEqKGLWPWQLl6jPtUqQ0= -github.com/gorse-io/dashboard v0.0.0-20241207032532-3b75acd211c4/go.mod h1:LBLzsMv3XVLmpaM/1q8/sGvv2Avj1YxmHBZfXcdqRjU= -github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77 h1:WA5kRl4LNduJuM59vvMoAyBPU+7KZL2ROjE2fPUy6sE= -github.com/gorse-io/dashboard v0.0.0-20241219140402-1035820fbe77/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc= +github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606 h1:5Vh8xik8c905IYFg66ujt7FuuuPtzSW6e2DRzBUYc58= +github.com/gorse-io/dashboard v0.0.0-20241220180536-6acaf5256606/go.mod h1:6h/3EYChEyiynyCMMDsCsDEVBSOPLSo1L/+aHqj9kdc= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg= diff --git a/master/rest.go b/master/rest.go index 4e6d6fb6d..3dd78774e 100644 --- a/master/rest.go +++ b/master/rest.go @@ -22,7 +22,6 @@ import ( "io" "net/http" "os" - "reflect" "sort" "strconv" "strings" @@ -138,7 +137,7 @@ func (m *Master) CreateWebService() { Returns(http.StatusOK, "OK", UserIterator{}). Writes(UserIterator{})) // Get non-personalized recommendation - ws.Route(ws.GET("/non-personalized/{name}").To(m.getNonPersonalized). + ws.Route(ws.GET("/dashboard/non-personalized/{name}").To(m.getNonPersonalized). Doc("Get non-personalized recommendations."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.QueryParameter("category", "Category of returned items.").DataType("string")). @@ -151,6 +150,7 @@ func (m *Master) CreateWebService() { Doc("Get recommendation for user."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). + Param(ws.QueryParameter("category", "category of items").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Returns(http.StatusOK, "OK", []data.Item{}). Writes([]data.Item{})) @@ -159,6 +159,7 @@ func (m *Master) CreateWebService() { Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.PathParameter("user-id", "identifier of the user").DataType("string")). Param(ws.PathParameter("recommender", "one of `final`, `collaborative`, `user_based` and `item_based`").DataType("string")). + Param(ws.QueryParameter("category", "category of items").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Returns(http.StatusOK, "OK", []data.Item{}). Writes([]data.Item{})) @@ -175,6 +176,7 @@ func (m *Master) CreateWebService() { Doc("get neighbors of a item"). Metadata(restfulspec.KeyOpenAPITags, []string{"recommendation"}). Param(ws.PathParameter("item-id", "identifier of the item").DataType("string")). + Param(ws.QueryParameter("category", "category of items").DataType("string")). Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). Returns(http.StatusOK, "OK", []ScoredItem{}). @@ -743,7 +745,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon // parse arguments recommender := request.PathParameter("recommender") userId := request.PathParameter("user-id") - categories := []string{request.PathParameter("category")} + categories := server.ReadCategories(request) n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN) if err != nil { server.BadRequest(response, err) @@ -844,80 +846,49 @@ type ScoreUser struct { Score float64 } -func (m *Master) searchDocuments(collection, subset, category string, request *restful.Request, response *restful.Response, retType interface{}) { - ctx := context.Background() - if request != nil && request.Request != nil { - ctx = request.Request.Context() - } - var n, offset int +func (m *Master) GetItem(score cache.Score) (any, error) { + var item ScoredItem var err error - // read arguments - if offset, err = server.ParseInt(request, "offset", 0); err != nil { - server.BadRequest(response, err) - return - } - if n, err = server.ParseInt(request, "n", m.Config.Server.DefaultN); err != nil { - server.BadRequest(response, err) - return - } - // Get the popular list - scores, err := m.CacheClient.SearchScores(ctx, collection, subset, []string{category}, offset, m.Config.Recommend.CacheSize) + item.Score = score.Score + item.Item, err = m.DataClient.GetItem(context.Background(), score.Id) if err != nil { - server.InternalServerError(response, err) - return - } - if n > 0 && len(scores) > n { - scores = scores[:n] + return nil, err } - // Send result - switch retType.(type) { - case data.Item: - details := make([]ScoredItem, len(scores)) - for i := range scores { - details[i].Score = scores[i].Score - details[i].Item, err = m.DataClient.GetItem(ctx, scores[i].Id) - if err != nil { - server.InternalServerError(response, err) - return - } - } - server.Ok(response, details) - case data.User: - details := make([]ScoreUser, len(scores)) - for i := range scores { - details[i].Score = scores[i].Score - details[i].User, err = m.DataClient.GetUser(ctx, scores[i].Id) - if err != nil { - server.InternalServerError(response, err) - return - } - } - server.Ok(response, details) - default: - log.ResponseLogger(response).Fatal("unknown return type", zap.Any("ret_type", reflect.TypeOf(retType))) + return item, nil +} + +func (m *Master) GetUser(score cache.Score) (any, error) { + var user ScoreUser + var err error + user.Score = score.Score + user.User, err = m.DataClient.GetUser(context.Background(), score.Id) + if err != nil { + return nil, err } + return user, nil } func (m *Master) getNonPersonalized(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") - category := request.QueryParameter("category") - m.searchDocuments(cache.NonPersonalized, name, category, request, response, data.Item{}) + categories := server.ReadCategories(request) + m.SearchDocuments(cache.NonPersonalized, name, categories, m.GetItem, request, response) } func (m *Master) getItemNeighbors(request *restful.Request, response *restful.Response) { itemId := request.PathParameter("item-id") - m.searchDocuments(cache.ItemNeighbors, itemId, "", request, response, data.Item{}) + categories := server.ReadCategories(request) + m.SearchDocuments(cache.ItemNeighbors, itemId, categories, m.GetItem, request, response) } func (m *Master) getItemCategorizedNeighbors(request *restful.Request, response *restful.Response) { itemId := request.PathParameter("item-id") - category := request.PathParameter("category") - m.searchDocuments(cache.ItemNeighbors, itemId, category, request, response, data.Item{}) + categories := server.ReadCategories(request) + m.SearchDocuments(cache.ItemNeighbors, itemId, categories, m.GetItem, request, response) } func (m *Master) getUserNeighbors(request *restful.Request, response *restful.Response) { userId := request.PathParameter("user-id") - m.searchDocuments(cache.UserNeighbors, userId, "", request, response, data.User{}) + m.SearchDocuments(cache.UserNeighbors, userId, []string{""}, m.GetUser, request, response) } func (m *Master) importExportUsers(response http.ResponseWriter, request *http.Request) { diff --git a/master/rest_test.go b/master/rest_test.go index 2ff94b44b..4d0a03208 100644 --- a/master/rest_test.go +++ b/master/rest_test.go @@ -515,10 +515,10 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) { operators := []ListOperator{ {"Item Neighbors", cache.ItemNeighbors, "0", "", "/api/dashboard/item/0/neighbors"}, {"Item Neighbors in Category", cache.ItemNeighbors, "0", "*", "/api/dashboard/item/0/neighbors/*"}, - {"Latest Items", cache.NonPersonalized, cache.Latest, "", "/api/non-personalized/latest/"}, - {"Popular Items", cache.NonPersonalized, cache.Popular, "", "/api/non-personalized/popular/"}, - {"Latest Items in Category", cache.NonPersonalized, cache.Latest, "*", "/api/non-personalized/latest/"}, - {"Popular Items in Category", cache.NonPersonalized, cache.Popular, "*", "/api/non-personalized/popular/"}, + {"Latest Items", cache.NonPersonalized, cache.Latest, "", "/api/dashboard/non-personalized/latest/"}, + {"Popular Items", cache.NonPersonalized, cache.Popular, "", "/api/dashboard/non-personalized/popular/"}, + {"Latest Items in Category", cache.NonPersonalized, cache.Latest, "*", "/api/dashboard/non-personalized/latest/"}, + {"Popular Items in Category", cache.NonPersonalized, cache.Popular, "*", "/api/dashboard/non-personalized/popular/"}, } for i, operator := range operators { t.Run(operator.Name, func(t *testing.T) { diff --git a/server/rest.go b/server/rest.go index 327e82b7e..45a5eaca3 100644 --- a/server/rest.go +++ b/server/rest.go @@ -436,6 +436,7 @@ func (s *RestServer) CreateWebService() { Doc("Get popular items."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). + Param(ws.QueryParameter("category", "Category of returned items").DataType("string")). Param(ws.QueryParameter("n", "Number of returned recommendations").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned recommendations").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). @@ -456,6 +457,7 @@ func (s *RestServer) CreateWebService() { Doc("Get the latest items."). Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). + Param(ws.QueryParameter("category", "Category of returned items").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). Param(ws.QueryParameter("offset", "Offset of returned items").DataType("integer")). Param(ws.QueryParameter("user-id", "Remove read items of a user").DataType("string")). @@ -585,7 +587,10 @@ func ParseDuration(request *restful.Request, name string) (time.Duration, error) return time.ParseDuration(valueString) } -func (s *RestServer) searchDocuments(collection, subset, category string, isItem bool, request *restful.Request, response *restful.Response) { +func (s *RestServer) SearchDocuments(collection, subset string, categories []string, + iteratee func(item cache.Score) (any, error), + request *restful.Request, response *restful.Response, +) { var ( ctx = request.Request.Context() n int @@ -623,7 +628,7 @@ func (s *RestServer) searchDocuments(collection, subset, category string, isItem } // Get the sorted list - items, err := s.CacheClient.SearchScores(ctx, collection, subset, []string{category}, offset, end) + items, err := s.CacheClient.SearchScores(ctx, collection, subset, categories, offset, end) if err != nil { InternalServerError(response, err) return @@ -644,26 +649,39 @@ func (s *RestServer) searchDocuments(collection, subset, category string, isItem if n > 0 && len(items) > n { items = items[:n] } - Ok(response, items) + if iteratee != nil { + var results []any + for _, item := range items { + result, err := iteratee(item) + if err != nil { + InternalServerError(response, err) + return + } + results = append(results, result) + } + Ok(response, results) + } else { + Ok(response, items) + } } func (s *RestServer) getPopular(request *restful.Request, response *restful.Response) { - category := request.PathParameter("category") - log.ResponseLogger(response).Debug("get category popular items in category", zap.String("category", category)) - s.searchDocuments(cache.NonPersonalized, cache.Popular, category, true, request, response) + categories := ReadCategories(request) + log.ResponseLogger(response).Debug("get category popular items in category", zap.Strings("categories", categories)) + s.SearchDocuments(cache.NonPersonalized, cache.Popular, categories, nil, request, response) } func (s *RestServer) getLatest(request *restful.Request, response *restful.Response) { - category := request.PathParameter("category") - log.ResponseLogger(response).Debug("get category latest items in category", zap.String("category", category)) - s.searchDocuments(cache.NonPersonalized, cache.Latest, category, true, request, response) + categories := ReadCategories(request) + log.ResponseLogger(response).Debug("get category latest items in category", zap.Strings("categories", categories)) + s.SearchDocuments(cache.NonPersonalized, cache.Latest, categories, nil, request, response) } func (s *RestServer) getNonPersonalized(request *restful.Request, response *restful.Response) { name := request.PathParameter("name") - category := request.QueryParameter("category") + categories := ReadCategories(request) log.ResponseLogger(response).Debug("get leaderboard", zap.String("name", name)) - s.searchDocuments(cache.NonPersonalized, name, category, false, request, response) + s.SearchDocuments(cache.NonPersonalized, name, categories, nil, request, response) } // get feedback by item-id with feedback type @@ -701,23 +719,23 @@ func (s *RestServer) getFeedbackByItem(request *restful.Request, response *restf func (s *RestServer) getItemNeighbors(request *restful.Request, response *restful.Response) { // Get item id itemId := request.PathParameter("item-id") - category := request.PathParameter("category") - s.searchDocuments(cache.ItemNeighbors, itemId, category, true, request, response) + categories := ReadCategories(request) + s.SearchDocuments(cache.ItemNeighbors, itemId, categories, nil, request, response) } // getUserNeighbors gets neighbors of a user from database. func (s *RestServer) getUserNeighbors(request *restful.Request, response *restful.Response) { // Get item id userId := request.PathParameter("user-id") - s.searchDocuments(cache.UserNeighbors, userId, "", false, request, response) + s.SearchDocuments(cache.UserNeighbors, userId, []string{""}, nil, request, response) } // getCollaborative gets cached recommended items from database. func (s *RestServer) getCollaborative(request *restful.Request, response *restful.Response) { // Get user id userId := request.PathParameter("user-id") - category := request.PathParameter("category") - s.searchDocuments(cache.OfflineRecommend, userId, category, true, request, response) + categories := ReadCategories(request) + s.SearchDocuments(cache.OfflineRecommend, userId, categories, nil, request, response) } // Recommend items to users. @@ -995,10 +1013,7 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re BadRequest(response, err) return } - categories := request.QueryParameters("category") - if len(categories) == 0 { - categories = []string{request.PathParameter("category")} - } + categories := ReadCategories(request) offset, err := ParseInt(request, "offset", 0) if err != nil { BadRequest(response, err) @@ -1963,3 +1978,13 @@ func withWildCard(categories []string) []string { result = append(result, "") return result } + +func ReadCategories(request *restful.Request) []string { + if pathValue := request.PathParameter("category"); pathValue != "" { + return []string{pathValue} + } else if queryValues := request.QueryParameters("category"); len(queryValues) > 0 { + return queryValues + } else { + return []string{""} + } +}