Skip to content

Commit

Permalink
feat: implemented persisted queries for GET methods with only SHA-256…
Browse files Browse the repository at this point in the history
… hash of query string (#2067)

### 📝 Description
According to the GraphQL APQ flow description, GET requests containing
only SHA-256 hash of the query should be checked in cache and respond
with PERSISTED_QUERY_NOT_FOUND error if request is not cached.
Both Ktor and Spring server implementations didn't handle this first
query without a query param.
I tried to implement the change without breaking existing behaviours, as
a query param is expected to take precedence over post body, for
example, as in one of the tests in RouteConfigurationIT.


### 🔗 Related Issues
#2065
  • Loading branch information
malaquf authored Jan 5, 2025
1 parent 681c70d commit f1a8603
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2022 Expedia, Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.expediagroup.graphql.examples.server.spring.query

import com.expediagroup.graphql.examples.server.spring.GRAPHQL_MEDIA_TYPE
import com.expediagroup.graphql.examples.server.spring.verifyData
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType.APPLICATION_JSON
import org.springframework.test.web.reactive.server.WebTestClient

@SpringBootTest(
properties = ["graphql.automaticPersistedQueries.enabled=true"]
)
@AutoConfigureWebTestClient
@TestInstance(PER_CLASS)
class APQQueryIT(@Autowired private val testClient: WebTestClient) {

@Test
fun `verify GET persisted query with hash only followed by POST with hash`() {
val query = "simpleDeprecatedQuery"

testClient.get()
.uri { builder ->
builder.path("/graphql")
.queryParam("extensions", "{extension}")
.build("""{"persistedQuery":{"version":1,"sha256Hash":"aee64e0a941589ff06b717d4930405f3eafb089e687bef6ece5719ea6a4e7f35"}}""")
}
.exchange()
.expectBody().json(
"""
{
errors: [
{
message: "PersistedQueryNotFound"
}
]
}
""".trimIndent()
)

val expectedData = "false"

testClient.post()
.uri { builder ->
builder.path("/graphql")
.queryParam("extensions", "{extension}")
.build("""{"persistedQuery":{"version":1,"sha256Hash":"aee64e0a941589ff06b717d4930405f3eafb089e687bef6ece5719ea6a4e7f35"}}""")
}
.accept(APPLICATION_JSON)
.contentType(GRAPHQL_MEDIA_TYPE)
.bodyValue("query { $query }")
.exchange()
.verifyData(query, expectedData)

testClient.get()
.uri { builder ->
builder.path("/graphql")
.queryParam("extensions", "{extension}")
.build("""{"persistedQuery":{"version":1,"sha256Hash":"aee64e0a941589ff06b717d4930405f3eafb089e687bef6ece5719ea6a4e7f35"}}""")
}
.exchange()
.verifyData(query, expectedData)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import java.io.IOException
internal const val REQUEST_PARAM_QUERY = "query"
internal const val REQUEST_PARAM_OPERATION_NAME = "operationName"
internal const val REQUEST_PARAM_VARIABLES = "variables"
internal const val REQUEST_PARAM_EXTENSIONS = "extensions"
internal const val REQUEST_PARAM_PERSISTED_QUERY = "persistedQuery"

/**
* GraphQL Ktor [ApplicationRequest] parser.
Expand All @@ -46,8 +48,12 @@ open class KtorGraphQLRequestParser(
else -> null
}

private fun parseGetRequest(request: ApplicationRequest): GraphQLServerRequest? {
val query = request.queryParameters[REQUEST_PARAM_QUERY] ?: throw IllegalStateException("Invalid HTTP request - GET request has to specify query parameter")
private fun parseGetRequest(request: ApplicationRequest): GraphQLServerRequest {
val extensions = request.queryParameters[REQUEST_PARAM_EXTENSIONS]
val query = request.queryParameters[REQUEST_PARAM_QUERY] ?: ""
check(query.isNotEmpty() || extensions?.contains(REQUEST_PARAM_PERSISTED_QUERY) == true) {
"Invalid HTTP request - GET request has to specify either query parameter or persisted query extension"
}
if (query.startsWith("mutation ") || query.startsWith("subscription ")) {
throw UnsupportedOperationException("Invalid GraphQL operation - only queries are supported for GET requests")
}
Expand All @@ -56,7 +62,15 @@ open class KtorGraphQLRequestParser(
val graphQLVariables: Map<String, Any>? = variables?.let {
mapper.readValue(it, mapTypeReference)
}
return GraphQLRequest(query = query, operationName = operationName, variables = graphQLVariables)
val extensionsMap: Map<String, Any>? = request.queryParameters[REQUEST_PARAM_EXTENSIONS]?.let {
mapper.readValue(it, mapTypeReference)
}
return GraphQLRequest(
query = query,
operationName = operationName,
variables = graphQLVariables,
extensions = extensionsMap
)
}

private suspend fun parsePostRequest(request: ApplicationRequest): GraphQLServerRequest? = try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,16 @@ class GraphQLPluginTest {
}
}

@Test
fun `server should return Method Not Allowed for Mutation GET requests with persisted query`() {
testApplication {
val response = client.get("/graphql") {
parameter("query", "mutation { foo }")
}
assertEquals(HttpStatusCode.MethodNotAllowed, response.status)
}
}

@Test
fun `server should return Bad Request for invalid GET requests`() {
testApplication {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class KtorGraphQLRequestParserTest {
fun `parseRequest should throw IllegalStateException if request method is GET without query`() = runTest {
val request = mockk<ApplicationRequest>(relaxed = true) {
every { queryParameters[REQUEST_PARAM_QUERY] } returns null
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null
every { local.method } returns HttpMethod.Get
}
assertFailsWith<IllegalStateException> {
Expand All @@ -60,6 +61,7 @@ class KtorGraphQLRequestParserTest {
every { queryParameters[REQUEST_PARAM_QUERY] } returns "{ foo }"
every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns null
every { queryParameters[REQUEST_PARAM_VARIABLES] } returns null
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null
every { local.method } returns HttpMethod.Get
}
val graphQLRequest = parser.parseRequest(serverRequest)
Expand All @@ -76,6 +78,7 @@ class KtorGraphQLRequestParserTest {
every { queryParameters[REQUEST_PARAM_QUERY] } returns "query MyFoo { foo }"
every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns "MyFoo"
every { queryParameters[REQUEST_PARAM_VARIABLES] } returns """{"a":1}"""
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null
every { local.method } returns HttpMethod.Get
}
val graphQLRequest = parser.parseRequest(serverRequest)
Expand All @@ -86,6 +89,27 @@ class KtorGraphQLRequestParserTest {
assertEquals(1, graphQLRequest.variables?.get("a"))
}

@Test
fun `parseRequest should return request if method is GET with hash only`() = runTest {
val serverRequest = mockk<ApplicationRequest>(relaxed = true) {
every { queryParameters[REQUEST_PARAM_QUERY] } returns null
every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns "MyFoo"
every { queryParameters[REQUEST_PARAM_VARIABLES] } returns """{"a":1}"""
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns """{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}"""
every { local.method } returns HttpMethod.Get
}
val graphQLRequest = parser.parseRequest(serverRequest)
assertNotNull(graphQLRequest)
assertTrue(graphQLRequest is GraphQLRequest)
assertEquals("", graphQLRequest.query)
assertEquals("MyFoo", graphQLRequest.operationName)
assertEquals(1, graphQLRequest.variables?.get("a"))
assertEquals(
mapOf("version" to 1, "sha256Hash" to "some-hash"),
graphQLRequest.extensions?.get("persistedQuery")
)
}

@Test
fun `parseRequest should return request if method is POST`() = runTest {
val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ import org.springframework.web.reactive.function.server.ServerRequest
import org.springframework.web.reactive.function.server.awaitBody
import org.springframework.web.reactive.function.server.bodyToMono
import org.springframework.web.server.ResponseStatusException
import kotlin.jvm.optionals.getOrNull

internal const val REQUEST_PARAM_QUERY = "query"
internal const val REQUEST_PARAM_OPERATION_NAME = "operationName"
internal const val REQUEST_PARAM_VARIABLES = "variables"
internal const val REQUEST_PARAM_EXTENSIONS = "extensions"
internal const val REQUEST_PARAM_PERSISTED_QUERY = "persistedQuery"
internal val graphQLMediaType = MediaType("application", "graphql")

open class SpringGraphQLRequestParser(
Expand All @@ -43,20 +46,33 @@ open class SpringGraphQLRequestParser(
private val mapTypeReference: MapType = TypeFactory.defaultInstance().constructMapType(HashMap::class.java, String::class.java, Any::class.java)

override suspend fun parseRequest(request: ServerRequest): GraphQLServerRequest? = when {
request.queryParam(REQUEST_PARAM_QUERY).isPresent -> { getRequestFromGet(request) }
request.method().equals(HttpMethod.POST) -> { getRequestFromPost(request) }
request.isGetPersistedQuery() || request.hasQueryParam() -> { getRequestFromGet(request) }
request.method() == HttpMethod.POST -> getRequestFromPost(request)
else -> null
}

private fun ServerRequest.hasQueryParam() = queryParam(REQUEST_PARAM_QUERY).isPresent

private fun ServerRequest.isGetPersistedQuery() =
method() == HttpMethod.GET && queryParam(REQUEST_PARAM_EXTENSIONS).getOrNull()?.contains(REQUEST_PARAM_PERSISTED_QUERY) == true

private fun getRequestFromGet(serverRequest: ServerRequest): GraphQLServerRequest {
val query = serverRequest.queryParam(REQUEST_PARAM_QUERY).get()
val query = serverRequest.queryParam(REQUEST_PARAM_QUERY).orElse("")
val operationName: String? = serverRequest.queryParam(REQUEST_PARAM_OPERATION_NAME).orElseGet { null }
val variables: String? = serverRequest.queryParam(REQUEST_PARAM_VARIABLES).orElseGet { null }
val graphQLVariables: Map<String, Any>? = variables?.let {
objectMapper.readValue(it, mapTypeReference)
}
val extensions: Map<String, Any>? = serverRequest.queryParam(REQUEST_PARAM_EXTENSIONS).takeIf { it.isPresent }?.get()?.let {
objectMapper.readValue(it, mapTypeReference)
}

return GraphQLRequest(query = query, operationName = operationName, variables = graphQLVariables)
return GraphQLRequest(
query = query,
operationName = operationName,
variables = graphQLVariables,
extensions = extensions
)
}

private suspend fun getRequestFromPost(serverRequest: ServerRequest): GraphQLServerRequest? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runBlockingTest
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.springframework.http.HttpHeaders
Expand All @@ -44,18 +43,20 @@ class SpringGraphQLRequestParserTest {
private val parser = SpringGraphQLRequestParser(objectMapper)

@Test
fun `parseRequest should return null if request method is not valid`() = runBlockingTest {
fun `parseRequest should return null if request method is not valid`() = runTest {
val request = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { method() } returns HttpMethod.PUT
}
assertNull(parser.parseRequest(request))
}

@Test
fun `parseRequest should return null if request method is GET without query`() = runBlockingTest {
fun `parseRequest should return null if request method is GET without query`() = runTest {
val request = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { method() } returns HttpMethod.GET
}
assertNull(parser.parseRequest(request))
Expand All @@ -65,6 +66,7 @@ class SpringGraphQLRequestParserTest {
fun `parseRequest should return request if method is GET with simple query`() = runTest {
val serverRequest = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.of("{ foo }")
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.empty()
every { method() } returns HttpMethod.GET
Expand All @@ -82,6 +84,7 @@ class SpringGraphQLRequestParserTest {
val serverRequest = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.of("query MyFoo { foo }")
every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.of("MyFoo")
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.of("""{ "a": 1 }""")
every { method() } returns HttpMethod.GET
}
Expand All @@ -93,6 +96,54 @@ class SpringGraphQLRequestParserTest {
assertEquals(1, graphQLRequest.variables?.get("a"))
}

@Test
fun `parseRequest should return request if method is GET with hash only`() = runTest {
val serverRequest = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.of("""{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}""")
every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.empty()
every { method() } returns HttpMethod.GET
}
val graphQLRequest = parser.parseRequest(serverRequest)
assertNotNull(graphQLRequest)
assertTrue(graphQLRequest is GraphQLRequest)
assertEquals("", graphQLRequest.query)
assertNull(graphQLRequest.operationName)
assertNull(graphQLRequest.variables)
assertEquals(
mapOf("version" to 1, "sha256Hash" to "some-hash"),
graphQLRequest.extensions?.get("persistedQuery")
)
}

@Test
fun `parseRequest should return request if method is POST with content-type json and persisted query extension`() = runTest {
val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1))
val serverRequest = MockServerRequest.builder()
.method(HttpMethod.POST)
.queryParam(
REQUEST_PARAM_EXTENSIONS,
"""
{
"persistedQuery": {
"version": 1,
"sha256Hash": "some-hash"
}
}
""".trimIndent()
)
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.body(Mono.justOrEmpty(mockRequest))

val graphQLRequest = parser.parseRequest(serverRequest)
assertNotNull(graphQLRequest)
assertTrue(graphQLRequest is GraphQLRequest)
assertEquals("query MyFoo { foo }", graphQLRequest.query)
assertEquals("MyFoo", graphQLRequest.operationName)
assertEquals(1, graphQLRequest.variables?.get("a"))
}

@Test
fun `parseRequest should return request if method is POST with no content-type`() = runTest {
val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1))
Expand Down

0 comments on commit f1a8603

Please sign in to comment.