From cdcc291f254e6488d6665063e08f2861fe2447b1 Mon Sep 17 00:00:00 2001 From: crazycs Date: Sun, 29 Dec 2024 01:33:42 +0800 Subject: [PATCH] domain: Optimize GetDomain api (#58550) ref pingcap/tidb#56649 --- pkg/domain/domainctx.go | 23 ++++----------------- pkg/domain/domainctx_test.go | 8 +++---- pkg/executor/executor_pkg_test.go | 2 +- pkg/executor/executor_required_rows_test.go | 2 +- pkg/executor/join/joiner_test.go | 2 +- pkg/executor/select_test.go | 2 +- pkg/planner/core/logical_plans_test.go | 2 +- pkg/planner/core/mock.go | 2 +- pkg/session/bootstrap_test.go | 4 ++-- pkg/session/session.go | 11 ++++++++-- pkg/util/context/context.go | 3 +++ pkg/util/mock/context.go | 13 +++++++++++- 12 files changed, 39 insertions(+), 35 deletions(-) diff --git a/pkg/domain/domainctx.go b/pkg/domain/domainctx.go index 55e96cd48d920..18ec7769ee125 100644 --- a/pkg/domain/domainctx.go +++ b/pkg/domain/domainctx.go @@ -18,26 +18,11 @@ import ( contextutil "github.com/pingcap/tidb/pkg/util/context" ) -// domainKeyType is a dummy type to avoid naming collision in context. -type domainKeyType int - -// String defines a Stringer function for debugging and pretty printing. -func (domainKeyType) String() string { - return "domain" -} - -const domainKey domainKeyType = 0 - -// BindDomain binds domain to context. -func BindDomain(ctx contextutil.ValueStoreContext, domain *Domain) { - ctx.SetValue(domainKey, domain) -} - // GetDomain gets domain from context. func GetDomain(ctx contextutil.ValueStoreContext) *Domain { - v, ok := ctx.Value(domainKey).(*Domain) - if !ok { - return nil + v, ok := ctx.GetDomain().(*Domain) + if ok { + return v } - return v + return nil } diff --git a/pkg/domain/domainctx_test.go b/pkg/domain/domainctx_test.go index 0efd628665169..21d1963cbbaef 100644 --- a/pkg/domain/domainctx_test.go +++ b/pkg/domain/domainctx_test.go @@ -23,13 +23,11 @@ import ( func TestDomainCtx(t *testing.T) { ctx := mock.NewContext() - require.NotEqual(t, "", domainKey.String()) - - BindDomain(ctx, nil) + ctx.BindDomain(nil) v := GetDomain(ctx) require.Nil(t, v) - ctx.ClearValue(domainKey) + ctx.BindDomain(&Domain{}) v = GetDomain(ctx) - require.Nil(t, v) + require.NotNil(t, v) } diff --git a/pkg/executor/executor_pkg_test.go b/pkg/executor/executor_pkg_test.go index ec0c15fa7ea57..6582d91c7776a 100644 --- a/pkg/executor/executor_pkg_test.go +++ b/pkg/executor/executor_pkg_test.go @@ -332,7 +332,7 @@ func TestFilterTemporaryTableKeys(t *testing.T) { func TestErrLevelsForResetStmtContext(t *testing.T) { ctx := mock.NewContext() - domain.BindDomain(ctx, &domain.Domain{}) + ctx.BindDomain(&domain.Domain{}) cases := []struct { name string diff --git a/pkg/executor/executor_required_rows_test.go b/pkg/executor/executor_required_rows_test.go index a81dda7e1512a..ec32ea1f6cd7f 100644 --- a/pkg/executor/executor_required_rows_test.go +++ b/pkg/executor/executor_required_rows_test.go @@ -215,7 +215,7 @@ func defaultCtx() sessionctx.Context { ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, ctx.GetSessionVars().MemQuotaQuery) ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1) ctx.GetSessionVars().SnapshotTS = uint64(1) - domain.BindDomain(ctx, domain.NewMockDomain()) + ctx.BindDomain(domain.NewMockDomain()) return ctx } diff --git a/pkg/executor/join/joiner_test.go b/pkg/executor/join/joiner_test.go index 796f1b17ec398..92a3aab004deb 100644 --- a/pkg/executor/join/joiner_test.go +++ b/pkg/executor/join/joiner_test.go @@ -38,7 +38,7 @@ func defaultCtx() sessionctx.Context { ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, ctx.GetSessionVars().MemQuotaQuery) ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1) ctx.GetSessionVars().SnapshotTS = uint64(1) - domain.BindDomain(ctx, domain.NewMockDomain()) + ctx.BindDomain(domain.NewMockDomain()) return ctx } diff --git a/pkg/executor/select_test.go b/pkg/executor/select_test.go index faa342c3b262b..c6dd909bebfe6 100644 --- a/pkg/executor/select_test.go +++ b/pkg/executor/select_test.go @@ -26,7 +26,7 @@ import ( func BenchmarkResetContextOfStmt(b *testing.B) { stmt := &ast.SelectStmt{} ctx := mock.NewContext() - domain.BindDomain(ctx, &domain.Domain{}) + ctx.BindDomain(&domain.Domain{}) for i := 0; i < b.N; i++ { executor.ResetContextOfStmt(ctx, stmt) } diff --git a/pkg/planner/core/logical_plans_test.go b/pkg/planner/core/logical_plans_test.go index b10abfd47fd49..1df812f693159 100644 --- a/pkg/planner/core/logical_plans_test.go +++ b/pkg/planner/core/logical_plans_test.go @@ -118,7 +118,7 @@ func createPlannerSuite() (s *plannerSuite) { if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil { panic(fmt.Sprintf("create mock context panic: %+v", err)) } - domain.BindDomain(ctx, do) + ctx.BindDomain(do) ctx.SetInfoSchema(s.is) s.ctx = ctx s.sctx = ctx diff --git a/pkg/planner/core/mock.go b/pkg/planner/core/mock.go index 8fee1721bc116..122741adc8c6a 100644 --- a/pkg/planner/core/mock.go +++ b/pkg/planner/core/mock.go @@ -429,7 +429,7 @@ func MockContext() *mock.Context { if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil { panic(fmt.Sprintf("create mock context panic: %+v", err)) } - domain.BindDomain(ctx, do) + ctx.BindDomain(do) return ctx } diff --git a/pkg/session/bootstrap_test.go b/pkg/session/bootstrap_test.go index db58b3c5f11b2..7bf76aebf7b31 100644 --- a/pkg/session/bootstrap_test.go +++ b/pkg/session/bootstrap_test.go @@ -178,7 +178,7 @@ func TestBootstrapWithError(t *testing.T) { dom, err := domap.Get(store) require.NoError(t, err) require.NoError(t, dom.Start(ddl.Bootstrap)) - domain.BindDomain(se, dom) + se.dom = dom b, err := checkBootstrapped(se) require.False(t, b) require.NoError(t, err) @@ -2409,7 +2409,7 @@ func TestTiDBUpgradeToVer211(t *testing.T) { require.NoError(t, err) require.Less(t, int64(ver210), ver) - domain.BindDomain(seV210, dom) + seV210.(*session).dom = dom r := MustExecToRecodeSet(t, seV210, "select count(summary) from mysql.tidb_background_subtask_history;") req := r.NewChunk(nil) err = r.Next(ctx, req) diff --git a/pkg/session/session.go b/pkg/session/session.go index 8fe917b358f43..2d173bb2dc2b8 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -182,6 +182,8 @@ type session struct { currentCtx context.Context // only use for runtime.trace, Please NEVER use it. currentPlan base.Plan + // dom is *domain.Domain, use `any` to avoid import cycle. + dom any store kv.Storage sessionPlanCache sessionctx.SessionPlanCache @@ -3789,6 +3791,7 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { return nil, err } s := &session{ + dom: dom, store: store, ddlOwnerManager: dom.DDL().OwnerManager(), client: store.GetClient(), @@ -3808,7 +3811,6 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { s.lockedTables = make(map[int64]model.TableLockTpInfo) s.advisoryLocks = make(map[string]*advisoryLock) - domain.BindDomain(s, dom) // session implements variable.GlobalVarAccessor. Bind it to ctx. s.sessionVars.GlobalVarsAccessor = s s.txn.init() @@ -3852,6 +3854,7 @@ func detachStatsCollector(s *session) *session { // a lock context, which cause we can't call createSession directly. func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) { s := &session{ + dom: dom, store: store, sessionVars: variable.NewSessionVars(nil), client: store.GetClient(), @@ -3864,7 +3867,6 @@ func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er s.tblctx = tblsession.NewMutateContext(s) s.mu.values = make(map[fmt.Stringer]any) s.lockedTables = make(map[int64]model.TableLockTpInfo) - domain.BindDomain(s, dom) // session implements variable.GlobalVarAccessor. Bind it to ctx. s.sessionVars.GlobalVarsAccessor = s s.txn.init() @@ -4674,3 +4676,8 @@ func (s *session) GetCursorTracker() cursor.Tracker { func (s *session) GetCommitWaitGroup() *sync.WaitGroup { return &s.commitWaitGroup } + +// GetDomain get domain from session. +func (s *session) GetDomain() any { + return s.dom +} diff --git a/pkg/util/context/context.go b/pkg/util/context/context.go index 8c265c43cde90..6d0ea1d6f7ca9 100644 --- a/pkg/util/context/context.go +++ b/pkg/util/context/context.go @@ -29,6 +29,9 @@ type ValueStoreContext interface { // ClearValue clears the value associated with this context for key. ClearValue(key fmt.Stringer) + + // GetDomain returns the domain. + GetDomain() any } var contextIDGenerator atomic.Uint64 diff --git a/pkg/util/mock/context.go b/pkg/util/mock/context.go index ba79e993754d7..14ed7fc40e911 100644 --- a/pkg/util/mock/context.go +++ b/pkg/util/mock/context.go @@ -66,7 +66,8 @@ var ( type Context struct { planctx.EmptyPlanContextExtended *sessionexpr.ExprContext - txn wrapTxn // mock global variable + txn wrapTxn // mock global variable + dom any Store kv.Storage // mock global variable ctx context.Context sm util.SessionManager @@ -639,6 +640,16 @@ func (*Context) GetCommitWaitGroup() *sync.WaitGroup { return nil } +// BindDomain bind domain into ctx. +func (c *Context) BindDomain(dom any) { + c.dom = dom +} + +// GetDomain get domain from ctx. +func (c *Context) GetDomain() any { + return c.dom +} + // NewContextDeprecated creates a new mocked sessionctx.Context. // Deprecated: This method is only used for some legacy code. // DO NOT use mock.Context in new production code, and use the real Context instead.