Skip to content

Commit

Permalink
fix: login
Browse files Browse the repository at this point in the history
  • Loading branch information
Vilsol committed Dec 11, 2023
1 parent f70210a commit ce05441
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 62 deletions.
22 changes: 0 additions & 22 deletions db/postgres/sml_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,8 @@ package postgres

import (
"context"
"strings"

"github.com/satisfactorymodding/smr-api/models"
)

func GetSMLVersions(ctx context.Context, filter *models.SMLVersionFilter) []SMLVersion {
var smlVersions []SMLVersion
query := DBCtx(ctx)

if filter != nil {
query = query.Limit(*filter.Limit).
Offset(*filter.Offset).
Order(string(*filter.OrderBy) + " " + string(*filter.Order))

if filter.Search != nil && *filter.Search != "" {
query = query.Where("to_tsvector(name) @@ to_tsquery(?)", strings.ReplaceAll(*filter.Search, " ", " & "))
}
}

query.Preload("Targets").Find(&smlVersions)

return smlVersions
}

func GetSMLLatestVersions(ctx context.Context) *[]SMLVersion {
var smlVersions []SMLVersion

Expand Down
10 changes: 5 additions & 5 deletions db/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,24 @@ func UserHas(ctx context.Context, role *auth.Role, usr *ent.User) bool {
return exist
}

func UserFromGQLContext(ctx context.Context) (*ent.User, error) {
func UserFromGQLContext(ctx context.Context) (*ent.User, bool, error) {
header := ctx.Value(util.ContextHeader{}).(http.Header)
authorization := header.Get("Authorization")

if authorization == "" {
return nil, errors.New("user not logged in")
return nil, true, errors.New("user not logged in")
}

user, err := From(ctx).UserSession.Query().Where(usersession.Token(authorization)).QueryUser().First(ctx)
if err != nil {
if ent.IsNotFound(err) {
return nil, errors.New("user not logged in")
return nil, true, errors.New("user not logged in")
}

return nil, err
return nil, false, err
}

return user, nil
return user, false, nil
}

func UserCanUploadModVersions(ctx context.Context, user *ent.User, modID string) bool {
Expand Down
30 changes: 15 additions & 15 deletions gql/directive.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Directive struct {
}

func canEditMod(ctx context.Context, _ interface{}, next graphql.Resolver, field string) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -53,7 +53,7 @@ func canEditMod(ctx context.Context, _ interface{}, next graphql.Resolver, field
}

func canEditModCompatibility(ctx context.Context, _ interface{}, next graphql.Resolver, field *string) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -74,7 +74,7 @@ func canEditModCompatibility(ctx context.Context, _ interface{}, next graphql.Re
}

func canEditVersion(ctx context.Context, _ interface{}, next graphql.Resolver, field string) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -91,7 +91,7 @@ func canEditVersion(ctx context.Context, _ interface{}, next graphql.Resolver, f
}

func canEditUser(ctx context.Context, obj interface{}, next graphql.Resolver, field string, object bool) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -115,7 +115,7 @@ func canEditUser(ctx context.Context, obj interface{}, next graphql.Resolver, fi
}

func canEditGuide(ctx context.Context, _ interface{}, next graphql.Resolver, field string) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -137,7 +137,7 @@ func canEditGuide(ctx context.Context, _ interface{}, next graphql.Resolver, fie
}

func isLoggedIn(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -150,8 +150,8 @@ func isLoggedIn(ctx context.Context, _ interface{}, next graphql.Resolver) (inte
}

func isNotLoggedIn(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
if err != nil {
user, nonFatal, err := db.UserFromGQLContext(ctx)
if err != nil && !nonFatal {
return nil, err
}

Expand All @@ -167,7 +167,7 @@ func getArgument(ctx context.Context, key string) interface{} {
}

func canApproveMods(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -180,7 +180,7 @@ func canApproveMods(ctx context.Context, _ interface{}, next graphql.Resolver) (
}

func canApproveVersions(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -193,7 +193,7 @@ func canApproveVersions(ctx context.Context, _ interface{}, next graphql.Resolve
}

func canEditUsers(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -206,7 +206,7 @@ func canEditUsers(ctx context.Context, _ interface{}, next graphql.Resolver) (in
}

func canEditSMLVersions(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -219,7 +219,7 @@ func canEditSMLVersions(ctx context.Context, _ interface{}, next graphql.Resolve
}

func canEditBootstrapVersions(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -232,7 +232,7 @@ func canEditBootstrapVersions(ctx context.Context, _ interface{}, next graphql.R
}

func canEditAnnouncements(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -245,7 +245,7 @@ func canEditAnnouncements(ctx context.Context, _ interface{}, next graphql.Resol
}

func canManageTags(ctx context.Context, _ interface{}, next graphql.Resolver) (interface{}, error) {
user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion gql/resolver_guides.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (r *mutationResolver) CreateGuide(ctx context.Context, g generated.NewGuide
return nil, fmt.Errorf("validation failed: %w", err)
}

user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand Down
23 changes: 12 additions & 11 deletions gql/resolver_mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (r *mutationResolver) CreateMod(ctx context.Context, newMod generated.NewMo
SetINNF(newMod.FullDescription, dbMod.SetFullDescription)
SetINNF(newMod.Hidden, dbMod.SetHidden)

user, err := db.UserFromGQLContext(ctx)
user, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -347,7 +347,7 @@ func (r *queryResolver) GetMod(ctx context.Context, modID string) (*generated.Mo
wrapper, ctx := WrapQueryTrace(ctx, "getMod")
defer wrapper.end()

dbMod, err := db.From(ctx).Mod.Get(ctx, modID)
dbMod, err := db.From(ctx).Mod.Query().Where(mod.ID(modID)).WithTags().First(ctx)
if err != nil {
return nil, err
}
Expand All @@ -367,7 +367,7 @@ func (r *queryResolver) GetModByReference(ctx context.Context, modReference stri
wrapper, ctx := WrapQueryTrace(ctx, "getModByReference")
defer wrapper.end()

dbMod, err := db.From(ctx).Mod.Query().Where(mod.ModReference(modReference)).First(ctx)
dbMod, err := db.From(ctx).Mod.Query().Where(mod.ModReference(modReference)).WithTags().First(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -448,7 +448,7 @@ func (r *getModsResolver) Count(ctx context.Context, _ *generated.GetMods) (int,
return 0, err
}

query := db.From(ctx).Debug().Mod.Query()
query := db.From(ctx).Mod.Query()
query = convertModFilter(query, modFilter, false, unapproved)

result, err := query.Count(ctx)
Expand Down Expand Up @@ -477,7 +477,7 @@ func (r *getMyModsResolver) Mods(ctx context.Context, _ *generated.GetMyMods) ([
modFilter.AddField(field.Name)
}

query := db.From(ctx).Debug().Mod.Query()
query := db.From(ctx).Mod.Query()
query = convertModFilter(query, modFilter, false, unapproved)

result, err := query.All(ctx)
Expand All @@ -500,7 +500,7 @@ func (r *getMyModsResolver) Count(ctx context.Context, _ *generated.GetMyMods) (
return 0, err
}

query := db.From(ctx).Debug().Mod.Query()
query := db.From(ctx).Mod.Query()
query = convertModFilter(query, modFilter, false, unapproved)

result, err := query.Count(ctx)
Expand Down Expand Up @@ -677,7 +677,7 @@ func (r *queryResolver) GetModByIDOrReference(ctx context.Context, modIDOrRefere
wrapper, ctx := WrapQueryTrace(ctx, "getModByIdOrReference")
defer wrapper.end()

m, err := db.From(ctx).Mod.Query().Where(mod.Or(
m, err := db.From(ctx).Mod.Query().WithTags().Where(mod.Or(
mod.ID(modIDOrReference),
mod.ModReference(modIDOrReference),
)).First(ctx)
Expand Down Expand Up @@ -738,6 +738,8 @@ func (r *queryResolver) ResolveModVersions(ctx context.Context, filter []*genera
}

func convertModFilter(query *ent.ModQuery, filter *models.ModFilter, count bool, unapproved bool) *ent.ModQuery {
query = query.WithTags()

if len(filter.Ids) > 0 {
query = query.Where(mod.IDIn(filter.Ids...))
} else if len(filter.References) > 0 {
Expand All @@ -749,10 +751,9 @@ func convertModFilter(query *ent.ModQuery, filter *models.ModFilter, count bool,

if *filter.OrderBy != generated.ModFieldsSearch {
if string(*filter.OrderBy) == "last_version_date" {
query = query.Order(sql.OrderByField(
"case when last_version_date is null then 1 else 0 end, last_version_date",
db.OrderToOrder(filter.Order.String()),
).ToFunc())
query = query.Modify(func(s *sql.Selector) {
s.OrderExpr(sql.ExprP("case when last_version_date is null then 1 else 0 end, last_version_date"))
}).Clone()
} else {
query = query.Order(sql.OrderByField(
filter.OrderBy.String(),
Expand Down
8 changes: 4 additions & 4 deletions gql/resolver_users.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (r *queryResolver) GetMe(ctx context.Context) (*generated.User, error) {
wrapper, ctx := WrapQueryTrace(ctx, "getMe")
defer wrapper.end()

result, err := db.UserFromGQLContext(ctx)
result, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -233,7 +233,7 @@ func (r *userResolver) Groups(ctx context.Context, _ *generated.User) ([]*genera
wrapper, ctx := WrapQueryTrace(ctx, "User.guides")
defer wrapper.end()

u, err := db.UserFromGQLContext(ctx)
u, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand All @@ -259,7 +259,7 @@ func (r *userResolver) Roles(ctx context.Context, _ *generated.User) (*generated
wrapper, ctx := WrapQueryTrace(ctx, "User.guides")
defer wrapper.end()

u, err := db.UserFromGQLContext(ctx)
u, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -357,7 +357,7 @@ func (r *mutationResolver) DiscourseSso(ctx context.Context, sso string, sig str
return nil, fmt.Errorf("failed to decode sso: %w", err)
}

u, err := db.UserFromGQLContext(ctx)
u, _, err := db.UserFromGQLContext(ctx)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion tests/announcements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestAnnouncements(t *testing.T) {

// Run Twice to detect any cache issues
for i := 0; i < 2; i++ {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run("Loop"+strconv.Itoa(i), func(t *testing.T) {
var announcementID string

t.Run("Create", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion tests/guides_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestGuides(t *testing.T) {

// Run Twice to detect any cache issues
for i := 0; i < 2; i++ {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run("Loop"+strconv.Itoa(i), func(t *testing.T) {
var guideID string

t.Run("Create", func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion tests/mod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestMods(t *testing.T) {

// Run Twice to detect any cache issues
for i := 0; i < 2; i++ {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run("Loop"+strconv.Itoa(i), func(t *testing.T) {
var objID string

modReference := "hello" + strconv.Itoa(i)
Expand Down
2 changes: 1 addition & 1 deletion tests/sml_versions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestSMLVersions(t *testing.T) {

// Run Twice to detect any cache issues
for i := 0; i < 2; i++ {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run("Loop"+strconv.Itoa(i), func(t *testing.T) {
var objID string

t.Run("Create", func(t *testing.T) {
Expand Down

0 comments on commit ce05441

Please sign in to comment.