Skip to content

Commit

Permalink
POC: use a dirty flag
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Apr 12, 2024
1 parent a6fe198 commit 30f207d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 10 deletions.
93 changes: 85 additions & 8 deletions solara/toestand.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def merge_state(d1: S, **kwargs) -> S:
class ValueBase(Generic[T]):
def __init__(self, merge: Callable = merge_state):
self.merge = merge
self.listeners_changed: Dict[str, Set[Tuple[Callable[[], None], Optional[ContextManager]]]] = defaultdict(set)
self.listeners: Dict[str, Set[Tuple[Callable[[T], None], Optional[ContextManager]]]] = defaultdict(set)
self.listeners2: Dict[str, Set[Tuple[Callable[[T, T], None], Optional[ContextManager]]]] = defaultdict(set)

Expand Down Expand Up @@ -116,6 +117,15 @@ def get(self) -> T:
def _get_scope_key(self):
raise NotImplementedError

def subscribe_changed(self, listener: Callable[[], None], scope: Optional[ContextManager] = None):
scope_id = self._get_scope_key()
self.listeners_changed[scope_id].add((listener, scope))

def cleanup():
self.listeners_changed[scope_id].remove((listener, scope))

return cleanup

def subscribe(self, listener: Callable[[T], None], scope: Optional[ContextManager] = None):
scope_id = self._get_scope_key()
self.listeners[scope_id].add((listener, scope))
Expand All @@ -134,10 +144,32 @@ def cleanup():

return cleanup

def _fire_changed(self):
logger.info("value changed from will fire changed events")
scope_id = self._get_scope_key()
scopes = set()
listeners_changed = self.listeners_changed[scope_id].copy()
if not listeners_changed:
return
for listener_changed, scope in listeners_changed:
if scope is not None:
scopes.add(scope)
stack = contextlib.ExitStack()
with contextlib.ExitStack() as stack:
for scope in scopes:
stack.enter_context(scope)
for listener_changed, scope in listeners_changed:
# TODO: disable getting state
listener_changed()

def fire(self, new: T, old: T):
logger.info("value change from %s to %s, will fire events", old, new)
scope_id = self._get_scope_key()
scopes = set()
# TODO: early return if no listeners
for listener_changed, scope in self.listeners_changed[scope_id].copy():
if scope is not None:
scopes.add(scope)
for listener, scope in self.listeners[scope_id].copy():
if scope is not None:
scopes.add(scope)
Expand All @@ -148,6 +180,17 @@ def fire(self, new: T, old: T):
with contextlib.ExitStack() as stack:
for scope in scopes:
stack.enter_context(scope)
# this is the first phase of the fire, we only notify the listeners that the value has changed
# but not what the value is. This is the phase where all computed values are invalidated
for listener_changed, scope in self.listeners_changed[scope_id].copy():
# during this event handling we should not allow state updates (mayne not even reads) as that would
# trigger new events, and we are in the middle of handling events.
# TODO: disable getting state
listener_changed()
# we still support the old way of listening to changes, but ideally we deprecate this
# as sync event handling is difficult to get right.
# This will be difficult to do without, since depending on a ref to a field should not
# trigger a re-render which currently requires knowing the value of the field
for listener, scope in self.listeners[scope_id].copy():
listener(new)
for listener2, scope in self.listeners2[scope_id].copy():
Expand Down Expand Up @@ -365,6 +408,9 @@ def peek(self) -> S:
"""Return the value without automatically subscribing to listeners."""
return self._storage.peek()

def subscribe_changed(self, listener: Callable[[], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe_changed(listener, scope=scope)

def subscribe(self, listener: Callable[[S], None], scope: Optional[ContextManager] = None):
return self._storage.subscribe(listener, scope=scope)

Expand Down Expand Up @@ -407,19 +453,28 @@ def __init__(self, f: Callable[[], S], key=None):

self.f = f

def on_change(*ignore):
with self._auto_subscriber.value:
self.set(f())
def on_change():
self._dirty.set(True) # settings state should not be allowed
# listeners are attached to the storage
self._storage._fire_changed()
scope_id = self._storage._get_scope_key()
if self._storage.listeners[scope_id] or self._storage.listeners2[scope_id]:
# DeprecationWarning: Using .subscribe and .subscribe_change on a computed value is not supported
warnings.warn("Using .subscribe and .subscribe_change on a computed value is deprecated, use .subscribe_changed", DeprecationWarning)
self._ensure_computed()

import functools

self._auto_subscriber = Singleton(functools.wraps(AutoSubscribeContextManager)(lambda: AutoSubscribeContextManager(on_change)))
# we should have a KernelVar, similar to ContextVar, or threading.local since we don't need/want reactivity
self._dirty = Reactive(False)

@functools.wraps(f)
def factory():
v = self._auto_subscriber.value
with v:
return f()
_auto_subscriber = self._auto_subscriber.value
with _auto_subscriber:
value = f()
return value

super().__init__(KernelStoreFactory(factory, key=key))

Expand All @@ -432,6 +487,18 @@ def cleanup():

solara.lifecycle.on_kernel_start(reset)

def _ensure_computed(self):
if self._dirty.peek():
self._dirty.set(False)
with self._auto_subscriber.value:
self.set(self.f())

def get(self):
self._ensure_computed()
if thread_local.reactive_used is not None:
thread_local.reactive_used.add(self)
return self._storage.get()

def __repr__(self):
value = super().__repr__()
return "<Computed" + value[len("<Reactive") : -1]
Expand Down Expand Up @@ -524,6 +591,12 @@ def __repr__(self):
def lock(self):
return self._root.lock

def subscribe_changed(self, listener: Callable[[], None], scope: Optional[ContextManager] = None):
def on_changed():
listener()

return self._root.subscribe_changed(on_changed, scope=scope)

def subscribe(self, listener: Callable[[T], None], scope: Optional[ContextManager] = None):
def on_change(new, old):
try:
Expand Down Expand Up @@ -702,7 +775,11 @@ def update_subscribers(self, change_handler, scope=None):

for reactive in added:
if reactive not in self.subscribed:
unsubscribe = reactive.subscribe_change(change_handler, scope=scope)
# if we subscribe to subscribe_changed instead, we get too many false positives
# and we would render too often, the main issue is that a ref to a computed
# will say 'i may have changed' via subscribe_changed, but the subscribe_change
# in the field will do an (eager) comparison, avoiding the false positive
unsubscribe = reactive.subscribe_change(lambda a, b: change_handler(), scope=scope)
self.subscribed[reactive] = unsubscribe
for reactive in removed:
unsubscribe = self.subscribed[reactive]
Expand Down Expand Up @@ -732,7 +809,7 @@ def __init__(self, element: solara.Element):
def __enter__(self):
_, set_counter = solara.use_state(0, key="auto_subscribe_force_update_counter")

def force_update(new_value, old_value):
def force_update():
# can we do just x+1 to collapse multiple updates into one?
set_counter(lambda x: x + 1)

Expand Down
8 changes: 6 additions & 2 deletions tests/unit/toestand_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,8 @@ def Main():
assert rc.find(v.Slider).widget.v_model == 2
y.value = "hello"
assert rc.find(v.Slider).widget.label == "hello"
assert len(x._storage.listeners2) == 1
assert len(x._storage.listeners2[kernel_context.id]) == 1
assert len(x._storage.listeners_changed[kernel_context.id]) == 0
# force an extra listener
x.value = 0
# and remove it
Expand All @@ -837,6 +838,7 @@ def Main():
rc.close()
assert not x._storage.listeners[kernel_context.id]
assert not x._storage.listeners2[kernel_context.id]
assert not x._storage.listeners_changed[kernel_context.id]


def test_reactive_auto_subscribe_sub():
Expand All @@ -861,7 +863,7 @@ def Test():
ref.value += 1
assert rc.find(v.Alert).widget.children[0] == "2 bears around here"
assert reactive_used == {ref}
# now check that we didn't listen to the while object, just count changes
# now check that we didn't listen to the whole object, just count changes
renders_before = renders
Ref(bears.fields.type).value = "pink"
assert renders == renders_before
Expand Down Expand Up @@ -1233,6 +1235,8 @@ class Person(BaseModel):
assert person.get().height == 2.0
assert Ref(person.fields.name).get() == "Maria"
assert Ref(person.fields.height).get() == 2.0


@dataclasses.dataclass(frozen=True)
class Profile:
name: str
Expand Down

0 comments on commit 30f207d

Please sign in to comment.