Skip to content

Commit

Permalink
Fix nil struct decoding. Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 24, 2015
1 parent f014462 commit c138b92
Show file tree
Hide file tree
Showing 10 changed files with 497 additions and 424 deletions.
4 changes: 2 additions & 2 deletions encode.go → append.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func appendIface(dst []byte, srci interface{}) []byte {
case QueryAppender:
return src.AppendQuery(dst)
case driver.Valuer:
return appendDriverValue(dst, src)
return appendDriverValuer(dst, src)
default:
return appendValue(dst, reflect.ValueOf(srci))
}
Expand Down Expand Up @@ -329,7 +329,7 @@ func appendInt64Slice(dst []byte, v []int64) []byte {
return dst
}

func appendDriverValue(dst []byte, v driver.Valuer) []byte {
func appendDriverValuer(dst []byte, v driver.Valuer) []byte {
value, err := v.Value()
if err != nil {
log.Printf("%#v value failed: %s", v, err)
Expand Down
2 changes: 1 addition & 1 deletion encode_value.go → append_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func appendAppenderValue(dst []byte, v reflect.Value) []byte {
}

func appendDriverValuerValue(dst []byte, v reflect.Value) []byte {
return appendDriverValue(dst, v.Interface().(driver.Valuer))
return appendDriverValuer(dst, v.Interface().(driver.Valuer))
}

func isEmptyValue(v reflect.Value) bool {
Expand Down
10 changes: 5 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,20 @@ func (db *DB) conn() (*conn, error) {
return cn, nil
}

func (db *DB) freeConn(cn *conn, e error) error {
if e == nil {
func (db *DB) freeConn(cn *conn, err error) error {
if err == nil {
return db.pool.Put(cn)
}
if cn.br.Buffered() > 0 {
return db.pool.Remove(cn)
}
if pgerr, ok := e.(Error); ok && pgerr.Field('S') != "FATAL" {
if pgerr, ok := err.(Error); ok && pgerr.Field('S') != "FATAL" {
return db.pool.Put(cn)
}
if _, ok := e.(dbError); ok {
if _, ok := err.(dbError); ok {
return db.pool.Put(cn)
}
if neterr, ok := e.(net.Error); ok && neterr.Timeout() {
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
if err := db.cancelRequest(cn.processId, cn.secretKey); err != nil {
log.Printf("pg: cancelRequest failed: %s", err)
}
Expand Down
28 changes: 7 additions & 21 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,19 @@ import (
"strconv"
)

func Decode(dst interface{}, f []byte) error {
if scanner, ok := dst.(sql.Scanner); ok {
return decodeScanner(scanner, f)
}

func Decode(dst interface{}, b []byte) error {
v := reflect.ValueOf(dst)
if !v.IsValid() || v.Kind() != reflect.Ptr {
return decodeError(v)
}
vv := v.Elem()
if !vv.IsValid() {
return decodeError(v)
}
return DecodeValue(vv, f)
}

func decodeError(v reflect.Value) error {
if !v.IsValid() {
return errorf("pg: Decode(nil)")
}
if !v.CanSet() {
return errorf("pg: Decode(nonsettable %s)", v.Type())
if v.Kind() != reflect.Ptr {
return errorf("pg: Decode(nonsettable %T)", dst)
}
if v.Kind() == reflect.Interface {
return errorf("pg: Decode(nil)")
vv := v.Elem()
if !vv.IsValid() {
return errorf("pg: Decode(nonsettable %T)", dst)
}
return errorf("pg: Decode(nil %s)", v.Type())
return DecodeValue(vv, b)
}

func decodeScanner(scanner sql.Scanner, b []byte) error {
Expand Down
160 changes: 84 additions & 76 deletions decode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,76 +17,63 @@ var (

type valueDecoder func(reflect.Value, []byte) error

var valueDecoders = [...]valueDecoder{
reflect.Bool: decodeBoolValue,
reflect.Int: decodeIntValue,
reflect.Int8: decodeIntValue,
reflect.Int16: decodeIntValue,
reflect.Int32: decodeIntValue,
reflect.Int64: decodeIntValue,
reflect.Uint: decodeUintValue,
reflect.Uint8: decodeUintValue,
reflect.Uint16: decodeUintValue,
reflect.Uint32: decodeUintValue,
reflect.Uint64: decodeUintValue,
reflect.Uintptr: nil,
reflect.Float32: decodeFloatValue,
reflect.Float64: decodeFloatValue,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Interface: nil,
reflect.Map: decodeMapValue,
reflect.Ptr: nil,
reflect.Slice: decodeSliceValue,
reflect.String: decodeStringValue,
reflect.Struct: decodeStructValue,
reflect.UnsafePointer: nil,
}

func DecodeValue(dst reflect.Value, f []byte) error {
if !dst.IsValid() {
return decodeError(dst)
}

if f == nil {
return decodeNullValue(dst)
}

kind := dst.Kind()
if kind == reflect.Ptr && dst.IsNil() && dst.CanSet() {
dst.Set(reflect.New(dst.Type().Elem()))
}

if scanner, ok := dst.Interface().(sql.Scanner); ok {
return decodeScanner(scanner, f)
}

if kind == reflect.Interface || kind == reflect.Ptr {
v := dst.Elem()
if !v.IsValid() {
return decodeError(dst)
}
return DecodeValue(v, f)
var valueDecoders []valueDecoder

func init() {
valueDecoders = []valueDecoder{
reflect.Bool: decodeBoolValue,
reflect.Int: decodeIntValue,
reflect.Int8: decodeIntValue,
reflect.Int16: decodeIntValue,
reflect.Int32: decodeIntValue,
reflect.Int64: decodeIntValue,
reflect.Uint: decodeUintValue,
reflect.Uint8: decodeUintValue,
reflect.Uint16: decodeUintValue,
reflect.Uint32: decodeUintValue,
reflect.Uint64: decodeUintValue,
reflect.Uintptr: nil,
reflect.Float32: decodeFloatValue,
reflect.Float64: decodeFloatValue,
reflect.Complex64: nil,
reflect.Complex128: nil,
reflect.Array: nil,
reflect.Chan: nil,
reflect.Func: nil,
reflect.Interface: decodeInterfaceValue,
reflect.Map: decodeMapValue,
reflect.Ptr: decodePtrValue,
reflect.Slice: decodeSliceValue,
reflect.String: decodeStringValue,
reflect.Struct: decodeStructValue,
reflect.UnsafePointer: nil,
}
}

func DecodeValue(v reflect.Value, b []byte) error {
if !v.IsValid() {
return errorf("pg: Decode(nil)")
}

if !dst.CanSet() {
return decodeError(dst)
if b == nil {
return decodeNullValue(v)
}

if dst.Type() == timeType {
return decodeTimeValue(dst, f)
decoder := getDecoder(v.Type())
if decoder != nil {
return decoder(v, b)
}

if decoder := valueDecoders[kind]; decoder != nil {
return decoder(dst, f)
if v.Kind() == reflect.Interface {
return errorf("pg: Decode(nil)")
}
return errorf("pg: unsupported dst: %s", dst.Type())
return errorf("pg: Decode(unsupported %s)", v.Type())
}

func decodeBoolValue(v reflect.Value, b []byte) error {
if !v.CanSet() {
return errorf("pg: Decode(nonsettable %s)", v.Type())
}
v.SetBool(len(b) == 1 && b[0] == 't')
return nil
}
Expand Down Expand Up @@ -136,52 +123,70 @@ func decodeTimeValue(v reflect.Value, b []byte) error {
return nil
}

func decodeSliceValue(dst reflect.Value, f []byte) error {
elemType := dst.Type().Elem()
func decodePtrValue(v reflect.Value, b []byte) error {
if v.IsNil() {
if !v.CanSet() {
return errorf("pg: Decode(nonsettable %s)", v.Type())
}
vv := reflect.New(v.Type().Elem())
v.Set(vv)
}
return DecodeValue(v.Elem(), b)
}

func decodeSliceValue(v reflect.Value, b []byte) error {
elemType := v.Type().Elem()
switch elemType.Kind() {
case reflect.Uint8:
b, err := decodeBytes(f)
bs, err := decodeBytes(b)
if err != nil {
return err
}
dst.SetBytes(b)
v.SetBytes(bs)
return nil
case reflect.String:
s, err := decodeStringSlice(f)
s, err := decodeStringSlice(b)
if err != nil {
return err
}
dst.Set(reflect.ValueOf(s))
v.Set(reflect.ValueOf(s))
return nil
case reflect.Int:
s, err := decodeIntSlice(f)
s, err := decodeIntSlice(b)
if err != nil {
return err
}
dst.Set(reflect.ValueOf(s))
v.Set(reflect.ValueOf(s))
return nil
case reflect.Int64:
s, err := decodeInt64Slice(f)
s, err := decodeInt64Slice(b)
if err != nil {
return err
}
dst.Set(reflect.ValueOf(s))
v.Set(reflect.ValueOf(s))
return nil
}
return errorf("pg: unsupported dst: %s", dst.Type())
return errorf("pg: Decode(unsupported %s)", v.Type())
}

func decodeInterfaceValue(v reflect.Value, b []byte) error {
if v.IsNil() {
return errorf("pg: Decode(nil)")
}
return DecodeValue(v.Elem(), b)
}

func decodeMapValue(dst reflect.Value, f []byte) error {
typ := dst.Type()
func decodeMapValue(v reflect.Value, b []byte) error {
typ := v.Type()
if typ.Key().Kind() == reflect.String && typ.Elem().Kind() == reflect.String {
m, err := decodeStringStringMap(f)
m, err := decodeStringStringMap(b)
if err != nil {
return err
}
dst.Set(reflect.ValueOf(m))
v.Set(reflect.ValueOf(m))
return nil
}
return errorf("pg: unsupported dst: %s", dst.Type())
return errorf("pg: Decode(unsupported %s)", v.Type())
}

func decodeNullValue(v reflect.Value) error {
Expand All @@ -200,6 +205,9 @@ func decodeNullValue(v reflect.Value) error {
}

func decodeScannerValue(v reflect.Value, b []byte) error {
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return decodeScanner(v.Interface().(sql.Scanner), b)
}

Expand Down
12 changes: 6 additions & 6 deletions loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,15 @@ func (set IntSet) LoadColumn(colIdx int, colName string, b []byte) error {
func NewColumnLoader(dst interface{}) (ColumnLoader, error) {
v := reflect.ValueOf(dst)
if !v.IsValid() {
return nil, decodeError(v)
return nil, errorf("pg: Decode(nil)")
}
if v.Kind() != reflect.Ptr {
return nil, decodeError(v)
return nil, errorf("pg: Decode(nonsettable %T)", dst)
}
v = v.Elem()
switch v.Kind() {
vv := v.Elem()
switch vv.Kind() {
case reflect.Struct:
return newStructLoader(v), nil
return newStructLoader(vv), nil
}
return nil, errorf("pg: unsupported dst %s", v.Type())
return nil, errorf("pg: Decode(unsupported %T)", dst)
}
Loading

0 comments on commit c138b92

Please sign in to comment.