diff --git a/cmd/build.go b/cmd/build.go index 26b14bb0..b8001cb0 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -83,6 +83,7 @@ func runBuild(args []string, wd string) { ModRootPath: gocBuild.ModRootPath, OneMainPackage: true, // it is a go build GlobalCoverVarImportPath: gocBuild.GlobalCoverVarImportPath, + CoverModName: "coverPackageMod", } err = cover.Execute(ci) if err != nil { diff --git a/cmd/cover.go b/cmd/cover.go index 133d39d2..c9104edd 100644 --- a/cmd/cover.go +++ b/cmd/cover.go @@ -53,6 +53,7 @@ func runCover(target string) { Center: center, Singleton: singleton, OneMainPackage: false, + CoverModName: "coverPackageMod", } _ = cover.Execute(ci) } diff --git a/pkg/build/gomodules.go b/pkg/build/gomodules.go index 350bc517..d892fcd3 100644 --- a/pkg/build/gomodules.go +++ b/pkg/build/gomodules.go @@ -78,7 +78,8 @@ func (b *Build) updateGoModFile() (updateFlag bool, newModFile []byte, err error // absolute path no need to rewrite if newVersion == "" && !filepath.IsAbs(newPath) { var absPath string - fullPath := filepath.Join(b.ModRoot, newPath) + //替换原路径为目标路径 + fullPath := filepath.Join(b.TmpDir, newPath) absPath, _ = filepath.Abs(fullPath) // DropReplace & AddReplace will not return error // so no need to check the error diff --git a/pkg/cover/cover.go b/pkg/cover/cover.go index 07455651..77db18ea 100644 --- a/pkg/cover/cover.go +++ b/pkg/cover/cover.go @@ -141,6 +141,7 @@ type CoverInfo struct { AgentPort string Center string Singleton bool + CoverModName string } // Execute inject cover variables for all the .go files in the target folder @@ -154,6 +155,7 @@ func Execute(coverInfo *CoverInfo) error { center := coverInfo.Center singleton := coverInfo.Singleton globalCoverVarImportPath := coverInfo.GlobalCoverVarImportPath + coverPackageMod := coverInfo.CoverModName if coverInfo.IsMod { globalCoverVarImportPath = filepath.Join(coverInfo.ModRootPath, globalCoverVarImportPath) @@ -170,7 +172,7 @@ func Execute(coverInfo *CoverInfo) error { listArgs = append(listArgs, args) } listArgs = append(listArgs, "./...") - pkgs, err := ListPackages(target, strings.Join(listArgs, " "), newGopath) + pkgs, err := ListPackagesInAllModule(target, strings.Join(listArgs, " "), newGopath) if err != nil { log.Errorf("Fail to list all packages, the error: %v", err) return err @@ -183,7 +185,7 @@ func Execute(coverInfo *CoverInfo) error { if pkg.Name == "main" { log.Printf("handle package: %v", pkg.ImportPath) // inject the main package - mainCover, mainDecl := AddCounters(pkg, mode, globalCoverVarImportPath) + mainCover, mainDecl := AddCounters(pkg, mode, coverPackageMod) allDecl += mainDecl // new a testcover for this service tc := TestCover{ @@ -206,7 +208,7 @@ func Execute(coverInfo *CoverInfo) error { //only focus package neither standard Go library nor dependency library if depPkg, ok := pkgs[dep]; ok { - packageCover, depDecl := AddCounters(depPkg, mode, globalCoverVarImportPath) + packageCover, depDecl := AddCounters(depPkg, mode, coverPackageMod) allDecl += depDecl tc.DepsCover = append(tc.DepsCover, packageCover) seen[dep] = packageCover @@ -267,6 +269,63 @@ func ListPackages(dir string, args string, newgopath string) (map[string]*Packag return pkgs, nil } +// ListPackagesInAllModule 递归地列出指定目录及其子模块中的所有Go包 +func ListPackagesInAllModule(dir string, args string, newgopath string) (map[string]*Package, error) { + pkgs := make(map[string]*Package) + + // processModule 处理一个Go模块目录,运行go list并收集包信息 + processModule := func(modDir string) error { + cmd := exec.Command("/bin/bash", "-c", "go list "+args) + cmd.Dir = modDir + if newgopath != "" { + cmd.Env = append(os.Environ(), "GOPATH="+newgopath) + } + + var errbuf bytes.Buffer + cmd.Stderr = &errbuf + + out, err := cmd.Output() + if err != nil { + return fmt.Errorf("%s: %v", errbuf.String(), err) + } + + dec := json.NewDecoder(bytes.NewReader(out)) + for { + var pkg Package + if err := dec.Decode(&pkg); err == io.EOF { + break + } else if err != nil { + return err + } + pkgs[pkg.ImportPath] = &pkg + } + + return nil + } + // 遍历目录查找子模块 + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() && info.Name() == "vendor" { + return filepath.SkipDir // 跳过 vendor 目录 + } + if info.Name() == "go.mod" { + modDir := filepath.Dir(path) + if err := processModule(modDir); err != nil { + return err + } + } + return nil + }) + + if err != nil { + return nil, err + } + + return pkgs, nil +} + // AddCounters is different from official go tool cover // 1. only inject covervar++ into source file // 2. no declarartions for these covervars diff --git a/pkg/cover/instrument.go b/pkg/cover/instrument.go index 22b374a9..9d2df24c 100644 --- a/pkg/cover/instrument.go +++ b/pkg/cover/instrument.go @@ -17,10 +17,15 @@ package cover import ( + "bufio" "fmt" + log "github.com/sirupsen/logrus" + "io/ioutil" "os" + "os/exec" "path" "path/filepath" + "strings" "text/template" ) @@ -61,7 +66,7 @@ import ( "syscall" "testing" - _cover {{.GlobalCoverVarImportPath | printf "%q"}} + _cover "coverPackageMod" ) @@ -511,6 +516,188 @@ func injectGlobalCoverVarFile(ci *CoverInfo, content string) error { return err } _, err = coverFile.WriteString(content) + if err != nil { + log.Errorf("err:", err) + return err + } + + goVersion, err := getModuleGoVersion(filepath.Join(ci.Target, "go.mod")) + if goVersion == "" || err != nil { + goVersion = "1.13" //其他情况默认使用1.13版本 + } + + //将cover.go所在package模块化 + modFilePath := filepath.Join(ci.Target, ci.GlobalCoverVarImportPath, "go.mod") + modContent := fmt.Sprintf("module %s\n\ngo %s\n", extractSuffix(ci.GlobalCoverVarImportPath), goVersion) + modFile, err := os.Create(modFilePath) + if err != nil { + log.Errorf("create modFile err:", err) + return err + } + defer modFile.Close() + + _, err = modFile.WriteString(modContent) + if err != nil { + log.Errorf("modFile write err:", err) + return err + } + + // 更新所有模块的 go.mod 文件以包含对新模块的引用 + err = UpdateAllModuleDependencies(ci.Target, ci.GlobalCoverVarImportPath, ci.CoverModName) + if err != nil { + return err + } + startDir := ci.Target // 从当前目录开始搜索,你可以修改为任何起始路径 + err = filepath.Walk(startDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Skip if it's not a directory + if !info.IsDir() { + return nil + } + + return nil + }) + + if err != nil { + log.Errorf("Error walking the path %q: %v\n", startDir, err) + } return err } + +func extractSuffix(input string) string { + // Find the last index of "/" + lastIndex := strings.LastIndex(input, "/") + if lastIndex == -1 { + return input // 如果没有找到"/",则返回整个字符串 + } + // 返回最后一个"/"之后的子串 + return input[lastIndex+1:] +} + +// UpdateAllModuleDependencies 递归修改所有模块替换依赖 +func UpdateAllModuleDependencies(moduleRootPath, modFileDir, coverPackageMod string) error { + // 获取 cover.go 文件所在的模块名称 + coverModFilePath := filepath.Join(moduleRootPath, modFileDir, "go.mod") + coverModuleName, err := getModuleName(coverModFilePath) + if err != nil { + return err + } + coverLocalPath := filepath.Join(moduleRootPath, modFileDir) + // 遍历根目录下的所有子目录 + return filepath.WalkDir(moduleRootPath, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + + // 忽略非目录项 + if !d.IsDir() { + return nil + } + + // 检查目录下是否有 go.mod 文件 + thisModFilePath := filepath.Join(path, "go.mod") + if _, err := os.Stat(thisModFilePath); err == nil { + // 获取当前模块名称 + currentModuleName, err := getModuleName(thisModFilePath) + if err != nil { + log.Errorf("err:", err) + return err + } + + // 如果当前模块是 cover.go 所在的模块,则跳过 + if currentModuleName == coverModuleName { + return nil + } + + // 添加依赖 + return addLocalDependencyToModFile(thisModFilePath, coverPackageMod, coverLocalPath, coverModuleName) + } + return nil + }) +} + +// addLocalDependencyToModFile 替换依赖 +func addLocalDependencyToModFile(modFilePath, importPath, localPath, rootModuleName string) error { + // 获取当前模块名称 + currentModuleName, err := getModuleName(modFilePath) + if err != nil { + return err + } + + // 如果当前模块是根模块,则跳过 + if currentModuleName == rootModuleName { + return nil + } + + // 添加 replace 指令 + replaceCmd := exec.Command("go", "mod", "edit", fmt.Sprintf("-replace=%s=%s", importPath, localPath), modFilePath) + replaceOutput, err := replaceCmd.CombinedOutput() + if err != nil { + log.Errorf("go mod edit -replace output:\n%s\n", string(replaceOutput)) + return fmt.Errorf("failed to add replace for %s: %w", importPath, err) + } + + // 添加 require 指令 + requireCmd := exec.Command("go", "mod", "edit", fmt.Sprintf("-require=%s@v0.0.0-00010101000000-000000000000", importPath), modFilePath) + requireOutput, err := requireCmd.CombinedOutput() + if err != nil { + log.Infof("go mod edit -require output:\n%s\n", string(requireOutput)) + return fmt.Errorf("failed to add require for %s: %w", importPath, err) + } + + return nil +} + +// getModuleName 获取module名 +func getModuleName(modFilePath string) (string, error) { + // 读取 go.mod 文件内容 + content, err := ioutil.ReadFile(modFilePath) + if err != nil { + return "", err + } + + // 查找 module 语句 + lines := strings.Split(string(content), "\n") + for _, line := range lines { + if strings.HasPrefix(line, "module ") { + // 获取模块名称 + return strings.TrimSpace(strings.TrimPrefix(line, "module ")), nil + } + } + + return "", fmt.Errorf("module directive not found in %s", modFilePath) +} + +// getModuleGoVersion 获取mudule的版本号 +func getModuleGoVersion(modFilePath string) (string, error) { + + // 打开文件 + file, err := os.Open(modFilePath) + if err != nil { + log.Fatalf("Error opening go.mod file: %v", err) + } + defer file.Close() + + // 创建文件的 bufio.Reader + scanner := bufio.NewScanner(file) + + // 读取文件行并查找 Go 版本 + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "go ") { + // 找到 Go 版本行 + version := strings.TrimSpace(strings.TrimPrefix(line, "go")) + return version, nil + } + } + + if err := scanner.Err(); err != nil { + log.Fatalf("Error reading go.mod file: %v", err) + return "", err + } + return "", nil +}