Skip to content

Commit

Permalink
Add support for Go SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjohnsonpint committed Aug 16, 2024
1 parent 6882355 commit 39b4fed
Show file tree
Hide file tree
Showing 43 changed files with 1,602 additions and 372 deletions.
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"legacymodels",
"Lessable",
"Macbook",
"makeslice",
"malloc",
"manifestdata",
"mapstructure",
"Msgf",
Expand All @@ -79,6 +81,7 @@
"pgxpool",
"pluginmanager",
"promhttp",
"ptrs",
"reindex",
"renameio",
"schemagen",
Expand All @@ -100,6 +103,7 @@
"urlpkg",
"userid",
"usize",
"vals",
"vecs",
"viterin",
"walltime",
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## UNRELEASED

- Runtime support for Hypermode Go SDK [#317](https://github.com/hypermodeAI/runtime/pull/317)
- Improve logger registration code [#335](https://github.com/hypermodeAI/runtime/pull/335)
- Add dgraph host functions [#336](https://github.com/hypermodeAI/runtime/pull/336)

Expand Down
8 changes: 4 additions & 4 deletions collections/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,17 @@ func validateEmbedder(ctx context.Context, embedder string) error {
lti := plugins.GetPlugin(ctx).Language.TypeInfo()

p := fn.Parameters[0]
if !lti.IsArrayType(p.Type) || !lti.IsStringType(lti.GetArraySubtype(p.Type)) {
if !lti.IsListType(p.Type) || !lti.IsStringType(lti.GetListSubtype(p.Type)) {
return errInvalidEmbedderSignature
}

r := fn.Results[0]
if !lti.IsArrayType(r.Type) {
if !lti.IsListType(r.Type) {
return errInvalidEmbedderSignature
}

a := lti.GetArraySubtype(r.Type)
if !lti.IsArrayType(a) || !lti.IsFloatType(lti.GetArraySubtype(a)) {
t := lti.GetListSubtype(r.Type)
if !lti.IsListType(t) || !lti.IsFloatType(lti.GetListSubtype(t)) {
return errInvalidEmbedderSignature
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ require (
github.com/rs/cors v1.11.0
github.com/rs/xid v1.5.0
github.com/rs/zerolog v1.33.0
github.com/spf13/cast v1.7.0
github.com/stretchr/testify v1.9.0
github.com/tetratelabs/wazero v1.8.0
github.com/viterin/vek v0.4.2
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ github.com/dop251/goja v0.0.0-20240806095544-3491d4a58fbe h1:jwFJkgsdelB87ohlXaA
github.com/dop251/goja v0.0.0-20240806095544-3491d4a58fbe/go.mod h1:DF+w/nLMIkvRpyhd/0K+Okbh3fVZBtXLwRtS/ccAa5w=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/getsentry/sentry-go v0.28.1 h1:zzaSm/vHmGllRM6Tpx1492r0YDzauArdBfkJRtY6P5k=
github.com/getsentry/sentry-go v0.28.1/go.mod h1:1fQZ+7l7eeJ3wYi82q5Hg8GqAPgefRq+FP/QhafYVgg=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
Expand Down Expand Up @@ -223,6 +225,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sosodev/duration v1.2.0 h1:pqK/FLSjsAADWY74SyWDCjOcd5l7H8GSnnOGEB9A1Us=
github.com/sosodev/duration v1.2.0/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERAikUR6SDg=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
Expand Down
124 changes: 67 additions & 57 deletions graphql/schemagen/schemagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"context"
"fmt"
"slices"
"sort"
"strings"

"hmruntime/languages"
Expand All @@ -29,15 +30,14 @@ func GetGraphQLSchema(ctx context.Context, md *metadata.Metadata) (*GraphQLSchem
lti := languages.GetLanguageForSDK(md.SDK).TypeInfo()
typeDefs, errors := transformTypes(md.Types, lti)
functions, errs := transformFunctions(md.FnExports, typeDefs, lti)
types := utils.MapValues(typeDefs)
errors = append(errors, errs...)

if len(errors) > 0 {
return nil, fmt.Errorf("failed to generate schema: %+v", errors)
}

functions = filterFunctions(functions)
types = filterTypes(types, functions)
types := filterTypes(utils.MapValues(typeDefs), functions)

buf := bytes.Buffer{}
writeSchema(&buf, functions, types)
Expand All @@ -60,6 +60,10 @@ type TransformError struct {
Error error
}

func (e *TransformError) String() string {
return fmt.Sprintf("source: %+v, error: %v", e.Source, e.Error)
}

func transformTypes(types metadata.TypeMap, lti languages.TypeInfo) (map[string]*TypeDefinition, []*TransformError) {
typeDefs := make(map[string]*TypeDefinition, len(types))
errors := make([]*TransformError, 0)
Expand Down Expand Up @@ -113,7 +117,11 @@ func transformFunctions(functions metadata.FunctionMap, typeDefs map[string]*Typ
errors := make([]*TransformError, 0)

i := 0
for _, f := range functions {
fnNames := utils.MapKeys(functions)
sort.Strings(fnNames)
for _, name := range fnNames {
f := functions[name]

params, err := convertParameters(f.Parameters, lti, typeDefs)
if err != nil {
errors = append(errors, &TransformError{f, err})
Expand Down Expand Up @@ -211,10 +219,10 @@ func writeSchema(buf *bytes.Buffer, functions []*FunctionSignature, typeDefs []*

// sort functions and type definitions
slices.SortFunc(functions, func(a, b *FunctionSignature) int {
return cmp.Compare(a.Name, b.Name)
return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})
slices.SortFunc(typeDefs, func(a, b *TypeDefinition) int {
return cmp.Compare(a.Name, b.Name)
return cmp.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
})

// write query functions
Expand Down Expand Up @@ -324,26 +332,28 @@ func convertResults(results []*metadata.Result, lti languages.TypeInfo, typeDefs
case 1:
// Note: Single result doesn't use the name, even if it's present.
return convertType(results[0].Type, lti, typeDefs, false)
default:
fields := getFieldsFromResults(results)
t := getTypeForFields(fields, typeDefs)
return t, nil
}
}

func getFieldsFromResults(results []*metadata.Result) []*NameTypePair {
fields := make([]*NameTypePair, len(results))
for i, r := range results {
name := r.Name
if name == "" {
name = fmt.Sprintf("field_%d", i+1)
name = fmt.Sprintf("item%d", i+1)
}

typ, err := convertType(r.Type, lti, typeDefs, false)
if err != nil {
return "", err
}

fields[i] = &NameTypePair{
Name: name,
Type: r.Type,
Type: typ,
}
}
return fields

t := getTypeForFields(fields, typeDefs)
return t, nil
}

func getTypeForFields(fields []*NameTypePair, typeDefs map[string]*TypeDefinition) string {
Expand All @@ -369,7 +379,7 @@ func getTypeForFields(fields []*NameTypePair, typeDefs map[string]*TypeDefinitio
// there's no existing type that matches, so create a new one
var name string
for i := 1; ; i++ {
name = fmt.Sprintf("type_%d", i)
name = fmt.Sprintf("_type%d", i)
if _, ok := typeDefs[name]; !ok {
break
}
Expand Down Expand Up @@ -415,48 +425,6 @@ func convertType(typ string, lti languages.TypeInfo, typeDefs map[string]*TypeDe
n = "!"
}

// check for array types
if lti.IsArrayType(typ) {
elem := lti.GetArraySubtype(typ)
t, err := convertType(elem, lti, typeDefs, firstPass)
if err != nil {
return "", err
}
return "[" + t + "]" + n, nil
}

// check for map types
if lti.IsMapType(typ) {
k, v := lti.GetMapSubtypes(typ)
kt, err := convertType(k, lti, typeDefs, firstPass)
if err != nil {
return "", err
}
vt, err := convertType(v, lti, typeDefs, firstPass)
if err != nil {
return "", err
}

// The pair type name will be composed from the key and value types.
// ex: StringStringPair, IntStringPair, StringNullableStringPair, etc.
ktn := utils.If(strings.HasSuffix(kt, "!"), kt[:len(kt)-1], "Nullable"+kt)
vtn := utils.If(strings.HasSuffix(vt, "!"), vt[:len(vt)-1], "Nullable"+vt)
if ktn[0] == '[' {
ktn = ktn[1:len(ktn)-2] + "List"
}
if vtn[0] == '[' {
vtn = vtn[1:len(vtn)-2] + "List"
}
typeName := ktn + vtn + "Pair"

newMapType(typeName, []*NameTypePair{{"key", kt}, {"value", vt}}, typeDefs)

// The map is represented as a list of the pair type.
// The list might be nullable, but the pair type within the list is always non-nullable.
// ex: [StringStringPair!] or [StringStringPair!]!
return "[" + typeName + "!]" + n, nil
}

// convert basic types
// TODO: How do we want to provide GraphQL "ID" scalar types? Maybe they're annotated? or maybe by naming convention?

Expand Down Expand Up @@ -503,6 +471,48 @@ func convertType(typ string, lti languages.TypeInfo, typeDefs map[string]*TypeDe
return newScalar("Timestamp", typeDefs) + n, nil
}

// check for array types
if lti.IsListType(typ) {
elem := lti.GetListSubtype(typ)
t, err := convertType(elem, lti, typeDefs, firstPass)
if err != nil {
return "", err
}
return "[" + t + "]" + n, nil
}

// check for map types
if lti.IsMapType(typ) {
k, v := lti.GetMapSubtypes(typ)
kt, err := convertType(k, lti, typeDefs, firstPass)
if err != nil {
return "", err
}
vt, err := convertType(v, lti, typeDefs, firstPass)
if err != nil {
return "", err
}

// The pair type name will be composed from the key and value types.
// ex: StringStringPair, IntStringPair, StringNullableStringPair, etc.
ktn := utils.If(strings.HasSuffix(kt, "!"), kt[:len(kt)-1], "Nullable"+kt)
vtn := utils.If(strings.HasSuffix(vt, "!"), vt[:len(vt)-1], "Nullable"+vt)
if ktn[0] == '[' {
ktn = ktn[1:len(ktn)-2] + "List"
}
if vtn[0] == '[' {
vtn = vtn[1:len(vtn)-2] + "List"
}
typeName := ktn + vtn + "Pair"

newMapType(typeName, []*NameTypePair{{"key", kt}, {"value", vt}}, typeDefs)

// The map is represented as a list of the pair type.
// The list might be nullable, but the pair type within the list is always non-nullable.
// ex: [StringStringPair!] or [StringStringPair!]!
return "[" + typeName + "!]" + n, nil
}

name := lti.GetNameForType(typ)

// in the first pass, we convert input custom type definitions
Expand Down
23 changes: 4 additions & 19 deletions hostfunctions/hostfunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"hmruntime/languages"
"hmruntime/logger"
"hmruntime/manifestdata"
"hmruntime/plugins"
"hmruntime/sqlclient"
"hmruntime/utils"
"hmruntime/wasmhost"
Expand Down Expand Up @@ -95,14 +94,14 @@ func readParams(ctx context.Context, mod wasm.Module, stack []uint64, params ...
return fmt.Errorf("expected a stack of size %d, but got %d", len(params), len(stack))
}

adapter, err := getWasmAdapter(ctx)
adapter, err := languages.GetWasmAdapter(ctx)
if err != nil {
return err
}

errs := make([]error, 0, len(params))
for i, p := range params {
if err := adapter.DecodeValue(ctx, mod, stack[i], p); err != nil {
if err := adapter.DecodeValue(ctx, stack[i], p); err != nil {
errs = append(errs, err)
}
}
Expand All @@ -124,14 +123,14 @@ func writeResults(ctx context.Context, mod wasm.Module, stack []uint64, results
return fmt.Errorf("not enough stack space to write %d results", len(results))
}

adapter, err := getWasmAdapter(ctx)
adapter, err := languages.GetWasmAdapter(ctx)
if err != nil {
return err
}

errs := make([]error, 0, len(results))
for i, r := range results {
val, err := adapter.EncodeValue(ctx, mod, r)
val, err := adapter.EncodeValue(ctx, r)
if err != nil {
stack[i] = 0
errs = append(errs, err)
Expand All @@ -148,20 +147,6 @@ func writeResults(ctx context.Context, mod wasm.Module, stack []uint64, results

}

func getWasmAdapter(ctx context.Context) (languages.WasmAdapter, error) {
p := plugins.GetPlugin(ctx)
if p == nil {
return nil, errors.New("no plugin found in context")
}

wa := p.Language.WasmAdapter()
if wa == nil {
return nil, errors.New("no wasm adapter found in plugin")
}

return wa, nil
}

// Each message is optional, but if provided, it will be logged at the appropriate time.
type hostFunctionMessages struct {
Starting string
Expand Down
Loading

0 comments on commit 39b4fed

Please sign in to comment.