Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Marginal diagnostic tool #1691

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ botorch/qmc/sobol.c*
# Sphinx documentation
sphinx/build/

# Docusaurus
# Docusaurus and diagnostic tools
website/build/
website/i18n/
website/node_modules/
node_modules

# Tutorials
docs/overview/tutorials/*/*.mdx
Expand Down
2 changes: 2 additions & 0 deletions src/beanmachine/ppl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import experimental
from .diagnostics import Diagnostics
from .diagnostics.common_statistics import effective_sample_size, r_hat, split_r_hat
from .diagnostics.tools import viz
from .inference import (
CompositionalInference,
empirical,
Expand Down Expand Up @@ -60,4 +61,5 @@
"random_variable",
"simulate",
"split_r_hat",
"viz",
]
22 changes: 22 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# flake8: noqa

"""Visual diagnostic tools for Bean Machine models."""

import sys
from pathlib import Path


if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict


TOOLS_DIR = Path(__file__).parent.resolve()
JS_DIR = TOOLS_DIR.joinpath("js")
JS_DIST_DIR = JS_DIR.joinpath("dist")
75 changes: 75 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/.eslintrc.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

const OFF = 0;
const WARNING = 1;
const ERROR = 2;

module.exports = {
root: true,
env: {
browser: true,
commonjs: true,
jest: true,
node: true,
},
parser: '@typescript-eslint/parser',
parserOptions: {
allowImportExportEverywhere: true,
},
extends: ['airbnb', 'prettier', 'plugin:import/typescript'],
plugins: ['prefer-arrow'],
rules: {
// Allow more than 1 class per file.
'max-classes-per-file': ['error', {ignoreExpressions: true, max: 2}],
// Allow snake_case.
camelcase: [
OFF,
{
properties: 'never',
ignoreDestructuring: true,
ignoreImports: true,
ignoreGlobals: true,
},
],
'no-underscore-dangle': OFF,
// Arrow function rules.
'prefer-arrow/prefer-arrow-functions': [
ERROR,
{
disallowPrototype: true,
singleReturnOnly: false,
classPropertiesAllowed: false,
},
],
'prefer-arrow-callback': [ERROR, {allowNamedFunctions: true}],
'arrow-parens': [ERROR, 'always'],
'arrow-body-style': [ERROR, 'always'],
'func-style': [ERROR, 'declaration', {allowArrowFunctions: true}],
'react/function-component-definition': [
ERROR,
{
namedComponents: 'arrow-function',
unnamedComponents: 'arrow-function',
},
],
// Ignore the global require, since some required packages are BrowserOnly.
'global-require': 0,
// We reassign several parameter objects since Bokeh is just updating values in the
// them.
'no-param-reassign': 0,
// Ignore certain webpack alias because it can't be resolved
'import/no-unresolved': [
ERROR,
{ignore: ['^@theme', '^@docusaurus', '^@generated', '^@bokeh']},
],
'import/extensions': OFF,
'object-shorthand': [ERROR, 'never'],
'prefer-destructuring': [WARNING, {object: true, array: true}],
'no-nested-ternary': 0,
},
};
54 changes: 54 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"name": "visual-diagnostic-tools",
"version": "0.1.0",
"description": "",
"license": "MIT",
"keywords": [],
"repository": {},
"scripts": {
"build": "webpack"
},
"dependencies": {
"@bokeh/bokehjs": "^2.4.3",
"fast-kde": "^0.2.1"
},
"devDependencies": {
"@types/node": "^18.0.4",
"@typescript-eslint/eslint-plugin": "^5.30.5",
"@typescript-eslint/parser": "^5.30.5",
"eslint": "^8.19.0",
"eslint-config-airbnb": "^19.0.4",
"eslint-config-prettier": "^8.5.0",
"eslint-plugin-import": "^2.26.0",
"eslint-plugin-jsx-a11y": "^6.5.1",
"eslint-plugin-prefer-arrow": "^1.2.3",
"eslint-plugin-react": "^7.28.0",
"eslint-plugin-react-hooks": "^4.3.0",
"ts-loader": "^9.3.1",
"ts-node": "^10.9.1",
"typescript": "^4.7.4",
"webpack": "^5.74.0",
"webpack-cli": "^4.10.0"
},
"overrides": {
"cwise": "$cwise",
"minimist": "$minimist",
"quote-stream": "$quote-stream",
"static-eval": "$static-eval",
"static-module": "$static-module",
"typedarray-pool": "$typedarray-pool"
},
"peerDependencies": {
"@types/cwise": "^1.0.4",
"@types/minimist": "^1.2.2",
"@types/static-eval": "^0.2.31",
"@types/typedarray-pool": "^1.1.2",
"buffer": "^6.0.3",
"cwise": "^1.0.10",
"minimist": "^1.2.6",
"quote-stream": "^1.0.2",
"static-eval": "2.1.0",
"static-module": "^3.0.4",
"typedarray-pool": "^1.2.0"
}
}
190 changes: 190 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import {Axis} from '@bokehjs/models/axes/axis';
import {cumulativeSum} from '../stats/array';
import {scaleToOne} from '../stats/dataTransformation';
import {
interval as hdiInterval,
data as hdiData,
} from '../stats/highestDensityInterval';
import {oneD} from '../stats/marginal';
import {mean as computeMean} from '../stats/pointStatistic';
import {interpolatePoints} from '../stats/utils';
import * as interfaces from './interfaces';

// Define the names of the figures used for this Bokeh application.
const figureNames = ['marginal', 'cumulative'];

/**
* Update the given Bokeh Axis object with the new label string. You must use this
* method to update axis strings using TypeScript, otherwise the ts compiler will throw
* a type check error.
*
* @param {Axis} axis - The Bokeh Axis object needing a new label.
* @param {string | null} label - The new label for the Bokeh Axis object.
*/
export const updateAxisLabel = (axis: Axis, label: string | null): void => {
// Type check requirement.
if ('axis_label' in axis) {
axis.axis_label = label;
}
};

/**
* Compute the following statistics for the given random variable data
*
* - lower bound for the highest density interval calculated from the marginalX;
* - mean of the rawData;
* - upper bound for the highest density interval calculated from the marginalY.
*
* @param {number[]} rawData - Raw random variable data from the model.
* @param {number[]} marginalX - The support of the Kernel Density Estimate of the
* random variable.
* @param {number[]} marginalY - The Kernel Density Estimate of the random variable.
* @param {number | null} [hdiProb=null] - The highest density interval probability
* value. If the default value is not overwritten, then the default HDI probability
* is 0.89. See Statistical Rethinking by McElreath for a description as to why this
* value is the default.
* @param {string[]} [text_align=['right', 'center', 'left']] - How to align the text
* shown in the figure for the point statistics.
* @param {number[]} [x_offset=[-5, 0, 5]] - Offset values for the text along the
* x-axis.
* @param {number[]} [y_offset=[0, 10, 0]] - Offset values for the text along the
* y-axis
* @returns {interfaces.LabelsData} Object containing the computed stats.
*/
export const computeStats = (
rawData: number[],
marginalX: number[],
marginalY: number[],
hdiProb: number | null = null,
text_align: string[] = ['right', 'center', 'left'],
x_offset: number[] = [-5, 0, 5],
y_offset: number[] = [0, 10, 0],
): interfaces.LabelsData => {
// Set the default value to 0.89 if no default value has been given.
const hdiProbability = hdiProb ?? 0.89;

// Compute the point statistics for the KDE, and create labels to display them in the
// figures.
const mean = computeMean(marginalX);
const hdiBounds = hdiInterval(rawData, hdiProbability);
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound];
const y = interpolatePoints({x: marginalX, y: marginalY, points: x});
const text = [
`Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`,
`Mean: ${mean.toFixed(3)}`,
`Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`,
];

return {
x: x,
y: y,
text: text,
text_align: text_align,
x_offset: x_offset,
y_offset: y_offset,
};
};

/**
* Compute data for the one-dimensional marginal diagnostic tool.
*
* @param {number[]} data - Raw random variable data from the model.
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when
* calculating the Kernel Density Estimate (KDE).
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @returns {interfaces.Data} The marginal distribution and cumulative
* distribution calculated from the given random variable data. Point statistics are
* also calculated.
*/
export const computeData = (
data: number[],
bwFactor: number,
hdiProbability: number,
): interfaces.Data => {
const output = {} as interfaces.Data;
for (let i = 0; i < figureNames.length; i += 1) {
const figureName = figureNames[i];
output[figureName] = {} as interfaces.GlyphData;

// Compute the one-dimensional KDE and its cumulative distribution.
const distribution = oneD(data, bwFactor);
switch (figureName) {
case 'cumulative':
distribution.y = scaleToOne(cumulativeSum(distribution.y));
break;
default:
break;
}

// Compute the point statistics for the given data.
const stats = computeStats(data, distribution.x, distribution.y, hdiProbability);

output[figureName] = {
distribution: distribution,
hdi: hdiData(data, distribution.x, distribution.y, hdiProbability),
stats: {x: stats.x, y: stats.y, text: stats.text},
labels: stats,
};
}
return output;
};

/**
* Callback used to update the Bokeh application in the notebook.
*
* @param {number[]} data - Raw random variable data from the model.
* @param {string} rvName - The name of the random variable from the model.
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when
* calculating the kernel density estimate.
* @param {number} hdiProbability - The highest density interval probability to use when
* calculating the HDI.
* @param {interfaces.Sources} sources - Bokeh sources used to render glyphs in the
* application.
* @param {interfaces.Figures} figures - Bokeh figures shown in the application.
* @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs.
* @returns {number} We display the value of the bandwidth used for computing the Kernel
* Density Estimate in a div, and must return that value here in order to update the
* value displayed to the user.
*/
export const update = (
data: number[],
rvName: string,
bwFactor: number,
hdiProbability: number,
sources: interfaces.Sources,
figures: interfaces.Figures,
tooltips: interfaces.Tooltips,
): number => {
const computedData = computeData(data, bwFactor, hdiProbability);
for (let i = 0; i < figureNames.length; i += 1) {
// Update all sources with new data calculated above.
const figureName = figureNames[i];
sources[figureName].distribution.data = {
x: computedData[figureName].distribution.x,
y: computedData[figureName].distribution.y,
};
sources[figureName].hdi.data = {
base: computedData[figureName].hdi.base,
lower: computedData[figureName].hdi.lower,
upper: computedData[figureName].hdi.upper,
};
sources[figureName].stats.data = computedData[figureName].stats;
sources[figureName].labels.data = computedData[figureName].labels;

// Update the axes labels.
updateAxisLabel(figures[figureName].below[0], rvName);

// Update the tooltips.
tooltips[figureName].stats.tooltips = [['', '@text']];
tooltips[figureName].distribution.tooltips = [[rvName, '@x']];
}
return computedData.marginal.distribution.bandwidth;
};
12 changes: 12 additions & 0 deletions src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import * as marginal1d from './callbacks';

// The CustomJS methods used by Bokeh require us to make the JavaScript available in the
// browser, which is done by defining it below.
(window as any).marginal1d = marginal1d;
Loading