This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: This commit includes the marginal 1D diagnostic tool with JavaScript callbacks. ### Motivation This PR completes one tool that uses Bokeh and JavaScript callbacks in order to create an interactive tool that can be viewed in Jupyter. This refactors the code in PR #1631 heavily, since pure Python callbacks were found to not function properly with internal tools. ### Changes proposed A new `tool` folder in the `diagnostics` folder contains the proposed changes. In this folder there is a `js` folder that contains all the JavaScript callbacks needed for the Bokeh tool. The tool creates plots of marginal distributions for each random variable of the model. The output is a self-contained HTML object that can be rendered in Jupyter without any external CDN calls for JS resources. Pull Request resolved: #1691 Test Plan: Unit tests for the Python and JavaScript will be done at a later commit. Right now the testing was to run the tool in the Coin Flipping tutorial, and to inspect the output and ensure only static resources were used. ### Types of changes - [ ] Docs change / refactoring / dependency upgrade - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ### Checklist - [x] My code follows the code style of this project. - [ ] My change requires a change to the documentation. - [ ] I have updated the documentation accordingly. - [x] I have read the **[CONTRIBUTING](https://github.com/facebookresearch/beanmachine/blob/main/CONTRIBUTING.md)** document. - [ ] I have added tests to cover my changes. - [ ] All new and existing tests passed. - [x] The title of my pull request is a short description of the requested changes. ### TODO - [ ] Python unit tests - [ ] JavaScript unit tests - [ ] Figure out if the build should run `npm run build` for the tools, or if we should just have the minified code for the JS callbacks in the code base. Reviewed By: feynmanliang Differential Revision: D39714194 Pulled By: horizon-blue fbshipit-source-id: 4d87a9fb4108c093f327a94ea33d604dcda68dc8
- Loading branch information
1 parent
edf7424
commit 91bce6a
Showing
29 changed files
with
4,893 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# flake8: noqa | ||
|
||
"""Visual diagnostic tools for Bean Machine models.""" | ||
|
||
import sys | ||
from pathlib import Path | ||
|
||
|
||
if sys.version_info >= (3, 8): | ||
from typing import TypedDict | ||
else: | ||
from typing_extensions import TypedDict | ||
|
||
|
||
TOOLS_DIR = Path(__file__).parent.resolve() | ||
JS_DIR = TOOLS_DIR.joinpath("js") | ||
JS_DIST_DIR = JS_DIR.joinpath("dist") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/** | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
const OFF = 0; | ||
const WARNING = 1; | ||
const ERROR = 2; | ||
|
||
module.exports = { | ||
root: true, | ||
env: { | ||
browser: true, | ||
commonjs: true, | ||
jest: true, | ||
node: true, | ||
}, | ||
parser: '@typescript-eslint/parser', | ||
parserOptions: { | ||
allowImportExportEverywhere: true, | ||
}, | ||
extends: ['airbnb', 'prettier', 'plugin:import/typescript'], | ||
plugins: ['prefer-arrow'], | ||
rules: { | ||
// Allow more than 1 class per file. | ||
'max-classes-per-file': ['error', {ignoreExpressions: true, max: 2}], | ||
// Allow snake_case. | ||
camelcase: [ | ||
OFF, | ||
{ | ||
properties: 'never', | ||
ignoreDestructuring: true, | ||
ignoreImports: true, | ||
ignoreGlobals: true, | ||
}, | ||
], | ||
'no-underscore-dangle': OFF, | ||
// Arrow function rules. | ||
'prefer-arrow/prefer-arrow-functions': [ | ||
ERROR, | ||
{ | ||
disallowPrototype: true, | ||
singleReturnOnly: false, | ||
classPropertiesAllowed: false, | ||
}, | ||
], | ||
'prefer-arrow-callback': [ERROR, {allowNamedFunctions: true}], | ||
'arrow-parens': [ERROR, 'always'], | ||
'arrow-body-style': [ERROR, 'always'], | ||
'func-style': [ERROR, 'declaration', {allowArrowFunctions: true}], | ||
'react/function-component-definition': [ | ||
ERROR, | ||
{ | ||
namedComponents: 'arrow-function', | ||
unnamedComponents: 'arrow-function', | ||
}, | ||
], | ||
// Ignore the global require, since some required packages are BrowserOnly. | ||
'global-require': 0, | ||
// We reassign several parameter objects since Bokeh is just updating values in the | ||
// them. | ||
'no-param-reassign': 0, | ||
// Ignore certain webpack alias because it can't be resolved | ||
'import/no-unresolved': [ | ||
ERROR, | ||
{ignore: ['^@theme', '^@docusaurus', '^@generated', '^@bokeh']}, | ||
], | ||
'import/extensions': OFF, | ||
'object-shorthand': [ERROR, 'never'], | ||
'prefer-destructuring': [WARNING, {object: true, array: true}], | ||
'no-nested-ternary': 0, | ||
}, | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"arrowParens": "always", | ||
"bracketSpacing": false, | ||
"printWidth": 88, | ||
"proseWrap": "never", | ||
"singleQuote": true, | ||
"trailingComma": "all" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
{ | ||
"name": "visual-diagnostic-tools", | ||
"version": "0.1.0", | ||
"description": "", | ||
"license": "MIT", | ||
"keywords": [], | ||
"repository": {}, | ||
"scripts": { | ||
"build": "webpack" | ||
}, | ||
"dependencies": { | ||
"@bokeh/bokehjs": "^2.4.3", | ||
"fast-kde": "^0.2.1" | ||
}, | ||
"devDependencies": { | ||
"@types/node": "^18.0.4", | ||
"@typescript-eslint/eslint-plugin": "^5.30.5", | ||
"@typescript-eslint/parser": "^5.30.5", | ||
"eslint": "^8.19.0", | ||
"eslint-config-airbnb": "^19.0.4", | ||
"eslint-config-prettier": "^8.5.0", | ||
"eslint-plugin-import": "^2.26.0", | ||
"eslint-plugin-jsx-a11y": "^6.5.1", | ||
"eslint-plugin-prefer-arrow": "^1.2.3", | ||
"eslint-plugin-react": "^7.28.0", | ||
"eslint-plugin-react-hooks": "^4.3.0", | ||
"prettier": "^2.7.1", | ||
"ts-loader": "^9.3.1", | ||
"ts-node": "^10.9.1", | ||
"typescript": "^4.7.4", | ||
"webpack": "^5.74.0", | ||
"webpack-cli": "^4.10.0" | ||
}, | ||
"overrides": { | ||
"cwise": "$cwise", | ||
"minimist": "$minimist", | ||
"quote-stream": "$quote-stream", | ||
"static-eval": "$static-eval", | ||
"static-module": "$static-module", | ||
"typedarray-pool": "$typedarray-pool" | ||
}, | ||
"peerDependencies": { | ||
"@types/cwise": "^1.0.4", | ||
"@types/minimist": "^1.2.2", | ||
"@types/static-eval": "^0.2.31", | ||
"@types/typedarray-pool": "^1.1.2", | ||
"buffer": "^6.0.3", | ||
"cwise": "^1.0.10", | ||
"minimist": "^1.2.6", | ||
"quote-stream": "^1.0.2", | ||
"static-eval": "2.1.0", | ||
"static-module": "^3.0.4", | ||
"typedarray-pool": "^1.2.0" | ||
} | ||
} |
190 changes: 190 additions & 0 deletions
190
src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
/** | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
import {Axis} from '@bokehjs/models/axes/axis'; | ||
import {cumulativeSum} from '../stats/array'; | ||
import {scaleToOne} from '../stats/dataTransformation'; | ||
import { | ||
interval as hdiInterval, | ||
data as hdiData, | ||
} from '../stats/highestDensityInterval'; | ||
import {oneD} from '../stats/marginal'; | ||
import {mean as computeMean} from '../stats/pointStatistic'; | ||
import {interpolatePoints} from '../stats/utils'; | ||
import * as interfaces from './interfaces'; | ||
|
||
// Define the names of the figures used for this Bokeh application. | ||
const figureNames = ['marginal', 'cumulative']; | ||
|
||
/** | ||
* Update the given Bokeh Axis object with the new label string. You must use this | ||
* method to update axis strings using TypeScript, otherwise the ts compiler will throw | ||
* a type check error. | ||
* | ||
* @param {Axis} axis - The Bokeh Axis object needing a new label. | ||
* @param {string | null} label - The new label for the Bokeh Axis object. | ||
*/ | ||
export const updateAxisLabel = (axis: Axis, label: string | null): void => { | ||
// Type check requirement. | ||
if ('axis_label' in axis) { | ||
axis.axis_label = label; | ||
} | ||
}; | ||
|
||
/** | ||
* Compute the following statistics for the given random variable data | ||
* | ||
* - lower bound for the highest density interval calculated from the marginalX; | ||
* - mean of the rawData; | ||
* - upper bound for the highest density interval calculated from the marginalY. | ||
* | ||
* @param {number[]} rawData - Raw random variable data from the model. | ||
* @param {number[]} marginalX - The support of the Kernel Density Estimate of the | ||
* random variable. | ||
* @param {number[]} marginalY - The Kernel Density Estimate of the random variable. | ||
* @param {number | null} [hdiProb=null] - The highest density interval probability | ||
* value. If the default value is not overwritten, then the default HDI probability | ||
* is 0.89. See Statistical Rethinking by McElreath for a description as to why this | ||
* value is the default. | ||
* @param {string[]} [text_align=['right', 'center', 'left']] - How to align the text | ||
* shown in the figure for the point statistics. | ||
* @param {number[]} [x_offset=[-5, 0, 5]] - Offset values for the text along the | ||
* x-axis. | ||
* @param {number[]} [y_offset=[0, 10, 0]] - Offset values for the text along the | ||
* y-axis | ||
* @returns {interfaces.LabelsData} Object containing the computed stats. | ||
*/ | ||
export const computeStats = ( | ||
rawData: number[], | ||
marginalX: number[], | ||
marginalY: number[], | ||
hdiProb: number | null = null, | ||
text_align: string[] = ['right', 'center', 'left'], | ||
x_offset: number[] = [-5, 0, 5], | ||
y_offset: number[] = [0, 10, 0], | ||
): interfaces.LabelsData => { | ||
// Set the default value to 0.89 if no default value has been given. | ||
const hdiProbability = hdiProb ?? 0.89; | ||
|
||
// Compute the point statistics for the KDE, and create labels to display them in the | ||
// figures. | ||
const mean = computeMean(marginalX); | ||
const hdiBounds = hdiInterval(rawData, hdiProbability); | ||
const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound]; | ||
const y = interpolatePoints({x: marginalX, y: marginalY, points: x}); | ||
const text = [ | ||
`Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`, | ||
`Mean: ${mean.toFixed(3)}`, | ||
`Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`, | ||
]; | ||
|
||
return { | ||
x: x, | ||
y: y, | ||
text: text, | ||
text_align: text_align, | ||
x_offset: x_offset, | ||
y_offset: y_offset, | ||
}; | ||
}; | ||
|
||
/** | ||
* Compute data for the one-dimensional marginal diagnostic tool. | ||
* | ||
* @param {number[]} data - Raw random variable data from the model. | ||
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when | ||
* calculating the Kernel Density Estimate (KDE). | ||
* @param {number} hdiProbability - The highest density interval probability to use when | ||
* calculating the HDI. | ||
* @returns {interfaces.Data} The marginal distribution and cumulative | ||
* distribution calculated from the given random variable data. Point statistics are | ||
* also calculated. | ||
*/ | ||
export const computeData = ( | ||
data: number[], | ||
bwFactor: number, | ||
hdiProbability: number, | ||
): interfaces.Data => { | ||
const output = {} as interfaces.Data; | ||
for (let i = 0; i < figureNames.length; i += 1) { | ||
const figureName = figureNames[i]; | ||
output[figureName] = {} as interfaces.GlyphData; | ||
|
||
// Compute the one-dimensional KDE and its cumulative distribution. | ||
const distribution = oneD(data, bwFactor); | ||
switch (figureName) { | ||
case 'cumulative': | ||
distribution.y = scaleToOne(cumulativeSum(distribution.y)); | ||
break; | ||
default: | ||
break; | ||
} | ||
|
||
// Compute the point statistics for the given data. | ||
const stats = computeStats(data, distribution.x, distribution.y, hdiProbability); | ||
|
||
output[figureName] = { | ||
distribution: distribution, | ||
hdi: hdiData(data, distribution.x, distribution.y, hdiProbability), | ||
stats: {x: stats.x, y: stats.y, text: stats.text}, | ||
labels: stats, | ||
}; | ||
} | ||
return output; | ||
}; | ||
|
||
/** | ||
* Callback used to update the Bokeh application in the notebook. | ||
* | ||
* @param {number[]} data - Raw random variable data from the model. | ||
* @param {string} rvName - The name of the random variable from the model. | ||
* @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when | ||
* calculating the kernel density estimate. | ||
* @param {number} hdiProbability - The highest density interval probability to use when | ||
* calculating the HDI. | ||
* @param {interfaces.Sources} sources - Bokeh sources used to render glyphs in the | ||
* application. | ||
* @param {interfaces.Figures} figures - Bokeh figures shown in the application. | ||
* @param {interfaces.Tooltips} tooltips - Bokeh tooltips shown on the glyphs. | ||
* @returns {number} We display the value of the bandwidth used for computing the Kernel | ||
* Density Estimate in a div, and must return that value here in order to update the | ||
* value displayed to the user. | ||
*/ | ||
export const update = ( | ||
data: number[], | ||
rvName: string, | ||
bwFactor: number, | ||
hdiProbability: number, | ||
sources: interfaces.Sources, | ||
figures: interfaces.Figures, | ||
tooltips: interfaces.Tooltips, | ||
): number => { | ||
const computedData = computeData(data, bwFactor, hdiProbability); | ||
for (let i = 0; i < figureNames.length; i += 1) { | ||
// Update all sources with new data calculated above. | ||
const figureName = figureNames[i]; | ||
sources[figureName].distribution.data = { | ||
x: computedData[figureName].distribution.x, | ||
y: computedData[figureName].distribution.y, | ||
}; | ||
sources[figureName].hdi.data = { | ||
base: computedData[figureName].hdi.base, | ||
lower: computedData[figureName].hdi.lower, | ||
upper: computedData[figureName].hdi.upper, | ||
}; | ||
sources[figureName].stats.data = computedData[figureName].stats; | ||
sources[figureName].labels.data = computedData[figureName].labels; | ||
|
||
// Update the axes labels. | ||
updateAxisLabel(figures[figureName].below[0], rvName); | ||
|
||
// Update the tooltips. | ||
tooltips[figureName].stats.tooltips = [['', '@text']]; | ||
tooltips[figureName].distribution.tooltips = [[rvName, '@x']]; | ||
} | ||
return computedData.marginal.distribution.bandwidth; | ||
}; |
12 changes: 12 additions & 0 deletions
12
src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/index.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
/** | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
import * as marginal1d from './callbacks'; | ||
|
||
// The CustomJS methods used by Bokeh require us to make the JavaScript available in the | ||
// browser, which is done by defining it below. | ||
(window as any).marginal1d = marginal1d; |
Oops, something went wrong.