Skip to content

Commit

Permalink
Add Rules options to mux to specify custom HTTP rules (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane authored Sep 25, 2023
1 parent 9eb3970 commit c5eb1db
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 12 deletions.
83 changes: 72 additions & 11 deletions vanguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"math"
"net/http"
"sort"
"strings"
"sync"
"time"

"connectrpc.com/connect"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
Expand Down Expand Up @@ -275,6 +277,57 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser
return nil
}

// RegisterRules is the set of HTTP rules that apply to RPC methods via selectors.
// The rules are used in addition to any rules defined in the service's proto
// file.
//
// Services should be registered first. If a given rule doesn't match any
// already-registered method, an error is returned.
// See: https://cloud.google.com/service-infrastructure/docs/service-management/reference/rpc/google.api#google.api.DocumentationRule.FIELDS.string.google.api.DocumentationRule.selector
func (m *Mux) RegisterRules(rules ...*annotations.HttpRule) error {
m.maybeInit()
if len(rules) == 0 {
return nil
}
methodRules := make(map[*methodConfig][]*annotations.HttpRule)
for _, rule := range rules {
var applied bool
selector := rule.GetSelector()
if selector == "" {
return fmt.Errorf("rule missing selector")
}
if i := strings.Index(selector, "*"); i >= 0 {
if i != len(selector)-1 {
return fmt.Errorf("wildcard selector %q must be at the end", rule.GetSelector())
}
selector = selector[:len(selector)-1]
if len(selector) > 0 && !strings.HasSuffix(selector, ".") {
return fmt.Errorf("wildcard selector %q must be whole component", rule.GetSelector())
}
}
for _, methodConf := range m.methods {
methodName := string(methodConf.descriptor.FullName())
if !strings.HasPrefix(methodName, selector) {
continue
}
methodRules[methodConf] = append(methodRules[methodConf], rule)
applied = true
}
if !applied {
return fmt.Errorf("rule %q does not match any methods", rule.GetSelector())
}
}
for methodConf, rules := range methodRules {
for _, rule := range rules {
if err := m.addRule(rule, methodConf); err != nil {
// TODO: use the multi-error type errors.Join()
return err
}
}
}
return nil
}

// AddCodec adds the given codec implementation.
//
// By default, the mux already understands "proto", "json", and "text" codecs. The
Expand Down Expand Up @@ -336,18 +389,26 @@ func (m *Mux) registerMethod(handler http.Handler, methodDesc protoreflect.Metho
}

if httpRule, ok := getHTTPRuleExtension(methodDesc); ok {
firstTarget, err := m.restRoutes.addRoute(methodConf, httpRule)
if err != nil {
return fmt.Errorf("failed to add REST route for method %s: %w", methodPath, err)
if err := m.addRule(httpRule, methodConf); err != nil {
return err
}
methodConf.httpRule = firstTarget
for i, rule := range httpRule.AdditionalBindings {
if len(rule.AdditionalBindings) > 0 {
return fmt.Errorf("nested additional bindings are not supported (method %s)", methodPath)
}
if _, err := m.restRoutes.addRoute(methodConf, rule); err != nil {
return fmt.Errorf("failed to add REST route (add'l binding #%d) for method %s: %w", i+1, methodPath, err)
}
}
return nil
}

func (m *Mux) addRule(httpRule *annotations.HttpRule, methodConf *methodConfig) error {
methodPath := methodConf.methodPath
firstTarget, err := m.restRoutes.addRoute(methodConf, httpRule)
if err != nil {
return fmt.Errorf("failed to add REST route for method %s: %w", methodPath, err)
}
methodConf.httpRule = firstTarget
for i, rule := range httpRule.AdditionalBindings {
if len(rule.AdditionalBindings) > 0 {
return fmt.Errorf("nested additional bindings are not supported (method %s)", methodPath)
}
if _, err := m.restRoutes.addRoute(methodConf, rule); err != nil {
return fmt.Errorf("failed to add REST route (add'l binding #%d) for method %s: %w", i+1, methodPath, err)
}
}
return nil
Expand Down
87 changes: 86 additions & 1 deletion vanguard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"connectrpc.com/connect"
"connectrpc.com/vanguard/internal/gen/vanguard/test/v1"
testv1 "connectrpc.com/vanguard/internal/gen/vanguard/test/v1"
"connectrpc.com/vanguard/internal/gen/vanguard/test/v1/testv1connect"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/api/annotations"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
Expand Down Expand Up @@ -1570,6 +1572,89 @@ func TestMux_HookOrder(t *testing.T) {
}
}

func TestRuleSelector(t *testing.T) {
t.Parallel()

var interceptor testInterceptor
serveMux := http.NewServeMux()
serveMux.Handle(testv1connect.NewLibraryServiceHandler(
testv1connect.UnimplementedLibraryServiceHandler{},
connect.WithInterceptors(&interceptor),
))
mux := &Mux{}
assert.NoError(t, mux.RegisterServiceByName(serveMux, testv1connect.LibraryServiceName))

assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{
Selector: "grpc.health.v1.Health.Check",
Pattern: &annotations.HttpRule_Get{
Get: "/healthz",
},
}), "rule \"grpc.health.v1.Health.Check\" does not match any methods")
assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{
Selector: "invalid.*.Get",
Pattern: &annotations.HttpRule_Get{
Get: "/v1/*",
},
}), "wildcard selector \"invalid.*.Get\" must be at the end")
assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{
Selector: "grpc.health.v1.Health.*",
Pattern: &annotations.HttpRule_Get{
Get: "/healthz",
},
}), "rule \"grpc.health.v1.Health.*\" does not match any methods")
assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{
Pattern: &annotations.HttpRule_Get{
Get: "/v1/*",
},
}), "rule missing selector")

assert.NoError(t, mux.RegisterRules(&annotations.HttpRule{
Selector: "vanguard.test.v1.LibraryService.GetBook",
Pattern: &annotations.HttpRule_Get{
Get: "/v1/selector/{name=shelves/*/books/*}",
},
}))

ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/selector/shelves/123/books/456", http.NoBody)
require.NoError(t, err)
req.Header.Set("Message", "hello")
req.Header.Set("Test", t.Name()) // for interceptor
req.Header.Set("Content-Type", "application/json")
rsp := httptest.NewRecorder()

interceptor.set(t, testStream{
method: testv1connect.LibraryServiceGetBookProcedure,
reqHeader: http.Header{
"Message": []string{"hello"},
},
rspHeader: http.Header{
"Message": []string{"world"},
},
msgs: []testMsg{
{in: &testMsgIn{
msg: &testv1.GetBookRequest{Name: "shelves/123/books/456"},
}},
{out: &testMsgOut{
msg: &testv1.Book{Name: "shelves/123/books/456"},
}},
},
})
defer interceptor.del(t)

mux.AsHandler().ServeHTTP(rsp, req)
result := rsp.Result()
defer result.Body.Close()

dump, err := httputil.DumpResponse(result, true)
require.NoError(t, err)
t.Log(string(dump))

assert.Equal(t, http.StatusOK, result.StatusCode)
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
assert.Equal(t, "world", result.Header.Get("Message"))
}

type testStream struct {
method string
reqHeader http.Header // expected
Expand Down

0 comments on commit c5eb1db

Please sign in to comment.