diff --git a/middleware.go b/middleware.go index f46e526d..c14e2e16 100644 --- a/middleware.go +++ b/middleware.go @@ -147,6 +147,7 @@ func (m *validateRequestMiddleware) Handle(next Handler) Handler { var errors []error if m.QueryRules != nil { opt := &validation.Options{ + Context: r.Context(), // TODO test context Data: r.Query, Rules: m.QueryRules(r).AsRules(), ConvertSingleValueArrays: true, @@ -168,6 +169,7 @@ func (m *validateRequestMiddleware) Handle(next Handler) Handler { } if m.BodyRules != nil { opt := &validation.Options{ + Context: r.Context(), Data: r.Data, Rules: m.BodyRules(r).AsRules(), ConvertSingleValueArrays: !strings.HasPrefix(contentType, "application/json"), diff --git a/middleware_test.go b/middleware_test.go index 29986388..d5c450e4 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -2,6 +2,7 @@ package goyave import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -22,6 +23,8 @@ import ( _ "goyave.dev/goyave/v5/database/dialect/sqlite" ) +type testCtxKey struct{} + func TestMiddlewareHolder(t *testing.T) { m1 := &recoveryMiddleware{} m2 := &languageMiddleware{} @@ -317,10 +320,11 @@ func TestValidateMiddleware(t *testing.T) { return validation.RuleSet{{Path: "param", Rules: validation.List{validation.Required(), &testValidator{ validateFunc: func(v *testValidator, ctx *validation.Context) bool { assert.Equal(t, request, ctx.Extra[validation.ExtraRequest{}]) - assert.NotNil(t, v.DB()) assert.NotNil(t, v.Config()) + assert.NotNil(t, v.DB()) assert.NotNil(t, v.Logger()) assert.NotNil(t, v.Lang()) + assert.Equal(t, "test-value", ctx.Context.Value(testCtxKey{})) return false }, }}}} @@ -396,6 +400,7 @@ func TestValidateMiddleware(t *testing.T) { assert.NotNil(t, v.DB()) assert.NotNil(t, v.Logger()) assert.NotNil(t, v.Lang()) + assert.Equal(t, "test-value", ctx.Context.Value(testCtxKey{})) return false }, }}}} @@ -501,6 +506,7 @@ func TestValidateMiddleware(t *testing.T) { m.Init(server) request := NewRequest(httptest.NewRequest(http.MethodGet, "/test", nil)) + request.WithContext(context.WithValue(request.Context(), testCtxKey{}, "test-value")) request.Lang = server.Lang.GetDefault() request.Query = c.query request.Data = c.data diff --git a/validation/validator.go b/validation/validator.go index 8c031309..c10a2918 100644 --- a/validation/validator.go +++ b/validation/validator.go @@ -1,6 +1,7 @@ package validation import ( + "context" "reflect" "strings" "time" @@ -101,8 +102,10 @@ func (c *component) Logger() *slog.Logger { // Only `Data`, `Rules` and `Language` are mandatory. However, it is recommended // to provide values for all the options in case a `Validator` requires them to function. type Options struct { - Data any - Rules Ruler + // Context defaults to `context.Background()` if not provided. + Context context.Context + Data any + Rules Ruler Now time.Time @@ -150,7 +153,9 @@ type AddedValidationError[T addedValidationErrorConstraint] struct { // Context is a structure unique per `Validator.Validate()` execution containing // all the data required by a validator. type Context struct { - Data any + // Context is never nil. Defaults to `context.Background()`. Readonly. + Context context.Context + Data any // Extra the map of Extra from the validation Options. Extra map[any]any @@ -281,6 +286,9 @@ func Validate(options *Options) (*Errors, []error) { if options.Language == nil { options.Language = lang.Default } + if options.Context == nil { + options.Context = context.Background() + } rules := options.Rules.AsRules() for _, field := range rules { @@ -364,6 +372,7 @@ func (v *validator) validateField(fieldName string, field *Field, walkData any, errorPath := field.getErrorPath(parentPath, c) ctx := &Context{ + Context: v.options.Context, Data: data, Extra: v.options.Extra, Value: value, diff --git a/validation/validator_test.go b/validation/validator_test.go index e33b9e9d..ec2f6ae6 100644 --- a/validation/validator_test.go +++ b/validation/validator_test.go @@ -2,12 +2,14 @@ package validation import ( "bytes" + "context" "fmt" "testing" "time" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" "goyave.dev/goyave/v5/config" "goyave.dev/goyave/v5/lang" @@ -980,3 +982,61 @@ func (v addErrorsValidator) Validate(ctx *Context) bool { } return true } + +func TestValidateWithContext(t *testing.T) { + cases := []struct { + ctx context.Context + expect func(*testing.T, context.Context) bool + desc string + }{ + { + desc: "default_background", + ctx: nil, + expect: func(t *testing.T, ctx context.Context) bool { + return assert.Equal(t, context.Background(), ctx) + }, + }, + { + desc: "custom", + ctx: context.WithValue(context.Background(), testCtxKey{}, "test-value"), + expect: func(t *testing.T, ctx context.Context) bool { + return assert.Equal(t, "test-value", ctx.Value(testCtxKey{})) + }, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + opts := &Options{ + Data: map[string]any{}, + Context: c.ctx, + Rules: RuleSet{ + { + Path: CurrentElement, + Rules: List{&ctxValidator{t: t, expect: c.expect}}, + }, + }, + } + + validationErrors, errs := Validate(opts) + require.Nil(t, errs) + require.Nil(t, validationErrors) + }) + } +} + +type testCtxKey struct{} + +type ctxValidator struct { + BaseValidator + t *testing.T + expect func(*testing.T, context.Context) bool +} + +func (ctxValidator) Name() string { + return "ctxValidator" +} + +func (v ctxValidator) Validate(ctx *Context) bool { + return v.expect(v.t, ctx.Context) +}