diff --git a/src/beanmachine/ppl/diagnostics/tools/trace/tool.py b/src/beanmachine/ppl/diagnostics/tools/trace/tool.py index 3460c92bd8..7c9555e0b7 100644 --- a/src/beanmachine/ppl/diagnostics/tools/trace/tool.py +++ b/src/beanmachine/ppl/diagnostics/tools/trace/tool.py @@ -5,54 +5,44 @@ """Trace diagnostic tool for a Bean Machine model.""" -from typing import Optional, TypeVar +from typing import TypeVar from beanmachine.ppl.diagnostics.tools.trace import utils -from beanmachine.ppl.diagnostics.tools.utils.base import Base +from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import ( + DiagnosticToolBaseClass, +) from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples -from bokeh.embed import file_html +from bokeh.models import Model from bokeh.models.callbacks import CustomJS -from bokeh.resources import INLINE T = TypeVar("T", bound="Trace") -class Trace(Base): +class Trace(DiagnosticToolBaseClass): """Trace tool. - Parameters - ---------- - mcs : MonteCarloSamples - Bean Machine model object. - - Attributes - ---------- - data : Dict[str, List[List[float]]] - JSON serialized version of the Bean Machine model. - rv_names : List[str] - The list of random variables string names for the given model. - num_chains : int - The number of chains of the model. - num_draws : int - The number of draws of the model for each chain. - palette : List[str] - A list of color values used for the glyphs in the figures. The colors are - specifically chosen from the Colorblind palette defined in Bokeh. - js : str - The JavaScript callbacks needed to render the Bokeh tool independently from - a Python server. - name : str - The name to use when saving Bokeh JSON to disk. + Args: + mcs (MonteCarloSamples): The return object from running a Bean Machine model. + + Attributes: + data (Dict[str, List[List[float]]]): JSON serializable representation of the + given `mcs` object. + rv_names (List[str]): The list of random variables string names for the given + model. + num_chains (int): The number of chains of the model. + num_draws (int): The number of draws of the model for each chain. + palette (List[str]): A list of color values used for the glyphs in the figures. + The colors are specifically chosen from the Colorblind palette defined in + Bokeh. + tool_js (str):The JavaScript callbacks needed to render the Bokeh tool + independently from a Python server. """ def __init__(self: T, mcs: MonteCarloSamples) -> None: super(Trace, self).__init__(mcs) - def create_document(self: T, name: Optional[str] = None) -> str: - if name is not None: - self.name = name - + def create_document(self: T) -> Model: # Initialize widget values using Python. rv_name = self.rv_names[0] @@ -113,7 +103,7 @@ def create_document(self: T, name: Optional[str] = None) -> str: tooltips, ); }} catch (error) {{ - {self.js} + {self.tool_js} trace.update( rvData, rvName, @@ -159,10 +149,5 @@ def create_document(self: T, name: Optional[str] = None) -> str: widgets["bw_factor_slider"].js_on_change("value", slider_callback) widgets["hdi_slider"].js_on_change("value", slider_callback) - # Create the view of the tool and serialize it into HTML using static resources - # from Bokeh. Embedding the tool in this manner prevents external CDN calls for - # JavaScript resources, and prevents the user from having to know where the - # Bokeh server is. tool_view = utils.create_view(figures=figures, widgets=widgets) - output = file_html(tool_view, resources=INLINE) - return output + return tool_view diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/base.py b/src/beanmachine/ppl/diagnostics/tools/utils/base.py deleted file mode 100644 index 6a4e605d8a..0000000000 --- a/src/beanmachine/ppl/diagnostics/tools/utils/base.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -"""Base class for diagnostic tools of a Bean Machine model.""" - -import re -from abc import ABC, abstractmethod -from typing import Optional, TypeVar - -from beanmachine.ppl.diagnostics.tools import JS_DIST_DIR -from beanmachine.ppl.diagnostics.tools.utils import plotting_utils -from beanmachine.ppl.diagnostics.tools.utils.model_serializers import serialize_bm -from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples -from IPython.display import display, HTML - - -T = TypeVar("T", bound="Base") - - -class Base(ABC): - @abstractmethod - def __init__(self: T, mcs: MonteCarloSamples) -> None: - self.data = serialize_bm(mcs) - self.rv_names = ["Select a random variable..."] + list(self.data.keys()) - self.num_chains = mcs.num_chains - self.num_draws = mcs.get_num_samples() - self.palette = plotting_utils.choose_palette(self.num_chains) - self.js = self.load_js() - - def load_js(self: T) -> str: - name = self.__class__.__name__ - name_tokens = re.findall(r"[A-Z][^A-Z]*", name) - name = "_".join(name_tokens) - path = JS_DIST_DIR.joinpath(f"{name.lower()}.js") - with path.open() as f: - js = f.read() - return js - - def show(self: T, name: Optional[str] = None): - display(HTML(self.create_document(name=name))) - - @abstractmethod - def create_document(self: T) -> str: - ...