Skip to content

Commit

Permalink
Add more tests, utilities, and call graph construction stability (#11)
Browse files Browse the repository at this point in the history
* Add more XSS tests
* Add re-usable SSA walking function
* Simplify call graph construction, making it reliable
  • Loading branch information
picatz authored Nov 21, 2023
1 parent 470dd40 commit 6b31d14
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 112 deletions.
14 changes: 2 additions & 12 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,9 @@ jobs:
- uses: actions/checkout@v3

- name: Setup Go
uses: actions/setup-go@v3
uses: actions/setup-go@v4
with:
go-version: 1.19

- name: Setup Go Cache
uses: actions/cache@v3
with:
path: |
~/.cache/go-build
~/go/pkg/mod
key: ${{ runner.os }}-golang-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-golang-
go-version: '1.21'

- name: Build
run: go build -v ./...
Expand Down
55 changes: 14 additions & 41 deletions callgraph/callgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@ package callgraph

import (
"bytes"
"context"
"fmt"
"go/token"
"go/types"
"runtime"
"sync"

"golang.org/x/sync/errgroup"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
Expand All @@ -20,46 +17,31 @@ import (
// If the call graph is sound, such nodes indicate unreachable
// functions.
type Graph struct {
sync.RWMutex
Root *Node // the distinguished root node
Nodes map[*ssa.Function]*Node // all nodes by function
}

// New returns a new Graph with the specified root node.
func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) {
g := &Graph{
RWMutex: sync.RWMutex{},
Nodes: make(map[*ssa.Function]*Node),
Nodes: make(map[*ssa.Function]*Node),
}

g.Root = g.CreateNode(root)

eg, _ := errgroup.WithContext(context.Background())

eg.SetLimit(runtime.NumCPU())

allFns := ssautil.AllFunctions(root.Prog)

for _, srcFn := range srcFns {
fn := srcFn
eg.Go(func() error {
err := g.AddFunction(fn, allFns)
if err != nil {
return fmt.Errorf("failed to add src function %v: %w", fn, err)
}
err := g.AddFunction(srcFn, allFns)
if err != nil {
return g, fmt.Errorf("failed to add src function %v: %w", srcFn, err)
}

for _, block := range fn.DomPreorder() {
for _, instr := range block.Instrs {
checkBlockInstruction(root, allFns, g, fn, instr)
}
for _, block := range srcFn.DomPreorder() {
for _, instr := range block.Instrs {
checkBlockInstruction(root, allFns, g, srcFn, instr)
}
return nil
})
}

err := eg.Wait()
if err != nil {
return nil, fmt.Errorf("error from errgroup: %w", err)
}
}

return g, nil
Expand All @@ -70,8 +52,6 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g
case *ssa.Call:
var instrCall *ssa.Function

// TODO: map more things to instrCall?

switch callt := instrt.Call.Value.(type) {
case *ssa.Function:
instrCall = callt
Expand Down Expand Up @@ -155,6 +135,9 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g
pkg := root.Prog.ImportedPackage(instrt.Call.Method.Pkg().Path())
fn := pkg.Prog.NewFunction(instrt.Call.Method.Name(), instrt.Call.Signature(), "callgraph")
instrCall = fn
default:
// case *ssa.TypeAssert: ??
// fmt.Printf("unknown call type: %v: %[1]T\n", callt)
}

// If we could not determine the function being
Expand Down Expand Up @@ -199,24 +182,16 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool
recvType = recv.Type()
}

// start := time.Now()

if len(allFns) == 0 {
allFns = ssautil.AllFunctions(target.Prog)
}

// log.Printf("finished loading %d functions for target %v in %v seconds", len(allFns), target, time.Since(start).Seconds())

// Find all direct calls to function,
// or a place where its address is taken.
for progFn := range allFns {
fn := progFn
// debugf("checking prog fn %v", fn)
// log.Printf("strt analyzing %v : blk %d", targetNode, len(fn.Blocks))
var space [32]*ssa.Value // preallocate
blocks := fn.DomPreorder()

for _, block := range blocks {
for _, block := range progFn.DomPreorder() {
for _, instr := range block.Instrs {
// Is this a method (T).f of a concrete type T
// whose runtime type descriptor is address-taken?
Expand All @@ -239,7 +214,7 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool
// Direct call to target?
rands := instr.Operands(space[:0])
if site, ok := instr.(ssa.CallInstruction); ok && site.Common().Value == target {
AddEdge(cg.CreateNode(fn), site, targetNode)
AddEdge(cg.CreateNode(progFn), site, targetNode)
rands = rands[1:] // skip .Value (rands[0])
}

Expand All @@ -258,8 +233,6 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool

// CreateNode returns the Node for fn, creating it if not present.
func (g *Graph) CreateNode(fn *ssa.Function) *Node {
g.Lock()
defer g.Unlock()
n, ok := g.Nodes[fn]
if !ok {
n = &Node{Func: fn, ID: len(g.Nodes)}
Expand Down
48 changes: 48 additions & 0 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ func Check(cg *callgraph.Graph, sources Sources, sinks Sinks) Results {
for sink := range sinks {
sinkPaths := callgraph.PathsSearchCallTo(cg.Root, sink)

// fmt.Println("sink paths:", len(sinkPaths))

for _, sinkPath := range sinkPaths {
// fmt.Println("sink path:", sinkPath)
// Ensure the path isn't empty (which can happen?!).
//
// TODO: ensure returned paths from within searched paths
Expand Down Expand Up @@ -121,6 +124,7 @@ func checkPath(path callgraph.Path, sources Sources) (bool, string, ssa.Value) {
// TODO: when non-function sinks are supported, we will need to handle
// them differently here.
for _, lastCallArg := range lastCallArgs {
// fmt.Println(lastCallArg)
tainted, src, tv := checkSSAValue(path, sources, lastCallArg, visited)
if tainted {
return true, src, tv
Expand Down Expand Up @@ -182,6 +186,26 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va
return true, src, value
}

// Check the parameter's referrers.
refs := value.Referrers()
if refs != nil {
for _, ref := range *refs {
refVal, isVal := ref.(ssa.Value)
if isVal {
tainted, src, tv := checkSSAValue(path, sources, refVal, visited)
if tainted {
return true, src, tv
}
continue
}

tainted, src, tv := checkSSAInstruction(path, sources, ref, visited)
if tainted {
return true, src, tv
}
}
}

// TODO: consider if we can remove the range with a single
// step backwards?
for _, edge := range path {
Expand Down Expand Up @@ -443,6 +467,30 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va
if tainted {
return true, src, tv
}

// Check the value's referrers.
refs := value.X.Referrers()
for _, ref := range *refs {
refVal, isVal := ref.(ssa.Value)
if isVal {
tainted, src, tv := checkSSAValue(path, sources, refVal, visited)
if tainted {
return true, src, tv
}
continue
}

tainted, src, tv := checkSSAInstruction(path, sources, ref, visited)
if tainted {
return true, src, tv
}
}
case *ssa.TypeAssert:
// Check the value being type asserted.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
if tainted {
return true, src, tv
}
case *ssa.Convert:
// Check the value being converted.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
Expand Down
5 changes: 1 addition & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ module github.com/picatz/taint

go 1.19

require (
golang.org/x/sync v0.1.0
golang.org/x/tools v0.4.0
)
require golang.org/x/tools v0.4.0

require (
golang.org/x/mod v0.7.0 // indirect
Expand Down
1 change: 0 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4=
Expand Down
Loading

0 comments on commit 6b31d14

Please sign in to comment.