diff --git a/callgraph/doc.go b/callgraph/doc.go deleted file mode 100644 index 6f960bd..0000000 --- a/callgraph/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package callgraph implements an extended version of golang.org/x/tools/go/callgraph -// with additional optimization attempts and functionality. -package callgraph diff --git a/callgraphutil/aliases.go b/callgraphutil/aliases.go new file mode 100644 index 0000000..617a380 --- /dev/null +++ b/callgraphutil/aliases.go @@ -0,0 +1,9 @@ +package callgraphutil + +import "golang.org/x/tools/go/callgraph" + +// Nodes is a handy alias for a slice of callgraph.Nodes. +type Nodes = []*callgraph.Node + +// Edges is a handy alias for a slice of callgraph.Edges. +type Edges = []*callgraph.Edge diff --git a/callgraphutil/calls.go b/callgraphutil/calls.go new file mode 100644 index 0000000..2ca788f --- /dev/null +++ b/callgraphutil/calls.go @@ -0,0 +1,35 @@ +package callgraphutil + +import "golang.org/x/tools/go/callgraph" + +// CalleesOf returns nodes that are called by the caller node. +func CalleesOf(caller *callgraph.Node) Nodes { + calleesMap := make(map[*callgraph.Node]bool) + for _, e := range caller.Out { + calleesMap[e.Callee] = true + } + + // Convert map to slice. + calleesSlice := make([]*callgraph.Node, 0, len(calleesMap)) + for callee := range calleesMap { + calleesSlice = append(calleesSlice, callee) + } + + return calleesSlice +} + +// CallersOf returns nodes that call the callee node. +func CallersOf(callee *callgraph.Node) Nodes { + uniqCallers := make(map[*callgraph.Node]bool) + for _, e := range callee.In { + uniqCallers[e.Caller] = true + } + + // Convert map to slice. + callersSlice := make(Nodes, 0, len(uniqCallers)) + for caller := range uniqCallers { + callersSlice = append(callersSlice, caller) + } + + return callersSlice +} diff --git a/callgraphutil/doc.go b/callgraphutil/doc.go new file mode 100644 index 0000000..6989af2 --- /dev/null +++ b/callgraphutil/doc.go @@ -0,0 +1,3 @@ +// Package callgraphutil implements utilities for golang.org/x/tools/go/callgraph +// including path searching, graph construction, printing, and more. +package callgraphutil diff --git a/callgraph/callgraph.go b/callgraphutil/graph.go similarity index 50% rename from callgraph/callgraph.go rename to callgraphutil/graph.go index 8818821..ef33232 100644 --- a/callgraph/callgraph.go +++ b/callgraphutil/graph.go @@ -1,34 +1,47 @@ -package callgraph +package callgraphutil import ( "bytes" "fmt" - "go/token" "go/types" + "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) -// func debug(f string, args ...interface{}) { -// fmt.Printf(f, args...) -// fmt.Printf("\033[1000D") -// } +// GraphString returns a string representation of the call graph, +// which is a sequence of nodes separated by newlines, with the +// callees of each node indented by a tab. +func GraphString(g *callgraph.Graph) string { + var buf bytes.Buffer -// A Graph represents a call graph. -// -// A graph may contain nodes that are not reachable from the root. -// If the call graph is sound, such nodes indicate unreachable -// functions. -type Graph struct { - Root *Node // the distinguished root node - Nodes map[*ssa.Function]*Node // all nodes by function + for _, n := range g.Nodes { + fmt.Fprintf(&buf, "%s\n", n) + for _, e := range n.Out { + fmt.Fprintf(&buf, "\t→ %s\n", e.Callee) + } + fmt.Fprintf(&buf, "\n") + } + + return buf.String() } -// New returns a new Graph with the specified root node. -func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) { - g := &Graph{ - Nodes: make(map[*ssa.Function]*Node), +// NewGraph returns a new Graph with the specified root node. +// +// Typically, the root node is the main function of the program, and the +// srcFns are the source functions that are of interest to the caller. But, the root +// node can be any function, and the srcFns can be any set of functions. +// +// This algorithm attempts to add all source functions reachable from the root node +// by traversing the SSA IR and adding edges to the graph; it handles calls +// to functions, methods, closures, and interfaces. It may miss some complex +// edges today, such as stucts containing function fields accessed via slice or map +// indexing. This is a known limitation, but something we hope to improve in the near future. +// https://github.com/picatz/taint/issues/23 +func NewGraph(root *ssa.Function, srcFns ...*ssa.Function) (*callgraph.Graph, error) { + g := &callgraph.Graph{ + Nodes: make(map[*ssa.Function]*callgraph.Node), } g.Root = g.CreateNode(root) @@ -38,7 +51,7 @@ func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) { for _, srcFn := range srcFns { // debug("adding src function %d/%d: %v\n", i+1, len(srcFns), srcFn) - err := g.AddFunction(srcFn, allFns) + err := AddFunction(g, srcFn, allFns) if err != nil { return g, fmt.Errorf("failed to add src function %v: %w", srcFn, err) } @@ -53,7 +66,10 @@ func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) { return g, nil } -func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *Graph, fn *ssa.Function, instr ssa.Instruction) error { +// checkBlockInstruction checks the given instruction for any function calls, adding +// edges to the call graph as needed and recursively adding any new functions to the graph +// that are discovered during the process (typically via interface methods). +func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *callgraph.Graph, fn *ssa.Function, instr ssa.Instruction) error { // debug("\tcheckBlockInstruction: %v\n", instr) switch instrt := instr.(type) { case *ssa.Call: @@ -108,7 +124,7 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g fn = pkg.Prog.NewFunction(method.Name(), method.Type().(*types.Signature), "callgraph") } - AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(fn)) + callgraph.AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(fn)) switch xType := instrtCallArgt.X.Type().(type) { case *types.Named: @@ -130,7 +146,7 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g fn2 = pkg2.Prog.NewFunction(method.Name(), methodType, "callgraph") } - AddEdge(g.CreateNode(fn), instrt, g.CreateNode(fn2)) + callgraph.AddEdge(g.CreateNode(fn), instrt, g.CreateNode(fn2)) default: continue } @@ -181,9 +197,9 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g return nil } - AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall)) + callgraph.AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall)) - err := g.AddFunction(instrCall, allFns) + err := AddFunction(g, instrCall, allFns) if err != nil { return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err) } @@ -194,21 +210,39 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g switch argt := arg.(type) { case *ssa.Function: // TODO: check if edge already exists? - AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt)) + callgraph.AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt)) case *ssa.MakeClosure: switch argtFn := argt.Fn.(type) { case *ssa.Function: - AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn)) + callgraph.AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn)) } } } } + + // Delete duplicate edges that may have been added, which is a responsibility of the caller + // when using the callgraph.AddEdge function directly. + for _, n := range g.Nodes { + // debug("checking node %v\n", n) + for i := 0; i < len(n.Out); i++ { + for j := i + 1; j < len(n.Out); j++ { + if n.Out[i].Callee == n.Out[j].Callee { + // debug("deleting duplicate edge %v\n", n.Out[j]) + n.Out = append(n.Out[:j], n.Out[j+1:]...) + j-- + } + } + } + } + return nil } // AddFunction analyzes the given target SSA function, adding information to the call graph. +// +// Based on the implementation of golang.org/x/tools/cmd/guru/callers.go: // https://cs.opensource.google/go/x/tools/+/master:cmd/guru/callers.go;drc=3e0d083b858b3fdb7d095b5a3deb184aa0a5d35e;bpv=1;bpt=1;l=90 -func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool) error { +func AddFunction(cg *callgraph.Graph, target *ssa.Function, allFns map[*ssa.Function]bool) error { // debug("\tAddFunction: %v (all funcs %d)\n", target, len(allFns)) // First check if we have already processed this function. @@ -256,7 +290,7 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool // Direct call to target? rands := instr.Operands(space[:0]) if site, ok := instr.(ssa.CallInstruction); ok && site.Common().Value == target { - AddEdge(cg.CreateNode(progFn), site, targetNode) + callgraph.AddEdge(cg.CreateNode(progFn), site, targetNode) rands = rands[1:] // skip .Value (rands[0]) } @@ -272,280 +306,3 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool return nil } - -// CreateNode returns the Node for fn, creating it if not present. -func (g *Graph) CreateNode(fn *ssa.Function) *Node { - // debug("\tCreateNode: %v\n", fn) - - n, ok := g.Nodes[fn] - if !ok { - n = &Node{Func: fn, ID: len(g.Nodes)} - g.Nodes[fn] = n - return n - } - return n -} - -func (g *Graph) String() string { - var buf bytes.Buffer - - for _, n := range g.Nodes { - fmt.Fprintf(&buf, "%s\n", n) - for _, e := range n.Out { - fmt.Fprintf(&buf, "\t→ %s\n", e.Callee) - } - fmt.Fprintf(&buf, "\n") - } - return buf.String() -} - -// A Node represents a node in a call graph. -type Node struct { - Func *ssa.Function // the function this node represents - ID int // 0-based sequence number - In []*Edge // unordered set of incoming call edges (n.In[*].Callee == n) - Out []*Edge // unordered set of outgoing call edges (n.Out[*].Caller == n) -} - -func (n *Node) String() string { - return fmt.Sprintf("n%d:%s", n.ID, n.Func) -} - -// A Edge represents an edge in the call graph. -// -// Site is nil for edges originating in synthetic or intrinsic -// functions, e.g. reflect.Call or the root of the call graph. -type Edge struct { - Caller *Node - Site ssa.CallInstruction - Callee *Node -} - -func (e Edge) String() string { - return fmt.Sprintf("%s → %s", e.Caller, e.Callee) -} - -// Description returns a human-readable description of the edge. -func (e Edge) Description() string { - var prefix string - switch e.Site.(type) { - case nil: - return "synthetic call" - case *ssa.Go: - prefix = "concurrent " - case *ssa.Defer: - prefix = "deferred " - } - return prefix + e.Site.Common().Description() -} - -func (e Edge) Pos() token.Pos { - if e.Site == nil { - return token.NoPos - } - return e.Site.Pos() -} - -// AddEdge adds the edge (caller, site, callee) to the call graph. -func AddEdge(caller *Node, site ssa.CallInstruction, callee *Node) { - // debug("\tAddEdge(%v): %v → %v\n", site, caller, callee) - - e := &Edge{caller, site, callee} - - var existingCalleeEdge bool - - for _, in := range callee.In { - if in.String() == e.String() { - existingCalleeEdge = true - break - } - } - - if !existingCalleeEdge { - callee.In = append(callee.In, e) - } - - var existingCallerEdge bool - - for _, out := range caller.Out { - if out.String() == e.String() { - existingCallerEdge = true - break - } - } - - if !existingCallerEdge { - caller.Out = append(caller.Out, e) - } -} - -// GraphVisitEdges visits all the edges in graph g in depth-first order. -// The edge function is called for each edge in postorder. If it -// returns non-nil, visitation stops and GraphVisitEdges returns that -// value. -func (g *Graph) VisitEdges(edge func(*Edge) error) error { - seen := make(map[*Node]bool) - var visit func(n *Node) error - visit = func(n *Node) error { - if !seen[n] { - seen[n] = true - for _, e := range n.Out { - if err := visit(e.Callee); err != nil { - return err - } - if err := edge(e); err != nil { - return err - } - } - } - return nil - } - for _, n := range g.Nodes { - if err := visit(n); err != nil { - return err - } - } - return nil -} - -type Path []*Edge - -func (p Path) Empty() bool { - return len(p) == 0 -} - -func (p Path) First() *Edge { - if len(p) == 0 { - return nil - } - return p[0] -} - -func (p Path) Last() *Edge { - if len(p) == 0 { - return nil - } - return p[len(p)-1] -} - -// String returns a string representation of the path which -// is a sequence of edges separated by " → ". -// -// Intended to be used while debugging. -func (p Path) String() string { - var buf bytes.Buffer - for i, e := range p { - if i == 0 { - buf.WriteString(e.Caller.String()) - } - - buf.WriteString(" → ") - - buf.WriteString(e.Callee.String()) - } - return buf.String() -} - -type Paths []Path - -func PathSearch(start *Node, isMatch func(*Node) bool) Path { - stack := make(Path, 0, 32) - seen := make(map[*Node]bool) - var search func(n *Node) Path - search = func(n *Node) Path { - if !seen[n] { - // debug("searching: %v\n", n) - seen[n] = true - if isMatch(n) { - return stack - } - for _, e := range n.Out { - stack = append(stack, e) // push - if found := search(e.Callee); found != nil { - return found - } - stack = stack[:len(stack)-1] // pop - } - } - return nil - } - return search(start) -} - -func PathsSearch(start *Node, isMatch func(*Node) bool) Paths { - paths := Paths{} - - stack := make(Path, 0, 32) - seen := make(map[*Node]bool) - var search func(n *Node) - search = func(n *Node) { - // debug("searching: %v\n", n) - if !seen[n] { - seen[n] = true - if isMatch(n) { - paths = append(paths, stack) - - stack = make(Path, 0, 32) - seen = make(map[*Node]bool) - return - } - for _, e := range n.Out { - // debug("\tout: %v\n", e) - stack = append(stack, e) // push - search(e.Callee) - if len(stack) == 0 { - continue - } - stack = stack[:len(stack)-1] // pop - } - } - } - search(start) - - return paths -} - -func CalleesOf(caller *Node) map[*Node]bool { - callees := make(map[*Node]bool) - for _, e := range caller.Out { - callees[e.Callee] = true - } - return callees -} - -func CallersOf(callee *Node) map[*Node]bool { - callers := make(map[*Node]bool) - for _, e := range callee.In { - callers[e.Caller] = true - } - return callers -} - -func PathSearchCallTo(start *Node, fn string) Path { - return PathSearch(start, func(n *Node) bool { - fnStr := n.Func.String() - return fnStr == fn - }) -} - -func InstructionsFor(root *Node, v ssa.Value) (si ssa.Instruction) { - PathsSearch(root, func(n *Node) bool { - for _, b := range root.Func.Blocks { - for _, instr := range b.Instrs { - if instr.Pos() == v.Pos() { - si = instr - return true - } - } - } - return false - }) - return -} - -func PathsSearchCallTo(start *Node, fn string) Paths { - return PathsSearch(start, func(n *Node) bool { - fnStr := n.Func.String() - return fnStr == fn - }) -} diff --git a/callgraphutil/graph_vulncheck.go b/callgraphutil/graph_vulncheck.go new file mode 100644 index 0000000..4014c87 --- /dev/null +++ b/callgraphutil/graph_vulncheck.go @@ -0,0 +1,102 @@ +package callgraphutil + +import ( + "context" + + "golang.org/x/tools/go/callgraph" + "golang.org/x/tools/go/callgraph/cha" + "golang.org/x/tools/go/callgraph/vta" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" +) + +// NewVulncheckCallGraph builds a call graph of prog based on VTA analysis, +// straight from the govulncheck project. This is used to demonstrate the +// difference between the call graph built by this package's algorithm and +// govulncheck's algorithm (based on CHA and VTA analysis). +// +// This method is based on the following: +// https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/utils.go#L63 +func NewVulncheckCallGraph(ctx context.Context, prog *ssa.Program, entries []*ssa.Function) (*callgraph.Graph, error) { + entrySlice := make(map[*ssa.Function]bool) + for _, e := range entries { + entrySlice[e] = true + } + + if err := ctx.Err(); err != nil { // cancelled? + return nil, err + } + initial := cha.CallGraph(prog) + allFuncs := ssautil.AllFunctions(prog) + + fslice := forwardSlice(entrySlice, initial) + // Keep only actually linked functions. + pruneSet(fslice, allFuncs) + + if err := ctx.Err(); err != nil { // cancelled? + return nil, err + } + vtaCg := vta.CallGraph(fslice, initial) + + // Repeat the process once more, this time using + // the produced VTA call graph as the base graph. + fslice = forwardSlice(entrySlice, vtaCg) + pruneSet(fslice, allFuncs) + + if err := ctx.Err(); err != nil { // cancelled? + return nil, err + } + cg := vta.CallGraph(fslice, vtaCg) + cg.DeleteSyntheticNodes() + + return cg, nil +} + +// forwardSlice computes the transitive closure of functions forward reachable +// via calls in cg or referred to in an instruction starting from `sources`. +// +// https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/slicing.go#L14 +func forwardSlice(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool { + seen := make(map[*ssa.Function]bool) + var visit func(f *ssa.Function) + visit = func(f *ssa.Function) { + if seen[f] { + return + } + seen[f] = true + + if n := cg.Nodes[f]; n != nil { + for _, e := range n.Out { + if e.Site != nil { + visit(e.Callee.Func) + } + } + } + + var buf [10]*ssa.Value // avoid alloc in common case + for _, b := range f.Blocks { + for _, instr := range b.Instrs { + for _, op := range instr.Operands(buf[:0]) { + if fn, ok := (*op).(*ssa.Function); ok { + visit(fn) + } + } + } + } + } + for source := range sources { + visit(source) + } + return seen +} + +// pruneSet removes functions in `set` that are in `toPrune`. +// +// https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/slicing.go#L49 +func pruneSet(set, toPrune map[*ssa.Function]bool) { + for f := range set { + if !toPrune[f] { + delete(set, f) + } + } +} diff --git a/callgraphutil/path.go b/callgraphutil/path.go new file mode 100644 index 0000000..a7453bd --- /dev/null +++ b/callgraphutil/path.go @@ -0,0 +1,196 @@ +package callgraphutil + +import ( + "bytes" + + "golang.org/x/tools/go/callgraph" +) + +// Path is a sequence of callgraph.Edges, where each edge +// represents a call from a caller to a callee, making up +// a "chain" of calls, e.g.: main → foo → bar → baz. +type Path []*callgraph.Edge + +// Empty returns true if the path is empty, false otherwise. +func (p Path) Empty() bool { + return len(p) == 0 +} + +// First returns the first edge in the path, or nil if the path is empty. +func (p Path) First() *callgraph.Edge { + if len(p) == 0 { + return nil + } + return p[0] +} + +// Last returns the last edge in the path, or nil if the path is empty. +func (p Path) Last() *callgraph.Edge { + if len(p) == 0 { + return nil + } + return p[len(p)-1] +} + +// String returns a string representation of the path which +// is a sequence of edges separated by " → ". +// +// Intended to be used while debugging. +func (p Path) String() string { + var buf bytes.Buffer + for i, e := range p { + if i == 0 { + buf.WriteString(e.Caller.String()) + } + + buf.WriteString(" → ") + + buf.WriteString(e.Callee.String()) + } + return buf.String() +} + +// Paths is a collection of paths, which may be logically grouped +// together, e.g.: all paths from main to foo, or all paths from +// main to bar. +type Paths []Path + +// Shortest returns the shortest path in the collection of paths. +// +// If there are no paths, this returns nil. If there are multiple +// paths of the same length, this returns the first path found. +func (p Paths) Shortest() Path { + if len(p) == 0 { + return nil + } + + shortest := p[0] + for _, path := range p { + if len(path) < len(shortest) { + shortest = path + } + } + + return shortest +} + +// Longest returns the longest path in the collection of paths. +// +// If there are no paths, this returns nil. If there are multiple +// paths of the same length, the first path found is returned. +func (p Paths) Longest() Path { + if len(p) == 0 { + return nil + } + + longest := p[0] + for _, path := range p { + if len(path) > len(longest) { + longest = path + } + } + + return longest +} + +// PathSearch returns the first path found from the start node +// to a node that matches the isMatch function. This is a depth +// first search, so it will return the first path found, which +// may not be the shortest path. +// +// To find all paths, use PathsSearch, which returns a collection +// of paths. +func PathSearch(start *callgraph.Node, isMatch func(*callgraph.Node) bool) Path { + var ( + stack = make(Path, 0, 32) + seen = make(map[*callgraph.Node]bool) + + search func(n *callgraph.Node) Path + ) + + search = func(n *callgraph.Node) Path { + if !seen[n] { + // debug("searching: %v\n", n) + seen[n] = true + if isMatch(n) { + return stack + } + for _, e := range n.Out { + stack = append(stack, e) // push + if found := search(e.Callee); found != nil { + return found + } + stack = stack[:len(stack)-1] // pop + } + } + return nil + } + return search(start) +} + +// PathsSearch returns all paths found from the start node +// to a node that matches the isMatch function. Under the hood, +// this is a depth first search. +// +// To find the first path (which may not be the shortest), use PathSearch. +func PathsSearch(start *callgraph.Node, isMatch func(*callgraph.Node) bool) Paths { + var ( + paths = Paths{} + + stack = make(Path, 0, 32) + seen = make(map[*callgraph.Node]bool) + + search func(n *callgraph.Node) + ) + + search = func(n *callgraph.Node) { + if n == nil { + return + } + + // debug("searching: %v\n", n) + if !seen[n] { + seen[n] = true + if isMatch(n) { + paths = append(paths, stack) + + stack = make(Path, 0, 32) + seen = make(map[*callgraph.Node]bool) + return + } + for _, e := range n.Out { + // debug("\tout: %v\n", e) + stack = append(stack, e) // push + search(e.Callee) + if len(stack) == 0 { + continue + } + stack = stack[:len(stack)-1] // pop + } + } + } + search(start) + + return paths +} + +// PathSearchCallTo returns the first path found from the start node +// to a node that matches the function name. +func PathSearchCallTo(start *callgraph.Node, fn string) Path { + return PathSearch(start, func(n *callgraph.Node) bool { + fnStr := n.Func.String() + return fnStr == fn + }) +} + +// PathsSearchCallTo returns the paths that call the given function name, +// which uses SSA function name syntax, e.g.: "(*database/sql.DB).Query". +func PathsSearchCallTo(start *callgraph.Node, fn string) Paths { + return PathsSearch(start, func(n *callgraph.Node) bool { + if n == nil || n.Func == nil { + return false + } + fnStr := n.Func.String() + return fnStr == fn + }) +} diff --git a/callgraphutil/ssa.go b/callgraphutil/ssa.go new file mode 100644 index 0000000..6c1f40a --- /dev/null +++ b/callgraphutil/ssa.go @@ -0,0 +1,23 @@ +package callgraphutil + +import ( + "golang.org/x/tools/go/callgraph" + "golang.org/x/tools/go/ssa" +) + +// InstructionsFor returns the ssa.Instruction for the given ssa.Value using +// the given node as the root of the call graph that is searched. +func InstructionsFor(root *callgraph.Node, v ssa.Value) (si ssa.Instruction) { + PathsSearch(root, func(n *callgraph.Node) bool { + for _, b := range root.Func.Blocks { + for _, instr := range b.Instrs { + if instr.Pos() == v.Pos() { + si = instr + return true + } + } + } + return false + }) + return +} diff --git a/check.go b/check.go index f1a9dea..8a2fbea 100644 --- a/check.go +++ b/check.go @@ -1,8 +1,10 @@ package taint import ( - "github.com/picatz/taint/callgraph" + "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" + + "github.com/picatz/taint/callgraphutil" ) // Result is an individual finding from a taint check. @@ -13,7 +15,7 @@ import ( type Result struct { // Path is the specific path within a callgraph // where the source founds its way into a sink. - Path callgraph.Path + Path callgraphutil.Path // Source type information. SourceType string @@ -65,7 +67,7 @@ func Check(cg *callgraph.Graph, sources Sources, sinks Sinks) Results { // within the callgraph that those sinks can end up as // the final node path (the "sink path"). for sink := range sinks { - sinkPaths := callgraph.PathsSearchCallTo(cg.Root, sink) + sinkPaths := callgraphutil.PathsSearchCallTo(cg.Root, sink) // fmt.Println("sink paths:", len(sinkPaths)) @@ -106,7 +108,7 @@ func Check(cg *callgraph.Graph, sources Sources, sinks Sinks) Results { // checkPath implements taint analysis that can be used to identify if the given // callgraph path contains information from taintable sources (typically user input). -func checkPath(path callgraph.Path, sources Sources) (bool, string, ssa.Value) { +func checkPath(path callgraphutil.Path, sources Sources) (bool, string, ssa.Value) { // Ensure the path isn't empty (which can happen?!). if path.Empty() { return false, "", nil @@ -134,7 +136,7 @@ func checkPath(path callgraph.Path, sources Sources) (bool, string, ssa.Value) { // calls itself (or checkSSAInstruction) as nessecary. // // It returns true if the given SSA value is tained by any of the given sources. -func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited valueSet) (bool, string, ssa.Value) { +func checkSSAValue(path callgraphutil.Path, sources Sources, v ssa.Value, visited valueSet) (bool, string, ssa.Value) { // First, check if this value has already been visited. // // If so, we can assume it is safe. @@ -550,7 +552,7 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va // checkSSAInstruction is used internally by checkSSAValue when it needs to traverse // SSA instructions, like the contents of a calling function. -func checkSSAInstruction(path callgraph.Path, sources Sources, i ssa.Instruction, visited valueSet) (bool, string, ssa.Value) { +func checkSSAInstruction(path callgraphutil.Path, sources Sources, i ssa.Instruction, visited valueSet) (bool, string, ssa.Value) { // fmt.Printf("! check SSA instr %s: %[1]T\n", i) switch instr := i.(type) { diff --git a/cmd/taint/main.go b/cmd/taint/main.go index 751eebf..5dfc7f2 100644 --- a/cmd/taint/main.go +++ b/cmd/taint/main.go @@ -18,8 +18,9 @@ import ( "github.com/charmbracelet/lipgloss" "github.com/picatz/taint" - "github.com/picatz/taint/callgraph" + "github.com/picatz/taint/callgraphutil" "golang.org/x/term" + "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" @@ -369,7 +370,7 @@ var builtinCommandLoad = &command{ return nil } - cg, err = callgraph.New(mainFn, srcFns...) + cg, err = callgraphutil.NewGraph(mainFn, srcFns...) if err != nil { bt.WriteString(err.Error() + "\n") bt.Flush() @@ -427,7 +428,7 @@ var builtinCommandCG = &command{ return nil } - cgStr := strings.ReplaceAll(cg.String(), "→", styleFaint.Render("→")) + cgStr := strings.ReplaceAll(callgraphutil.GraphString(cg), "→", styleFaint.Render("→")) bt.WriteString(cgStr) bt.Flush() @@ -527,7 +528,7 @@ var builtinCommandsCallpath = &command{ fn := args[0] - paths := callgraph.PathsSearchCallTo(cg.Root, fn) + paths := callgraphutil.PathsSearchCallTo(cg.Root, fn) if len(paths) == 0 { bt.WriteString("no calls to " + fn + "\n") diff --git a/cmd/taint/main_test.go b/cmd/taint/main_test.go index 9ad27cf..2f18f43 100644 --- a/cmd/taint/main_test.go +++ b/cmd/taint/main_test.go @@ -8,7 +8,7 @@ import ( "os" "testing" - "github.com/picatz/taint/callgraph" + "github.com/picatz/taint/callgraphutil" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" @@ -97,7 +97,7 @@ func TestLoadAndSearch(t *testing.T) { t.Fatal("main function not found") } - cg, err := callgraph.New(mainFn, srcFns...) + cg, err := callgraphutil.NewGraph(mainFn, srcFns...) if err != nil { t.Fatal(err) } @@ -112,7 +112,7 @@ func TestLoadAndSearch(t *testing.T) { // t.Log(path) - paths := callgraph.PathsSearchCallTo(cg.Root, "(*database/sql.DB).Query") + paths := callgraphutil.PathsSearchCallTo(cg.Root, "(*database/sql.DB).Query") if len(paths) == 0 { t.Fatal("no paths found") diff --git a/log/injection/injection.go b/log/injection/injection.go index c5a2a55..1feac09 100644 --- a/log/injection/injection.go +++ b/log/injection/injection.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/picatz/taint" - "github.com/picatz/taint/callgraph" + "github.com/picatz/taint/callgraphutil" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" @@ -129,7 +129,7 @@ func run(pass *analysis.Pass) (interface{}, error) { // Construct a callgraph, using the main function as the root, // constructed of all other functions. This returns a callgraph // we can use to identify directed paths to logging functions. - cg, err := callgraph.New(mainFn, buildSSA.SrcFuncs...) + cg, err := callgraphutil.NewGraph(mainFn, buildSSA.SrcFuncs...) if err != nil { return nil, fmt.Errorf("failed to create new callgraph: %w", err) } diff --git a/sql/injection/injection.go b/sql/injection/injection.go index 5822d54..d1979a1 100644 --- a/sql/injection/injection.go +++ b/sql/injection/injection.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/picatz/taint" - "github.com/picatz/taint/callgraph" + "github.com/picatz/taint/callgraphutil" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" @@ -119,11 +119,42 @@ func run(pass *analysis.Pass) (interface{}, error) { // Construct a callgraph, using the main function as the root, // constructed of all other functions. This returns a callgraph // we can use to identify directed paths to SQL queries. - cg, err := callgraph.New(mainFn, buildSSA.SrcFuncs...) + cg, err := callgraphutil.NewGraph(mainFn, buildSSA.SrcFuncs...) if err != nil { return nil, fmt.Errorf("failed to create new callgraph: %w", err) } + // If you'd like to compare the callgraph constructed by the + // callgraphutil package to the one constructed by others + // (e.g. pointer analysis, rta, cha, static, etc), uncomment the + // following lines and compare the output. + // + // Today, I believe the callgraphutil package is the most + // accurate, but I'd love to be proven wrong. + + // Note: this actually panis for testcase b + // ptares, err := pointer.Analyze(&pointer.Config{ + // Mains: []*ssa.Package{buildSSA.Pkg}, + // BuildCallGraph: true, + // }) + // if err != nil { + // return nil, fmt.Errorf("failed to create new callgraph using pointer analysis: %w", err) + // } + // cg := ptares.CallGraph + + // cg := rta.Analyze([]*ssa.Function{mainFn}, true).CallGraph + // cg := cha.CallGraph(buildSSA.Pkg.Prog) + // cg := static.CallGraph(buildSSA.Pkg.Prog) + + // https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/utils.go#L61 + // cg, err := callgraphutil.NewVulncheckCallGraph(context.Background(), buildSSA.Pkg.Prog, buildSSA.SrcFuncs) + // if err != nil { + // return nil, err + // } + // cg.Root = cg.CreateNode(mainFn) + + // fmt.Println(callgraphutil.CallGraphString(cg)) + // Run taint check for user controlled values (sources) ending // up in injectable SQL methods (sinks). results := taint.Check(cg, userControlledValues, injectableSQLMethods) diff --git a/xss/xss.go b/xss/xss.go index 06fbfc2..358c3cd 100644 --- a/xss/xss.go +++ b/xss/xss.go @@ -5,7 +5,7 @@ import ( "strings" "github.com/picatz/taint" - "github.com/picatz/taint/callgraph" + "github.com/picatz/taint/callgraphutil" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" @@ -68,7 +68,7 @@ func run(pass *analysis.Pass) (interface{}, error) { // Construct a callgraph, using the main function as the root, // constructed of all other functions. This returns a callgraph // we can use to identify directed paths to logging functions. - cg, err := callgraph.New(mainFn, buildSSA.SrcFuncs...) + cg, err := callgraphutil.NewGraph(mainFn, buildSSA.SrcFuncs...) if err != nil { return nil, fmt.Errorf("failed to create new callgraph: %w", err) }