Skip to content

Commit

Permalink
domain: Optimize GetDomain api (#58550)
Browse files Browse the repository at this point in the history
ref #56649
  • Loading branch information
crazycs520 authored Dec 28, 2024
1 parent be6396c commit cdcc291
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 35 deletions.
23 changes: 4 additions & 19 deletions pkg/domain/domainctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,11 @@ import (
contextutil "github.com/pingcap/tidb/pkg/util/context"
)

// domainKeyType is a dummy type to avoid naming collision in context.
type domainKeyType int

// String defines a Stringer function for debugging and pretty printing.
func (domainKeyType) String() string {
return "domain"
}

const domainKey domainKeyType = 0

// BindDomain binds domain to context.
func BindDomain(ctx contextutil.ValueStoreContext, domain *Domain) {
ctx.SetValue(domainKey, domain)
}

// GetDomain gets domain from context.
func GetDomain(ctx contextutil.ValueStoreContext) *Domain {
v, ok := ctx.Value(domainKey).(*Domain)
if !ok {
return nil
v, ok := ctx.GetDomain().(*Domain)
if ok {
return v
}
return v
return nil
}
8 changes: 3 additions & 5 deletions pkg/domain/domainctx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ import (

func TestDomainCtx(t *testing.T) {
ctx := mock.NewContext()
require.NotEqual(t, "", domainKey.String())

BindDomain(ctx, nil)
ctx.BindDomain(nil)
v := GetDomain(ctx)
require.Nil(t, v)

ctx.ClearValue(domainKey)
ctx.BindDomain(&Domain{})
v = GetDomain(ctx)
require.Nil(t, v)
require.NotNil(t, v)
}
2 changes: 1 addition & 1 deletion pkg/executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func TestFilterTemporaryTableKeys(t *testing.T) {

func TestErrLevelsForResetStmtContext(t *testing.T) {
ctx := mock.NewContext()
domain.BindDomain(ctx, &domain.Domain{})
ctx.BindDomain(&domain.Domain{})

cases := []struct {
name string
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func defaultCtx() sessionctx.Context {
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, ctx.GetSessionVars().MemQuotaQuery)
ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1)
ctx.GetSessionVars().SnapshotTS = uint64(1)
domain.BindDomain(ctx, domain.NewMockDomain())
ctx.BindDomain(domain.NewMockDomain())
return ctx
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/join/joiner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func defaultCtx() sessionctx.Context {
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(-1, ctx.GetSessionVars().MemQuotaQuery)
ctx.GetSessionVars().StmtCtx.DiskTracker = disk.NewTracker(-1, -1)
ctx.GetSessionVars().SnapshotTS = uint64(1)
domain.BindDomain(ctx, domain.NewMockDomain())
ctx.BindDomain(domain.NewMockDomain())
return ctx
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
func BenchmarkResetContextOfStmt(b *testing.B) {
stmt := &ast.SelectStmt{}
ctx := mock.NewContext()
domain.BindDomain(ctx, &domain.Domain{})
ctx.BindDomain(&domain.Domain{})
for i := 0; i < b.N; i++ {
executor.ResetContextOfStmt(ctx, stmt)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/logical_plans_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func createPlannerSuite() (s *plannerSuite) {
if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil {
panic(fmt.Sprintf("create mock context panic: %+v", err))
}
domain.BindDomain(ctx, do)
ctx.BindDomain(do)
ctx.SetInfoSchema(s.is)
s.ctx = ctx
s.sctx = ctx
Expand Down
2 changes: 1 addition & 1 deletion pkg/planner/core/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func MockContext() *mock.Context {
if err := do.CreateStatsHandle(ctx, initStatsCtx); err != nil {
panic(fmt.Sprintf("create mock context panic: %+v", err))
}
domain.BindDomain(ctx, do)
ctx.BindDomain(do)
return ctx
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/session/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestBootstrapWithError(t *testing.T) {
dom, err := domap.Get(store)
require.NoError(t, err)
require.NoError(t, dom.Start(ddl.Bootstrap))
domain.BindDomain(se, dom)
se.dom = dom
b, err := checkBootstrapped(se)
require.False(t, b)
require.NoError(t, err)
Expand Down Expand Up @@ -2409,7 +2409,7 @@ func TestTiDBUpgradeToVer211(t *testing.T) {
require.NoError(t, err)
require.Less(t, int64(ver210), ver)

domain.BindDomain(seV210, dom)
seV210.(*session).dom = dom
r := MustExecToRecodeSet(t, seV210, "select count(summary) from mysql.tidb_background_subtask_history;")
req := r.NewChunk(nil)
err = r.Next(ctx, req)
Expand Down
11 changes: 9 additions & 2 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ type session struct {
currentCtx context.Context // only use for runtime.trace, Please NEVER use it.
currentPlan base.Plan

// dom is *domain.Domain, use `any` to avoid import cycle.
dom any
store kv.Storage

sessionPlanCache sessionctx.SessionPlanCache
Expand Down Expand Up @@ -3789,6 +3791,7 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) {
return nil, err
}
s := &session{
dom: dom,
store: store,
ddlOwnerManager: dom.DDL().OwnerManager(),
client: store.GetClient(),
Expand All @@ -3808,7 +3811,6 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) {
s.lockedTables = make(map[int64]model.TableLockTpInfo)
s.advisoryLocks = make(map[string]*advisoryLock)

domain.BindDomain(s, dom)
// session implements variable.GlobalVarAccessor. Bind it to ctx.
s.sessionVars.GlobalVarsAccessor = s
s.txn.init()
Expand Down Expand Up @@ -3852,6 +3854,7 @@ func detachStatsCollector(s *session) *session {
// a lock context, which cause we can't call createSession directly.
func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) {
s := &session{
dom: dom,
store: store,
sessionVars: variable.NewSessionVars(nil),
client: store.GetClient(),
Expand All @@ -3864,7 +3867,6 @@ func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er
s.tblctx = tblsession.NewMutateContext(s)
s.mu.values = make(map[fmt.Stringer]any)
s.lockedTables = make(map[int64]model.TableLockTpInfo)
domain.BindDomain(s, dom)
// session implements variable.GlobalVarAccessor. Bind it to ctx.
s.sessionVars.GlobalVarsAccessor = s
s.txn.init()
Expand Down Expand Up @@ -4674,3 +4676,8 @@ func (s *session) GetCursorTracker() cursor.Tracker {
func (s *session) GetCommitWaitGroup() *sync.WaitGroup {
return &s.commitWaitGroup
}

// GetDomain get domain from session.
func (s *session) GetDomain() any {
return s.dom
}
3 changes: 3 additions & 0 deletions pkg/util/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type ValueStoreContext interface {

// ClearValue clears the value associated with this context for key.
ClearValue(key fmt.Stringer)

// GetDomain returns the domain.
GetDomain() any
}

var contextIDGenerator atomic.Uint64
Expand Down
13 changes: 12 additions & 1 deletion pkg/util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ var (
type Context struct {
planctx.EmptyPlanContextExtended
*sessionexpr.ExprContext
txn wrapTxn // mock global variable
txn wrapTxn // mock global variable
dom any
Store kv.Storage // mock global variable
ctx context.Context
sm util.SessionManager
Expand Down Expand Up @@ -639,6 +640,16 @@ func (*Context) GetCommitWaitGroup() *sync.WaitGroup {
return nil
}

// BindDomain bind domain into ctx.
func (c *Context) BindDomain(dom any) {
c.dom = dom
}

// GetDomain get domain from ctx.
func (c *Context) GetDomain() any {
return c.dom
}

// NewContextDeprecated creates a new mocked sessionctx.Context.
// Deprecated: This method is only used for some legacy code.
// DO NOT use mock.Context in new production code, and use the real Context instead.
Expand Down

0 comments on commit cdcc291

Please sign in to comment.