diff --git a/check.go b/check.go index cbad798..349349f 100644 --- a/check.go +++ b/check.go @@ -524,6 +524,25 @@ func checkSSAValue(path callgraph.Path, sources Sources, v ssa.Value, visited va continue } + tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) + if tainted { + return true, src, tv + } + } + } + case *ssa.MakeMap: + refs := value.Referrers() + if refs != nil { + for _, ref := range *refs { + refVal, isVal := ref.(ssa.Value) + if isVal { + tainted, src, tv := checkSSAValue(path, sources, refVal, visited) + if tainted { + return true, src, tv + } + continue + } + tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) if tainted { return true, src, tv @@ -566,6 +585,18 @@ func checkSSAInstruction(path callgraph.Path, sources Sources, i ssa.Instruction return true, src, tv } } + case *ssa.MapUpdate: + // Map update instructions need to be checked for both the map being updated, + // and the key and value being updated. + tainted, src, tv := checkSSAValue(path, sources, instr.Key, visited) + if tainted { + return true, src, tv + } + + tainted, src, tv = checkSSAValue(path, sources, instr.Value, visited) + if tainted { + return true, src, tv + } default: // fmt.Printf("? check SSA instr %s: %[1]T\n", i) return false, "", nil diff --git a/log/injection/injection.go b/log/injection/injection.go index feea5f8..c5a2a55 100644 --- a/log/injection/injection.go +++ b/log/injection/injection.go @@ -43,6 +43,34 @@ var injectableLogFunctions = taint.NewSinks( "(*log.Logger).SetOutput", "(*log.Logger).SetPrefix", "(*log.Logger).Writer", + + // log/slog (structured logging) + // https://pkg.go.dev/log/slog + "log/slog.Debug", + "log/slog.DebugContext", + "log/slog.Error", + "log/slog.ErrorContext", + "log/slog.Info", + "log/slog.InfoContext", + "log/slog.Warn", + "log/slog.WarnContext", + "log/slog.Log", + "log/slog.LogAttrs", + "(*log/slog.Logger).With", + "(*log/slog.Logger).Debug", + "(*log/slog.Logger).DebugContext", + "(*log/slog.Logger).Error", + "(*log/slog.Logger).ErrorContext", + "(*log/slog.Logger).Info", + "(*log/slog.Logger).InfoContext", + "(*log/slog.Logger).Warn", + "(*log/slog.Logger).WarnContext", + "(*log/slog.Logger).Log", + "(*log/slog.Logger).LogAttrs", + "log/slog.NewRecord", + "(*log/slog.Record).Add", + "(*log/slog.Record).AddAttrs", + // TODO: consider adding the following logger packages, // and the ability to configure this list generically. // @@ -85,7 +113,7 @@ func run(pass *analysis.Pass) (interface{}, error) { // program being analyzed before running the analysis. // // This prevents wasting time analyzing programs that don't log. - if !imports(pass, "log") { + if !imports(pass, "log", "log/slog") { return nil, nil } diff --git a/log/injection/injection_test.go b/log/injection/injection_test.go index 45078cf..19362da 100644 --- a/log/injection/injection_test.go +++ b/log/injection/injection_test.go @@ -19,3 +19,19 @@ func TestB(t *testing.T) { func TestC(t *testing.T) { analysistest.Run(t, testdata, Analyzer, "c") } + +func TestD(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "d") +} + +func TestE(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "e") +} + +func TestF(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "f") +} + +func TestG(t *testing.T) { + analysistest.Run(t, testdata, Analyzer, "g") +} diff --git a/log/injection/testdata/src/d/main.go b/log/injection/testdata/src/d/main.go new file mode 100644 index 0000000..217af84 --- /dev/null +++ b/log/injection/testdata/src/d/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "log/slog" + "net/http" +) + +func l(input string) { + slog.Info(input) // want "potential log injection" +} + +func buisness(input string) { + l(input) +} + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + input := r.URL.Query().Get("input") + + f := func() { + buisness(input) + } + + f() + }) + + http.ListenAndServe(":8080", nil) +} diff --git a/log/injection/testdata/src/e/main.go b/log/injection/testdata/src/e/main.go new file mode 100644 index 0000000..f683bf0 --- /dev/null +++ b/log/injection/testdata/src/e/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "log/slog" + "net/http" + "os" +) + +var logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, +})) + +func l(input string) { + logger.InfoContext(context.Background(), "l", "input", input) // want "potential log injection" +} + +func buisness(input string) { + l(input) +} + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + input := r.URL.Query().Get("input") + + f := func() { + buisness(input) + } + + f() + }) + + http.ListenAndServe(":8080", nil) +} diff --git a/log/injection/testdata/src/f/main.go b/log/injection/testdata/src/f/main.go new file mode 100644 index 0000000..d9e35ec --- /dev/null +++ b/log/injection/testdata/src/f/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "log/slog" + "net/http" + "os" +) + +var logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, +})) + +func l(input string) { + logger.InfoContext(context.Background(), "l", "input", map[string]string{"value": input}) // want "potential log injection" +} + +func buisness(input string) { + l(input) +} + +func main() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + input := r.URL.Query().Get("input") + + f := func() { + buisness(input) + } + + f() + }) + + http.ListenAndServe(":8080", nil) +} diff --git a/log/injection/testdata/src/g/main.go b/log/injection/testdata/src/g/main.go new file mode 100644 index 0000000..2da520f --- /dev/null +++ b/log/injection/testdata/src/g/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "context" + "log/slog" + "net/http" + "os" +) + +func l(logger *slog.Logger, input string) { + logger2 := logger.With("input", input).WithGroup("l") // want "potential log injection" + + logger2.InfoContext(context.Background(), "l", "input", []string{input}) // want "potential log injection" +} + +func buisness(logger *slog.Logger, input string) { + l(logger, input) +} + +func main() { + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + input := r.URL.Query().Get("input") + + f := func() { + buisness(logger, input) + } + + f() + }) + + http.ListenAndServe(":8080", nil) +}