Skip to content

Commit

Permalink
Add callgraphutil.WriteDOT (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
picatz authored Jan 5, 2024
1 parent eebdf3f commit e98ee4e
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 0 deletions.
88 changes: 88 additions & 0 deletions callgraphutil/dot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package callgraphutil

import (
"bufio"
"fmt"
"io"
"strings"

"golang.org/x/tools/go/callgraph"
)

// WriteDOT writes the given callgraph.Graph to the given io.Writer in the
// DOT format, which can be used to generate a visual representation of the
// call graph using Graphviz.
func WriteDOT(w io.Writer, g *callgraph.Graph) error {
b := bufio.NewWriter(w)
defer b.Flush()

b.WriteString("digraph callgraph {\n")
b.WriteString("\tgraph [fontname=\"Helvetica\", overlap=false normalize=true];\n")
b.WriteString("\tnode [fontname=\"Helvetica\" shape=box];\n")
b.WriteString("\tedge [fontname=\"Helvetica\"];\n")

edges := []*callgraph.Edge{}

nodesByPkg := map[string][]*callgraph.Node{}

addPkgNode := func(n *callgraph.Node) {
// TODO: fix this so there's not so many "shared" functions?
//
// It is a bit of a hack, but it works for now.
var pkgPath string
if n.Func.Pkg != nil {
pkgPath = n.Func.Pkg.Pkg.Path()
} else {
pkgPath = "shared"
}

// Check if the package already exists.
if _, ok := nodesByPkg[pkgPath]; !ok {
// If not, create it.
nodesByPkg[pkgPath] = []*callgraph.Node{}
}
nodesByPkg[pkgPath] = append(nodesByPkg[pkgPath], n)
}

// Check if root node exists, if so, write it.
if g.Root != nil {
b.WriteString(fmt.Sprintf("\troot = %d;\n", g.Root.ID))
}

// Process nodes and edges.
for _, n := range g.Nodes {
// Add node to map of nodes by package.
addPkgNode(n)

// Add edges
edges = append(edges, n.Out...)
}

// Write nodes by package.
for pkg, nodes := range nodesByPkg {
// Make the pkg name sugraph cluster friendly (remove dots, dashes, and slashes).
clusterName := strings.Replace(pkg, ".", "_", -1)
clusterName = strings.Replace(clusterName, "/", "_", -1)
clusterName = strings.Replace(clusterName, "-", "_", -1)

// NOTE: even if we're using a subgraph cluster, it may not be
// respected by all Graphviz layout engines. For example, the
// "dot" engine will respect the cluster, but the "sfdp" engine
// will not.
b.WriteString(fmt.Sprintf("\tsubgraph cluster_%s {\n", clusterName))
b.WriteString(fmt.Sprintf("\t\tlabel=%q;\n", pkg))
for _, n := range nodes {
b.WriteString(fmt.Sprintf("\t\t%d [label=%q];\n", n.ID, n.Func))
}
b.WriteString("\t}\n")
}

// Write edges.
for _, e := range edges {
b.WriteString(fmt.Sprintf("\t%d -> %d;\n", e.Caller.ID, e.Callee.ID))
}

b.WriteString("}\n")

return nil
}
196 changes: 196 additions & 0 deletions callgraphutil/dot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package callgraphutil_test

import (
"bytes"
"context"
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
"path/filepath"
"testing"

"github.com/go-git/go-git/v5"
"github.com/picatz/taint/callgraphutil"
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)

func cloneGitHubRepository(ctx context.Context, ownerName, repoName string) (string, string, error) {
// Get the owner and repo part of the URL.
ownerAndRepo := ownerName + "/" + repoName

// Get the directory path.
dir := filepath.Join(os.TempDir(), "taint", "github", ownerAndRepo)

// Check if the directory exists.
_, err := os.Stat(dir)
if err == nil {
// If the directory exists, we'll assume it's a valid repository,
// and return the directory. Open the directory to
repo, err := git.PlainOpen(dir)
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

// Get the repository's HEAD.
head, err := repo.Head()
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

return dir, head.Hash().String(), nil
}

// Clone the repository.
repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{
URL: fmt.Sprintf("https://github.com/%s", ownerAndRepo),
Depth: 1,
Tags: git.NoTags,
SingleBranch: true,
})
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

// Get the repository's HEAD.
head, err := repo.Head()
if err != nil {
return dir, "", fmt.Errorf("%w", err)
}

return dir, head.Hash().String(), nil
}

func loadPackages(ctx context.Context, dir, pattern string) ([]*packages.Package, error) {
loadMode :=
packages.NeedName |
packages.NeedDeps |
packages.NeedFiles |
packages.NeedModule |
packages.NeedTypes |
packages.NeedImports |
packages.NeedSyntax |
packages.NeedTypesInfo
// packages.NeedTypesSizes |
// packages.NeedCompiledGoFiles |
// packages.NeedExportFile |
// packages.NeedEmbedPatterns

// parseMode := parser.ParseComments
parseMode := parser.SkipObjectResolution

// patterns := []string{dir}
patterns := []string{pattern}
// patterns := []string{"all"}

pkgs, err := packages.Load(&packages.Config{
Mode: loadMode,
Context: ctx,
Env: os.Environ(),
Dir: dir,
Tests: false,
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
return parser.ParseFile(fset, filename, src, parseMode)
},
}, patterns...)
if err != nil {
return nil, err
}

return pkgs, nil

}

func loadSSA(ctx context.Context, pkgs []*packages.Package) (mainFn *ssa.Function, srcFns []*ssa.Function, err error) {
ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug

// Analyze the package.
ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode)

ssaProg.Build()

for _, pkg := range ssaPkgs {
pkg.Build()
}

mainPkgs := ssautil.MainPackages(ssaPkgs)

mainFn = mainPkgs[0].Members["main"].(*ssa.Function)

for _, pkg := range ssaPkgs {
for _, fn := range pkg.Members {
if fn.Object() == nil {
continue
}

if fn.Object().Name() == "_" {
continue
}

pkgFn := pkg.Func(fn.Object().Name())
if pkgFn == nil {
continue
}

var addAnons func(f *ssa.Function)
addAnons = func(f *ssa.Function) {
srcFns = append(srcFns, f)
for _, anon := range f.AnonFuncs {
addAnons(anon)
}
}
addAnons(pkgFn)
}
}

if mainFn == nil {
err = fmt.Errorf("failed to find main function")
return
}

return
}

func loadCallGraph(ctx context.Context, mainFn *ssa.Function, srcFns []*ssa.Function) (*callgraph.Graph, error) {
cg, err := callgraphutil.NewGraph(mainFn, srcFns...)
if err != nil {
return nil, fmt.Errorf("failed to create new callgraph: %w", err)
}

return cg, nil
}

func TestWriteDOT(t *testing.T) {
repo, _, err := cloneGitHubRepository(context.Background(), "picatz", "taint")
if err != nil {
t.Fatal(err)
}

pkgs, err := loadPackages(context.Background(), repo, "./...")
if err != nil {
t.Fatal(err)
}

mainFn, srcFns, err := loadSSA(context.Background(), pkgs)
if err != nil {
t.Fatal(err)
}

cg, err := loadCallGraph(context.Background(), mainFn, srcFns)
if err != nil {
t.Fatal(err)
}

output := &bytes.Buffer{}

err = callgraphutil.WriteDOT(output, cg)
if err != nil {
t.Fatal(err)
}

fmt.Println(output.String())
}

0 comments on commit e98ee4e

Please sign in to comment.