Skip to content

Commit

Permalink
refactor around drivers and dbs, drivers do the core db work, db is a…
Browse files Browse the repository at this point in the history
… logical separation
  • Loading branch information
jairad26 committed Jan 10, 2025
1 parent 392d287 commit a9842e5
Show file tree
Hide file tree
Showing 13 changed files with 287 additions and 288 deletions.
60 changes: 30 additions & 30 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ import (
"github.com/hypermodeinc/modusdb/api/structreflect"
)

func Create[T any](driver *Driver, object T, ns ...uint64) (uint64, T, error) {
func Create[T any](driver *Driver, object T, dbId ...uint64) (uint64, T, error) {
driver.mutex.Lock()
defer driver.mutex.Unlock()
if len(ns) > 1 {
if len(dbId) > 1 {
return 0, object, fmt.Errorf("only one namespace is allowed")
}
ctx, n, err := getDefaultNamespace(driver, ns...)
ctx, db, err := getDefaultDB(driver, dbId...)
if err != nil {
return 0, object, err
}
Expand All @@ -36,12 +36,12 @@ func Create[T any](driver *Driver, object T, ns ...uint64) (uint64, T, error) {

dms := make([]*dql.Mutation, 0)
sch := &schema.ParsedSchema{}
err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch)
err = generateSetDqlMutationsAndSchema[T](ctx, db, object, gid, &dms, sch)
if err != nil {
return 0, object, err
}

err = n.alterSchemaWithParsed(ctx, sch)
err = driver.alterSchemaWithParsed(ctx, sch)
if err != nil {
return 0, object, err
}
Expand All @@ -51,19 +51,19 @@ func Create[T any](driver *Driver, object T, ns ...uint64) (uint64, T, error) {
return 0, object, err
}

return getByGid[T](ctx, n, gid)
return getByGid[T](ctx, db, gid)
}

func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, error) {
func Upsert[T any](driver *Driver, object T, dbId ...uint64) (uint64, T, bool, error) {

var wasFound bool
driver.mutex.Lock()
defer driver.mutex.Unlock()
if len(ns) > 1 {
if len(dbId) > 1 {
return 0, object, false, fmt.Errorf("only one namespace is allowed")
}

ctx, n, err := getDefaultNamespace(driver, ns...)
ctx, db, err := getDefaultDB(driver, dbId...)
if err != nil {
return 0, object, false, err
}
Expand All @@ -82,18 +82,18 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err

dms := make([]*dql.Mutation, 0)
sch := &schema.ParsedSchema{}
err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch)
err = generateSetDqlMutationsAndSchema[T](ctx, db, object, gid, &dms, sch)
if err != nil {
return 0, object, false, err
}

err = n.alterSchemaWithParsed(ctx, sch)
err = db.driver.alterSchemaWithParsed(ctx, sch)
if err != nil {
return 0, object, false, err
}

if gid != 0 || cf != nil {
gid, err = getExistingObject[T](ctx, n, gid, cf, object)
gid, err = getExistingObject[T](ctx, db, gid, cf, object)
if err != nil && err != apiutils.ErrNoObjFound {
return 0, object, false, err
}
Expand All @@ -108,7 +108,7 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err
}

dms = make([]*dql.Mutation, 0)
err = generateSetDqlMutationsAndSchema[T](ctx, n, object, gid, &dms, sch)
err = generateSetDqlMutationsAndSchema[T](ctx, db, object, gid, &dms, sch)
if err != nil {
return 0, object, false, err
}
Expand All @@ -118,68 +118,68 @@ func Upsert[T any](driver *Driver, object T, ns ...uint64) (uint64, T, bool, err
return 0, object, false, err
}

gid, object, err = getByGid[T](ctx, n, gid)
gid, object, err = getByGid[T](ctx, db, gid)
if err != nil {
return 0, object, false, err
}

return gid, object, wasFound, nil
}

func Get[T any, R UniqueField](driver *Driver, uniqueField R, ns ...uint64) (uint64, T, error) {
func Get[T any, R UniqueField](driver *Driver, uniqueField R, dbId ...uint64) (uint64, T, error) {
driver.mutex.Lock()
defer driver.mutex.Unlock()
var obj T
if len(ns) > 1 {
if len(dbId) > 1 {
return 0, obj, fmt.Errorf("only one namespace is allowed")
}
ctx, n, err := getDefaultNamespace(driver, ns...)
ctx, db, err := getDefaultDB(driver, dbId...)
if err != nil {
return 0, obj, err
}
if uid, ok := any(uniqueField).(uint64); ok {
return getByGid[T](ctx, n, uid)
return getByGid[T](ctx, db, uid)
}

if cf, ok := any(uniqueField).(ConstrainedField); ok {
return getByConstrainedField[T](ctx, n, cf)
return getByConstrainedField[T](ctx, db, cf)
}

return 0, obj, fmt.Errorf("invalid unique field type")
}

func Query[T any](driver *Driver, queryParams QueryParams, ns ...uint64) ([]uint64, []T, error) {
func Query[T any](driver *Driver, queryParams QueryParams, dbId ...uint64) ([]uint64, []T, error) {
driver.mutex.Lock()
defer driver.mutex.Unlock()
if len(ns) > 1 {
if len(dbId) > 1 {
return nil, nil, fmt.Errorf("only one namespace is allowed")
}
ctx, n, err := getDefaultNamespace(driver, ns...)
ctx, db, err := getDefaultDB(driver, dbId...)
if err != nil {
return nil, nil, err
}

return executeQuery[T](ctx, n, queryParams, true)
return executeQuery[T](ctx, db, queryParams, true)
}

func Delete[T any, R UniqueField](driver *Driver, uniqueField R, ns ...uint64) (uint64, T, error) {
func Delete[T any, R UniqueField](driver *Driver, uniqueField R, dbId ...uint64) (uint64, T, error) {
driver.mutex.Lock()
defer driver.mutex.Unlock()
var zeroObj T
if len(ns) > 1 {
if len(dbId) > 1 {
return 0, zeroObj, fmt.Errorf("only one namespace is allowed")
}
ctx, n, err := getDefaultNamespace(driver, ns...)
ctx, db, err := getDefaultDB(driver, dbId...)
if err != nil {
return 0, zeroObj, err
}
if uid, ok := any(uniqueField).(uint64); ok {
uid, obj, err := getByGid[T](ctx, n, uid)
uid, obj, err := getByGid[T](ctx, db, uid)
if err != nil {
return 0, zeroObj, err
}

dms := generateDeleteDqlMutations(n, uid)
dms := generateDeleteDqlMutations(db, uid)

err = applyDqlMutations(ctx, driver, dms)
if err != nil {
Expand All @@ -190,12 +190,12 @@ func Delete[T any, R UniqueField](driver *Driver, uniqueField R, ns ...uint64) (
}

if cf, ok := any(uniqueField).(ConstrainedField); ok {
uid, obj, err := getByConstrainedField[T](ctx, n, cf)
uid, obj, err := getByConstrainedField[T](ctx, db, cf)
if err != nil {
return 0, zeroObj, err
}

dms := generateDeleteDqlMutations(n, uid)
dms := generateDeleteDqlMutations(db, uid)

err = applyDqlMutations(ctx, driver, dms)
if err != nil {
Expand Down
20 changes: 10 additions & 10 deletions api_mutation_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/hypermodeinc/modusdb/api/structreflect"
)

func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object T,
func generateSetDqlMutationsAndSchema[T any](ctx context.Context, d *DB, object T,
gid uint64, dms *[]*dql.Mutation, sch *schema.ParsedSchema) error {
t := reflect.TypeOf(object)
if t.Kind() != reflect.Struct {
Expand All @@ -49,7 +49,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object
if tagMaps.JsonToReverseEdge[jsonName] != "" {
reverseEdgeStr := tagMaps.JsonToReverseEdge[jsonName]
typeName := strings.Split(reverseEdgeStr, ".")[0]
currSchema, err := getSchema(ctx, n)
currSchema, err := getSchema(ctx, d)
if err != nil {
return err
}
Expand All @@ -70,7 +70,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object
}

if !(typeFound && predicateFound) {
if err := mutations.HandleReverseEdge(jsonName, reflectValueType, n.ID(), sch,
if err := mutations.HandleReverseEdge(jsonName, reflectValueType, d.ID(), sch,
reverseEdgeStr); err != nil {
return err
}
Expand All @@ -82,17 +82,17 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object
continue
}

value, err = processStructValue(ctx, value, n)
value, err = processStructValue(ctx, value, d)
if err != nil {
return err
}

value, err = processPointerValue(ctx, value, n)
value, err = processPointerValue(ctx, value, d)
if err != nil {
return err
}

nquad, u, err := mutations.CreateNQuadAndSchema(value, gid, jsonName, t, n.ID())
nquad, u, err := mutations.CreateNQuadAndSchema(value, gid, jsonName, t, d.ID())
if err != nil {
return err
}
Expand All @@ -111,7 +111,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object
}

sch.Types = append(sch.Types, &pb.TypeUpdate{
TypeName: apiutils.AddNamespace(n.ID(), t.Name()),
TypeName: apiutils.AddNamespace(d.ID(), t.Name()),
Fields: sch.Preds,
})

Expand All @@ -120,7 +120,7 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object
return err
}
typeNquad := &api.NQuad{
Namespace: n.ID(),
Namespace: d.ID(),
Subject: fmt.Sprint(gid),
Predicate: "dgraph.type",
ObjectValue: val,
Expand All @@ -134,11 +134,11 @@ func generateSetDqlMutationsAndSchema[T any](ctx context.Context, n *DB, object
return nil
}

func generateDeleteDqlMutations(n *DB, gid uint64) []*dql.Mutation {
func generateDeleteDqlMutations(d *DB, gid uint64) []*dql.Mutation {
return []*dql.Mutation{{
Del: []*api.NQuad{
{
Namespace: n.ID(),
Namespace: d.ID(),
Subject: fmt.Sprint(gid),
Predicate: x.Star,
ObjectValue: &api.Value{
Expand Down
30 changes: 15 additions & 15 deletions api_mutation_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
"github.com/hypermodeinc/modusdb/api/structreflect"
)

func processStructValue(ctx context.Context, value any, n *DB) (any, error) {
func processStructValue(ctx context.Context, value any, db *DB) (any, error) {
if reflect.TypeOf(value).Kind() == reflect.Struct {
value = reflect.ValueOf(value).Interface()
newGid, err := getUidOrMutate(ctx, n.driver, n, value)
newGid, err := getUidOrMutate(ctx, db.driver, db, value)
if err != nil {
return nil, err
}
Expand All @@ -26,19 +26,19 @@ func processStructValue(ctx context.Context, value any, n *DB) (any, error) {
return value, nil
}

func processPointerValue(ctx context.Context, value any, n *DB) (any, error) {
func processPointerValue(ctx context.Context, value any, db *DB) (any, error) {
reflectValueType := reflect.TypeOf(value)
if reflectValueType.Kind() == reflect.Pointer {
reflectValueType = reflectValueType.Elem()
if reflectValueType.Kind() == reflect.Struct {
value = reflect.ValueOf(value).Elem().Interface()
return processStructValue(ctx, value, n)
return processStructValue(ctx, value, db)
}
}
return value, nil
}

func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (uint64, error) {
func getUidOrMutate[T any](ctx context.Context, driver *Driver, db *DB, object T) (uint64, error) {
gid, cfKeyValue, err := structreflect.GetUniqueConstraint[T](object)
if err != nil {
return 0, err
Expand All @@ -50,17 +50,17 @@ func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (ui

dms := make([]*dql.Mutation, 0)
sch := &schema.ParsedSchema{}
err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch)
err = generateSetDqlMutationsAndSchema(ctx, db, object, gid, &dms, sch)
if err != nil {
return 0, err
}

err = n.alterSchemaWithParsed(ctx, sch)
err = driver.alterSchemaWithParsed(ctx, sch)
if err != nil {
return 0, err
}
if gid != 0 || cf != nil {
gid, err = getExistingObject(ctx, n, gid, cf, object)
gid, err = getExistingObject(ctx, db, gid, cf, object)
if err != nil && err != apiutils.ErrNoObjFound {
return 0, err
}
Expand All @@ -69,40 +69,40 @@ func getUidOrMutate[T any](ctx context.Context, db *Driver, n *DB, object T) (ui
}
}

gid, err = db.z.nextUID()
gid, err = driver.z.nextUID()
if err != nil {
return 0, err
}

dms = make([]*dql.Mutation, 0)
err = generateSetDqlMutationsAndSchema(ctx, n, object, gid, &dms, sch)
err = generateSetDqlMutationsAndSchema(ctx, db, object, gid, &dms, sch)
if err != nil {
return 0, err
}

err = applyDqlMutations(ctx, db, dms)
err = applyDqlMutations(ctx, driver, dms)
if err != nil {
return 0, err
}

return gid, nil
}

func applyDqlMutations(ctx context.Context, db *Driver, dms []*dql.Mutation) error {
func applyDqlMutations(ctx context.Context, driver *Driver, dms []*dql.Mutation) error {
edges, err := query.ToDirectedEdges(dms, nil)
if err != nil {
return err
}

if !db.isOpen.Load() {
if !driver.isOpen.Load() {
return ErrClosedDriver
}

startTs, err := db.z.nextTs()
startTs, err := driver.z.nextTs()
if err != nil {
return err
}
commitTs, err := db.z.nextTs()
commitTs, err := driver.z.nextTs()
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit a9842e5

Please sign in to comment.