Skip to content

Commit

Permalink
feat(jzero): improve ivm init logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jaronnie committed Jul 1, 2024
1 parent d260382 commit 6b1c61a
Showing 1 changed file with 52 additions and 26 deletions.
78 changes: 52 additions & 26 deletions internal/ivm/ivminit/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,26 @@ func (ivm *IvmInit) astInspect(f *ast.File, oldService, newService, logicMethodN
}

// 添加 import
f.Decls = append([]ast.Decl{&ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: `"google.golang.org/protobuf/proto"`,
},
},
&ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf(`"%s/internal/logic/%s"`, ivm.jzeroRpc.Module, strings.ToLower(oldService)),
},
},
&ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: fmt.Sprintf(`"%s/internal/pb/%spb"`, ivm.jzeroRpc.Module, strings.ToLower(oldService)),
},
},
},
}}, f.Decls...)
// Track added imports to avoid duplicates
addedImports := make(map[string]bool)

if !hasImport(f, `"google.golang.org/protobuf/proto"`) {
addImport(f, `"google.golang.org/protobuf/proto"`, addedImports)
}

if !hasImport(f, fmt.Sprintf(`"%s/internal/logic/%s"`, ivm.jzeroRpc.Module, strings.ToLower(oldService))) {
addImport(f, fmt.Sprintf(`"%s/internal/logic/%s"`, ivm.jzeroRpc.Module, strings.ToLower(oldService)), addedImports)
}

if !hasImport(f, fmt.Sprintf(`"%s/internal/pb/%spb"`, ivm.jzeroRpc.Module, strings.ToLower(oldService))) {
addImport(f, fmt.Sprintf(`"%s/internal/pb/%spb"`, ivm.jzeroRpc.Module, strings.ToLower(oldService)), addedImports)
}

// 修改函数体逻辑
ast.Inspect(f, func(n ast.Node) bool {
if fn, ok := n.(*ast.FuncDecl); ok && fn.Recv != nil && fn.Name.Name == logicMethodName {
// get fn request type and response type name
var requestTypeName, responseTypeName string
if len(fn.Type.Params.List) > 0 {
// 第一个参数是请求类型
requestField := fn.Type.Params.List[0]
if field, ok := requestField.Names[0].Obj.Decl.(*ast.Field); ok {
if startExpr, ok := field.Type.(*ast.StarExpr); ok {
Expand All @@ -134,7 +123,6 @@ func (ivm *IvmInit) astInspect(f *ast.File, oldService, newService, logicMethodN
}
// 获取响应类型名称
if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 {
// 假设第一个返回值是响应类型
responseField := fn.Type.Results.List[0]
if starExpr, ok := responseField.Type.(*ast.StarExpr); ok {
if selectorExpr, ok := starExpr.X.(*ast.SelectorExpr); ok {
Expand Down Expand Up @@ -343,3 +331,41 @@ func (ivm *IvmInit) astInspect(f *ast.File, oldService, newService, logicMethodN
return true
})
}

// hasImport checks if the given import path is already declared in the file.
func hasImport(f *ast.File, path string) bool {
for _, decl := range f.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.IMPORT {
continue
}
for _, spec := range genDecl.Specs {
importSpec, ok := spec.(*ast.ImportSpec)
if !ok {
continue
}
if importSpec.Path.Value == path {
return true
}
}
}
return false
}

// addImport adds the import declaration to the file if it's not already marked as added.
func addImport(f *ast.File, path string, addedImports map[string]bool) {
if !addedImports[path] {
addedImports[path] = true
f.Decls = append([]ast.Decl{&ast.GenDecl{
Tok: token.IMPORT,
Specs: []ast.Spec{
&ast.ImportSpec{
Path: &ast.BasicLit{
Kind: token.STRING,
Value: path,
},
},
},
}}, f.Decls...)
}
}

0 comments on commit 6b1c61a

Please sign in to comment.