diff --git a/exp/zapslog/handler.go b/exp/zapslog/handler.go index 982d9bccd..d9d800951 100644 --- a/exp/zapslog/handler.go +++ b/exp/zapslog/handler.go @@ -32,13 +32,17 @@ import ( "go.uber.org/zap/zapcore" ) +// ContextFieldExtractor can be used to extract a field from a context. +type ContextFieldExtractor func(ctx context.Context) []zapcore.Field + // Handler implements the slog.Handler by writing to a zap Core. type Handler struct { - core zapcore.Core - name string // logger name - addCaller bool - addStackAt slog.Level - callerSkip int + core zapcore.Core + name string // logger name + addCaller bool + addStackAt slog.Level + callerSkip int + contextFieldExtractors []ContextFieldExtractor } // NewHandler builds a [Handler] that writes to the supplied [zapcore.Core] @@ -157,7 +161,14 @@ func (h *Handler) Handle(ctx context.Context, record slog.Record) error { ce.Stack = stacktrace.Take(3 + h.callerSkip) } - fields := make([]zapcore.Field, 0, record.NumAttrs()) + var contextFields []zapcore.Field + for _, extractor := range h.contextFieldExtractors { + contextFields = append(contextFields, extractor(ctx)...) + } + + fields := make([]zapcore.Field, 0, record.NumAttrs()+len(contextFields)) + fields = append(fields, contextFields...) + record.Attrs(func(attr slog.Attr) bool { fields = append(fields, convertAttrToField(attr)) return true diff --git a/exp/zapslog/handler_test.go b/exp/zapslog/handler_test.go index 68339df62..487129a6d 100644 --- a/exp/zapslog/handler_test.go +++ b/exp/zapslog/handler_test.go @@ -23,16 +23,20 @@ package zapslog import ( + "context" "log/slog" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" ) +type testContextKey string + func TestAddCaller(t *testing.T) { t.Parallel() @@ -189,3 +193,19 @@ func TestAttrKinds(t *testing.T) { }, entry.ContextMap()) } + +func TestContextFieldExtractor(t *testing.T) { + key := testContextKey("testkey") + fac, logs := observer.New(zapcore.DebugLevel) + ctx := context.WithValue(context.Background(), key, "testvalue") + + sl := slog.New(NewHandler(fac, WithContextFieldExtractors(func(ctx context.Context) []zapcore.Field { + v := ctx.Value(key).(string) + return []zapcore.Field{zap.String("testkey", v)} + }))) + sl.InfoContext(ctx, "msg") + lines := logs.TakeAll() + + require.Len(t, lines, 1) + require.Equal(t, "testvalue", lines[0].ContextMap()["testkey"]) +} diff --git a/exp/zapslog/options.go b/exp/zapslog/options.go index 0eb5c8c0e..859524213 100644 --- a/exp/zapslog/options.go +++ b/exp/zapslog/options.go @@ -70,3 +70,10 @@ func AddStacktraceAt(lvl slog.Level) Option { log.addStackAt = lvl }) } + +// .WithContextFieldExtractors configures the Logger to extract fields from the context. +func WithContextFieldExtractors(extractors ...ContextFieldExtractor) Option { + return optionFunc(func(log *Handler) { + log.contextFieldExtractors = append(log.contextFieldExtractors, extractors...) + }) +}