From 470dd404144f6143c5f56ffdd532d92c4e6180ea Mon Sep 17 00:00:00 2001 From: Kent 'picat' Gruber Date: Mon, 7 Aug 2023 23:44:54 -0400 Subject: [PATCH] Add XSS analysis (#9) * Add initial implementation of `xss` package and CLI * Better handle method call graph construction and concurrency * Remove unused debug and track interface functions with `*ssa.ChangeInterface` --- README.md | 63 ++++++++++- callgraph/callgraph.go | 214 ++++++++++++++++++++++--------------- check.go | 6 ++ cmd/xss/main.go | 10 ++ xss/testdata/src/a/main.go | 13 +++ xss/testdata/src/b/main.go | 19 ++++ xss/testdata/src/c/main.go | 53 +++++++++ xss/testdata/src/d/main.go | 27 +++++ xss/xss.go | 152 ++++++++++++++++++++++++++ xss/xss_test.go | 25 +++++ 10 files changed, 493 insertions(+), 89 deletions(-) create mode 100644 cmd/xss/main.go create mode 100644 xss/testdata/src/a/main.go create mode 100644 xss/testdata/src/b/main.go create mode 100644 xss/testdata/src/c/main.go create mode 100644 xss/testdata/src/d/main.go create mode 100644 xss/xss.go create mode 100644 xss/xss_test.go diff --git a/README.md b/README.md index 99e6821..94a4bd8 100644 --- a/README.md +++ b/README.md @@ -94,8 +94,65 @@ func run() { func main() { run() } -$ time sqli main.go +$ sqli main.go ./sql/injection/testdata/src/example/main.go:9:10: potential sql injection +``` + +### `logi` + +The `logi` analyzer is a CLI tool that demonstrates usage of the `taint` package to find +potential log injections. + +```console +$ go install github.com/picatz/taint/cmd/logi@latest +``` + +```console +$ cd log/injection/testdata/src/a +$ cat main.go +package main + +import ( + "log" + "net/http" +) + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + log.Println(r.URL.Query().Get("input")) + }) + + http.ListenAndServe(":8080", nil) +} +$ logi main.go +./log/injection/testdata/src/example/main.go:10:14: potential log injection +``` + +### `xss` + +The `xss` analyzer is a CLI tool that demonstrates usage of the `taint` package to find +potential cross-site scripting (XSS) vulnerabilities. + +```console +$ go install github.com/picatz/taint/cmd/xss@latest +``` + +```console +$ cd xss/testdata/src/a +$ cat main.go +package main + +import ( + "net/http" +) -sqli main.go 0.12s user 0.15s system 291% cpu 0.094 total -``` \ No newline at end of file +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.URL.Query().Get("input"))) // want "potential XSS" + }) + + http.ListenAndServe(":8080", nil) +} +$ xss main.go +./xss/testdata/src/example/main.go:9:8: potential XSS +``` diff --git a/callgraph/callgraph.go b/callgraph/callgraph.go index 10dc929..4bf0d68 100644 --- a/callgraph/callgraph.go +++ b/callgraph/callgraph.go @@ -6,14 +6,10 @@ import ( "fmt" "go/token" "go/types" - "log" - "os" + "runtime" "sync" - "sync/atomic" - "time" "golang.org/x/sync/errgroup" - "golang.org/x/sync/semaphore" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) @@ -27,7 +23,6 @@ type Graph struct { sync.RWMutex Root *Node // the distinguished root node Nodes map[*ssa.Function]*Node // all nodes by function - debug bool } // New returns a new Graph with the specified root node. @@ -35,103 +30,27 @@ func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) { g := &Graph{ RWMutex: sync.RWMutex{}, Nodes: make(map[*ssa.Function]*Node), - debug: false, // TODO: make configurable with env variable? } + g.Root = g.CreateNode(root) eg, _ := errgroup.WithContext(context.Background()) - // 500 = 5849.241s - // 10 = 4333.030s - var ( - logger *log.Logger - ops int64 - total int - ) - - if g.debug { - logger = log.New(os.Stderr, "callgraph-debug ", log.LstdFlags) - total = len(srcFns) - } - - s := semaphore.NewWeighted(10) + eg.SetLimit(runtime.NumCPU()) allFns := ssautil.AllFunctions(root.Prog) for _, srcFn := range srcFns { fn := srcFn - err := s.Acquire(context.Background(), 1) - if err != nil { - return nil, fmt.Errorf("failed to aquite semaphore: %w", err) - } eg.Go(func() error { - defer s.Release(1) - - if g.debug { - start := time.Now() - defer func() { - ops := atomic.AddInt64(&ops, 1) - logger.Printf("done processing %v (%v/%v) after %v seconds\n", fn, ops, total, time.Since(start).Seconds()) - }() - } - - err = g.AddFunction(fn, allFns) + err := g.AddFunction(fn, allFns) if err != nil { return fmt.Errorf("failed to add src function %v: %w", fn, err) } for _, block := range fn.DomPreorder() { for _, instr := range block.Instrs { - // debugf("found block instr") - switch instrt := instr.(type) { - case *ssa.Call: - // debugf("found call instr") - var instrCall *ssa.Function - - // Handle the case where the function calls a - // named function (*ssa.Function), and the case - // where the function calls an anonymous - // function (*ssa.MakeClosure). - switch callt := instrt.Call.Value.(type) { - case *ssa.Function: - // debugf("found call instr to function") - instrCall = callt - case *ssa.MakeClosure: - // debugf("found call instr to closure") - switch calltFn := callt.Fn.(type) { - case *ssa.Function: - instrCall = calltFn - } - } - - // If we could not determine the function being - // called, skip this instruction. - if instrCall == nil { - continue - } - - AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall)) - - err := g.AddFunction(instrCall, allFns) - if err != nil { - return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err) - } - - // attempt to link function arguments that are functions - for a := 0; a < len(instrt.Call.Args); a++ { - arg := instrt.Call.Args[a] - switch argt := arg.(type) { - case *ssa.Function: - // TODO: check if edge already exists? - AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt)) - case *ssa.MakeClosure: - switch argtFn := argt.Fn.(type) { - case *ssa.Function: - AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn)) - } - } - } - } + checkBlockInstruction(root, allFns, g, fn, instr) } } return nil @@ -146,6 +65,129 @@ func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) { return g, nil } +func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *Graph, fn *ssa.Function, instr ssa.Instruction) error { + switch instrt := instr.(type) { + case *ssa.Call: + var instrCall *ssa.Function + + // TODO: map more things to instrCall? + + switch callt := instrt.Call.Value.(type) { + case *ssa.Function: + instrCall = callt + + for _, instrtCallArg := range instrt.Call.Args { + switch instrtCallArgt := instrtCallArg.(type) { + case *ssa.ChangeInterface: + // Track type casts through matching interface methods. + // + // # Example + // + // func buffer(r io.Reader) io.Reader { + // return bufio.NewReader(r) + // } + // + // func mirror(w http.ResponseWriter, r *http.Request) { + // _, err := io.Copy(w, buffer(r.Body)) // w is an http.ResponseWriter, convert to io.Writer for io.Copy + // if err != nil { + // panic(err) + // } + // } + // + // io.Copy is called with an io.Writer, but the underlying type is a net/http.ResponseWriter. + // + // n11:net/http.HandleFunc → n1:c.mirror → n5:io.Copy → n6:(io.Writer).Write → n7:(net/http.ResponseWriter).Write + // + switch argtt := instrtCallArgt.Type().Underlying().(type) { + case *types.Interface: + numMethods := argtt.NumMethods() + + for i := 0; i < numMethods; i++ { + method := argtt.Method(i) + + pkg := root.Prog.ImportedPackage(method.Pkg().Path()) + fn := pkg.Prog.NewFunction(method.Name(), method.Type().(*types.Signature), "callgraph") + AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(fn)) + + switch xType := instrtCallArgt.X.Type().(type) { + case *types.Named: + named := xType + + pkg2 := root.Prog.ImportedPackage(named.Obj().Pkg().Path()) + + methodSet := pkg2.Prog.MethodSets.MethodSet(named) + methodSel := methodSet.Lookup(pkg2.Pkg, method.Name()) + + if methodSel == nil { + continue + } + + methodType := methodSel.Type().(*types.Signature) + + fn2 := pkg2.Prog.NewFunction(method.Name(), methodType, "callgraph") + + AddEdge(g.CreateNode(fn), instrt, g.CreateNode(fn2)) + default: + continue + } + } + } + } + } + case *ssa.MakeClosure: + switch calltFn := callt.Fn.(type) { + case *ssa.Function: + instrCall = calltFn + } + case *ssa.Parameter: + // This is likely a method call, so we need to + // get the function from the method receiver which + // is not available directly from the call instruction, + // but rather from the package level function. + + // Skip this instruction if we could not determine + // the function being called. + if !instrt.Call.IsInvoke() || (instrt.Call.Method == nil) { + return nil + } + + // TODO: should we share the resulting function? + pkg := root.Prog.ImportedPackage(instrt.Call.Method.Pkg().Path()) + fn := pkg.Prog.NewFunction(instrt.Call.Method.Name(), instrt.Call.Signature(), "callgraph") + instrCall = fn + } + + // If we could not determine the function being + // called, skip this instruction. + if instrCall == nil { + return nil + } + + AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall)) + + err := g.AddFunction(instrCall, allFns) + if err != nil { + return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err) + } + + // attempt to link function arguments that are functions + for a := 0; a < len(instrt.Call.Args); a++ { + arg := instrt.Call.Args[a] + switch argt := arg.(type) { + case *ssa.Function: + // TODO: check if edge already exists? + AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt)) + case *ssa.MakeClosure: + switch argtFn := argt.Fn.(type) { + case *ssa.Function: + AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn)) + } + } + } + } + return nil +} + // AddFunction analyzes the given target SSA function, adding information to the call graph. // https://cs.opensource.google/go/x/tools/+/master:cmd/guru/callers.go;drc=3e0d083b858b3fdb7d095b5a3deb184aa0a5d35e;bpv=1;bpt=1;l=90 func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool) error { diff --git a/check.go b/check.go index d60c374..cf70518 100644 --- a/check.go +++ b/check.go @@ -437,6 +437,12 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va if tainted { return true, src, tv } + case *ssa.ChangeInterface: + // Check the value being changed into an interface. + tainted, src, tv := checkSSAValue(path, sources, value.X, visited) + if tainted { + return true, src, tv + } case *ssa.Convert: // Check the value being converted. tainted, src, tv := checkSSAValue(path, sources, value.X, visited) diff --git a/cmd/xss/main.go b/cmd/xss/main.go new file mode 100644 index 0000000..e6eb2dc --- /dev/null +++ b/cmd/xss/main.go @@ -0,0 +1,10 @@ +package main + +import ( + "github.com/picatz/taint/xss" + "golang.org/x/tools/go/analysis/singlechecker" +) + +func main() { + singlechecker.Main(xss.Analyzer) +} diff --git a/xss/testdata/src/a/main.go b/xss/testdata/src/a/main.go new file mode 100644 index 0000000..185ab1c --- /dev/null +++ b/xss/testdata/src/a/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "net/http" +) + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(r.URL.Query().Get("input"))) // want "potential XSS" + }) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/testdata/src/b/main.go b/xss/testdata/src/b/main.go new file mode 100644 index 0000000..2b1177d --- /dev/null +++ b/xss/testdata/src/b/main.go @@ -0,0 +1,19 @@ +package main + +import ( + "net/http" +) + +func mirror(w http.ResponseWriter, r *http.Request) { + input := r.URL.Query().Get("input") + + b := []byte(input) + + w.Write(b) // want "potential XSS" +} + +func main() { + http.HandleFunc("/", mirror) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/testdata/src/c/main.go b/xss/testdata/src/c/main.go new file mode 100644 index 0000000..8fe9a37 --- /dev/null +++ b/xss/testdata/src/c/main.go @@ -0,0 +1,53 @@ +package main + +import ( + "bufio" + "io" + "net/http" +) + +func buffer(r io.Reader) io.Reader { + return bufio.NewReader(r) +} + +func mirror(w http.ResponseWriter, r *http.Request) { + _, err := io.Copy(w, buffer(r.Body)) // want "potential XSS" + if err != nil { + panic(err) + } +} + +func mirror2(w http.ResponseWriter, r *http.Request) { + _, err := io.WriteString(w, r.URL.Query().Get("q")) // want "potential XSS" + if err != nil { + panic(err) + } +} + +func mirror3(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(r.URL.Query().Get("q"))) // want "potential XSS" + if err != nil { + panic(err) + } +} + +func mirror4(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + + _, err = w.Write(b) // want "potential XSS" + if err != nil { + panic(err) + } +} + +func main() { + http.HandleFunc("/1", mirror) + http.HandleFunc("/2", mirror2) + http.HandleFunc("/3", mirror3) + http.HandleFunc("/4", mirror4) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/testdata/src/d/main.go b/xss/testdata/src/d/main.go new file mode 100644 index 0000000..918e959 --- /dev/null +++ b/xss/testdata/src/d/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "html" + "io" + "net/http" +) + +func mirrorSafe(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + if err != nil { + panic(err) + } + + str := html.EscapeString(string(b)) + + _, err = w.Write([]byte(str)) // safe + if err != nil { + panic(err) + } +} + +func main() { + http.HandleFunc("/mirror-safe", mirrorSafe) + + http.ListenAndServe(":8080", nil) +} diff --git a/xss/xss.go b/xss/xss.go new file mode 100644 index 0000000..27b5544 --- /dev/null +++ b/xss/xss.go @@ -0,0 +1,152 @@ +package xss + +import ( + "fmt" + "strings" + + "github.com/picatz/taint" + "github.com/picatz/taint/callgraph" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/buildssa" + "golang.org/x/tools/go/ssa" +) + +var userControlledValues = taint.NewSources( + "*net/http.Request", +) + +var injectableFunctions = taint.NewSinks( + // Note: at this time, they *must* be a function or method. + "(net/http.ResponseWriter).Write", + "(net/http.ResponseWriter).WriteHeader", +) + +// Analyzer finds potential XSS issues. +var Analyzer = &analysis.Analyzer{ + Name: "xss", + Doc: "finds potential XSS issues", + Run: run, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, +} + +// imports returns true if the package imports any of the given packages. +func imports(pass *analysis.Pass, pkgs ...string) bool { + var imported bool + for _, imp := range pass.Pkg.Imports() { + for _, pkg := range pkgs { + if strings.HasSuffix(imp.Path(), pkg) { + imported = true + break + } + } + if imported { + break + } + } + return imported +} + +func run(pass *analysis.Pass) (interface{}, error) { + // Require the log package is imported in the + // program being analyzed before running the analysis. + // + // This prevents wasting time analyzing programs that don't log. + if !imports(pass, "net/http") { + return nil, nil + } + + // Get the built SSA IR. + buildSSA := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + + // Identify the main function from the package's SSA IR. + mainFn := buildSSA.Pkg.Func("main") + if mainFn == nil { + return nil, nil + } + + // Construct a callgraph, using the main function as the root, + // constructed of all other functions. This returns a callgraph + // we can use to identify directed paths to logging functions. + cg, err := callgraph.New(mainFn, buildSSA.SrcFuncs...) + if err != nil { + return nil, fmt.Errorf("failed to create new callgraph: %w", err) + } + + // Run taint check for user controlled values (sources) ending + // up in injectable log functions (sinks). + results := taint.Check(cg, userControlledValues, injectableFunctions) + + for _, result := range results { + // Check if html.EscapeString was called on the source value + // before it was passed to the sink. + var escaped bool + for _, edge := range result.Path { + for _, arg := range edge.Site.Common().Args { + if checkIfHTMLEscapeString(arg) { + escaped = true + break + } + } + if escaped { + break + } + } + + if !escaped { + pass.Reportf(result.SinkValue.Pos(), "potential XSS") + } + } + + return nil, nil +} + +// checkIfHTMLEscapeString returns true if the given value uses +// html.EscapeString, calling itself recursively as needed. +func checkIfHTMLEscapeString(value ssa.Value) bool { + switch value := value.(type) { + case *ssa.Call: + return value.Call.Value.String() == "html.EscapeString" + case *ssa.MakeInterface: + return checkIfHTMLEscapeString(value.X) + case *ssa.ChangeInterface: + return checkIfHTMLEscapeString(value.X) + case *ssa.Convert: + return checkIfHTMLEscapeString(value.X) + case *ssa.UnOp: + return checkIfHTMLEscapeString(value.X) + case *ssa.Phi: + for _, edge := range value.Edges { + if checkIfHTMLEscapeString(edge) { + return true + } + } + return false + case *ssa.Alloc: + refs := value.Referrers() + if refs == nil { + return false + } + for _, instr := range *refs { + for _, opr := range instr.Operands(nil) { + if checkIfHTMLEscapeString(*opr) { + return true + } + } + } + case *ssa.FieldAddr: + return checkIfHTMLEscapeString(value.X) + case *ssa.Field: + return checkIfHTMLEscapeString(value.X) + case *ssa.IndexAddr: + return checkIfHTMLEscapeString(value.X) + case *ssa.Index: + return checkIfHTMLEscapeString(value.X) + case *ssa.Lookup: + return checkIfHTMLEscapeString(value.X) + case *ssa.Slice: + return checkIfHTMLEscapeString(value.X) + } + + return false +} diff --git a/xss/xss_test.go b/xss/xss_test.go new file mode 100644 index 0000000..c2ff574 --- /dev/null +++ b/xss/xss_test.go @@ -0,0 +1,25 @@ +package xss + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +var testdata = analysistest.TestData() + +func TestA(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "a") +} + +func TestB(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "b") +} + +func TestC(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "c") +} + +func TestD(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "d") +}