diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 6ab8acf..354daf7 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -16,19 +16,9 @@ jobs: - uses: actions/checkout@v3 - name: Setup Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v4 with: - go-version: 1.19 - - - name: Setup Go Cache - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-golang-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-golang- + go-version: '1.21' - name: Build run: go build -v ./... diff --git a/callgraph/callgraph.go b/callgraph/callgraph.go index 4bf0d68..d54781e 100644 --- a/callgraph/callgraph.go +++ b/callgraph/callgraph.go @@ -2,14 +2,11 @@ package callgraph import ( "bytes" - "context" "fmt" "go/token" "go/types" - "runtime" "sync" - "golang.org/x/sync/errgroup" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) @@ -20,7 +17,6 @@ import ( // If the call graph is sound, such nodes indicate unreachable // functions. type Graph struct { - sync.RWMutex Root *Node // the distinguished root node Nodes map[*ssa.Function]*Node // all nodes by function } @@ -28,38 +24,24 @@ type Graph struct { // New returns a new Graph with the specified root node. func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) { g := &Graph{ - RWMutex: sync.RWMutex{}, - Nodes: make(map[*ssa.Function]*Node), + Nodes: make(map[*ssa.Function]*Node), } g.Root = g.CreateNode(root) - eg, _ := errgroup.WithContext(context.Background()) - - eg.SetLimit(runtime.NumCPU()) - allFns := ssautil.AllFunctions(root.Prog) for _, srcFn := range srcFns { - fn := srcFn - eg.Go(func() error { - err := g.AddFunction(fn, allFns) - if err != nil { - return fmt.Errorf("failed to add src function %v: %w", fn, err) - } + err := g.AddFunction(srcFn, allFns) + if err != nil { + return g, fmt.Errorf("failed to add src function %v: %w", srcFn, err) + } - for _, block := range fn.DomPreorder() { - for _, instr := range block.Instrs { - checkBlockInstruction(root, allFns, g, fn, instr) - } + for _, block := range srcFn.DomPreorder() { + for _, instr := range block.Instrs { + checkBlockInstruction(root, allFns, g, srcFn, instr) } - return nil - }) - } - - err := eg.Wait() - if err != nil { - return nil, fmt.Errorf("error from errgroup: %w", err) + } } return g, nil @@ -70,8 +52,6 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g case *ssa.Call: var instrCall *ssa.Function - // TODO: map more things to instrCall? - switch callt := instrt.Call.Value.(type) { case *ssa.Function: instrCall = callt @@ -155,6 +135,9 @@ func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g pkg := root.Prog.ImportedPackage(instrt.Call.Method.Pkg().Path()) fn := pkg.Prog.NewFunction(instrt.Call.Method.Name(), instrt.Call.Signature(), "callgraph") instrCall = fn + default: + // case *ssa.TypeAssert: ?? + // fmt.Printf("unknown call type: %v: %[1]T\n", callt) } // If we could not determine the function being @@ -199,24 +182,16 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool recvType = recv.Type() } - // start := time.Now() - if len(allFns) == 0 { allFns = ssautil.AllFunctions(target.Prog) } - // log.Printf("finished loading %d functions for target %v in %v seconds", len(allFns), target, time.Since(start).Seconds()) - // Find all direct calls to function, // or a place where its address is taken. for progFn := range allFns { - fn := progFn - // debugf("checking prog fn %v", fn) - // log.Printf("strt analyzing %v : blk %d", targetNode, len(fn.Blocks)) var space [32]*ssa.Value // preallocate - blocks := fn.DomPreorder() - for _, block := range blocks { + for _, block := range progFn.DomPreorder() { for _, instr := range block.Instrs { // Is this a method (T).f of a concrete type T // whose runtime type descriptor is address-taken? @@ -239,7 +214,7 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool // Direct call to target? rands := instr.Operands(space[:0]) if site, ok := instr.(ssa.CallInstruction); ok && site.Common().Value == target { - AddEdge(cg.CreateNode(fn), site, targetNode) + AddEdge(cg.CreateNode(progFn), site, targetNode) rands = rands[1:] // skip .Value (rands[0]) } @@ -258,8 +233,6 @@ func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool // CreateNode returns the Node for fn, creating it if not present. func (g *Graph) CreateNode(fn *ssa.Function) *Node { - g.Lock() - defer g.Unlock() n, ok := g.Nodes[fn] if !ok { n = &Node{Func: fn, ID: len(g.Nodes)} diff --git a/check.go b/check.go index cf70518..60919b3 100644 --- a/check.go +++ b/check.go @@ -67,7 +67,10 @@ func Check(cg *callgraph.Graph, sources Sources, sinks Sinks) Results { for sink := range sinks { sinkPaths := callgraph.PathsSearchCallTo(cg.Root, sink) + // fmt.Println("sink paths:", len(sinkPaths)) + for _, sinkPath := range sinkPaths { + // fmt.Println("sink path:", sinkPath) // Ensure the path isn't empty (which can happen?!). // // TODO: ensure returned paths from within searched paths @@ -121,6 +124,7 @@ func checkPath(path callgraph.Path, sources Sources) (bool, string, ssa.Value) { // TODO: when non-function sinks are supported, we will need to handle // them differently here. for _, lastCallArg := range lastCallArgs { + // fmt.Println(lastCallArg) tainted, src, tv := checkSSAValue(path, sources, lastCallArg, visited) if tainted { return true, src, tv @@ -182,6 +186,26 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va return true, src, value } + // Check the parameter's referrers. + refs := value.Referrers() + if refs != nil { + for _, ref := range *refs { + refVal, isVal := ref.(ssa.Value) + if isVal { + tainted, src, tv := checkSSAValue(path, sources, refVal, visited) + if tainted { + return true, src, tv + } + continue + } + + tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) + if tainted { + return true, src, tv + } + } + } + // TODO: consider if we can remove the range with a single // step backwards? for _, edge := range path { @@ -443,6 +467,30 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va if tainted { return true, src, tv } + + // Check the value's referrers. + refs := value.X.Referrers() + for _, ref := range *refs { + refVal, isVal := ref.(ssa.Value) + if isVal { + tainted, src, tv := checkSSAValue(path, sources, refVal, visited) + if tainted { + return true, src, tv + } + continue + } + + tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) + if tainted { + return true, src, tv + } + } + case *ssa.TypeAssert: + // Check the value being type asserted. + tainted, src, tv := checkSSAValue(path, sources, value.X, visited) + if tainted { + return true, src, tv + } case *ssa.Convert: // Check the value being converted. tainted, src, tv := checkSSAValue(path, sources, value.X, visited) diff --git a/go.mod b/go.mod index b98f2d5..e6ee40b 100644 --- a/go.mod +++ b/go.mod @@ -2,10 +2,7 @@ module github.com/picatz/taint go 1.19 -require ( - golang.org/x/sync v0.1.0 - golang.org/x/tools v0.4.0 -) +require golang.org/x/tools v0.4.0 require ( golang.org/x/mod v0.7.0 // indirect diff --git a/go.sum b/go.sum index 65b3000..a20f737 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,6 @@ golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4= diff --git a/walk_ssa.go b/walk_ssa.go new file mode 100644 index 0000000..f294b65 --- /dev/null +++ b/walk_ssa.go @@ -0,0 +1,160 @@ +package taint + +import ( + "fmt" + + "golang.org/x/tools/go/ssa" +) + +var ErrStopWalk = fmt.Errorf("taint: stop walk") + +// WalkSSA walks the SSA IR recursively with a visitor function that +// can be used to inspect each node in the graph. The visitor function +// should return an error if it wants to stop the walk. +func WalkSSA(v ssa.Value, visit func(v ssa.Value) error) error { + visited := make(valueSet) + + return walkSSA(v, visit, visited) +} + +func walkSSA(v ssa.Value, visit func(v ssa.Value) error, visited valueSet) error { + if visited == nil { + visited = make(valueSet) + } + + if visited.includes(v) { + return nil + } + + visited.add(v) + + // fmt.Printf("walk SSA: %s: %[1]T\n", v) + + if err := visit(v); err != nil { + return err + } + + switch v := v.(type) { + case *ssa.Call: + // Check the operands of the call instruction. + for _, opr := range v.Operands(nil) { + if err := walkSSA(*opr, visit, visited); err != nil { + return err + } + } + + // Check the arguments of the call instruction. + for _, arg := range v.Common().Args { + if err := walkSSA(arg, visit, visited); err != nil { + return err + } + } + + // Check the function being called. + if err := walkSSA(v.Call.Value, visit, visited); err != nil { + return err + } + + // Check the return value of the call instruction. + if v.Common().IsInvoke() { + if err := walkSSA(v.Common().Value, visit, visited); err != nil { + return err + } + } + + // Check the return value of the call instruction. + if err := walkSSA(v.Common().Value, visit, visited); err != nil { + return err + } + case *ssa.ChangeInterface: + if err := walkSSA(v.X, visit, visited); err != nil { + return err + } + case *ssa.Convert: + if err := walkSSA(v.X, visit, visited); err != nil { + return err + } + case *ssa.MakeInterface: + if err := walkSSA(v.X, visit, visited); err != nil { + return err + } + case *ssa.Phi: + for _, edge := range v.Edges { + if err := walkSSA(edge, visit, visited); err != nil { + return err + } + } + case *ssa.UnOp: + if err := walkSSA(v.X, visit, visited); err != nil { + return err + } + case *ssa.Function: + for _, block := range v.Blocks { + for _, instr := range block.Instrs { + for _, opr := range instr.Operands(nil) { + if err := walkSSA(*opr, visit, visited); err != nil { + return err + } + } + } + } + default: + // fmt.Printf("? walk SSA %s: %[1]T\n", v) + } + + refs := v.Referrers() + if refs == nil { + return nil + } + + for _, instr := range *refs { + switch instr := instr.(type) { + case *ssa.Store: + // Store instructions need to be checked for both the value being stored, + // and the address being stored to. + if err := walkSSA(instr.Val, visit, visited); err != nil { + return err + } + + if err := walkSSA(instr.Addr, visit, visited); err != nil { + return err + } + case *ssa.Call: + // Check the operands of the call instruction. + for _, opr := range instr.Operands(nil) { + if err := walkSSA(*opr, visit, visited); err != nil { + return err + } + } + + // Check the arguments of the call instruction. + for _, arg := range instr.Common().Args { + if err := walkSSA(arg, visit, visited); err != nil { + return err + } + } + + // Check the function being called. + if err := walkSSA(instr.Call.Value, visit, visited); err != nil { + return err + } + + // Check the return value of the call instruction. + if instr.Common().IsInvoke() { + if err := walkSSA(instr.Common().Value, visit, visited); err != nil { + return err + } + } + + // Check the return value of the call instruction. + if err := walkSSA(instr.Common().Value, visit, visited); err != nil { + return err + } + default: + // fmt.Printf("? check SSA instr %s: %[1]T\n", i) + continue + } + } + + return nil +} diff --git a/xss/testdata/src/e/main.go b/xss/testdata/src/e/main.go new file mode 100644 index 0000000..d989279 --- /dev/null +++ b/xss/testdata/src/e/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "io" + "net/http" +) + +// this will panic if run, because the given *http.Request is not an io.Reader +// but it's fine for testing, because we don't actually run the code. +func echo(w io.Writer, r any) { + ior := r.(io.Reader) + + b, err := io.ReadAll(ior) + if err != nil { + panic(err) + } + + w.Write(b) +} + +func handler(w http.ResponseWriter, r *http.Request) { + echo(w, r) // want "potential XSS" +} + +func main() { + http.HandleFunc("/mirror-safe", handler) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/testdata/src/f/main.go b/xss/testdata/src/f/main.go new file mode 100644 index 0000000..fb108f3 --- /dev/null +++ b/xss/testdata/src/f/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "io" + "net/http" +) + +func echo(w io.Writer, r any) { + ior := r.(io.Reader) + + b, err := io.ReadAll(ior) + if err != nil { + panic(err) + } + + w.Write(b) +} + +func handler(w http.ResponseWriter, r *http.Request) { + echo(w, r.Body) // want "potential XSS" +} + +func main() { + http.HandleFunc("/mirror-safe", handler) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/testdata/src/g/main.go b/xss/testdata/src/g/main.go new file mode 100644 index 0000000..9233e0e --- /dev/null +++ b/xss/testdata/src/g/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "bufio" + "html" + "io" + "net/http" +) + +func echoSafe(w io.Writer, r any) { + ior := r.(io.Reader) + + b, err := io.ReadAll(ior) + if err != nil { + panic(err) + } + + es := html.EscapeString(string(b)) + + w.Write([]byte(es)) +} + +func echoUnsafe(w io.Writer, r any) { + ior := r.(io.Reader) + + b, err := io.ReadAll(ior) + if err != nil { + panic(err) + } + + w.Write(b) +} + +func handler(w http.ResponseWriter, r *http.Request) { + b := bufio.NewWriterSize(w, 4096) + defer b.Flush() + + switch r.URL.Path { + case "/mirror-safe": + echoSafe(w, r.Body) + case "/mirror-unsafe": + echoUnsafe(w, r.Body) // want "potential XSS" + } +} + +func main() { + http.HandleFunc("/", handler) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/xss.go b/xss/xss.go index 27b5544..06fbfc2 100644 --- a/xss/xss.go +++ b/xss/xss.go @@ -73,6 +73,8 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, fmt.Errorf("failed to create new callgraph: %w", err) } + // fmt.Println(cg) + // Run taint check for user controlled values (sources) ending // up in injectable log functions (sinks). results := taint.Check(cg, userControlledValues, injectableFunctions) @@ -83,10 +85,17 @@ func run(pass *analysis.Pass) (interface{}, error) { var escaped bool for _, edge := range result.Path { for _, arg := range edge.Site.Common().Args { - if checkIfHTMLEscapeString(arg) { - escaped = true - break - } + taint.WalkSSA(arg, func(v ssa.Value) error { + call, ok := v.(*ssa.Call) + if !ok { + return nil + } + if call.Call.Value.String() == "html.EscapeString" { + escaped = true + return taint.ErrStopWalk + } + return nil + }) } if escaped { break @@ -100,53 +109,3 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } - -// checkIfHTMLEscapeString returns true if the given value uses -// html.EscapeString, calling itself recursively as needed. -func checkIfHTMLEscapeString(value ssa.Value) bool { - switch value := value.(type) { - case *ssa.Call: - return value.Call.Value.String() == "html.EscapeString" - case *ssa.MakeInterface: - return checkIfHTMLEscapeString(value.X) - case *ssa.ChangeInterface: - return checkIfHTMLEscapeString(value.X) - case *ssa.Convert: - return checkIfHTMLEscapeString(value.X) - case *ssa.UnOp: - return checkIfHTMLEscapeString(value.X) - case *ssa.Phi: - for _, edge := range value.Edges { - if checkIfHTMLEscapeString(edge) { - return true - } - } - return false - case *ssa.Alloc: - refs := value.Referrers() - if refs == nil { - return false - } - for _, instr := range *refs { - for _, opr := range instr.Operands(nil) { - if checkIfHTMLEscapeString(*opr) { - return true - } - } - } - case *ssa.FieldAddr: - return checkIfHTMLEscapeString(value.X) - case *ssa.Field: - return checkIfHTMLEscapeString(value.X) - case *ssa.IndexAddr: - return checkIfHTMLEscapeString(value.X) - case *ssa.Index: - return checkIfHTMLEscapeString(value.X) - case *ssa.Lookup: - return checkIfHTMLEscapeString(value.X) - case *ssa.Slice: - return checkIfHTMLEscapeString(value.X) - } - - return false -} diff --git a/xss/xss_test.go b/xss/xss_test.go index c2ff574..12a9f3e 100644 --- a/xss/xss_test.go +++ b/xss/xss_test.go @@ -23,3 +23,15 @@ func TestC(t *testing.T) { func TestD(t *testing.T) { analysistest.Run(t, testdata, Analyzer, "d") } + +func TestE(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "e") +} + +func TestF(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "f") +} + +func TestG(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "g") +}