From bfddceab172ebcd136bc0679bac11f151cf118d3 Mon Sep 17 00:00:00 2001 From: Divi Date: Thu, 13 Jun 2024 14:29:18 +0530 Subject: [PATCH] introduce query variables in update query (#12) * save * Revert "save" This reverts commit 73357c4b0cc460ee0da3f855113c222cf8627007. * query vars * types * types * fix test * fix test * build * update update signature --- client.go | 5 +- cmd/eywagen/eywatest/eywa_fields.go | 65 ++++++++++++++++++ cmd/eywagen/eywatest/eywa_test.go | 16 +++-- cmd/eywagen/eywatest/eywatest.go | 6 +- cmd/eywagen/main.go | 96 +++++++++++++++++++++++++- eywa.go | 23 ++++++- gql_types.go | 100 ++++++++++++++++++++++++++++ marshal_gql.go | 4 +- query_vars.go | 37 ++++++++++ update.go | 24 +++++-- 10 files changed, 356 insertions(+), 20 deletions(-) create mode 100644 gql_types.go create mode 100644 query_vars.go diff --git a/client.go b/client.go index a7152ba..595781c 100644 --- a/client.go +++ b/client.go @@ -40,9 +40,10 @@ func NewClient(gqlEndpoint string, opt *ClientOpts) *Client { return c } -func (c *Client) do(q string) (*bytes.Buffer, error) { +func (c *Client) do(q Queryable) (*bytes.Buffer, error) { reqObj := graphqlRequest{ - Query: q, + Query: q.Query(), + Variables: q.Variables(), } var reqBytes bytes.Buffer diff --git a/cmd/eywagen/eywatest/eywa_fields.go b/cmd/eywagen/eywatest/eywa_fields.go index 3374f04..216c595 100644 --- a/cmd/eywagen/eywatest/eywa_fields.go +++ b/cmd/eywagen/eywatest/eywa_fields.go @@ -16,6 +16,13 @@ func testTable_NameField(val string) eywa.ModelField[testTable] { Value: val, } } + +func testTable_NameVar(val string) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "name", + Value: eywa.QueryVar("testTable_Name", eywa.StringVar[string](val)), + } +} const testTable_Age eywa.ModelFieldName[testTable] = "age" func testTable_AgeField(val *int) eywa.ModelField[testTable] { @@ -24,6 +31,13 @@ func testTable_AgeField(val *int) eywa.ModelField[testTable] { Value: val, } } + +func testTable_AgeVar(val *int) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "age", + Value: eywa.QueryVar("testTable_Age", eywa.NullableIntVar[*int](val)), + } +} const testTable_ID eywa.ModelFieldName[testTable] = "id" func testTable_IDField(val int) eywa.ModelField[testTable] { @@ -32,6 +46,28 @@ func testTable_IDField(val int) eywa.ModelField[testTable] { Value: val, } } + +func testTable_IDVar(val int) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "id", + Value: eywa.QueryVar("testTable_ID", eywa.IntVar[int](val)), + } +} +const testTable_iD eywa.ModelFieldName[testTable] = "idd" + +func testTable_iDField(val int32) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "idd", + Value: val, + } +} + +func testTable_iDVar(val int32) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "idd", + Value: eywa.QueryVar("testTable_iD", eywa.IntVar[int32](val)), + } +} const testTable_custom eywa.ModelFieldName[testTable] = "custom" func testTable_customField(val *customType) eywa.ModelField[testTable] { @@ -41,6 +77,13 @@ func testTable_customField(val *customType) eywa.ModelField[testTable] { } } +func testTable_customVar[T interface{eywa.JSONValue | eywa.JSONBValue;eywa.TypedValue}](val *customType) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "custom", + Value: eywa.QueryVar("testTable_custom", T{val}), + } +} + func testTable_testTable2(subField eywa.ModelFieldName[testTable2], subFields ...eywa.ModelFieldName[testTable2]) string { buf := bytes.NewBuffer([]byte("testTable2 {")) buf.WriteString(string(subField)) @@ -60,6 +103,28 @@ func testTable_JsonBColField(val jsonbcol) eywa.ModelField[testTable] { } } +func testTable_JsonBColVar[T interface{eywa.JSONValue | eywa.JSONBValue;eywa.TypedValue}](val jsonbcol) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "jsonb_col", + Value: eywa.QueryVar("testTable_JsonBCol", T{val}), + } +} +const testTable_RR eywa.ModelFieldName[testTable] = "r" + +func testTable_RRField(val R) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "r", + Value: val, + } +} + +func testTable_RRVar(val R) eywa.ModelField[testTable] { + return eywa.ModelField[testTable]{ + Name: "r", + Value: eywa.QueryVar("testTable_RR", eywa.StringVar[R](val)), + } +} + const testTable2_ID eywa.ModelFieldName[testTable2] = "id" func testTable2_IDField(val uuid.UUID) eywa.ModelField[testTable2] { diff --git a/cmd/eywagen/eywatest/eywa_test.go b/cmd/eywagen/eywatest/eywa_test.go index cfc9872..481e06b 100644 --- a/cmd/eywagen/eywatest/eywa_test.go +++ b/cmd/eywagen/eywatest/eywa_test.go @@ -42,7 +42,7 @@ func TestUpdateQuery(t *testing.T) { eywa.Eq[testTable](testTable_IDField(3)), ).Set( testTable_NameField("updatetest"), - testTable_JsonBColField(jsonbcol{ + testTable_JsonBColVar[eywa.JSONBValue](jsonbcol{ StrField: "abcd", IntField: 2, BoolField: false, @@ -53,15 +53,23 @@ func TestUpdateQuery(t *testing.T) { testTable_ID, ) - expected := `mutation update_test_table { -update_test_table(where: {id: {_eq: 3}}, _set: {name: "updatetest", jsonb_col: "{\"str_field\":\"abcd\",\"int_field\":2,\"bool_field\":false,\"arr_field\":[1,2,3]}"}) { + expected := `mutation update_test_table($testTable_JsonBCol: jsonb) { +update_test_table(where: {id: {_eq: 3}}, _set: {name: "updatetest", jsonb_col: $testTable_JsonBCol}) { returning { id name } } }` - if assert.Equal(t, expected, q.Query()) { + expectedVars := map[string]interface{}{ + "testTable_JsonBCol": jsonbcol{ + StrField: "abcd", + IntField: 2, + BoolField: false, + ArrField: []int{1, 2, 3}, + }, + } + if assert.Equal(t, expected, q.Query()) && assert.Equal(t, expectedVars, q.Variables()) { accessKey := os.Getenv("TEST_HGE_ACCESS_KEY") c := eywa.NewClient("https://aware-cowbird-80.hasura.app/v1/graphql", &eywa.ClientOpts{ Headers: map[string]string{ diff --git a/cmd/eywagen/eywatest/eywatest.go b/cmd/eywagen/eywatest/eywatest.go index fd064fc..f06bac7 100644 --- a/cmd/eywagen/eywatest/eywatest.go +++ b/cmd/eywagen/eywatest/eywatest.go @@ -2,16 +2,20 @@ package eywatest import "github.com/google/uuid" -//go:generate eywagen -types testTable,testTable2 +//go:generate ../eywagen -types testTable,testTable2 type testTable struct { Name string `json:"name"` Age *int `json:"age"` ID int `json:"id,omitempty"` + iD int32 `json:"idd,omitempty"` custom *customType `json:"custom"` testTable2 *testTable2 `json:"testTable2"` JsonBCol jsonbcol `json:"jsonb_col"` + RR R `json:"r"` } +type R string + func (t testTable) ModelName() string { return "test_table" } diff --git a/cmd/eywagen/main.go b/cmd/eywagen/main.go index a62a15b..97bec3b 100644 --- a/cmd/eywagen/main.go +++ b/cmd/eywagen/main.go @@ -35,6 +35,23 @@ func %sField(val %s) eywa.ModelField[%s] { } } ` + modelScalarVarFunc = ` +func %sVar(val %s) eywa.ModelField[%s] { + return eywa.ModelField[%s]{ + Name: "%s", + Value: eywa.QueryVar("%s", %s[%s](val)), + } +} +` + modelVarFunc = ` +func %sVar[T interface{%s;eywa.TypedValue}](val %s) eywa.ModelField[%s] { + return eywa.ModelField[%s]{ + Name: "%s", + Value: eywa.QueryVar("%s", T{val}), + } +} +` + modelRelationshipNameFunc = ` func %s(subField eywa.ModelFieldName[%s], subFields ...eywa.ModelFieldName[%s]) string { buf := bytes.NewBuffer([]byte("%s {")) @@ -143,8 +160,9 @@ func parseType(typeName string, pkg *types.Package, contents *fileContent) { if fieldTypeNameFull[0] == '*' { fieldTypeName = fieldTypeNameFull[1:] } + fieldScalarGqlType := gqlType(fieldType.Underlying().String()) - // *struct -> struct, *[] -> [] + // *struct -> struct, *[] -> [], *int -> int, etc if ptr, ok := fieldType.(*types.Pointer); ok { fieldType = ptr.Elem() } @@ -154,15 +172,19 @@ func parseType(typeName string, pkg *types.Package, contents *fileContent) { } else if array, ok := fieldType.(*types.Array); ok { fieldType = array.Elem() } - // struct -> *struct + var fieldGqlType string if _, ok := fieldType.Underlying().(*types.Struct); ok { fieldType = types.NewPointer(fieldType) + fieldGqlType = "eywa.JSONValue | eywa.JSONBValue" + } else if _, ok := fieldType.Underlying().(*types.Map); ok { + fieldGqlType = "eywa.JSONValue | eywa.JSONBValue" } switch fieldType := fieldType.(type) { case *types.Pointer: - if types.NewMethodSet(fieldType).Lookup(pkg, "ModelName") != nil { + fieldMethodSet := types.NewMethodSet(fieldType) + if m := fieldMethodSet.Lookup(pkg, "ModelName"); m != nil && m.Type().String() == "func() string" { contents.importsMap["bytes"] = true contents.content.WriteString(fmt.Sprintf( modelRelationshipNameFunc, @@ -187,6 +209,30 @@ func parseType(typeName string, pkg *types.Package, contents *fileContent) { typeName, fieldName, )) + if fieldScalarGqlType != "" { + contents.content.WriteString(fmt.Sprintf( + modelScalarVarFunc, + fmt.Sprintf("%s_%s", typeName, field.Name()), + fieldTypeNameFull, + typeName, + typeName, + fieldName, + fmt.Sprintf("%s_%s", typeName, field.Name()), + fmt.Sprintf("eywa.%s", fieldScalarGqlType), + fieldTypeNameFull, + )) + } else if fieldGqlType != "" { + contents.content.WriteString(fmt.Sprintf( + modelVarFunc, + fmt.Sprintf("%s_%s", typeName, field.Name()), + fieldGqlType, + fieldTypeNameFull, + typeName, + typeName, + fieldName, + fmt.Sprintf("%s_%s", typeName, field.Name()), + )) + } } default: contents.content.WriteString(fmt.Sprintf( @@ -203,6 +249,30 @@ func parseType(typeName string, pkg *types.Package, contents *fileContent) { typeName, fieldName, )) + if fieldScalarGqlType != "" { + contents.content.WriteString(fmt.Sprintf( + modelScalarVarFunc, + fmt.Sprintf("%s_%s", typeName, field.Name()), + fieldTypeNameFull, + typeName, + typeName, + fieldName, + fmt.Sprintf("%s_%s", typeName, field.Name()), + fmt.Sprintf("eywa.%sVar", fieldScalarGqlType), + fieldTypeNameFull, + )) + } else if fieldGqlType != "" { + contents.content.WriteString(fmt.Sprintf( + modelVarFunc, + fmt.Sprintf("%s_%s", typeName, field.Name()), + fieldGqlType, + fieldTypeNameFull, + typeName, + typeName, + fieldName, + fmt.Sprintf("%s_%s", typeName, field.Name()), + )) + } } } for _, t := range recurseParse { @@ -252,3 +322,23 @@ func parseFieldTypeName(name, rootPkgPath string) (sourcePkgPath, typeName strin } return matches[2], fmt.Sprintf("%s%s.%s", matches[1], matches[3], matches[4]) } + +var gqlTypes = map[string]string{ + "bool": "Boolean", + "*bool": "NullableBoolean", + "int": "Int", + "*int": "NullableInt", + "float": "Float", + "*float": "NullableFloat", + "string": "String", + "*string": "NullableString", +} + +func gqlType(fieldType string) string { + for k, v := range gqlTypes { + if strings.HasPrefix(fieldType, k) { + return v + } + } + return "" +} diff --git a/eywa.go b/eywa.go index f2209d4..fd00e5f 100644 --- a/eywa.go +++ b/eywa.go @@ -53,7 +53,7 @@ func (f RawField) GetName() string { return f.Name } func (f RawField) GetValue() string { - if val, ok := f.Value.(gqlMarshaller); ok { + if val, ok := f.Value.(gqlMarshaler); ok { return val.marshalGQL() } val, _ := json.Marshal(f.Value) @@ -66,6 +66,9 @@ func (f RawField) GetValue() string { } return string(val) } +func (f RawField) GetRawValue() interface{} { + return f.Value +} type ModelField[M Model] struct { Name string @@ -76,7 +79,11 @@ func (f ModelField[M]) GetName() string { return f.Name } func (f ModelField[M]) GetValue() string { - if val, ok := f.Value.(gqlMarshaller); ok { + if var_, ok := f.Value.(queryVar); ok { + return fmt.Sprintf("$%s", var_.name) + } + + if val, ok := f.Value.(gqlMarshaler); ok { return val.marshalGQL() } @@ -90,11 +97,15 @@ func (f ModelField[M]) GetValue() string { } return string(val) } +func (f ModelField[M]) GetRawValue() interface{} { + return f.Value +} type Field[M Model] interface { RawField | ModelField[M] GetName() string GetValue() string + GetRawValue() interface{} } type fieldArr[M Model, F Field[M]] []F @@ -118,10 +129,12 @@ func (fs fieldArr[M, MF]) marshalGQL() string { type Queryable interface { Query() string + Variables() map[string]interface{} } type QuerySkeleton[M Model, FN FieldName[M], F Field[M]] struct { ModelName string + queryVars queryVarArr // fields ModelFieldArr[M, FN, F] queryArgs[M, FN, F] } @@ -195,8 +208,12 @@ func (sq GetQuery[M, FN, F]) Query() string { ) } +func (sq GetQuery[M, FN, F]) Variables() map[string]interface{} { + return nil +} + func (sq GetQuery[M, FN, F]) Exec(client *Client) ([]M, error) { - respBytes, err := client.do(sq.Query()) + respBytes, err := client.do(sq) if err != nil { return nil, err } diff --git a/gql_types.go b/gql_types.go new file mode 100644 index 0000000..601bcf7 --- /dev/null +++ b/gql_types.go @@ -0,0 +1,100 @@ +package eywa + +//type Type interface { +// Type() string +//} +// +//type Boolean interface { +// ~bool +//} +//type NullableBoolean interface { +// ~*bool +//} +//type Int interface { +// ~int +//} +//type NullableInt interface { +// ~*int +//} +//type Float interface { +// ~float32 | ~float64 +//} +//type NullableFloat interface { +// ~*float32 | ~*float64 +//} +//type String interface { +// ~string +//} +//type NullableBoolean interface { +// ~*string +//} + +type TypedValue interface { + Type() string + Value() interface{} +} + +type scalarValue struct { + name string + value interface{} +} + +func (tv scalarValue) Type() string { + return tv.name +} +func (tv scalarValue) Value() interface{} { + return tv.value +} + +func BooleanVar[T ~bool](val T) TypedValue { + return scalarValue{"Boolean!", val} +} +func NullableBooleanVar[T ~*bool](val T) TypedValue { + return scalarValue{"Boolean", val} +} +func IntVar[T ~int | ~int8 | ~int16 | ~int32 | ~int64](val T) TypedValue { + return scalarValue{"Int!", val} +} +func NullableIntVar[T ~*int | ~*int8 | ~*int16 | ~*int32 | ~*int64](val T) TypedValue { + return scalarValue{"Int", val} +} +func FloatVar[T ~float64 | ~float32](val T) TypedValue { + return scalarValue{"Float!", val} +} +func NullableFloat[T ~*float64 | ~*float32](val T) TypedValue { + return scalarValue{"Float", val} +} +func StringVar[T ~string](val T) TypedValue { + return scalarValue{"String!", val} +} +func NullableStringVar[T ~*string](val T) TypedValue { + return scalarValue{"String", val} +} +func JSONVar(val interface{}) TypedValue { + return JSONValue{val} +} +func JSONBVar(val interface{}) TypedValue { + return JSONBValue{val} +} + +type JSONValue struct { + Val interface{} +} + +func (jv JSONValue) Type() string { + return "json" +} +func (jv JSONValue) Value() interface{} { + return jv.Val +} + +type JSONBValue struct { + Val interface{} +} + +func (jv JSONBValue) Type() string { + return "jsonb" +} +func (jv JSONBValue) Value() interface{} { + return jv.Val +} diff --git a/marshal_gql.go b/marshal_gql.go index 31f83c8..c6f9842 100644 --- a/marshal_gql.go +++ b/marshal_gql.go @@ -1,6 +1,6 @@ package eywa -type gqlMarshaller interface { +type gqlMarshaler interface { marshalGQL() string } @@ -11,6 +11,6 @@ func (he HasuraEnum) marshalGQL() string { return string(he) } -func x(q gqlMarshaller) string { +func x(q gqlMarshaler) string { return "abcd" } diff --git a/query_vars.go b/query_vars.go new file mode 100644 index 0000000..0a85b98 --- /dev/null +++ b/query_vars.go @@ -0,0 +1,37 @@ +package eywa + +import ( + "bytes" + "fmt" +) + +type queryVar struct { + name string + value TypedValue +} + +func (v queryVar) marshalGQL() string { + return fmt.Sprintf("$%s: %s", v.name, v.value.Type()) +} + +type queryVarArr []queryVar + +func (vs queryVarArr) marshalGQL() string { + if len(vs) == 0 { + return "" + } + buf := bytes.NewBufferString("(") + for i, v := range vs { + if i > 0 { + buf.WriteString(", ") + } + buf.WriteString(v.marshalGQL()) + } + buf.WriteString(")") + return buf.String() + +} + +func QueryVar(name string, value TypedValue) queryVar { + return queryVar{name, value} +} diff --git a/update.go b/update.go index afbf07d..2692bc0 100644 --- a/update.go +++ b/update.go @@ -20,6 +20,11 @@ type UpdateQueryBuilder[M Model, FN FieldName[M], F Field[M]] struct { func (uq UpdateQueryBuilder[M, FN, F]) Set(fields ...F) UpdateQueryBuilder[M, FN, F] { uq.set = &set[M, F]{fieldArr[M, F](fields)} + for _, f := range fields { + if var_, ok := f.GetRawValue().(queryVar); ok { + uq.queryVars = append(uq.queryVars, var_) + } + } return uq } @@ -50,7 +55,7 @@ type UpdateQuery[M Model, FN FieldName[M], F Field[M]] struct { fields []FN } -func (uq *UpdateQuery[M, FN, F]) marshalGQL() string { +func (uq UpdateQuery[M, FN, F]) marshalGQL() string { return fmt.Sprintf( "%s {\nreturning {\n%s\n}\n}", uq.uq.marshalGQL(), @@ -58,16 +63,25 @@ func (uq *UpdateQuery[M, FN, F]) marshalGQL() string { ) } -func (uq *UpdateQuery[M, FN, F]) Query() string { +func (uq UpdateQuery[M, FN, F]) Query() string { return fmt.Sprintf( - "mutation update_%s {\n%s\n}", + "mutation update_%s%s {\n%s\n}", uq.uq.ModelName, + uq.uq.queryVars.marshalGQL(), uq.marshalGQL(), ) } -func (uq *UpdateQuery[M, FN, F]) Exec(client *Client) ([]M, error) { - respBytes, err := client.do(uq.Query()) +func (uq UpdateQuery[M, FN, F]) Variables() map[string]interface{} { + vars := map[string]interface{}{} + for _, var_ := range uq.uq.queryVars { + vars[var_.name] = var_.value.Value() + } + return vars +} + +func (uq UpdateQuery[M, FN, F]) Exec(client *Client) ([]M, error) { + respBytes, err := client.do(uq) if err != nil { return nil, err }