From c5eb1db0bcd068bc2eea8cab8b6ef577fc9a2733 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:16:38 +0100 Subject: [PATCH] Add Rules options to mux to specify custom HTTP rules (#74) --- vanguard.go | 83 +++++++++++++++++++++++++++++++++++++++------ vanguard_test.go | 87 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 12 deletions(-) diff --git a/vanguard.go b/vanguard.go index 13ffcd3..47fa861 100644 --- a/vanguard.go +++ b/vanguard.go @@ -20,10 +20,12 @@ import ( "math" "net/http" "sort" + "strings" "sync" "time" "connectrpc.com/connect" + "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" @@ -275,6 +277,57 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser return nil } +// RegisterRules is the set of HTTP rules that apply to RPC methods via selectors. +// The rules are used in addition to any rules defined in the service's proto +// file. +// +// Services should be registered first. If a given rule doesn't match any +// already-registered method, an error is returned. +// See: https://cloud.google.com/service-infrastructure/docs/service-management/reference/rpc/google.api#google.api.DocumentationRule.FIELDS.string.google.api.DocumentationRule.selector +func (m *Mux) RegisterRules(rules ...*annotations.HttpRule) error { + m.maybeInit() + if len(rules) == 0 { + return nil + } + methodRules := make(map[*methodConfig][]*annotations.HttpRule) + for _, rule := range rules { + var applied bool + selector := rule.GetSelector() + if selector == "" { + return fmt.Errorf("rule missing selector") + } + if i := strings.Index(selector, "*"); i >= 0 { + if i != len(selector)-1 { + return fmt.Errorf("wildcard selector %q must be at the end", rule.GetSelector()) + } + selector = selector[:len(selector)-1] + if len(selector) > 0 && !strings.HasSuffix(selector, ".") { + return fmt.Errorf("wildcard selector %q must be whole component", rule.GetSelector()) + } + } + for _, methodConf := range m.methods { + methodName := string(methodConf.descriptor.FullName()) + if !strings.HasPrefix(methodName, selector) { + continue + } + methodRules[methodConf] = append(methodRules[methodConf], rule) + applied = true + } + if !applied { + return fmt.Errorf("rule %q does not match any methods", rule.GetSelector()) + } + } + for methodConf, rules := range methodRules { + for _, rule := range rules { + if err := m.addRule(rule, methodConf); err != nil { + // TODO: use the multi-error type errors.Join() + return err + } + } + } + return nil +} + // AddCodec adds the given codec implementation. // // By default, the mux already understands "proto", "json", and "text" codecs. The @@ -336,18 +389,26 @@ func (m *Mux) registerMethod(handler http.Handler, methodDesc protoreflect.Metho } if httpRule, ok := getHTTPRuleExtension(methodDesc); ok { - firstTarget, err := m.restRoutes.addRoute(methodConf, httpRule) - if err != nil { - return fmt.Errorf("failed to add REST route for method %s: %w", methodPath, err) + if err := m.addRule(httpRule, methodConf); err != nil { + return err } - methodConf.httpRule = firstTarget - for i, rule := range httpRule.AdditionalBindings { - if len(rule.AdditionalBindings) > 0 { - return fmt.Errorf("nested additional bindings are not supported (method %s)", methodPath) - } - if _, err := m.restRoutes.addRoute(methodConf, rule); err != nil { - return fmt.Errorf("failed to add REST route (add'l binding #%d) for method %s: %w", i+1, methodPath, err) - } + } + return nil +} + +func (m *Mux) addRule(httpRule *annotations.HttpRule, methodConf *methodConfig) error { + methodPath := methodConf.methodPath + firstTarget, err := m.restRoutes.addRoute(methodConf, httpRule) + if err != nil { + return fmt.Errorf("failed to add REST route for method %s: %w", methodPath, err) + } + methodConf.httpRule = firstTarget + for i, rule := range httpRule.AdditionalBindings { + if len(rule.AdditionalBindings) > 0 { + return fmt.Errorf("nested additional bindings are not supported (method %s)", methodPath) + } + if _, err := m.restRoutes.addRoute(methodConf, rule); err != nil { + return fmt.Errorf("failed to add REST route (add'l binding #%d) for method %s: %w", i+1, methodPath, err) } } return nil diff --git a/vanguard_test.go b/vanguard_test.go index 99c72a5..93b26dc 100644 --- a/vanguard_test.go +++ b/vanguard_test.go @@ -22,6 +22,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/http/httputil" "strings" "sync" "sync/atomic" @@ -29,11 +30,12 @@ import ( "time" "connectrpc.com/connect" - "connectrpc.com/vanguard/internal/gen/vanguard/test/v1" + testv1 "connectrpc.com/vanguard/internal/gen/vanguard/test/v1" "connectrpc.com/vanguard/internal/gen/vanguard/test/v1/testv1connect" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" @@ -1570,6 +1572,89 @@ func TestMux_HookOrder(t *testing.T) { } } +func TestRuleSelector(t *testing.T) { + t.Parallel() + + var interceptor testInterceptor + serveMux := http.NewServeMux() + serveMux.Handle(testv1connect.NewLibraryServiceHandler( + testv1connect.UnimplementedLibraryServiceHandler{}, + connect.WithInterceptors(&interceptor), + )) + mux := &Mux{} + assert.NoError(t, mux.RegisterServiceByName(serveMux, testv1connect.LibraryServiceName)) + + assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{ + Selector: "grpc.health.v1.Health.Check", + Pattern: &annotations.HttpRule_Get{ + Get: "/healthz", + }, + }), "rule \"grpc.health.v1.Health.Check\" does not match any methods") + assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{ + Selector: "invalid.*.Get", + Pattern: &annotations.HttpRule_Get{ + Get: "/v1/*", + }, + }), "wildcard selector \"invalid.*.Get\" must be at the end") + assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{ + Selector: "grpc.health.v1.Health.*", + Pattern: &annotations.HttpRule_Get{ + Get: "/healthz", + }, + }), "rule \"grpc.health.v1.Health.*\" does not match any methods") + assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{ + Pattern: &annotations.HttpRule_Get{ + Get: "/v1/*", + }, + }), "rule missing selector") + + assert.NoError(t, mux.RegisterRules(&annotations.HttpRule{ + Selector: "vanguard.test.v1.LibraryService.GetBook", + Pattern: &annotations.HttpRule_Get{ + Get: "/v1/selector/{name=shelves/*/books/*}", + }, + })) + + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/selector/shelves/123/books/456", http.NoBody) + require.NoError(t, err) + req.Header.Set("Message", "hello") + req.Header.Set("Test", t.Name()) // for interceptor + req.Header.Set("Content-Type", "application/json") + rsp := httptest.NewRecorder() + + interceptor.set(t, testStream{ + method: testv1connect.LibraryServiceGetBookProcedure, + reqHeader: http.Header{ + "Message": []string{"hello"}, + }, + rspHeader: http.Header{ + "Message": []string{"world"}, + }, + msgs: []testMsg{ + {in: &testMsgIn{ + msg: &testv1.GetBookRequest{Name: "shelves/123/books/456"}, + }}, + {out: &testMsgOut{ + msg: &testv1.Book{Name: "shelves/123/books/456"}, + }}, + }, + }) + defer interceptor.del(t) + + mux.AsHandler().ServeHTTP(rsp, req) + result := rsp.Result() + defer result.Body.Close() + + dump, err := httputil.DumpResponse(result, true) + require.NoError(t, err) + t.Log(string(dump)) + + assert.Equal(t, http.StatusOK, result.StatusCode) + assert.Equal(t, "application/json", result.Header.Get("Content-Type")) + assert.Equal(t, "world", result.Header.Get("Message")) +} + type testStream struct { method string reqHeader http.Header // expected