diff --git a/examples/server/spring-server/src/test/kotlin/com/expediagroup/graphql/examples/server/spring/query/APQQueryIT.kt b/examples/server/spring-server/src/test/kotlin/com/expediagroup/graphql/examples/server/spring/query/APQQueryIT.kt new file mode 100644 index 0000000000..847f0784d0 --- /dev/null +++ b/examples/server/spring-server/src/test/kotlin/com/expediagroup/graphql/examples/server/spring/query/APQQueryIT.kt @@ -0,0 +1,83 @@ +/* + * Copyright 2022 Expedia, Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.expediagroup.graphql.examples.server.spring.query + +import com.expediagroup.graphql.examples.server.spring.GRAPHQL_MEDIA_TYPE +import com.expediagroup.graphql.examples.server.spring.verifyData +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.http.MediaType.APPLICATION_JSON +import org.springframework.test.web.reactive.server.WebTestClient + +@SpringBootTest( + properties = ["graphql.automaticPersistedQueries.enabled=true"] +) +@AutoConfigureWebTestClient +@TestInstance(PER_CLASS) +class APQQueryIT(@Autowired private val testClient: WebTestClient) { + + @Test + fun `verify GET persisted query with hash only followed by POST with hash`() { + val query = "simpleDeprecatedQuery" + + testClient.get() + .uri { builder -> + builder.path("/graphql") + .queryParam("extensions", "{extension}") + .build("""{"persistedQuery":{"version":1,"sha256Hash":"aee64e0a941589ff06b717d4930405f3eafb089e687bef6ece5719ea6a4e7f35"}}""") + } + .exchange() + .expectBody().json( + """ + { + errors: [ + { + message: "PersistedQueryNotFound" + } + ] + } + """.trimIndent() + ) + + val expectedData = "false" + + testClient.post() + .uri { builder -> + builder.path("/graphql") + .queryParam("extensions", "{extension}") + .build("""{"persistedQuery":{"version":1,"sha256Hash":"aee64e0a941589ff06b717d4930405f3eafb089e687bef6ece5719ea6a4e7f35"}}""") + } + .accept(APPLICATION_JSON) + .contentType(GRAPHQL_MEDIA_TYPE) + .bodyValue("query { $query }") + .exchange() + .verifyData(query, expectedData) + + testClient.get() + .uri { builder -> + builder.path("/graphql") + .queryParam("extensions", "{extension}") + .build("""{"persistedQuery":{"version":1,"sha256Hash":"aee64e0a941589ff06b717d4930405f3eafb089e687bef6ece5719ea6a4e7f35"}}""") + } + .exchange() + .verifyData(query, expectedData) + } +} diff --git a/servers/graphql-kotlin-ktor-server/src/main/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParser.kt b/servers/graphql-kotlin-ktor-server/src/main/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParser.kt index 006889e11d..7c055d1bdd 100644 --- a/servers/graphql-kotlin-ktor-server/src/main/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParser.kt +++ b/servers/graphql-kotlin-ktor-server/src/main/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParser.kt @@ -30,6 +30,8 @@ import java.io.IOException internal const val REQUEST_PARAM_QUERY = "query" internal const val REQUEST_PARAM_OPERATION_NAME = "operationName" internal const val REQUEST_PARAM_VARIABLES = "variables" +internal const val REQUEST_PARAM_EXTENSIONS = "extensions" +internal const val REQUEST_PARAM_PERSISTED_QUERY = "persistedQuery" /** * GraphQL Ktor [ApplicationRequest] parser. @@ -46,8 +48,12 @@ open class KtorGraphQLRequestParser( else -> null } - private fun parseGetRequest(request: ApplicationRequest): GraphQLServerRequest? { - val query = request.queryParameters[REQUEST_PARAM_QUERY] ?: throw IllegalStateException("Invalid HTTP request - GET request has to specify query parameter") + private fun parseGetRequest(request: ApplicationRequest): GraphQLServerRequest { + val extensions = request.queryParameters[REQUEST_PARAM_EXTENSIONS] + val query = request.queryParameters[REQUEST_PARAM_QUERY] ?: "" + check(query.isNotEmpty() || extensions?.contains(REQUEST_PARAM_PERSISTED_QUERY) == true) { + "Invalid HTTP request - GET request has to specify either query parameter or persisted query extension" + } if (query.startsWith("mutation ") || query.startsWith("subscription ")) { throw UnsupportedOperationException("Invalid GraphQL operation - only queries are supported for GET requests") } @@ -56,7 +62,15 @@ open class KtorGraphQLRequestParser( val graphQLVariables: Map? = variables?.let { mapper.readValue(it, mapTypeReference) } - return GraphQLRequest(query = query, operationName = operationName, variables = graphQLVariables) + val extensionsMap: Map? = request.queryParameters[REQUEST_PARAM_EXTENSIONS]?.let { + mapper.readValue(it, mapTypeReference) + } + return GraphQLRequest( + query = query, + operationName = operationName, + variables = graphQLVariables, + extensions = extensionsMap + ) } private suspend fun parsePostRequest(request: ApplicationRequest): GraphQLServerRequest? = try { diff --git a/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/GraphQLPluginTest.kt b/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/GraphQLPluginTest.kt index 510bb689ae..77602c42f4 100644 --- a/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/GraphQLPluginTest.kt +++ b/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/GraphQLPluginTest.kt @@ -134,6 +134,16 @@ class GraphQLPluginTest { } } + @Test + fun `server should return Method Not Allowed for Mutation GET requests with persisted query`() { + testApplication { + val response = client.get("/graphql") { + parameter("query", "mutation { foo }") + } + assertEquals(HttpStatusCode.MethodNotAllowed, response.status) + } + } + @Test fun `server should return Bad Request for invalid GET requests`() { testApplication { diff --git a/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParserTest.kt b/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParserTest.kt index 27b3dca7dc..5e6a6c86b1 100644 --- a/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParserTest.kt +++ b/servers/graphql-kotlin-ktor-server/src/test/kotlin/com/expediagroup/graphql/server/ktor/KtorGraphQLRequestParserTest.kt @@ -35,6 +35,7 @@ class KtorGraphQLRequestParserTest { fun `parseRequest should throw IllegalStateException if request method is GET without query`() = runTest { val request = mockk(relaxed = true) { every { queryParameters[REQUEST_PARAM_QUERY] } returns null + every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null every { local.method } returns HttpMethod.Get } assertFailsWith { @@ -60,6 +61,7 @@ class KtorGraphQLRequestParserTest { every { queryParameters[REQUEST_PARAM_QUERY] } returns "{ foo }" every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns null every { queryParameters[REQUEST_PARAM_VARIABLES] } returns null + every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null every { local.method } returns HttpMethod.Get } val graphQLRequest = parser.parseRequest(serverRequest) @@ -76,6 +78,7 @@ class KtorGraphQLRequestParserTest { every { queryParameters[REQUEST_PARAM_QUERY] } returns "query MyFoo { foo }" every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns "MyFoo" every { queryParameters[REQUEST_PARAM_VARIABLES] } returns """{"a":1}""" + every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null every { local.method } returns HttpMethod.Get } val graphQLRequest = parser.parseRequest(serverRequest) @@ -86,6 +89,27 @@ class KtorGraphQLRequestParserTest { assertEquals(1, graphQLRequest.variables?.get("a")) } + @Test + fun `parseRequest should return request if method is GET with hash only`() = runTest { + val serverRequest = mockk(relaxed = true) { + every { queryParameters[REQUEST_PARAM_QUERY] } returns null + every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns "MyFoo" + every { queryParameters[REQUEST_PARAM_VARIABLES] } returns """{"a":1}""" + every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns """{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}""" + every { local.method } returns HttpMethod.Get + } + val graphQLRequest = parser.parseRequest(serverRequest) + assertNotNull(graphQLRequest) + assertTrue(graphQLRequest is GraphQLRequest) + assertEquals("", graphQLRequest.query) + assertEquals("MyFoo", graphQLRequest.operationName) + assertEquals(1, graphQLRequest.variables?.get("a")) + assertEquals( + mapOf("version" to 1, "sha256Hash" to "some-hash"), + graphQLRequest.extensions?.get("persistedQuery") + ) + } + @Test fun `parseRequest should return request if method is POST`() = runTest { val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1)) diff --git a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParser.kt b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParser.kt index 89db7a121c..a89c4f8036 100644 --- a/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParser.kt +++ b/servers/graphql-kotlin-spring-server/src/main/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParser.kt @@ -30,10 +30,13 @@ import org.springframework.web.reactive.function.server.ServerRequest import org.springframework.web.reactive.function.server.awaitBody import org.springframework.web.reactive.function.server.bodyToMono import org.springframework.web.server.ResponseStatusException +import kotlin.jvm.optionals.getOrNull internal const val REQUEST_PARAM_QUERY = "query" internal const val REQUEST_PARAM_OPERATION_NAME = "operationName" internal const val REQUEST_PARAM_VARIABLES = "variables" +internal const val REQUEST_PARAM_EXTENSIONS = "extensions" +internal const val REQUEST_PARAM_PERSISTED_QUERY = "persistedQuery" internal val graphQLMediaType = MediaType("application", "graphql") open class SpringGraphQLRequestParser( @@ -43,20 +46,33 @@ open class SpringGraphQLRequestParser( private val mapTypeReference: MapType = TypeFactory.defaultInstance().constructMapType(HashMap::class.java, String::class.java, Any::class.java) override suspend fun parseRequest(request: ServerRequest): GraphQLServerRequest? = when { - request.queryParam(REQUEST_PARAM_QUERY).isPresent -> { getRequestFromGet(request) } - request.method().equals(HttpMethod.POST) -> { getRequestFromPost(request) } + request.isGetPersistedQuery() || request.hasQueryParam() -> { getRequestFromGet(request) } + request.method() == HttpMethod.POST -> getRequestFromPost(request) else -> null } + private fun ServerRequest.hasQueryParam() = queryParam(REQUEST_PARAM_QUERY).isPresent + + private fun ServerRequest.isGetPersistedQuery() = + method() == HttpMethod.GET && queryParam(REQUEST_PARAM_EXTENSIONS).getOrNull()?.contains(REQUEST_PARAM_PERSISTED_QUERY) == true + private fun getRequestFromGet(serverRequest: ServerRequest): GraphQLServerRequest { - val query = serverRequest.queryParam(REQUEST_PARAM_QUERY).get() + val query = serverRequest.queryParam(REQUEST_PARAM_QUERY).orElse("") val operationName: String? = serverRequest.queryParam(REQUEST_PARAM_OPERATION_NAME).orElseGet { null } val variables: String? = serverRequest.queryParam(REQUEST_PARAM_VARIABLES).orElseGet { null } val graphQLVariables: Map? = variables?.let { objectMapper.readValue(it, mapTypeReference) } + val extensions: Map? = serverRequest.queryParam(REQUEST_PARAM_EXTENSIONS).takeIf { it.isPresent }?.get()?.let { + objectMapper.readValue(it, mapTypeReference) + } - return GraphQLRequest(query = query, operationName = operationName, variables = graphQLVariables) + return GraphQLRequest( + query = query, + operationName = operationName, + variables = graphQLVariables, + extensions = extensions + ) } private suspend fun getRequestFromPost(serverRequest: ServerRequest): GraphQLServerRequest? { diff --git a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParserTest.kt b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParserTest.kt index 87a16f3fd3..f427746058 100644 --- a/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParserTest.kt +++ b/servers/graphql-kotlin-spring-server/src/test/kotlin/com/expediagroup/graphql/server/spring/execution/SpringGraphQLRequestParserTest.kt @@ -22,7 +22,6 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import io.mockk.every import io.mockk.mockk import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.test.runBlockingTest import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.springframework.http.HttpHeaders @@ -44,18 +43,20 @@ class SpringGraphQLRequestParserTest { private val parser = SpringGraphQLRequestParser(objectMapper) @Test - fun `parseRequest should return null if request method is not valid`() = runBlockingTest { + fun `parseRequest should return null if request method is not valid`() = runTest { val request = mockk(relaxed = true) { every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty() + every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty() every { method() } returns HttpMethod.PUT } assertNull(parser.parseRequest(request)) } @Test - fun `parseRequest should return null if request method is GET without query`() = runBlockingTest { + fun `parseRequest should return null if request method is GET without query`() = runTest { val request = mockk(relaxed = true) { every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty() + every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty() every { method() } returns HttpMethod.GET } assertNull(parser.parseRequest(request)) @@ -65,6 +66,7 @@ class SpringGraphQLRequestParserTest { fun `parseRequest should return request if method is GET with simple query`() = runTest { val serverRequest = mockk(relaxed = true) { every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.of("{ foo }") + every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty() every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.empty() every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.empty() every { method() } returns HttpMethod.GET @@ -82,6 +84,7 @@ class SpringGraphQLRequestParserTest { val serverRequest = mockk(relaxed = true) { every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.of("query MyFoo { foo }") every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.of("MyFoo") + every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty() every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.of("""{ "a": 1 }""") every { method() } returns HttpMethod.GET } @@ -93,6 +96,54 @@ class SpringGraphQLRequestParserTest { assertEquals(1, graphQLRequest.variables?.get("a")) } + @Test + fun `parseRequest should return request if method is GET with hash only`() = runTest { + val serverRequest = mockk(relaxed = true) { + every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty() + every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.of("""{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}""") + every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.empty() + every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.empty() + every { method() } returns HttpMethod.GET + } + val graphQLRequest = parser.parseRequest(serverRequest) + assertNotNull(graphQLRequest) + assertTrue(graphQLRequest is GraphQLRequest) + assertEquals("", graphQLRequest.query) + assertNull(graphQLRequest.operationName) + assertNull(graphQLRequest.variables) + assertEquals( + mapOf("version" to 1, "sha256Hash" to "some-hash"), + graphQLRequest.extensions?.get("persistedQuery") + ) + } + + @Test + fun `parseRequest should return request if method is POST with content-type json and persisted query extension`() = runTest { + val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1)) + val serverRequest = MockServerRequest.builder() + .method(HttpMethod.POST) + .queryParam( + REQUEST_PARAM_EXTENSIONS, + """ + { + "persistedQuery": { + "version": 1, + "sha256Hash": "some-hash" + } + } + """.trimIndent() + ) + .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .body(Mono.justOrEmpty(mockRequest)) + + val graphQLRequest = parser.parseRequest(serverRequest) + assertNotNull(graphQLRequest) + assertTrue(graphQLRequest is GraphQLRequest) + assertEquals("query MyFoo { foo }", graphQLRequest.query) + assertEquals("MyFoo", graphQLRequest.operationName) + assertEquals(1, graphQLRequest.variables?.get("a")) + } + @Test fun `parseRequest should return request if method is POST with no content-type`() = runTest { val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1))