diff --git a/src/main/kotlin/com/tinder/StateMachine.kt b/src/main/kotlin/com/tinder/StateMachine.kt index 4b785c7..6a316af 100644 --- a/src/main/kotlin/com/tinder/StateMachine.kt +++ b/src/main/kotlin/com/tinder/StateMachine.kt @@ -51,7 +51,14 @@ class StateMachine private construc private fun STATE.getDefinition() = graph.stateDefinitions .filter { it.key.matches(this) } .map { it.value } - .firstOrNull() ?: error("Missing definition for state ${this.javaClass.simpleName}!") + .also { if (it.isEmpty()) error("Missing definition for state ${this.javaClass.simpleName}!") } + .fold(Graph.State()) { acc, state -> + acc.apply { + onEnterListeners.addAll(state.onEnterListeners) + onExitListeners.addAll(state.onExitListeners) + transitions.putAll(state.transitions) + } + } private fun STATE.notifyOnEnter(cause: EVENT) { getDefinition().onEnterListeners.forEach { it(this, cause) } diff --git a/src/test/kotlin/com/tinder/StateMachineTest.kt b/src/test/kotlin/com/tinder/StateMachineTest.kt index 5d6e579..67199a3 100644 --- a/src/test/kotlin/com/tinder/StateMachineTest.kt +++ b/src/test/kotlin/com/tinder/StateMachineTest.kt @@ -719,6 +719,83 @@ internal class StateMachineTest { } } + class WithSplitStateDefinition { + private val firstDefinitionOnEnterListener = mock Unit>() + private val secondDefinitionOnEnterListener = mock Unit>() + private val firstDefinitionOnExitListener = mock Unit>() + private val secondDefinitionOnExitListener = mock Unit>() + + private val stateMachine = StateMachine.create { + initialState(STATE_A) + + state(STATE_A) { + on(EVENT_1) { + transitionTo(STATE_B) + } + + onExit(firstDefinitionOnExitListener) + } + + state(STATE_A) { + on(EVENT_2) { + transitionTo(STATE_B) + } + + onExit(secondDefinitionOnExitListener) + } + + state(STATE_B) { + onEnter(firstDefinitionOnEnterListener) + } + + state(STATE_B) { + onEnter(secondDefinitionOnEnterListener) + } + } + + @Test + fun transition_givenFirstDefinitionEvent_shouldReturnValidTransition() { + // When + val transition = stateMachine.transition(EVENT_1) + + // Then + assertThat(transition).isEqualTo( + StateMachine.Transition.Valid(STATE_A, EVENT_1, STATE_B, null) + ) + } + + @Test + fun transition_givenSecondDefinitionEvent_shouldReturnValidTransition() { + // When + val transition = stateMachine.transition(EVENT_2) + + // Then + assertThat(transition).isEqualTo( + StateMachine.Transition.Valid(STATE_A, EVENT_2, STATE_B, null) + ) + } + + @Test + fun transition_givenValidEvent_shouldTriggerOnExitListeners() { + // When + stateMachine.transition(EVENT_1) + + // Then + then(firstDefinitionOnExitListener).should().invoke(STATE_A, EVENT_1) + then(secondDefinitionOnExitListener).should().invoke(STATE_A, EVENT_1) + } + + @Test + fun transition_givenValidEvent_shouldTriggerOnEnterListeners() { + // When + stateMachine.transition(EVENT_1) + + // Then + then(firstDefinitionOnEnterListener).should().invoke(STATE_B, EVENT_1) + then(secondDefinitionOnEnterListener).should().invoke(STATE_B, EVENT_1) + } + } + private companion object { private const val STATE_A = "a" private const val STATE_B = "b" @@ -733,5 +810,4 @@ internal class StateMachineTest { private const val SIDE_EFFECT_1 = "alpha" } } - }