diff --git a/callgraphutil/dot.go b/callgraphutil/dot.go new file mode 100644 index 0000000..07b8609 --- /dev/null +++ b/callgraphutil/dot.go @@ -0,0 +1,88 @@ +package callgraphutil + +import ( + "bufio" + "fmt" + "io" + "strings" + + "golang.org/x/tools/go/callgraph" +) + +// WriteDOT writes the given callgraph.Graph to the given io.Writer in the +// DOT format, which can be used to generate a visual representation of the +// call graph using Graphviz. +func WriteDOT(w io.Writer, g *callgraph.Graph) error { + b := bufio.NewWriter(w) + defer b.Flush() + + b.WriteString("digraph callgraph {\n") + b.WriteString("\tgraph [fontname=\"Helvetica\", overlap=false normalize=true];\n") + b.WriteString("\tnode [fontname=\"Helvetica\" shape=box];\n") + b.WriteString("\tedge [fontname=\"Helvetica\"];\n") + + edges := []*callgraph.Edge{} + + nodesByPkg := map[string][]*callgraph.Node{} + + addPkgNode := func(n *callgraph.Node) { + // TODO: fix this so there's not so many "shared" functions? + // + // It is a bit of a hack, but it works for now. + var pkgPath string + if n.Func.Pkg != nil { + pkgPath = n.Func.Pkg.Pkg.Path() + } else { + pkgPath = "shared" + } + + // Check if the package already exists. + if _, ok := nodesByPkg[pkgPath]; !ok { + // If not, create it. + nodesByPkg[pkgPath] = []*callgraph.Node{} + } + nodesByPkg[pkgPath] = append(nodesByPkg[pkgPath], n) + } + + // Check if root node exists, if so, write it. + if g.Root != nil { + b.WriteString(fmt.Sprintf("\troot = %d;\n", g.Root.ID)) + } + + // Process nodes and edges. + for _, n := range g.Nodes { + // Add node to map of nodes by package. + addPkgNode(n) + + // Add edges + edges = append(edges, n.Out...) + } + + // Write nodes by package. + for pkg, nodes := range nodesByPkg { + // Make the pkg name sugraph cluster friendly (remove dots, dashes, and slashes). + clusterName := strings.Replace(pkg, ".", "_", -1) + clusterName = strings.Replace(clusterName, "/", "_", -1) + clusterName = strings.Replace(clusterName, "-", "_", -1) + + // NOTE: even if we're using a subgraph cluster, it may not be + // respected by all Graphviz layout engines. For example, the + // "dot" engine will respect the cluster, but the "sfdp" engine + // will not. + b.WriteString(fmt.Sprintf("\tsubgraph cluster_%s {\n", clusterName)) + b.WriteString(fmt.Sprintf("\t\tlabel=%q;\n", pkg)) + for _, n := range nodes { + b.WriteString(fmt.Sprintf("\t\t%d [label=%q];\n", n.ID, n.Func)) + } + b.WriteString("\t}\n") + } + + // Write edges. + for _, e := range edges { + b.WriteString(fmt.Sprintf("\t%d -> %d;\n", e.Caller.ID, e.Callee.ID)) + } + + b.WriteString("}\n") + + return nil +} diff --git a/callgraphutil/dot_test.go b/callgraphutil/dot_test.go new file mode 100644 index 0000000..8cea70c --- /dev/null +++ b/callgraphutil/dot_test.go @@ -0,0 +1,196 @@ +package callgraphutil_test + +import ( + "bytes" + "context" + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "testing" + + "github.com/go-git/go-git/v5" + "github.com/picatz/taint/callgraphutil" + "golang.org/x/tools/go/callgraph" + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" +) + +func cloneGitHubRepository(ctx context.Context, ownerName, repoName string) (string, string, error) { + // Get the owner and repo part of the URL. + ownerAndRepo := ownerName + "/" + repoName + + // Get the directory path. + dir := filepath.Join(os.TempDir(), "taint", "github", ownerAndRepo) + + // Check if the directory exists. + _, err := os.Stat(dir) + if err == nil { + // If the directory exists, we'll assume it's a valid repository, + // and return the directory. Open the directory to + repo, err := git.PlainOpen(dir) + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + // Get the repository's HEAD. + head, err := repo.Head() + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + return dir, head.Hash().String(), nil + } + + // Clone the repository. + repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{ + URL: fmt.Sprintf("https://github.com/%s", ownerAndRepo), + Depth: 1, + Tags: git.NoTags, + SingleBranch: true, + }) + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + // Get the repository's HEAD. + head, err := repo.Head() + if err != nil { + return dir, "", fmt.Errorf("%w", err) + } + + return dir, head.Hash().String(), nil +} + +func loadPackages(ctx context.Context, dir, pattern string) ([]*packages.Package, error) { + loadMode := + packages.NeedName | + packages.NeedDeps | + packages.NeedFiles | + packages.NeedModule | + packages.NeedTypes | + packages.NeedImports | + packages.NeedSyntax | + packages.NeedTypesInfo + // packages.NeedTypesSizes | + // packages.NeedCompiledGoFiles | + // packages.NeedExportFile | + // packages.NeedEmbedPatterns + + // parseMode := parser.ParseComments + parseMode := parser.SkipObjectResolution + + // patterns := []string{dir} + patterns := []string{pattern} + // patterns := []string{"all"} + + pkgs, err := packages.Load(&packages.Config{ + Mode: loadMode, + Context: ctx, + Env: os.Environ(), + Dir: dir, + Tests: false, + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parseMode) + }, + }, patterns...) + if err != nil { + return nil, err + } + + return pkgs, nil + +} + +func loadSSA(ctx context.Context, pkgs []*packages.Package) (mainFn *ssa.Function, srcFns []*ssa.Function, err error) { + ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug + + // Analyze the package. + ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode) + + ssaProg.Build() + + for _, pkg := range ssaPkgs { + pkg.Build() + } + + mainPkgs := ssautil.MainPackages(ssaPkgs) + + mainFn = mainPkgs[0].Members["main"].(*ssa.Function) + + for _, pkg := range ssaPkgs { + for _, fn := range pkg.Members { + if fn.Object() == nil { + continue + } + + if fn.Object().Name() == "_" { + continue + } + + pkgFn := pkg.Func(fn.Object().Name()) + if pkgFn == nil { + continue + } + + var addAnons func(f *ssa.Function) + addAnons = func(f *ssa.Function) { + srcFns = append(srcFns, f) + for _, anon := range f.AnonFuncs { + addAnons(anon) + } + } + addAnons(pkgFn) + } + } + + if mainFn == nil { + err = fmt.Errorf("failed to find main function") + return + } + + return +} + +func loadCallGraph(ctx context.Context, mainFn *ssa.Function, srcFns []*ssa.Function) (*callgraph.Graph, error) { + cg, err := callgraphutil.NewGraph(mainFn, srcFns...) + if err != nil { + return nil, fmt.Errorf("failed to create new callgraph: %w", err) + } + + return cg, nil +} + +func TestWriteDOT(t *testing.T) { + repo, _, err := cloneGitHubRepository(context.Background(), "picatz", "taint") + if err != nil { + t.Fatal(err) + } + + pkgs, err := loadPackages(context.Background(), repo, "./...") + if err != nil { + t.Fatal(err) + } + + mainFn, srcFns, err := loadSSA(context.Background(), pkgs) + if err != nil { + t.Fatal(err) + } + + cg, err := loadCallGraph(context.Background(), mainFn, srcFns) + if err != nil { + t.Fatal(err) + } + + output := &bytes.Buffer{} + + err = callgraphutil.WriteDOT(output, cg) + if err != nil { + t.Fatal(err) + } + + fmt.Println(output.String()) +}