Skip to content

Commit

Permalink
fix(db): Multi-DHTs support and stmt reuse bug (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Apr 9, 2024
1 parent 2acb436 commit 65ab358
Show file tree
Hide file tree
Showing 12 changed files with 357 additions and 100 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.4.1] - 2014-04-09
### Added
- added multi-dht support on `DB` (#31)

### Fixes
- stmt that is in using should not be close in background clean worker (#31)

## [1.4.0] - 2014-04-06
### Added
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ You’ll find the SQLE package useful if you’re not a fan of full-featured ORM
> All examples on https://go.dev/doc/tutorial/database-access can directly work with `sqle.DB` instance.
>
> See full examples on https://github.com/yaitoo/auth
>
### Install SQLE
- install latest commit from `main` branch
```
Expand Down
24 changes: 17 additions & 7 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ import (
"database/sql"
"log"
"sync"
"time"
)

type Context struct {
*sql.DB
sync.Mutex
_ noCopy

index int
stmts map[string]*cachedStmt
stmts map[string]*Stmt
stmtsMutex sync.Mutex

stmtMaxIdleTime time.Duration
Index int
}

func (db *Context) Query(query string, args ...any) (*Rows, error) {
Expand All @@ -32,13 +35,14 @@ func (db *Context) QueryBuilder(ctx context.Context, b *Builder) (*Rows, error)

func (db *Context) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
var rows *sql.Rows
var stmt *sql.Stmt
var stmt *Stmt
var err error
if len(args) > 0 {
stmt, err = db.prepareStmt(ctx, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
stmt.Reuse()
return nil, err
}
}
Expand All @@ -50,7 +54,7 @@ func (db *Context) QueryContext(ctx context.Context, query string, args ...any)
}
}

return &Rows{Rows: rows, query: query}, nil
return &Rows{Rows: rows, stmt: stmt, query: query}, nil
}

func (db *Context) QueryRow(query string, args ...any) *Row {
Expand All @@ -71,28 +75,32 @@ func (db *Context) QueryRowBuilder(ctx context.Context, b *Builder) *Row {

func (db *Context) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
var rows *sql.Rows
var stmt *sql.Stmt
var stmt *Stmt
var err error

if len(args) > 0 {
stmt, err = db.prepareStmt(ctx, query)
if err != nil {
return &Row{
err: err,
stmt: stmt,
query: query,
}
}
rows, err = stmt.QueryContext(ctx, args...)
return &Row{
rows: rows,
err: err,
stmt: stmt,
query: query,

rows: rows,
}
}

rows, err = db.DB.QueryContext(ctx, query, args...)
return &Row{
rows: rows,
stmt: stmt,
err: err,
query: query,
}
Expand All @@ -118,6 +126,8 @@ func (db *Context) ExecContext(ctx context.Context, query string, args ...any) (
return nil, err
}

defer stmt.Reuse()

return stmt.ExecContext(ctx, args...)
}
return db.DB.ExecContext(context.Background(), query, args...)
Expand All @@ -134,7 +144,7 @@ func (db *Context) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error
return nil, err
}

return &Tx{Tx: tx, cachedStmts: make(map[string]*sql.Stmt)}, nil
return &Tx{Tx: tx, stmts: make(map[string]*sql.Stmt)}, nil
}

func (db *Context) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error {
Expand Down
60 changes: 40 additions & 20 deletions context_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,47 +7,67 @@ import (
"time"
)

type cachedStmt struct {
sync.Mutex
stmt *sql.Stmt
type Stmt struct {
*sql.Stmt
mu sync.Mutex
lastUsed time.Time
isUsing bool
}

func (db *Context) prepareStmt(ctx context.Context, query string) (*sql.Stmt, error) {
func (s *Stmt) Reuse() {
s.mu.Lock()
defer s.mu.Unlock()

s.isUsing = false
}

func (db *Context) prepareStmt(ctx context.Context, query string) (*Stmt, error) {
db.stmtsMutex.Lock()
defer db.stmtsMutex.Unlock()
s, ok := db.stmts[query]

if ok {
s.lastUsed = time.Now()
return s.stmt, nil
s.isUsing = true
return s, nil
}

stmt, err := db.DB.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

db.stmts[query] = &cachedStmt{
stmt: stmt,
s = &Stmt{
Stmt: stmt,
lastUsed: time.Now(),
isUsing: true,
}

return stmt, nil
db.stmts[query] = s

return s, nil
}

func (db *Context) closeIdleStmt() {
for {
<-time.After(StmtMaxIdleTime)

db.stmtsMutex.Lock()
lastActive := time.Now().Add(-1 * time.Minute)
for k, v := range db.stmts {
if v.lastUsed.Before(lastActive) {
delete(db.stmts, k)
go v.stmt.Close() //nolint: errcheck
}
func (db *Context) closeStaleStmt() {
db.stmtsMutex.Lock()
defer db.stmtsMutex.Unlock()

lastActive := time.Now().Add(-db.stmtMaxIdleTime)
for k, s := range db.stmts {
s.mu.Lock()
if !s.isUsing && s.lastUsed.Before(lastActive) {
delete(db.stmts, k)
go s.Stmt.Close() //nolint: errcheck
}
db.stmtsMutex.Unlock()
s.mu.Unlock()
}

}

func (db *Context) checkIdleStmt() {
for {
<-time.After(db.stmtMaxIdleTime)

db.closeStaleStmt()
}
}
160 changes: 159 additions & 1 deletion context_stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ func TestStmt(t *testing.T) {

db := Open(d)

db.stmtMaxIdleTime = 1 * time.Second

tests := []struct {
name string
run func(t *testing.T)
Expand Down Expand Up @@ -87,7 +89,7 @@ func TestStmt(t *testing.T) {
},
},
{
name: "stmt_should_work_in_exec",
name: "stmt_should_work",
run: func(t *testing.T) {
for i := 0; i < 100; i++ {

Expand All @@ -110,6 +112,162 @@ func TestStmt(t *testing.T) {
}
},
},
{
name: "stmt_reuse_should_work_in_exec",
run: func(t *testing.T) {
q := "INSERT INTO `rows`(`id`,`status`) VALUES(?, ?)"

result, err := db.Exec(q, 200, 0)
require.NoError(t, err)
affected, err := result.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1), affected)

s, ok := db.stmts[q]
require.True(t, ok)
require.False(t, s.isUsing)

time.Sleep(2 * time.Second)
db.closeStaleStmt()

// stmt should be closed and released
require.False(t, s.isUsing)

s, ok = db.stmts[q]
require.False(t, ok)
require.Nil(t, s)

},
},
{
name: "stmt_reuse_should_work_in_rows_scan",
run: func(t *testing.T) {
var id int
q := "SELECT id, 'rows_scan' as reuse FROM rows WHERE id = ?"
rows, err := db.Query(q, 200)
require.NoError(t, err)

s, ok := db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

time.Sleep(2 * time.Second)
db.closeStaleStmt()

// stmt that is in using should not be closed
s, ok = db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

rows.Scan(&id) // nolint: errcheck
require.False(t, s.isUsing)

db.closeStaleStmt()

// stmt should be closed and released
s, ok = db.stmts[q]
require.False(t, ok)
require.Nil(t, s)
},
},
{
name: "stmt_reuse_should_work_in_rows_bind",
run: func(t *testing.T) {
var r struct {
ID int
}

q := "SELECT id, 'rows_bind' as reuse FROM rows WHERE id = ?"
rows, err := db.Query(q, 200)
require.NoError(t, err)

s, ok := db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

time.Sleep(2 * time.Second)
db.closeStaleStmt()

// stmt that is in using should not be closed
s, ok = db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

rows.Bind(&r) // nolint: errcheck
require.False(t, s.isUsing)

db.closeStaleStmt()

// stmt should be closed and released
s, ok = db.stmts[q]
require.False(t, ok)
require.Nil(t, s)
},
},
{
name: "stmt_reuse_should_work_in_row_scan",
run: func(t *testing.T) {
var id int
q := "SELECT id, 'row_scan' as reuse FROM rows WHERE id = ?"
row := db.QueryRow(q, 200)
require.NoError(t, err)

s, ok := db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

time.Sleep(2 * time.Second)
db.closeStaleStmt()

// stmt that is in using should not be closed
s, ok = db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

row.Scan(&id) // nolint: errcheck
require.False(t, s.isUsing)

db.closeStaleStmt()

// stmt should be closed and released
s, ok = db.stmts[q]
require.False(t, ok)
require.Nil(t, s)
},
},
{
name: "stmt_reuse_should_work_in_row_bind",
run: func(t *testing.T) {
var r struct {
ID int
}
q := "SELECT id, 'row_bind' as reuse FROM rows WHERE id = ?"
row, err := db.Query(q, 200)
require.NoError(t, err)

s, ok := db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

time.Sleep(2 * time.Second)
db.closeStaleStmt()

// stmt that is in using should not be closed
s, ok = db.stmts[q]
require.True(t, ok)
require.True(t, s.isUsing)

row.Bind(&r) // nolint: errcheck
require.False(t, s.isUsing)

db.closeStaleStmt()

// stmt should be closed and released
s, ok = db.stmts[q]
require.False(t, ok)
require.Nil(t, s)
},
},
}

for _, test := range tests {
Expand Down
Loading

0 comments on commit 65ab358

Please sign in to comment.