diff --git a/versionware/handler.go b/versionware/handler.go new file mode 100644 index 00000000..4c100967 --- /dev/null +++ b/versionware/handler.go @@ -0,0 +1,105 @@ +// Package versionware provides routing and middleware for building versioned +// HTTP services. +package versionware + +import ( + "fmt" + "net/http" + "sort" + + "github.com/snyk/vervet" +) + +const ( + // HeaderSnykVersionRequested is a response header acknowledging the API + // version that was requested. + HeaderSnykVersionRequested = "snyk-version-requested" + + // HeaderSnykVersionServed is a response header indicating the actual API + // version that was matched and served the response. + HeaderSnykVersionServed = "snyk-version-served" +) + +// Handler is a multiplexing http.Handler that dispatches requests based on the +// version query parameter according to vervet's API version matching rules. +type Handler struct { + handlers []http.Handler + versions vervet.VersionSlice + errFunc VersionErrorHandler +} + +// VersionErrorHandler defines a function which handles versioning error +// responses in requests. +type VersionErrorHandler func(w http.ResponseWriter, r *http.Request, status int, err error) + +// VersionHandler expresses a pairing of Version and http.Handler. +type VersionHandler struct { + Version vervet.Version + Handler http.Handler +} + +// NewHandler returns a new Handler instance, which handles versioned requests +// with the matching version handler. +func NewHandler(vhs ...VersionHandler) *Handler { + h := &Handler{ + handlers: make([]http.Handler, len(vhs)), + versions: make([]vervet.Version, len(vhs)), + errFunc: defaultErrorHandler, + } + handlerVersions := map[string]http.Handler{} + for i := range vhs { + v := vhs[i].Version + h.versions[i] = v + handlerVersions[v.String()] = vhs[i].Handler + } + sort.Sort(h.versions) + for i := range h.versions { + h.handlers[i] = handlerVersions[h.versions[i].String()] + } + return h +} + +func defaultErrorHandler(w http.ResponseWriter, r *http.Request, status int, err error) { + http.Error(w, err.Error(), status) +} + +// HandleErrors changes the default error handler to the provided function. It +// may be used to control the format of versioning error responses. +func (h *Handler) HandleErrors(errFunc VersionErrorHandler) { + h.errFunc = errFunc +} + +// Resolve returns the resolved version and its associated http.Handler for the +// requested version. +func (h *Handler) Resolve(requested vervet.Version) (*vervet.Version, http.Handler, error) { + resolvedIndex, err := h.versions.ResolveIndex(requested) + if err != nil { + return nil, nil, err + } + resolved := h.versions[resolvedIndex] + return &resolved, h.handlers[resolvedIndex], nil +} + +// ServeHTTP implements http.Handler with the handler matching the version +// query parameter on the request. If no matching version is found, responds +// 404. +func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + versionParam := req.URL.Query().Get("version") + if versionParam == "" { + h.errFunc(w, req, http.StatusBadRequest, fmt.Errorf("missing required query parameter 'version'")) + return + } + requested, err := vervet.ParseVersion(versionParam) + if err != nil { + h.errFunc(w, req, http.StatusBadRequest, err) + return + } + resolved, handler, err := h.Resolve(*requested) + if err != nil { + h.errFunc(w, req, http.StatusNotFound, err) + return + } + w.Header().Set(HeaderSnykVersionRequested, requested.String()) + w.Header().Set(HeaderSnykVersionServed, resolved.String()) + handler.ServeHTTP(w, req) +} diff --git a/versionware/handler_test.go b/versionware/handler_test.go new file mode 100644 index 00000000..40d50af3 --- /dev/null +++ b/versionware/handler_test.go @@ -0,0 +1,114 @@ +package versionware_test + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + qt "github.com/frankban/quicktest" + + "github.com/snyk/vervet" + "github.com/snyk/vervet/versionware" +) + +func ExampleHandler() { + h := versionware.NewHandler([]versionware.VersionHandler{{ + Version: vervet.MustParseVersion("2021-10-01"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte("oct")); err != nil { + panic(err) + } + }), + }, { + Version: vervet.MustParseVersion("2021-11-01"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte("nov")); err != nil { + panic(err) + } + }), + }, { + Version: vervet.MustParseVersion("2021-09-01"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte("sept")); err != nil { + panic(err) + } + }), + }}...) + + s := httptest.NewServer(h) + defer s.Close() + + resp, err := s.Client().Get(s.URL + "?version=2021-10-31") + if err != nil { + panic(err) + } + defer resp.Body.Close() + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + panic(err) + } + + fmt.Print(string(respBody)) + // Output: oct +} + +func TestHandler(t *testing.T) { + c := qt.New(t) + h := versionware.NewHandler([]versionware.VersionHandler{{ + Version: vervet.MustParseVersion("2021-10-01"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("oct")) + c.Assert(err, qt.IsNil) + }), + }, { + Version: vervet.MustParseVersion("2021-11-01"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("nov")) + c.Assert(err, qt.IsNil) + }), + }, { + Version: vervet.MustParseVersion("2021-09-01"), + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("sept")) + c.Assert(err, qt.IsNil) + }), + }}...) + s := httptest.NewServer(h) + c.Cleanup(s.Close) + + tests := []struct { + requested, resolved string + contents string + status int + }{{ + "2021-08-31", "", "no matching version\n", 404, + }, { + "bad wolf", "", "400 Bad Request", 400, + }, { + "", "", "missing required query parameter 'version'\n", 400, + }, { + "2021-09-16", "2021-09-01", "sept", 200, + }, { + "2021-10-01", "2021-10-01", "oct", 200, + }, { + "2021-10-31", "2021-10-01", "oct", 200, + }, { + "2021-11-05", "2021-11-01", "nov", 200, + }, { + "2023-02-05", "2021-11-01", "nov", 200, + }} + for i, test := range tests { + c.Logf("test#%d: requested %s resolved %s", i, test.requested, test.resolved) + req, err := http.NewRequest("GET", s.URL+"?version="+test.requested, nil) + c.Assert(err, qt.IsNil) + resp, err := s.Client().Do(req) + c.Assert(err, qt.IsNil) + defer resp.Body.Close() + c.Assert(resp.StatusCode, qt.Equals, test.status) + contents, err := ioutil.ReadAll(resp.Body) + c.Assert(err, qt.IsNil) + c.Assert(string(contents), qt.Equals, test.contents) + } +}