diff --git a/.gitignore b/.gitignore index fc95c31..44677df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ cmd/apirunner/apirunner dist/ +.idea/ diff --git a/basicresponse.json b/basicresponse.json index b607c1a..d48b38e 100644 --- a/basicresponse.json +++ b/basicresponse.json @@ -13,6 +13,21 @@ "name": "name" } } + }, + { + "name": "basicStringBody", + "request": { + "method": "POST", + "url": "/user", + "body": "name" + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "id": 1, + "name": "name" + } + } } ] } diff --git a/bodyparser.json b/bodyparser.json new file mode 100644 index 0000000..b1a117d --- /dev/null +++ b/bodyparser.json @@ -0,0 +1,32 @@ +{ + "tests": [ + { + "name": "bodyStringParser", + "request": { + "method": "POST", + "url": "/string", + "body": "Hello, World!" + }, + "expectedResponse": { + "statusCode": 200, + "body": "Hello, World!" + } + }, + { + "name": "bodyJSONParser", + "request": { + "method": "POST", + "url": "/json", + "body": { + "hello": "world" + } + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "hello": "world" + } + } + } + ] +} diff --git a/suite.go b/suite.go index 5e29169..c6c44d5 100644 --- a/suite.go +++ b/suite.go @@ -237,14 +237,22 @@ func (suite TestSuite) executeTest(test TestSpec, extractedFields map[string]int if test.Request.Body == nil { requestBody = bytes.NewBuffer([]byte("{}")) } else { - reqBodyBytes, err := json.Marshal(test.Request.Body) - if err != nil { - testErrors = append(testErrors, fmt.Sprintf("Invalid request body: %v", err)) - return Failed(test.Name, testErrors, time.Since(start)) + var stringBody string + // Marshalling a string directly will escape the string, so we need to handle it separately + if str, ok := test.Request.Body.(string); ok { + stringBody = str + } else { + reqBodyBytes, err := json.Marshal(test.Request.Body) + + if err != nil { + testErrors = append(testErrors, fmt.Sprintf("Invalid request body: %v", err)) + return Failed(test.Name, testErrors, time.Since(start)) + } + stringBody = string(reqBodyBytes) } // Replace any template variables in test's request body with the appropriate value - processedRequestBody, err := templateReplace(string(reqBodyBytes), extractedFields) + processedRequestBody, err := templateReplace(stringBody, extractedFields) if err != nil { testErrors = append(testErrors, err.Error()) return Failed(test.Name, testErrors, time.Since(start)) @@ -348,12 +356,21 @@ func (suite TestSuite) executeTest(test TestSpec, extractedFields map[string]int // Otherwise, deep compare response payload to expected response payload var r interface{} err = json.Unmarshal(body, &r) - if err != nil { - testErrors = append(testErrors, fmt.Sprintf("Error parsing json response from server: %v", err)) - return Failed(test.Name, testErrors, time.Since(start)) - } - switch r.(type) { - case map[string]interface{}: + switch { + case err != nil: + // If JSON unmarshalling fails, compare the response as a plain text string + expectedString, ok := expectedResponse.(string) + if !ok { + testErrors = append(testErrors, fmt.Sprintf("Expected a JSON object, but got a non-JSON response: %s", string(body))) + } else { + processedExpectedBody, err := templateReplace(expectedString, extractedFields) + if err != nil { + testErrors = append(testErrors, fmt.Sprintf("Error comparing actual and expected responses: %v", err)) + } else if string(body) != processedExpectedBody { + testErrors = append(testErrors, fmt.Sprintf("Expected response payload %s but got %s", expectedString, string(body))) + } + } + case isMap(r): differences, err := suite.compareObjects(r.(map[string]interface{}), expectedResponse.(map[string]interface{}), extractedFields, test.Name) if err != nil { testErrors = append(testErrors, fmt.Sprintf("Error comparing actual and expected responses: %v", err)) @@ -362,7 +379,7 @@ func (suite TestSuite) executeTest(test TestSpec, extractedFields map[string]int if len(differences) > 0 { testErrors = append(testErrors, differences...) } - case []interface{}: + case isSlice(r): response := r.([]interface{}) expected := expectedResponse.([]interface{}) if len(response) != len(expected) { @@ -394,6 +411,16 @@ func (suite TestSuite) executeTest(test TestSpec, extractedFields map[string]int return Passed(test.Name, time.Since(start)) } +func isMap(v interface{}) bool { + _, ok := v.(map[string]interface{}) + return ok +} + +func isSlice(v interface{}) bool { + _, ok := v.([]interface{}) + return ok +} + func (suite TestSuite) compareObjects(obj map[string]interface{}, expectedObj map[string]interface{}, extractedFields map[string]interface{}, objPrefix string) ([]string, error) { // Track all new field values from response obj flattenedObj := flatten(obj, objPrefix, 0) diff --git a/suite_test.go b/suite_test.go index dacfcbe..54904ac 100644 --- a/suite_test.go +++ b/suite_test.go @@ -117,3 +117,34 @@ func TestTemplateVars(t *testing.T) { t.Errorf("Expected failure result to contain string: 'missing template value for var: 'test1.userIdWrongVar''") } } + +type EchoRequestHttpClient struct { + StatusCode int +} + +func (c *EchoRequestHttpClient) Do(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: c.StatusCode, + Body: req.Body, + Header: req.Header, + }, nil +} + +func TestParsesBodyCorrectly(t *testing.T) { + mockClient := EchoRequestHttpClient{} + mockClient.StatusCode = 200 + results, _ := ExecuteSuite(RunConfig{ + BaseUrl: "", + CustomHeaders: nil, + HttpClient: &mockClient, + }, "bodyparser.json", true) + + if len(results.Passed) == 0 { + t.Errorf("All tests should have passed.\n") + } + if len(results.Failed) > 0 { + for _, test := range results.Failed { + t.Errorf("Failed test result: [%s]\n", test.Result()) + } + } +} diff --git a/templatevars.json b/templatevars.json index 6b49c89..6246d8a 100644 --- a/templatevars.json +++ b/templatevars.json @@ -28,6 +28,20 @@ "userId": "{{ test1.userId }}" } } + }, + { + "name": "test3", + "request": { + "method": "POST", + "url": "/users", + "body": "{{ test1.userId }}" + }, + "expectedResponse": { + "statusCode": 200, + "body": { + "userId": "{{ test1.userId }}" + } + } } ] }