Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: computed should be glitch free #593

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 85 additions & 8 deletions solara/toestand.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def merge_state(d1: S, **kwargs) -> S:
class ValueBase(Generic[T]):
def __init__(self, merge: Callable = merge_state):
self.merge = merge
self.listeners_changed: Dict[str, Set[Tuple[Callable[[], None], Optional[ContextManager]]]] = defaultdict(set)
self.listeners: Dict[str, Set[Tuple[Callable[[T], None], Optional[ContextManager]]]] = defaultdict(set)
self.listeners2: Dict[str, Set[Tuple[Callable[[T, T], None], Optional[ContextManager]]]] = defaultdict(set)

Expand Down Expand Up @@ -116,6 +117,15 @@ def get(self) -> T:
def _get_scope_key(self):
raise NotImplementedError

def subscribe_changed(self, listener: Callable[[], None], scope: Optional[ContextManager] = None):
scope_id = self._get_scope_key()
self.listeners_changed[scope_id].add((listener, scope))

def cleanup():
self.listeners_changed[scope_id].remove((listener, scope))

return cleanup

def subscribe(self, listener: Callable[[T], None], scope: Optional[ContextManager] = None):
scope_id = self._get_scope_key()
self.listeners[scope_id].add((listener, scope))
Expand All @@ -134,10 +144,32 @@ def cleanup():

return cleanup

def _fire_changed(self):
logger.info("value changed from will fire changed events")
scope_id = self._get_scope_key()
scopes = set()
listeners_changed = self.listeners_changed[scope_id].copy()
if not listeners_changed:
return
for listener_changed, scope in listeners_changed:
if scope is not None:
scopes.add(scope)
stack = contextlib.ExitStack()
with contextlib.ExitStack() as stack:
for scope in scopes:
stack.enter_context(scope)
for listener_changed, scope in listeners_changed:
# TODO: disable getting state
listener_changed()

def fire(self, new: T, old: T):
logger.info("value change from %s to %s, will fire events", old, new)
scope_id = self._get_scope_key()
scopes = set()
# TODO: early return if no listeners
for listener_changed, scope in self.listeners_changed[scope_id].copy():
if scope is not None:
scopes.add(scope)
for listener, scope in self.listeners[scope_id].copy():
if scope is not None:
scopes.add(scope)
Expand All @@ -148,6 +180,17 @@ def fire(self, new: T, old: T):
with contextlib.ExitStack() as stack:
for scope in scopes:
stack.enter_context(scope)
# this is the first phase of the fire, we only notify the listeners that the value has changed
# but not what the value is. This is the phase where all computed values are invalidated
for listener_changed, scope in self.listeners_changed[scope_id].copy():
# during this event handling we should not allow state updates (mayne not even reads) as that would
# trigger new events, and we are in the middle of handling events.
# TODO: disable getting state
listener_changed()
# we still support the old way of listening to changes, but ideally we deprecate this
# as sync event handling is difficult to get right.
# This will be difficult to do without, since depending on a ref to a field should not
# trigger a re-render which currently requires knowing the value of the field
for listener, scope in self.listeners[scope_id].copy():
listener(new)
for listener2, scope in self.listeners2[scope_id].copy():
Expand Down Expand Up @@ -365,6 +408,9 @@ def peek(self) -> S:
"""Return the value without automatically subscribing to listeners."""
return self._storage.peek()

def subscribe_changed(self, listener: Callable[[], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe_changed(listener, scope=scope)

def subscribe(self, listener: Callable[[S], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe(listener, scope=scope)

Expand Down Expand Up @@ -407,19 +453,28 @@ def __init__(self, f: Callable[[], S], key=None):

self.f = f

def on_change(*ignore):
with self._auto_subscriber.value:
self.set(f())
def on_change():
self._dirty.set(True) # settings state should not be allowed
# listeners are attached to the storage
self._storage._fire_changed()
scope_id = self._storage._get_scope_key()
if self._storage.listeners[scope_id] or self._storage.listeners2[scope_id]:
# DeprecationWarning: Using .subscribe and .subscribe_change on a computed value is not supported
warnings.warn("Using .subscribe and .subscribe_change on a computed value is deprecated, use .subscribe_changed", DeprecationWarning)
self._ensure_computed()

import functools

self._auto_subscriber = Singleton(functools.wraps(AutoSubscribeContextManager)(lambda: AutoSubscribeContextManager(on_change)))
# we should have a KernelVar, similar to ContextVar, or threading.local since we don't need/want reactivity
self._dirty = Reactive(False)

@functools.wraps(f)
def factory():
v = self._auto_subscriber.value
with v:
return f()
_auto_subscriber = self._auto_subscriber.value
with _auto_subscriber:
value = f()
return value

super().__init__(KernelStoreFactory(factory, key=key))

Expand All @@ -432,6 +487,18 @@ def cleanup():

solara.lifecycle.on_kernel_start(reset)

def _ensure_computed(self):
if self._dirty.peek():
self._dirty.set(False)
with self._auto_subscriber.value:
self.set(self.f())

def get(self):
self._ensure_computed()
if thread_local.reactive_used is not None:
thread_local.reactive_used.add(self)
return self._storage.get()

def __repr__(self):
value = super().__repr__()
return "<Computed" + value[len("<Reactive") : -1]
Expand Down Expand Up @@ -524,6 +591,12 @@ def __repr__(self):
def lock(self):
return self._root.lock

def subscribe_changed(self, listener: Callable[[], None], scope: Optional[ContextManager] = None):
def on_changed():
listener()

return self._root.subscribe_changed(on_changed, scope=scope)

def subscribe(self, listener: Callable[[T], None], scope: Optional[ContextManager] = None):
def on_change(new, old):
try:
Expand Down Expand Up @@ -702,7 +775,11 @@ def update_subscribers(self, change_handler, scope=None):

for reactive in added:
if reactive not in self.subscribed:
unsubscribe = reactive.subscribe_change(change_handler, scope=scope)
# if we subscribe to subscribe_changed instead, we get too many false positives
# and we would render too often, the main issue is that a ref to a computed
# will say 'i may have changed' via subscribe_changed, but the subscribe_change
# in the field will do an (eager) comparison, avoiding the false positive
unsubscribe = reactive.subscribe_change(lambda a, b: change_handler(), scope=scope)
self.subscribed[reactive] = unsubscribe
for reactive in removed:
unsubscribe = self.subscribed[reactive]
Expand Down Expand Up @@ -732,7 +809,7 @@ def __init__(self, element: solara.Element):
def __enter__(self):
_, set_counter = solara.use_state(0, key="auto_subscribe_force_update_counter")

def force_update(new_value, old_value):
def force_update():
# can we do just x+1 to collapse multiple updates into one?
set_counter(lambda x: x + 1)

Expand Down
64 changes: 62 additions & 2 deletions tests/unit/toestand_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,8 @@ def Main():
assert rc.find(v.Slider).widget.v_model == 2
y.value = "hello"
assert rc.find(v.Slider).widget.label == "hello"
assert len(x._storage.listeners2) == 1
assert len(x._storage.listeners2[kernel_context.id]) == 1
assert len(x._storage.listeners_changed[kernel_context.id]) == 0
# force an extra listener
x.value = 0
# and remove it
Expand All @@ -837,6 +838,7 @@ def Main():
rc.close()
assert not x._storage.listeners[kernel_context.id]
assert not x._storage.listeners2[kernel_context.id]
assert not x._storage.listeners_changed[kernel_context.id]


def test_reactive_auto_subscribe_sub():
Expand All @@ -861,7 +863,7 @@ def Test():
ref.value += 1
assert rc.find(v.Alert).widget.children[0] == "2 bears around here"
assert reactive_used == {ref}
# now check that we didn't listen to the while object, just count changes
# now check that we didn't listen to the whole object, just count changes
renders_before = renders
Ref(bears.fields.type).value = "pink"
assert renders == renders_before
Expand Down Expand Up @@ -1233,3 +1235,61 @@ class Person(BaseModel):
assert person.get().height == 2.0
assert Ref(person.fields.name).get() == "Maria"
assert Ref(person.fields.height).get() == 2.0


@dataclasses.dataclass(frozen=True)
class Profile:
name: str
surname: str


def test_computed_glitch_invalid_state_without_error():
from solara.toestand import Computed

profile = Reactive(Profile(name="John", surname="Doe"))

name = Computed(lambda: profile.value.name)
surname = Computed(lambda: profile.value.surname)

computed_initials = []

def compute_initials():
initials = name.value[0] + surname.value[0]
computed_initials.append(initials)
return initials

initials = Computed(compute_initials)

assert name.value == "John"
assert surname.value == "Doe"
assert initials.value == "JD"
assert computed_initials == ["JD"]

profile.value = Profile(name="Rosa", surname="Breddels")

assert name.value == "Rosa"
assert surname.value == "Breddels"
assert initials.value == "RB"
assert computed_initials == ["JD", "RB"]


@dataclasses.dataclass(frozen=True)
class CountrySelection:
countries: List[str]
selected: str


def test_computed_glitch_invalid_state_with_error():
from solara.toestand import Computed

country_selection = Reactive(CountrySelection(countries=["Netherlands", "Belgium", "Germany"], selected="Germany"))

countries = Computed(lambda: country_selection.value.countries)
selected = Computed(lambda: country_selection.value.selected)
selected_index = Computed(lambda: countries.value.index(selected.value))

assert selected_index.value == 2

country_selection.value = CountrySelection(countries=["China", "Japan"], selected="Japan")

assert selected_index.value == 1
Loading