Skip to content

Commit

Permalink
feat: support recursive types (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jan 13, 2022
1 parent 4346e2a commit b51e911
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 13 deletions.
28 changes: 26 additions & 2 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func (r *Reader) ReadVal(schema Schema, obj interface{}) {
r.ReportError("ReadVal", "can only unmarshal into pointer")
return
}
if typ.Kind() == reflect.Ptr {
ptrType := typ.(*reflect2.UnsafePtrType)
typ = ptrType.Elem()
}

decoder = r.cfg.DecoderOf(schema, typ)
}
Expand Down Expand Up @@ -79,9 +83,12 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder {
return decoder
}

ptrType := typ.(*reflect2.UnsafePtrType)
decoder = decoderOfType(c, schema, ptrType.Elem())
pe := &placeholderEncoder{}
c.addDecoderToCache(schema.Fingerprint(), rtype, pe)

decoder = decoderOfType(c, schema, typ)
c.addDecoderToCache(schema.Fingerprint(), rtype, decoder)
pe.dec = decoder
return decoder
}

Expand Down Expand Up @@ -139,14 +146,31 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder {
return encoder
}

pe := &placeholderEncoder{}
c.addEncoderToCache(schema.Fingerprint(), rtype, pe)

encoder = encoderOfType(c, schema, typ)
if typ.LikePtr() {
encoder = &onePtrEncoder{encoder}
}
c.addEncoderToCache(schema.Fingerprint(), rtype, encoder)
pe.enc = encoder
return encoder
}

type placeholderEncoder struct {
enc ValEncoder
dec ValDecoder
}

func (e *placeholderEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
e.enc.Encode(noescape(ptr), w)
}

func (e *placeholderEncoder) Decode(ptr unsafe.Pointer, r *Reader) {
e.dec.Decode(noescape(ptr), r)
}

type onePtrEncoder struct {
enc ValEncoder
}
Expand Down
4 changes: 2 additions & 2 deletions codec_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) V
func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
arr := schema.(*ArraySchema)
sliceType := typ.(*reflect2.UnsafeSliceType)
decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem())
decoder := cfg.DecoderOf(arr.Items(), sliceType.Elem())

return &arrayDecoder{typ: sliceType, decoder: decoder}
}
Expand Down Expand Up @@ -67,7 +67,7 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
arr := schema.(*ArraySchema)
sliceType := typ.(*reflect2.UnsafeSliceType)
encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem())
encoder := cfg.EncoderOf(arr.Items(), sliceType.Elem())

return &arrayEncoder{
blockLength: cfg.getBlockLength(),
Expand Down
4 changes: 2 additions & 2 deletions codec_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) Val
func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder {
m := schema.(*MapSchema)
mapType := typ.(*reflect2.UnsafeMapType)
decoder := decoderOfType(cfg, m.Values(), mapType.Elem())
decoder := cfg.DecoderOf(m.Values(), mapType.Elem())

return &mapDecoder{
mapType: mapType,
Expand Down Expand Up @@ -72,7 +72,7 @@ func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder {
m := schema.(*MapSchema)
mapType := typ.(*reflect2.UnsafeMapType)
encoder := encoderOfType(cfg, m.Values(), mapType.Elem())
encoder := cfg.EncoderOf(m.Values(), mapType.Elem())

return &mapEncoder{
blockLength: cfg.getBlockLength(),
Expand Down
6 changes: 3 additions & 3 deletions codec_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValD
_, typeIdx := union.Indices()
ptrType := typ.(*reflect2.UnsafePtrType)
elemType := ptrType.Elem()
decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType)
decoder := cfg.DecoderOf(union.Types()[typeIdx], elemType)

return &unionPtrDecoder{
schema: union,
Expand Down Expand Up @@ -195,7 +195,7 @@ func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValE
union := schema.(*UnionSchema)
nullIdx, typeIdx := union.Indices()
ptrType := typ.(*reflect2.UnsafePtrType)
encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem())
encoder := cfg.EncoderOf(union.Types()[typeIdx], ptrType.Elem())

return &unionPtrEncoder{
schema: union,
Expand Down Expand Up @@ -230,7 +230,7 @@ func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) ValDecoder {
for i, schema := range union.Types() {
name := unionResolutionName(schema)
if typ, err := cfg.resolver.Type(name); err == nil {
decoder := decoderOfType(cfg, schema, typ)
decoder := cfg.DecoderOf(schema, typ)
decoders[i] = decoder
types[i] = typ
continue
Expand Down
19 changes: 19 additions & 0 deletions decoder_array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ func TestDecoder_ArraySliceOfStruct(t *testing.T) {
assert.Equal(t, []TestRecord{{A: 27, B: "foo"}, {A: 27, B: "foo"}}, got)
}

func TestDecoder_ArrayRecursiveStruct(t *testing.T) {
defer ConfigTeardown()

type record struct {
A int `avro:"a"`
B []record `avro:"b"`
}

data := []byte{0x2, 0x3, 0x8, 0x4, 0x0, 0x6, 0x0, 0x0}
schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"array", "items": "test"}}]}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got record
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, record{A: 1, B: []record{{A: 2}, {A: 3}}}, got)
}

func TestDecoder_ArraySliceError(t *testing.T) {
defer ConfigTeardown()

Expand Down
19 changes: 19 additions & 0 deletions decoder_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ func TestDecoder_MapMapOfStruct(t *testing.T) {
assert.Equal(t, map[string]TestRecord{"foo": {A: 27, B: "foo"}}, got)
}

func TestDecoder_MapOfRecursiveStruct(t *testing.T) {
defer ConfigTeardown()

type record struct {
A int `avro:"a"`
B map[string]record `avro:"b"`
}

data := []byte{0x02, 0x01, 0x0c, 0x06, 0x66, 0x6f, 0x6f, 0x04, 0x0, 0x0}
schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"map", "values": "test"}}]}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got record
err := dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, record{A: 1, B: map[string]record{"foo": {A: 2, B: map[string]record{}}}}, got)
}

func TestDecoder_MapMapError(t *testing.T) {
defer ConfigTeardown()

Expand Down
77 changes: 73 additions & 4 deletions decoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/hamba/avro"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDecoder_UnionInvalidType(t *testing.T) {
Expand Down Expand Up @@ -236,6 +237,38 @@ func TestDecoder_UnionPtrNotNullable(t *testing.T) {
assert.Error(t, err)
}

func TestDecoder_UnionPtrRecursiveType(t *testing.T) {
defer ConfigTeardown()

type record struct {
A int `avro:"a"`
B *record `avro:"b"`
}

data := []byte{0x02, 0x02, 0x04, 0x0}
schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "int"},
{"name": "b", "type": [null, "test"]}
]
}`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got record
err := dec.Decode(&got)

require.NoError(t, err)
want := record{
A: 1,
B: &record{
A: 2,
},
}
assert.Equal(t, want, got)
}

func TestDecoder_UnionInterface(t *testing.T) {
defer ConfigTeardown()

Expand Down Expand Up @@ -401,18 +434,54 @@ func TestDecoder_UnionInterfaceRecord(t *testing.T) {

data := []byte{0x02, 0x36, 0x06, 0x66, 0x6F, 0x6F}
schema := `["int", {"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}]`
dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))
dec, err := avro.NewDecoder(schema, bytes.NewReader(data))
require.NoError(t, err)

var got interface{}
err := dec.Decode(&got)
err = dec.Decode(&got)

assert.NoError(t, err)
assert.IsType(t, &TestRecord{}, got)
require.NoError(t, err)
require.IsType(t, &TestRecord{}, got)
rec := got.(*TestRecord)
assert.Equal(t, int64(27), rec.A)
assert.Equal(t, "foo", rec.B)
}

func TestDecoder_UnionInterfaceRecursiveType(t *testing.T) {
defer ConfigTeardown()

type record struct {
A int `avro:"a"`
B interface{} `avro:"b"`
}

data := []byte{0x02, 0x02, 0x04, 0x0}
schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "int"},
{"name": "b", "type": [null, "test"]}
]
}`
avro.Register("test", record{})

dec, _ := avro.NewDecoder(schema, bytes.NewReader(data))

var got record
err := dec.Decode(&got)

require.NoError(t, err)
require.IsType(t, record{}, got)
want := record{
A: 1,
B: record{
A: 2,
},
}
assert.Equal(t, want, got)
}

func TestDecoder_UnionInterfaceRecordNotReused(t *testing.T) {
defer ConfigTeardown()

Expand Down
20 changes: 20 additions & 0 deletions encoder_array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ func TestEncoder_ArrayOfStruct(t *testing.T) {
assert.Equal(t, []byte{0x03, 0x14, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x0}, buf.Bytes())
}

func TestEncoder_ArrayRecursiveStruct(t *testing.T) {
defer ConfigTeardown()

type record struct {
A int `avro:"a"`
B []record `avro:"b"`
}

schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"array", "items": "test"}}]}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

rec := record{A: 1, B: []record{{A: 2}, {A: 3}}}
err = enc.Encode(rec)

assert.NoError(t, err)
assert.Equal(t, []byte{0x2, 0x3, 0x8, 0x4, 0x0, 0x6, 0x0, 0x0}, buf.Bytes())
}

func TestEncoder_ArrayError(t *testing.T) {
defer ConfigTeardown()

Expand Down
20 changes: 20 additions & 0 deletions encoder_map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,26 @@ func TestEncoder_MapOfStruct(t *testing.T) {
assert.Equal(t, []byte{0x01, 0x12, 0x06, 0x66, 0x6F, 0x6F, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x0}, buf.Bytes())
}

func TestEncoder_MapOfRecursiveStruct(t *testing.T) {
defer ConfigTeardown()

type record struct {
A int `avro:"a"`
B map[string]record `avro:"b"`
}

schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"map", "values": "test"}}]}`
buf := bytes.NewBuffer([]byte{})
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

rec := record{A: 1, B: map[string]record{"foo": {A: 2}}}
err = enc.Encode(rec)

assert.NoError(t, err)
assert.Equal(t, []byte{0x02, 0x01, 0x0c, 0x06, 0x66, 0x6f, 0x6f, 0x04, 0x0, 0x0}, buf.Bytes())
}

func TestEncoder_MapInvalidKeyType(t *testing.T) {
defer ConfigTeardown()

Expand Down
Loading

0 comments on commit b51e911

Please sign in to comment.