From ae420978dcabc9d11d4f617f67d7c2ff7689007b Mon Sep 17 00:00:00 2001 From: Chris Banes Date: Mon, 27 May 2024 13:52:58 +0100 Subject: [PATCH] Fix saveable state being restored when using reset root navigation events (#1354) This PR migrates `NavigableCircuitContent` to use `SaveableStateHolder`, rather than our hand-rolled `SaveableStateRegistryBackStackRecordLocalProvider`. I don't know the history behind `SaveableStateRegistryBackStackRecordLocalProvider`, but `SaveableStateHolder` is the first party solution for this. It is used by AndroidX Navigation, so we can assume it is well tested. This PR relies on a bunch of `movableContent` fixes added in #1282 (I had to add a similar one in this PR for `CupertinoGestureNavigationDecoration`). Fixes #1342 --- .../BackStackRecordLocalProvider.android.kt | 2 +- .../BackStackRecordLocalProvider.js.kt | 2 +- ...ateRegistryBackStackRecordLocalProvider.kt | 162 ------------------ .../BackStackRecordLocalProvider.ios.kt | 2 +- .../BackStackRecordLocalProvider.jvm.kt | 2 +- .../NavigableCircuitSaveableStateTest.kt | 112 ++++++++++++ .../foundation/NavigableCircuitContent.kt | 17 +- .../CupertinoGestureNavigationDecoration.kt | 6 +- 8 files changed, 132 insertions(+), 173 deletions(-) delete mode 100644 backstack/src/commonMain/kotlin/com/slack/circuit/backstack/SaveableStateRegistryBackStackRecordLocalProvider.kt diff --git a/backstack/src/androidMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.android.kt b/backstack/src/androidMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.android.kt index 4e0489f0a..5f617e5c2 100644 --- a/backstack/src/androidMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.android.kt +++ b/backstack/src/androidMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.android.kt @@ -2,4 +2,4 @@ package com.slack.circuit.backstack internal actual val defaultBackStackRecordLocalProviders: List> = - listOf(SaveableStateRegistryBackStackRecordLocalProvider, ViewModelBackStackRecordLocalProvider) + listOf(ViewModelBackStackRecordLocalProvider) diff --git a/backstack/src/browserMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.js.kt b/backstack/src/browserMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.js.kt index 7ac4d7d2b..dc5bac5de 100644 --- a/backstack/src/browserMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.js.kt +++ b/backstack/src/browserMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.js.kt @@ -2,4 +2,4 @@ package com.slack.circuit.backstack internal actual val defaultBackStackRecordLocalProviders: List> = - listOf(SaveableStateRegistryBackStackRecordLocalProvider) + emptyList() diff --git a/backstack/src/commonMain/kotlin/com/slack/circuit/backstack/SaveableStateRegistryBackStackRecordLocalProvider.kt b/backstack/src/commonMain/kotlin/com/slack/circuit/backstack/SaveableStateRegistryBackStackRecordLocalProvider.kt deleted file mode 100644 index b240d589c..000000000 --- a/backstack/src/commonMain/kotlin/com/slack/circuit/backstack/SaveableStateRegistryBackStackRecordLocalProvider.kt +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (C) 2022 Adam Powell - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.slack.circuit.backstack - -import androidx.compose.runtime.Composable -import androidx.compose.runtime.ProvidedValue -import androidx.compose.runtime.RememberObserver -import androidx.compose.runtime.getValue -import androidx.compose.runtime.mutableStateMapOf -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.remember -import androidx.compose.runtime.saveable.LocalSaveableStateRegistry -import androidx.compose.runtime.saveable.SaveableStateRegistry -import androidx.compose.runtime.saveable.Saver -import androidx.compose.runtime.saveable.mapSaver -import androidx.compose.runtime.saveable.rememberSaveable -import androidx.compose.runtime.setValue -import androidx.compose.runtime.snapshots.SnapshotStateMap -import kotlinx.atomicfu.locks.SynchronizedObject -import kotlinx.atomicfu.locks.synchronized -import kotlinx.collections.immutable.ImmutableList -import kotlinx.collections.immutable.persistentListOf - -/** A [BackStackRecordLocalProvider] that provides a [SaveableStateRegistry] for each record. */ -public object SaveableStateRegistryBackStackRecordLocalProvider : - BackStackRecordLocalProvider { - @Composable - override fun providedValuesFor(record: BackStack.Record): ProvidedValues { - val childRegistry = - rememberSaveable( - record, - saver = BackStackRecordLocalSaveableStateRegistry.Saver, - key = record.key, - ) { - BackStackRecordLocalSaveableStateRegistry(mutableStateMapOf()) - } - // This write depends on childRegistry.parentRegistry being snapshot state backed - childRegistry.parentRegistry = LocalSaveableStateRegistry.current - return remember(childRegistry) { - object : ProvidedValues { - val list = persistentListOf(LocalSaveableStateRegistry provides childRegistry) - - @Composable - override fun provideValues(): ImmutableList> { - remember { RememberObserverImpl(childRegistry) } - return list - } - } - } - } -} - -// Extracted to work around a WASM bug -// https://youtrack.jetbrains.com/issue/KT-66465#focus=Comments-27-9568825.0-0 -private class RememberObserverImpl( - private val childRegistry: BackStackRecordLocalSaveableStateRegistry -) : RememberObserver { - override fun onForgotten() { - childRegistry.saveForContentLeavingComposition() - } - - override fun onRemembered() {} - - override fun onAbandoned() {} -} - -private class BackStackRecordLocalSaveableStateRegistry( - // Note: restored is snapshot-backed because consumeRestored runs in composition - // and must be rolled back if composition does not commit - private val restored: SnapshotStateMap> -) : SaveableStateRegistry { - var parentRegistry: SaveableStateRegistry? by mutableStateOf(null) - - private val valueProviders = mutableMapOf Any?>>() - - private val lock = SynchronizedObject() - - override fun canBeSaved(value: Any): Boolean = parentRegistry?.canBeSaved(value) != false - - override fun consumeRestored(key: String): Any? = - restored.remove(key)?.let { list -> - list.first().also { - if (list.size > 1) { - restored[key] = list.drop(1) - } - } - } - - override fun performSave(): Map> { - val map = restored.toMutableMap() - saveInto(map) - return map - } - - override fun registerProvider( - key: String, - valueProvider: () -> Any?, - ): SaveableStateRegistry.Entry { - require(key.isNotBlank()) { "Registered key is empty or blank" } - synchronized(lock) { valueProviders.getOrPut(key) { mutableListOf() }.add(valueProvider) } - return object : SaveableStateRegistry.Entry { - override fun unregister() { - synchronized(lock) { - val list = valueProviders.remove(key) - list?.remove(valueProvider) - if (!list.isNullOrEmpty()) { - // if there are other providers for this key return list - // back to the map - valueProviders[key] = list - } - } - } - } - } - - fun saveForContentLeavingComposition() { - saveInto(restored) - } - - private fun saveInto(map: MutableMap>) { - synchronized(lock) { - valueProviders.forEach { (key, list) -> - if (list.size == 1) { - val value = list[0].invoke() - if (value != null) { - map[key] = arrayListOf(value) - } - } else { - // nulls hold empty spaces - map[key] = list.map { it() } - } - } - } - } - - companion object { - val Saver = - mapSaver( - save = { value -> value.performSave() }, - restore = { value -> - BackStackRecordLocalSaveableStateRegistry( - mutableStateMapOf>().apply { - @Suppress("UNCHECKED_CAST") putAll(value as Map>) - } - ) - }, - ) - } -} diff --git a/backstack/src/iOSMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.ios.kt b/backstack/src/iOSMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.ios.kt index b8c68b7f0..a70fccf8c 100644 --- a/backstack/src/iOSMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.ios.kt +++ b/backstack/src/iOSMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.ios.kt @@ -6,4 +6,4 @@ package com.slack.circuit.backstack */ internal actual val defaultBackStackRecordLocalProviders: List> - get() = listOf(SaveableStateRegistryBackStackRecordLocalProvider) + get() = emptyList() diff --git a/backstack/src/jvmMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.jvm.kt b/backstack/src/jvmMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.jvm.kt index 7ac4d7d2b..dc5bac5de 100644 --- a/backstack/src/jvmMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.jvm.kt +++ b/backstack/src/jvmMain/kotlin/com/slack/circuit/backstack/BackStackRecordLocalProvider.jvm.kt @@ -2,4 +2,4 @@ package com.slack.circuit.backstack internal actual val defaultBackStackRecordLocalProviders: List> = - listOf(SaveableStateRegistryBackStackRecordLocalProvider) + emptyList() diff --git a/circuit-foundation/src/commonJvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitSaveableStateTest.kt b/circuit-foundation/src/commonJvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitSaveableStateTest.kt index 30703d057..a18479db8 100644 --- a/circuit-foundation/src/commonJvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitSaveableStateTest.kt +++ b/circuit-foundation/src/commonJvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitSaveableStateTest.kt @@ -12,6 +12,8 @@ import com.slack.circuit.internal.test.TestContentTags.TAG_GO_NEXT import com.slack.circuit.internal.test.TestContentTags.TAG_INCREASE_COUNT import com.slack.circuit.internal.test.TestContentTags.TAG_LABEL import com.slack.circuit.internal.test.TestContentTags.TAG_POP +import com.slack.circuit.internal.test.TestContentTags.TAG_RESET_ROOT_ALPHA +import com.slack.circuit.internal.test.TestContentTags.TAG_RESET_ROOT_BETA import com.slack.circuit.internal.test.TestCountPresenter import com.slack.circuit.internal.test.TestScreen import com.slack.circuit.internal.test.createTestCircuit @@ -28,6 +30,14 @@ class NavigableCircuitSaveableStateTest { @Test fun saveableStateScopedToBackstackWithoutKeys() = saveableStateScopedToBackstack(false) + @Test + fun saveableStateScopedToBackstackResetRootsWithKeys() = + saveableStateScopedToBackstackResetRoots(true) + + @Test + fun saveableStateScopedToBackstackResetRootsWithoutKeys() = + saveableStateScopedToBackstackResetRoots(false) + private fun saveableStateScopedToBackstack(useKeys: Boolean) { composeTestRule.run { val circuit = @@ -96,4 +106,106 @@ class NavigableCircuitSaveableStateTest { onNodeWithTag(TAG_COUNT).assertTextEquals("0") } } + + private fun saveableStateScopedToBackstackResetRoots(useKeys: Boolean) { + composeTestRule.run { + val circuit = + createTestCircuit( + useKeys = useKeys, + rememberType = TestCountPresenter.RememberType.Saveable, + saveStateOnRootChange = true, + restoreStateOnRootChange = true, + ) + + setContent { + CircuitCompositionLocals(circuit) { + val backStack = rememberSaveableBackStack(TestScreen.RootAlpha) + val navigator = + rememberCircuitNavigator( + backStack = backStack, + onRootPop = {}, // no-op for tests + ) + NavigableCircuitContent(navigator = navigator, backStack = backStack) + } + } + + // Current: Root Alpha. Navigate to Screen A + onNodeWithTag(TAG_LABEL).assertTextEquals("Root Alpha") + onNodeWithTag(TAG_GO_NEXT).performClick() + + // Current: Screen A. Increase count to 1 + onNodeWithTag(TAG_LABEL).assertTextEquals("A") + onNodeWithTag(TAG_COUNT).assertTextEquals("0") + onNodeWithTag(TAG_INCREASE_COUNT).performClick() + onNodeWithTag(TAG_COUNT).assertTextEquals("1") + + // Navigate to Screen B. Increase count to 1 + onNodeWithTag(TAG_GO_NEXT).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("B") + onNodeWithTag(TAG_COUNT).assertTextEquals("0") + onNodeWithTag(TAG_INCREASE_COUNT).performClick() + onNodeWithTag(TAG_COUNT).assertTextEquals("1") + + // Navigate to Screen C. Increase count to 1 + onNodeWithTag(TAG_GO_NEXT).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("C") + onNodeWithTag(TAG_COUNT).assertTextEquals("0") + onNodeWithTag(TAG_INCREASE_COUNT).performClick() + onNodeWithTag(TAG_COUNT).assertTextEquals("1") + + // Pop to Screen B. Increase count from 1 to 2. + onNodeWithTag(TAG_POP).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("B") + onNodeWithTag(TAG_COUNT).assertTextEquals("1") + onNodeWithTag(TAG_INCREASE_COUNT).performClick() + onNodeWithTag(TAG_COUNT).assertTextEquals("2") + + // Navigate to Screen C. Assert that it's state was not retained + onNodeWithTag(TAG_GO_NEXT).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("C") + onNodeWithTag(TAG_COUNT).assertTextEquals("0") + + // Pop to Screen B. Assert that it's state was retained + onNodeWithTag(TAG_POP).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("B") + onNodeWithTag(TAG_COUNT).assertTextEquals("2") + + // So at this point: + // Active: Root Alpha, Screen A (count: 1), Screen B: (count: 2) + // Retained: empty + + // Let's switch to Root B + onNodeWithTag(TAG_RESET_ROOT_BETA).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("Root Beta") + + // Navigate to Screen A, and increase count to 2 + onNodeWithTag(TAG_GO_NEXT).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("A") + onNodeWithTag(TAG_COUNT).assertTextEquals("0") + onNodeWithTag(TAG_INCREASE_COUNT).performClick() + onNodeWithTag(TAG_INCREASE_COUNT).performClick() + onNodeWithTag(TAG_COUNT).assertTextEquals("2") + + // So at this point: + // Active: Root Beta, Screen A (count: 2) + // Retained: Root Alpha, Screen A (count: 1), Screen B: (count: 2) + + // Let's switch back to Root Alpha + onNodeWithTag(TAG_RESET_ROOT_ALPHA).performClick() + // Root Alpha should now be active. The top record for Root Alpha is Screen B: (count: 2) + onNodeWithTag(TAG_LABEL).assertTextEquals("B") + onNodeWithTag(TAG_COUNT).assertTextEquals("2") + + // Pop to Screen A + onNodeWithTag(TAG_POP).performClick() + onNodeWithTag(TAG_LABEL).assertTextEquals("A") + onNodeWithTag(TAG_COUNT).assertTextEquals("1") + + // Let's switch back to Root B + onNodeWithTag(TAG_RESET_ROOT_BETA).performClick() + // Root Beta should now be active. The top record for Root Beta is Screen A: (count: 2) + onNodeWithTag(TAG_LABEL).assertTextEquals("A") + onNodeWithTag(TAG_COUNT).assertTextEquals("2") + } + } } diff --git a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt index 12c7b9fef..d6ef39c40 100644 --- a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt +++ b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt @@ -26,6 +26,7 @@ import androidx.compose.runtime.getValue import androidx.compose.runtime.movableContentOf import androidx.compose.runtime.remember import androidx.compose.runtime.rememberUpdatedState +import androidx.compose.runtime.saveable.rememberSaveableStateHolder import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import com.slack.circuit.backstack.BackStack @@ -102,17 +103,21 @@ public fun NavigableCircuitContent( val outerKey = "_navigable_registry_${currentCompositeKeyHash.toString(MaxSupportedRadix)}" val outerRegistry = rememberRetained(key = outerKey) { RetainedStateRegistry() } + val saveableStateHolder = rememberSaveableStateHolder() + CompositionLocalProvider(LocalRetainedStateRegistry provides outerRegistry) { decoration.DecoratedContent(activeContentProviders, backStack.size, modifier) { provider -> val record = provider.record - // Remember the `providedValues` lookup because this composition can live longer than - // the record is present in the backstack, if the decoration is animated for example. - val values = remember(record) { providedValues[record] }?.provideValues() - val providedLocals = remember(values) { values?.toTypedArray() ?: emptyArray() } + saveableStateHolder.SaveableStateProvider(record.key) { + // Remember the `providedValues` lookup because this composition can live longer than + // the record is present in the backstack, if the decoration is animated for example. + val values = remember(record) { providedValues[record] }?.provideValues() + val providedLocals = remember(values) { values?.toTypedArray() ?: emptyArray() } - CompositionLocalProvider(LocalBackStack provides backStack, *providedLocals) { - provider.content(record) + CompositionLocalProvider(LocalBackStack provides backStack, *providedLocals) { + provider.content(record) + } } } } diff --git a/circuitx/gesture-navigation/src/commonMain/kotlin/com/slack/circuitx/gesturenavigation/CupertinoGestureNavigationDecoration.kt b/circuitx/gesture-navigation/src/commonMain/kotlin/com/slack/circuitx/gesturenavigation/CupertinoGestureNavigationDecoration.kt index 58281987d..53a6c193c 100644 --- a/circuitx/gesture-navigation/src/commonMain/kotlin/com/slack/circuitx/gesturenavigation/CupertinoGestureNavigationDecoration.kt +++ b/circuitx/gesture-navigation/src/commonMain/kotlin/com/slack/circuitx/gesturenavigation/CupertinoGestureNavigationDecoration.kt @@ -105,7 +105,11 @@ public class CupertinoGestureNavigationDecoration( label = "GestureNavDecoration", ) - if (previous != null && !transition.isStateBeingAnimated { it.record == previous }) { + if ( + previous != null && + !transition.isPending && + !transition.isStateBeingAnimated { it.record == previous } + ) { // We display the 'previous' item in the back stack for when the user performs a gesture // to go back. // We only display it here if the transition is not running. When the transition is