diff --git a/codec.go b/codec.go index cb83035a..3553088e 100644 --- a/codec.go +++ b/codec.go @@ -46,6 +46,10 @@ func (r *Reader) ReadVal(schema Schema, obj interface{}) { r.ReportError("ReadVal", "can only unmarshal into pointer") return } + if typ.Kind() == reflect.Ptr { + ptrType := typ.(*reflect2.UnsafePtrType) + typ = ptrType.Elem() + } decoder = r.cfg.DecoderOf(schema, typ) } @@ -79,9 +83,12 @@ func (c *frozenConfig) DecoderOf(schema Schema, typ reflect2.Type) ValDecoder { return decoder } - ptrType := typ.(*reflect2.UnsafePtrType) - decoder = decoderOfType(c, schema, ptrType.Elem()) + pe := &placeholderEncoder{} + c.addDecoderToCache(schema.Fingerprint(), rtype, pe) + + decoder = decoderOfType(c, schema, typ) c.addDecoderToCache(schema.Fingerprint(), rtype, decoder) + pe.dec = decoder return decoder } @@ -139,14 +146,31 @@ func (c *frozenConfig) EncoderOf(schema Schema, typ reflect2.Type) ValEncoder { return encoder } + pe := &placeholderEncoder{} + c.addEncoderToCache(schema.Fingerprint(), rtype, pe) + encoder = encoderOfType(c, schema, typ) if typ.LikePtr() { encoder = &onePtrEncoder{encoder} } c.addEncoderToCache(schema.Fingerprint(), rtype, encoder) + pe.enc = encoder return encoder } +type placeholderEncoder struct { + enc ValEncoder + dec ValDecoder +} + +func (e *placeholderEncoder) Encode(ptr unsafe.Pointer, w *Writer) { + e.enc.Encode(noescape(ptr), w) +} + +func (e *placeholderEncoder) Decode(ptr unsafe.Pointer, r *Reader) { + e.dec.Decode(noescape(ptr), r) +} + type onePtrEncoder struct { enc ValEncoder } diff --git a/codec_array.go b/codec_array.go index fe35e9d9..a8d05c77 100644 --- a/codec_array.go +++ b/codec_array.go @@ -29,7 +29,7 @@ func createEncoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) V func decoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { arr := schema.(*ArraySchema) sliceType := typ.(*reflect2.UnsafeSliceType) - decoder := decoderOfType(cfg, arr.Items(), sliceType.Elem()) + decoder := cfg.DecoderOf(arr.Items(), sliceType.Elem()) return &arrayDecoder{typ: sliceType, decoder: decoder} } @@ -67,7 +67,7 @@ func (d *arrayDecoder) Decode(ptr unsafe.Pointer, r *Reader) { func encoderOfArray(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { arr := schema.(*ArraySchema) sliceType := typ.(*reflect2.UnsafeSliceType) - encoder := encoderOfType(cfg, arr.Items(), sliceType.Elem()) + encoder := cfg.EncoderOf(arr.Items(), sliceType.Elem()) return &arrayEncoder{ blockLength: cfg.getBlockLength(), diff --git a/codec_map.go b/codec_map.go index 5efdb1d3..655cda9e 100644 --- a/codec_map.go +++ b/codec_map.go @@ -29,7 +29,7 @@ func createEncoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) Val func decoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecoder { m := schema.(*MapSchema) mapType := typ.(*reflect2.UnsafeMapType) - decoder := decoderOfType(cfg, m.Values(), mapType.Elem()) + decoder := cfg.DecoderOf(m.Values(), mapType.Elem()) return &mapDecoder{ mapType: mapType, @@ -72,7 +72,7 @@ func (d *mapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { func encoderOfMap(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEncoder { m := schema.(*MapSchema) mapType := typ.(*reflect2.UnsafeMapType) - encoder := encoderOfType(cfg, m.Values(), mapType.Elem()) + encoder := cfg.EncoderOf(m.Values(), mapType.Elem()) return &mapEncoder{ blockLength: cfg.getBlockLength(), diff --git a/codec_union.go b/codec_union.go index 9cb4031b..f9fe0561 100644 --- a/codec_union.go +++ b/codec_union.go @@ -153,7 +153,7 @@ func decoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValD _, typeIdx := union.Indices() ptrType := typ.(*reflect2.UnsafePtrType) elemType := ptrType.Elem() - decoder := decoderOfType(cfg, union.Types()[typeIdx], elemType) + decoder := cfg.DecoderOf(union.Types()[typeIdx], elemType) return &unionPtrDecoder{ schema: union, @@ -195,7 +195,7 @@ func encoderOfPtrUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValE union := schema.(*UnionSchema) nullIdx, typeIdx := union.Indices() ptrType := typ.(*reflect2.UnsafePtrType) - encoder := encoderOfType(cfg, union.Types()[typeIdx], ptrType.Elem()) + encoder := cfg.EncoderOf(union.Types()[typeIdx], ptrType.Elem()) return &unionPtrEncoder{ schema: union, @@ -230,7 +230,7 @@ func decoderOfResolvedUnion(cfg *frozenConfig, schema Schema) ValDecoder { for i, schema := range union.Types() { name := unionResolutionName(schema) if typ, err := cfg.resolver.Type(name); err == nil { - decoder := decoderOfType(cfg, schema, typ) + decoder := cfg.DecoderOf(schema, typ) decoders[i] = decoder types[i] = typ continue diff --git a/decoder_array_test.go b/decoder_array_test.go index 8def4998..99d51dda 100644 --- a/decoder_array_test.go +++ b/decoder_array_test.go @@ -50,6 +50,25 @@ func TestDecoder_ArraySliceOfStruct(t *testing.T) { assert.Equal(t, []TestRecord{{A: 27, B: "foo"}, {A: 27, B: "foo"}}, got) } +func TestDecoder_ArrayRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B []record `avro:"b"` + } + + data := []byte{0x2, 0x3, 0x8, 0x4, 0x0, 0x6, 0x0, 0x0} + schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"array", "items": "test"}}]}` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + assert.NoError(t, err) + assert.Equal(t, record{A: 1, B: []record{{A: 2}, {A: 3}}}, got) +} + func TestDecoder_ArraySliceError(t *testing.T) { defer ConfigTeardown() diff --git a/decoder_map_test.go b/decoder_map_test.go index db865cad..0c9faea7 100644 --- a/decoder_map_test.go +++ b/decoder_map_test.go @@ -50,6 +50,25 @@ func TestDecoder_MapMapOfStruct(t *testing.T) { assert.Equal(t, map[string]TestRecord{"foo": {A: 27, B: "foo"}}, got) } +func TestDecoder_MapOfRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B map[string]record `avro:"b"` + } + + data := []byte{0x02, 0x01, 0x0c, 0x06, 0x66, 0x6f, 0x6f, 0x04, 0x0, 0x0} + schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"map", "values": "test"}}]}` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + assert.NoError(t, err) + assert.Equal(t, record{A: 1, B: map[string]record{"foo": {A: 2, B: map[string]record{}}}}, got) +} + func TestDecoder_MapMapError(t *testing.T) { defer ConfigTeardown() diff --git a/decoder_union_test.go b/decoder_union_test.go index d1421b05..464d12e2 100644 --- a/decoder_union_test.go +++ b/decoder_union_test.go @@ -8,6 +8,7 @@ import ( "github.com/hamba/avro" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDecoder_UnionInvalidType(t *testing.T) { @@ -236,6 +237,38 @@ func TestDecoder_UnionPtrNotNullable(t *testing.T) { assert.Error(t, err) } +func TestDecoder_UnionPtrRecursiveType(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B *record `avro:"b"` + } + + data := []byte{0x02, 0x02, 0x04, 0x0} + schema := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "int"}, + {"name": "b", "type": [null, "test"]} + ] +}` + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + require.NoError(t, err) + want := record{ + A: 1, + B: &record{ + A: 2, + }, + } + assert.Equal(t, want, got) +} + func TestDecoder_UnionInterface(t *testing.T) { defer ConfigTeardown() @@ -401,18 +434,54 @@ func TestDecoder_UnionInterfaceRecord(t *testing.T) { data := []byte{0x02, 0x36, 0x06, 0x66, 0x6F, 0x6F} schema := `["int", {"type": "record", "name": "test", "fields" : [{"name": "a", "type": "long"}, {"name": "b", "type": "string"}]}]` - dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + dec, err := avro.NewDecoder(schema, bytes.NewReader(data)) + require.NoError(t, err) var got interface{} - err := dec.Decode(&got) + err = dec.Decode(&got) - assert.NoError(t, err) - assert.IsType(t, &TestRecord{}, got) + require.NoError(t, err) + require.IsType(t, &TestRecord{}, got) rec := got.(*TestRecord) assert.Equal(t, int64(27), rec.A) assert.Equal(t, "foo", rec.B) } +func TestDecoder_UnionInterfaceRecursiveType(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B interface{} `avro:"b"` + } + + data := []byte{0x02, 0x02, 0x04, 0x0} + schema := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "int"}, + {"name": "b", "type": [null, "test"]} + ] +}` + avro.Register("test", record{}) + + dec, _ := avro.NewDecoder(schema, bytes.NewReader(data)) + + var got record + err := dec.Decode(&got) + + require.NoError(t, err) + require.IsType(t, record{}, got) + want := record{ + A: 1, + B: record{ + A: 2, + }, + } + assert.Equal(t, want, got) +} + func TestDecoder_UnionInterfaceRecordNotReused(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_array_test.go b/encoder_array_test.go index c99b2696..64f0576c 100644 --- a/encoder_array_test.go +++ b/encoder_array_test.go @@ -63,6 +63,26 @@ func TestEncoder_ArrayOfStruct(t *testing.T) { assert.Equal(t, []byte{0x03, 0x14, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x0}, buf.Bytes()) } +func TestEncoder_ArrayRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B []record `avro:"b"` + } + + schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"array", "items": "test"}}]}` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{A: 1, B: []record{{A: 2}, {A: 3}}} + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x2, 0x3, 0x8, 0x4, 0x0, 0x6, 0x0, 0x0}, buf.Bytes()) +} + func TestEncoder_ArrayError(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_map_test.go b/encoder_map_test.go index 1d411a16..8e5de238 100644 --- a/encoder_map_test.go +++ b/encoder_map_test.go @@ -63,6 +63,26 @@ func TestEncoder_MapOfStruct(t *testing.T) { assert.Equal(t, []byte{0x01, 0x12, 0x06, 0x66, 0x6F, 0x6F, 0x36, 0x06, 0x66, 0x6f, 0x6f, 0x0}, buf.Bytes()) } +func TestEncoder_MapOfRecursiveStruct(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B map[string]record `avro:"b"` + } + + schema := `{"type": "record", "name": "test", "fields" : [{"name": "a", "type": "int"}, {"name": "b", "type": {"type":"map", "values": "test"}}]}` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{A: 1, B: map[string]record{"foo": {A: 2}}} + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x02, 0x01, 0x0c, 0x06, 0x66, 0x6f, 0x6f, 0x04, 0x0, 0x0}, buf.Bytes()) +} + func TestEncoder_MapInvalidKeyType(t *testing.T) { defer ConfigTeardown() diff --git a/encoder_union_test.go b/encoder_union_test.go index 8b89f397..943a4d61 100644 --- a/encoder_union_test.go +++ b/encoder_union_test.go @@ -238,6 +238,38 @@ func TestEncoder_UnionPtrNotNullable(t *testing.T) { assert.Error(t, err) } +func TestEncoder_UnionPtrRecursiveType(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B *record `avro:"b"` + } + + schema := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "int"}, + {"name": "b", "type": [null, "test"]} + ] +}` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{ + A: 1, + B: &record{ + A: 2, + }, + } + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x02, 0x02, 0x04, 0x0}, buf.Bytes()) +} + func TestEncoder_UnionInterface(t *testing.T) { defer ConfigTeardown() @@ -366,6 +398,38 @@ func TestEncoder_UnionInterfaceNamed(t *testing.T) { assert.Equal(t, []byte{0x02, 0x02}, buf.Bytes()) } +func TestEncoder_UnionInterfaceRecursiveType(t *testing.T) { + defer ConfigTeardown() + + type record struct { + A int `avro:"a"` + B interface{} `avro:"b"` + } + + schema := `{ + "type": "record", + "name": "test", + "fields" : [ + {"name": "a", "type": "int"}, + {"name": "b", "type": [null, "test"]} + ] +}` + buf := bytes.NewBuffer([]byte{}) + enc, err := avro.NewEncoder(schema, buf) + assert.NoError(t, err) + + rec := record{ + A: 1, + B: &record{ + A: 2, + }, + } + err = enc.Encode(rec) + + assert.NoError(t, err) + assert.Equal(t, []byte{0x02, 0x02, 0x04, 0x0}, buf.Bytes()) +} + func TestEncoder_UnionInterfaceWithTime(t *testing.T) { defer ConfigTeardown()