diff --git a/middleware/signing.go b/middleware/signing.go new file mode 100644 index 0000000..bbe7dbb --- /dev/null +++ b/middleware/signing.go @@ -0,0 +1,58 @@ +package middleware + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" +) + +const HEADER_SIGNATURE = "X-Skill-Signature-256" + +func NewSigning(signingKeys []string) (func(http.Handler) http.Handler, error) { + if len(signingKeys) == 0 { + return nil, fmt.Errorf("no signing keys provided") + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) { + signature := request.Header.Get(HEADER_SIGNATURE) + if signature == "" { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + body, _ := io.ReadAll(request.Body) + // Replace the body with a new reader after reading from the original + request.Body = io.NopCloser(bytes.NewBuffer(body)) + + if !signatureIsValid(signature, signingKeys, body) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, request) + }) + }, nil +} + +func signatureIsValid(signature string, signingKeys []string, body []byte) bool { + for _, key := range signingKeys { + hmac := hmac.New(sha256.New, []byte(key)) + + // compute the HMAC + hmac.Write(body) + dataHmac := hmac.Sum(nil) + + hmacHex := hex.EncodeToString(dataHmac) + + if hmacHex == signature { + return true + } + + } + return false +} diff --git a/middleware/signing_test.go b/middleware/signing_test.go new file mode 100644 index 0000000..235b7b9 --- /dev/null +++ b/middleware/signing_test.go @@ -0,0 +1,49 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestSignatureValidation(t *testing.T) { + var ( + signingKey = "It's a Secret to Everybody" + payload = "Hello, World!" + ) + + if !signatureIsValid("757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17", []string{signingKey}, []byte(payload)) { + t.Error("signature should be valid") + } +} + +func TestSignatureHandler(t *testing.T) { + var ( + signingKey = "It's a Secret to Everybody" + payload = "Hello, World!" + signature = "757107ea0eb2509fc211221cce984b8a37570b6d7586c22c46f4379c8b043e17" + rr = httptest.NewRecorder() + ) + + // Create a signed request + req := httptest.NewRequest("POST", "/", strings.NewReader(payload)) + req.Header.Set(HEADER_SIGNATURE, signature) + + // Create a handler that will validate the signature + signingMiddleware, err := NewSigning([]string{signingKey}) + if err != nil { + t.Error("error creating signing middleware") + } + + // Create and call the handler + handler := signingMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + handler.ServeHTTP(rr, req) + + // Check the status code + if rr.Code != 200 { + t.Errorf("status code should be 200, got: %d", rr.Code) + } +}