Skip to content

Commit

Permalink
perf: ipyvue widget can use a faster less generate state_get
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Mar 10, 2024
1 parent 6e376a3 commit 01dd332
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,57 @@ def patch_ipyreact():
# make this a no-op, we'll create the widget when needed
ipyreact.importmap._update_import_map = lambda: None

def patch_ipyvue_performance():
import functools
from collections.abc import Iterable

@functools.lru_cache(None) # type: ignore
def class_traits(cls, **metadata):
# we cache it for performance reasons
return cls.class_traits(**metadata)

@functools.lru_cache(None) # type: ignore
def to_jsons_meta(cls):
traits = class_traits(cls)
callables = {}
for name, trait in traits.items():
to_json = trait.metadata.get("to_json")
if to_json:
callables[name] = to_json
return callables

def get_state_fast(self, key=None, drop_defaults=False):
cls = type(self)
traits = class_traits(cls, sync=True) # type: ignore
if key is None:
keys = list(traits)
elif isinstance(key, str):
keys = [key]
elif isinstance(key, Iterable):
keys = list(key)
else:
raise ValueError("key must be a string, an iterable of keys, or None")
state = {}
to_jsons = to_jsons_meta(cls) # type: ignore
assert drop_defaults is False
trait_values = self._trait_values
for k in keys:
if k not in trait_values:
value = getattr(self, k)
else:
value = trait_values[k]
if k in to_jsons:
wire_value = to_jsons[k](value, self)
else:
# should we call _trait_to_json?
wire_value = value
state[k] = wire_value
return state

import ipyvue

ipyvue.VueWidget.get_state = get_state_fast


def patch():
global _patched
Expand All @@ -331,6 +382,12 @@ def patch():
else:
patch_ipyreact()

if settings.main.experimental_performance:
# this might be a bit too much
# import traitlets
# traitlets.TraitType._validate = lambda self, trait, value: value

patch_ipyvue_performance()
# the ipyvue.Template module cannot be accessed like ipyvue.Template
# because the import in ipvue overrides it
template_mod = sys.modules["ipyvue.Template"]
Expand Down

0 comments on commit 01dd332

Please sign in to comment.