Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Rebase etc
Browse files Browse the repository at this point in the history
- Rebased against `main`, and fixed merged conflicts.
- Updated copyright requirement
- Updated how models are serialized.
- Update style based on other pushes to the branch
- Added `__future__` annotations
- updated a few docstrings that needed an update
- updated yarn lock based on changes with the marginal1d tool and the
  trace tool.
- removed the 3rd party histogram requirement, since it is easy to
  calculate a histogram.
- other fixes associated with merging etc.
  • Loading branch information
ndmlny-qs committed Oct 19, 2022
1 parent c509ad4 commit 0bc5cc0
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 132 deletions.
2 changes: 1 addition & 1 deletion src/beanmachine/ppl/diagnostics/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# flake8: noqa

"""Visual diagnostic tools for Bean Machine models."""

import sys
from pathlib import Path

Expand All @@ -16,6 +15,7 @@
# accepted, see https://peps.python.org/pep-0655/. This is to follow the
# interface objects in JavaScript that allow keys to not be required using ?.
from typing import TypedDict

from typing_extensions import NotRequired
else:
from typing_extensions import NotRequired, TypedDict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ export const computeStats = (

// Compute the point statistics for the KDE, and create labels to display them in the
// figures.
const mean = computeMean(marginalX);
const mean = computeMean(rawData);
const hdiBounds = hdiInterval(rawData, hdiProbability);
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound];
const y = interpolatePoints({x: marginalX, y: marginalY, points: x});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
/* import {calculateHistogram} from 'compute-histogram'; */
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import {linearRange, numericalSort, shape} from './array';
import {rankData, scaleToOne} from './dataTransformation';
import {mean as computeMean} from './pointStatistic';
Expand Down
61 changes: 2 additions & 59 deletions src/beanmachine/ppl/diagnostics/tools/js/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
"@jridgewell/sourcemap-codec" "^1.4.10"
"@jridgewell/trace-mapping" "^0.3.9"

"@jridgewell/[email protected]", "@jridgewell/resolve-uri@^3.0.3":
"@jridgewell/[email protected]":
version "3.1.0"
resolved "https://registry.yarnpkg.com/@jridgewell/resolve-uri/-/resolve-uri-3.1.0.tgz#2203b118c157721addfe69d47b70465463066d78"
integrity sha512-F2msla3tad+Mfht5cJq7LSXcdudKTWCVYUgw6pLFOOHSTtZlj6SWNYAp+AhuqLmWdBO2X5hPrLcu8cVP8fy28w==
Expand Down Expand Up @@ -201,7 +201,7 @@
resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee"
integrity sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==

"@types/node@*", "@types/node@^18.0.4":
"@types/node@*":
version "18.8.5"
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.8.5.tgz#6a31f820c1077c3f8ce44f9e203e68a176e8f59e"
integrity sha512-Bq7G3AErwe5A/Zki5fdD3O6+0zDChhg671NfPjtIcbtzDNZTv4NPKMRFr7gtYPG7y+B8uTiNK4Ngd9T0FTar6Q==
Expand All @@ -218,20 +218,6 @@
dependencies:
"@types/jquery" "*"

"@typescript-eslint/eslint-plugin@^5.30.5":
version "5.40.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/eslint-plugin/-/eslint-plugin-5.40.0.tgz#0159bb71410eec563968288a17bd4478cdb685bd"
integrity sha512-FIBZgS3DVJgqPwJzvZTuH4HNsZhHMa9SjxTKAZTlMsPw/UzpEjcf9f4dfgDJEHjK+HboUJo123Eshl6niwEm/Q==
dependencies:
"@typescript-eslint/scope-manager" "5.40.0"
"@typescript-eslint/type-utils" "5.40.0"
"@typescript-eslint/utils" "5.40.0"
debug "^4.3.4"
ignore "^5.2.0"
regexpp "^3.2.0"
semver "^7.3.7"
tsutils "^3.21.0"

"@typescript-eslint/parser@^5.30.5":
version "5.40.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/parser/-/parser-5.40.0.tgz#432bddc1fe9154945660f67c1ba6d44de5014840"
Expand All @@ -250,16 +236,6 @@
"@typescript-eslint/types" "5.40.0"
"@typescript-eslint/visitor-keys" "5.40.0"

"@typescript-eslint/[email protected]":
version "5.40.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/type-utils/-/type-utils-5.40.0.tgz#4964099d0158355e72d67a370249d7fc03331126"
integrity sha512-nfuSdKEZY2TpnPz5covjJqav+g5qeBqwSHKBvz7Vm1SAfy93SwKk/JeSTymruDGItTwNijSsno5LhOHRS1pcfw==
dependencies:
"@typescript-eslint/typescript-estree" "5.40.0"
"@typescript-eslint/utils" "5.40.0"
debug "^4.3.4"
tsutils "^3.21.0"

"@typescript-eslint/[email protected]":
version "5.40.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/types/-/types-5.40.0.tgz#8de07e118a10b8f63c99e174a3860f75608c822e"
Expand All @@ -278,19 +254,6 @@
semver "^7.3.7"
tsutils "^3.21.0"

"@typescript-eslint/[email protected]":
version "5.40.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/utils/-/utils-5.40.0.tgz#647f56a875fd09d33c6abd70913c3dd50759b772"
integrity sha512-MO0y3T5BQ5+tkkuYZJBjePewsY+cQnfkYeRqS6tPh28niiIwPnQ1t59CSRcs1ZwJJNOdWw7rv9pF8aP58IMihA==
dependencies:
"@types/json-schema" "^7.0.9"
"@typescript-eslint/scope-manager" "5.40.0"
"@typescript-eslint/types" "5.40.0"
"@typescript-eslint/typescript-estree" "5.40.0"
eslint-scope "^5.1.1"
eslint-utils "^3.0.0"
semver "^7.3.7"

"@typescript-eslint/[email protected]":
version "5.40.0"
resolved "https://registry.yarnpkg.com/@typescript-eslint/visitor-keys/-/visitor-keys-5.40.0.tgz#dd2d38097f68e0d2e1e06cb9f73c0173aca54b68"
Expand Down Expand Up @@ -684,11 +647,6 @@ core-js-pure@^3.25.1:
resolved "https://registry.yarnpkg.com/core-js-pure/-/core-js-pure-3.25.5.tgz#79716ba54240c6aa9ceba6eee08cf79471ba184d"
integrity sha512-oml3M22pHM+igfWHDfdLVq2ShWmjM2V4L+dQEBs0DWVIqEm9WHCwGAlZ6BmyBQGy5sFrJmcx+856D9lVKyGWYg==

create-require@^1.1.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/create-require/-/create-require-1.1.1.tgz#c1d7e8f1e5f6cfc9ff65f9cd352d37348756c333"
integrity sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==

cross-spawn@^7.0.2, cross-spawn@^7.0.3:
version "7.0.3"
resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.3.tgz#f73a85b9d5d41d045551c177e2882d4ac85728a6"
Expand Down Expand Up @@ -786,11 +744,6 @@ emoji-regex@^9.2.2:
resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-9.2.2.tgz#840c8803b0d8047f4ff0cf963176b32d4ef3ed72"
integrity sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==

emoji-regex@^9.2.2:
version "9.2.2"
resolved "https://registry.yarnpkg.com/emoji-regex/-/emoji-regex-9.2.2.tgz#840c8803b0d8047f4ff0cf963176b32d4ef3ed72"
integrity sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==

enhanced-resolve@^5.0.0, enhanced-resolve@^5.10.0:
version "5.10.0"
resolved "https://registry.yarnpkg.com/enhanced-resolve/-/enhanced-resolve-5.10.0.tgz#0dc579c3bb2a1032e357ac45b8f3a6f3ad4fb1e6"
Expand Down Expand Up @@ -1630,11 +1583,6 @@ js-sdsl@^4.1.4:
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==

"js-tokens@^3.0.0 || ^4.0.0":
version "4.0.0"
resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499"
integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==

js-yaml@^4.1.0:
version "4.1.0"
resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-4.1.0.tgz#c1fb65f8f5017901cdd2c951864ba18458a10602"
Expand Down Expand Up @@ -2008,11 +1956,6 @@ prelude-ls@^1.2.1:
resolved "https://registry.yarnpkg.com/prelude-ls/-/prelude-ls-1.2.1.tgz#debc6489d7a6e6b0e7611888cec880337d316396"
integrity sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==

prettier@^2.7.1:
version "2.7.1"
resolved "https://registry.yarnpkg.com/prettier/-/prettier-2.7.1.tgz#e235806850d057f97bb08368a4f7d899f7760c64"
integrity sha512-ujppO+MkdPqoVINuDFDRLClm7D78qbDt0/NR+wp5FqEZOoTNAjPHWj17QRhu7geIHJfcNhRk1XVQmF8Bp3ye+g==

proj4@^2.7.5:
version "2.8.0"
resolved "https://registry.yarnpkg.com/proj4/-/proj4-2.8.0.tgz#b2cb8f3ccd56d4dcc7c3e46155cd02caa804b170"
Expand Down
10 changes: 3 additions & 7 deletions src/beanmachine/ppl/diagnostics/tools/marginal1d/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# LICENSE file in the root directory of this source tree.

"""Marginal 1D diagnostic tool for a Bean Machine model."""

from typing import TypeVar
from __future__ import annotations

from beanmachine.ppl.diagnostics.tools.marginal1d import utils
from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import (
Expand All @@ -16,9 +15,6 @@
from bokeh.models.callbacks import CustomJS


T = TypeVar("T", bound="Marginal1d")


class Marginal1d(DiagnosticToolBaseClass):
"""
Marginal 1D diagnostic tool.
Expand All @@ -40,10 +36,10 @@ class Marginal1d(DiagnosticToolBaseClass):
independently from a Python server.
"""

def __init__(self: T, mcs: MonteCarloSamples) -> None:
def __init__(self: Marginal1d, mcs: MonteCarloSamples) -> None:
super(Marginal1d, self).__init__(mcs)

def create_document(self: T) -> Model:
def create_document(self: Marginal1d) -> Model:
# Initialize widget values using Python.
rv_name = self.rv_names[0]
bw_factor = 1.0
Expand Down
1 change: 0 additions & 1 deletion src/beanmachine/ppl/diagnostics/tools/marginal1d/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

"""Marginal 1D diagnostic tool types for a Bean Machine model."""

from typing import Any, Dict, List, Union

from beanmachine.ppl.diagnostics.tools import TypedDict
Expand Down
1 change: 0 additions & 1 deletion src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

"""Methods used to generate the diagnostic tool."""

from typing import List

import numpy as np
Expand Down
24 changes: 16 additions & 8 deletions src/beanmachine/ppl/diagnostics/tools/trace/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# LICENSE file in the root directory of this source tree.

"""Trace diagnostic tool for a Bean Machine model."""

from typing import TypeVar
from __future__ import annotations

from beanmachine.ppl.diagnostics.tools.trace import utils
from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import (
Expand All @@ -16,9 +15,6 @@
from bokeh.models.callbacks import CustomJS


T = TypeVar("T", bound="Trace")


class Trace(DiagnosticToolBaseClass):
"""Trace tool.
Expand All @@ -39,10 +35,10 @@ class Trace(DiagnosticToolBaseClass):
independently from a Python server.
"""

def __init__(self: T, mcs: MonteCarloSamples) -> None:
def __init__(self: Trace, mcs: MonteCarloSamples) -> None:
super(Trace, self).__init__(mcs)

def create_document(self: T) -> Model:
def create_document(self: Trace) -> Model:
# Initialize widget values using Python.
rv_name = self.rv_names[0]

Expand Down Expand Up @@ -88,10 +84,22 @@ def create_document(self: T) -> Model:
# Create the widgets for the tool using Python.
widgets = utils.create_widgets(rv_names=self.rv_names, rv_name=rv_name)

# Create the view of the tool and serialize it into HTML using static resources
# from Bokeh. Embedding the tool in this manner prevents external CDN calls for
# JavaScript resources, and prevents the user from having to know where the
# Bokeh server is.
tool_view = utils.create_view(figures=figures, widgets=widgets)

# Create callbacks for the tool using JavaScript.
callback_js = f"""
const rvName = widgets.rv_select.value;
const rvData = data[rvName];
let bw = 0.0;
// Remove the CSS classes that dim the tool output on initial load.
const toolTab = toolView.tabs[0];
const toolChildren = toolTab.child.children;
const dimmedComponent = toolChildren[1];
dimmedComponent.css_classes = [];
try {{
trace.update(
rvData,
Expand Down Expand Up @@ -125,6 +133,7 @@ def create_document(self: T) -> Model:
"sources": sources,
"figures": figures,
"tooltips": tooltips,
"toolView": tool_view,
}

# Each widget requires slightly different JS.
Expand All @@ -149,5 +158,4 @@ def create_document(self: T) -> Model:
widgets["bw_factor_slider"].js_on_change("value", slider_callback)
widgets["hdi_slider"].js_on_change("value", slider_callback)

tool_view = utils.create_view(figures=figures, widgets=widgets)
return tool_view
1 change: 0 additions & 1 deletion src/beanmachine/ppl/diagnostics/tools/trace/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

"""Trace diagnostic tool types for a Bean Machine model."""

from typing import Any, Dict, List, Union

from beanmachine.ppl.diagnostics.tools import NotRequired, TypedDict
Expand Down
24 changes: 5 additions & 19 deletions src/beanmachine/ppl/diagnostics/tools/trace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.

"""Methods used to generate the diagnostic tool."""

from typing import List

from beanmachine.ppl.diagnostics.tools.trace import typing
Expand Down Expand Up @@ -201,21 +200,18 @@ def create_figures(rv_name: str, num_chains: int) -> typing.Figures:
if figure_name == "marginals":
fig.title = "Marginal"
fig.xaxis.axis_label = rv_name
# fig.x_range = Range1d()
fig.yaxis.visible = False
elif figure_name == "forests":
fig.title = "Forest"
fig.xaxis.axis_label = rv_name
fig.yaxis.axis_label = "Chain"
fig.yaxis.minor_tick_line_color = None
fig.yaxis.ticker.desired_num_ticks = num_chains
# fig.x_range = Range1d()
elif figure_name == "traces":
fig.title = "Trace"
fig.xaxis.axis_label = "Draw from single chain"
fig.yaxis.axis_label = rv_name
fig.width = TRACE_PLOT_WIDTH
# fig.x_range = Range1d()
elif figure_name == "ranks":
fig.title = "Rank"
fig.xaxis.axis_label = "Rank from all chains"
Expand Down Expand Up @@ -399,8 +395,6 @@ def add_glyphs(
None
Adds data bound glyphs to the given figures directly.
"""
# range_min = []
# range_max = []
for figure_name, figure_sources in sources.items():
fig = figures[figure_name]
for chain_name, source in figure_sources.items():
Expand All @@ -417,11 +411,6 @@ def add_glyphs(
# its range stable are linked to the marginal figure's range below.
if figure_name == "marginals":
pass
# data = source["line"].data["x"]
# minimum = min(data) if len(data) != 0 else 0
# maximum = max(data) if len(data) != 0 else 1
# range_min.append(minimum)
# range_max.append(maximum)
elif figure_name == "forests":
fig.add_glyph(
source_or_glyph=source["circle"],
Expand All @@ -437,12 +426,7 @@ def add_glyphs(
name=chain_glyphs["quad"]["glyph"].name,
)
# Link figure ranges together.
# figures["marginals"].x_range = Range1d(
# start=min(range_min) if len(range_min) != 0 else 0,
# end=max(range_max) if len(range_max) != 0 else 1,
# )
figures["forests"].x_range = figures["marginals"].x_range
# figures["traces"].y_range = figures["marginals"].x_range


def create_annotations(figures: typing.Figures, num_chains: int) -> typing.Annotations:
Expand Down Expand Up @@ -573,21 +557,23 @@ def create_tooltips(
{
"line": HoverTool(
renderers=plotting_utils.filter_renderers(
fig, f"{figure_name}{chain_name.title()}LineGlyph"
fig,
f"{figure_name}{chain_name.title()}LineGlyph",
),
tooltips=[("Chain", "@chain"), ("Rank mean", "@rankMean")],
),
"quad": HoverTool(
renderers=plotting_utils.filter_renderers(
fig, f"{figure_name}{chain_name.title()}QuadGlyph"
fig,
f"{figure_name}{chain_name.title()}QuadGlyph",
),
tooltips=[
("Chain", "@chain"),
("Draws", "@draws"),
("Rank", "@rank"),
],
),
}
},
)
return output

Expand Down
Loading

0 comments on commit 0bc5cc0

Please sign in to comment.