Skip to content

Commit

Permalink
Fix TLS connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jan 13, 2017
1 parent 19b78dc commit d2bcc88
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 22 deletions.
3 changes: 2 additions & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ func (db *DB) conn() (*pool.Conn, error) {
return nil, err
}

cn.SetReadWriteTimeout(db.opt.ReadTimeout, db.opt.WriteTimeout)

if cn.InitedAt.IsZero() {
if err := db.initConn(cn); err != nil {
_ = db.pool.Remove(cn, err)
Expand All @@ -78,7 +80,6 @@ func (db *DB) conn() (*pool.Conn, error) {
cn.InitedAt = time.Now()
}

cn.SetReadWriteTimeout(db.opt.ReadTimeout, db.opt.WriteTimeout)
return cn, nil
}

Expand Down
17 changes: 12 additions & 5 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pg_test

import (
"bytes"
"crypto/tls"
"database/sql"
"fmt"
"net"
Expand All @@ -26,11 +27,17 @@ func TestGinkgo(t *testing.T) {

func pgOptions() *pg.Options {
return &pg.Options{
User: "postgres",
Database: "postgres",
DialTimeout: 30 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
User: "postgres",
Database: "postgres",

TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},

DialTimeout: 30 * time.Second,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,

PoolSize: 10,
PoolTimeout: 30 * time.Second,
IdleTimeout: 10 * time.Second,
Expand Down
25 changes: 18 additions & 7 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,29 @@ func init() {
}

func connect() *pg.DB {
return pg.Connect(&pg.Options{
User: "postgres",
})
return pg.Connect(pgOptions())
}

func ExampleConnect() {
db := pg.Connect(&pg.Options{
User: "postgres",
User: "postgres",
Password: "",
Database: "postgres",
})
err := db.Close()
fmt.Println(err)
// Output: <nil>

var n int
_, err := db.QueryOne(pg.Scan(&n), "SELECT 1")
if err != nil {
panic(err)
}
fmt.Println(n)

err = db.Close()
if err != nil {
panic(err)
}

// Output: 1
}

func ExampleDB_QueryOne() {
Expand Down
2 changes: 1 addition & 1 deletion internal/pool/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (buf *WriteBuffer) WriteByte(c byte) {

func (buf *WriteBuffer) Flush() error {
_, err := buf.w.Write(buf.Bytes)
buf.Bytes = buf.Bytes[:0]
buf.Reset()
return err
}

Expand Down
1 change: 1 addition & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func TestUnixSocket(t *testing.T) {
opt := pgOptions()
opt.Network = "unix"
opt.Addr = "/var/run/postgresql/.s.PGSQL.5432"
opt.TLSConfig = nil
db := pg.Connect(opt)
defer db.Close()

Expand Down
10 changes: 2 additions & 8 deletions messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func enableSSL(cn *pool.Conn, tlsConf *tls.Config) error {
return err
}

b := make([]byte, 1)
b := cn.Buf[:1]
_, err := io.ReadFull(cn.NetConn, b)
if err != nil {
return err
Expand All @@ -118,13 +118,7 @@ func enableSSL(cn *pool.Conn, tlsConf *tls.Config) error {
return errSSLNotSupported
}

if tlsConf == nil {
tlsConf = &tls.Config{
InsecureSkipVerify: true,
}
}
cn.NetConn = tls.Client(cn.NetConn, tlsConf)

cn.SetNetConn(tls.Client(cn.NetConn, tlsConf))
return nil
}

Expand Down

0 comments on commit d2bcc88

Please sign in to comment.