Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
minibmg: use NUTS from bmg. (#1794)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1794

Use NUTS from bmg as a library from minibmg.  Currently uses (slow) reverse-mode AD for the gradients.

Reviewed By: horizon-blue

Differential Revision: D40356996

fbshipit-source-id: ebd1ebc054f736e57259f16b1c21d3e659950b97
  • Loading branch information
Neal Gafter authored and facebook-github-bot committed Oct 30, 2022
1 parent f7da1a1 commit 1f04780
Show file tree
Hide file tree
Showing 8 changed files with 517 additions and 40 deletions.
46 changes: 46 additions & 0 deletions minibmg/graph_properties/unobserved_samples.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "beanmachine/minibmg/graph_properties/unobserved_samples.h"
#include <exception>
#include <list>
#include <map>
#include <memory>
#include <unordered_set>

namespace {

using namespace beanmachine::minibmg;

class unobserved_samples_property
: public Property<unobserved_samples_property, Graph, std::vector<Nodep>> {
public:
std::vector<Nodep>* create(const Graph& g) const override {
auto result = new std::vector<Nodep>{};
std::unordered_set<Nodep> observed_samples;
for (auto& p : g.observations) {
observed_samples.insert(p.first);
}
for (auto& node : g) {
if (std::dynamic_pointer_cast<const ScalarSampleNode>(node) &&
!observed_samples.contains(node)) {
result->push_back(node);
}
}
return result;
}
};

} // namespace

namespace beanmachine::minibmg {

const std::vector<Nodep>& unobserved_samples(const Graph& graph) {
return *unobserved_samples_property::get(graph);
}

} // namespace beanmachine::minibmg
19 changes: 19 additions & 0 deletions minibmg/graph_properties/unobserved_samples.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <list>
#include <unordered_set>
#include "beanmachine/minibmg/graph.h"
#include "beanmachine/minibmg/node.h"

namespace beanmachine::minibmg {

const std::vector<Nodep>& unobserved_samples(const Graph& graph);

} // namespace beanmachine::minibmg
174 changes: 174 additions & 0 deletions minibmg/inference/global_state.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "beanmachine/minibmg/inference/global_state.h"
#include <math.h>
#include <memory>
#include "beanmachine/graph/global/global_state.h"
#include "beanmachine/minibmg/eval.h"
#include "beanmachine/minibmg/graph_properties/unobserved_samples.h"

namespace beanmachine::minibmg {

MinibmgGlobalState::MinibmgGlobalState(beanmachine::minibmg::Graph& graph)
: graph{graph}, world{hmc_world_0(graph)} {
samples.clear();
// Since we only support scalars, we count the unobserved samples by ones.
int num_unobserved_samples = -graph.observations.size();
for (auto& node : graph) {
if (std::dynamic_pointer_cast<const ScalarSampleNode>(node)) {
num_unobserved_samples++;
}
}
flat_size = num_unobserved_samples;
}

void MinibmgGlobalState::initialize_values(
beanmachine::graph::InitType init_type,
uint seed) {
std::mt19937 gen(31 * seed + 17);
std::vector<double>& samples = unconstrained_values;
switch (init_type) {
case graph::InitType::PRIOR: {
// Evaluate the graph, saving samples.
auto read_variable = [](const std::string&, const unsigned) -> Real {
// there are no variables, so we don't have to read them.
throw std::logic_error("models do not contain variables");
};
auto my_sampler = [&samples](
const Distribution<Real>& distribution,
std::mt19937& gen) -> SampledValue<Real> {
auto result = sample_from_distribution(distribution, gen);
// save the proposed value
samples.push_back(result.unconstrained.as_double());
return result;
};
auto eval_result = eval_graph<Real>(
graph,
gen,
/* read_variable= */ read_variable,
real_eval_data,
/* run_queries= */ false,
/* eval_log_prob= */ true,
/* sampler = */ my_sampler);
} break;
case graph::InitType::RANDOM: {
std::uniform_real_distribution<> uniform_real_distribution(-2, 2);
for (int i = 0; i < flat_size; i++) {
samples.push_back(uniform_real_distribution(gen));
}
} break;
default: {
for (int i = 0; i < flat_size; i++) {
samples.push_back(0);
}
} break;
}

// update and backup values, gradients, and log_prob
update_log_prob();
update_backgrad();
backup_unconstrained_values();
backup_unconstrained_grads();
}

void MinibmgGlobalState::backup_unconstrained_values() {
saved_unconstrained_values = unconstrained_values;
}

void MinibmgGlobalState::backup_unconstrained_grads() {
saved_unconstrained_grads = unconstrained_grads;
}

void MinibmgGlobalState::revert_unconstrained_values() {
unconstrained_values = saved_unconstrained_values;
}

void MinibmgGlobalState::revert_unconstrained_grads() {
unconstrained_grads = saved_unconstrained_grads;
}

void MinibmgGlobalState::add_to_stochastic_unconstrained_nodes(
Eigen::VectorXd& increment) {
if (increment.size() != flat_size) {
throw std::invalid_argument(
"The size of increment is inconsistent with the values in the graph");
}
for (int i = 0; i < flat_size; i++) {
unconstrained_values[i] += increment[i];
}
}

void MinibmgGlobalState::get_flattened_unconstrained_values(
Eigen::VectorXd& flattened_values) {
flattened_values.resize(flat_size);
for (int i = 0; i < flat_size; i++) {
flattened_values[i] = unconstrained_values[i];
}
}

void MinibmgGlobalState::set_flattened_unconstrained_values(
Eigen::VectorXd& flattened_values) {
if (flattened_values.size() != flat_size) {
throw std::invalid_argument(
"The size of flattened_values is inconsistent with the values in the graph");
}
for (int i = 0; i < flat_size; i++) {
unconstrained_values[i] = flattened_values[i];
}
}

void MinibmgGlobalState::get_flattened_unconstrained_grads(
Eigen::VectorXd& flattened_grad) {
flattened_grad.resize(flat_size);
for (int i = 0; i < flat_size; i++) {
flattened_grad[i] = unconstrained_grads[i];
}
}

double MinibmgGlobalState::get_log_prob() {
return log_prob;
}

void MinibmgGlobalState::update_log_prob() {
log_prob = world->log_prob(this->unconstrained_values);
}

void MinibmgGlobalState::update_backgrad() {
unconstrained_grads = world->gradients(this->unconstrained_values);
}

void MinibmgGlobalState::collect_sample() {
auto queries = world->queries(this->unconstrained_values);
std::vector<beanmachine::graph::NodeValue> compat_query;
for (auto v : queries) {
compat_query.emplace_back(v);
}
this->samples.emplace_back(compat_query);
}

std::vector<std::vector<beanmachine::graph::NodeValue>>&
MinibmgGlobalState::get_samples() {
return samples;
}

void MinibmgGlobalState::set_default_transforms() {
// minibmg always uses the default transforms
}

void MinibmgGlobalState::set_agg_type(
beanmachine::graph::AggregationType agg_type) {
if (agg_type != beanmachine::graph::AggregationType::NONE) {
throw std::logic_error("unimplemented AggregationType");
}
}

void MinibmgGlobalState::clear_samples() {
samples.clear();
}

} // namespace beanmachine::minibmg
66 changes: 66 additions & 0 deletions minibmg/inference/global_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include "beanmachine/graph/global/global_state.h"
#include "beanmachine/graph/graph.h"
#include "beanmachine/minibmg/ad/real.h"
#include "beanmachine/minibmg/ad/reverse.h"
#include "beanmachine/minibmg/graph.h"
#include "hmc_world.h"

namespace beanmachine::minibmg {

// using namespace beanmachine::graph;

// Global state, an implementation of beanmachine::graph::GlobalState which is
// needed to use the NUTS api from bmg.
class MinibmgGlobalState : public beanmachine::graph::GlobalState {
public:
explicit MinibmgGlobalState(beanmachine::minibmg::Graph& graph);
void initialize_values(beanmachine::graph::InitType init_type, uint seed)
override;
void backup_unconstrained_values() override;
void backup_unconstrained_grads() override;
void revert_unconstrained_values() override;
void revert_unconstrained_grads() override;
void add_to_stochastic_unconstrained_nodes(
Eigen::VectorXd& increment) override;
void get_flattened_unconstrained_values(
Eigen::VectorXd& flattened_values) override;
void set_flattened_unconstrained_values(
Eigen::VectorXd& flattened_values) override;
void get_flattened_unconstrained_grads(
Eigen::VectorXd& flattened_grad) override;
double get_log_prob() override;
void update_log_prob() override;
void update_backgrad() override;
void collect_sample() override;
std::vector<std::vector<beanmachine::graph::NodeValue>>& get_samples()
override;
void set_default_transforms() override;
void set_agg_type(beanmachine::graph::AggregationType) override;
void clear_samples() override;

private:
const beanmachine::minibmg::Graph& graph;
const std::unique_ptr<const HMCWorld> world;
std::vector<std::vector<beanmachine::graph::NodeValue>> samples;
int flat_size;
double log_prob;
std::vector<double> unconstrained_values;
std::vector<double> unconstrained_grads;
std::vector<double> saved_unconstrained_values;
std::vector<double> saved_unconstrained_grads;

// scratchpads for evaluation
std::unordered_map<Nodep, Reverse<Real>> reverse_eval_data;
std::unordered_map<Nodep, Real> real_eval_data;
};

} // namespace beanmachine::minibmg
Loading

0 comments on commit 1f04780

Please sign in to comment.