Skip to content

Commit

Permalink
feat: support complex types imported from src <=> dest
Browse files Browse the repository at this point in the history
  • Loading branch information
khatibomar committed May 11, 2024
1 parent 8ba48eb commit 2d1b4e0
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ func (dest CarDTO) FromCar(src car.Car) CarDTO {
func (dest UserDTO) FromUser(src user.User) UserDTO {
dest.Name = src.Name
dest.Age = src.Age
dest.MetaData = src.MetaData
return dest
}
```
Expand Down
58 changes: 56 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ func (g *Generator) generate(source SourceData, destination DestinationData) err
}
if ok {
// NOTE(khatibomar): I should support convertion between convertible types
if sourceField.Type == destinationField.Type {
isImportedFromSource := strings.TrimPrefix(sourceField.Type, fmt.Sprintf("%s.", getPackage(destination.node))) == destinationField.Type &&
in(joinLinuxPath(g.module, g.pathFromModule, filepath.Dir(destination.path)), getImports(source.node))
isImportedFromDestination := strings.TrimPrefix(destinationField.Type, fmt.Sprintf("%s.", getPackage(source.node))) == sourceField.Type &&
in(joinLinuxPath(g.module, g.pathFromModule, filepath.Dir(source.path)), getImports(destination.node))
if sourceField.Type == destinationField.Type || isImportedFromSource || isImportedFromDestination {
g.Printf("dest.%s = src.%s\n", destinationField.Name, sourceField.Name)
}
}
Expand All @@ -341,6 +345,24 @@ func getPackage(node *ast.File) string {
return node.Name.String()
}

func getImports(node *ast.File) []string {
var res []string
for _, i := range node.Imports {
res = append(res, i.Path.Value[1:len(i.Path.Value)-1])
}

return res
}

func in(target string, list []string) bool {
for _, item := range list {
if item == target {
return true
}
}
return false
}

func getFields(node *ast.File, typeName string) ([]Field, error) {
var fields []Field
var found bool
Expand All @@ -355,7 +377,7 @@ func getFields(node *ast.File, typeName string) ([]Field, error) {
for _, n := range f.Names {
fields = append(fields, Field{
Name: n.Name,
Type: fmt.Sprintf("%s", f.Type),
Type: extractTypeFromExpression(f.Type),
})
}
}
Expand All @@ -373,6 +395,38 @@ func getFields(node *ast.File, typeName string) ([]Field, error) {
return fields, err
}

func extractTypeFromExpression(expr ast.Expr) string {
switch expr := expr.(type) {
case *ast.Ident:
return expr.Name
case *ast.StarExpr:
return "*" + extractTypeFromExpression(expr.X)
case *ast.ArrayType:
return "[]" + extractTypeFromExpression(expr.Elt)
case *ast.MapType:
return "map[" + extractTypeFromExpression(expr.Key) + "]" + extractTypeFromExpression(expr.Value)
case *ast.StructType:
return "struct{}"
case *ast.InterfaceType:
return "interface{}"
case *ast.ChanType:
var dir string
switch expr.Dir {
case ast.SEND:
dir = "chan<- "
case ast.RECV:
dir = "<-chan "
default:
dir = "chan "
}
return dir + extractTypeFromExpression(expr.Value)
case *ast.SelectorExpr:
return extractTypeFromExpression(expr.X) + "." + expr.Sel.Name
default:
return "unknown"
}
}

type Field struct {
Name string
Type string
Expand Down
227 changes: 210 additions & 17 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,39 @@ import (
"github.com/stretchr/testify/assert"
)

func TestGetImports(t *testing.T) {
run := func(src string, expected []string) {
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "main.go", src, 0)
assert.Equal(t, nil, err)
assert.Equal(t, expected, getImports(node))
}

run(`package p
import "tt"
import "/"`, []string{"tt", "/"})
}

func TestGetFields(t *testing.T) {
src := `
run := func(src, typeName, expectedPkg string, expectedFields []Field) {
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "main.go", src, 0)
assert.Equal(t, nil, err)

fields, err := getFields(node, typeName)
assert.NoError(t, err)
pkgName := getPackage(node)

assert.Equal(t, expectedPkg, pkgName)

for i, f := range fields {
assert.Equal(t, expectedFields[i].Name, f.Name)
assert.Equal(t, expectedFields[i].Type, f.Type)
}
}

t.Run("simple struct", func(t *testing.T) {
src := `
package p
type P struct {
Expand All @@ -24,23 +55,15 @@ func TestGetFields(t *testing.T) {
a int
}
`
run(src, "P", "p", []Field{
{"a", "int"},
{"B", "string"},
})

fset := token.NewFileSet()
node, err := parser.ParseFile(fset, "main.go", src, 0)
assert.Equal(t, nil, err)

fields, err := getFields(node, "P")
assert.NoError(t, err)
pkgName := getPackage(node)

assert.Equal(t, "p", pkgName, "")
assert.Len(t, fields, 2)

assert.Equal(t, fields[0].Name, "a")
assert.Equal(t, fields[0].Type, "int")

assert.Equal(t, fields[1].Name, "B")
assert.Equal(t, fields[1].Type, "string")
run(src, "K", "p", []Field{
{"a", "int"},
})
})
}

func TestGenerate(t *testing.T) {
Expand Down Expand Up @@ -301,6 +324,176 @@ func TestGenerate(t *testing.T) {
},
})
})

t.Run("complex field in dest refereing src", func(t *testing.T) {
srcCode := `
package p
type P struct {
a int
B string
Meta MetaData
}
type MetaData struct{}
`

dstCode := `
package l
import "/bli"
type L struct {
a int
B string
Meta p.MetaData
}
`

expectedOutput := `func (dest L) FromP(src p.P) L {
dest.B = src.B
dest.Meta = src.Meta
return dest
}`

runTest(t, testInput{
srcCode: srcCode,
destCode: dstCode,
sourcePath: "/bli/p.go",
destPath: "/bla/l.go",
srcName: "P",
destName: "L",
expectedOutput: expectedOutput,
expectedGenerateError: nil,
})
})

t.Run("complex field in dest but not refereing src", func(t *testing.T) {
srcCode := `
package p
type P struct {
a int
B string
Meta MetaData
}
type MetaData struct{}
`

dstCode := `
package l
import p "/x"
type L struct {
a int
B string
Meta p.MetaData
}
`

expectedOutput := `func (dest L) FromP(src p.P) L {
dest.B = src.B
return dest
}`

runTest(t, testInput{
srcCode: srcCode,
destCode: dstCode,
sourcePath: "/bli/p.go",
destPath: "/bla/l.go",
srcName: "P",
destName: "L",
expectedOutput: expectedOutput,
expectedGenerateError: nil,
})
})

t.Run("complex field in src refereing dest", func(t *testing.T) {
srcCode := `
package p
import "/bla"
type P struct {
a int
B string
Meta l.MetaData
}
`

dstCode := `
package l
type L struct {
a int
B string
Meta MetaData
}
type MetaData struct{}
`

expectedOutput := `func (dest L) FromP(src p.P) L {
dest.B = src.B
dest.Meta = src.Meta
return dest
}`

runTest(t, testInput{
srcCode: srcCode,
destCode: dstCode,
sourcePath: "/bli/file.go",
destPath: "/bla/file.go",
srcName: "P",
destName: "L",
expectedOutput: expectedOutput,
expectedGenerateError: nil,
})
})

t.Run("complex field in src but not refereing dest", func(t *testing.T) {
srcCode := `
package p
import "/blo"
type P struct {
a int
B string
Meta l.MetaData
}
`

dstCode := `
package l
type L struct {
a int
B string
Meta MetaData
}
type MetaData struct{}
`

expectedOutput := `func (dest L) FromP(src p.P) L {
dest.B = src.B
return dest
}`

runTest(t, testInput{
srcCode: srcCode,
destCode: dstCode,
sourcePath: "/bli/file.go",
destPath: "/bla/file.go",
srcName: "P",
destName: "L",
expectedOutput: expectedOutput,
expectedGenerateError: nil,
})
})
}

func TestGroupMappings(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions testdata/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package user
import "time"

type User struct {
Name string
Age int
MetadData Metadata
Name string
Age int
MetaData Metadata
}

type Metadata struct {
Expand Down

0 comments on commit 2d1b4e0

Please sign in to comment.