diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts index d4c85f5e28..18a1d6ad5b 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal1d/callbacks.ts @@ -5,36 +5,20 @@ * LICENSE file in the root directory of this source tree. */ -import {Axis} from '@bokehjs/models/axes/axis'; -import {cumulativeSum} from '../stats/array'; +import {arrayMean, cumulativeSum} from '../stats/array'; import {scaleToOne} from '../stats/dataTransformation'; import { interval as hdiInterval, data as hdiData, } from '../stats/highestDensityInterval'; import {oneD} from '../stats/marginal'; -import {mean as computeMean} from '../stats/pointStatistic'; import {interpolatePoints} from '../stats/utils'; +import {updateAxisLabel} from '../utils/plottingUtils'; import * as interfaces from './interfaces'; // Define the names of the figures used for this Bokeh application. const figureNames = ['marginal', 'cumulative']; -/** - * Update the given Bokeh Axis object with the new label string. You must use this - * method to update axis strings using TypeScript, otherwise the ts compiler will throw - * a type check error. - * - * @param {Axis} axis - The Bokeh Axis object needing a new label. - * @param {string | null} label - The new label for the Bokeh Axis object. - */ -export const updateAxisLabel = (axis: Axis, label: string | null): void => { - // Type check requirement. - if ('axis_label' in axis) { - axis.axis_label = label; - } -}; - /** * Compute the following statistics for the given random variable data * @@ -72,7 +56,7 @@ export const computeStats = ( // Compute the point statistics for the KDE, and create labels to display them in the // figures. - const mean = computeMean(rawData); + const mean = arrayMean(rawData); const hdiBounds = hdiInterval(rawData, hdiProbability); const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound]; const y = interpolatePoints({x: marginalX, y: marginalY, points: x}); diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/callbacks.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/callbacks.ts new file mode 100644 index 0000000000..a1a88bd4c6 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/callbacks.ts @@ -0,0 +1,243 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {Axis} from '@bokehjs/models/axes/axis'; +import {arrayMean, linearRange} from '../stats/array'; +import { + data as computeHdiData, + data90Degrees, + interval as hdiInterval, +} from '../stats/highestDensityInterval'; +import {oneD} from '../stats/marginal'; +import {interpolatePoints} from '../stats/utils'; +import * as interfaces from './interfaces'; +import {updateAxisLabel} from '../utils/plottingUtils'; + +export const computeXData = ( + data: number[][], + hdiProbability: number, + bwFactor: number, +): interfaces.XData => { + const flatData = data.flat(); + + // Distribution + const distribution = oneD(flatData, bwFactor); + + // HDI + const hdi = computeHdiData(flatData, distribution.x, distribution.y, hdiProbability); + const hdiData = {base: hdi.base, lower: hdi.lower, upper: hdi.upper}; + + // Stats + const mean = arrayMean(flatData); + const hdiBounds = hdiInterval(flatData, hdiProbability); + const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound]; + const y = interpolatePoints({x: distribution.x, y: distribution.y, points: x}); + const text = [ + `Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`, + `Mean: ${mean.toFixed(3)}`, + `Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`, + ]; + const output = { + distribution: distribution, + hdi: hdiData, + stats: {x: x, y: y, text: text}, + }; + + return output; +}; + +export const computeYData = ( + data: number[][], + hdiProbability: number, + bwFactor: number, +): interfaces.YData => { + const flatData = data.flat(); + + // Distribution + const distribution = oneD(flatData, bwFactor); + + // HDI + const hdi = data90Degrees(flatData, distribution.x, distribution.y, hdiProbability); + const hdiData = { + lower: {base: hdi.upper.base, lower: hdi.upper.lower, upper: hdi.upper.upper}, + upper: {base: hdi.lower.base, lower: hdi.lower.lower, upper: hdi.lower.upper}, + }; + + // Stats + const mean = arrayMean(flatData); + const hdiBounds = hdiInterval(flatData, hdiProbability); + const x = [hdiBounds.lowerBound, mean, hdiBounds.upperBound]; + const y = interpolatePoints({x: distribution.x, y: distribution.y, points: x}); + const text = [ + `Lower HDI: ${hdiBounds.lowerBound.toFixed(3)}`, + `Mean: ${mean.toFixed(3)}`, + `Upper HDI: ${hdiBounds.upperBound.toFixed(3)}`, + ]; + const output = { + distribution: distribution, + hdi: hdiData, + stats: {x: x, y: y, text: text}, + }; + + return output; +}; + +export const computeXYData = ( + rawX: number[][], + computedX: interfaces.XData, + rawY: number[][], + computedY: interfaces.YData, +): interfaces.XYData => { + const flatDataX = rawX.flat(); + const flatDataY = rawY.flat(); + + // NOTE: Falling back to displaying data from the samples as the 2D KDE is not + // rendering properly. + const dataDistribution = {x: flatDataX, y: flatDataY}; + + // Stats: Create the stats for the 2D marginal. This is just a central point on the + // figure showing the mean values of both 1D marginals. + const stats = { + x: [computedX.stats.x[1]], + y: [computedY.stats.x[1]], + text: [ + `Mean: ${computedX.stats.x[1].toFixed(3)}/${computedY.stats.x[1].toFixed(3)}`, + ], + }; + + // HDI: Create the HDI guide lines in the 2D marginal distribution. These help the + // user understand how the 2D probability space is affected by changing the HDI + // regions of the 1D marginals independently. + const x = linearRange( + Math.min(...computedX.hdi.base), + Math.max(...computedX.hdi.base), + 1, + true, + 100, + ); + const y = linearRange( + Math.min(...computedY.hdi.lower.lower), + Math.max(...computedY.hdi.upper.upper), + 1, + true, + 100, + ); + const hdi = { + x: { + lower: {x: Array(y.length).fill(Math.min(...computedX.hdi.base)), y: y}, + upper: {x: Array(y.length).fill(Math.max(...computedX.hdi.base)), y: y}, + }, + y: { + lower: {x: x, y: Array(x.length).fill(Math.min(...computedY.hdi.lower.lower))}, + upper: {x: x, y: Array(x.length).fill(Math.max(...computedY.hdi.upper.upper))}, + }, + }; + const output = { + distribution: dataDistribution, + hdi: hdi, + stats: stats, + }; + return output; +}; + +export const computeData = ( + dataX: number[][], + hdiProbabilityX: number, + dataY: number[][], + hdiProbabilityY: number, + bwFactor: number, +): interfaces.Data => { + const xData = computeXData(dataX, hdiProbabilityX, bwFactor); + const yData = computeYData(dataY, hdiProbabilityY, bwFactor); + const xyData = computeXYData(dataX, xData, dataY, yData); + return {x: xData, y: yData, xy: xyData}; +}; + +export const update = ( + dataX: number[][], + hdiProbabilityX: number, + dataY: number[][], + hdiProbabilityY: number, + bwFactor: number, + xAxisLabel: string, + yAxisLabel: string, + figures: interfaces.Figures, + sources: interfaces.Sources, + tooltips: interfaces.Tooltips, + widgets: interfaces.Widgets, + glyphs: interfaces.Glyphs, +): number[] => { + const computedData = computeData( + dataX, + hdiProbabilityX, + dataY, + hdiProbabilityY, + bwFactor, + ); + // Update the x figure. + const xDistribution = { + x: computedData.x.distribution.x, + y: computedData.x.distribution.y, + }; + const bandwidthX = computedData.x.distribution.bandwidth; + sources.x.distribution.data = xDistribution; + sources.x.hdi.data = computedData.x.hdi; + sources.x.stats.data = computedData.x.stats; + tooltips.x.distribution.tooltips = [[xAxisLabel, '@x']]; + figures.xy.x_range = figures.x.x_range; + + // Update the y figure. + const yDistribution = { + x: computedData.y.distribution.y, + y: computedData.y.distribution.x, + }; + const bandwidthY = computedData.y.distribution.bandwidth; + sources.y.distribution.data = yDistribution; + sources.y.hdi.lower.data = computedData.y.hdi.lower; + sources.y.hdi.upper.data = computedData.y.hdi.upper; + const yStats = { + x: computedData.y.stats.y, + y: computedData.y.stats.x, + text: computedData.y.stats.text, + }; + sources.y.stats.data = yStats; + tooltips.y.distribution.tooltips = [[yAxisLabel, '@y']]; + figures.xy.y_range = figures.y.y_range; + + // Update the xy figure. + sources.xy.distribution.data = computedData.xy.distribution; + tooltips.xy.distribution.tooltips = [ + [xAxisLabel, '@x'], + [yAxisLabel, '@y'], + ]; + sources.xy.hdi.x.lower.data = computedData.xy.hdi.x.lower; + sources.xy.hdi.x.upper.data = computedData.xy.hdi.x.upper; + tooltips.xy.hdi.x.lower.tooltips = [[xAxisLabel, '@x']]; + tooltips.xy.hdi.x.upper.tooltips = [[xAxisLabel, '@x']]; + sources.xy.hdi.y.lower.data = computedData.xy.hdi.y.lower; + sources.xy.hdi.y.upper.data = computedData.xy.hdi.y.upper; + tooltips.xy.hdi.y.lower.tooltips = [[yAxisLabel, '@y']]; + tooltips.xy.hdi.y.upper.tooltips = [[yAxisLabel, '@y']]; + sources.xy.stats.data = computedData.xy.stats; + tooltips.xy.stats.tooltips = [ + [xAxisLabel, '@x'], + [yAxisLabel, '@y'], + ]; + + (window as any).data = computedData; + (window as any).figures = figures; + (window as any).glyphs = glyphs; + (window as any).sources = sources; + + updateAxisLabel(figures.xy.below[0] as Axis, xAxisLabel); + updateAxisLabel(figures.xy.left[0] as Axis, yAxisLabel); + + // Update widgets. + widgets.bw_div_x.text = `Bandwidth ${xAxisLabel}: ${bwFactor * bandwidthX}`; + widgets.bw_div_y.text = `Bandwidth ${yAxisLabel}: ${bwFactor * bandwidthY}`; + return [bandwidthX, bandwidthY]; +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/index.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/index.ts new file mode 100644 index 0000000000..1786ac0434 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/index.ts @@ -0,0 +1,12 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import * as marginal2d from './callbacks'; + +// The CustomJS methods used by Bokeh require us to make the JavaScript available in the +// browser, which is done by defining it below. +(window as any).marginal2d = marginal2d; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/interfaces.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/interfaces.ts new file mode 100644 index 0000000000..d5ae207801 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/marginal2d/interfaces.ts @@ -0,0 +1,121 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {Figure} from '@bokehjs/api/plotting'; +import {Circle, Line} from '@bokehjs/models/glyphs'; +import {ColumnDataSource} from '@bokehjs/models/sources/column_data_source'; +import {HoverTool} from '@bokehjs/models/tools/inspectors/hover_tool'; +import {Div} from '@bokehjs/models/widgets/div'; +import {Select} from '@bokehjs/models/widgets/selectbox'; +import {Slider} from '@bokehjs/models/widgets/slider'; + +export interface XData { + distribution: {x: number[]; y: number[]; bandwidth: number}; + hdi: {base: number[]; lower: number[]; upper: number[]}; + stats: {x: number[]; y: number[]; text: string[]}; +} + +export interface YData { + distribution: {x: number[]; y: number[]; bandwidth: number}; + hdi: { + lower: {base: number[]; lower: number[]; upper: number[]}; + upper: {base: number[]; lower: number[]; upper: number[]}; + }; + stats: {x: number[]; y: number[]; text: string[]}; +} + +export interface XYData { + distribution: {x: number[]; y: number[]}; + hdi: { + x: { + lower: {x: number[]; y: number[]}; + upper: {x: number[]; y: number[]}; + }; + y: { + lower: {x: number[]; y: number[]}; + upper: {x: number[]; y: number[]}; + }; + }; + stats: {x: number[]; y: number[]; text: string[]}; +} + +export interface Data { + x: XData; + y: YData; + xy: XYData; +} + +export interface Glyphs { + x: { + distribution: {glyph: Line; hover_glyph: Line}; + stats: {glyph: Circle; hover_glyph: Circle}; + }; + y: { + distribution: {glyph: Line; hover_glyph: Line}; + stats: {glyph: Circle; hover_glyph: Circle}; + }; + xy: { + distribution: Circle; + hdi: { + x: { + lower: {glyph: Line; hover_glyph: Line}; + upper: {glyph: Line; hover_glyph: Line}; + }; + y: { + lower: {glyph: Line; hover_glyph: Line}; + upper: {glyph: Line; hover_glyph: Line}; + }; + }; + stats: {glyph: Circle; hover_glyph: Circle}; + }; +} + +export interface Figures { + x: Figure; + y: Figure; + xy: Figure; +} + +export interface Sources { + x: {distribution: ColumnDataSource; hdi: ColumnDataSource; stats: ColumnDataSource}; + y: { + distribution: ColumnDataSource; + hdi: {lower: ColumnDataSource; upper: ColumnDataSource}; + stats: ColumnDataSource; + }; + xy: { + distribution: ColumnDataSource; + hdi: { + x: {lower: ColumnDataSource; upper: ColumnDataSource}; + y: {lower: ColumnDataSource; upper: ColumnDataSource}; + }; + stats: ColumnDataSource; + }; +} + +export interface Tooltips { + x: {distribution: HoverTool; stats: HoverTool}; + y: {distribution: HoverTool; stats: HoverTool}; + xy: { + distribution: HoverTool; + hdi: { + x: {lower: HoverTool; upper: HoverTool}; + y: {lower: HoverTool; upper: HoverTool}; + }; + stats: HoverTool; + }; +} + +export interface Widgets { + rv_select_x: Select; + rv_select_y: Select; + bw_factor_slider: Slider; + hdi_slider_x: Slider; + hdi_slider_y: Slider; + bw_div_x: Div; + bw_div_y: Div; +} diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts index 39dcf65bd0..ba5f5260e7 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/array.ts @@ -5,10 +5,33 @@ * LICENSE file in the root directory of this source tree. */ +/** + * Syntactic sugar for summing an array of numbers. + * + * @param {number[]} data - The array of data. + * @returns {number} The sum of the array of data. + */ +export const arraySum = (data: number[]): number => { + return data.reduce((previousValue, currentValue) => { + return previousValue + currentValue; + }); +}; + +/** + * Calculate the mean of the given array of data. + * + * @param {number[]} data - The array of data. + * @returns {number} The mean of the given data. + */ +export const arrayMean = (data: number[]): number => { + const dataSum = arraySum(data); + return dataSum / data.length; +}; + /** * Cumulative sum of the given data. * - * @param {number[]} data - Any array of data. + * @param {number[]} data - Any one dimensional array of data. * @returns {number[]} The cumulative sum of the given data. */ export const cumulativeSum = (data: number[]): number[] => { @@ -115,9 +138,9 @@ export const argSort = (data: number[]): number[] => { }; /** - * Count the number of time a value appears in an array. + * Count the number of times a value appears in an array. * - * @param {number[]} data - The numeric array to count objects for. + * @param {number[]} data - The numeric array that we will count the values in. * @returns {{[key: string]: number}} An object that contains the keys as the items in * the original array, and values that are counts of the key. */ diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/highestDensityInterval.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/highestDensityInterval.ts index e0356eb0b2..ab6e22609d 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/highestDensityInterval.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/highestDensityInterval.ts @@ -94,3 +94,82 @@ export const data = ( upperBound: hdi.upperBound, }; }; + +/** + * Create the HDI Bokeh annotation for a marginal that has been rotated and flipped. We + * must break up the annotation into a lower and upper component so that it renders + * correctly in the browser. + * + * @param {number[]} rvData - Raw random variable data from the model. + * @param {number[]} marginalX - The support of the Kernel Density Estimate of the + * random variable. + * @param {number[]} marginalY - The Kernel Density Estimate of the random variable. + * @param {number} hdiProbability - The highest density interval probability to use when + * calculating the HDI. + * @returns {{ + * lower: {base: number[]; lower: number[]; upper: number[]}; + * upper: {base: number[]; lower: number[]; upper: number[]}; + * lowerBound: number; + * upperBound: number; + * }} The lower and upper components are drawn on the same figure. + */ +export const data90Degrees = ( + rvData: number[], + marginalX: number[], + marginalY: number[], + hdiProbability: number, +): { + lower: { + base: number[]; + lower: number[]; + upper: number[]; + }; + upper: { + base: number[]; + lower: number[]; + upper: number[]; + }; + lowerBound: number; + upperBound: number; +} => { + const hdiBounds = interval(rvData, hdiProbability); + const hdiData = data(rvData, marginalX, marginalY, hdiProbability); + const hdiX = hdiData.base; + const n = hdiX.length; + const halfIndex = Math.floor(n / 2); + const xAtHalfIndex = hdiX[halfIndex]; + + const lowerBase = [0]; + const lowerLower = [hdiBounds.lowerBound]; + for (let i = 0; i < marginalX.length; i += 1) { + if (marginalX[i] <= xAtHalfIndex && marginalX[i] >= hdiBounds.lowerBound) { + lowerBase.push(marginalY[i]); + lowerLower.push(marginalX[i]); + } + } + const lowerUpper = Array(lowerBase.length).fill(xAtHalfIndex); + + const upperBase = [0]; + const upperUpper = [hdiBounds.upperBound]; + for (let i = marginalX.length - 1; i >= 0; i -= 1) { + if (marginalX[i] >= xAtHalfIndex && marginalX[i] <= hdiBounds.upperBound) { + upperBase.push(marginalY[i]); + upperUpper.push(marginalX[i]); + } + } + const upperLower = Array(upperBase.length).fill(xAtHalfIndex); + return { + lower: { + base: upperBase, + lower: upperLower, + upper: upperUpper, + }, + upper: { + base: lowerBase, + lower: lowerLower, + upper: lowerUpper, + }, + lowerBound: hdiBounds.lowerBound, + upperBound: hdiBounds.upperBound, + }; +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts index 40a4868571..003a72f2db 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/histogram.ts @@ -5,9 +5,8 @@ * LICENSE file in the root directory of this source tree. */ -import {linearRange, numericalSort, shape} from './array'; +import {arrayMean, linearRange, numericalSort, shape} from './array'; import {rankData, scaleToOne} from './dataTransformation'; -import {mean as computeMean} from './pointStatistic'; /** * Compute the histogram of the given data. @@ -16,7 +15,8 @@ import {mean as computeMean} from './pointStatistic'; * @param {number} [numBins] - The number of bins to use for the histogram. If none is * given, then we follow ArviZ's implementation by using twice then number of bins * of the Sturges formula. - * @returns {number[][]} [TODO:description] + * @returns {number[][]} A two-dimensional array where the first row is the bins and the + * second row is the histogram count. */ export const calculateHistogram = (data: number[], numBins: number = 0): number[][] => { const sortedData = numericalSort(data); @@ -123,7 +123,7 @@ export const rankHistogram = (data: number[][]): RankHistogram => { return value + i; }); - const chainRankMean = computeMean(chainCounts); + const chainRankMean = arrayMean(chainCounts); const left = binEdges.slice(0, binEdges.length - 1); const right = binEdges.slice(1); const binLabel = []; @@ -145,7 +145,7 @@ export const rankHistogram = (data: number[][]): RankHistogram => { line: {x: x, y: y}, chain: Array(x.length).fill(i + 1), rankMean: Array(x.length).fill(chainIndex - chainRankMean), - mean: Array(x.length).fill(computeMean(counts)), + mean: Array(x.length).fill(arrayMean(counts)), }; } return output; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/marginal.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/marginal.ts index fed70ee34c..87fde27d2e 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/marginal.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/marginal.ts @@ -6,6 +6,7 @@ */ import {density1d} from 'fast-kde/src/density1d'; +import {density2d} from 'fast-kde/src/density2d'; import {scaleToOne} from './dataTransformation'; /** @@ -47,3 +48,44 @@ export const oneD = ( bandwidth: kde1d.bandwidth(), }; }; + +/** + * Computes the 2D Kernel Density Estimate. + * + * @param {number[]} x - The raw random variable data of the model in the x direction. + * @param {number[]} y - The raw random variable data of the model in the y direction. + * @param {number} bwFactor - Multiplicative factor to be applied to the bandwidth when + * calculating the Kernel Density Estimate (KDE). + * @param {number[]} [bins] - The number of bins to use for calculating the 2D KDE. + * @returns {{x: number[]; y: number[]; z: number[]; bw: {x: number; y: number}}} + * The computed 2D KDE with bandwidths for both sets of data. + */ +export const twoD = ( + x: number[], + y: number[], + bwFactor: number, + bins: number[] = [128, 128], +): {x: number[]; y: number[]; z: number[]; bw: {x: number; y: number}} => { + // Prepare the random variables for calculating the 2D KDE using fast-kde. + const data = []; + for (let i: number = 0; i < x.length; i += 1) { + data.push({u: x[i], v: y[i]}); + } + + // Calculate the 2D KDE. + const kde2d = density2d(data, {x: 'u', y: 'v', bins: bins, adjust: bwFactor, pad: 3}); + const [bwX, bwY] = kde2d.bandwidth(); + + // Extract the 2D data points from the 2D KDE calculation. + const points: {x: number; y: number; z: number}[] = [...kde2d]; + const X: number[] = []; + const Y: number[] = []; + const Z: number[] = []; + for (let i: number = 0; i < points.length; i += 1) { + X[i] = points[i].x; + Y[i] = points[i].y; + Z[i] = points[i].z; + } + + return {x: X, y: Y, z: Z, bw: {x: bwX, y: bwY}}; +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/pointStatistic.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/stats/pointStatistic.ts deleted file mode 100644 index 3ee47b9876..0000000000 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/stats/pointStatistic.ts +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -/** - * Syntactic sugar for summing an array of numbers. - * - * @param {number[]} data - The array of data. - * @returns {number} The sum of the array of data. - */ -export const sum = (data: number[]): number => { - return data.reduce((previousValue, currentValue) => { - return previousValue + currentValue; - }); -}; - -/** - * Calculate the mean of the given array of data. - * - * @param {number[]} data - The array of data. - * @returns {number} The mean of the given data. - */ -export const mean = (data: number[]): number => { - const dataSum = sum(data); - return dataSum / data.length; -}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts index 1c4d4ab362..05fded3c17 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/trace/callbacks.ts @@ -5,31 +5,15 @@ * LICENSE file in the root directory of this source tree. */ -import {Axis} from '@bokehjs/models/axes/axis'; import * as interfaces from './interfaces'; -import {linearRange, shape} from '../stats/array'; +import {arrayMean, linearRange, shape} from '../stats/array'; import {interval as hdiInterval} from '../stats/highestDensityInterval'; import {rankHistogram} from '../stats/histogram'; import {oneD} from '../stats/marginal'; -import {mean} from '../stats/pointStatistic'; +import {updateAxisLabel} from '../utils/plottingUtils'; const figureNames = ['marginals', 'forests', 'traces', 'ranks']; -/** - * Update the given Bokeh Axis object with the new label string. You must use this - * method to update axis strings using TypeScript, otherwise the ts compiler will throw - * a type check error. - * - * @param {Axis} axis - The Bokeh Axis object needing a new label. - * @param {string | null} label - The new label for the Bokeh Axis object. - */ -export const updateAxisLabel = (axis: Axis, label: string | null): void => { - // Type check requirement. - if ('axis_label' in axis) { - axis.axis_label = label; - } -}; - /** * Compute data for the trace diagnostic tool. * @@ -69,7 +53,7 @@ export const computeData = ( const chainName = `chain${chainIndex}`; const chainData = data[j]; const marginal = oneD(chainData, bwFactor); - const marginalMean = mean(marginal.x); + const marginalMean = arrayMean(marginal.x); let hdiBounds; switch (figureName) { case 'marginals': diff --git a/src/beanmachine/ppl/diagnostics/tools/js/src/utils/plottingUtils.ts b/src/beanmachine/ppl/diagnostics/tools/js/src/utils/plottingUtils.ts new file mode 100644 index 0000000000..27730e62d6 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/js/src/utils/plottingUtils.ts @@ -0,0 +1,24 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import {Axis} from '@bokehjs/models/axes/axis'; + +/** + * Update the given Bokeh Axis object with the new label string. You must use this + * method to update axis strings using TypeScript, otherwise the ts compiler will throw + * a type check error. + * + * @param {Axis} axis - The Bokeh Axis object needing a new label. + * @param {string | null} label - The new label for the Bokeh Axis object. + */ +// eslint-disable-next-line import/prefer-default-export +export const updateAxisLabel = (axis: Axis, label: string | null): void => { + // Type check requirement. + if ('axis_label' in axis) { + axis.axis_label = label; + } +}; diff --git a/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json b/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json index 2212d273f3..22146c9a06 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json +++ b/src/beanmachine/ppl/diagnostics/tools/js/tsconfig.json @@ -23,12 +23,7 @@ "node_modules/@bokeh/bokehjs/build/js/lib/*", "node_modules/@bokeh/bokehjs/build/js/types/*" ], - "fast-kde/*": ["node_modules/fast-kde/*"], - "ndarray/*": ["node_modules/ndarray/*"], - "ndarray-fft/*": [ - "node_modules/ndarray-fft/*", - "bokeh_extensions/types/ndarray-fft.d.ts" - ] + "fast-kde/*": ["node_modules/fast-kde/*"] }, "resolveJsonModule": true, "skipLibCheck": true, diff --git a/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js b/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js index e13c2507f6..b7ed1b9be7 100644 --- a/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js +++ b/src/beanmachine/ppl/diagnostics/tools/js/webpack.config.js @@ -10,38 +10,27 @@ const path = require('path'); module.exports = { entry: { marginal1d: './src/marginal1d/index.ts', + marginal2d: './src/marginal2d/index.ts', trace: './src/trace/index.ts', }, output: { filename: '[name].js', path: path.resolve(__dirname, 'dist'), }, + mode: 'production', module: { rules: [ { test: /\.ts$/, use: 'ts-loader', - exclude: /node_modules\/(?!(@bokeh\/bokehjs\/build\/js\/lib)\/).*/, + exclude: /node_modules/, }, ], }, - target: 'web', - mode: 'production', resolve: { extensions: ['.ts', '.js'], - modules: ['./stats', './interfaces', './types', './node_modules'], alias: { - 'fast-kde/src/density1d': path.resolve( - __dirname, - 'node_modules/fast-kde/src/density1d.js', - ), - '@bokehjs/models/ranges/range1d': path.resolve( - __dirname, - 'node_modules/@bokeh/bokehjs/build/js/lib/models/ranges/range1d.js', - ), + '@bokehjs': path.resolve(__dirname, './node_modules/@bokeh/bokehjs/build/js/lib'), }, }, - optimization: { - minimize: false, - }, }; diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py b/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py index ff148175c3..3827776eda 100644 --- a/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py +++ b/src/beanmachine/ppl/diagnostics/tools/marginal1d/utils.py @@ -62,7 +62,7 @@ def create_sources() -> typing.Sources: - """Create Bokeh sources from the given data that will be bound to glyphs. + """Create Bokeh sources that will be bound to glyphs. Returns ------- diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal2d/__init__.py b/src/beanmachine/ppl/diagnostics/tools/marginal2d/__init__.py new file mode 100644 index 0000000000..7bec24cb17 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/marginal2d/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal2d/tool.py b/src/beanmachine/ppl/diagnostics/tools/marginal2d/tool.py new file mode 100644 index 0000000000..1103483140 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/marginal2d/tool.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Marginal 2D diagnostic tool for a Bean Machine model.""" +from __future__ import annotations + +from beanmachine.ppl.diagnostics.tools.marginal2d import utils +from beanmachine.ppl.diagnostics.tools.utils.diagnostic_tool_base import ( + DiagnosticToolBaseClass, +) +from beanmachine.ppl.inference.monte_carlo_samples import MonteCarloSamples +from bokeh.models import Model +from bokeh.models.callbacks import CustomJS + + +class Marginal2d(DiagnosticToolBaseClass): + """ + Marginal 2D diagnostic tool. + + Args: + mcs (MonteCarloSamples): The return object from running a Bean Machine model. + + Attributes: + data (Dict[str, List[List[float]]]): JSON serializable representation of the + given `mcs` object. + rv_names (List[str]): The list of random variables string names for the given + model. + num_chains (int): The number of chains of the model. + num_draws (int): The number of draws of the model for each chain. + palette (List[str]): A list of color values used for the glyphs in the figures. + The colors are specifically chosen from the Colorblind palette defined in + Bokeh. + tool_js (str):The JavaScript callbacks needed to render the Bokeh tool + independently from a Python server. + """ + + def __init__(self: Marginal2d, mcs: MonteCarloSamples) -> None: + super(Marginal2d, self).__init__(mcs) + + def create_document(self: Marginal2d) -> Model: + # Initialize widget values using Python. + rv_name_x = self.rv_names[0] + rv_name_y = self.rv_names[0] + bw_factor = 1.0 + bandwidth_x = 1.0 + bandwidth_y = 1.0 + + # NOTE: We are going to use Python and Bokeh to render the tool in the notebook + # output cell, however, we WILL NOT use Python to calculate any of the + # statistics displayed in the tool. We do this so we can make the BROWSER + # run all the calculations based on user interactions. If we did not + # employ this strategy, then the initial display a user would receive + # would be calculated by Python, and any subsequent updates would be + # calculated by JavaScript. The side-effect of having two backends + # calculate data could cause the figures to flicker, which would not be a + # good end user experience. + # + # Bokeh 3.0 is implementing an "on load" feature, which would nullify this + # requirement, and until that version is released, we have to employ this + # work-around. + + # Create empty Bokeh sources using Python. + sources = utils.create_sources() + + # Create empty figures for the tool using Python. + figures = utils.create_figures(rv_name_x=rv_name_x, rv_name_y=rv_name_y) + + # Create empty glyphs and attach them to the figures using Python. + glyphs = utils.create_glyphs() + utils.add_glyphs(figures=figures, glyphs=glyphs, sources=sources) + + # Create empty annotations and attach them to the figures using Python. + annotations = utils.create_annotations(sources=sources) + utils.add_annotations(figures=figures, annotations=annotations) + + # Create empty tool tips and attach them to the figures using Python. + tooltips = utils.create_tooltips( + rv_name_x=rv_name_x, + rv_name_y=rv_name_y, + figures=figures, + ) + utils.add_tooltips(figures=figures, tooltips=tooltips) + + # Create the widgets for the tool using Python. + widgets = utils.create_widgets( + rv_name_x=rv_name_x, + rv_name_y=rv_name_y, + rv_names=self.rv_names, + bw_factor=bw_factor, + bandwidth_x=bandwidth_x, + bandwidth_y=bandwidth_y, + ) + + # Create the view of the tool and serialize it into HTML using static resources + # from Bokeh. Embedding the tool in this manner prevents external CDN calls for + # JavaScript resources, and prevents the user from having to know where the + # Bokeh server is. + tool_view = utils.create_view(figures=figures, widgets=widgets) + + # Create callbacks for the tool using JavaScript. + callback_js = f""" + const rvNameX = widgets.rv_select_x.value; + const rvDataX = data[rvNameX]; + const rvNameY = widgets.rv_select_y.value; + const rvDataY = data[rvNameY]; + const hdiProbabilityX = widgets.hdi_slider_x.value / 100; + const hdiProbabilityY = widgets.hdi_slider_y.value / 100; + const bwFactor = widgets.bw_factor_slider.value; + // Remove the CSS classes that dim the tool output on initial load. + const rvX = widgets.rv_select_x.value; + const rvY = widgets.rv_select_y.value; + const rvXCheck = rvX !== 'Select a random variable...'; + const rvYCheck = rvY !== 'Select a random variable...'; + const cssRemovalCheck = rvXCheck && rvYCheck; + if (cssRemovalCheck) {{ + const toolTab = toolView.tabs[0]; + const toolChildren = toolTab.child.children; + const dimmedComponent = toolChildren[1]; + dimmedComponent.css_classes = []; + }}; + try {{ + const [bandwidthX, bandwidthY] = marginal2d.update( + rvDataX, + hdiProbabilityX, + rvDataY, + hdiProbabilityY, + bwFactor, + rvNameX, + rvNameY, + figures, + sources, + tooltips, + widgets, + glyphs, + ); + }} catch (error) {{ + {self.tool_js} + const [bandwidthX, bandwidthY] = marginal2d.update( + rvDataX, + hdiProbabilityX, + rvDataY, + hdiProbabilityY, + bwFactor, + rvNameX, + rvNameY, + figures, + sources, + tooltips, + widgets, + glyphs, + ); + }} + figures.xy.x_range = figures.x.x_range; + figures.xy.y_range = figures.x.y_range; + """ + + # Each widget requires the following dictionary for the CustomJS method. Notice + # that the callback_js object above uses the names defined as keys in the below + # object with values defined by the Python objects. + callback_arguments = { + "data": self.data, + "widgets": widgets, + "sources": sources, + "figures": figures, + "tooltips": tooltips, + "toolView": tool_view, + "widgets": widgets, + "glyphs": glyphs, + } + + # Each widget requires slightly different JS, except for the sliders. + rv_select_x_js = f""" + widgets.bw_factor_slider.value = 1.0; + widgets.hdi_slider_x.value = 89; + widgets.hdi_slider_y.value = 89; + {callback_js}; + """ + rv_select_x_callback = CustomJS(args=callback_arguments, code=rv_select_x_js) + rv_select_y_js = f""" + widgets.bw_factor_slider.value = 1.0; + widgets.hdi_slider_x.value = 89; + widgets.hdi_slider_y.value = 89; + {callback_js}; + """ + rv_select_y_callback = CustomJS(args=callback_arguments, code=rv_select_y_js) + slider_js = f""" + {callback_js}; + """ + slider_callback = CustomJS(args=callback_arguments, code=slider_js) + + # Tell Python to use the JavaScript. + widgets["rv_select_x"].js_on_change("value", rv_select_x_callback) + widgets["rv_select_y"].js_on_change("value", rv_select_y_callback) + widgets["bw_factor_slider"].js_on_change("value", slider_callback) + widgets["hdi_slider_x"].js_on_change("value", slider_callback) + widgets["hdi_slider_y"].js_on_change("value", slider_callback) + + return tool_view diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal2d/typing.py b/src/beanmachine/ppl/diagnostics/tools/marginal2d/typing.py new file mode 100644 index 0000000000..6ce3c13b94 --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/marginal2d/typing.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Marginal 2D diagnostic tool types for a Bean Machine model.""" +from __future__ import annotations + +from typing import Any, Dict, List, Union + +from beanmachine.ppl.diagnostics.tools import TypedDict +from bokeh.models.annotations import Band +from bokeh.models.glyphs import Circle, Image, Line +from bokeh.models.sources import ColumnDataSource +from bokeh.models.tools import HoverTool +from bokeh.models.widgets.inputs import Select +from bokeh.models.widgets.markups import Div +from bokeh.models.widgets.sliders import Slider +from bokeh.plotting import Figure + + +# NOTE: These are the types pyre gives us when using `reveal_type(...)` on the outputs +# of the methods. +Data = Dict[ + str, + Union[ + Dict[ + str, + Union[ + Dict[str, List[Any]], + Dict[str, Dict[str, List[Any]]], + Dict[str, Union[List[Any], float]], + ], + ], + Dict[ + str, + Union[Dict[str, List[Any]], Dict[str, Dict[str, Dict[str, List[Any]]]]], + ], + Dict[str, Union[Dict[str, List[Any]], Dict[str, Union[List[Any], float]]]], + ], +] +Sources = Dict[ + str, + Union[ + Dict[str, Union[Dict[str, Dict[str, ColumnDataSource]], ColumnDataSource]], + Dict[str, Union[Dict[str, ColumnDataSource], ColumnDataSource]], + Dict[str, ColumnDataSource], + ], +] +Figures = Dict[str, Any] +Glyphs = Dict[ + str, + Union[ + Dict[ + str, + Union[ + Dict[str, Dict[str, Dict[str, Line]]], + Dict[str, Circle], + ], + ], + Dict[str, Union[Dict[str, Circle], Dict[str, Line]]], + ], +] +Annotations = Dict[str, Union[Dict[str, Band], Band]] +Tooltips = Dict[ + str, + Union[ + Dict[str, Union[Dict[str, Dict[str, HoverTool]], HoverTool]], + Dict[str, HoverTool], + ], +] +Widgets = Dict[str, Union[Div, Select, Slider]] + + +# NOTE: TypedDict objects are for reference only. Due to the way pyre accesses keys in +# dictionaries, and how NumPy casts arrays when using tolist(), we are unable to +# use them, but they provide semantic information for the different types. + + +class XSource(TypedDict): # pyre-ignore + distribution: ColumnDataSource + hdi: ColumnDataSource + stats: ColumnDataSource + + +class YHDISource(TypedDict): # pyre-ignore + lower: ColumnDataSource + upper: ColumnDataSource + + +class YSource(TypedDict): # pyre-ignore + distribution: ColumnDataSource + hdi: YHDISource + stats: ColumnDataSource + + +class XYHDIBoundsSource(TypedDict): # pyre-ignore + lower: ColumnDataSource + upper: ColumnDataSource + + +class XYHDISource(TypedDict): # pyre-ignore + x: XYHDIBoundsSource + y: XYHDIBoundsSource + + +class XYSource(TypedDict): # pyre-ignore + distribution: ColumnDataSource + hdi: XYHDISource + stats: ColumnDataSource + + +class _Sources(TypedDict): # pyre-ignore + x: XSource + y: YSource + xy: XYSource + + +class _Figures(TypedDict): # pyre-ignore + x: Figure + y: Figure + xy: Figure + + +class LineGlyph(TypedDict): # pyre-ignore + glyph: Line + hover_glyph: Line + + +class CircleGlyph(TypedDict): # pyre-ignore + glyph: Circle + hover_glyph: Circle + + +class XorYGlyphs(TypedDict): # pyre-ignore + distribution: LineGlyph + stats: CircleGlyph + + +class LowerOrUpperHDIGlyphs(TypedDict): # pyre-ignore + lower: LineGlyph + upper: LineGlyph + + +class XYHDIGlyphs(TypedDict): # pyre-ignore + x: LowerOrUpperHDIGlyphs + y: LowerOrUpperHDIGlyphs + + +class XYGlyphs(TypedDict): # pyre-ignore + distribution: Image + hdi: XYHDIGlyphs + stats: CircleGlyph + + +class _Glyphs(TypedDict): # pyre-ignore + x: XorYGlyphs + y: XorYGlyphs + xy: XYGlyphs + + +class YAnnotations(TypedDict): # pyre-ignore + lower: Band + upper: Band + + +class _Annotations(TypedDict): # pyre-ignore + x: Band + y: YAnnotations + + +class XorYTooltips(TypedDict): # pyre-ignore + distribution: HoverTool + stats: HoverTool + + +class LowerOrUpperTooltips(TypedDict): # pyre-ignore + lower: HoverTool + upper: HoverTool + + +class XYHDITooltips(TypedDict): # pyre-ignore + x: LowerOrUpperTooltips + y: LowerOrUpperTooltips + + +class XYTooltips(TypedDict): # pyre-ignore + hdi: XYHDITooltips + stats: HoverTool + + +class _Tooltips(TypedDict): # pyre-ignore + x: XorYTooltips + y: XorYTooltips + xy: XYTooltips + + +class _Widgets(TypedDict): # pyre-ignore + rv_select_x: Select + rv_select_y: Select + hdi_slider_x: Slider + hdi_slider_y: Slider + bw_div_x: Div + bw_div_y: Div diff --git a/src/beanmachine/ppl/diagnostics/tools/marginal2d/utils.py b/src/beanmachine/ppl/diagnostics/tools/marginal2d/utils.py new file mode 100644 index 0000000000..2a19d0c29c --- /dev/null +++ b/src/beanmachine/ppl/diagnostics/tools/marginal2d/utils.py @@ -0,0 +1,776 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Methods used to generate the diagnostic tool.""" +from __future__ import annotations + +from typing import List + +import numpy as np +from beanmachine.ppl.diagnostics.tools.marginal2d import typing +from beanmachine.ppl.diagnostics.tools.utils.plotting_utils import ( + choose_palette, + create_toolbar, + filter_renderers, +) +from bokeh.models.annotations import Band +from bokeh.models.glyphs import Circle, Line +from bokeh.models.layouts import Column, GridBox, Row +from bokeh.models.sources import ColumnDataSource +from bokeh.models.tools import HoverTool +from bokeh.models.widgets.inputs import Select +from bokeh.models.widgets.markups import Div +from bokeh.models.widgets.panels import Panel, Tabs +from bokeh.models.widgets.sliders import Slider +from bokeh.plotting import figure + + +MARGINAL1D_PLOT_WIDTH = 500 +MARGINAL1D_PLOT_HEIGHT = 100 +MARGINAL2D_PLOT_WIDTH = MARGINAL1D_PLOT_WIDTH +MARGINAL2D_PLOT_HEIGHT = MARGINAL2D_PLOT_WIDTH +# Define what the empty data object looks like in order to make the browser handle all +# computations. +EMPTY_DATA = { + "x": { + "distribution": {"x": [], "y": [], "bandwidth": np.NaN}, + "hdi": {"base": [], "lower": [], "upper": []}, + "stats": {"x": [], "y": [], "text": []}, + "labels": { + "x": [], + "y": [], + "text": [], + "text_align": [], + "x_offset": [], + "y_offset": [], + }, + }, + "y": { + "distribution": {"x": [], "y": [], "bandwidth": np.NaN}, + "hdi": { + "lower": {"base": [], "lower": [], "upper": []}, + "upper": {"base": [], "lower": [], "upper": []}, + }, + "stats": {"x": [], "y": [], "text": []}, + "labels": { + "x": [], + "y": [], + "text": [], + "text_align": [], + "x_offset": [], + "y_offset": [], + }, + }, + "xy": { + "distribution": {"x": [], "y": []}, + "hdi": { + "x": {"lower": {"x": [], "y": []}, "upper": {"x": [], "y": []}}, + "y": {"lower": {"x": [], "y": []}, "upper": {"x": [], "y": []}}, + }, + "stats": {"x": [], "y": [], "text": []}, + "labels": { + "x": [], + "y": [], + "text": [], + "text_align": [], + "x_offset": [], + "y_offset": [], + }, + }, +} + + +def create_sources() -> typing.Sources: + """ + Create Bokeh sources that will be bound to glyphs. + + Returns: + typing.Sources: A dictionary of Bokeh ColumnDataSource objects. + """ + output = { + "x": { + "distribution": ColumnDataSource({"x": [], "y": []}), + "hdi": ColumnDataSource({"base": [], "lower": [], "upper": []}), + "stats": ColumnDataSource({"x": [], "y": [], "text": []}), + }, + "y": { + "distribution": ColumnDataSource({"x": [], "y": []}), + "hdi": { + "lower": ColumnDataSource({"base": [], "lower": [], "upper": []}), + "upper": ColumnDataSource({"base": [], "lower": [], "upper": []}), + }, + "stats": ColumnDataSource({"x": [], "y": [], "text": []}), + }, + "xy": { + "distribution": ColumnDataSource({"x": [], "y": []}), + "hdi": { + "x": { + "lower": ColumnDataSource({"x": [], "y": []}), + "upper": ColumnDataSource({"x": [], "y": []}), + }, + "y": { + "lower": ColumnDataSource({"x": [], "y": []}), + "upper": ColumnDataSource({"x": [], "y": []}), + }, + }, + "stats": ColumnDataSource({"x": [], "y": [], "text": []}), + }, + } + return output + + +def create_figures(rv_name_x: str, rv_name_y: str) -> typing.Figures: + """ + Create the Bokeh figures used for the tool. + + Args: + rv_name_x (str): The name of the random variable data in the x-direction. + rv_name_y (str): The name of the random variable data in the y-direction. + + Returns: + typing.Figures: A dictionary of Bokeh Figure objects. + """ + # x figure + x = figure( + outline_line_color=None, + min_border=None, + width=MARGINAL1D_PLOT_WIDTH, + height=MARGINAL1D_PLOT_HEIGHT, + name="x", + y_axis_label=rv_name_y, + ) + # NOTE: The extra steps we are taking for customizing both the x and y figures stem + # from the fact that the plots will shift to fit in their allotted space if no + # axis exists. In order to prevent these visual shifts, we keep the axes + # around and manipulate them so you cannot see them. Unfortunately we must + # keep the major tick labels around and color them white, otherwise the + # problem will persist. + x.grid.visible = False + x.xaxis.visible = False + x.x_range.range_padding = 0 + x.y_range.range_padding = 0 + x.yaxis.axis_label_text_color = None + x.yaxis.axis_line_color = None + x.yaxis.major_label_text_color = "white" + x.yaxis.major_tick_line_color = None + x.yaxis.minor_tick_line_color = None + + # y figure + y = figure( + outline_line_color=None, + min_border=None, + width=MARGINAL1D_PLOT_HEIGHT, + height=MARGINAL1D_PLOT_WIDTH, + name="y", + x_axis_label=rv_name_x, + ) + y.grid.visible = False + y.yaxis.visible = False + y.x_range.range_padding = 0 + y.y_range.range_padding = 0 + y.xaxis.axis_label_text_color = None + y.xaxis.axis_line_color = None + y.xaxis.major_label_text_color = "white" + y.xaxis.major_tick_line_color = None + y.xaxis.minor_tick_line_color = None + + # xy figure + xy = figure( + outline_line_color="black", + min_border=None, + width=MARGINAL2D_PLOT_WIDTH, + height=MARGINAL2D_PLOT_HEIGHT, + x_axis_label=rv_name_x, + y_axis_label=rv_name_y, + x_range=x.x_range, + y_range=y.y_range, + name="xy", + ) + xy.grid.visible = False + xy.x_range.range_padding = 0 + xy.y_range.range_padding = 0 + + output = {"x": x, "y": y, "xy": xy} + return output + + +def create_glyphs() -> typing.Glyphs: + """ + Create the glyphs used for the figures of the tool. + + Returns: + typing.Glyphs: A dictionary of Bokeh Glyphs objects. + """ + palette = choose_palette(4) + glyph_color = palette[0] + hover_glyph_color = palette[1] + mean_color = palette[3] + output = { + "x": { + "distribution": { + "glyph": Line( + x="x", + y="y", + line_alpha=0.7, + line_color=glyph_color, + line_width=2.0, + name="xDistribution", + ), + "hover_glyph": Line( + x="x", + y="y", + line_alpha=1.0, + line_color=hover_glyph_color, + line_width=2.0, + ), + }, + "stats": { + "glyph": Circle( + x="x", + y="y", + size=10, + fill_alpha=1.0, + fill_color=glyph_color, + line_color="white", + name="xStats", + ), + "hover_glyph": Circle( + x="x", + y="y", + size=10, + fill_alpha=1.0, + fill_color=hover_glyph_color, + line_color="black", + ), + }, + }, + "y": { + "distribution": { + "glyph": Line( + x="x", + y="y", + line_alpha=0.7, + line_color=glyph_color, + line_width=2.0, + name="yDistribution", + ), + "hover_glyph": Line( + x="x", + y="y", + line_alpha=1.0, + line_color=hover_glyph_color, + line_width=2.0, + ), + }, + "stats": { + "glyph": Circle( + x="x", + y="y", + size=10, + fill_alpha=1.0, + fill_color=glyph_color, + line_color="white", + name="yStats", + ), + "hover_glyph": Circle( + x="x", + y="y", + size=10, + fill_alpha=1.0, + fill_color=hover_glyph_color, + line_color="black", + ), + }, + }, + "xy": { + "distribution": { + "glyph": Circle( + x="x", + y="y", + size=5, + fill_alpha=0.4, + line_color="white", + fill_color=glyph_color, + name="xyDistribution", + ), + "hover_glyph": Circle( + x="x", + y="y", + size=5, + fill_alpha=1.0, + line_color="black", + fill_color=hover_glyph_color, + ), + }, + "hdi": { + "x": { + "lower": { + "glyph": Line( + x="x", + y="y", + line_alpha=0.7, + line_color="black", + line_width=2.0, + name="xyLowerXHDI", + ), + "hover_glyph": Line( + x="x", + y="y", + line_alpha=1.0, + line_color="black", + line_width=2.0, + ), + }, + "upper": { + "glyph": Line( + x="x", + y="y", + line_alpha=0.7, + line_color="black", + line_width=2.0, + name="xyUpperXHDI", + ), + "hover_glyph": Line( + x="x", + y="y", + line_alpha=1.0, + line_color="black", + line_width=2.0, + ), + }, + }, + "y": { + "lower": { + "glyph": Line( + x="x", + y="y", + line_alpha=0.7, + line_color="black", + line_width=2.0, + name="xyLowerYHDI", + ), + "hover_glyph": Line( + x="x", + y="y", + line_alpha=1.0, + line_color="black", + line_width=2.0, + ), + }, + "upper": { + "glyph": Line( + x="x", + y="y", + line_alpha=0.7, + line_color="black", + line_width=2.0, + name="xyUpperYHDI", + ), + "hover_glyph": Line( + x="x", + y="y", + line_alpha=1.0, + line_color="black", + line_width=2.0, + ), + }, + }, + }, + "stats": { + "glyph": Circle( + x="x", + y="y", + size=20, + fill_alpha=0.5, + fill_color=mean_color, + line_color="white", + name="xyStats", + ), + "hover_glyph": Circle( + x="x", + y="y", + size=20, + fill_alpha=1.0, + fill_color=mean_color, + line_color="black", + ), + }, + }, + } + return output + + +def add_glyphs( + figures: typing.Figures, + glyphs: typing.Glyphs, + sources: typing.Sources, +) -> None: + """ + Bind source data to glyphs and add the glyphs to the given figures. + + Args: + figures (typing.Figures): A dictionary of Bokeh Figure objects. + glyphs (typing.Glyphs): A dictionary of Bokeh Glyphs objects. + sources (typing.Sources): A dictionary of Bokeh ColumnDataSource objects. + + Returns + None: Adds data bound glyphs to the given figures directly. + """ + # x figure + figures["x"].add_glyph( + source_or_glyph=sources["x"]["distribution"], + glyph=glyphs["x"]["distribution"]["glyph"], + hover_glyph=glyphs["x"]["distribution"]["hover_glyph"], + name="xDistribution", + ) + figures["x"].add_glyph( + source_or_glyph=sources["x"]["stats"], + glyph=glyphs["x"]["stats"]["glyph"], + hover_glyph=glyphs["x"]["stats"]["hover_glyph"], + name="xStats", + ) + + # y figure + figures["y"].add_glyph( + source_or_glyph=sources["y"]["distribution"], + glyph=glyphs["y"]["distribution"]["glyph"], + hover_glyph=glyphs["y"]["distribution"]["hover_glyph"], + name="yDistribution", + ) + figures["y"].add_glyph( + source_or_glyph=sources["y"]["stats"], + glyph=glyphs["y"]["stats"]["glyph"], + hover_glyph=glyphs["y"]["stats"]["hover_glyph"], + name="yStats", + ) + + # xy figure + figures["xy"].add_glyph( + source_or_glyph=sources["xy"]["distribution"], + glyph=glyphs["xy"]["distribution"]["glyph"], + hover_glyph=glyphs["xy"]["distribution"]["hover_glyph"], + name="xyDistribution", + ) + figures["xy"].add_glyph( + source_or_glyph=sources["xy"]["hdi"]["x"]["lower"], + glyph=glyphs["xy"]["hdi"]["x"]["lower"]["glyph"], + hover_glyph=glyphs["xy"]["hdi"]["x"]["lower"]["hover_glyph"], + name="xyHDIXLower", + ) + figures["xy"].add_glyph( + source_or_glyph=sources["xy"]["hdi"]["x"]["upper"], + glyph=glyphs["xy"]["hdi"]["x"]["upper"]["glyph"], + hover_glyph=glyphs["xy"]["hdi"]["x"]["upper"]["hover_glyph"], + name="xyHDIXUpper", + ) + figures["xy"].add_glyph( + source_or_glyph=sources["xy"]["hdi"]["y"]["lower"], + glyph=glyphs["xy"]["hdi"]["y"]["lower"]["glyph"], + hover_glyph=glyphs["xy"]["hdi"]["y"]["lower"]["hover_glyph"], + name="xyHDIYLower", + ) + figures["xy"].add_glyph( + source_or_glyph=sources["xy"]["hdi"]["y"]["upper"], + glyph=glyphs["xy"]["hdi"]["y"]["upper"]["glyph"], + hover_glyph=glyphs["xy"]["hdi"]["y"]["upper"]["hover_glyph"], + name="xyHDIYUpper", + ) + figures["xy"].add_glyph( + source_or_glyph=sources["xy"]["stats"], + glyph=glyphs["xy"]["stats"]["glyph"], + hover_glyph=glyphs["xy"]["stats"]["hover_glyph"], + name="xyStats", + ) + + +def create_annotations(sources: typing.Sources) -> typing.Annotations: + """ + Create any annotations for the figures of the tool. + + Args: + sources (typing.Sources): A dictionary of Bokeh ColumnDataSource objects. + + Returns: + typing.Annotations: A dictionary of Bokeh Annotation objects. + """ + palette = choose_palette(1) + color = palette[0] + output = { + "x": Band( + base="base", + lower="lower", + upper="upper", + source=sources["x"]["hdi"], + level="underlay", + fill_color=color, + fill_alpha=0.2, + line_color=None, + name="xHDI", + ), + "y": { + "lower": Band( + base="base", + lower="lower", + upper="upper", + source=sources["y"]["hdi"]["lower"], + level="underlay", + fill_color=color, + fill_alpha=0.2, + line_color=None, + name="yLowerHDI", + ), + "upper": Band( + base="base", + lower="lower", + upper="upper", + source=sources["y"]["hdi"]["upper"], + level="underlay", + fill_color=color, + fill_alpha=0.2, + line_color=None, + name="yUpperHDI", + ), + }, + } + return output + + +def add_annotations(figures: typing.Figures, annotations: typing.Annotations) -> None: + """ + Add the given annotations to the given figures of the tool. + + Args: + figures (typing.Figures): A dictionary of Bokeh Figure objects. + annotations (typing.Annotations): A dictionary of Bokeh Annotation objects. + + Returns: + None: Adds annotations directly to the given figures. + """ + figures["x"].add_layout(annotations["x"]) + figures["y"].add_layout(annotations["y"]["lower"]) + figures["y"].add_layout(annotations["y"]["upper"]) + + +def create_tooltips( + rv_name_x: str, + rv_name_y: str, + figures: typing.Figures, +) -> typing.Tooltips: + """ + Create hover tools for the glyphs used in the figures of the tool. + + Args: + rv_name_x (str): The name of the random variable data in the x-direction. + rv_name_y (str): The name of the random variable data in the y-direction. + figures (typing.Figures): A dictionary of Bokeh Figure objects. + + Returns: + typing.Tooltips: A dictionary of Bokeh HoverTools objects. + """ + x_dist = filter_renderers(figure=figures["x"], search="xDistribution") + x_stats = filter_renderers(figure=figures["x"], search="xStats") + y_dist = filter_renderers(figure=figures["y"], search="yDistribution") + y_stats = filter_renderers(figure=figures["y"], search="yStats") + xy_dist = filter_renderers(figure=figures["xy"], search="xyDistribution") + xy_lower_x_hdi = filter_renderers(figure=figures["xy"], search="xyLowerXHDI") + xy_upper_x_hdi = filter_renderers(figure=figures["xy"], search="xyUpperXHDI") + xy_lower_y_hdi = filter_renderers(figure=figures["xy"], search="xyLowerYHDI") + xy_upper_y_hdi = filter_renderers(figure=figures["xy"], search="xyUpperYHDI") + xy_stats = filter_renderers(figure=figures["xy"], search="xyStats") + output = { + "x": { + "distribution": HoverTool(renderers=x_dist, tooltips=[(rv_name_x, "@x")]), + "stats": HoverTool(renderers=x_stats, tooltips=[("", "@text")]), + }, + "y": { + "distribution": HoverTool(renderers=y_dist, tooltips=[(rv_name_y, "@y")]), + "stats": HoverTool(renderers=y_stats, tooltips=[("", "@text")]), + }, + "xy": { + "distribution": HoverTool( + renderers=xy_dist, + tooltips=[(rv_name_x, "@x"), (rv_name_y, "@y")], + ), + "hdi": { + "x": { + "lower": HoverTool( + renderers=xy_lower_x_hdi, + tooltips=[(rv_name_x, "@x")], + ), + "upper": HoverTool( + renderers=xy_upper_x_hdi, + tooltips=[(rv_name_x, "@x")], + ), + }, + "y": { + "lower": HoverTool( + renderers=xy_lower_y_hdi, + tooltips=[(rv_name_y, "@y")], + ), + "upper": HoverTool( + renderers=xy_upper_y_hdi, + tooltips=[(rv_name_y, "@y")], + ), + }, + }, + "stats": HoverTool( + renderers=xy_stats, + tooltips=[(rv_name_x, "@x"), (rv_name_y, "@y")], + ), + }, + } + return output + + +def add_tooltips(figures: typing.Figures, tooltips: typing.Tooltips) -> None: + """ + Add the given tools to the figures. + + Args: + figures (typing.Figures): A dictionary of Bokeh Figure objects. + tooltips (typing.Tooltips): A dictionary of Bokeh HoverTools objects. + + Returns: + None: Adds the tooltips directly to the given figures. + """ + figures["x"].add_tools(tooltips["x"]["distribution"]) + figures["x"].add_tools(tooltips["x"]["stats"]) + figures["y"].add_tools(tooltips["y"]["distribution"]) + figures["y"].add_tools(tooltips["y"]["stats"]) + figures["xy"].add_tools(tooltips["xy"]["distribution"]) + figures["xy"].add_tools(tooltips["xy"]["stats"]) + figures["xy"].add_tools(tooltips["xy"]["hdi"]["x"]["lower"]) + figures["xy"].add_tools(tooltips["xy"]["hdi"]["x"]["upper"]) + figures["xy"].add_tools(tooltips["xy"]["hdi"]["y"]["lower"]) + figures["xy"].add_tools(tooltips["xy"]["hdi"]["y"]["upper"]) + + +def create_widgets( + rv_name_x: str, + rv_name_y: str, + rv_names: List[str], + bw_factor: float, + bandwidth_x: float, + bandwidth_y: float, +) -> typing.Widgets: + """ + Create the widgets used in the tool. + + Args: + rv_name_x (str): The name of the random variable along the x-axis. + rv_name_y (str): The name of the random variable along the y-axis. + rv_names (List[str]): A list of all available random variable names. + bw_factor (float): Multiplicative factor used when calculating the kernel + density estimate. + bandwidth_x (float): The bandwidth used to calculate the KDE along the x-axis. + bandwidth_y (float): The bandwidth used to calculate the KDE along the y-axis. + + Returns: + typing.Widgets: A dictionary of Bokeh widget objects. + """ + output = { + "rv_select_x": Select(value=rv_name_x, options=rv_names, title="Query (x)"), + "rv_select_y": Select(value=rv_name_y, options=rv_names, title="Query (y)"), + "bw_factor_slider": Slider( + title="Bandwidth factor", + start=0.01, + end=2.00, + value=1.00, + step=0.01, + ), + "hdi_slider_x": Slider(start=1, end=99, step=1, value=89, title="HDI (x)"), + "hdi_slider_y": Slider(start=1, end=99, step=1, value=89, title="HDI (y)"), + "bw_div_x": Div(text=f"Bandwidth {rv_name_x}: {bw_factor * bandwidth_x}"), + "bw_div_y": Div(text=f"Bandwidth {rv_name_y}: {bw_factor * bandwidth_y}"), + } + return output + + +def help_page() -> Div: + """ + Help tab for the tool. + + Returns: + Div: Bokeh Div widget containing the help tab information. + """ + text = """ +

+ Joint plot +

+

+ A joint plot shows univariate marginals along the x and y axes. The + central figure shows the bivariate marginal of both random variables. +

+ """ + output = Div( + text=text, + disable_math=False, + min_width=MARGINAL1D_PLOT_WIDTH + MARGINAL2D_PLOT_WIDTH, + ) + return output + + +def create_figure_grid(figures: typing.Figures) -> Row: + """Layout the given figures in a grid, and make one toolbar. + + Parameters + ---------- + figures : typing.Figures + A dictionary of Bokeh Figure objects. + + Returns + ------- + Row + A Bokeh layout object. + """ + toolbar = create_toolbar(list(figures.values())) + grid_box = GridBox( + children=[ + [figures["x"], 0, 0], + [figures["xy"], 1, 0], + [figures["y"], 1, 1], + ], + ) + return Row(children=[grid_box, toolbar]) + + +def create_view(figures: typing.Figures, widgets: typing.Widgets) -> Tabs: + """ + Create the tool view. + + Args: + figures (typing.Figures): A dictionary of Bokeh Figure objects. + widgets (typing.Widgets): A dictionary of Bokeh widget objects. + + Returns: + Tabs: Bokeh Tabs objects. + """ + help_panel = Panel(child=help_page(), title="Help", name="helpPanel") + figure_grid = create_figure_grid(figures) + tool_panel = Panel( + child=Column( + children=[ + Row(children=[widgets["rv_select_x"], widgets["rv_select_y"]]), + Row( + children=[ + figure_grid, + Column( + children=[ + widgets["bw_factor_slider"], + widgets["hdi_slider_x"], + widgets["hdi_slider_y"], + widgets["bw_div_x"], + widgets["bw_div_y"], + ] + ), + ], + css_classes=["bm-tool-loading", "arcs"], + ), + ], + ), + title="Marginal 2D", + name="toolPanel", + ) + return Tabs(tabs=[tool_panel, help_panel]) diff --git a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py index 27f1d1f8be..2059e33669 100644 --- a/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py +++ b/src/beanmachine/ppl/diagnostics/tools/utils/model_serializers.py @@ -12,8 +12,10 @@ def serialize_bm(samples: MonteCarloSamples) -> Dict[str, List[List[float]]]: """ Convert Bean Machine models to a JSON serializable object. + Args: samples (MonteCarloSamples): Output of a model from Bean Machine. + Returns Dict[str, List[List[float]]]: The JSON serializable object for use in the diagnostics tools. diff --git a/src/beanmachine/ppl/diagnostics/tools/viz.py b/src/beanmachine/ppl/diagnostics/tools/viz.py index c149ba9d7e..33cf1899ca 100644 --- a/src/beanmachine/ppl/diagnostics/tools/viz.py +++ b/src/beanmachine/ppl/diagnostics/tools/viz.py @@ -59,6 +59,18 @@ def marginal1d(self: DiagnosticsTools) -> None: Marginal1d(self.mcs).show() + @_requires_dev_packages + def marginal2d(self: DiagnosticsTools) -> None: + """ + Marginal 2D diagnostic tool for a Bean Machine model. + + Returns: + None: Displays the tool directly in a Jupyter notebook. + """ + from beanmachine.ppl.diagnostics.tools.marginal2d.tool import Marginal2d + + Marginal2d(self.mcs).show() + @_requires_dev_packages def trace(self: DiagnosticsTools) -> None: """