Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Update trace tool to use DiagnosticToolBaseClass
Browse files Browse the repository at this point in the history
  • Loading branch information
horizon-blue committed Oct 18, 2022
1 parent 1db044a commit 51ccc88
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 85 deletions.
63 changes: 24 additions & 39 deletions src/beanmachine/ppl/diagnostics/tools/trace/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,44 @@

"""Trace diagnostic tool for a Bean Machine model."""

from typing import Optional, TypeVar
from typing import TypeVar

from beanmachine.ppl.diagnostics.tools.trace import utils
from beanmachine.ppl.diagnostics.tools.utils.base import Base
from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import (
DiagnosticToolBaseClass,
)
from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples
from bokeh.embed import file_html
from bokeh.models import Model
from bokeh.models.callbacks import CustomJS
from bokeh.resources import INLINE


T = TypeVar("T", bound="Trace")


class Trace(Base):
class Trace(DiagnosticToolBaseClass):
"""Trace tool.
Parameters
----------
mcs : MonteCarloSamples
Bean Machine model object.
Attributes
----------
data : Dict[str, List[List[float]]]
JSON serialized version of the Bean Machine model.
rv_names : List[str]
The list of random variables string names for the given model.
num_chains : int
The number of chains of the model.
num_draws : int
The number of draws of the model for each chain.
palette : List[str]
A list of color values used for the glyphs in the figures. The colors are
specifically chosen from the Colorblind palette defined in Bokeh.
js : str
The JavaScript callbacks needed to render the Bokeh tool independently from
a Python server.
name : str
The name to use when saving Bokeh JSON to disk.
Args:
mcs (MonteCarloSamples): The return object from running a Bean Machine model.
Attributes:
data (Dict[str, List[List[float]]]): JSON serializable representation of the
given `mcs` object.
rv_names (List[str]): The list of random variables string names for the given
model.
num_chains (int): The number of chains of the model.
num_draws (int): The number of draws of the model for each chain.
palette (List[str]): A list of color values used for the glyphs in the figures.
The colors are specifically chosen from the Colorblind palette defined in
Bokeh.
tool_js (str):The JavaScript callbacks needed to render the Bokeh tool
independently from a Python server.
"""

def __init__(self: T, mcs: MonteCarloSamples) -> None:
super(Trace, self).__init__(mcs)

def create_document(self: T, name: Optional[str] = None) -> str:
if name is not None:
self.name = name

def create_document(self: T) -> Model:
# Initialize widget values using Python.
rv_name = self.rv_names[0]

Expand Down Expand Up @@ -113,7 +103,7 @@ def create_document(self: T, name: Optional[str] = None) -> str:
tooltips,
);
}} catch (error) {{
{self.js}
{self.tool_js}
trace.update(
rvData,
rvName,
Expand Down Expand Up @@ -159,10 +149,5 @@ def create_document(self: T, name: Optional[str] = None) -> str:
widgets["bw_factor_slider"].js_on_change("value", slider_callback)
widgets["hdi_slider"].js_on_change("value", slider_callback)

# Create the view of the tool and serialize it into HTML using static resources
# from Bokeh. Embedding the tool in this manner prevents external CDN calls for
# JavaScript resources, and prevents the user from having to know where the
# Bokeh server is.
tool_view = utils.create_view(figures=figures, widgets=widgets)
output = file_html(tool_view, resources=INLINE)
return output
return tool_view
46 changes: 0 additions & 46 deletions src/beanmachine/ppl/diagnostics/tools/utils/base.py

This file was deleted.

0 comments on commit 51ccc88

Please sign in to comment.