Skip to content

Commit

Permalink
misc: restructure content type encoder modules (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac authored Dec 5, 2024
1 parent 79c10b5 commit 4b22101
Show file tree
Hide file tree
Showing 12 changed files with 1,175 additions and 1,109 deletions.
312 changes: 245 additions & 67 deletions connector/internal/contenttype/multipart.go
Original file line number Diff line number Diff line change
@@ -1,100 +1,278 @@
package contenttype

import (
"encoding/json"
"bytes"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"reflect"
"slices"

"github.com/hasura/ndc-http/ndc-http-schema/schema"
rest "github.com/hasura/ndc-http/ndc-http-schema/schema"
"github.com/hasura/ndc-sdk-go/schema"
"github.com/hasura/ndc-sdk-go/utils"
)

// MultipartWriter extends multipart.Writer with helpers
type MultipartWriter struct {
*multipart.Writer
// MultipartFormEncoder implements a multipart/form encoder.
type MultipartFormEncoder struct {
schema *rest.NDCHttpSchema
paramEncoder *URLParameterEncoder
operation *rest.OperationInfo
arguments map[string]any
}

// NewMultipartWriter creates a MultipartWriter instance
func NewMultipartWriter(w io.Writer) *MultipartWriter {
return &MultipartWriter{multipart.NewWriter(w)}
func NewMultipartFormEncoder(schema *rest.NDCHttpSchema, operation *rest.OperationInfo, arguments map[string]any) *MultipartFormEncoder {
return &MultipartFormEncoder{
schema: schema,
paramEncoder: NewURLParameterEncoder(schema),
operation: operation,
arguments: arguments,
}
}

// WriteDataURI write a file from data URI string
func (w *MultipartWriter) WriteDataURI(name string, value any, headers http.Header) error {
b64, err := utils.DecodeString(value)
if err != nil {
return fmt.Errorf("%s: %w", name, err)
}
dataURI, err := DecodeDataURI(b64)
if err != nil {
return fmt.Errorf("%s: %w", name, err)
// Encode the multipart form.
func (c *MultipartFormEncoder) Encode(bodyData any) (*bytes.Reader, string, error) {
bodyInfo, ok := c.operation.Arguments[rest.BodyKey]
if !ok {
return nil, "", errRequestBodyTypeRequired
}

h := make(textproto.MIMEHeader)
for key, header := range headers {
h[key] = header
}
h.Set("Content-Disposition",
fmt.Sprintf(`form-data; name="%s"; filename="%s"`,
escapeQuotes(name), escapeQuotes(name)))

if dataURI.MediaType == "" {
h.Set("Content-Type", schema.ContentTypeOctetStream)
} else {
h.Set("Content-Type", dataURI.MediaType)
}
buffer := new(bytes.Buffer)
writer := NewMultipartWriter(buffer)

p, err := w.CreatePart(h)
if err != nil {
return fmt.Errorf("%s: %w", name, err)
if err := c.evalMultipartForm(writer, &bodyInfo, reflect.ValueOf(bodyData)); err != nil {
return nil, "", err
}
if err := writer.Close(); err != nil {
return nil, "", err
}

_, err = p.Write([]byte(dataURI.Data))
reader := bytes.NewReader(buffer.Bytes())
buffer.Reset()

return err
return reader, writer.FormDataContentType(), nil
}

// WriteField calls CreateFormField and then writes the given value with json encoding.
func (w *MultipartWriter) WriteJSON(fieldName string, value any, headers http.Header) error {
bs, err := json.Marshal(value)
if err != nil {
return err
func (mfb *MultipartFormEncoder) evalMultipartForm(w *MultipartWriter, bodyInfo *rest.ArgumentInfo, bodyData reflect.Value) error {
bodyData, ok := utils.UnwrapPointerFromReflectValue(bodyData)
if !ok {
return nil
}
switch bodyType := bodyInfo.Type.Interface().(type) {
case *schema.NullableType:
return mfb.evalMultipartForm(w, &rest.ArgumentInfo{
ArgumentInfo: schema.ArgumentInfo{
Type: bodyType.UnderlyingType,
},
HTTP: bodyInfo.HTTP,
}, bodyData)
case *schema.NamedType:
if !ok {
return fmt.Errorf("%s: %w", rest.BodyKey, errArgumentRequired)
}
bodyObject, ok := mfb.schema.ObjectTypes[bodyType.Name]
if !ok {
break
}
kind := bodyData.Kind()
switch kind {
case reflect.Map, reflect.Interface:
bi := bodyData.Interface()
bodyMap, ok := bi.(map[string]any)
if !ok {
return fmt.Errorf("invalid multipart form body, expected object, got %v", bi)
}

h := createFieldMIMEHeader(fieldName, headers)
h.Set(schema.ContentTypeHeader, schema.ContentTypeJSON)
p, err := w.CreatePart(h)
if err != nil {
return err
}
for key, fieldInfo := range bodyObject.Fields {
fieldValue := bodyMap[key]
var enc *rest.EncodingObject
if len(mfb.operation.Request.RequestBody.Encoding) > 0 {
en, ok := mfb.operation.Request.RequestBody.Encoding[key]
if ok {
enc = &en
}
}

_, err = p.Write(bs)
if err := mfb.evalMultipartFieldValueRecursive(w, key, reflect.ValueOf(fieldValue), &fieldInfo, enc); err != nil {
return err
}
}

return err
}
return nil
case reflect.Struct:
reflectType := bodyData.Type()
for fieldIndex := range bodyData.NumField() {
fieldValue := bodyData.Field(fieldIndex)
fieldType := reflectType.Field(fieldIndex)
fieldInfo, ok := bodyObject.Fields[fieldType.Name]
if !ok {
continue
}

var enc *rest.EncodingObject
if len(mfb.operation.Request.RequestBody.Encoding) > 0 {
en, ok := mfb.operation.Request.RequestBody.Encoding[fieldType.Name]
if ok {
enc = &en
}
}

if err := mfb.evalMultipartFieldValueRecursive(w, fieldType.Name, fieldValue, &fieldInfo, enc); err != nil {
return err
}
}

// WriteField calls CreateFormField and then writes the given value.
func (w *MultipartWriter) WriteField(fieldName, value string, headers http.Header) error {
h := createFieldMIMEHeader(fieldName, headers)
p, err := w.CreatePart(h)
if err != nil {
return err
return nil
}
}
_, err = p.Write([]byte(value))

return err
return fmt.Errorf("invalid multipart form body, expected object, got %v", bodyInfo.Type)
}

func createFieldMIMEHeader(fieldName string, headers http.Header) textproto.MIMEHeader {
h := make(textproto.MIMEHeader)
for key, header := range headers {
h[key] = header
func (mfb *MultipartFormEncoder) evalMultipartFieldValueRecursive(w *MultipartWriter, name string, value reflect.Value, fieldInfo *rest.ObjectField, enc *rest.EncodingObject) error {
underlyingValue, notNull := utils.UnwrapPointerFromReflectValue(value)
argTypeT, err := fieldInfo.Type.InterfaceT()
switch argType := argTypeT.(type) {
case *schema.ArrayType:
if !notNull {
return fmt.Errorf("%s: %w", name, errArgumentRequired)
}
if enc != nil && slices.Contains(enc.ContentType, rest.ContentTypeJSON) {
var headers http.Header
var err error
if len(enc.Headers) > 0 {
headers, err = mfb.evalEncodingHeaders(enc.Headers)
if err != nil {
return err
}
}

return w.WriteJSON(name, value.Interface(), headers)
}

if !slices.Contains([]reflect.Kind{reflect.Slice, reflect.Array}, value.Kind()) {
return fmt.Errorf("%s: expected array type, got %v", name, value.Kind())
}

for i := range value.Len() {
elem := value.Index(i)
err := mfb.evalMultipartFieldValueRecursive(w, name+"[]", elem, &rest.ObjectField{
ObjectField: schema.ObjectField{
Type: argType.ElementType,
},
HTTP: fieldInfo.HTTP.Items,
}, enc)
if err != nil {
return err
}
}

return nil
case *schema.NullableType:
if !notNull {
return nil
}

return mfb.evalMultipartFieldValueRecursive(w, name, underlyingValue, &rest.ObjectField{
ObjectField: schema.ObjectField{
Type: argType.UnderlyingType,
},
HTTP: fieldInfo.HTTP,
}, enc)
case *schema.NamedType:
if !notNull {
return fmt.Errorf("%s: %w", name, errArgumentRequired)
}
var headers http.Header
var err error
if enc != nil && len(enc.Headers) > 0 {
headers, err = mfb.evalEncodingHeaders(enc.Headers)
if err != nil {
return err
}
}

if iScalar, ok := mfb.schema.ScalarTypes[argType.Name]; ok {
switch iScalar.Representation.Interface().(type) {
case *schema.TypeRepresentationBytes:
return w.WriteDataURI(name, value.Interface(), headers)
default:
}
}

if enc != nil && slices.Contains(enc.ContentType, rest.ContentTypeJSON) {
return w.WriteJSON(name, value, headers)
}

params, err := mfb.paramEncoder.EncodeParameterValues(fieldInfo, value, []string{})
if err != nil {
return err
}

if len(params) == 0 {
return nil
}

for _, p := range params {
keys := p.Keys()
values := p.Values()
fieldName := name

if len(keys) > 0 {
keys = append([]Key{NewKey(name)}, keys...)
fieldName = keys.String()
}

if len(values) > 1 {
fieldName += "[]"
for _, v := range values {
if err = w.WriteField(fieldName, v, headers); err != nil {
return err
}
}
} else if len(values) == 1 {
if err = w.WriteField(fieldName, values[0], headers); err != nil {
return err
}
}
}

return nil
case *schema.PredicateType:
return fmt.Errorf("%s: predicate type is not supported", name)
default:
return fmt.Errorf("%s: %w", name, err)
}
}

func (mfb *MultipartFormEncoder) evalEncodingHeaders(encHeaders map[string]rest.RequestParameter) (http.Header, error) {
results := http.Header{}
for key, param := range encHeaders {
argumentName := param.ArgumentName
if argumentName == "" {
argumentName = key
}
argumentInfo, ok := mfb.operation.Arguments[argumentName]
if !ok {
continue
}
rawHeaderValue, ok := mfb.arguments[argumentName]
if !ok {
continue
}

headerParams, err := mfb.paramEncoder.EncodeParameterValues(&rest.ObjectField{
ObjectField: schema.ObjectField{
Type: argumentInfo.Type,
},
HTTP: param.Schema,
}, reflect.ValueOf(rawHeaderValue), []string{})
if err != nil {
return nil, err
}

param.Name = key
SetHeaderParameters(&results, &param, headerParams)
}
h.Set("Content-Disposition",
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(fieldName)))

return h
return results, nil
}
Loading

0 comments on commit 4b22101

Please sign in to comment.