diff --git a/cmd/capslock/capslock.go b/cmd/capslock/capslock.go index 90f04e3..7722dac 100644 --- a/cmd/capslock/capslock.go +++ b/cmd/capslock/capslock.go @@ -13,10 +13,12 @@ package main import ( + "bytes" "flag" "fmt" "log" "os" + "os/exec" "runtime" "runtime/pprof" "strings" @@ -41,6 +43,7 @@ var ( memprofile = flag.String("memprofile", "", "write memory profile to specified file") granularity = flag.String("granularity", "", `the granularity to use for comparisons, either "package" or "function".`) + forceLocalModule = flag.Bool("force_local_module", false, "if the requested packages cannot be loaded in the current workspace, return an error immediately, instead of trying to load them in a temporary module") ) func main() { @@ -99,12 +102,57 @@ func run() error { classifier = analyzer.GetClassifier(*noiseFlag) } - pkgs, err := analyzer.LoadPackages(packageNames, - analyzer.LoadConfig{ - BuildTags: *buildTags, - GOOS: *goos, - GOARCH: *goarch, - }) + loadConfig := analyzer.LoadConfig{ + BuildTags: *buildTags, + GOOS: *goos, + GOARCH: *goarch, + } + pkgs, listFailed, failedPackage, err := loadPackages(packageNames, loadConfig) + if (listFailed || len(pkgs) == 0) && !*forceLocalModule { + // Either: + // - `go list` returned an error for one of the packages, perhaps because + // it is not a dependency of the current workspace; or + // - no packages were loaded, because paths with '...' wildcards matched + // no dependencies of the current workspace. + // + // Here we try again in a temporary module, in which we call `go get` for + // each package. + // + // -force_local_module disables this behavior, and returns an error + // instead. + if listFailed { + fmt.Fprintf(os.Stderr, "Couldn't load package %q in the current module.", failedPackage) + } else { + fmt.Fprintf(os.Stderr, "Found no packages matching %q in the current module.", packageNames) + } + fmt.Fprintf(os.Stderr, " Trying again in a temporary module.\n") + + // Save current working directory. + var wd string + wd, err = os.Getwd() + if err != nil { + return err + } + + // Create a temporary module, switch to it, and `go get` the requested packages. + var remove func() + remove, err = makeTemporaryModule(packageNames) + if remove != nil { + defer remove() + } + if err != nil { + return err + } + + // Try loading the packages again. + pkgs, _, _, err = loadPackages(packageNames, loadConfig) + + // Switch back to the original working directory. + err1 := os.Chdir(wd) + if err == nil && err1 != nil { + return fmt.Errorf("returning to working directory: %w", err1) + } + } if err != nil { return fmt.Errorf("Error loading packages: %w", err) } @@ -143,3 +191,65 @@ func run() error { } return err } + +// loadPackages calls analyzer.LoadPackages to load the specified packages. +// +// If it fails due to a ListError (for example, if one of the packages is not a +// dependency of the current module), the return value listFailed will be true, +// and failedPackage will specify the package that couldn't be loaded. +func loadPackages(packageNames []string, loadConfig analyzer.LoadConfig) (pkgs []*packages.Package, listFailed bool, failedPackage string, err error) { + pkgs, err = analyzer.LoadPackages(packageNames, loadConfig) + for _, p := range pkgs { + for _, e := range p.Errors { + if e.Kind == packages.ListError { + return pkgs, true, p.ID, err + } + } + } + return pkgs, false, "", err +} + +// makeTemporaryModule switches to a new temporary directory, creates a module +// there, and adds the specified packages to that module with `go get`. +// +// It also sets the environment variable GOWORK to "off", to avoid analyses +// being affected by workspaces we did not intend to use. (For example, if +// there's a go.work file in /tmp.) +// +// The caller can call the returned function, if it is non-nil, to remove the +// temporary directory containing the module when it is no longer needed. +func makeTemporaryModule(packageNames []string) (remove func(), err error) { + if err = os.Setenv("GOWORK", "off"); err != nil { + return nil, err + } + tmpdir, err := os.MkdirTemp("", "") + if err != nil { + return nil, fmt.Errorf("creating temporary directory: %w", err) + } + remove = func() { os.RemoveAll(tmpdir) } + if err = os.Chdir(tmpdir); err != nil { + return remove, fmt.Errorf("switching to temporary directory: %w", err) + } + run := func(command string, args ...string) error { + if *verbose >= 2 { + log.Printf("running %q with args %q", command, args) + } + cmd := exec.Command(command, args...) + var stderr bytes.Buffer + cmd.Stderr = &stderr + err := cmd.Run() + if err != nil || *verbose >= 2 { + os.Stderr.Write(stderr.Bytes()) + } + return err + } + if err = run("go", "mod", "init", "capslockmodule"); err != nil { + return remove, fmt.Errorf("creating temporary module: %w", err) + } + for _, p := range packageNames { + if err := run("go", "get", p); err != nil { + return remove, fmt.Errorf("calling `go get %q`: %w", p, err) + } + } + return remove, nil +}