From a9842e586908a6b5d8cd418b02f0d5e2ae935db1 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Thu, 9 Jan 2025 22:14:14 -0800 Subject: [PATCH] refactor around drivers and dbs, drivers do the core db work, db is a logical separation --- api.go | 60 +++++++------- api_mutation_gen.go | 20 ++--- api_mutation_helpers.go | 30 +++---- api_query_execution.go | 40 +++++----- api_types.go | 20 ++--- config.go | 4 +- db.go | 163 +------------------------------------- driver.go | 170 ++++++++++++++++++++++++++++++++++++++-- driver_test.go | 18 ++--- live.go | 16 ++-- live_benchmark_test.go | 2 +- live_test.go | 18 ++--- vector_test.go | 14 ++-- 13 files changed, 287 insertions(+), 288 deletions(-) diff --git a/api.go b/api.go index ce04050..5d936c2 100644 --- a/api.go +++ b/api.go @@ -18,13 +18,13 @@ import ( "github.com/hypermodeinc/modusdb/api/structreflect" ) -func Create[T any](driver *Driver, object T, ns ...uint64) (uint64, T, error) { +func Create[T any](driver *Driver, object T, dbId ...uint64) (uint64, T, error) { driver.mutex.Lock() defer driver.mutex.Unlock() - if len(ns) > 1 { + if len(dbId) > 1 { return 0, object, fmt.Errorf("only one namespace is allowed") } - ctx, n, err := getDefaultNamespace(driver, ns...) + ctx, db, err := getDefaultDB(driver, dbId...) if err != nil { return 0, object, err } @@ -36,12 +36,12 @@ func Create[T any](driver *Driver, object T, ns ...uint64) (uint64, T, error) { dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, db, object, gid, &dms, sch) if err != nil { return 0, object, err } - err = n.alterSchemaWithParsed(ctx, sch) + err = driver.alterSchemaWithParsed(ctx, sch) if err != nil { return 0, object, err } @@ -51,19 +51,19 @@ func Create[T any](driver *Driver, object T, ns ...uint64) (uint64, T, error) { return 0, object, err } - return getByGid[T](ctx, n, gid) + return getByGid[T](ctx, db, gid) } -func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, error) { +func Upsert[T any](driver *Driver, object T, dbId ...uint64) (uint64, T, bool, error) { var wasFound bool driver.mutex.Lock() defer driver.mutex.Unlock() - if len(ns) > 1 { + if len(dbId) > 1 { return 0, object, false, fmt.Errorf("only one namespace is allowed") } - ctx, n, err := getDefaultNamespace(driver, ns...) + ctx, db, err := getDefaultDB(driver, dbId...) if err != nil { return 0, object, false, err } @@ -82,18 +82,18 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, db, object, gid, &dms, sch) if err != nil { return 0, object, false, err } - err = n.alterSchemaWithParsed(ctx, sch) + err = db.driver.alterSchemaWithParsed(ctx, sch) if err != nil { return 0, object, false, err } if gid != 0 || cf != nil { - gid, err = getExistingObject[T](ctx, n, gid, cf, object) + gid, err = getExistingObject[T](ctx, db, gid, cf, object) if err != nil && err != apiutils.ErrNoObjFound { return 0, object, false, err } @@ -108,7 +108,7 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err } dms = make([]*dql.Mutation, 0) - err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema[T](ctx, db, object, gid, &dms, sch) if err != nil { return 0, object, false, err } @@ -118,7 +118,7 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err return 0, object, false, err } - gid, object, err = getByGid[T](ctx, n, gid) + gid, object, err = getByGid[T](ctx, db, gid) if err != nil { return 0, object, false, err } @@ -126,60 +126,60 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err return gid, object, wasFound, nil } -func Get[T any, R UniqueField](driver *Driver, uniqueField R, ns ...uint64) (uint64, T, error) { +func Get[T any, R UniqueField](driver *Driver, uniqueField R, dbId ...uint64) (uint64, T, error) { driver.mutex.Lock() defer driver.mutex.Unlock() var obj T - if len(ns) > 1 { + if len(dbId) > 1 { return 0, obj, fmt.Errorf("only one namespace is allowed") } - ctx, n, err := getDefaultNamespace(driver, ns...) + ctx, db, err := getDefaultDB(driver, dbId...) if err != nil { return 0, obj, err } if uid, ok := any(uniqueField).(uint64); ok { - return getByGid[T](ctx, n, uid) + return getByGid[T](ctx, db, uid) } if cf, ok := any(uniqueField).(ConstrainedField); ok { - return getByConstrainedField[T](ctx, n, cf) + return getByConstrainedField[T](ctx, db, cf) } return 0, obj, fmt.Errorf("invalid unique field type") } -func Query[T any](driver *Driver, queryParams QueryParams, ns ...uint64) ([]uint64, []T, error) { +func Query[T any](driver *Driver, queryParams QueryParams, dbId ...uint64) ([]uint64, []T, error) { driver.mutex.Lock() defer driver.mutex.Unlock() - if len(ns) > 1 { + if len(dbId) > 1 { return nil, nil, fmt.Errorf("only one namespace is allowed") } - ctx, n, err := getDefaultNamespace(driver, ns...) + ctx, db, err := getDefaultDB(driver, dbId...) if err != nil { return nil, nil, err } - return executeQuery[T](ctx, n, queryParams, true) + return executeQuery[T](ctx, db, queryParams, true) } -func Delete[T any, R UniqueField](driver *Driver, uniqueField R, ns ...uint64) (uint64, T, error) { +func Delete[T any, R UniqueField](driver *Driver, uniqueField R, dbId ...uint64) (uint64, T, error) { driver.mutex.Lock() defer driver.mutex.Unlock() var zeroObj T - if len(ns) > 1 { + if len(dbId) > 1 { return 0, zeroObj, fmt.Errorf("only one namespace is allowed") } - ctx, n, err := getDefaultNamespace(driver, ns...) + ctx, db, err := getDefaultDB(driver, dbId...) if err != nil { return 0, zeroObj, err } if uid, ok := any(uniqueField).(uint64); ok { - uid, obj, err := getByGid[T](ctx, n, uid) + uid, obj, err := getByGid[T](ctx, db, uid) if err != nil { return 0, zeroObj, err } - dms := generateDeleteDqlMutations(n, uid) + dms := generateDeleteDqlMutations(db, uid) err = applyDqlMutations(ctx, driver, dms) if err != nil { @@ -190,12 +190,12 @@ func Delete[T any, R UniqueField](driver *Driver, uniqueField R, ns ...uint64) ( } if cf, ok := any(uniqueField).(ConstrainedField); ok { - uid, obj, err := getByConstrainedField[T](ctx, n, cf) + uid, obj, err := getByConstrainedField[T](ctx, db, cf) if err != nil { return 0, zeroObj, err } - dms := generateDeleteDqlMutations(n, uid) + dms := generateDeleteDqlMutations(db, uid) err = applyDqlMutations(ctx, driver, dms) if err != nil { diff --git a/api_mutation_gen.go b/api_mutation_gen.go index c2a304b..1fe4c45 100644 --- a/api_mutation_gen.go +++ b/api_mutation_gen.go @@ -26,7 +26,7 @@ import ( "github.com/hypermodeinc/modusdb/api/structreflect" ) -func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object T, +func generateSetDqlMutationsAndSchema[T any](ctx context.Context, d *DB, object T, gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error { t := reflect.TypeOf(object) if t.Kind() != reflect.Struct { @@ -49,7 +49,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object if tagMaps.JsonToReverseEdge[jsonName] != "" { reverseEdgeStr := tagMaps.JsonToReverseEdge[jsonName] typeName := strings.Split(reverseEdgeStr, ".")[0] - currSchema, err := getSchema(ctx, n) + currSchema, err := getSchema(ctx, d) if err != nil { return err } @@ -70,7 +70,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object } if !(typeFound && predicateFound) { - if err := mutations.HandleReverseEdge(jsonName, reflectValueType, n.ID(), sch, + if err := mutations.HandleReverseEdge(jsonName, reflectValueType, d.ID(), sch, reverseEdgeStr); err != nil { return err } @@ -82,17 +82,17 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object continue } - value, err = processStructValue(ctx, value, n) + value, err = processStructValue(ctx, value, d) if err != nil { return err } - value, err = processPointerValue(ctx, value, n) + value, err = processPointerValue(ctx, value, d) if err != nil { return err } - nquad, u, err := mutations.CreateNQuadAndSchema(value, gid, jsonName, t, n.ID()) + nquad, u, err := mutations.CreateNQuadAndSchema(value, gid, jsonName, t, d.ID()) if err != nil { return err } @@ -111,7 +111,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object } sch.Types = append(sch.Types, &pb.TypeUpdate{ - TypeName: apiutils.AddNamespace(n.ID(), t.Name()), + TypeName: apiutils.AddNamespace(d.ID(), t.Name()), Fields: sch.Preds, }) @@ -120,7 +120,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object return err } typeNquad := &api.NQuad{ - Namespace: n.ID(), + Namespace: d.ID(), Subject: fmt.Sprint(gid), Predicate: "dgraph.type", ObjectValue: val, @@ -134,11 +134,11 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object return nil } -func generateDeleteDqlMutations(n *DB, gid uint64) []*dql.Mutation { +func generateDeleteDqlMutations(d *DB, gid uint64) []*dql.Mutation { return []*dql.Mutation{{ Del: []*api.NQuad{ { - Namespace: n.ID(), + Namespace: d.ID(), Subject: fmt.Sprint(gid), Predicate: x.Star, ObjectValue: &api.Value{ diff --git a/api_mutation_helpers.go b/api_mutation_helpers.go index bcb1bf6..4a1f30b 100644 --- a/api_mutation_helpers.go +++ b/api_mutation_helpers.go @@ -14,10 +14,10 @@ import ( "github.com/hypermodeinc/modusdb/api/structreflect" ) -func processStructValue(ctx context.Context, value any, n *DB) (any, error) { +func processStructValue(ctx context.Context, value any, db *DB) (any, error) { if reflect.TypeOf(value).Kind() == reflect.Struct { value = reflect.ValueOf(value).Interface() - newGid, err := getUidOrMutate(ctx, n.driver, n, value) + newGid, err := getUidOrMutate(ctx, db.driver, db, value) if err != nil { return nil, err } @@ -26,19 +26,19 @@ func processStructValue(ctx context.Context, value any, n *DB) (any, error) { return value, nil } -func processPointerValue(ctx context.Context, value any, n *DB) (any, error) { +func processPointerValue(ctx context.Context, value any, db *DB) (any, error) { reflectValueType := reflect.TypeOf(value) if reflectValueType.Kind() == reflect.Pointer { reflectValueType = reflectValueType.Elem() if reflectValueType.Kind() == reflect.Struct { value = reflect.ValueOf(value).Elem().Interface() - return processStructValue(ctx, value, n) + return processStructValue(ctx, value, db) } } return value, nil } -func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (uint64, error) { +func getUidOrMutate[T any](ctx context.Context, driver *Driver, db *DB, object T) (uint64, error) { gid, cfKeyValue, err := structreflect.GetUniqueConstraint[T](object) if err != nil { return 0, err @@ -50,17 +50,17 @@ func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (ui dms := make([]*dql.Mutation, 0) sch := &schema.ParsedSchema{} - err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema(ctx, db, object, gid, &dms, sch) if err != nil { return 0, err } - err = n.alterSchemaWithParsed(ctx, sch) + err = driver.alterSchemaWithParsed(ctx, sch) if err != nil { return 0, err } if gid != 0 || cf != nil { - gid, err = getExistingObject(ctx, n, gid, cf, object) + gid, err = getExistingObject(ctx, db, gid, cf, object) if err != nil && err != apiutils.ErrNoObjFound { return 0, err } @@ -69,18 +69,18 @@ func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (ui } } - gid, err = db.z.nextUID() + gid, err = driver.z.nextUID() if err != nil { return 0, err } dms = make([]*dql.Mutation, 0) - err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch) + err = generateSetDqlMutationsAndSchema(ctx, db, object, gid, &dms, sch) if err != nil { return 0, err } - err = applyDqlMutations(ctx, db, dms) + err = applyDqlMutations(ctx, driver, dms) if err != nil { return 0, err } @@ -88,21 +88,21 @@ func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (ui return gid, nil } -func applyDqlMutations(ctx context.Context, db *Driver, dms []*dql.Mutation) error { +func applyDqlMutations(ctx context.Context, driver *Driver, dms []*dql.Mutation) error { edges, err := query.ToDirectedEdges(dms, nil) if err != nil { return err } - if !db.isOpen.Load() { + if !driver.isOpen.Load() { return ErrClosedDriver } - startTs, err := db.z.nextTs() + startTs, err := driver.z.nextTs() if err != nil { return err } - commitTs, err := db.z.nextTs() + commitTs, err := driver.z.nextTs() if err != nil { return err } diff --git a/api_query_execution.go b/api_query_execution.go index 2cd8ffb..421b648 100644 --- a/api_query_execution.go +++ b/api_query_execution.go @@ -20,34 +20,34 @@ import ( "github.com/hypermodeinc/modusdb/api/structreflect" ) -func getByGid[T any](ctx context.Context, n *DB, gid uint64) (uint64, T, error) { - return executeGet[T](ctx, n, gid) +func getByGid[T any](ctx context.Context, db *DB, gid uint64) (uint64, T, error) { + return executeGet[T](ctx, db, gid) } -func getByGidWithObject[T any](ctx context.Context, n *DB, gid uint64, obj T) (uint64, T, error) { - return executeGetWithObject[T](ctx, n, obj, false, gid) +func getByGidWithObject[T any](ctx context.Context, db *DB, gid uint64, obj T) (uint64, T, error) { + return executeGetWithObject[T](ctx, db, obj, false, gid) } -func getByConstrainedField[T any](ctx context.Context, n *DB, cf ConstrainedField) (uint64, T, error) { - return executeGet[T](ctx, n, cf) +func getByConstrainedField[T any](ctx context.Context, db *DB, cf ConstrainedField) (uint64, T, error) { + return executeGet[T](ctx, db, cf) } -func getByConstrainedFieldWithObject[T any](ctx context.Context, n *DB, +func getByConstrainedFieldWithObject[T any](ctx context.Context, db *DB, cf ConstrainedField, obj T) (uint64, T, error) { - return executeGetWithObject[T](ctx, n, obj, false, cf) + return executeGetWithObject[T](ctx, db, obj, false, cf) } -func executeGet[T any, R UniqueField](ctx context.Context, n *DB, args ...R) (uint64, T, error) { +func executeGet[T any, R UniqueField](ctx context.Context, db *DB, args ...R) (uint64, T, error) { var obj T if len(args) != 1 { - return 0, obj, fmt.Errorf("expected 1 argument, got %d", len(args)) + return 0, obj, fmt.Errorf("expected 1 argument, got %db", len(args)) } - return executeGetWithObject(ctx, n, obj, true, args...) + return executeGetWithObject(ctx, db, obj, true, args...) } -func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *DB, +func executeGetWithObject[T any, R UniqueField](ctx context.Context, db *DB, obj T, withReverse bool, args ...R) (uint64, T, error) { t := reflect.TypeOf(obj) @@ -79,7 +79,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *DB, return 0, obj, fmt.Errorf("constraint not defined for field %s", cf.Key) } - resp, err := n.queryWithLock(ctx, query) + resp, err := db.driver.queryWithLock(ctx, db, query) if err != nil { return 0, obj, err } @@ -107,7 +107,7 @@ func executeGetWithObject[T any, R UniqueField](ctx context.Context, n *DB, return structreflect.ConvertDynamicToTyped[T](result.Obj[0], t) } -func executeQuery[T any](ctx context.Context, n *DB, queryParams QueryParams, +func executeQuery[T any](ctx context.Context, db *DB, queryParams QueryParams, withReverse bool) ([]uint64, []T, error) { var obj T t := reflect.TypeOf(obj) @@ -143,7 +143,7 @@ func executeQuery[T any](ctx context.Context, n *DB, queryParams QueryParams, query := querygen.FormatObjsQuery(t.Name(), filterQueryFunc, paginationAndSorting, readFromQuery) - resp, err := n.queryWithLock(ctx, query) + resp, err := db.driver.queryWithLock(ctx, db, query) if err != nil { return nil, nil, err } @@ -186,13 +186,13 @@ func executeQuery[T any](ctx context.Context, n *DB, queryParams QueryParams, return gids, objs, nil } -func getExistingObject[T any](ctx context.Context, n *DB, gid uint64, cf *ConstrainedField, +func getExistingObject[T any](ctx context.Context, db *DB, gid uint64, cf *ConstrainedField, object T) (uint64, error) { var err error if gid != 0 { - gid, _, err = getByGidWithObject[T](ctx, n, gid, object) + gid, _, err = getByGidWithObject[T](ctx, db, gid, object) } else if cf != nil { - gid, _, err = getByConstrainedFieldWithObject[T](ctx, n, *cf, object) + gid, _, err = getByConstrainedFieldWithObject[T](ctx, db, *cf, object) } if err != nil { return 0, err @@ -200,8 +200,8 @@ func getExistingObject[T any](ctx context.Context, n *DB, gid uint64, cf *Constr return gid, nil } -func getSchema(ctx context.Context, n *DB) (*querygen.SchemaResponse, error) { - resp, err := n.queryWithLock(ctx, querygen.SchemaQuery) +func getSchema(ctx context.Context, db *DB) (*querygen.SchemaResponse, error) { + resp, err := db.driver.queryWithLock(ctx, db, querygen.SchemaQuery) if err != nil { return nil, err } diff --git a/api_types.go b/api_types.go index 010dd79..37aae5d 100644 --- a/api_types.go +++ b/api_types.go @@ -75,32 +75,32 @@ type VectorPredicate struct { type ModusDbOption func(*modusDbOptions) type modusDbOptions struct { - namespace uint64 + db uint64 } -func WithNamespace(namespace uint64) ModusDbOption { +func WithDB(db uint64) ModusDbOption { return func(o *modusDbOptions) { - o.namespace = namespace + o.db = db } } -func getDefaultNamespace(db *Driver, ns ...uint64) (context.Context, *DB, error) { +func getDefaultDB(driver *Driver, dbId ...uint64) (context.Context, *DB, error) { dbOpts := &modusDbOptions{ - namespace: db.db0.ID(), + db: driver.db0.ID(), } - for _, ns := range ns { - WithNamespace(ns)(dbOpts) + for _, db := range dbId { + WithDB(db)(dbOpts) } - n, err := db.getDBWithLock(dbOpts.namespace) + d, err := driver.getDBWithLock(dbOpts.db) if err != nil { return nil, nil, err } ctx := context.Background() - ctx = x.AttachNamespace(ctx, n.ID()) + ctx = x.AttachNamespace(ctx, d.ID()) - return ctx, n, nil + return ctx, d, nil } func filterToQueryFunc(typeName string, f Filter) querygen.QueryFunc { diff --git a/config.go b/config.go index 76d3d1b..0eac913 100644 --- a/config.go +++ b/config.go @@ -20,8 +20,8 @@ func NewDefaultConfig(dir string) Config { return Config{dataDir: dir, limitNormalizeNode: 10000} } -func (cc Config) WithLimitNormalizeNode(n int) Config { - cc.limitNormalizeNode = n +func (cc Config) WithLimitNormalizeNode(d int) Config { + cc.limitNormalizeNode = d return cc } diff --git a/db.go b/db.go index a50b58d..6d18cb6 100644 --- a/db.go +++ b/db.go @@ -11,17 +11,8 @@ package modusdb import ( "context" - "fmt" - "strconv" "github.com/dgraph-io/dgo/v240/protos/api" - "github.com/dgraph-io/dgraph/v24/dql" - "github.com/dgraph-io/dgraph/v24/edgraph" - "github.com/dgraph-io/dgraph/v24/protos/pb" - "github.com/dgraph-io/dgraph/v24/query" - "github.com/dgraph-io/dgraph/v24/schema" - "github.com/dgraph-io/dgraph/v24/worker" - "github.com/dgraph-io/dgraph/v24/x" ) // DB is one of the namespaces in modusDB. @@ -36,164 +27,18 @@ func (db *DB) ID() uint64 { // DropData drops all the data in the modusDB instance. func (db *DB) DropData(ctx context.Context) error { - db.driver.mutex.Lock() - defer db.driver.mutex.Unlock() - - if !db.driver.isOpen.Load() { - return ErrClosedDriver - } - - p := &pb.Proposal{Mutations: &pb.Mutations{ - GroupId: 1, - DropOp: pb.Mutations_DATA, - DropValue: strconv.FormatUint(db.ID(), 10), - }} - - if err := worker.ApplyMutations(ctx, p); err != nil { - return fmt.Errorf("error applying mutation: %w", err) - } - - // TODO: insert drop record - // TODO: should we reset back the timestamp as well? - return nil + return db.driver.dropData(ctx, db) } func (db *DB) AlterSchema(ctx context.Context, sch string) error { - db.driver.mutex.Lock() - defer db.driver.mutex.Unlock() - - if !db.driver.isOpen.Load() { - return ErrClosedDriver - } - - sc, err := schema.ParseWithNamespace(sch, db.ID()) - if err != nil { - return fmt.Errorf("error parsing schema: %w", err) - } - return db.alterSchemaWithParsed(ctx, sc) -} - -func (db *DB) alterSchemaWithParsed(ctx context.Context, sc *schema.ParsedSchema) error { - for _, pred := range sc.Preds { - worker.InitTablet(pred.Predicate) - } - - startTs, err := db.driver.z.nextTs() - if err != nil { - return err - } - - p := &pb.Proposal{Mutations: &pb.Mutations{ - GroupId: 1, - StartTs: startTs, - Schema: sc.Preds, - Types: sc.Types, - }} - if err := worker.ApplyMutations(ctx, p); err != nil { - return fmt.Errorf("error applying mutation: %w", err) - } - return nil + return db.driver.alterSchema(ctx, db, sch) } func (db *DB) Mutate(ctx context.Context, ms []*api.Mutation) (map[string]uint64, error) { - if len(ms) == 0 { - return nil, nil - } - - db.driver.mutex.Lock() - defer db.driver.mutex.Unlock() - dms := make([]*dql.Mutation, 0, len(ms)) - for _, mu := range ms { - dm, err := edgraph.ParseMutationObject(mu, false) - if err != nil { - return nil, fmt.Errorf("error parsing mutation: %w", err) - } - dms = append(dms, dm) - } - newUids, err := query.ExtractBlankUIDs(ctx, dms) - if err != nil { - return nil, err - } - if len(newUids) > 0 { - num := &pb.Num{Val: uint64(len(newUids)), Type: pb.Num_UID} - res, err := db.driver.z.nextUIDs(num) - if err != nil { - return nil, err - } - - curId := res.StartId - for k := range newUids { - x.AssertTruef(curId != 0 && curId <= res.EndId, "not enough uids generated") - newUids[k] = curId - curId++ - } - } - - return db.mutateWithDqlMutation(ctx, dms, newUids) -} - -func (db *DB) mutateWithDqlMutation(ctx context.Context, dms []*dql.Mutation, - newUids map[string]uint64) (map[string]uint64, error) { - edges, err := query.ToDirectedEdges(dms, newUids) - if err != nil { - return nil, err - } - ctx = x.AttachNamespace(ctx, db.ID()) - - if !db.driver.isOpen.Load() { - return nil, ErrClosedDriver - } - - startTs, err := db.driver.z.nextTs() - if err != nil { - return nil, err - } - commitTs, err := db.driver.z.nextTs() - if err != nil { - return nil, err - } - - m := &pb.Mutations{ - GroupId: 1, - StartTs: startTs, - Edges: edges, - } - m.Edges, err = query.ExpandEdges(ctx, m) - if err != nil { - return nil, fmt.Errorf("error expanding edges: %w", err) - } - - for _, edge := range m.Edges { - worker.InitTablet(edge.Attr) - } - - p := &pb.Proposal{Mutations: m, StartTs: startTs} - if err := worker.ApplyMutations(ctx, p); err != nil { - return nil, err - } - - return newUids, worker.ApplyCommited(ctx, &pb.OracleDelta{ - Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, - }) + return db.driver.mutate(ctx, db, ms) } // Query performs query or mutation or upsert on the given modusDB instance. func (db *DB) Query(ctx context.Context, query string) (*api.Response, error) { - db.driver.mutex.RLock() - defer db.driver.mutex.RUnlock() - - return db.queryWithLock(ctx, query) -} - -func (db *DB) queryWithLock(ctx context.Context, query string) (*api.Response, error) { - if !db.driver.isOpen.Load() { - return nil, ErrClosedDriver - } - - ctx = x.AttachNamespace(ctx, db.ID()) - return (&edgraph.Server{}).QueryNoAuth(ctx, &api.Request{ - ReadOnly: true, - Query: query, - StartTs: db.driver.z.readTs(), - }) + return db.driver.query(ctx, db, query) } diff --git a/driver.go b/driver.go index 6a29a01..a1e437f 100644 --- a/driver.go +++ b/driver.go @@ -14,14 +14,17 @@ import ( "errors" "fmt" "path" + "strconv" "sync" "sync/atomic" "github.com/dgraph-io/badger/v4" "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/dgraph-io/dgraph/v24/dql" "github.com/dgraph-io/dgraph/v24/edgraph" "github.com/dgraph-io/dgraph/v24/posting" "github.com/dgraph-io/dgraph/v24/protos/pb" + "github.com/dgraph-io/dgraph/v24/query" "github.com/dgraph-io/dgraph/v24/schema" "github.com/dgraph-io/dgraph/v24/worker" "github.com/dgraph-io/dgraph/v24/x" @@ -139,6 +142,10 @@ func (driver *Driver) getDBWithLock(dbID uint64) (*DB, error) { return &DB{id: dbID, driver: driver}, nil } +func (driver *Driver) GetDefaultDB() *DB { + return driver.db0 +} + // DropAll drops all the data and schema in the modusDB instance. func (driver *Driver) DropAll(ctx context.Context) error { driver.mutex.Lock() @@ -163,20 +170,167 @@ func (driver *Driver) DropAll(ctx context.Context) error { return nil } -func (driver *Driver) DropData(ctx context.Context) error { - return driver.db0.DropData(ctx) +func (driver *Driver) dropData(ctx context.Context, db *DB) error { + driver.mutex.Lock() + defer driver.mutex.Unlock() + + if !driver.isOpen.Load() { + return ErrClosedDriver + } + + p := &pb.Proposal{Mutations: &pb.Mutations{ + GroupId: 1, + DropOp: pb.Mutations_DATA, + DropValue: strconv.FormatUint(db.ID(), 10), + }} + + if err := worker.ApplyMutations(ctx, p); err != nil { + return fmt.Errorf("error applying mutation: %w", err) + } + + // TODO: insert drop record + // TODO: should we reset back the timestamp as well? + return nil +} + +func (driver *Driver) alterSchema(ctx context.Context, db *DB, sch string) error { + driver.mutex.Lock() + defer driver.mutex.Unlock() + + if !driver.isOpen.Load() { + return ErrClosedDriver + } + + sc, err := schema.ParseWithNamespace(sch, db.ID()) + if err != nil { + return fmt.Errorf("error parsing schema: %w", err) + } + return driver.alterSchemaWithParsed(ctx, sc) +} + +func (driver *Driver) alterSchemaWithParsed(ctx context.Context, sc *schema.ParsedSchema) error { + for _, pred := range sc.Preds { + worker.InitTablet(pred.Predicate) + } + + startTs, err := driver.z.nextTs() + if err != nil { + return err + } + + p := &pb.Proposal{Mutations: &pb.Mutations{ + GroupId: 1, + StartTs: startTs, + Schema: sc.Preds, + Types: sc.Types, + }} + if err := worker.ApplyMutations(ctx, p); err != nil { + return fmt.Errorf("error applying mutation: %w", err) + } + return nil +} + +func (driver *Driver) query(ctx context.Context, db *DB, q string) (*api.Response, error) { + driver.mutex.RLock() + defer driver.mutex.RUnlock() + + return driver.queryWithLock(ctx, db, q) } -func (driver *Driver) AlterSchema(ctx context.Context, sch string) error { - return driver.db0.AlterSchema(ctx, sch) +func (driver *Driver) queryWithLock(ctx context.Context, db *DB, q string) (*api.Response, error) { + if !driver.isOpen.Load() { + return nil, ErrClosedDriver + } + + ctx = x.AttachNamespace(ctx, db.ID()) + return (&edgraph.Server{}).QueryNoAuth(ctx, &api.Request{ + ReadOnly: true, + Query: q, + StartTs: driver.z.readTs(), + }) } -func (driver *Driver) Query(ctx context.Context, q string) (*api.Response, error) { - return driver.db0.Query(ctx, q) +func (driver *Driver) mutate(ctx context.Context, db *DB, ms []*api.Mutation) (map[string]uint64, error) { + if len(ms) == 0 { + return nil, nil + } + + driver.mutex.Lock() + defer driver.mutex.Unlock() + dms := make([]*dql.Mutation, 0, len(ms)) + for _, mu := range ms { + dm, err := edgraph.ParseMutationObject(mu, false) + if err != nil { + return nil, fmt.Errorf("error parsing mutation: %w", err) + } + dms = append(dms, dm) + } + newUids, err := query.ExtractBlankUIDs(ctx, dms) + if err != nil { + return nil, err + } + if len(newUids) > 0 { + num := &pb.Num{Val: uint64(len(newUids)), Type: pb.Num_UID} + res, err := driver.z.nextUIDs(num) + if err != nil { + return nil, err + } + + curId := res.StartId + for k := range newUids { + x.AssertTruef(curId != 0 && curId <= res.EndId, "not enough uids generated") + newUids[k] = curId + curId++ + } + } + + return driver.mutateWithDqlMutation(ctx, db, dms, newUids) } -func (driver *Driver) Mutate(ctx context.Context, ms []*api.Mutation) (map[string]uint64, error) { - return driver.db0.Mutate(ctx, ms) +func (driver *Driver) mutateWithDqlMutation(ctx context.Context, db *DB, dms []*dql.Mutation, + newUids map[string]uint64) (map[string]uint64, error) { + edges, err := query.ToDirectedEdges(dms, newUids) + if err != nil { + return nil, fmt.Errorf("error converting to directed edges: %w", err) + } + ctx = x.AttachNamespace(ctx, db.ID()) + + if !driver.isOpen.Load() { + return nil, ErrClosedDriver + } + + startTs, err := driver.z.nextTs() + if err != nil { + return nil, err + } + commitTs, err := driver.z.nextTs() + if err != nil { + return nil, err + } + + m := &pb.Mutations{ + GroupId: 1, + StartTs: startTs, + Edges: edges, + } + + m.Edges, err = query.ExpandEdges(ctx, m) + if err != nil { + return nil, fmt.Errorf("error expanding edges: %w", err) + } + + for _, edge := range m.Edges { + worker.InitTablet(edge.Attr) + } + + p := &pb.Proposal{Mutations: m, StartTs: startTs} + if err := worker.ApplyMutations(ctx, p); err != nil { + return nil, err + } + + return newUids, worker.ApplyCommited(ctx, &pb.OracleDelta{ + Txns: []*pb.TxnStatus{{StartTs: startTs, CommitTs: commitTs}}, + }) } func (driver *Driver) Load(ctx context.Context, schemaPath, dataPath string) error { diff --git a/driver_test.go b/driver_test.go index 24d1a82..26dcfaa 100644 --- a/driver_test.go +++ b/driver_test.go @@ -30,9 +30,9 @@ func TestRestart(t *testing.T) { defer func() { driver.Close() }() require.NoError(t, driver.DropAll(context.Background())) - require.NoError(t, driver.AlterSchema(context.Background(), "name: string @index(term) .")) + require.NoError(t, driver.GetDefaultDB().AlterSchema(context.Background(), "name: string @index(term) .")) - _, err = driver.Mutate(context.Background(), []*api.Mutation{ + _, err = driver.GetDefaultDB().Mutate(context.Background(), []*api.Mutation{ { Set: []*api.NQuad{ { @@ -51,14 +51,14 @@ func TestRestart(t *testing.T) { name } }` - qresp, err := driver.Query(context.Background(), query) + qresp, err := driver.GetDefaultDB().Query(context.Background(), query) require.NoError(t, err) require.JSONEq(t, `{"me":[{"name":"A"}]}`, string(qresp.GetJson())) driver.Close() driver, err = modusdb.NewDriver(modusdb.NewDefaultConfig(dataDir)) require.NoError(t, err) - qresp, err = driver.Query(context.Background(), query) + qresp, err = driver.GetDefaultDB().Query(context.Background(), query) require.NoError(t, err) require.JSONEq(t, `{"me":[{"name":"A"}]}`, string(qresp.GetJson())) @@ -71,7 +71,7 @@ func TestSchemaQuery(t *testing.T) { defer driver.Close() require.NoError(t, driver.DropAll(context.Background())) - require.NoError(t, driver.AlterSchema(context.Background(), ` + require.NoError(t, driver.GetDefaultDB().AlterSchema(context.Background(), ` name: string @index(exact) . age: int . married: bool . @@ -79,7 +79,7 @@ func TestSchemaQuery(t *testing.T) { dob: datetime . `)) - resp, err := driver.Query(context.Background(), `schema(pred: [name, age]) {type}`) + resp, err := driver.GetDefaultDB().Query(context.Background(), `schema(pred: [name, age]) {type}`) require.NoError(t, err) require.JSONEq(t, @@ -100,10 +100,10 @@ func TestBasicVector(t *testing.T) { defer driver.Close() require.NoError(t, driver.DropAll(context.Background())) - require.NoError(t, driver.AlterSchema(context.Background(), + require.NoError(t, driver.GetDefaultDB().AlterSchema(context.Background(), `project_description_v: float32vector @index(hnsw(exponent: "5", metric: "euclidean")) .`)) - uids, err := driver.Mutate(context.Background(), []*api.Mutation{{ + uids, err := driver.GetDefaultDB().Mutate(context.Background(), []*api.Mutation{{ Set: []*api.NQuad{{ Subject: "_:vector", Predicate: "project_description_v", @@ -119,7 +119,7 @@ func TestBasicVector(t *testing.T) { t.Fatalf("Expected non-zero uid") } - resp, err := driver.Query(context.Background(), fmt.Sprintf(`query { + resp, err := driver.GetDefaultDB().Query(context.Background(), fmt.Sprintf(`query { q (func: uid(%v)) { project_description_v } diff --git a/live.go b/live.go index 1d4f6fb..2bfea22 100644 --- a/live.go +++ b/live.go @@ -34,29 +34,29 @@ const ( ) type liveLoader struct { - n *DB + d *DB blankNodes map[string]string mutex sync.RWMutex } -func (n *DB) Load(ctx context.Context, schemaPath, dataPath string) error { +func (d *DB) Load(ctx context.Context, schemaPath, dataPath string) error { schemaData, err := os.ReadFile(schemaPath) if err != nil { return fmt.Errorf("error reading schema file [%v]: %w", schemaPath, err) } - if err := n.AlterSchema(ctx, string(schemaData)); err != nil { + if err := d.AlterSchema(ctx, string(schemaData)); err != nil { return fmt.Errorf("error altering schema: %w", err) } - if err := n.LoadData(ctx, dataPath); err != nil { + if err := d.LoadData(ctx, dataPath); err != nil { return fmt.Errorf("error loading data: %w", err) } return nil } // TODO: Add support for CSV file -func (n *DB) LoadData(inCtx context.Context, dataDir string) error { +func (d *DB) LoadData(inCtx context.Context, dataDir string) error { fs := filestore.NewFileStore(dataDir) files := fs.FindDataFiles(dataDir, []string{".rdf", ".rdf.gz", ".json", ".json.gz"}) if len(files) == 0 { @@ -94,7 +94,7 @@ func (n *DB) LoadData(inCtx context.Context, dataDir string) error { if !ok { return nil } - uids, err := n.Mutate(rootCtx, []*api.Mutation{nqs}) + uids, err := d.Mutate(rootCtx, []*api.Mutation{nqs}) if err != nil { return fmt.Errorf("error applying mutations: %w", err) } @@ -104,7 +104,7 @@ func (n *DB) LoadData(inCtx context.Context, dataDir string) error { } }) - ll := &liveLoader{n: n, blankNodes: make(map[string]string)} + ll := &liveLoader{d: d, blankNodes: make(map[string]string)} for _, datafile := range files { procG.Go(func() error { return ll.processFile(procCtx, fs, datafile, nqch) @@ -246,7 +246,7 @@ func (l *liveLoader) uid(ns uint64, val string) (string, error) { return uid, nil } - asUID, err := l.n.driver.LeaseUIDs(1) + asUID, err := l.d.driver.LeaseUIDs(1) if err != nil { return "", fmt.Errorf("error allocating UID: %w", err) } diff --git a/live_benchmark_test.go b/live_benchmark_test.go index f504a58..036862d 100644 --- a/live_benchmark_test.go +++ b/live_benchmark_test.go @@ -112,7 +112,7 @@ func BenchmarkDatabaseOperations(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - resp, err := driver.Query(context.Background(), query) + resp, err := driver.GetDefaultDB().Query(context.Background(), query) require.NoError(b, err) require.JSONEq(b, expected, string(resp.Json)) } diff --git a/live_test.go b/live_test.go index 9f99373..62a2731 100644 --- a/live_test.go +++ b/live_test.go @@ -49,16 +49,16 @@ const ( func TestLiveLoaderSmall(t *testing.T) { - db, err := modusdb.NewDriver(modusdb.NewDefaultConfig(t.TempDir())) + driver, err := modusdb.NewDriver(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) - defer db.Close() + defer driver.Close() dataFolder := t.TempDir() schemaFile := filepath.Join(dataFolder, "data.schema") dataFile := filepath.Join(dataFolder, "data.rdf") require.NoError(t, os.WriteFile(schemaFile, []byte(DbSchema), 0600)) require.NoError(t, os.WriteFile(dataFile, []byte(SmallData), 0600)) - require.NoError(t, db.Load(context.Background(), schemaFile, dataFile)) + require.NoError(t, driver.Load(context.Background(), schemaFile, dataFile)) const query = `{ caro(func: allofterms(name@en, "Marc Caro")) { @@ -84,15 +84,15 @@ func TestLiveLoaderSmall(t *testing.T) { ] }` - resp, err := db.Query(context.Background(), query) + resp, err := driver.GetDefaultDB().Query(context.Background(), query) require.NoError(t, err) require.JSONEq(t, expected, string(resp.Json)) } func TestLiveLoader1Million(t *testing.T) { - db, err := modusdb.NewDriver(modusdb.NewDefaultConfig(t.TempDir())) + driver, err := modusdb.NewDriver(modusdb.NewDefaultConfig(t.TempDir())) require.NoError(t, err) - defer db.Close() + defer driver.Close() baseDir := t.TempDir() schResp, err := grab.Get(baseDir, oneMillionSchema) @@ -100,12 +100,12 @@ func TestLiveLoader1Million(t *testing.T) { dataResp, err := grab.Get(baseDir, oneMillionRDF) require.NoError(t, err) - require.NoError(t, db.DropAll(context.Background())) - require.NoError(t, db.Load(context.Background(), schResp.Filename, dataResp.Filename)) + require.NoError(t, driver.DropAll(context.Background())) + require.NoError(t, driver.Load(context.Background(), schResp.Filename, dataResp.Filename)) for _, tt := range common.OneMillionTCs { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - resp, err := db.Query(ctx, tt.Query) + resp, err := driver.GetDefaultDB().Query(ctx, tt.Query) cancel() if ctx.Err() == context.DeadlineExceeded { diff --git a/vector_test.go b/vector_test.go index 041bb1a..f944fa1 100644 --- a/vector_test.go +++ b/vector_test.go @@ -34,7 +34,7 @@ func TestVectorDelete(t *testing.T) { defer driver.Close() require.NoError(t, driver.DropAll(context.Background())) - require.NoError(t, driver.AlterSchema(context.Background(), + require.NoError(t, driver.GetDefaultDB().AlterSchema(context.Background(), fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidean"))) // insert random vectors @@ -42,7 +42,7 @@ func TestVectorDelete(t *testing.T) { require.NoError(t, err) //nolint:gosec rdf, vectors := dgraphapi.GenerateRandomVectors(int(assignIDs.StartId)-10, int(assignIDs.EndId)-10, 10, "vtest") - _, err = driver.Mutate(context.Background(), []*api.Mutation{{SetNquads: []byte(rdf)}}) + _, err = driver.GetDefaultDB().Mutate(context.Background(), []*api.Mutation{{SetNquads: []byte(rdf)}}) require.NoError(t, err) // check the count of the vectors inserted @@ -51,7 +51,7 @@ func TestVectorDelete(t *testing.T) { count(uid) } }` - resp, err := driver.Query(context.Background(), q1) + resp, err := driver.GetDefaultDB().Query(context.Background(), q1) require.NoError(t, err) require.JSONEq(t, fmt.Sprintf(`{"vector":[{"count":%d}]}`, numVectors), string(resp.Json)) @@ -68,7 +68,7 @@ func TestVectorDelete(t *testing.T) { triples := strings.Split(rdf, "\n") deleteTriple := func(idx int) string { - _, err := driver.Mutate(context.Background(), []*api.Mutation{{ + _, err := driver.GetDefaultDB().Mutate(context.Background(), []*api.Mutation{{ DelNquads: []byte(triples[idx]), }}) require.NoError(t, err) @@ -80,7 +80,7 @@ func TestVectorDelete(t *testing.T) { } }`, uid[1:len(uid)-1]) - res, err := driver.Query(context.Background(), q2) + res, err := driver.GetDefaultDB().Query(context.Background(), q2) require.NoError(t, err) require.JSONEq(t, `{"vector":[]}`, string(res.Json)) return triples[idx] @@ -105,8 +105,8 @@ func TestVectorDelete(t *testing.T) { _ = queryVectors(t, driver, fmt.Sprintf(q3, strings.Split(triple, `"`)[1])) } -func queryVectors(t *testing.T, db *modusdb.Driver, query string) [][]float32 { - resp, err := db.Query(context.Background(), query) +func queryVectors(t *testing.T, driver *modusdb.Driver, query string) [][]float32 { + resp, err := driver.GetDefaultDB().Query(context.Background(), query) require.NoError(t, err) var data struct {