diff --git a/.gitignore b/.gitignore index 973d4b3..09670d4 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,5 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +app diff --git a/echo_test.go b/echo_test.go index 91b075f..9690998 100644 --- a/echo_test.go +++ b/echo_test.go @@ -93,3 +93,51 @@ func TestEchoServerMiddleware(t *testing.T) { respData = append(respData, 0xa) assert.Equal(t, respData, resp.Bytes()) } +func TestOutgoingRequestEcho(t *testing.T) { + client := &Client{ + config: &Config{}, + } + publishCalled := false + var parentId *string + client.PublishMessage = func(ctx context.Context, payload Payload) error { + if payload.RawURL == "/from-gorilla" { + assert.NotNil(t, payload.ParentID) + parentId = payload.ParentID + } else if payload.URLPath == "/:slug/test" { + assert.Equal(t, *parentId, payload.MsgID) + } + publishCalled = true + return nil + } + router := echo.New() + router.Use(client.EchoMiddleware) + router.POST("/:slug/test", func(c echo.Context) (err error) { + body, err := io.ReadAll(c.Request().Body) + assert.NotEmpty(t, body) + reqData, _ := json.Marshal(exampleData2) + assert.Equal(t, reqData, body) + HTTPClient := http.DefaultClient + HTTPClient.Transport = client.WrapRoundTripper( + c.Request().Context(), HTTPClient.Transport, + WithRedactHeaders([]string{}), + ) + _, _ = HTTPClient.Get("http://localhost:3000/from-gorilla") + + c.JSON(http.StatusAccepted, exampleData) + return + }) + ts := httptest.NewServer(router) + defer ts.Close() + + _, err := req.Post(ts.URL+"/slug-value/test", + req.Param{"param1": "abc", "param2": 123}, + req.Header{ + "Content-Type": "application/json", + "X-API-KEY": "past-3", + }, + req.BodyJSON(exampleData2), + ) + assert.NoError(t, err) + assert.True(t, publishCalled) + +} diff --git a/errors_test.go b/errors_test.go index e510a58..85d593d 100644 --- a/errors_test.go +++ b/errors_test.go @@ -65,13 +65,7 @@ func TestErrorReporting(t *testing.T) { assert.Equal(t, GoOutgoing, payload.SdkType) return nil } - ctx := context.Background() - atHTTPClient := http.DefaultClient - atHTTPClient.Transport = outClient.WrapRoundTripper( - ctx, atHTTPClient.Transport, - WithRedactHeaders([]string{}), - ) - req.SetClient(atHTTPClient) + _, err := req.Post(ts.URL+"/test", req.Param{"param1": "abc", "param2": 123}, req.Header{ @@ -83,7 +77,6 @@ func TestErrorReporting(t *testing.T) { assert.NoError(t, err) assert.True(t, publishCalled) } - func TestGinMiddlewareGETError(t *testing.T) { gin.SetMode(gin.TestMode) client := &Client{ diff --git a/fiber_test.go b/fiber_test.go index a557ac9..cd963da 100644 --- a/fiber_test.go +++ b/fiber_test.go @@ -100,6 +100,53 @@ func TestFiberMiddleware(t *testing.T) { assert.Equal(t, respData, data) } +func TestOutgoingRequestFiber(t *testing.T) { + client := &Client{ + config: &Config{}, + } + publishCalled := false + var parentId *string + client.PublishMessage = func(ctx context.Context, payload Payload) error { + if payload.RawURL == "/from-gorilla" { + assert.NotNil(t, payload.ParentID) + parentId = payload.ParentID + } else if payload.URLPath == "/:slug/test" { + assert.Equal(t, *parentId, payload.MsgID) + } + publishCalled = true + return nil + } + router := fiber.New() + router.Use(client.FiberMiddleware) + router.Post("/:slug/test", func(c *fiber.Ctx) error { + body := c.Request().Body() + assert.NotEmpty(t, body) + reqData, _ := json.Marshal(exampleData2) + assert.Equal(t, reqData, body) + HTTPClient := http.DefaultClient + HTTPClient.Transport = client.WrapRoundTripper( + c.UserContext(), HTTPClient.Transport, + WithRedactHeaders([]string{}), + ) + _, _ = HTTPClient.Get("http://localhost:3000/from-gorilla") + + c.Append("Content-Type", "application/json") + c.Append("X-API-KEY", "applicationKey") + + return c.Status(http.StatusAccepted).JSON(exampleData) + }) + + reqData, _ := json.Marshal(exampleData2) + ts := httptest.NewRequest("POST", "/slug-value/test?param1=abc¶m2=123", bytes.NewReader(reqData)) + ts.Header.Set("Content-Type", "application/json") + ts.Header.Set("X-API-KEY", "past-3") + + _, err := router.Test(ts) + assert.NoError(t, err) + assert.True(t, publishCalled) + +} + // func TestFiberMiddlewareGET(t *testing.T) { // client := &Client{ // config: &Config{}, diff --git a/gin_test.go b/gin_test.go index ff04364..fd8a057 100644 --- a/gin_test.go +++ b/gin_test.go @@ -150,3 +150,42 @@ func TestGinMiddlewareGET(t *testing.T) { assert.True(t, publishCalled) assert.Equal(t, respData, resp.Bytes()) } + +func TestOutgoingRequestGin(t *testing.T) { + gin.SetMode(gin.TestMode) + client := &Client{ + config: &Config{}, + } + var publishCalled bool + router := gin.New() + router.Use(client.GinMiddleware) + var parentId *string + client.PublishMessage = func(ctx context.Context, payload Payload) error { + if payload.RawURL == "/from-gorilla" { + assert.NotNil(t, payload.ParentID) + parentId = payload.ParentID + } else if payload.URLPath == "/:slug/test" { + assert.Equal(t, *parentId, payload.MsgID) + } + publishCalled = true + return nil + } + router.GET("/:slug/test", func(c *gin.Context) { + HTTPClient := http.DefaultClient + HTTPClient.Transport = client.WrapRoundTripper( + c.Request.Context(), HTTPClient.Transport, + WithRedactHeaders([]string{}), + ) + _, _ = HTTPClient.Get("http://localhost:3000/from-gorilla") + + c.JSON(http.StatusAccepted, gin.H{"hello": "world"}) + }) + + ts := httptest.NewServer(router) + defer ts.Close() + + _, err := req.Get(ts.URL + "/slug-value/test") + assert.NoError(t, err) + assert.True(t, publishCalled) + +} diff --git a/native.go b/native.go index 1672584..3a0147d 100644 --- a/native.go +++ b/native.go @@ -18,7 +18,7 @@ import ( func (c *Client) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { msgID := uuid.Must(uuid.NewRandom()) - newCtx := context.WithValue(req.Context(), ErrorListCtxKey, msgID) + newCtx := context.WithValue(req.Context(), CurrentRequestMessageID, msgID) errorList := []ATError{} newCtx = context.WithValue(newCtx, ErrorListCtxKey, &errorList) @@ -60,10 +60,10 @@ func (c *Client) Middleware(next http.Handler) http.Handler { func (c *Client) GorillaMuxMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { msgID := uuid.Must(uuid.NewRandom()) - newCtx := context.WithValue(req.Context(), ErrorListCtxKey, msgID) + newCtx := context.WithValue(req.Context(), CurrentRequestMessageID, msgID) errorList := []ATError{} - newCtx = context.WithValue(req.Context(), ErrorListCtxKey, &errorList) + newCtx = context.WithValue(newCtx, ErrorListCtxKey, &errorList) req = req.WithContext(newCtx) reqBuf, _ := io.ReadAll(req.Body) @@ -110,10 +110,10 @@ func (c *Client) GorillaMuxMiddleware(next http.Handler) http.Handler { func (c *Client) ChiMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { msgID := uuid.Must(uuid.NewRandom()) - newCtx := context.WithValue(req.Context(), ErrorListCtxKey, msgID) + newCtx := context.WithValue(req.Context(), CurrentRequestMessageID, msgID) errorList := []ATError{} - newCtx = context.WithValue(req.Context(), ErrorListCtxKey, &errorList) + newCtx = context.WithValue(newCtx, ErrorListCtxKey, &errorList) req = req.WithContext(newCtx) reqBuf, _ := io.ReadAll(req.Body) diff --git a/outgoing.go b/outgoing.go index f17497b..c67dcc7 100644 --- a/outgoing.go +++ b/outgoing.go @@ -46,7 +46,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (res *http.Response, err er var payload Payload var parentMsgIDPtr *uuid.UUID - parentMsgID, ok := req.Context().Value(CurrentRequestMessageID).(uuid.UUID) + parentMsgID, ok := rt.ctx.Value(CurrentRequestMessageID).(uuid.UUID) if ok { parentMsgIDPtr = &parentMsgID } diff --git a/sdk_test.go b/sdk_test.go index 146af3c..8a5e4ff 100644 --- a/sdk_test.go +++ b/sdk_test.go @@ -86,60 +86,50 @@ func TestOutgoingMiddleware(t *testing.T) { return nil } - handlerFn := func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - assert.NoError(t, err) - assert.NotEmpty(t, body) - - jsonByte, err := json.Marshal(exampleData) - assert.NoError(t, err) - - w.Header().Add("Content-Type", "application/json") - w.Header().Add("X-API-KEY", "applicationKey") - w.WriteHeader(http.StatusAccepted) - w.Write(jsonByte) - } - - ts := httptest.NewServer(client.Middleware(http.HandlerFunc(handlerFn))) - defer ts.Close() - outClient := &Client{ config: &Config{}, } outClient.PublishMessage = func(ctx context.Context, payload Payload) error { - assert.Equal(t, "POST", payload.Method) - assert.Equal(t, "/test", payload.URLPath) + assert.Equal(t, "GET", payload.Method) + assert.Equal(t, "/from-gorilla", payload.URLPath) assert.Equal(t, map[string]string(nil), payload.PathParams) assert.Equal(t, map[string][]string{ "param1": {"abc"}, "param2": {"123"}, }, payload.QueryParams) - assert.Equal(t, map[string][]string{ - "Content-Type": {"application/json"}, - "X-Api-Key": {"past-3"}, - }, payload.RequestHeaders) - assert.Equal(t, "/test?param1=abc¶m2=123", payload.RawURL) - assert.Equal(t, http.StatusAccepted, payload.StatusCode) + assert.Equal(t, "/from-gorilla?param1=abc¶m2=123", payload.RawURL) + assert.Equal(t, http.StatusServiceUnavailable, payload.StatusCode) assert.Greater(t, payload.Duration, 1000*time.Nanosecond) assert.Equal(t, GoOutgoing, payload.SdkType) + assert.NotNil(t, payload.ParentID) - reqData, _ := json.Marshal(exampleData2) - // respData, _ := json.Marshal(exampleDataRedacted) + return nil + } - assert.Equal(t, reqData, payload.RequestBody) - // assert.Equal(t, respData, payload.ResponseBody) + handlerFn := func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.NotEmpty(t, body) + atHTTPClient := http.DefaultClient + atHTTPClient.Transport = outClient.WrapRoundTripper( + r.Context(), atHTTPClient.Transport, + WithRedactHeaders([]string{}), + ) + _, _ = atHTTPClient.Get("http://localhost:3000/from-gorilla?param1=abc¶m2=123") + jsonByte, err := json.Marshal(exampleData) + assert.NoError(t, err) - return nil + w.Header().Add("Content-Type", "application/json") + w.Header().Add("X-API-KEY", "applicationKey") + w.WriteHeader(http.StatusAccepted) + w.Write(jsonByte) } - ctx := context.Background() - atHTTPClient := http.DefaultClient - atHTTPClient.Transport = outClient.WrapRoundTripper( - ctx, atHTTPClient.Transport, - WithRedactHeaders([]string{}), - ) - req.SetClient(atHTTPClient) + + ts := httptest.NewServer(client.Middleware(http.HandlerFunc(handlerFn))) + defer ts.Close() + _, err := req.Post(ts.URL+"/test", req.Param{"param1": "abc", "param2": 123}, req.Header{