This repository has been archived by the owner on Dec 18, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #1794 Use NUTS from bmg as a library from minibmg. Currently uses (slow) reverse-mode AD for the gradients. Reviewed By: horizon-blue Differential Revision: D40356996 fbshipit-source-id: ebd1ebc054f736e57259f16b1c21d3e659950b97
- Loading branch information
1 parent
f7da1a1
commit 1f04780
Showing
8 changed files
with
517 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "beanmachine/minibmg/graph_properties/unobserved_samples.h" | ||
#include <exception> | ||
#include <list> | ||
#include <map> | ||
#include <memory> | ||
#include <unordered_set> | ||
|
||
namespace { | ||
|
||
using namespace beanmachine::minibmg; | ||
|
||
class unobserved_samples_property | ||
: public Property<unobserved_samples_property, Graph, std::vector<Nodep>> { | ||
public: | ||
std::vector<Nodep>* create(const Graph& g) const override { | ||
auto result = new std::vector<Nodep>{}; | ||
std::unordered_set<Nodep> observed_samples; | ||
for (auto& p : g.observations) { | ||
observed_samples.insert(p.first); | ||
} | ||
for (auto& node : g) { | ||
if (std::dynamic_pointer_cast<const ScalarSampleNode>(node) && | ||
!observed_samples.contains(node)) { | ||
result->push_back(node); | ||
} | ||
} | ||
return result; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace beanmachine::minibmg { | ||
|
||
const std::vector<Nodep>& unobserved_samples(const Graph& graph) { | ||
return *unobserved_samples_property::get(graph); | ||
} | ||
|
||
} // namespace beanmachine::minibmg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <list> | ||
#include <unordered_set> | ||
#include "beanmachine/minibmg/graph.h" | ||
#include "beanmachine/minibmg/node.h" | ||
|
||
namespace beanmachine::minibmg { | ||
|
||
const std::vector<Nodep>& unobserved_samples(const Graph& graph); | ||
|
||
} // namespace beanmachine::minibmg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "beanmachine/minibmg/inference/global_state.h" | ||
#include <math.h> | ||
#include <memory> | ||
#include "beanmachine/graph/global/global_state.h" | ||
#include "beanmachine/minibmg/eval.h" | ||
#include "beanmachine/minibmg/graph_properties/unobserved_samples.h" | ||
|
||
namespace beanmachine::minibmg { | ||
|
||
MinibmgGlobalState::MinibmgGlobalState(beanmachine::minibmg::Graph& graph) | ||
: graph{graph}, world{hmc_world_0(graph)} { | ||
samples.clear(); | ||
// Since we only support scalars, we count the unobserved samples by ones. | ||
int num_unobserved_samples = -graph.observations.size(); | ||
for (auto& node : graph) { | ||
if (std::dynamic_pointer_cast<const ScalarSampleNode>(node)) { | ||
num_unobserved_samples++; | ||
} | ||
} | ||
flat_size = num_unobserved_samples; | ||
} | ||
|
||
void MinibmgGlobalState::initialize_values( | ||
beanmachine::graph::InitType init_type, | ||
uint seed) { | ||
std::mt19937 gen(31 * seed + 17); | ||
std::vector<double>& samples = unconstrained_values; | ||
switch (init_type) { | ||
case graph::InitType::PRIOR: { | ||
// Evaluate the graph, saving samples. | ||
auto read_variable = [](const std::string&, const unsigned) -> Real { | ||
// there are no variables, so we don't have to read them. | ||
throw std::logic_error("models do not contain variables"); | ||
}; | ||
auto my_sampler = [&samples]( | ||
const Distribution<Real>& distribution, | ||
std::mt19937& gen) -> SampledValue<Real> { | ||
auto result = sample_from_distribution(distribution, gen); | ||
// save the proposed value | ||
samples.push_back(result.unconstrained.as_double()); | ||
return result; | ||
}; | ||
auto eval_result = eval_graph<Real>( | ||
graph, | ||
gen, | ||
/* read_variable= */ read_variable, | ||
real_eval_data, | ||
/* run_queries= */ false, | ||
/* eval_log_prob= */ true, | ||
/* sampler = */ my_sampler); | ||
} break; | ||
case graph::InitType::RANDOM: { | ||
std::uniform_real_distribution<> uniform_real_distribution(-2, 2); | ||
for (int i = 0; i < flat_size; i++) { | ||
samples.push_back(uniform_real_distribution(gen)); | ||
} | ||
} break; | ||
default: { | ||
for (int i = 0; i < flat_size; i++) { | ||
samples.push_back(0); | ||
} | ||
} break; | ||
} | ||
|
||
// update and backup values, gradients, and log_prob | ||
update_log_prob(); | ||
update_backgrad(); | ||
backup_unconstrained_values(); | ||
backup_unconstrained_grads(); | ||
} | ||
|
||
void MinibmgGlobalState::backup_unconstrained_values() { | ||
saved_unconstrained_values = unconstrained_values; | ||
} | ||
|
||
void MinibmgGlobalState::backup_unconstrained_grads() { | ||
saved_unconstrained_grads = unconstrained_grads; | ||
} | ||
|
||
void MinibmgGlobalState::revert_unconstrained_values() { | ||
unconstrained_values = saved_unconstrained_values; | ||
} | ||
|
||
void MinibmgGlobalState::revert_unconstrained_grads() { | ||
unconstrained_grads = saved_unconstrained_grads; | ||
} | ||
|
||
void MinibmgGlobalState::add_to_stochastic_unconstrained_nodes( | ||
Eigen::VectorXd& increment) { | ||
if (increment.size() != flat_size) { | ||
throw std::invalid_argument( | ||
"The size of increment is inconsistent with the values in the graph"); | ||
} | ||
for (int i = 0; i < flat_size; i++) { | ||
unconstrained_values[i] += increment[i]; | ||
} | ||
} | ||
|
||
void MinibmgGlobalState::get_flattened_unconstrained_values( | ||
Eigen::VectorXd& flattened_values) { | ||
flattened_values.resize(flat_size); | ||
for (int i = 0; i < flat_size; i++) { | ||
flattened_values[i] = unconstrained_values[i]; | ||
} | ||
} | ||
|
||
void MinibmgGlobalState::set_flattened_unconstrained_values( | ||
Eigen::VectorXd& flattened_values) { | ||
if (flattened_values.size() != flat_size) { | ||
throw std::invalid_argument( | ||
"The size of flattened_values is inconsistent with the values in the graph"); | ||
} | ||
for (int i = 0; i < flat_size; i++) { | ||
unconstrained_values[i] = flattened_values[i]; | ||
} | ||
} | ||
|
||
void MinibmgGlobalState::get_flattened_unconstrained_grads( | ||
Eigen::VectorXd& flattened_grad) { | ||
flattened_grad.resize(flat_size); | ||
for (int i = 0; i < flat_size; i++) { | ||
flattened_grad[i] = unconstrained_grads[i]; | ||
} | ||
} | ||
|
||
double MinibmgGlobalState::get_log_prob() { | ||
return log_prob; | ||
} | ||
|
||
void MinibmgGlobalState::update_log_prob() { | ||
log_prob = world->log_prob(this->unconstrained_values); | ||
} | ||
|
||
void MinibmgGlobalState::update_backgrad() { | ||
unconstrained_grads = world->gradients(this->unconstrained_values); | ||
} | ||
|
||
void MinibmgGlobalState::collect_sample() { | ||
auto queries = world->queries(this->unconstrained_values); | ||
std::vector<beanmachine::graph::NodeValue> compat_query; | ||
for (auto v : queries) { | ||
compat_query.emplace_back(v); | ||
} | ||
this->samples.emplace_back(compat_query); | ||
} | ||
|
||
std::vector<std::vector<beanmachine::graph::NodeValue>>& | ||
MinibmgGlobalState::get_samples() { | ||
return samples; | ||
} | ||
|
||
void MinibmgGlobalState::set_default_transforms() { | ||
// minibmg always uses the default transforms | ||
} | ||
|
||
void MinibmgGlobalState::set_agg_type( | ||
beanmachine::graph::AggregationType agg_type) { | ||
if (agg_type != beanmachine::graph::AggregationType::NONE) { | ||
throw std::logic_error("unimplemented AggregationType"); | ||
} | ||
} | ||
|
||
void MinibmgGlobalState::clear_samples() { | ||
samples.clear(); | ||
} | ||
|
||
} // namespace beanmachine::minibmg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* | ||
* This source code is licensed under the MIT license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include "beanmachine/graph/global/global_state.h" | ||
#include "beanmachine/graph/graph.h" | ||
#include "beanmachine/minibmg/ad/real.h" | ||
#include "beanmachine/minibmg/ad/reverse.h" | ||
#include "beanmachine/minibmg/graph.h" | ||
#include "hmc_world.h" | ||
|
||
namespace beanmachine::minibmg { | ||
|
||
// using namespace beanmachine::graph; | ||
|
||
// Global state, an implementation of beanmachine::graph::GlobalState which is | ||
// needed to use the NUTS api from bmg. | ||
class MinibmgGlobalState : public beanmachine::graph::GlobalState { | ||
public: | ||
explicit MinibmgGlobalState(beanmachine::minibmg::Graph& graph); | ||
void initialize_values(beanmachine::graph::InitType init_type, uint seed) | ||
override; | ||
void backup_unconstrained_values() override; | ||
void backup_unconstrained_grads() override; | ||
void revert_unconstrained_values() override; | ||
void revert_unconstrained_grads() override; | ||
void add_to_stochastic_unconstrained_nodes( | ||
Eigen::VectorXd& increment) override; | ||
void get_flattened_unconstrained_values( | ||
Eigen::VectorXd& flattened_values) override; | ||
void set_flattened_unconstrained_values( | ||
Eigen::VectorXd& flattened_values) override; | ||
void get_flattened_unconstrained_grads( | ||
Eigen::VectorXd& flattened_grad) override; | ||
double get_log_prob() override; | ||
void update_log_prob() override; | ||
void update_backgrad() override; | ||
void collect_sample() override; | ||
std::vector<std::vector<beanmachine::graph::NodeValue>>& get_samples() | ||
override; | ||
void set_default_transforms() override; | ||
void set_agg_type(beanmachine::graph::AggregationType) override; | ||
void clear_samples() override; | ||
|
||
private: | ||
const beanmachine::minibmg::Graph& graph; | ||
const std::unique_ptr<const HMCWorld> world; | ||
std::vector<std::vector<beanmachine::graph::NodeValue>> samples; | ||
int flat_size; | ||
double log_prob; | ||
std::vector<double> unconstrained_values; | ||
std::vector<double> unconstrained_grads; | ||
std::vector<double> saved_unconstrained_values; | ||
std::vector<double> saved_unconstrained_grads; | ||
|
||
// scratchpads for evaluation | ||
std::unordered_map<Nodep, Reverse<Real>> reverse_eval_data; | ||
std::unordered_map<Nodep, Real> real_eval_data; | ||
}; | ||
|
||
} // namespace beanmachine::minibmg |
Oops, something went wrong.