From 31b136e24b8b4c36b13771799735d8c7ac160a1c Mon Sep 17 00:00:00 2001 From: Roey Berman Date: Wed, 20 Mar 2024 21:38:58 -0700 Subject: [PATCH] Add support for callback header (#8) Exposed callback headers as defined in this recent change to the API: https://github.com/nexus-rpc/api/pull/4 --- nexus/api.go | 27 ++++++++++++++++++++------- nexus/client.go | 13 +++++++++---- nexus/completion.go | 2 +- nexus/handle.go | 2 +- nexus/options.go | 3 +++ nexus/server.go | 9 +++++---- nexus/start_test.go | 15 +++++++++++++-- 7 files changed, 52 insertions(+), 19 deletions(-) diff --git a/nexus/api.go b/nexus/api.go index c1f573f..671d0b9 100644 --- a/nexus/api.go +++ b/nexus/api.go @@ -114,12 +114,13 @@ func (h Header) Get(k string) string { return h[strings.ToLower(k)] } -func httpHeaderToContentHeader(httpHeader http.Header) Header { +func prefixStrippedHTTPHeaderToNexusHeader(httpHeader http.Header, prefix string) Header { header := Header{} for k, v := range httpHeader { - if strings.HasPrefix(k, "Content-") { + lowerK := strings.ToLower(k) + if strings.HasPrefix(lowerK, prefix) { // Nexus headers can only have single values, ignore multiple values. - header[strings.ToLower(k[8:])] = v[0] + header[lowerK[len(prefix):]] = v[0] } } return header @@ -132,13 +133,25 @@ func addContentHeaderToHTTPHeader(nexusHeader Header, httpHeader http.Header) ht return httpHeader } -func httpHeaderToNexusHeader(httpHeader http.Header) Header { +func addCallbackHeaderToHTTPHeader(nexusHeader Header, httpHeader http.Header) http.Header { + for k, v := range nexusHeader { + httpHeader.Set("Nexus-Callback-"+k, v) + } + return httpHeader +} + +func httpHeaderToNexusHeader(httpHeader http.Header, excludePrefixes ...string) Header { header := Header{} +headerLoop: for k, v := range httpHeader { - if !strings.HasPrefix(k, "Content-") { - // Nexus headers can only have single values, ignore multiple values. - header[strings.ToLower(k)] = v[0] + lowerK := strings.ToLower(k) + for _, prefix := range excludePrefixes { + if strings.HasPrefix(lowerK, prefix) { + continue headerLoop + } } + // Nexus headers can only have single values, ignore multiple values. + header[lowerK] = v[0] } return header } diff --git a/nexus/client.go b/nexus/client.go index be6379e..8299738 100644 --- a/nexus/client.go +++ b/nexus/client.go @@ -193,6 +193,7 @@ func (c *Client) StartOperation(ctx context.Context, operation string, input any request.Header.Set(headerRequestID, options.RequestID) request.Header.Set(headerUserAgent, userAgent) addContentHeaderToHTTPHeader(reader.Header, request.Header) + addCallbackHeaderToHTTPHeader(options.CallbackHeader, request.Header) response, err := c.options.HTTPCaller(request) if err != nil { @@ -205,7 +206,7 @@ func (c *Client) StartOperation(ctx context.Context, operation string, input any serializer: c.options.Serializer, Reader: &Reader{ response.Body, - httpHeaderToContentHeader(response.Header), + prefixStrippedHTTPHeaderToNexusHeader(response.Header, "content-"), }, }, }, nil @@ -259,6 +260,9 @@ type ExecuteOperationOptions struct { // Even though Client.ExecuteOperation waits for operation completion, some applications may want to set this // callback as a fallback mechanism. CallbackURL string + // Optional header fields set by a client that are required to be attached to the callback request when an + // asynchronous operation completes. + CallbackHeader Header // Request ID that may be used by the server handler to dedupe this start request. // By default a v4 UUID will be generated by the client. RequestID string @@ -289,9 +293,10 @@ type ExecuteOperationOptions struct { // free up the underlying connection. func (c *Client) ExecuteOperation(ctx context.Context, operation string, input any, options ExecuteOperationOptions) (*LazyValue, error) { so := StartOperationOptions{ - CallbackURL: options.CallbackURL, - RequestID: options.RequestID, - Header: options.Header, + CallbackURL: options.CallbackURL, + CallbackHeader: options.CallbackHeader, + RequestID: options.RequestID, + Header: options.Header, } result, err := c.StartOperation(ctx, operation, input, so) if err != nil { diff --git a/nexus/completion.go b/nexus/completion.go index 9ef7bc1..0dc5080 100644 --- a/nexus/completion.go +++ b/nexus/completion.go @@ -178,7 +178,7 @@ func (h *completionHTTPHandler) ServeHTTP(writer http.ResponseWriter, request *h serializer: h.options.Serializer, Reader: &Reader{ request.Body, - httpHeaderToContentHeader(request.Header), + prefixStrippedHTTPHeaderToNexusHeader(request.Header, "content-"), }, } default: diff --git a/nexus/handle.go b/nexus/handle.go index ca59f26..d05ef94 100644 --- a/nexus/handle.go +++ b/nexus/handle.go @@ -107,7 +107,7 @@ func (h *OperationHandle[T]) GetResult(ctx context.Context, options GetOperation serializer: h.client.options.Serializer, Reader: &Reader{ response.Body, - httpHeaderToContentHeader(response.Header), + prefixStrippedHTTPHeaderToNexusHeader(response.Header, "content-"), }, } if _, ok := any(result).(*LazyValue); ok { diff --git a/nexus/options.go b/nexus/options.go index 4247214..687a5f9 100644 --- a/nexus/options.go +++ b/nexus/options.go @@ -18,6 +18,9 @@ type StartOperationOptions struct { // // Implement a [CompletionHandler] and expose it as an HTTP handler to handle async completions. CallbackURL string + // Optional header fields set by a client that are required to be attached to the callback request when an + // asynchronous operation completes. + CallbackHeader Header // Request ID that may be used by the server handler to dedupe a start request. // By default a v4 UUID will be generated by the client. RequestID string diff --git a/nexus/server.go b/nexus/server.go index 444d5c3..3162495 100644 --- a/nexus/server.go +++ b/nexus/server.go @@ -276,15 +276,16 @@ func (h *httpHandler) startOperation(writer http.ResponseWriter, request *http.R return } options := StartOperationOptions{ - RequestID: request.Header.Get(headerRequestID), - CallbackURL: request.URL.Query().Get(queryCallbackURL), - Header: httpHeaderToNexusHeader(request.Header), + RequestID: request.Header.Get(headerRequestID), + CallbackURL: request.URL.Query().Get(queryCallbackURL), + CallbackHeader: prefixStrippedHTTPHeaderToNexusHeader(request.Header, "nexus-callback-"), + Header: httpHeaderToNexusHeader(request.Header, "content-", "nexus-callback-"), } value := &LazyValue{ serializer: h.options.Serializer, Reader: &Reader{ request.Body, - httpHeaderToContentHeader(request.Header), + prefixStrippedHTTPHeaderToNexusHeader(request.Header, "content-"), }, } response, err := h.options.Handler.StartOperation(request.Context(), operation, value, options) diff --git a/nexus/start_test.go b/nexus/start_test.go index 6e557fc..d906a43 100644 --- a/nexus/start_test.go +++ b/nexus/start_test.go @@ -25,9 +25,19 @@ func (h *successHandler) StartOperation(ctx context.Context, operation string, i if options.CallbackURL != "http://test/callback" { return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "unexpected callback URL: %s", options.CallbackURL) } + if options.CallbackHeader.Get("callback-test") != "ok" { + return nil, HandlerErrorf( + HandlerErrorTypeBadRequest, + "invalid 'callback-test' callback header: %q", + options.CallbackHeader.Get("callback-test"), + ) + } if options.Header.Get("test") != "ok" { return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid 'test' header: %q", options.Header.Get("test")) } + if options.Header.Get("nexus-callback-callback-test") != "" { + return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "callback header not omitted from options Header") + } if options.Header.Get("User-Agent") != userAgent { return nil, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid 'User-Agent' header: %q", options.Header.Get("User-Agent")) } @@ -42,8 +52,9 @@ func TestSuccess(t *testing.T) { requestBody := []byte{0x00, 0x01} response, err := client.ExecuteOperation(ctx, "i need to/be escaped", requestBody, ExecuteOperationOptions{ - CallbackURL: "http://test/callback", - Header: Header{"test": "ok"}, + CallbackURL: "http://test/callback", + CallbackHeader: Header{"callback-test": "ok"}, + Header: Header{"test": "ok"}, }) require.NoError(t, err) var responseBody []byte