Skip to content

Commit

Permalink
Add XSS analysis (#9)
Browse files Browse the repository at this point in the history
* Add initial implementation of `xss` package and CLI 
* Better handle method call graph construction and concurrency
* Remove unused debug and track interface functions with `*ssa.ChangeInterface`
  • Loading branch information
picatz authored Aug 8, 2023
1 parent 70998b7 commit 470dd40
Show file tree
Hide file tree
Showing 10 changed files with 493 additions and 89 deletions.
63 changes: 60 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,65 @@ func run() {
func main() {
run()
}
$ time sqli main.go
$ sqli main.go
./sql/injection/testdata/src/example/main.go:9:10: potential sql injection
```

### `logi`

The `logi` analyzer is a CLI tool that demonstrates usage of the `taint` package to find
potential log injections.

```console
$ go install github.com/picatz/taint/cmd/logi@latest
```

```console
$ cd log/injection/testdata/src/a
$ cat main.go
package main

import (
"log"
"net/http"
)

func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
log.Println(r.URL.Query().Get("input"))
})

http.ListenAndServe(":8080", nil)
}
$ logi main.go
./log/injection/testdata/src/example/main.go:10:14: potential log injection
```

### `xss`

The `xss` analyzer is a CLI tool that demonstrates usage of the `taint` package to find
potential cross-site scripting (XSS) vulnerabilities.

```console
$ go install github.com/picatz/taint/cmd/xss@latest
```

```console
$ cd xss/testdata/src/a
$ cat main.go
package main

import (
"net/http"
)

sqli main.go 0.12s user 0.15s system 291% cpu 0.094 total
```
func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.URL.Query().Get("input"))) // want "potential XSS"
})

http.ListenAndServe(":8080", nil)
}
$ xss main.go
./xss/testdata/src/example/main.go:9:8: potential XSS
```
214 changes: 128 additions & 86 deletions callgraph/callgraph.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@ import (
"fmt"
"go/token"
"go/types"
"log"
"os"
"runtime"
"sync"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
Expand All @@ -27,111 +23,34 @@ type Graph struct {
sync.RWMutex
Root *Node // the distinguished root node
Nodes map[*ssa.Function]*Node // all nodes by function
debug bool
}

// New returns a new Graph with the specified root node.
func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) {
g := &Graph{
RWMutex: sync.RWMutex{},
Nodes: make(map[*ssa.Function]*Node),
debug: false, // TODO: make configurable with env variable?
}

g.Root = g.CreateNode(root)

eg, _ := errgroup.WithContext(context.Background())

// 500 = 5849.241s
// 10 = 4333.030s
var (
logger *log.Logger
ops int64
total int
)

if g.debug {
logger = log.New(os.Stderr, "callgraph-debug ", log.LstdFlags)
total = len(srcFns)
}

s := semaphore.NewWeighted(10)
eg.SetLimit(runtime.NumCPU())

allFns := ssautil.AllFunctions(root.Prog)

for _, srcFn := range srcFns {
fn := srcFn
err := s.Acquire(context.Background(), 1)
if err != nil {
return nil, fmt.Errorf("failed to aquite semaphore: %w", err)
}
eg.Go(func() error {
defer s.Release(1)

if g.debug {
start := time.Now()
defer func() {
ops := atomic.AddInt64(&ops, 1)
logger.Printf("done processing %v (%v/%v) after %v seconds\n", fn, ops, total, time.Since(start).Seconds())
}()
}

err = g.AddFunction(fn, allFns)
err := g.AddFunction(fn, allFns)
if err != nil {
return fmt.Errorf("failed to add src function %v: %w", fn, err)
}

for _, block := range fn.DomPreorder() {
for _, instr := range block.Instrs {
// debugf("found block instr")
switch instrt := instr.(type) {
case *ssa.Call:
// debugf("found call instr")
var instrCall *ssa.Function

// Handle the case where the function calls a
// named function (*ssa.Function), and the case
// where the function calls an anonymous
// function (*ssa.MakeClosure).
switch callt := instrt.Call.Value.(type) {
case *ssa.Function:
// debugf("found call instr to function")
instrCall = callt
case *ssa.MakeClosure:
// debugf("found call instr to closure")
switch calltFn := callt.Fn.(type) {
case *ssa.Function:
instrCall = calltFn
}
}

// If we could not determine the function being
// called, skip this instruction.
if instrCall == nil {
continue
}

AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall))

err := g.AddFunction(instrCall, allFns)
if err != nil {
return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err)
}

// attempt to link function arguments that are functions
for a := 0; a < len(instrt.Call.Args); a++ {
arg := instrt.Call.Args[a]
switch argt := arg.(type) {
case *ssa.Function:
// TODO: check if edge already exists?
AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt))
case *ssa.MakeClosure:
switch argtFn := argt.Fn.(type) {
case *ssa.Function:
AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn))
}
}
}
}
checkBlockInstruction(root, allFns, g, fn, instr)
}
}
return nil
Expand All @@ -146,6 +65,129 @@ func New(root *ssa.Function, srcFns ...*ssa.Function) (*Graph, error) {
return g, nil
}

func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *Graph, fn *ssa.Function, instr ssa.Instruction) error {
switch instrt := instr.(type) {
case *ssa.Call:
var instrCall *ssa.Function

// TODO: map more things to instrCall?

switch callt := instrt.Call.Value.(type) {
case *ssa.Function:
instrCall = callt

for _, instrtCallArg := range instrt.Call.Args {
switch instrtCallArgt := instrtCallArg.(type) {
case *ssa.ChangeInterface:
// Track type casts through matching interface methods.
//
// # Example
//
// func buffer(r io.Reader) io.Reader {
// return bufio.NewReader(r)
// }
//
// func mirror(w http.ResponseWriter, r *http.Request) {
// _, err := io.Copy(w, buffer(r.Body)) // w is an http.ResponseWriter, convert to io.Writer for io.Copy
// if err != nil {
// panic(err)
// }
// }
//
// io.Copy is called with an io.Writer, but the underlying type is a net/http.ResponseWriter.
//
// n11:net/http.HandleFunc → n1:c.mirror → n5:io.Copy → n6:(io.Writer).Write → n7:(net/http.ResponseWriter).Write
//
switch argtt := instrtCallArgt.Type().Underlying().(type) {
case *types.Interface:
numMethods := argtt.NumMethods()

for i := 0; i < numMethods; i++ {
method := argtt.Method(i)

pkg := root.Prog.ImportedPackage(method.Pkg().Path())
fn := pkg.Prog.NewFunction(method.Name(), method.Type().(*types.Signature), "callgraph")
AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(fn))

switch xType := instrtCallArgt.X.Type().(type) {
case *types.Named:
named := xType

pkg2 := root.Prog.ImportedPackage(named.Obj().Pkg().Path())

methodSet := pkg2.Prog.MethodSets.MethodSet(named)
methodSel := methodSet.Lookup(pkg2.Pkg, method.Name())

if methodSel == nil {
continue
}

methodType := methodSel.Type().(*types.Signature)

fn2 := pkg2.Prog.NewFunction(method.Name(), methodType, "callgraph")

AddEdge(g.CreateNode(fn), instrt, g.CreateNode(fn2))
default:
continue
}
}
}
}
}
case *ssa.MakeClosure:
switch calltFn := callt.Fn.(type) {
case *ssa.Function:
instrCall = calltFn
}
case *ssa.Parameter:
// This is likely a method call, so we need to
// get the function from the method receiver which
// is not available directly from the call instruction,
// but rather from the package level function.

// Skip this instruction if we could not determine
// the function being called.
if !instrt.Call.IsInvoke() || (instrt.Call.Method == nil) {
return nil
}

// TODO: should we share the resulting function?
pkg := root.Prog.ImportedPackage(instrt.Call.Method.Pkg().Path())
fn := pkg.Prog.NewFunction(instrt.Call.Method.Name(), instrt.Call.Signature(), "callgraph")
instrCall = fn
}

// If we could not determine the function being
// called, skip this instruction.
if instrCall == nil {
return nil
}

AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall))

err := g.AddFunction(instrCall, allFns)
if err != nil {
return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err)
}

// attempt to link function arguments that are functions
for a := 0; a < len(instrt.Call.Args); a++ {
arg := instrt.Call.Args[a]
switch argt := arg.(type) {
case *ssa.Function:
// TODO: check if edge already exists?
AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt))
case *ssa.MakeClosure:
switch argtFn := argt.Fn.(type) {
case *ssa.Function:
AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn))
}
}
}
}
return nil
}

// AddFunction analyzes the given target SSA function, adding information to the call graph.
// https://cs.opensource.google/go/x/tools/+/master:cmd/guru/callers.go;drc=3e0d083b858b3fdb7d095b5a3deb184aa0a5d35e;bpv=1;bpt=1;l=90
func (cg *Graph) AddFunction(target *ssa.Function, allFns map[*ssa.Function]bool) error {
Expand Down
6 changes: 6 additions & 0 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va
if tainted {
return true, src, tv
}
case *ssa.ChangeInterface:
// Check the value being changed into an interface.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
if tainted {
return true, src, tv
}
case *ssa.Convert:
// Check the value being converted.
tainted, src, tv := checkSSAValue(path, sources, value.X, visited)
Expand Down
10 changes: 10 additions & 0 deletions cmd/xss/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package main

import (
"github.com/picatz/taint/xss"
"golang.org/x/tools/go/analysis/singlechecker"
)

func main() {
singlechecker.Main(xss.Analyzer)
}
13 changes: 13 additions & 0 deletions xss/testdata/src/a/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package main

import (
"net/http"
)

func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(r.URL.Query().Get("input"))) // want "potential XSS"
})

http.ListenAndServe(":8080", nil)
}
19 changes: 19 additions & 0 deletions xss/testdata/src/b/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

import (
"net/http"
)

func mirror(w http.ResponseWriter, r *http.Request) {
input := r.URL.Query().Get("input")

b := []byte(input)

w.Write(b) // want "potential XSS"
}

func main() {
http.HandleFunc("/", mirror)

http.ListenAndServe(":8080", nil)
}
Loading

0 comments on commit 470dd40

Please sign in to comment.