Skip to content

Commit

Permalink
fix show subscriptions (#1.0-dev) (#13599)
Browse files Browse the repository at this point in the history
fix show subscriptions

Approved by: @iamlinjunhong, @qingxinhome, @heni02, @nnsgmsone, @ouyuanning, @sukki37
  • Loading branch information
YANGGMM authored Dec 21, 2023
1 parent 9d710a8 commit c758fc0
Show file tree
Hide file tree
Showing 20 changed files with 957 additions and 632 deletions.
5 changes: 4 additions & 1 deletion pkg/frontend/compiler_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,10 @@ func (tcc *TxnCompilerContext) ResolveVariable(varName string, isSystemVar, isGl
}
} else {
_, val, err := tcc.GetSession().GetUserDefinedVar(varName)
return val, err
if val == nil {
return nil, err
}
return val.Value, err
}
}

Expand Down
61 changes: 40 additions & 21 deletions pkg/frontend/mysql_cmd_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,19 @@ var RecordStatement = func(ctx context.Context, ses *Session, proc *process.Proc
if cw != nil {
copy(stmID[:], cw.GetUUID())
statement = cw.GetAst()

ses.ast = statement
text = SubStringFromBegin(envStmt, int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))

execSql := makeExecuteSql(ses, statement)
if len(execSql) != 0 {
bb := strings.Builder{}
bb.WriteString(envStmt)
bb.WriteString(" // ")
bb.WriteString(execSql)
text = SubStringFromBegin(bb.String(), int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))
} else {
text = SubStringFromBegin(envStmt, int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))
}
} else {
stmID = uuid.New()
text = SubStringFromBegin(envStmt, int(ses.GetParameterUnit().SV.LengthOfQueryPrinted))
Expand Down Expand Up @@ -554,7 +565,11 @@ func (mce *MysqlCmdExecutor) handleSelectVariables(ve *tree.VarExpr, cwIndex, cw
if err != nil {
return err
}
row[0] = val
if val != nil {
row[0] = val.Value
} else {
row[0] = nil
}
}

mrs.AddRow(row)
Expand Down Expand Up @@ -671,10 +686,10 @@ func (mce *MysqlCmdExecutor) handleCmdFieldList(requestCtx context.Context, icfl
return err
}

func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree.SetVar) error {
func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree.SetVar, sql string) error {
var err error = nil
var ok bool
setVarFunc := func(system, global bool, name string, value interface{}) error {
setVarFunc := func(system, global bool, name string, value interface{}, sql string) error {
if system {
if global {
err = doCheckRole(ctx, ses)
Expand Down Expand Up @@ -707,7 +722,7 @@ func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree
}
}
} else {
err = ses.SetUserDefinedVar(name, value)
err = ses.SetUserDefinedVar(name, value, sql)
if err != nil {
return err
}
Expand Down Expand Up @@ -737,7 +752,7 @@ func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree
"character_set_client", "character_set_connection", "character_set_results",
}
for _, rb := range replacedBy {
err = setVarFunc(assign.System, assign.Global, rb, value)
err = setVarFunc(assign.System, assign.Global, rb, value, sql)
if err != nil {
return err
}
Expand All @@ -746,7 +761,7 @@ func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree
if !ses.GetTenantInfo().IsSysTenant() {
return moerr.NewInternalError(ses.GetRequestContext(), "only system account can set system variable syspublications")
}
err = setVarFunc(assign.System, assign.Global, name, value)
err = setVarFunc(assign.System, assign.Global, name, value, sql)
if err != nil {
return err
}
Expand All @@ -765,7 +780,7 @@ func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree
cache.invalidate()
}
}
err = setVarFunc(assign.System, assign.Global, name, value)
err = setVarFunc(assign.System, assign.Global, name, value, sql)
if err != nil {
return err
}
Expand All @@ -783,24 +798,24 @@ func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree
cache.invalidate()
}
}
err = setVarFunc(assign.System, assign.Global, name, value)
err = setVarFunc(assign.System, assign.Global, name, value, sql)
if err != nil {
return err
}
} else if name == "runtime_filter_limit_in" {
err = setVarFunc(assign.System, assign.Global, name, value)
err = setVarFunc(assign.System, assign.Global, name, value, sql)
if err != nil {
return err
}
runtime.ProcessLevelRuntime().SetGlobalVariables("runtime_filter_limit_in", value)
} else if name == "runtime_filter_limit_bloom_filter" {
err = setVarFunc(assign.System, assign.Global, name, value)
err = setVarFunc(assign.System, assign.Global, name, value, sql)
if err != nil {
return err
}
runtime.ProcessLevelRuntime().SetGlobalVariables("runtime_filter_limit_bloom_filter", value)
} else {
err = setVarFunc(assign.System, assign.Global, name, value)
err = setVarFunc(assign.System, assign.Global, name, value, sql)
if err != nil {
return err
}
Expand All @@ -812,9 +827,9 @@ func doSetVar(ctx context.Context, mce *MysqlCmdExecutor, ses *Session, sv *tree
/*
handle setvar
*/
func (mce *MysqlCmdExecutor) handleSetVar(ctx context.Context, sv *tree.SetVar) error {
func (mce *MysqlCmdExecutor) handleSetVar(ctx context.Context, sv *tree.SetVar, sql string) error {
ses := mce.GetSession()
err := doSetVar(ctx, mce, ses, sv)
err := doSetVar(ctx, mce, ses, sv, sql)
if err != nil {
return err
}
Expand Down Expand Up @@ -1135,14 +1150,15 @@ func (mce *MysqlCmdExecutor) handleExplainStmt(requestCtx context.Context, stmt
return nil
}

func doPrepareStmt(ctx context.Context, ses *Session, st *tree.PrepareStmt) (*PrepareStmt, error) {
func doPrepareStmt(ctx context.Context, ses *Session, st *tree.PrepareStmt, sql string) (*PrepareStmt, error) {
preparePlan, err := buildPlan(ctx, ses, ses.GetTxnCompileCtx(), st)
if err != nil {
return nil, err
}

prepareStmt := &PrepareStmt{
Name: preparePlan.GetDcl().GetPrepare().GetName(),
Sql: sql,
PreparePlan: preparePlan,
PrepareStmt: st.Stmt,
getFromSendLongData: make(map[int]struct{}),
Expand All @@ -1154,8 +1170,8 @@ func doPrepareStmt(ctx context.Context, ses *Session, st *tree.PrepareStmt) (*Pr
}

// handlePrepareStmt
func (mce *MysqlCmdExecutor) handlePrepareStmt(ctx context.Context, st *tree.PrepareStmt) (*PrepareStmt, error) {
return doPrepareStmt(ctx, mce.GetSession(), st)
func (mce *MysqlCmdExecutor) handlePrepareStmt(ctx context.Context, st *tree.PrepareStmt, sql string) (*PrepareStmt, error) {
return doPrepareStmt(ctx, mce.GetSession(), st, sql)
}

func doPrepareString(ctx context.Context, ses *Session, st *tree.PrepareString) (*PrepareStmt, error) {
Expand All @@ -1174,6 +1190,7 @@ func doPrepareString(ctx context.Context, ses *Session, st *tree.PrepareString)
}
prepareStmt := &PrepareStmt{
Name: preparePlan.GetDcl().GetPrepare().GetName(),
Sql: st.Sql,
PreparePlan: preparePlan,
PrepareStmt: stmts[0],
}
Expand Down Expand Up @@ -2583,6 +2600,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
pu *config.ParameterUnit,
tenant string,
userName string,
sql string,
) (err error) {

var span trace.Span
Expand Down Expand Up @@ -2933,6 +2951,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
err = moerr.NewInternalError(proc.Ctx, "only admin can create subscription")
return
}
st.Sql = sql
case *tree.DropDatabase:
err = inputNameIsInvalid(proc.Ctx, string(st.Name))
if err != nil {
Expand All @@ -2945,7 +2964,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
}
case *tree.PrepareStmt:
selfHandle = true
prepareStmt, err = mce.handlePrepareStmt(requestCtx, st)
prepareStmt, err = mce.handlePrepareStmt(requestCtx, st, sql)
if err != nil {
return err
}
Expand Down Expand Up @@ -2979,7 +2998,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
}
case *tree.SetVar:
selfHandle = true
err = mce.handleSetVar(requestCtx, st)
err = mce.handleSetVar(requestCtx, st, sql)
if err != nil {
return err
}
Expand Down Expand Up @@ -3223,7 +3242,7 @@ func (mce *MysqlCmdExecutor) executeStmt(requestCtx context.Context,
// reset some special stmt for execute statement
switch st := stmt.(type) {
case *tree.SetVar:
err = mce.handleSetVar(requestCtx, st)
err = mce.handleSetVar(requestCtx, st, sql)
if err != nil {
return err
} else {
Expand Down Expand Up @@ -3703,7 +3722,7 @@ func (mce *MysqlCmdExecutor) doComQuery(requestCtx context.Context, input *UserI
ses.proc.UnixTime = proc.UnixTime
}

err = mce.executeStmt(requestCtx, ses, stmt, proc, cw, i, cws, proto, pu, tenant, userNameOnly)
err = mce.executeStmt(requestCtx, ses, stmt, proc, cw, i, cws, proto, pu, tenant, userNameOnly, sqlRecord[i])
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/frontend/mysql_cmd_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func Test_mce_selfhandle(t *testing.T) {
setVar, err := parsers.ParseOne(ctx, dialect.MYSQL, set, 1)
convey.So(err, convey.ShouldBeNil)

err = mce.handleSetVar(ctx, setVar.(*tree.SetVar))
err = mce.handleSetVar(ctx, setVar.(*tree.SetVar), "")
convey.So(err, convey.ShouldBeNil)

req := &Request{
Expand Down Expand Up @@ -871,7 +871,7 @@ func Test_HandlePrepareStmt(t *testing.T) {
}
runTestHandle("handlePrepareStmt", t, func(mce *MysqlCmdExecutor) error {
stmt := stmt.(*tree.PrepareStmt)
_, err := mce.handlePrepareStmt(context.TODO(), stmt)
_, err := mce.handlePrepareStmt(context.TODO(), stmt, "")
return err
})
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/frontend/plsql_interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (interpreter *Interpreter) FlushParam() error {
// save INOUT at session
interpreter.bh.ClearExecResultSet()
// system setvar execution
err := interpreter.ses.SetUserDefinedVar(interpreter.argsMap[k].(*tree.VarExpr).Name, v)
err := interpreter.ses.SetUserDefinedVar(interpreter.argsMap[k].(*tree.VarExpr).Name, v, "")
if err != nil {
return err
}
Expand All @@ -112,7 +112,7 @@ func (interpreter *Interpreter) FlushParam() error {
// save at session
interpreter.bh.ClearExecResultSet()
// system setvar execution
err := interpreter.ses.SetUserDefinedVar(interpreter.argsMap[k].(*tree.VarExpr).Name, v)
err := interpreter.ses.SetUserDefinedVar(interpreter.argsMap[k].(*tree.VarExpr).Name, v, "")
if err != nil {
return err
}
Expand Down Expand Up @@ -262,7 +262,7 @@ func (interpreter *Interpreter) ExecuteSp(stmt tree.Statement, dbName string) (e
return moerr.NewNotSupported(interpreter.ctx, fmt.Sprintf("parameter %s with type INOUT or IN has to have a specified value.", k))
}
// save param to local var scope
(*interpreter.varScope)[len(*interpreter.varScope)-1][strings.ToLower(k)] = value
(*interpreter.varScope)[len(*interpreter.varScope)-1][strings.ToLower(k)] = value.Value
}
} else {
// if param type is INOUT or OUT and the param is not provided with variable expr, raise an error
Expand Down
22 changes: 12 additions & 10 deletions pkg/frontend/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ type Session struct {
sql string

sysVars map[string]interface{}
userDefinedVars map[string]interface{}
userDefinedVars map[string]*UserDefinedVar
gSysVars *GlobalSystemVariables

//the server status
Expand Down Expand Up @@ -543,7 +543,7 @@ func NewSession(proto Protocol, mp *mpool.MPool, pu *config.ParameterUnit,
}
if isNotBackgroundSession {
ses.sysVars = gSysVars.CopySysVarsToSession()
ses.userDefinedVars = make(map[string]interface{})
ses.userDefinedVars = make(map[string]*UserDefinedVar)
ses.prepareStmts = make(map[string]*PrepareStmt)
ses.statsCache = plan2.NewStatsCache()
// For seq init values.
Expand Down Expand Up @@ -1166,11 +1166,13 @@ func (ses *Session) SetPrepareStmt(name string, prepareStmt *PrepareStmt) error
} else {
stmt.Close()
}
isInsertValues, exprList := checkPlanIsInsertValues(ses.proc,
prepareStmt.PreparePlan.GetDcl().GetPrepare().GetPlan())
if isInsertValues {
prepareStmt.proc = ses.proc
prepareStmt.exprList = exprList
if prepareStmt != nil && prepareStmt.PreparePlan != nil {
isInsertValues, exprList := checkPlanIsInsertValues(ses.proc,
prepareStmt.PreparePlan.GetDcl().GetPrepare().GetPlan())
if isInsertValues {
prepareStmt.proc = ses.proc
prepareStmt.exprList = exprList
}
}
ses.prepareStmts[name] = prepareStmt

Expand Down Expand Up @@ -1323,15 +1325,15 @@ func (ses *Session) CopyAllSessionVars() map[string]interface{} {
}

// SetUserDefinedVar sets the user defined variable to the value in session
func (ses *Session) SetUserDefinedVar(name string, value interface{}) error {
func (ses *Session) SetUserDefinedVar(name string, value interface{}, sql string) error {
ses.mu.Lock()
defer ses.mu.Unlock()
ses.userDefinedVars[strings.ToLower(name)] = value
ses.userDefinedVars[strings.ToLower(name)] = &UserDefinedVar{Value: value, Sql: sql}
return nil
}

// GetUserDefinedVar gets value of the user defined variable
func (ses *Session) GetUserDefinedVar(name string) (SystemVariableType, interface{}, error) {
func (ses *Session) GetUserDefinedVar(name string) (SystemVariableType, *UserDefinedVar, error) {
ses.mu.Lock()
defer ses.mu.Unlock()
val, ok := ses.userDefinedVars[strings.ToLower(name)]
Expand Down
2 changes: 1 addition & 1 deletion pkg/frontend/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ func TestVariables(t *testing.T) {
vars := ses.CopyAllSessionVars()
convey.So(len(vars), convey.ShouldNotBeZeroValue)

err := ses.SetUserDefinedVar("abc", 1)
err := ses.SetUserDefinedVar("abc", 1, "")
convey.So(err, convey.ShouldBeNil)

_, _, err = ses.GetUserDefinedVar("abc")
Expand Down
4 changes: 2 additions & 2 deletions pkg/frontend/status_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (pse *PrepareStmtExecutor) ResponseAfterExec(ctx context.Context, ses *Sess

func (pse *PrepareStmtExecutor) ExecuteImpl(ctx context.Context, ses *Session) error {
var err error
pse.prepareStmt, err = doPrepareStmt(ctx, ses, pse.ps)
pse.prepareStmt, err = doPrepareStmt(ctx, ses, pse.ps, "")
if err != nil {
return err
}
Expand Down Expand Up @@ -238,7 +238,7 @@ type SetVarExecutor struct {
}

func (sve *SetVarExecutor) ExecuteImpl(ctx context.Context, ses *Session) error {
return doSetVar(ctx, nil, ses, sve.sv)
return doSetVar(ctx, nil, ses, sve.sv, "")
}

type DeleteExecutor struct {
Expand Down
1 change: 1 addition & 0 deletions pkg/frontend/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func (ec *engineColumnInfo) GetType() types.T {

type PrepareStmt struct {
Name string
Sql string
PreparePlan *plan.Plan
PrepareStmt tree.Statement
ParamTypes []byte
Expand Down
Loading

0 comments on commit c758fc0

Please sign in to comment.