Skip to content

Commit

Permalink
feat: support embedded structs (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jun 15, 2021
1 parent 52dd086 commit 592567f
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 23 deletions.
120 changes: 100 additions & 20 deletions codec_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func decoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec
continue
}

dec := decoderOfType(cfg, field.Type(), sf.Field.Type())
dec := decoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type())
fields = append(fields, &structFieldDecoder{
field: sf.Field,
decoder: dec,
Expand All @@ -90,7 +90,7 @@ func (d *structDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
}

type structFieldDecoder struct {
field *reflect2.UnsafeStructField
field []*reflect2.UnsafeStructField
decoder ValDecoder
}

Expand All @@ -101,11 +101,29 @@ func (d *structFieldDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
return
}

fieldPtr := d.field.UnsafeGet(ptr)
fieldPtr := ptr
for i, f := range d.field {
fieldPtr = f.UnsafeGet(fieldPtr)

if i == len(d.field)-1 {
break
}

if f.Type().Kind() == reflect.Ptr {
if *((*unsafe.Pointer)(ptr)) == nil {
newPtr := f.Type().UnsafeNew()
*((*unsafe.Pointer)(fieldPtr)) = newPtr
}

fieldPtr = *((*unsafe.Pointer)(fieldPtr))
}
}
d.decoder.Decode(fieldPtr, r)

if r.Error != nil && r.Error != io.EOF {
r.Error = fmt.Errorf("%s: %s", d.field.Name(), r.Error.Error())
for _, f := range d.field {
r.Error = fmt.Errorf("%s: %s", f.Name(), r.Error.Error())
}
}
}

Expand All @@ -120,7 +138,7 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc
if sf == nil {
if !field.HasDefault() {
// In all other cases, this is a required field
return &errorEncoder{err: fmt.Errorf("avro: record %s is missing required field %s", rec.FullName(), field.Name())}
return &errorEncoder{err: fmt.Errorf("avro: record %s is missing required field %q", rec.FullName(), field.Name())}
}

def := field.Default()
Expand Down Expand Up @@ -151,7 +169,7 @@ func encoderOfStruct(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValEnc

fields = append(fields, &structFieldEncoder{
field: sf.Field,
encoder: encoderOfType(cfg, field.Type(), sf.Field.Type()),
encoder: encoderOfType(cfg, field.Type(), sf.Field[len(sf.Field)-1].Type()),
})
}

Expand All @@ -170,7 +188,7 @@ func (e *structEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
}

type structFieldEncoder struct {
field *reflect2.UnsafeStructField
field []*reflect2.UnsafeStructField
defaultPtr unsafe.Pointer
encoder ValEncoder
}
Expand All @@ -182,11 +200,29 @@ func (e *structFieldEncoder) Encode(ptr unsafe.Pointer, w *Writer) {
return
}

fieldPtr := e.field.UnsafeGet(ptr)
fieldPtr := ptr
for i, f := range e.field {
fieldPtr = f.UnsafeGet(fieldPtr)

if i == len(e.field)-1 {
break
}

if f.Type().Kind() == reflect.Ptr {
if *((*unsafe.Pointer)(ptr)) == nil {
w.Error = fmt.Errorf("embedded field %q is nil", f.Name())
return
}

fieldPtr = *((*unsafe.Pointer)(fieldPtr))
}
}
e.encoder.Encode(fieldPtr, w)

if w.Error != nil && w.Error != io.EOF {
w.Error = fmt.Errorf("%s: %s", e.field.Name(), w.Error.Error())
for _, f := range e.field {
w.Error = fmt.Errorf("%s: %s", f.Name(), w.Error.Error())
}
}
}

Expand Down Expand Up @@ -345,24 +381,68 @@ func (sf structFields) Get(name string) *structField {
}

type structField struct {
Field *reflect2.UnsafeStructField
Name string
Field []*reflect2.UnsafeStructField

anon *reflect2.UnsafeStructType
}

func describeStruct(tagKey string, typ reflect2.Type) *structDescriptor {
structType := typ.(*reflect2.UnsafeStructType)
fields := structFields{}
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i).(*reflect2.UnsafeStructField)
fieldName := field.Name()
if tag, ok := field.Tag().Lookup(tagKey); ok {
fieldName = tag
}

fields = append(fields, &structField{
Field: field,
Name: fieldName,
})
var curr []structField
next := []structField{{anon: structType}}

visited := map[uintptr]bool{}

for len(next) > 0 {
curr, next = next, curr[:0]

for _, f := range curr {
rtype := f.anon.RType()
if visited[f.anon.RType()] {
continue
}
visited[rtype] = true

for i := 0; i < f.anon.NumField(); i++ {
field := f.anon.Field(i).(*reflect2.UnsafeStructField)
isUnexported := field.PkgPath() != ""

chain := make([]*reflect2.UnsafeStructField, len(f.Field)+1)
copy(chain, f.Field)
chain[len(f.Field)] = field

if field.Anonymous() {
t := field.Type()
if t.Kind() == reflect.Ptr {
t = t.(*reflect2.UnsafePtrType).Elem()
}
if t.Kind() != reflect.Struct {
continue
}

next = append(next, structField{Field: chain, anon: t.(*reflect2.UnsafeStructType)})
continue
}

// Ignore unexported fields.
if isUnexported {
continue
}

fieldName := field.Name()
if tag, ok := field.Tag().Lookup(tagKey); ok {
fieldName = tag
}

fields = append(fields, &structField{
Name: fieldName,
Field: chain,
})
}
}
}

return &structDescriptor{
Expand Down
44 changes: 44 additions & 0 deletions decoder_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,50 @@ func TestDecoder_RecordStructInvalidData(t *testing.T) {
assert.Error(t, err)
}

func TestDecoder_RecordEmbeddedStruct(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}
schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
dec, err := avro.NewDecoder(schema, bytes.NewReader(data))
assert.NoError(t, err)

var got TestEmbeddedRecord
err = dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, TestEmbeddedRecord{TestEmbed: TestEmbed{A: 27}, B: "foo"}, got)
}

func TestDecoder_RecordEmbeddedPtrStruct(t *testing.T) {
defer ConfigTeardown()

data := []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}
schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
dec, err := avro.NewDecoder(schema, bytes.NewReader(data))
assert.NoError(t, err)

var got TestEmbeddedPtrRecord
err = dec.Decode(&got)

assert.NoError(t, err)
assert.Equal(t, TestEmbeddedPtrRecord{TestEmbed: &TestEmbed{A: 27}, B: "foo"}, got)
}

func TestDecoder_RecordMap(t *testing.T) {
defer ConfigTeardown()

Expand Down
107 changes: 107 additions & 0 deletions encoder_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,113 @@ func TestEncoder_RecordStructFieldError(t *testing.T) {
assert.Error(t, err)
}

func TestEncoder_RecordEmbeddedStruct(t *testing.T) {
defer ConfigTeardown()

schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
obj := TestEmbeddedRecord{TestEmbed: TestEmbed{A: 27}, B: "foo"}
buf := &bytes.Buffer{}
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(obj)

assert.NoError(t, err)
assert.Equal(t, []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, buf.Bytes())
}

func TestEncoder_RecordEmbeddedPtrStruct(t *testing.T) {
defer ConfigTeardown()

schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
obj := TestEmbeddedPtrRecord{TestEmbed: &TestEmbed{A: 27}, B: "foo"}
buf := &bytes.Buffer{}
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(obj)

assert.NoError(t, err)
assert.Equal(t, []byte{0x36, 0x06, 0x66, 0x6f, 0x6f}, buf.Bytes())
}

func TestEncoder_RecordEmbeddedPtrStructNull(t *testing.T) {
defer ConfigTeardown()

schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
obj := TestEmbeddedPtrRecord{B: "foo"}
buf := &bytes.Buffer{}
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(obj)

assert.Error(t, err)
}

func TestEncoder_RecordEmbeddedIntStruct(t *testing.T) {
defer ConfigTeardown()

schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
obj := TestEmbeddedIntRecord{TestEmbedInt: 27, B: "foo"}
buf := &bytes.Buffer{}
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(obj)

assert.Error(t, err)
}

func TestEncoder_RecordUnexportedStruct(t *testing.T) {
defer ConfigTeardown()

schema := `{
"type": "record",
"name": "test",
"fields" : [
{"name": "a", "type": "long"},
{"name": "b", "type": "string"}
]
}`
obj := TestUnexportedRecord{A: 27, b: "foo"}
buf := &bytes.Buffer{}
enc, err := avro.NewEncoder(schema, buf)
assert.NoError(t, err)

err = enc.Encode(obj)

assert.Error(t, err)
}

func TestEncoder_RecordMap(t *testing.T) {
defer ConfigTeardown()

Expand Down
2 changes: 1 addition & 1 deletion encoder_union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestEncoder_UnionMapWithDuration(t *testing.T) {
assert.NoError(t, err)

m := map[string]interface{}{
"int.time-millis": 123456789*time.Millisecond,
"int.time-millis": 123456789 * time.Millisecond,
}
err = enc.Encode(m)

Expand Down
3 changes: 1 addition & 2 deletions registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ func (c *Client) GetSchema(id int) (avro.Schema, error) {
}

var payload schemaPayload
err := c.request(http.MethodGet, "/schemas/ids/"+strconv.Itoa(id), nil, &payload)
if err != nil {
if err := c.request(http.MethodGet, "/schemas/ids/"+strconv.Itoa(id), nil, &payload); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit 592567f

Please sign in to comment.