Skip to content

Commit

Permalink
feat(server): support multi-categories filtering (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Jun 5, 2023
1 parent 08b4370 commit 205d264
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 19 deletions.
12 changes: 6 additions & 6 deletions master/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
// parse arguments
recommender := request.PathParameter("recommender")
userId := request.PathParameter("user-id")
category := request.PathParameter("category")
categories := []string{request.PathParameter("category")}
n, err := server.ParseInt(request, "n", m.Config.Server.DefaultN)
if err != nil {
server.BadRequest(response, err)
Expand All @@ -765,13 +765,13 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
var results []string
switch recommender {
case "offline":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendOffline)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendOffline)
case "collaborative":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendCollaborative)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendCollaborative)
case "user_based":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendUserBased)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendUserBased)
case "item_based":
results, err = m.Recommend(ctx, response, userId, category, n, m.RecommendItemBased)
results, err = m.Recommend(ctx, response, userId, categories, n, m.RecommendItemBased)
case "_":
recommenders := []server.Recommender{m.RecommendOffline}
for _, recommender := range m.Config.Recommend.Online.FallbackRecommend {
Expand All @@ -791,7 +791,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon
return
}
}
results, err = m.Recommend(ctx, response, userId, category, n, recommenders...)
results, err = m.Recommend(ctx, response, userId, categories, n, recommenders...)
}
if err != nil {
server.InternalServerError(response, err)
Expand Down
30 changes: 17 additions & 13 deletions server/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ func (s *RestServer) CreateWebService() {
Metadata(restfulspec.KeyOpenAPITags, []string{RecommendationAPITag}).
Param(ws.HeaderParameter("X-API-Key", "API key").DataType("string")).
Param(ws.PathParameter("user-id", "ID of the user to get recommendation").DataType("string")).
Param(ws.QueryParameter("category", "Category of the returned items (support multi-categories filtering)").DataType("string")).
Param(ws.QueryParameter("write-back-type", "Type of write back feedback").DataType("string")).
Param(ws.QueryParameter("write-back-delay", "Timestamp delay of write back feedback (format 0h0m0s)").DataType("string")).
Param(ws.QueryParameter("n", "Number of returned items").DataType("integer")).
Expand Down Expand Up @@ -684,11 +685,11 @@ func (s *RestServer) getCollaborative(request *restful.Request, response *restfu
// 1. If there are recommendations in cache, return cached recommendations.
// 2. If there are historical interactions of the users, return similar items.
// 3. Otherwise, return fallback recommendation (popular/latest).
func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, userId, category string, n int, recommenders ...Recommender) ([]string, error) {
func (s *RestServer) Recommend(ctx context.Context, response *restful.Response, userId string, categories []string, n int, recommenders ...Recommender) ([]string, error) {
initStart := time.Now()

// create context
recommendCtx, err := s.createRecommendContext(ctx, userId, category, n)
recommendCtx, err := s.createRecommendContext(ctx, userId, categories, n)
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -727,7 +728,7 @@ func (s *RestServer) Recommend(ctx context.Context, response *restful.Response,
type recommendContext struct {
context context.Context
userId string
category string
categories []string
userFeedback []data.Feedback
n int
results []string
Expand All @@ -750,7 +751,7 @@ type recommendContext struct {
loadPopularTime time.Duration
}

func (s *RestServer) createRecommendContext(ctx context.Context, userId, category string, n int) (*recommendContext, error) {
func (s *RestServer) createRecommendContext(ctx context.Context, userId string, categories []string, n int) (*recommendContext, error) {
// pull historical feedback
userFeedback, err := s.DataClient.GetUserFeedback(ctx, userId, s.Config.Now())
if err != nil {
Expand All @@ -764,7 +765,7 @@ func (s *RestServer) createRecommendContext(ctx context.Context, userId, categor
}
return &recommendContext{
userId: userId,
category: category,
categories: categories,
n: n,
excludeSet: excludeSet,
userFeedback: userFeedback,
Expand All @@ -777,7 +778,7 @@ type Recommender func(ctx *recommendContext) error
func (s *RestServer) RecommendOffline(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
recommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.OfflineRecommend, ctx.userId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
recommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.OfflineRecommend, ctx.userId, ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -797,7 +798,7 @@ func (s *RestServer) RecommendOffline(ctx *recommendContext) error {
func (s *RestServer) RecommendCollaborative(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
collaborativeRecommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.CollaborativeRecommend, ctx.userId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
collaborativeRecommendation, err := s.CacheClient.SearchDocuments(ctx.context, cache.CollaborativeRecommend, ctx.userId, ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -836,7 +837,7 @@ func (s *RestServer) RecommendUserBased(ctx *recommendContext) error {
if err != nil {
return errors.Trace(err)
}
if ctx.category == "" || funk.ContainsString(item.Categories, ctx.category) {
if funk.Equal(ctx.categories, []string{""}) || funk.Subset(ctx.categories, item.Categories) {
candidates[feedback.ItemId] += user.Score
}
}
Expand Down Expand Up @@ -876,7 +877,7 @@ func (s *RestServer) RecommendItemBased(ctx *recommendContext) error {
candidates := make(map[string]float64)
for _, feedback := range userFeedback {
// load similar items
similarItems, err := s.CacheClient.SearchDocuments(ctx.context, cache.ItemNeighbors, feedback.ItemId, []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
similarItems, err := s.CacheClient.SearchDocuments(ctx.context, cache.ItemNeighbors, feedback.ItemId, ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -906,7 +907,7 @@ func (s *RestServer) RecommendItemBased(ctx *recommendContext) error {
func (s *RestServer) RecommendLatest(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.LatestItems, "", []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.LatestItems, "", ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand All @@ -926,7 +927,7 @@ func (s *RestServer) RecommendLatest(ctx *recommendContext) error {
func (s *RestServer) RecommendPopular(ctx *recommendContext) error {
if len(ctx.results) < ctx.n {
start := time.Now()
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.PopularItems, "", []string{ctx.category}, 0, s.Config.Recommend.CacheSize)
items, err := s.CacheClient.SearchDocuments(ctx.context, cache.PopularItems, "", ctx.categories, 0, s.Config.Recommend.CacheSize)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -955,7 +956,10 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re
BadRequest(response, err)
return
}
category := request.PathParameter("category")
categories := request.QueryParameters("category")
if len(categories) == 0 {
categories = []string{request.PathParameter("category")}
}
offset, err := ParseInt(request, "offset", 0)
if err != nil {
BadRequest(response, err)
Expand Down Expand Up @@ -986,7 +990,7 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re
return
}
}
results, err := s.Recommend(ctx, response, userId, category, offset+n, recommenders...)
results, err := s.Recommend(ctx, response, userId, categories, offset+n, recommenders...)
if err != nil {
InternalServerError(response, err)
return
Expand Down
30 changes: 30 additions & 0 deletions server/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,36 @@ func (suite *ServerTestSuite) TestGetRecommends() {
End()
}

func (suite *ServerTestSuite) TestGetRecommendsWithMultiCategories() {
ctx := context.Background()
t := suite.T()
// insert recommendation
err := suite.CacheClient.AddDocuments(ctx, cache.OfflineRecommend, "0", []cache.Document{
{Id: "1", Score: 1, Categories: []string{""}},
{Id: "2", Score: 2, Categories: []string{"", "2"}},
{Id: "3", Score: 3, Categories: []string{"", "3"}},
{Id: "4", Score: 4, Categories: []string{"", "2"}},
{Id: "5", Score: 5, Categories: []string{"", "5"}},
{Id: "6", Score: 6, Categories: []string{"", "2", "3"}},
{Id: "7", Score: 7, Categories: []string{"", "7"}},
{Id: "8", Score: 8, Categories: []string{"", "2"}},
{Id: "9", Score: 9, Categories: []string{"", "3"}},
})
suite.NoError(err)
apitest.New().
Handler(suite.handler).
Get("/api/recommend/0").
Header("X-API-Key", apiKey).
QueryCollection(map[string][]string{
"n": []string{"3"},
"category": []string{"2", "3"},
}).
Expect(t).
Status(http.StatusOK).
Body(suite.marshal([]string{"6"})).
End()
}

func (suite *ServerTestSuite) TestGetRecommendsWithReplacement() {
ctx := context.Background()
t := suite.T()
Expand Down

0 comments on commit 205d264

Please sign in to comment.