diff --git a/build.gradle.kts b/build.gradle.kts index a097d288f..d3b900307 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -106,6 +106,8 @@ allprojects { "**/Remove.kt", "**/Pets.kt", "**/SystemUiController.kt", + "**/RetainedStateHolderTest.kt", + "**/RetainedStateRestorationTester.kt", ) } } diff --git a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt index 2b13d1005..f46a49f44 100644 --- a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt +++ b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/NavigableCircuitContent.kt @@ -48,6 +48,7 @@ import com.slack.circuit.retained.LocalCanRetainChecker import com.slack.circuit.retained.LocalRetainedStateRegistry import com.slack.circuit.retained.RetainedStateRegistry import com.slack.circuit.retained.rememberRetained +import com.slack.circuit.retained.rememberRetainedStateHolder import com.slack.circuit.runtime.InternalCircuitApi import com.slack.circuit.runtime.Navigator import com.slack.circuit.runtime.screen.Screen @@ -66,14 +67,6 @@ public fun NavigableCircuitContent( unavailableRoute: (@Composable (screen: Screen, modifier: Modifier) -> Unit) = circuit.onUnavailableContent, ) { - val activeContentProviders = - buildCircuitContentProviders( - backStack = backStack, - navigator = navigator, - circuit = circuit, - unavailableRoute = unavailableRoute, - ) - if (backStack.isEmpty) return /* @@ -110,10 +103,17 @@ public fun NavigableCircuitContent( */ val outerKey = "_navigable_registry_${currentCompositeKeyHash.toString(MaxSupportedRadix)}" val outerRegistry = rememberRetained(key = outerKey) { RetainedStateRegistry() } - val saveableStateHolder = rememberSaveableStateHolder() CompositionLocalProvider(LocalRetainedStateRegistry provides outerRegistry) { + val activeContentProviders = + buildCircuitContentProviders( + backStack = backStack, + navigator = navigator, + circuit = circuit, + unavailableRoute = unavailableRoute, + ) + decoration.DecoratedContent(activeContentProviders, backStack.size, modifier) { provider -> val record = provider.record @@ -175,6 +175,7 @@ private fun buildCircuitContentProviders( val lastNavigator by rememberUpdatedState(navigator) val lastCircuit by rememberUpdatedState(circuit) val lastUnavailableRoute by rememberUpdatedState(unavailableRoute) + val retainedStateHolder = rememberRetainedStateHolder() fun createRecordContent() = movableContentOf { record -> @@ -190,21 +191,16 @@ private fun buildCircuitContentProviders( // Now provide a new registry to the content for it to store any retained state in, // along with a retain checker which is always true (as upstream registries will // maintain the lifetime), and the other provided values - val recordRetainedStateRegistry = - rememberRetained(key = record.registryKey) { RetainedStateRegistry() } - - CompositionLocalProvider( - LocalRetainedStateRegistry provides recordRetainedStateRegistry, - LocalCanRetainChecker provides CanRetainChecker.Always, - LocalRecordLifecycle provides lifecycle, - ) { - CircuitContent( - screen = record.screen, - navigator = lastNavigator, - circuit = lastCircuit, - unavailableContent = lastUnavailableRoute, - key = record.key, - ) + retainedStateHolder.RetainedStateProvider(record.registryKey) { + CompositionLocalProvider(LocalRecordLifecycle provides lifecycle) { + CircuitContent( + screen = record.screen, + navigator = lastNavigator, + circuit = lastCircuit, + unavailableContent = lastUnavailableRoute, + key = record.key, + ) + } } } } diff --git a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/PausableState.kt b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/PausableState.kt index 410c7ba4a..788f9b09f 100644 --- a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/PausableState.kt +++ b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/PausableState.kt @@ -7,10 +7,6 @@ package com.slack.circuit.foundation import androidx.compose.runtime.Composable import androidx.compose.runtime.Stable import androidx.compose.runtime.remember -import com.slack.circuit.foundation.internal.withCompositionLocalProvider -import com.slack.circuit.retained.LocalRetainedStateRegistry -import com.slack.circuit.retained.RetainedStateRegistry -import com.slack.circuit.retained.rememberRetained import com.slack.circuit.runtime.CircuitUiState import com.slack.circuit.runtime.presenter.Presenter @@ -60,14 +56,13 @@ public fun pausableState( val state = remember(key) { MutableRef(null) } val saveableStateHolder = rememberSaveableStateHolderWithReturn() + val retainedStateHolder = rememberRetainedStateHolderWithReturn() return if (isActive || state.value == null) { - val retainedStateRegistry = rememberRetained(key = key) { RetainedStateRegistry() } - withCompositionLocalProvider(LocalRetainedStateRegistry provides retainedStateRegistry) { - saveableStateHolder.SaveableStateProvider( - key = key ?: "pausable_state", - content = produceState, - ) + val finalKey = key ?: "pausable_state" + saveableStateHolder + .SaveableStateProvider(finalKey) { + retainedStateHolder.RetainedStateProvider(key = finalKey, content = produceState) } .also { // Store the last emitted state diff --git a/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/RetainedStateHolder.kt b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/RetainedStateHolder.kt new file mode 100644 index 000000000..d9c3fee3f --- /dev/null +++ b/circuit-foundation/src/commonMain/kotlin/com/slack/circuit/foundation/RetainedStateHolder.kt @@ -0,0 +1,106 @@ +// Copyright (C) 2024 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package com.slack.circuit.foundation + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.DisposableEffect +import androidx.compose.runtime.key +import androidx.compose.runtime.remember +import com.slack.circuit.foundation.internal.withCompositionLocalProvider +import com.slack.circuit.retained.CanRetainChecker +import com.slack.circuit.retained.LocalCanRetainChecker +import com.slack.circuit.retained.LocalRetainedStateRegistry +import com.slack.circuit.retained.RetainedStateRegistry +import com.slack.circuit.retained.RetainedValueProvider +import com.slack.circuit.retained.rememberRetained + +/** Copy of [RetainedStateHolder] to return content value */ +internal interface RetainedStateHolder { + + @Composable fun RetainedStateProvider(key: String, content: @Composable () -> T): T + + fun removeState(key: String) +} + +/** Creates and remembers the instance of [RetainedStateHolder]. */ +@Composable +internal fun rememberRetainedStateHolderWithReturn(): RetainedStateHolder { + return rememberRetained { RetainedStateHolderImpl() } +} + +private class RetainedStateHolderImpl : RetainedStateHolder, RetainedStateRegistry { + + private val registry = RetainedStateRegistry() + + private val canRetainCheckers = mutableMapOf() + + @Composable + override fun RetainedStateProvider(key: String, content: @Composable (() -> T)): T { + return withCompositionLocalProvider(LocalRetainedStateRegistry provides registry) { + val parentCanRetainChecker = LocalCanRetainChecker.current ?: CanRetainChecker.Always + key(key) { + val entryCanRetainChecker = + remember(parentCanRetainChecker) { EntryCanRetainChecker(parentCanRetainChecker) } + val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() } + withCompositionLocalProvider( + LocalRetainedStateRegistry provides childRegistry, + LocalCanRetainChecker provides CanRetainChecker.Always, + content = content, + ) + .also { + DisposableEffect(Unit) { + canRetainCheckers[key] = entryCanRetainChecker + onDispose { + val retained = childRegistry.saveAll() + if (retained.isNotEmpty() && entryCanRetainChecker.canRetain(registry)) { + registry.saveValue(key) + } + canRetainCheckers -= key + } + } + } + } + } + } + + override fun removeState(key: String) { + val canRetainChecker = canRetainCheckers[key] + if (canRetainChecker != null) { + canRetainChecker.shouldSave = false + } else { + registry.consumeValue(key) + } + } + + override fun consumeValue(key: String): Any? { + return registry.consumeValue(key) + } + + override fun registerValue( + key: String, + valueProvider: RetainedValueProvider, + ): RetainedStateRegistry.Entry { + return registry.registerValue(key, valueProvider) + } + + override fun saveAll(): Map> { + return registry.saveAll() + } + + override fun saveValue(key: String) { + registry.saveValue(key) + } + + override fun forgetUnclaimedValues() { + registry.forgetUnclaimedValues() + } + + private class EntryCanRetainChecker(private val parentChecker: CanRetainChecker) : + CanRetainChecker { + + var shouldSave = true + + override fun canRetain(registry: RetainedStateRegistry): Boolean = + parentChecker.canRetain(registry) && shouldSave + } +} diff --git a/circuit-foundation/src/jvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitConditionalRetainTest.kt b/circuit-foundation/src/jvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitConditionalRetainTest.kt new file mode 100644 index 000000000..090c5bf89 --- /dev/null +++ b/circuit-foundation/src/jvmTest/kotlin/com/slack/circuit/foundation/NavigableCircuitConditionalRetainTest.kt @@ -0,0 +1,361 @@ +// Copyright (C) 2024 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package com.slack.circuit.foundation + +import androidx.compose.foundation.layout.Column +import androidx.compose.material.Button +import androidx.compose.material.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.testTag +import androidx.compose.ui.test.assertTextEquals +import androidx.compose.ui.test.junit4.ComposeContentTestRule +import androidx.compose.ui.test.junit4.createComposeRule +import androidx.compose.ui.test.onNodeWithTag +import androidx.compose.ui.test.performClick +import com.slack.circuit.backstack.rememberSaveableBackStack +import com.slack.circuit.retained.rememberRetained +import com.slack.circuit.runtime.CircuitUiEvent +import com.slack.circuit.runtime.CircuitUiState +import com.slack.circuit.runtime.Navigator +import com.slack.circuit.runtime.presenter.Presenter +import com.slack.circuit.runtime.screen.Screen +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith + +private const val TAG_SHOW_CHILD_BUTTON = "TAG_SHOW_CHILD_BUTTON" +private const val TAG_HIDE_CHILD_BUTTON = "TAG_HIDE_CHILD_BUTTON" +private const val TAG_INC_BUTTON = "TAG_INC_BUTTON" +private const val TAG_GOTO_BUTTON = "TAG_GOTO_BUTTON" +private const val TAG_POP_BUTTON = "TAG_POP_BUTTON" +private const val TAG_CONDITIONAL_RETAINED = "TAG_CONDITIONAL_RETAINED" +private const val TAG_UI_RETAINED = "TAG_UI_RETAINED" +private const val TAG_PRESENTER_RETAINED = "TAG_PRESENTER_RETAINED" +private const val TAG_STATE = "TAG_STATE" + +@RunWith(ComposeUiTestRunner::class) +class NavigableCircuitConditionalRetainTest { + + @get:Rule val composeTestRule = createComposeRule() + + private val dataSource = DataSource() + + private val circuit = + Circuit.Builder() + .addPresenter { _, _, _ -> ScreenAPresenter() } + .addUi { _, modifier -> ScreenAUi(modifier) } + .addPresenter { _, _, _ -> ScreenBPresenter(dataSource) } + .addUi { state, modifier -> ScreenBUi(state, modifier) } + .addPresenter { _, navigator, _ -> ScreenCPresenter(navigator) } + .addUi { state, modifier -> ScreenCUi(state, modifier) } + .addPresenter { _, navigator, _ -> ScreenDPresenter(navigator) } + .addUi { state, modifier -> ScreenDUi(state, modifier) } + .build() + + @Test + fun nestedCircuitContentWithPresentWithLifecycle() { + nestedCircuitContent(presentWithLifecycle = true) + } + + @Test + fun nestedCircuitContentWithoutPresentWithLifecycle() { + nestedCircuitContent(presentWithLifecycle = false) + } + + @Test + fun removedConditionalRetainWithPresentWithLifecycle() { + removedConditionalRetain(presentWithLifecycle = true) + } + + @Test + fun removedConditionalRetainWithoutPresentWithLifecycle() { + removedConditionalRetain(presentWithLifecycle = false) + } + + @Test + fun addedConditionalRetainWithPresentWithLifecycle() { + addedConditionalRetain(presentWithLifecycle = true) + } + + @Test + fun addedConditionalRetainWithoutPresentWithLifecycle() { + addedConditionalRetain(presentWithLifecycle = false) + } + + /** Nested circuit content should not be retained if it is removed */ + private fun nestedCircuitContent(presentWithLifecycle: Boolean) { + composeTestRule.run { + val modifiedCircuit = circuit.newBuilder().presentWithLifecycle(presentWithLifecycle).build() + setUpTestContent(modifiedCircuit, ScreenA) + + onNodeWithTag(TAG_STATE).assertDoesNotExist() + onNodeWithTag(TAG_PRESENTER_RETAINED).assertDoesNotExist() + onNodeWithTag(TAG_UI_RETAINED).assertDoesNotExist() + + dataSource.value = 1 + + // Show child + onNodeWithTag(TAG_SHOW_CHILD_BUTTON).performClick() + + onNodeWithTag(TAG_STATE).assertTextEquals("1") + onNodeWithTag(TAG_UI_RETAINED).assertTextEquals("1") + onNodeWithTag(TAG_PRESENTER_RETAINED).assertTextEquals("1") + + // Hide child + onNodeWithTag(TAG_HIDE_CHILD_BUTTON).performClick() + + onNodeWithTag(TAG_STATE).assertDoesNotExist() + onNodeWithTag(TAG_PRESENTER_RETAINED).assertDoesNotExist() + onNodeWithTag(TAG_UI_RETAINED).assertDoesNotExist() + + dataSource.value = 2 + + // Show child + onNodeWithTag(TAG_SHOW_CHILD_BUTTON).performClick() + + // Retained reset + onNodeWithTag(TAG_STATE).assertTextEquals("2") + onNodeWithTag(TAG_UI_RETAINED).assertTextEquals("2") + onNodeWithTag(TAG_PRESENTER_RETAINED).assertTextEquals("2") + } + } + + /** + * Conditional rememberRetained should not be retained if it is removed no matter current + * RetainedStateRegistry is saved or not. + */ + private fun removedConditionalRetain(presentWithLifecycle: Boolean) { + composeTestRule.run { + val modifiedCircuit = circuit.newBuilder().presentWithLifecycle(presentWithLifecycle).build() + setUpTestContent(modifiedCircuit, ScreenC) + + onNodeWithTag(TAG_STATE).assertDoesNotExist() + onNodeWithTag(TAG_PRESENTER_RETAINED).assertDoesNotExist() + onNodeWithTag(TAG_UI_RETAINED).assertDoesNotExist() + + // Show child + onNodeWithTag(TAG_SHOW_CHILD_BUTTON).performClick() + + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("0") + onNodeWithTag(TAG_INC_BUTTON).performClick() + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("1") + + // Hide child + onNodeWithTag(TAG_HIDE_CHILD_BUTTON).performClick() + + // Navigate other screen and pop for saving ScreenC's state + onNodeWithTag(TAG_GOTO_BUTTON).performClick() + onNodeWithTag(TAG_POP_BUTTON).performClick() + + // Show child + onNodeWithTag(TAG_SHOW_CHILD_BUTTON).performClick() + + // Child's retained state should not be retained + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("0") + } + } + + /** + * Conditional rememberRetained should be retained if it is added and current + * RetainedStateRegistry is saved + */ + private fun addedConditionalRetain(presentWithLifecycle: Boolean) { + composeTestRule.run { + val modifiedCircuit = circuit.newBuilder().presentWithLifecycle(presentWithLifecycle).build() + setUpTestContent(modifiedCircuit, ScreenC) + + onNodeWithTag(TAG_STATE).assertDoesNotExist() + onNodeWithTag(TAG_PRESENTER_RETAINED).assertDoesNotExist() + onNodeWithTag(TAG_UI_RETAINED).assertDoesNotExist() + + // Show child + onNodeWithTag(TAG_SHOW_CHILD_BUTTON).performClick() + + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("0") + onNodeWithTag(TAG_INC_BUTTON).performClick() + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("1") + + // Navigate other screen and pop for saving ScreenC's state + onNodeWithTag(TAG_GOTO_BUTTON).performClick() + onNodeWithTag(TAG_POP_BUTTON).performClick() + + // Child's retained state should be retained + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("1") + + // Hide child + onNodeWithTag(TAG_HIDE_CHILD_BUTTON).performClick() + // Show child + onNodeWithTag(TAG_SHOW_CHILD_BUTTON).performClick() + + // Child's retained state should not be retained + onNodeWithTag(TAG_CONDITIONAL_RETAINED).assertTextEquals("0") + } + } + + private fun ComposeContentTestRule.setUpTestContent(circuit: Circuit, screen: Screen): Navigator { + lateinit var navigator: Navigator + setContent { + CircuitCompositionLocals(circuit) { + val backStack = rememberSaveableBackStack(screen) + navigator = rememberCircuitNavigator(backStack = backStack, onRootPop = {}) + NavigableCircuitContent(navigator = navigator, backStack = backStack) + } + } + return navigator + } + + private data object ScreenA : Screen { + data object State : CircuitUiState + } + + private class ScreenAPresenter : Presenter { + @Composable + override fun present(): ScreenA.State { + return ScreenA.State + } + } + + @Composable + private fun ScreenAUi(modifier: Modifier = Modifier) { + Column(modifier) { + val isChildVisible = remember { mutableStateOf(false) } + Button( + modifier = Modifier.testTag(TAG_SHOW_CHILD_BUTTON), + onClick = { isChildVisible.value = true }, + ) { + Text("show") + } + Button( + modifier = Modifier.testTag(TAG_HIDE_CHILD_BUTTON), + onClick = { isChildVisible.value = false }, + ) { + Text("hide") + } + if (isChildVisible.value) { + CircuitContent(screen = ScreenB) + } + } + } + + private data object ScreenB : Screen { + + data class State(val count: Int, val retainedCount: Int) : CircuitUiState + } + + private class ScreenBPresenter(private val source: DataSource) : Presenter { + + @Composable + override fun present(): ScreenB.State { + val count = source.fetch() + val retained = rememberRetained { count } + return ScreenB.State(count, retained) + } + } + + @Composable + private fun ScreenBUi(state: ScreenB.State, modifier: Modifier = Modifier) { + Column(modifier) { + val retained = rememberRetained { state.count } + Text(text = retained.toString(), modifier = Modifier.testTag(TAG_UI_RETAINED)) + Text(text = state.count.toString(), modifier = Modifier.testTag(TAG_STATE)) + Text( + text = state.retainedCount.toString(), + modifier = Modifier.testTag(TAG_PRESENTER_RETAINED), + ) + } + } + + private data object ScreenC : Screen { + + data class State(val eventSink: (Event) -> Unit) : CircuitUiState + + sealed interface Event : CircuitUiEvent { + data class GoTo(val screen: Screen) : Event + } + } + + private class ScreenCPresenter(private val navigator: Navigator) : Presenter { + @Composable + override fun present(): ScreenC.State { + return ScreenC.State { event -> + when (event) { + is ScreenC.Event.GoTo -> navigator.goTo(event.screen) + } + } + } + } + + @Composable + private fun ScreenCUi(state: ScreenC.State, modifier: Modifier = Modifier) { + Column(modifier) { + Button( + modifier = Modifier.testTag(TAG_GOTO_BUTTON), + onClick = { state.eventSink(ScreenC.Event.GoTo(ScreenD)) }, + ) { + Text("goto") + } + val isVisible = rememberRetained { mutableStateOf(false) } + Button( + modifier = Modifier.testTag(TAG_SHOW_CHILD_BUTTON), + onClick = { isVisible.value = true }, + ) { + Text("show") + } + Button( + modifier = Modifier.testTag(TAG_HIDE_CHILD_BUTTON), + onClick = { isVisible.value = false }, + ) { + Text("hide") + } + if (isVisible.value) { + val count = rememberRetained { mutableStateOf(0) } + Button(modifier = Modifier.testTag(TAG_INC_BUTTON), onClick = { count.value += 1 }) { + Text("inc") + } + Text(modifier = Modifier.testTag(TAG_CONDITIONAL_RETAINED), text = count.value.toString()) + } + } + } + + private data object ScreenD : Screen { + + data class State(val eventSink: (Event) -> Unit) : CircuitUiState + + sealed interface Event : CircuitUiEvent { + data object Pop : Event + } + } + + private class ScreenDPresenter(private val navigator: Navigator) : Presenter { + + @Composable + override fun present(): ScreenD.State { + return ScreenD.State { event -> + when (event) { + is ScreenD.Event.Pop -> navigator.pop() + } + } + } + } + + @Composable + private fun ScreenDUi(state: ScreenD.State, modifier: Modifier = Modifier) { + Column(modifier) { + Button( + onClick = { state.eventSink(ScreenD.Event.Pop) }, + modifier = Modifier.testTag(TAG_POP_BUTTON), + ) { + Text(text = "pop") + } + } + } + + private class DataSource { + var value: Int = 0 + + fun fetch(): Int = value + } +} diff --git a/circuit-retained/api/android/circuit-retained.api b/circuit-retained/api/android/circuit-retained.api index 7870b92d7..5fa6a509a 100644 --- a/circuit-retained/api/android/circuit-retained.api +++ b/circuit-retained/api/android/circuit-retained.api @@ -40,7 +40,7 @@ public final class com/slack/circuit/retained/NoOpRetainedStateRegistry : com/sl public fun consumeValue (Ljava/lang/String;)Ljava/lang/Object; public fun forgetUnclaimedValues ()V public fun registerValue (Ljava/lang/String;Lcom/slack/circuit/retained/RetainedValueProvider;)Lcom/slack/circuit/retained/RetainedStateRegistry$Entry; - public fun saveAll ()V + public fun saveAll ()Ljava/util/Map; public fun saveValue (Ljava/lang/String;)V } @@ -60,11 +60,20 @@ public final class com/slack/circuit/retained/RememberRetainedKt { public static final fun rememberRetainedSaveable ([Ljava/lang/Object;Landroidx/compose/runtime/saveable/Saver;Ljava/lang/String;Lkotlin/jvm/functions/Function0;Landroidx/compose/runtime/Composer;II)Ljava/lang/Object; } +public abstract interface class com/slack/circuit/retained/RetainedStateHolder { + public abstract fun RetainedStateProvider (Ljava/lang/String;Lkotlin/jvm/functions/Function2;Landroidx/compose/runtime/Composer;I)V + public abstract fun removeState (Ljava/lang/String;)V +} + +public final class com/slack/circuit/retained/RetainedStateHolderKt { + public static final fun rememberRetainedStateHolder (Landroidx/compose/runtime/Composer;I)Lcom/slack/circuit/retained/RetainedStateHolder; +} + public abstract interface class com/slack/circuit/retained/RetainedStateRegistry { public abstract fun consumeValue (Ljava/lang/String;)Ljava/lang/Object; public abstract fun forgetUnclaimedValues ()V public abstract fun registerValue (Ljava/lang/String;Lcom/slack/circuit/retained/RetainedValueProvider;)Lcom/slack/circuit/retained/RetainedStateRegistry$Entry; - public abstract fun saveAll ()V + public abstract fun saveAll ()Ljava/util/Map; public abstract fun saveValue (Ljava/lang/String;)V } diff --git a/circuit-retained/api/circuit-retained.klib.api b/circuit-retained/api/circuit-retained.klib.api index d9bb2e987..366dc4bcf 100644 --- a/circuit-retained/api/circuit-retained.klib.api +++ b/circuit-retained/api/circuit-retained.klib.api @@ -25,11 +25,16 @@ abstract interface <#A: kotlin/Any?> com.slack.circuit.retained/RetainedValueHol abstract fun (): #A // com.slack.circuit.retained/RetainedValueHolder.value.|(){}[0] } +abstract interface com.slack.circuit.retained/RetainedStateHolder { // com.slack.circuit.retained/RetainedStateHolder|null[0] + abstract fun RetainedStateProvider(kotlin/String, kotlin/Function2, androidx.compose.runtime/Composer?, kotlin/Int) // com.slack.circuit.retained/RetainedStateHolder.RetainedStateProvider|RetainedStateProvider(kotlin.String;kotlin.Function2;androidx.compose.runtime.Composer?;kotlin.Int){}[0] + abstract fun removeState(kotlin/String) // com.slack.circuit.retained/RetainedStateHolder.removeState|removeState(kotlin.String){}[0] +} + abstract interface com.slack.circuit.retained/RetainedStateRegistry { // com.slack.circuit.retained/RetainedStateRegistry|null[0] abstract fun consumeValue(kotlin/String): kotlin/Any? // com.slack.circuit.retained/RetainedStateRegistry.consumeValue|consumeValue(kotlin.String){}[0] abstract fun forgetUnclaimedValues() // com.slack.circuit.retained/RetainedStateRegistry.forgetUnclaimedValues|forgetUnclaimedValues(){}[0] abstract fun registerValue(kotlin/String, com.slack.circuit.retained/RetainedValueProvider): com.slack.circuit.retained/RetainedStateRegistry.Entry // com.slack.circuit.retained/RetainedStateRegistry.registerValue|registerValue(kotlin.String;com.slack.circuit.retained.RetainedValueProvider){}[0] - abstract fun saveAll() // com.slack.circuit.retained/RetainedStateRegistry.saveAll|saveAll(){}[0] + abstract fun saveAll(): kotlin.collections/Map> // com.slack.circuit.retained/RetainedStateRegistry.saveAll|saveAll(){}[0] abstract fun saveValue(kotlin/String) // com.slack.circuit.retained/RetainedStateRegistry.saveValue|saveValue(kotlin.String){}[0] abstract interface Entry { // com.slack.circuit.retained/RetainedStateRegistry.Entry|null[0] @@ -46,7 +51,7 @@ final object com.slack.circuit.retained/NoOpRetainedStateRegistry : com.slack.ci final fun consumeValue(kotlin/String): kotlin/Any? // com.slack.circuit.retained/NoOpRetainedStateRegistry.consumeValue|consumeValue(kotlin.String){}[0] final fun forgetUnclaimedValues() // com.slack.circuit.retained/NoOpRetainedStateRegistry.forgetUnclaimedValues|forgetUnclaimedValues(){}[0] final fun registerValue(kotlin/String, com.slack.circuit.retained/RetainedValueProvider): com.slack.circuit.retained/RetainedStateRegistry.Entry // com.slack.circuit.retained/NoOpRetainedStateRegistry.registerValue|registerValue(kotlin.String;com.slack.circuit.retained.RetainedValueProvider){}[0] - final fun saveAll() // com.slack.circuit.retained/NoOpRetainedStateRegistry.saveAll|saveAll(){}[0] + final fun saveAll(): kotlin.collections/Map> // com.slack.circuit.retained/NoOpRetainedStateRegistry.saveAll|saveAll(){}[0] final fun saveValue(kotlin/String) // com.slack.circuit.retained/NoOpRetainedStateRegistry.saveValue|saveValue(kotlin.String){}[0] } @@ -78,6 +83,7 @@ final fun com.slack.circuit.retained/com_slack_circuit_retained_RetainableSaveab final fun com.slack.circuit.retained/com_slack_circuit_retained_RetainedStateRegistryImpl$stableprop_getter(): kotlin/Int // com.slack.circuit.retained/com_slack_circuit_retained_RetainedStateRegistryImpl$stableprop_getter|com_slack_circuit_retained_RetainedStateRegistryImpl$stableprop_getter(){}[0] final fun com.slack.circuit.retained/continuityRetainedStateRegistry(kotlin/String?, com.slack.circuit.retained/CanRetainChecker?, androidx.compose.runtime/Composer?, kotlin/Int, kotlin/Int): com.slack.circuit.retained/RetainedStateRegistry // com.slack.circuit.retained/continuityRetainedStateRegistry|continuityRetainedStateRegistry(kotlin.String?;com.slack.circuit.retained.CanRetainChecker?;androidx.compose.runtime.Composer?;kotlin.Int;kotlin.Int){}[0] final fun com.slack.circuit.retained/rememberCanRetainChecker(androidx.compose.runtime/Composer?, kotlin/Int): com.slack.circuit.retained/CanRetainChecker // com.slack.circuit.retained/rememberCanRetainChecker|rememberCanRetainChecker(androidx.compose.runtime.Composer?;kotlin.Int){}[0] +final fun com.slack.circuit.retained/rememberRetainedStateHolder(androidx.compose.runtime/Composer?, kotlin/Int): com.slack.circuit.retained/RetainedStateHolder // com.slack.circuit.retained/rememberRetainedStateHolder|rememberRetainedStateHolder(androidx.compose.runtime.Composer?;kotlin.Int){}[0] // Targets: [native] abstract fun interface com.slack.circuit.retained/RetainedValueProvider : kotlin/Function0 { // com.slack.circuit.retained/RetainedValueProvider|null[0] diff --git a/circuit-retained/api/jvm/circuit-retained.api b/circuit-retained/api/jvm/circuit-retained.api index 384214896..8ac050849 100644 --- a/circuit-retained/api/jvm/circuit-retained.api +++ b/circuit-retained/api/jvm/circuit-retained.api @@ -39,7 +39,7 @@ public final class com/slack/circuit/retained/NoOpRetainedStateRegistry : com/sl public fun consumeValue (Ljava/lang/String;)Ljava/lang/Object; public fun forgetUnclaimedValues ()V public fun registerValue (Ljava/lang/String;Lcom/slack/circuit/retained/RetainedValueProvider;)Lcom/slack/circuit/retained/RetainedStateRegistry$Entry; - public fun saveAll ()V + public fun saveAll ()Ljava/util/Map; public fun saveValue (Ljava/lang/String;)V } @@ -59,11 +59,20 @@ public final class com/slack/circuit/retained/RememberRetainedKt { public static final fun rememberRetainedSaveable ([Ljava/lang/Object;Landroidx/compose/runtime/saveable/Saver;Ljava/lang/String;Lkotlin/jvm/functions/Function0;Landroidx/compose/runtime/Composer;II)Ljava/lang/Object; } +public abstract interface class com/slack/circuit/retained/RetainedStateHolder { + public abstract fun RetainedStateProvider (Ljava/lang/String;Lkotlin/jvm/functions/Function2;Landroidx/compose/runtime/Composer;I)V + public abstract fun removeState (Ljava/lang/String;)V +} + +public final class com/slack/circuit/retained/RetainedStateHolderKt { + public static final fun rememberRetainedStateHolder (Landroidx/compose/runtime/Composer;I)Lcom/slack/circuit/retained/RetainedStateHolder; +} + public abstract interface class com/slack/circuit/retained/RetainedStateRegistry { public abstract fun consumeValue (Ljava/lang/String;)Ljava/lang/Object; public abstract fun forgetUnclaimedValues ()V public abstract fun registerValue (Ljava/lang/String;Lcom/slack/circuit/retained/RetainedValueProvider;)Lcom/slack/circuit/retained/RetainedStateRegistry$Entry; - public abstract fun saveAll ()V + public abstract fun saveAll ()Ljava/util/Map; public abstract fun saveValue (Ljava/lang/String;)V } diff --git a/circuit-retained/src/androidInstrumentedTest/AndroidManifest.xml b/circuit-retained/src/androidInstrumentedTest/AndroidManifest.xml index ada8acbc5..0798a1b1f 100644 --- a/circuit-retained/src/androidInstrumentedTest/AndroidManifest.xml +++ b/circuit-retained/src/androidInstrumentedTest/AndroidManifest.xml @@ -1,6 +1,7 @@ - + + - \ No newline at end of file + diff --git a/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedStateHolderTest.kt b/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedStateHolderTest.kt new file mode 100644 index 000000000..3f5896010 --- /dev/null +++ b/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedStateHolderTest.kt @@ -0,0 +1,324 @@ +/* + * Copyright 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.slack.circuit.retained.android + +import android.annotation.SuppressLint +import android.os.Bundle +import androidx.activity.ComponentActivity +import androidx.compose.runtime.CompositionLocalProvider +import androidx.compose.runtime.MutableState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.test.junit4.createAndroidComposeRule +import com.google.common.truth.Truth.assertThat +import com.slack.circuit.retained.LocalRetainedStateRegistry +import com.slack.circuit.retained.RetainedStateHolder +import com.slack.circuit.retained.RetainedStateRegistry +import com.slack.circuit.retained.rememberRetained +import com.slack.circuit.retained.rememberRetainedStateHolder +import leakcanary.DetectLeaksAfterTestSuccess.Companion.detectLeaksAfterTestSuccessWrapping +import org.junit.Rule +import org.junit.Test +import org.junit.rules.RuleChain + +// TODO adapt for retained more +@SuppressLint("RememberReturnType") +class RetainedStateHolderTest { + + private val composeTestRule = createAndroidComposeRule() + + @get:Rule + val rule = + RuleChain.emptyRuleChain().detectLeaksAfterTestSuccessWrapping(tag = "ActivitiesDestroyed") { + around(composeTestRule) + } + + private val restorationTester = RetainedStateRestorationTester(composeTestRule) + + @Test + fun stateIsRestoredWhenGoBackToScreen1() { + var increment = 0 + var screen by mutableStateOf(Screens.Screen1) + var numberOnScreen1 = -1 + var restorableNumberOnScreen1 = -1 + restorationTester.setContent { + val holder = rememberRetainedStateHolder() + holder.RetainedStateProvider(screen.name) { + if (screen == Screens.Screen1) { + numberOnScreen1 = remember { increment++ } + restorableNumberOnScreen1 = rememberRetained { increment++ } + } else { + // screen 2 + remember { 100 } + } + } + } + + composeTestRule.runOnIdle { + assertThat(numberOnScreen1).isEqualTo(0) + assertThat(restorableNumberOnScreen1).isEqualTo(1) + screen = Screens.Screen2 + } + + // wait for the screen switch to apply + composeTestRule.runOnIdle { + numberOnScreen1 = -1 + restorableNumberOnScreen1 = -1 + // switch back to screen1 + screen = Screens.Screen1 + } + + composeTestRule.runOnIdle { + assertThat(numberOnScreen1).isEqualTo(2) + assertThat(restorableNumberOnScreen1).isEqualTo(1) + } + } + + @Test + fun simpleRestoreOnlyOneScreen() { + var increment = 0 + var number = -1 + var restorableNumber = -1 + restorationTester.setContent { + val holder = rememberRetainedStateHolder() + holder.RetainedStateProvider(Screens.Screen1.name) { + number = remember { increment++ } + restorableNumber = rememberRetained { increment++ } + } + } + + composeTestRule.runOnIdle { + assertThat(number).isEqualTo(0) + assertThat(restorableNumber).isEqualTo(1) + number = -1 + restorableNumber = -1 + } + + restorationTester.emulateRetainedInstanceStateRestore() + + composeTestRule.runOnIdle { + assertThat(number).isEqualTo(2) + assertThat(restorableNumber).isEqualTo(1) + } + } + + @Test + fun switchToScreen2AndRestore() { + var increment = 0 + var screen by mutableStateOf(Screens.Screen1) + var numberOnScreen2 = -1 + var restorableNumberOnScreen2 = -1 + restorationTester.setContent { + val holder = rememberRetainedStateHolder() + holder.RetainedStateProvider(screen.name) { + if (screen == Screens.Screen2) { + numberOnScreen2 = remember { increment++ } + restorableNumberOnScreen2 = rememberRetained { increment++ } + } + } + } + + composeTestRule.runOnIdle { screen = Screens.Screen2 } + + // wait for the screen switch to apply + composeTestRule.runOnIdle { + assertThat(numberOnScreen2).isEqualTo(0) + assertThat(restorableNumberOnScreen2).isEqualTo(1) + numberOnScreen2 = -1 + restorableNumberOnScreen2 = -1 + } + + restorationTester.emulateRetainedInstanceStateRestore() + + composeTestRule.runOnIdle { + assertThat(numberOnScreen2).isEqualTo(2) + assertThat(restorableNumberOnScreen2).isEqualTo(1) + } + } + + @Test + fun stateOfScreen1IsSavedAndRestoredWhileWeAreOnScreen2() { + var increment = 0 + var screen by mutableStateOf(Screens.Screen1) + var numberOnScreen1 = -1 + var restorableNumberOnScreen1 = -1 + restorationTester.setContent { + val holder = rememberRetainedStateHolder() + holder.RetainedStateProvider(screen.name) { + if (screen == Screens.Screen1) { + numberOnScreen1 = remember { increment++ } + restorableNumberOnScreen1 = rememberRetained { increment++ } + } else { + // screen 2 + remember { 100 } + } + } + } + + composeTestRule.runOnIdle { + assertThat(numberOnScreen1).isEqualTo(0) + assertThat(restorableNumberOnScreen1).isEqualTo(1) + screen = Screens.Screen2 + } + + // wait for the screen switch to apply + composeTestRule.runOnIdle { + numberOnScreen1 = -1 + restorableNumberOnScreen1 = -1 + } + + restorationTester.emulateRetainedInstanceStateRestore() + + // switch back to screen1 + composeTestRule.runOnIdle { screen = Screens.Screen1 } + + composeTestRule.runOnIdle { + assertThat(numberOnScreen1).isEqualTo(2) + assertThat(restorableNumberOnScreen1).isEqualTo(1) + } + } + + @Test + fun weCanSkipSavingForCurrentScreen() { + var increment = 0 + var screen by mutableStateOf(Screens.Screen1) + var restorableStateHolder: RetainedStateHolder? = null + var restorableNumberOnScreen1 = -1 + restorationTester.setContent { + val holder = rememberRetainedStateHolder() + restorableStateHolder = holder + holder.RetainedStateProvider(screen.name) { + if (screen == Screens.Screen1) { + restorableNumberOnScreen1 = rememberRetained { increment++ } + } else { + // screen 2 + remember { 100 } + } + } + } + + composeTestRule.runOnIdle { + assertThat(restorableNumberOnScreen1).isEqualTo(0) + restorableNumberOnScreen1 = -1 + restorableStateHolder!!.removeState(Screens.Screen1.name) + screen = Screens.Screen2 + } + + composeTestRule.runOnIdle { + // switch back to screen1 + screen = Screens.Screen1 + } + + composeTestRule.runOnIdle { assertThat(restorableNumberOnScreen1).isEqualTo(1) } + } + + @Test + fun weCanRemoveAlreadySavedState() { + var increment = 0 + var screen by mutableStateOf(Screens.Screen1) + var restorableStateHolder: RetainedStateHolder? = null + var restorableNumberOnScreen1 = -1 + restorationTester.setContent { + val holder = rememberRetainedStateHolder() + restorableStateHolder = holder + holder.RetainedStateProvider(screen.name) { + if (screen == Screens.Screen1) { + restorableNumberOnScreen1 = rememberRetained { increment++ } + } else { + // screen 2 + remember { 100 } + } + } + } + + composeTestRule.runOnIdle { + assertThat(restorableNumberOnScreen1).isEqualTo(0) + restorableNumberOnScreen1 = -1 + screen = Screens.Screen2 + } + + composeTestRule.runOnIdle { + // switch back to screen1 + restorableStateHolder!!.removeState(Screens.Screen1.name) + screen = Screens.Screen1 + } + + composeTestRule.runOnIdle { assertThat(restorableNumberOnScreen1).isEqualTo(1) } + } + + @Test + fun restoringStateOfThePreviousPageAfterCreatingBundle() { + var showFirstPage by mutableStateOf(true) + var firstPageState: MutableState? = null + + composeTestRule.setContent { + val holder = rememberRetainedStateHolder() + holder.RetainedStateProvider(showFirstPage.toString()) { + if (showFirstPage) { + firstPageState = rememberRetained { mutableStateOf(0) } + } + } + } + + composeTestRule.runOnIdle { + assertThat(firstPageState!!.value).isEqualTo(0) + // change the value, so we can assert this change will be restored + firstPageState!!.value = 1 + firstPageState = null + showFirstPage = false + } + + composeTestRule.runOnIdle { + composeTestRule.activity.doFakeSave() + showFirstPage = true + } + + composeTestRule.runOnIdle { assertThat(firstPageState!!.value).isEqualTo(1) } + } + + @Test + fun saveNothingWhenNoRememberRetainedIsUsedInternally() { + var showFirstPage by mutableStateOf(true) + val registry = RetainedStateRegistry(emptyMap()) + + composeTestRule.setContent { + CompositionLocalProvider(LocalRetainedStateRegistry provides registry) { + val holder = rememberRetainedStateHolder() + holder.RetainedStateProvider(showFirstPage.toString()) {} + } + } + + composeTestRule.runOnIdle { showFirstPage = false } + + composeTestRule.runOnIdle { + val savedData = registry.saveAll() + assertThat(savedData).isEqualTo(emptyMap>()) + } + } + + class Activity : ComponentActivity() { + fun doFakeSave() { + onSaveInstanceState(Bundle()) + } + } +} + +enum class Screens { + Screen1, + Screen2, +} diff --git a/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedStateRestorationTester.kt b/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedStateRestorationTester.kt new file mode 100644 index 000000000..99e7ab8cf --- /dev/null +++ b/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedStateRestorationTester.kt @@ -0,0 +1,138 @@ +/* + * Copyright 2020 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.slack.circuit.retained.android + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.CompositionLocalProvider +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.test.junit4.ComposeContentTestRule +import com.slack.circuit.retained.LocalRetainedStateRegistry +import com.slack.circuit.retained.RetainedStateRegistry +import com.slack.circuit.retained.RetainedValueProvider +import com.slack.circuit.retained.rememberRetained + +/** + * Helps to test the retained state restoration for your Composable component. + * + * Instead of calling [ComposeContentTestRule.setContent] you need to use [setContent] on this + * object, then change your state so there is some change to be restored, then execute + * [emulateRetainedInstanceStateRestore] and assert your state is restored properly. + * + * Note that this tests only the restoration of the local state of the composable you passed to + * [setContent] and useful for testing [rememberRetained] integration. It is not testing the + * integration with any other life cycles or Activity callbacks. + */ +// TODO recreate for more realism? Need to save the content function to do that, or call it after +// TODO make this available in test utils? +class RetainedStateRestorationTester(private val composeTestRule: ComposeContentTestRule) { + + private var registry: RestorationRegistry? = null + + /** + * This functions is a direct replacement for [ComposeContentTestRule.setContent] if you are going + * to use [emulateRetainedInstanceStateRestore] in the test. + * + * @see ComposeContentTestRule.setContent + */ + fun setContent(composable: @Composable () -> Unit) { + composeTestRule.setContent { + CompositionLocalProvider( + LocalRetainedStateRegistry provides remember { RetainedStateRegistry() } + ) { + InjectRestorationRegistry { registry -> + this.registry = registry + composable() + } + } + } + } + + /** + * Saves all the state stored via [rememberRetained], disposes current composition, and composes + * again the content passed to [setContent]. Allows to test how your component behaves when the + * state restoration is happening. Note that the state stored via regular state() or remember() + * will be lost. + */ + fun emulateRetainedInstanceStateRestore() { + val registry = checkNotNull(registry) { "setContent should be called first!" } + composeTestRule.runOnIdle { registry.saveStateAndDisposeChildren() } + composeTestRule.runOnIdle { registry.emitChildrenWithRestoredState() } + composeTestRule.runOnIdle { + // we just wait for the children to be emitted + } + } + + @Composable + private fun InjectRestorationRegistry(content: @Composable (RestorationRegistry) -> Unit) { + val original = + requireNotNull(LocalRetainedStateRegistry.current) { + "StateRestorationTester requires composeTestRule.setContent() to provide " + + "a RetainedStateRegistry implementation via LocalRetainedStateRegistry" + } + val restorationRegistry = remember { RestorationRegistry(original) } + CompositionLocalProvider(LocalRetainedStateRegistry provides restorationRegistry) { + if (restorationRegistry.shouldEmitChildren) { + content(restorationRegistry) + } + } + } + + private class RestorationRegistry(private val original: RetainedStateRegistry) : + RetainedStateRegistry { + + var shouldEmitChildren by mutableStateOf(true) + private set + + private var currentRegistry: RetainedStateRegistry = original + private var savedMap: Map> = emptyMap() + + fun saveStateAndDisposeChildren() { + savedMap = currentRegistry.saveAll() + shouldEmitChildren = false + } + + fun emitChildrenWithRestoredState() { + currentRegistry = RetainedStateRegistry(values = savedMap) + shouldEmitChildren = true + } + + override fun consumeValue(key: String): Any? { + return currentRegistry.consumeValue(key) + } + + override fun registerValue( + key: String, + valueProvider: RetainedValueProvider, + ): RetainedStateRegistry.Entry { + return currentRegistry.registerValue(key, valueProvider) + } + + override fun saveAll(): Map> { + return currentRegistry.saveAll() + } + + override fun saveValue(key: String) { + currentRegistry.saveValue(key) + } + + override fun forgetUnclaimedValues() { + currentRegistry.forgetUnclaimedValues() + } + } +} diff --git a/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedTest.kt b/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedTest.kt index ea90be9fe..7007f4383 100644 --- a/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedTest.kt +++ b/circuit-retained/src/androidInstrumentedTest/kotlin/com/slack/circuit/retained/android/RetainedTest.kt @@ -14,7 +14,7 @@ import androidx.compose.runtime.CompositionLocalProvider import androidx.compose.runtime.RememberObserver import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue -import androidx.compose.runtime.key +import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.setValue @@ -23,6 +23,7 @@ import androidx.compose.ui.platform.testTag import androidx.compose.ui.test.assert import androidx.compose.ui.test.assertIsDisplayed import androidx.compose.ui.test.assertTextContains +import androidx.compose.ui.test.assertTextEquals import androidx.compose.ui.test.hasText import androidx.compose.ui.test.junit4.createAndroidComposeRule import androidx.compose.ui.test.onNodeWithTag @@ -41,6 +42,7 @@ import com.slack.circuit.retained.LocalRetainedStateRegistry import com.slack.circuit.retained.RetainedStateRegistry import com.slack.circuit.retained.continuityRetainedStateRegistry import com.slack.circuit.retained.rememberRetained +import com.slack.circuit.retained.rememberRetainedStateHolder import kotlinx.coroutines.flow.MutableStateFlow import leakcanary.DetectLeaksAfterTestSuccess.Companion.detectLeaksAfterTestSuccessWrapping import org.junit.Rule @@ -53,6 +55,7 @@ private const val TAG_RETAINED_2 = "retained2" private const val TAG_RETAINED_3 = "retained3" private const val TAG_BUTTON_SHOW = "btn_show" private const val TAG_BUTTON_HIDE = "btn_hide" +private const val TAG_BUTTON_INC = "btn_inc" class RetainedTest { private val composeTestRule = createAndroidComposeRule() @@ -364,10 +367,10 @@ class RetainedTest { val content = @Composable { - val nestedRegistryLevel1 = rememberRetained { RetainedStateRegistry() } - CompositionLocalProvider(LocalRetainedStateRegistry provides nestedRegistryLevel1) { - val nestedRegistryLevel2 = rememberRetained { RetainedStateRegistry() } - CompositionLocalProvider(LocalRetainedStateRegistry provides nestedRegistryLevel2) { + val holder1 = rememberRetainedStateHolder() + holder1.RetainedStateProvider("registry1") { + val holder2 = rememberRetainedStateHolder() + holder2.RetainedStateProvider("registry2") { @Suppress("UNUSED_VARIABLE") val retainedSubject = rememberRetained { subject } } } @@ -394,6 +397,54 @@ class RetainedTest { assertThat(subject.onForgottenCalled).isEqualTo(1) } + @Test + fun conditionalRetainBeforeSave() { + val registry = RetainedStateRegistry() + val content = @Composable { ConditionalRetainContent(registry) } + setActivityContent(content) + + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertDoesNotExist() + + composeTestRule.onNodeWithTag(TAG_BUTTON_SHOW).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertIsDisplayed() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertTextEquals("0") + + composeTestRule.onNodeWithTag(TAG_BUTTON_INC).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertTextEquals("1") + + composeTestRule.onNodeWithTag(TAG_BUTTON_HIDE).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertDoesNotExist() + + composeTestRule.onNodeWithTag(TAG_BUTTON_SHOW).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertIsDisplayed() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertTextEquals("0") + } + + @Test + fun conditionalRetainAfterSave() { + val registry = RetainedStateRegistry() + val content = @Composable { ConditionalRetainContent(registry) } + setActivityContent(content) + + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertDoesNotExist() + + composeTestRule.onNodeWithTag(TAG_BUTTON_SHOW).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertIsDisplayed() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertTextEquals("0") + + composeTestRule.onNodeWithTag(TAG_BUTTON_INC).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertTextEquals("1") + + composeTestRule.onNodeWithTag(TAG_BUTTON_HIDE).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertDoesNotExist() + + registry.saveAll() + + composeTestRule.onNodeWithTag(TAG_BUTTON_SHOW).performClick() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertIsDisplayed() + composeTestRule.onNodeWithTag(TAG_RETAINED_1).assertTextEquals("0") + } + private fun nestedRegistriesWithPopAndPush(useKeys: Boolean) { val content = @Composable { NestedRetainWithPushAndPop(useKeys = useKeys) } setActivityContent(content) @@ -584,10 +635,8 @@ private fun NestedRetains(useKeys: Boolean) { label = {}, ) - val nestedRegistryLevel1 = rememberRetained { RetainedStateRegistry() } - CompositionLocalProvider(LocalRetainedStateRegistry provides nestedRegistryLevel1) { - NestedRetainLevel1(useKeys) - } + val nestedStateHolderLevel1 = rememberRetainedStateHolder() + nestedStateHolderLevel1.RetainedStateProvider("registryLevel1") { NestedRetainLevel1(useKeys) } } } @@ -603,10 +652,8 @@ private fun NestedRetainLevel1(useKeys: Boolean) { label = {}, ) - val nestedRegistry = rememberRetained { RetainedStateRegistry() } - CompositionLocalProvider(LocalRetainedStateRegistry provides nestedRegistry) { - NestedRetainLevel2(useKeys) - } + val nestedStateHolderLevel2 = rememberRetainedStateHolder() + nestedStateHolderLevel2.RetainedStateProvider("registryLevel2") { NestedRetainLevel2(useKeys) } } @Composable @@ -651,16 +698,11 @@ private fun NestedRetainWithPushAndPop(useKeys: Boolean) { Text(text = "Show child") } + val retainedStateHolder = rememberRetainedStateHolder() // Keep the retained state registry around even if showNestedContent becomes false CompositionLocalProvider(LocalCanRetainChecker provides CanRetainChecker.Always) { if (showNestedContent.value) { - val nestedRegistry = rememberRetained { RetainedStateRegistry() } - CompositionLocalProvider( - LocalRetainedStateRegistry provides nestedRegistry, - LocalCanRetainChecker provides CanRetainChecker.Always, - ) { - NestedRetainLevel1(useKeys) - } + retainedStateHolder.RetainedStateProvider("registry") { NestedRetainLevel1(useKeys) } } } } @@ -696,15 +738,10 @@ private fun NestedRetainWithPushAndPopAndCannotRetain(useKeys: Boolean) { } // Keep the retained state registry around even if showNestedContent becomes false - CompositionLocalProvider(LocalCanRetainChecker provides CanRetainChecker.Always) { + val holder = rememberRetainedStateHolder() + CompositionLocalProvider(LocalCanRetainChecker provides { false }) { if (showNestedContent.value) { - val nestedRegistry = rememberRetained { RetainedStateRegistry() } - CompositionLocalProvider( - LocalRetainedStateRegistry provides nestedRegistry, - LocalCanRetainChecker provides { false }, - ) { - NestedRetainLevel1(useKeys) - } + holder.RetainedStateProvider("registry") { NestedRetainLevel1(useKeys) } } } } @@ -729,3 +766,25 @@ private fun InputsContent(input: String) { ) } } + +@Composable +private fun ConditionalRetainContent(registry: RetainedStateRegistry) { + CompositionLocalProvider(LocalRetainedStateRegistry provides registry) { + var showContent by remember { mutableStateOf(false) } + Column { + Button(modifier = Modifier.testTag(TAG_BUTTON_HIDE), onClick = { showContent = false }) { + Text(text = "Hide content") + } + Button(modifier = Modifier.testTag(TAG_BUTTON_SHOW), onClick = { showContent = true }) { + Text(text = "Show content") + } + if (showContent) { + var count by rememberRetained { mutableIntStateOf(0) } + Button(modifier = Modifier.testTag(TAG_BUTTON_INC), onClick = { count += 1 }) { + Text(text = "Increment") + } + Text(modifier = Modifier.testTag(TAG_RETAINED_1), text = count.toString()) + } + } + } +} diff --git a/circuit-retained/src/androidMain/kotlin/com/slack/circuit/retained/AndroidContinuity.kt b/circuit-retained/src/androidMain/kotlin/com/slack/circuit/retained/AndroidContinuity.kt index 8e99abad0..f354c8d9e 100644 --- a/circuit-retained/src/androidMain/kotlin/com/slack/circuit/retained/AndroidContinuity.kt +++ b/circuit-retained/src/androidMain/kotlin/com/slack/circuit/retained/AndroidContinuity.kt @@ -5,11 +5,10 @@ package com.slack.circuit.retained import androidx.annotation.VisibleForTesting import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect -import androidx.compose.runtime.RememberObserver -import androidx.compose.runtime.remember import androidx.compose.runtime.withFrameNanos import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.compose.LifecycleStartEffect import androidx.lifecycle.viewmodel.CreationExtras import androidx.lifecycle.viewmodel.compose.viewModel @@ -27,8 +26,8 @@ internal class ContinuityViewModel : ViewModel(), RetainedStateRegistry { return delegate.registerValue(key, valueProvider) } - override fun saveAll() { - delegate.saveAll() + override fun saveAll(): Map> { + return delegate.saveAll() } override fun saveValue(key: String) { @@ -40,7 +39,7 @@ internal class ContinuityViewModel : ViewModel(), RetainedStateRegistry { } override fun onCleared() { - delegate.retained.clear() + delegate.forgetUnclaimedValues() delegate.valueProviders.clear() } @@ -86,20 +85,10 @@ public fun continuityRetainedStateRegistry( @Suppress("ComposeViewModelInjection") val vm = viewModel(key = key, factory = factory) - remember(vm, canRetainChecker) { - object : RememberObserver { - override fun onAbandoned() = saveIfRetainable() - - override fun onForgotten() = saveIfRetainable() - - override fun onRemembered() { - // Do nothing - } - - fun saveIfRetainable() { - if (canRetainChecker.canRetain(vm)) { - vm.saveAll() - } + LifecycleStartEffect(vm) { + onStopOrDispose { + if (canRetainChecker.canRetain(vm)) { + vm.saveAll() } } } diff --git a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/NoOpRetainedStateRegistry.kt b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/NoOpRetainedStateRegistry.kt index 1adf45a74..f3f655bef 100644 --- a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/NoOpRetainedStateRegistry.kt +++ b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/NoOpRetainedStateRegistry.kt @@ -11,7 +11,7 @@ public object NoOpRetainedStateRegistry : RetainedStateRegistry { valueProvider: RetainedValueProvider, ): RetainedStateRegistry.Entry = NoOpEntry - override fun saveAll() {} + override fun saveAll(): Map> = emptyMap() override fun saveValue(key: String) {} diff --git a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RememberRetained.kt b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RememberRetained.kt index 431dbf90f..762f1f1ad 100644 --- a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RememberRetained.kt +++ b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RememberRetained.kt @@ -328,33 +328,6 @@ private class RetainableSaveableHolder( return registry == null || registry.canBeSaved(value) } - fun saveIfRetainable() { - val v = value ?: return - val reg = retainedStateRegistry ?: return - - if (!canRetainChecker.canRetain(reg)) { - retainedStateEntry?.unregister() - when (v) { - // If value is a RememberObserver, we notify that it has been forgotten. - is RememberObserver -> v.onForgotten() - // Or if its a registry, we need to tell it to clear, which will forward the 'forgotten' - // call onto its values - is RetainedStateRegistry -> { - // First we saveAll, which flattens down the value providers to our retained list - v.saveAll() - // Now we drop all retained values - v.forgetUnclaimedValues() - } - } - } else if (v is RetainedStateRegistry) { - // If the value is a RetainedStateRegistry, we need to take care to retain it. - // First we tell it to saveAll, to retain it's values. Then we need to tell the host - // registry to retain the child registry. - v.saveAll() - reg.saveValue(key) - } - } - override fun onRemembered() { registerRetained() registerSaveable() @@ -367,13 +340,24 @@ private class RetainableSaveableHolder( } override fun onForgotten() { - saveIfRetainable() - saveableStateEntry?.unregister() + release() } override fun onAbandoned() { - saveIfRetainable() + release() + } + + private fun release() { + val v = value + val reg = retainedStateRegistry + if (reg != null && !canRetainChecker.canRetain(reg)) { + when (v) { + is RememberObserver -> v.onForgotten() + is RetainedStateRegistry -> v.forgetUnclaimedValues() + } + } saveableStateEntry?.unregister() + retainedStateEntry?.unregister() } fun getValueIfInputsAreEqual(inputs: Array): T? { diff --git a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateHolder.kt b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateHolder.kt new file mode 100644 index 000000000..5c829a790 --- /dev/null +++ b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateHolder.kt @@ -0,0 +1,109 @@ +// Copyright (C) 2024 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package com.slack.circuit.retained + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.CompositionLocalProvider +import androidx.compose.runtime.DisposableEffect +import androidx.compose.runtime.ReusableContent +import androidx.compose.runtime.remember + +/** + * A holder that provides a unique retainedStateRegistry for each subtree and retains all preserved + * values. Each [RetainedStateProvider] maintains a unique retainedStateRegistry for each key, + * allowing it to save and restore states. + */ +public interface RetainedStateHolder { + + /** + * Provides a [RetainedStateRegistry] for the child [content] based on the specified [key]. Before + * the provided registry is disposed, it calls [RetainedStateRegistry.saveValue] on the holder's + * registry to save the current value, allowing it to be restored on the next visit with the same + * key. + */ + @Composable public fun RetainedStateProvider(key: String, content: @Composable () -> Unit) + + /** Removes the retained state associated with the passed [key]. */ + public fun removeState(key: String) +} + +/** Creates and remembers the instance of [RetainedStateHolder]. */ +@Composable +public fun rememberRetainedStateHolder(): RetainedStateHolder { + return rememberRetained { RetainedStateHolderImpl() } +} + +private class RetainedStateHolderImpl : RetainedStateHolder, RetainedStateRegistry { + + private val registry: RetainedStateRegistry = RetainedStateRegistry() + + private val canRetainCheckers = mutableMapOf() + + @Composable + override fun RetainedStateProvider(key: String, content: @Composable (() -> Unit)) { + CompositionLocalProvider(LocalRetainedStateRegistry provides registry) { + val parentCanRetainChecker = LocalCanRetainChecker.current ?: CanRetainChecker.Always + ReusableContent(key) { + val entryCanRetainChecker = + remember(parentCanRetainChecker) { EntryCanRetainChecker(parentCanRetainChecker) } + val childRegistry = rememberRetained(key = key) { RetainedStateRegistry() } + CompositionLocalProvider( + LocalRetainedStateRegistry provides childRegistry, + LocalCanRetainChecker provides CanRetainChecker.Always, + content = content, + ) + DisposableEffect(Unit) { + canRetainCheckers[key] = entryCanRetainChecker + onDispose { + val retained = childRegistry.saveAll() + if (retained.isNotEmpty() && entryCanRetainChecker.canRetain(registry)) { + registry.saveValue(key) + } + canRetainCheckers -= key + } + } + } + } + } + + override fun removeState(key: String) { + val canRetainChecker = canRetainCheckers[key] + if (canRetainChecker != null) { + canRetainChecker.shouldSave = false + } else { + registry.consumeValue(key) + } + } + + override fun consumeValue(key: String): Any? { + return registry.consumeValue(key) + } + + override fun registerValue( + key: String, + valueProvider: RetainedValueProvider, + ): RetainedStateRegistry.Entry { + return registry.registerValue(key, valueProvider) + } + + override fun saveAll(): Map> { + return registry.saveAll() + } + + override fun saveValue(key: String) { + registry.saveValue(key) + } + + override fun forgetUnclaimedValues() { + registry.forgetUnclaimedValues() + } + + private class EntryCanRetainChecker(private val parentChecker: CanRetainChecker) : + CanRetainChecker { + + var shouldSave = true + + override fun canRetain(registry: RetainedStateRegistry): Boolean = + parentChecker.canRetain(registry) && shouldSave + } +} diff --git a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateRegistry.kt b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateRegistry.kt index 15d1a8569..25fa5e0cc 100644 --- a/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateRegistry.kt +++ b/circuit-retained/src/commonMain/kotlin/com/slack/circuit/retained/RetainedStateRegistry.kt @@ -41,7 +41,7 @@ public interface RetainedStateRegistry { * Executes all the registered value providers and combines these values into a map. We have a * list of values for each key as it is allowed to have multiple providers for the same key. */ - public fun saveAll() + public fun saveAll(): Map> /** Executes the value providers registered with the given [key], and saves them for retrieval. */ public fun saveValue(key: String) @@ -109,9 +109,19 @@ internal class RetainedStateRegistryImpl(retained: MutableMap } } - override fun saveAll() { - val values = - valueProviders.mapValues { (_, list) -> + override fun saveAll(): Map> { + fun save(value: Any?): Boolean { + return when (value) { + // If we get a RetainedHolder value, need to unwrap and call again + is RetainedValueHolder<*> -> save(value.value) + // Dispatch the call to nested registries + is RetainedStateRegistry -> value.saveAll().isNotEmpty() + else -> true + } + } + + valueProviders + .mapValues { (_, list) -> // If we have multiple providers we should store null values as well to preserve // the order in which providers were registered. Say there were two providers. // the first provider returned null(nothing to save) and the second one returned @@ -119,13 +129,16 @@ internal class RetainedStateRegistryImpl(retained: MutableMap // same as to have nothing to restore) and the second one restore "1". list.map(RetainedValueProvider::invoke) } - - if (values.isNotEmpty()) { - // Store the values in our retained map - retained.putAll(values) - } + .forEach { (key, value) -> + val filtered = value.filter { save(it) } + if (filtered.isNotEmpty()) { + // Store the values in our retained map + retained[key] = filtered + } + } // Clear the value providers now that we've stored the values valueProviders.clear() + return retained } override fun saveValue(key: String) { diff --git a/circuitx/effects/src/androidUnitTest/kotlin/com/slack/circuitx/effects/RememberImpressionNavigatorTest.kt b/circuitx/effects/src/androidUnitTest/kotlin/com/slack/circuitx/effects/RememberImpressionNavigatorTest.kt index 6e9d23c8d..5aff78c73 100644 --- a/circuitx/effects/src/androidUnitTest/kotlin/com/slack/circuitx/effects/RememberImpressionNavigatorTest.kt +++ b/circuitx/effects/src/androidUnitTest/kotlin/com/slack/circuitx/effects/RememberImpressionNavigatorTest.kt @@ -177,8 +177,8 @@ class RememberImpressionNavigatorTest { } private fun ComposeContentTestRule.recreate() { - composed.value = false registry.saveAll() + composed.value = false waitForIdle() composed.value = true waitForIdle() diff --git a/circuitx/effects/src/commonTest/kotlin/com/slack/circuitx/effects/ImpressionEffectTest.kt b/circuitx/effects/src/commonTest/kotlin/com/slack/circuitx/effects/ImpressionEffectTest.kt index 59da3be2c..84fbdeab6 100644 --- a/circuitx/effects/src/commonTest/kotlin/com/slack/circuitx/effects/ImpressionEffectTest.kt +++ b/circuitx/effects/src/commonTest/kotlin/com/slack/circuitx/effects/ImpressionEffectTest.kt @@ -171,8 +171,8 @@ internal class ImpressionEffectTestSharedImpl : ImpressionEffectTestShared { /** Simulate a retained leaving and joining of the composition. */ private fun recreate() { - composed.value = false registry.saveAll() + composed.value = false composed.value = true } } diff --git a/samples/star/src/commonMain/kotlin/com/slack/circuit/star/home/HomeScreen.kt b/samples/star/src/commonMain/kotlin/com/slack/circuit/star/home/HomeScreen.kt index b5d7e6b0b..86e138558 100644 --- a/samples/star/src/commonMain/kotlin/com/slack/circuit/star/home/HomeScreen.kt +++ b/samples/star/src/commonMain/kotlin/com/slack/circuit/star/home/HomeScreen.kt @@ -21,6 +21,7 @@ import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.saveable.rememberSaveable +import androidx.compose.runtime.saveable.rememberSaveableStateHolder import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier import androidx.compose.ui.graphics.Color @@ -31,6 +32,7 @@ import com.slack.circuit.foundation.CircuitContent import com.slack.circuit.foundation.NavEvent import com.slack.circuit.foundation.onNavEvent import com.slack.circuit.retained.rememberRetained +import com.slack.circuit.retained.rememberRetainedStateHolder import com.slack.circuit.runtime.CircuitUiEvent import com.slack.circuit.runtime.CircuitUiState import com.slack.circuit.runtime.Navigator @@ -127,12 +129,19 @@ fun HomeContent(state: HomeScreen.State, modifier: Modifier = Modifier) = }, ) { paddingValues -> contentComposed = true - val screen = state.navItems[state.selectedIndex].screen - CircuitContent( - screen, - modifier = Modifier.padding(paddingValues), - onNavEvent = { event -> state.eventSink(ChildNav(event)) }, - ) + val saveableStateHolder = rememberSaveableStateHolder() + val retainedStateHolder = rememberRetainedStateHolder() + val currentScreen = state.navItems[state.selectedIndex].screen + saveableStateHolder.SaveableStateProvider(currentScreen) { + retainedStateHolder.RetainedStateProvider(state.selectedIndex.toString()) { + CircuitContent( + currentScreen, + modifier = Modifier.padding(paddingValues), + onNavEvent = { event -> state.eventSink(ChildNav(event)) }, + ) + } + } + contentComposed = true } Platform.ReportDrawnWhen { contentComposed } }