Skip to content

Commit

Permalink
feat: Allow provider state generator to fall back to the provider sta…
Browse files Browse the repository at this point in the history
…te parameters
  • Loading branch information
rholshausen committed Jul 11, 2024
1 parent ea7037f commit 7550a6d
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ data class ProviderStateGenerator @JvmOverloads constructor (
}

override fun generate(context: MutableMap<String, Any>, exampleValue: Any?): Any? {
logger.debug { "Applying Generator $this" }
logger.debug { "Applying Generator $this with context ${context["providerState"]}" }
return when (val providerState = context["providerState"]) {
is Map<*, *> -> {
val map = providerState as Map<String, Any>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class PactVerificationStateChangeExtension(

val providerStateContext = mutableMapOf<String, Any?>()
providerStates.forEach { state ->
providerStateContext.putAll(state.params)
val stateChangeMethods = findStateChangeMethods(context.requiredTestInstance,
testContext.stateChangeHandlers, state)
if (stateChangeMethods.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ class PactVerificationStateChangeExtensionSpec extends Specification {
}

@State(['Test 2'])
void state3(Map params) {
Map state3(Map params) {
state3Called = params
[a: 100, b: '200']
}
}

Expand Down Expand Up @@ -115,6 +116,30 @@ class PactVerificationStateChangeExtensionSpec extends Specification {
!testInstance.state2TeardownCalled
}

def 'returns any values returned from the state callback'() {
given:
def state = new ProviderState('Test 2', [a: 'A', b: 'B'])

when:
def result = verificationExtension.invokeStateChangeMethods(testContext, pactContext, [state],
StateChangeAction.SETUP)

then:
result == [a: 100, b: '200']
}

def 'falls back to the parameters of the provider state'() {
given:
def state = new ProviderState('Test 2', [a: 'A', c: 'C'])

when:
def result = verificationExtension.invokeStateChangeMethods(testContext, pactContext, [state],
StateChangeAction.SETUP)

then:
result == [a: 100, b: '200', c: 'C']
}

@SuppressWarnings('ClosureAsLastMethodParameter')
def 'marks the test as failed if the provider state callback fails'() {
given:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package au.com.dius.pact.provider.junit5;

import au.com.dius.pact.provider.junitsupport.Provider;
import au.com.dius.pact.provider.junitsupport.State;
import au.com.dius.pact.provider.junitsupport.loader.PactFolder;
import com.github.tomakehurst.wiremock.WireMockServer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ru.lanwen.wiremock.ext.WiremockResolver;
import ru.lanwen.wiremock.ext.WiremockUriResolver;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.Collections;
import java.util.Map;

import static com.github.tomakehurst.wiremock.client.WireMock.*;

@Provider("ProviderStateParametersInjected")
@PactFolder("pacts")
@ExtendWith({
WiremockResolver.class,
WiremockUriResolver.class
})
public class ProviderStateParametersInjectedTest {
private static final Logger LOGGER = LoggerFactory.getLogger(ProviderStateParametersInjectedTest.class);

@TestTemplate
@ExtendWith(PactVerificationInvocationContextProvider.class)
void testTemplate(PactVerificationContext context) {
context.verifyInteraction();
}

@BeforeEach
void before(PactVerificationContext context,
@WiremockResolver.Wiremock WireMockServer server,
@WiremockUriResolver.WiremockUri String uri) throws MalformedURLException {
context.setTarget(HttpTestTarget.fromUrl(new URL(uri)));

server.stubFor(
get(urlPathEqualTo("/api/hello/John"))
.willReturn(aResponse()
.withStatus(200)
.withHeader("content-type", "application/json")
.withBody("{\"name\": \"John\"}")
)
);
}

@State("User exists")
public Map<String, Object> defaultState(Map<String, Object> params) {
LOGGER.debug("Provider state params = " + params);
return Collections.emptyMap();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"consumer": {
"name": "SomeConsumer"
},
"interactions": [
{
"description": "Hello John",
"providerStates": [
{
"name": "User exists",
"params": {
"name": "John"
}
}
],
"request": {
"generators": {
"path": {
"dataType": "STRING",
"expression": "/api/hello/${name}",
"type": "ProviderState"
}
},
"method": "GET",
"path": "/api/hello/James"
},
"response": {
"body": {
"name": "John"
},
"headers": {
"Content-Type": "application/json"
},
"status": 200
}
}
],
"metadata": {
"pact-jvm": {
"version": "4.6.7"
},
"pactSpecification": {
"version": "3.0.0"
}
},
"provider": {
"name": "ProviderStateParametersInjected"
}
}
30 changes: 24 additions & 6 deletions provider/src/main/kotlin/au/com/dius/pact/provider/StateChange.kt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ object DefaultStateChange : StateChange, KLogging() {
} else {
val result = stateChange(verifier, ProviderState(""), provider, consumer, true, providerClient)
logger.debug { "State Change: \"\" -> $result" }
stateChangeResult.mapEither({
result.mapEither({
stateChangeResult.unwrap().plus(it)
}, {
failures[message] = it.message.toString()
Expand All @@ -101,7 +101,7 @@ object DefaultStateChange : StateChange, KLogging() {
return StateChangeResult(stateChangeResult, message)
}

@Suppress("TooGenericExceptionCaught", "ReturnCount")
@Suppress("TooGenericExceptionCaught", "ReturnCount", "ComplexMethod", "LongParameterList")
override fun stateChange(
verifier: IProviderVerifier,
state: ProviderState,
Expand All @@ -112,7 +112,10 @@ object DefaultStateChange : StateChange, KLogging() {
): Result<Map<String, Any?>, Exception> {
verifier.reportStateForInteraction(state.name.toString(), provider, consumer, isSetup)

logger.debug { "stateChangeHandler: consumer.stateChange=${consumer.stateChange}, provider.stateChangeUrl=${provider.stateChangeUrl}" }
logger.debug {
"stateChangeHandler: consumer.stateChange=${consumer.stateChange}, " +
"provider.stateChangeUrl=${provider.stateChangeUrl}"
}
try {
var stateChangeHandler = consumer.stateChange
var stateChangeUsesBody = consumer.stateChangeUsesBody
Expand All @@ -135,12 +138,26 @@ object DefaultStateChange : StateChange, KLogging() {
}
logger.debug { "Invoked state change closure -> $result" }
if (result !is URL) {
return Result.Ok(if (result is Map<*, *>) result as Map<String, Any> else emptyMap())
val map = if (result is Map<*, *>) {
state.params + (result as Map<String, Any?>)
} else {
state.params
}
return Result.Ok(map)
}
stateChangeHandler = result
}
return executeHttpStateChangeRequest(verifier, stateChangeHandler, stateChangeUsesBody, state, provider, isSetup,
providerClient)

val stateChangeResult = executeHttpStateChangeRequest(
verifier, stateChangeHandler, stateChangeUsesBody, state, provider, isSetup,
providerClient
)
return when (stateChangeResult) {
is Result.Ok -> {
Result.Ok(state.params + stateChangeResult.value)
}
is Result.Err -> stateChangeResult
}
} catch (e: Exception) {
verifier.reportStateChangeFailed(state, e, isSetup)
return Result.Err(e)
Expand Down Expand Up @@ -192,6 +209,7 @@ object DefaultStateChange : StateChange, KLogging() {
}
} ?: Result.Ok(emptyMap())
} catch (ex: URISyntaxException) {
logger.error(ex) { "State change request is not valid" }
verifier.reporters.forEach {
it.warnStateChangeIgnoredDueToInvalidUrl(state.name.toString(), provider, isSetup, stateChangeHandler)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import au.com.dius.pact.core.model.Interaction
import au.com.dius.pact.core.model.OptionalBody
import au.com.dius.pact.core.model.ProviderState
import au.com.dius.pact.core.support.Result
import groovy.json.JsonOutput
import org.apache.hc.core5.http.ClassicHttpResponse
import org.apache.hc.core5.http.HttpEntity
import spock.lang.Specification
Expand Down Expand Up @@ -123,6 +124,41 @@ class StateChangeSpec extends Specification {
closureArgs == [state]
}

def 'if the state change is a closure, returns the result from the closure if it is a Map'() {
given:
def value = [
a: 100,
b: '200'
]
consumerMap.stateChange = { arg -> value }

when:
def result = DefaultStateChange.INSTANCE.stateChange(providerVerifier, state, providerInfo, consumer(), true,
mockProviderClient)

then:
result instanceof Result.Ok
result.value == [a: 100, b: '200']
}

def 'if the state change is a closure, falls back to the state change parameters for state change results'() {
given:
def value = [
a: 100,
b: '200'
]
consumerMap.stateChange = { arg -> value }
state = new ProviderState('there is a state', [a: 1, b: 2, c: 'test'])

when:
def result = DefaultStateChange.INSTANCE.stateChange(providerVerifier, state, providerInfo, consumer(), true,
mockProviderClient)

then:
result instanceof Result.Ok
result.value == [a: 100, b: '200', c: 'test']
}

def 'if the state change is a string that is not handled by the other conditions, does nothing'() {
given:
consumerMap.stateChange = 'blah blah blah'
Expand Down Expand Up @@ -157,4 +193,59 @@ class StateChangeSpec extends Specification {
[new URI('http://localhost:2000/hello'), stateTwo, true, true, false]
]
}

def 'returns the result of the state change call if the result can be converted to a Map'() {
given:
consumerMap.stateChange = 'http://localhost:2000/state-change'
def stateResult = JsonOutput.toJson([
a: 100,
b: '200'
])
def entity = [
getContentType: { 'application/json' },
getContentLength: { stateResult.bytes.length as long },
getContent: { new ByteArrayInputStream(stateResult.bytes) }
] as HttpEntity
stateChangeResponse = [
getEntity: { entity },
getCode: { 200 },
close: { }
] as ClassicHttpResponse

when:
def result = DefaultStateChange.INSTANCE.stateChange(providerVerifier, state, providerInfo, consumer(), true,
mockProviderClient)

then:
result instanceof Result.Ok
result.value == [a: 100, b: '200']
}

def 'falls back to the state change parameters for state change results'() {
given:
consumerMap.stateChange = 'http://localhost:2000/state-change'
def stateResult = JsonOutput.toJson([
a: 100,
b: '200'
])
def entity = [
getContentType: { 'application/json' },
getContentLength: { stateResult.bytes.length as long },
getContent: { new ByteArrayInputStream(stateResult.bytes) }
] as HttpEntity
stateChangeResponse = [
getEntity: { entity },
getCode: { 200 },
close: { }
] as ClassicHttpResponse
state = new ProviderState('there is a state', [a: 1, b: 2, c: 'test'])

when:
def result = DefaultStateChange.INSTANCE.stateChange(providerVerifier, state, providerInfo, consumer(), true,
mockProviderClient)

then:
result instanceof Result.Ok
result.value == [a: 100, b: '200', c: 'test']
}
}

0 comments on commit 7550a6d

Please sign in to comment.