From 205d264f019378a3d337a8c41cb2fc9284ced556 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Mon, 5 Jun 2023 21:35:00 +0800 Subject: [PATCH] feat(server): support multi-categories filtering (#704) --- master/rest.go | 12 ++++++------ server/rest.go | 30 +++++++++++++++++------------- server/rest_test.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 19 deletions(-) diff --git a/master/rest.go b/master/rest.go index d53e5030c..630fbd3f6 100644 --- a/master/rest.go +++ b/master/rest.go @@ -756,7 +756,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon // parse arguments recommender := request.PathParameter("recommender") userId := request.PathParameter("user-id") - category := request.PathParameter("category") + categories := []string{request.PathParameter("category")} n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN) if err != nil { server.BadRequest(response, err) @@ -765,13 +765,13 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon var results []string switch recommender { case "offline": - results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendOffline) + results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendOffline) case "collaborative": - results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendCollaborative) + results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendCollaborative) case "user_based": - results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendUserBased) + results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendUserBased) case "item_based": - results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendItemBased) + results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendItemBased) case "_": recommenders := []server.Recommender{m.RecommendOffline} for _, recommender := range m.Config.Recommend.Online.FallbackRecommend { @@ -791,7 +791,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon return } } - results, err = m.Recommend(ctx, response, userId, category, n, recommenders...) + results, err = m.Recommend(ctx, response, userId, categories, n, recommenders...) } if err != nil { server.InternalServerError(response, err) diff --git a/server/rest.go b/server/rest.go index 3260c0cfa..f0375f85d 100644 --- a/server/rest.go +++ b/server/rest.go @@ -492,6 +492,7 @@ func (s *RestServer) CreateWebService() { Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}). Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")). Param(ws.PathParameter("user-id", "ID of the user to get recommendation").DataType("string")). + Param(ws.QueryParameter("category", "Category of the returned items (support multi-categories filtering)").DataType("string")). Param(ws.QueryParameter("write-back-type", "Type of write back feedback").DataType("string")). Param(ws.QueryParameter("write-back-delay", "Timestamp delay of write back feedback (format 0h0m0s)").DataType("string")). Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")). @@ -684,11 +685,11 @@ func (s *RestServer) getCollaborative(request *restful.Request, response *restfu // 1. If there are recommendations in cache, return cached recommendations. // 2. If there are historical interactions of the users, return similar items. // 3. Otherwise, return fallback recommendation (popular/latest). -func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, userId, category string, n int, recommenders ...Recommender) ([]string, error) { +func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, userId string, categories []string, n int, recommenders ...Recommender) ([]string, error) { initStart := time.Now() // create context - recommendCtx, err := s.createRecommendContext(ctx, userId, category, n) + recommendCtx, err := s.createRecommendContext(ctx, userId, categories, n) if err != nil { return nil, errors.Trace(err) } @@ -727,7 +728,7 @@ func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, type recommendContext struct { context context.Context userId string - category string + categories []string userFeedback []data.Feedback n int results []string @@ -750,7 +751,7 @@ type recommendContext struct { loadPopularTime time.Duration } -func (s *RestServer) createRecommendContext(ctx context.Context, userId, category string, n int) (*recommendContext, error) { +func (s *RestServer) createRecommendContext(ctx context.Context, userId string, categories []string, n int) (*recommendContext, error) { // pull historical feedback userFeedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now()) if err != nil { @@ -764,7 +765,7 @@ func (s *RestServer) createRecommendContext(ctx context.Context, userId, categor } return &recommendContext{ userId: userId, - category: category, + categories: categories, n: n, excludeSet: excludeSet, userFeedback: userFeedback, @@ -777,7 +778,7 @@ type Recommender func(ctx *recommendContext) error func (s *RestServer) RecommendOffline(ctx *recommendContext) error { if len(ctx.results) < ctx.n { start := time.Now() - recommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.OfflineRecommend, ctx.userId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize) + recommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.OfflineRecommend, ctx.userId, ctx.categories, 0, s.Config.Recommend.CacheSize) if err != nil { return errors.Trace(err) } @@ -797,7 +798,7 @@ func (s *RestServer) RecommendOffline(ctx *recommendContext) error { func (s *RestServer) RecommendCollaborative(ctx *recommendContext) error { if len(ctx.results) < ctx.n { start := time.Now() - collaborativeRecommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.CollaborativeRecommend, ctx.userId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize) + collaborativeRecommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.CollaborativeRecommend, ctx.userId, ctx.categories, 0, s.Config.Recommend.CacheSize) if err != nil { return errors.Trace(err) } @@ -836,7 +837,7 @@ func (s *RestServer) RecommendUserBased(ctx *recommendContext) error { if err != nil { return errors.Trace(err) } - if ctx.category == "" || funk.ContainsString(item.Categories, ctx.category) { + if funk.Equal(ctx.categories, []string{""}) || funk.Subset(ctx.categories, item.Categories) { candidates[feedback.ItemId] += user.Score } } @@ -876,7 +877,7 @@ func (s *RestServer) RecommendItemBased(ctx *recommendContext) error { candidates := make(map[string]float64) for _, feedback := range userFeedback { // load similar items - similarItems, err := s.CacheClient.SearchDocuments(ctx.context, cache.ItemNeighbors, feedback.ItemId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize) + similarItems, err := s.CacheClient.SearchDocuments(ctx.context, cache.ItemNeighbors, feedback.ItemId, ctx.categories, 0, s.Config.Recommend.CacheSize) if err != nil { return errors.Trace(err) } @@ -906,7 +907,7 @@ func (s *RestServer) RecommendItemBased(ctx *recommendContext) error { func (s *RestServer) RecommendLatest(ctx *recommendContext) error { if len(ctx.results) < ctx.n { start := time.Now() - items, err := s.CacheClient.SearchDocuments(ctx.context, cache.LatestItems, "", []string{ctx.category}, 0, s.Config.Recommend.CacheSize) + items, err := s.CacheClient.SearchDocuments(ctx.context, cache.LatestItems, "", ctx.categories, 0, s.Config.Recommend.CacheSize) if err != nil { return errors.Trace(err) } @@ -926,7 +927,7 @@ func (s *RestServer) RecommendLatest(ctx *recommendContext) error { func (s *RestServer) RecommendPopular(ctx *recommendContext) error { if len(ctx.results) < ctx.n { start := time.Now() - items, err := s.CacheClient.SearchDocuments(ctx.context, cache.PopularItems, "", []string{ctx.category}, 0, s.Config.Recommend.CacheSize) + items, err := s.CacheClient.SearchDocuments(ctx.context, cache.PopularItems, "", ctx.categories, 0, s.Config.Recommend.CacheSize) if err != nil { return errors.Trace(err) } @@ -955,7 +956,10 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re BadRequest(response, err) return } - category := request.PathParameter("category") + categories := request.QueryParameters("category") + if len(categories) == 0 { + categories = []string{request.PathParameter("category")} + } offset, err := ParseInt(request, "offset", 0) if err != nil { BadRequest(response, err) @@ -986,7 +990,7 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re return } } - results, err := s.Recommend(ctx, response, userId, category, offset+n, recommenders...) + results, err := s.Recommend(ctx, response, userId, categories, offset+n, recommenders...) if err != nil { InternalServerError(response, err) return diff --git a/server/rest_test.go b/server/rest_test.go index 6e8375c6f..a30670e4e 100644 --- a/server/rest_test.go +++ b/server/rest_test.go @@ -1140,6 +1140,36 @@ func (suite *ServerTestSuite) TestGetRecommends() { End() } +func (suite *ServerTestSuite) TestGetRecommendsWithMultiCategories() { + ctx := context.Background() + t := suite.T() + // insert recommendation + err := suite.CacheClient.AddDocuments(ctx, cache.OfflineRecommend, "0", []cache.Document{ + {Id: "1", Score: 1, Categories: []string{""}}, + {Id: "2", Score: 2, Categories: []string{"", "2"}}, + {Id: "3", Score: 3, Categories: []string{"", "3"}}, + {Id: "4", Score: 4, Categories: []string{"", "2"}}, + {Id: "5", Score: 5, Categories: []string{"", "5"}}, + {Id: "6", Score: 6, Categories: []string{"", "2", "3"}}, + {Id: "7", Score: 7, Categories: []string{"", "7"}}, + {Id: "8", Score: 8, Categories: []string{"", "2"}}, + {Id: "9", Score: 9, Categories: []string{"", "3"}}, + }) + suite.NoError(err) + apitest.New(). + Handler(suite.handler). + Get("/api/recommend/0"). + Header("X-API-Key", apiKey). + QueryCollection(map[string][]string{ + "n": []string{"3"}, + "category": []string{"2", "3"}, + }). + Expect(t). + Status(http.StatusOK). + Body(suite.marshal([]string{"6"})). + End() +} + func (suite *ServerTestSuite) TestGetRecommendsWithReplacement() { ctx := context.Background() t := suite.T()