Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix saveable state being restored when using reset root navigation events #1354

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package com.slack.circuit.backstack

internal actual val defaultBackStackRecordLocalProviders:
List<BackStackRecordLocalProvider<BackStack.Record>> =
listOf(SaveableStateRegistryBackStackRecordLocalProvider, ViewModelBackStackRecordLocalProvider)
listOf(ViewModelBackStackRecordLocalProvider)
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package com.slack.circuit.backstack

internal actual val defaultBackStackRecordLocalProviders:
List<BackStackRecordLocalProvider<BackStack.Record>> =
listOf(SaveableStateRegistryBackStackRecordLocalProvider)
emptyList()

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ package com.slack.circuit.backstack
*/
internal actual val defaultBackStackRecordLocalProviders:
List<BackStackRecordLocalProvider<BackStack.Record>>
get() = listOf(SaveableStateRegistryBackStackRecordLocalProvider)
get() = emptyList()
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ package com.slack.circuit.backstack

internal actual val defaultBackStackRecordLocalProviders:
List<BackStackRecordLocalProvider<BackStack.Record>> =
listOf(SaveableStateRegistryBackStackRecordLocalProvider)
emptyList()
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import com.slack.circuit.internal.test.TestContentTags.TAG_GO_NEXT
import com.slack.circuit.internal.test.TestContentTags.TAG_INCREASE_COUNT
import com.slack.circuit.internal.test.TestContentTags.TAG_LABEL
import com.slack.circuit.internal.test.TestContentTags.TAG_POP
import com.slack.circuit.internal.test.TestContentTags.TAG_RESET_ROOT_ALPHA
import com.slack.circuit.internal.test.TestContentTags.TAG_RESET_ROOT_BETA
import com.slack.circuit.internal.test.TestCountPresenter
import com.slack.circuit.internal.test.TestScreen
import com.slack.circuit.internal.test.createTestCircuit
Expand All @@ -28,6 +30,14 @@ class NavigableCircuitSaveableStateTest {

@Test fun saveableStateScopedToBackstackWithoutKeys() = saveableStateScopedToBackstack(false)

@Test
fun saveableStateScopedToBackstackResetRootsWithKeys() =
saveableStateScopedToBackstackResetRoots(true)

@Test
fun saveableStateScopedToBackstackResetRootsWithoutKeys() =
saveableStateScopedToBackstackResetRoots(false)

private fun saveableStateScopedToBackstack(useKeys: Boolean) {
composeTestRule.run {
val circuit =
Expand Down Expand Up @@ -96,4 +106,106 @@ class NavigableCircuitSaveableStateTest {
onNodeWithTag(TAG_COUNT).assertTextEquals("0")
}
}

private fun saveableStateScopedToBackstackResetRoots(useKeys: Boolean) {
composeTestRule.run {
val circuit =
createTestCircuit(
useKeys = useKeys,
rememberType = TestCountPresenter.RememberType.Saveable,
saveStateOnRootChange = true,
restoreStateOnRootChange = true,
)

setContent {
CircuitCompositionLocals(circuit) {
val backStack = rememberSaveableBackStack(TestScreen.RootAlpha)
val navigator =
rememberCircuitNavigator(
backStack = backStack,
onRootPop = {}, // no-op for tests
)
NavigableCircuitContent(navigator = navigator, backStack = backStack)
}
}

// Current: Root Alpha. Navigate to Screen A
onNodeWithTag(TAG_LABEL).assertTextEquals("Root Alpha")
onNodeWithTag(TAG_GO_NEXT).performClick()

// Current: Screen A. Increase count to 1
onNodeWithTag(TAG_LABEL).assertTextEquals("A")
onNodeWithTag(TAG_COUNT).assertTextEquals("0")
onNodeWithTag(TAG_INCREASE_COUNT).performClick()
onNodeWithTag(TAG_COUNT).assertTextEquals("1")

// Navigate to Screen B. Increase count to 1
onNodeWithTag(TAG_GO_NEXT).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("B")
onNodeWithTag(TAG_COUNT).assertTextEquals("0")
onNodeWithTag(TAG_INCREASE_COUNT).performClick()
onNodeWithTag(TAG_COUNT).assertTextEquals("1")

// Navigate to Screen C. Increase count to 1
onNodeWithTag(TAG_GO_NEXT).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("C")
onNodeWithTag(TAG_COUNT).assertTextEquals("0")
onNodeWithTag(TAG_INCREASE_COUNT).performClick()
onNodeWithTag(TAG_COUNT).assertTextEquals("1")

// Pop to Screen B. Increase count from 1 to 2.
onNodeWithTag(TAG_POP).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("B")
onNodeWithTag(TAG_COUNT).assertTextEquals("1")
onNodeWithTag(TAG_INCREASE_COUNT).performClick()
onNodeWithTag(TAG_COUNT).assertTextEquals("2")

// Navigate to Screen C. Assert that it's state was not retained
onNodeWithTag(TAG_GO_NEXT).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("C")
onNodeWithTag(TAG_COUNT).assertTextEquals("0")

// Pop to Screen B. Assert that it's state was retained
onNodeWithTag(TAG_POP).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("B")
onNodeWithTag(TAG_COUNT).assertTextEquals("2")

// So at this point:
// Active: Root Alpha, Screen A (count: 1), Screen B: (count: 2)
// Retained: empty

// Let's switch to Root B
onNodeWithTag(TAG_RESET_ROOT_BETA).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("Root Beta")

// Navigate to Screen A, and increase count to 2
onNodeWithTag(TAG_GO_NEXT).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("A")
onNodeWithTag(TAG_COUNT).assertTextEquals("0")
onNodeWithTag(TAG_INCREASE_COUNT).performClick()
onNodeWithTag(TAG_INCREASE_COUNT).performClick()
onNodeWithTag(TAG_COUNT).assertTextEquals("2")

// So at this point:
// Active: Root Beta, Screen A (count: 2)
// Retained: Root Alpha, Screen A (count: 1), Screen B: (count: 2)

// Let's switch back to Root Alpha
onNodeWithTag(TAG_RESET_ROOT_ALPHA).performClick()
// Root Alpha should now be active. The top record for Root Alpha is Screen B: (count: 2)
onNodeWithTag(TAG_LABEL).assertTextEquals("B")
onNodeWithTag(TAG_COUNT).assertTextEquals("2")

// Pop to Screen A
onNodeWithTag(TAG_POP).performClick()
onNodeWithTag(TAG_LABEL).assertTextEquals("A")
onNodeWithTag(TAG_COUNT).assertTextEquals("1")

// Let's switch back to Root B
onNodeWithTag(TAG_RESET_ROOT_BETA).performClick()
// Root Beta should now be active. The top record for Root Beta is Screen A: (count: 2)
onNodeWithTag(TAG_LABEL).assertTextEquals("A")
onNodeWithTag(TAG_COUNT).assertTextEquals("2")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import androidx.compose.runtime.getValue
import androidx.compose.runtime.movableContentOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberUpdatedState
import androidx.compose.runtime.saveable.rememberSaveableStateHolder
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import com.slack.circuit.backstack.BackStack
Expand Down Expand Up @@ -102,17 +103,21 @@ public fun <R : Record> NavigableCircuitContent(
val outerKey = "_navigable_registry_${currentCompositeKeyHash.toString(MaxSupportedRadix)}"
val outerRegistry = rememberRetained(key = outerKey) { RetainedStateRegistry() }

val saveableStateHolder = rememberSaveableStateHolder()

CompositionLocalProvider(LocalRetainedStateRegistry provides outerRegistry) {
decoration.DecoratedContent(activeContentProviders, backStack.size, modifier) { provider ->
val record = provider.record

// Remember the `providedValues` lookup because this composition can live longer than
// the record is present in the backstack, if the decoration is animated for example.
val values = remember(record) { providedValues[record] }?.provideValues()
val providedLocals = remember(values) { values?.toTypedArray() ?: emptyArray() }
saveableStateHolder.SaveableStateProvider(record.key) {
// Remember the `providedValues` lookup because this composition can live longer than
// the record is present in the backstack, if the decoration is animated for example.
val values = remember(record) { providedValues[record] }?.provideValues()
val providedLocals = remember(values) { values?.toTypedArray() ?: emptyArray() }

CompositionLocalProvider(LocalBackStack provides backStack, *providedLocals) {
provider.content(record)
CompositionLocalProvider(LocalBackStack provides backStack, *providedLocals) {
provider.content(record)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ public class CupertinoGestureNavigationDecoration(
label = "GestureNavDecoration",
)

if (previous != null && !transition.isStateBeingAnimated { it.record == previous }) {
if (
previous != null &&
!transition.isPending &&
!transition.isStateBeingAnimated { it.record == previous }
) {
// We display the 'previous' item in the back stack for when the user performs a gesture
// to go back.
// We only display it here if the transition is not running. When the transition is
Expand Down