From e2dc43ce52c59e6fd3e8bb7c3df3056d389e181f Mon Sep 17 00:00:00 2001 From: Alex Obukhov Date: Sun, 20 Sep 2020 08:37:18 +0200 Subject: [PATCH] support for Multipart requests (#13) * Implement file upload types * Publish fields of Upload struct * Rollback accidental change * Remove double encode * Fix parsing slice of files * Add test extract files * Add prepareMultipart test * Fix some misspells reported by goreportcard.com * Fix introspection go fmt --- file.go | 118 ++++++++++++++++++++++++++++++++++++++++++++++ file_test.go | 96 +++++++++++++++++++++++++++++++++++++ introspection.go | 3 +- queryer.go | 20 +++++++- queryerMultiOp.go | 2 +- queryerNetwork.go | 28 ++++++++--- queryer_test.go | 2 +- 7 files changed, 258 insertions(+), 11 deletions(-) create mode 100644 file.go create mode 100644 file_test.go diff --git a/file.go b/file.go new file mode 100644 index 0000000..e6f66c3 --- /dev/null +++ b/file.go @@ -0,0 +1,118 @@ +package graphql + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "strconv" +) + +type File interface { + io.Reader + io.Closer +} +type Upload struct { + File File + FileName string +} + +type UploadMap []struct { + upload Upload + positions []string +} + +func (u *UploadMap) UploadMap() map[string][]string { + var result = make(map[string][]string) + + for idx, attachment := range *u { + result[strconv.Itoa(idx)] = attachment.positions + } + + return result +} + +func (u *UploadMap) NotEmpty() bool { + return len(*u) > 0 +} + +func (u *UploadMap) Add(upload Upload, varName string) { + *u = append(*u, struct { + upload Upload + positions []string + }{ + upload, + []string{fmt.Sprintf("variables.%s", varName)}, + }) +} + +// function extracts attached files and sets respective variables to null +func extractFiles(input *QueryInput) *UploadMap { + uploadMap := &UploadMap{} + + for varName, value := range input.Variables { + switch valueTyped := value.(type) { + case Upload: + uploadMap.Add(valueTyped, varName) + input.Variables[varName] = nil + case []interface{}: + for i, uploadVal := range valueTyped { + if upload, ok := uploadVal.(Upload); ok { + uploadMap.Add(upload, fmt.Sprintf("%s.%d", varName, i)) + } + valueTyped[i] = nil + } + input.Variables[varName] = valueTyped + default: + //noop + } + } + return uploadMap +} + +func prepareMultipart(payload []byte, uploadMap *UploadMap) (body []byte, contentType string, err error) { + var b = bytes.Buffer{} + var fw io.Writer + + w := multipart.NewWriter(&b) + + fw, err = w.CreateFormField("operations") + if err != nil { + return + } + + _, err = fw.Write(payload) + if err != nil { + return + } + + fw, err = w.CreateFormField("map") + if err != nil { + return + } + + err = json.NewEncoder(fw).Encode(uploadMap.UploadMap()) + if err != nil { + return + } + + for index, uploadVariable := range *uploadMap { + fw, err := w.CreateFormFile(strconv.Itoa(index), uploadVariable.upload.FileName) + if err != nil { + return b.Bytes(), w.FormDataContentType(), err + } + + _, err = io.Copy(fw, uploadVariable.upload.File) + if err != nil { + return b.Bytes(), w.FormDataContentType(), err + } + } + + err = w.Close() + if err != nil { + return + } + + return b.Bytes(), w.FormDataContentType(), nil +} diff --git a/file_test.go b/file_test.go new file mode 100644 index 0000000..593e723 --- /dev/null +++ b/file_test.go @@ -0,0 +1,96 @@ +package graphql + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/stretchr/testify/assert" + "io/ioutil" + "strings" + "testing" +) + +func TestExtractFiles(t *testing.T) { + + upload1 := Upload{nil, "file1"} + upload2 := Upload{nil, "file2"} + upload3 := Upload{nil, "file3"} + + input := &QueryInput{ + Variables: map[string]interface{}{ + "stringParam": "hello world", + "someFile": upload1, + "allFiles": []interface{}{ + upload2, + upload3, + }, + "integerParam": 10, + }, + } + + actual := extractFiles(input) + + expected := &UploadMap{} + expected.Add(upload1, "someFile") + expected.Add(upload2, "allFiles.0") + expected.Add(upload3, "allFiles.1") + + assert.Equal(t, expected, actual) +} + +func TestPrepareMultipart(t *testing.T) { + upload1 := Upload{ioutil.NopCloser(bytes.NewBufferString("File1Contents")), "file1"} + upload2 := Upload{ioutil.NopCloser(bytes.NewBufferString("File2Contents")), "file2"} + upload3 := Upload{ioutil.NopCloser(bytes.NewBufferString("File3Contents")), "file3"} + + uploadMap := &UploadMap{} + uploadMap.Add(upload1, "someFile") + uploadMap.Add(upload2, "allFiles.0") + uploadMap.Add(upload3, "allFiles.1") + + payload, _ := json.Marshal(map[string]interface{}{ + "query": "mutation TestFileUpload($someFile: Upload!,$allFiles: [Upload!]!) {upload(file: $someFile) uploadMulti(files: $allFiles)}", + "variables": map[string]interface{}{ + "someFile": nil, + "allFiles": []interface{}{nil, nil}, + }, + "operationName": "TestFileUpload", + }) + + body, contentType, err := prepareMultipart(payload, uploadMap) + + headerParts := strings.Split(contentType, "; boundary=") + rawBody := []string{ + "--%[1]s", + "Content-Disposition: form-data; name=\"operations\"", + "", + "{\"operationName\":\"TestFileUpload\",\"query\":\"mutation TestFileUpload($someFile: Upload!,$allFiles: [Upload!]!) {upload(file: $someFile) uploadMulti(files: $allFiles)}\",\"variables\":{\"allFiles\":[null,null],\"someFile\":null}}", + "--%[1]s", + "Content-Disposition: form-data; name=\"map\"", + "", + "{\"0\":[\"variables.someFile\"],\"1\":[\"variables.allFiles.0\"],\"2\":[\"variables.allFiles.1\"]}\n", + "--%[1]s", + "Content-Disposition: form-data; name=\"0\"; filename=\"file1\"", + "Content-Type: application/octet-stream", + "", + "File1Contents", + "--%[1]s", + "Content-Disposition: form-data; name=\"1\"; filename=\"file2\"", + "Content-Type: application/octet-stream", + "", + "File2Contents", + "--%[1]s", + "Content-Disposition: form-data; name=\"2\"; filename=\"file3\"", + "Content-Type: application/octet-stream", + "", + "File3Contents", + "--%[1]s--", + "", + } + + expected := fmt.Sprintf(strings.Join(rawBody, "\r\n"), headerParts[1]) + + assert.Equal(t, "multipart/form-data", headerParts[0]) + assert.Equal(t, expected, string(body)) + assert.Nil(t, err) +} diff --git a/introspection.go b/introspection.go index 69fbf04..91e6f70 100644 --- a/introspection.go +++ b/introspection.go @@ -8,7 +8,6 @@ import ( "github.com/vektah/gqlparser/v2/ast" ) - // IntrospectRemoteSchema is used to build a RemoteSchema by firing the introspection query // at a remote service and reconstructing the schema object from the response func IntrospectRemoteSchema(url string) (*RemoteSchema, error) { @@ -52,7 +51,7 @@ func IntrospectAPI(queryer Queryer) (*ast.Schema, error) { result := IntrospectionQueryResult{} input := &QueryInput{ - Query: IntrospectionQuery, + Query: IntrospectionQuery, OperationName: "IntrospectionQuery", } diff --git a/queryer.go b/queryer.go index ab481df..755d3c9 100755 --- a/queryer.go +++ b/queryer.go @@ -27,7 +27,7 @@ type QueryInput struct { Variables map[string]interface{} `json:"variables"` } -// String returns a guarenteed unique string that can be used to identify the input +// String returns a guaranteed unique string that can be used to identify the input func (i *QueryInput) String() string { // let's just marshal the input marshaled, err := json.Marshal(i) @@ -121,6 +121,24 @@ func (q *NetworkQueryer) SendQuery(ctx context.Context, payload []byte) ([]byte, acc := req.WithContext(ctx) acc.Header.Set("Content-Type", "application/json") + return q.sendRequest(acc) +} + +// SendMultipart is responsible for sending multipart request to the desingated URL +func (q *NetworkQueryer) SendMultipart(ctx context.Context, payload []byte, contentType string) ([]byte, error) { + // construct the initial request we will send to the client + req, err := http.NewRequest("POST", q.URL, bytes.NewBuffer(payload)) + if err != nil { + return nil, err + } + // add the current context to the request + acc := req.WithContext(ctx) + acc.Header.Set("Content-Type", contentType) + + return q.sendRequest(acc) +} + +func (q *NetworkQueryer) sendRequest(acc *http.Request) ([]byte, error) { // we could have any number of middlewares that we have to go through so for _, mware := range q.Middlewares { err := mware(acc) diff --git a/queryerMultiOp.go b/queryerMultiOp.go index f2d4d9b..4384b10 100755 --- a/queryerMultiOp.go +++ b/queryerMultiOp.go @@ -22,7 +22,7 @@ type MultiOpQueryer struct { loader *dataloader.Loader } -// NewMultiOpQueryer returns a MultiOpQueryer with the provided paramters +// NewMultiOpQueryer returns a MultiOpQueryer with the provided parameters func NewMultiOpQueryer(url string, interval time.Duration, maxBatchSize int) *MultiOpQueryer { queryer := &MultiOpQueryer{ MaxBatchSize: maxBatchSize, diff --git a/queryerNetwork.go b/queryerNetwork.go index d7f3c62..692f930 100644 --- a/queryerNetwork.go +++ b/queryerNetwork.go @@ -3,9 +3,8 @@ package graphql import ( "context" "encoding/json" - "net/http" - "github.com/mitchellh/mapstructure" + "net/http" ) // SingleRequestQueryer sends the query to a url and returns the response @@ -43,6 +42,9 @@ func (q *SingleRequestQueryer) URL() string { // Query sends the query to the designated url and returns the response. func (q *SingleRequestQueryer) Query(ctx context.Context, input *QueryInput, receiver interface{}) error { + // check if query contains attached files + uploadMap := extractFiles(input) + // the payload payload, err := json.Marshal(map[string]interface{}{ "query": input.Query, @@ -53,10 +55,24 @@ func (q *SingleRequestQueryer) Query(ctx context.Context, input *QueryInput, rec return err } - // send that query to the api and write the appropriate response to the receiver - response, err := q.queryer.SendQuery(ctx, payload) - if err != nil { - return err + var response []byte + if uploadMap.NotEmpty() { + body, contentType, err := prepareMultipart(payload, uploadMap) + + responseBody, err := q.queryer.SendMultipart(ctx, body, contentType) + if err != nil { + return err + } + + response = responseBody + } else { + // send that query to the api and write the appropriate response to the receiver + responseBody, err := q.queryer.SendQuery(ctx, payload) + if err != nil { + return err + } + + response = responseBody } result := map[string]interface{}{} diff --git a/queryer_test.go b/queryer_test.go index 0dd363a..7b574ce 100755 --- a/queryer_test.go +++ b/queryer_test.go @@ -392,7 +392,7 @@ func TestQueryerWithMiddlewares(t *testing.T) { return &http.Response{ StatusCode: http.StatusExpectationFailed, // Send response to be tested - Body: ioutil.NopCloser(bytes.NewBufferString("Did not recieve the right header")), + Body: ioutil.NopCloser(bytes.NewBufferString("Did not receive the right header")), // Must be set to non-nil value or it panics Header: make(http.Header), }