forked from probsys/hierarchical-irm
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a sampling method to HIRM. #154
Merged
+194
−25
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,9 +9,9 @@ | |
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include "distributions/get_distribution.hh" | ||
#include "domain.hh" | ||
#include "util_hash.hh" | ||
#include "distributions/get_distribution.hh" | ||
|
||
typedef std::vector<T_item> T_items; | ||
typedef VectorIntHash H_items; | ||
|
@@ -21,44 +21,59 @@ class Relation { | |
public: | ||
typedef T ValueType; | ||
|
||
virtual void incorporate(std::mt19937* prng, const T_items& items, ValueType value) = 0; | ||
virtual void incorporate(std::mt19937* prng, const T_items& items, | ||
ValueType value) = 0; | ||
|
||
virtual void unincorporate(const T_items& items) = 0; | ||
|
||
virtual double logp(const T_items& items, ValueType value, std::mt19937* prng) = 0; | ||
virtual double logp(const T_items& items, ValueType value, | ||
std::mt19937* prng) = 0; | ||
|
||
virtual double logp_score() const = 0; | ||
|
||
virtual double cluster_or_prior_logp(std::mt19937* prng, const T_items& items, const ValueType& value) const = 0; | ||
virtual double cluster_or_prior_logp(std::mt19937* prng, const T_items& items, | ||
const ValueType& value) const = 0; | ||
|
||
virtual ValueType sample_at_items(std::mt19937* prng, const T_items& items) const = 0; | ||
virtual ValueType sample_at_items(std::mt19937* prng, | ||
const T_items& items) const = 0; | ||
|
||
virtual void incorporate_to_cluster(const T_items& items, const ValueType& value) = 0; | ||
// Takes a sample from the cluster containing `items` and incorporates it. | ||
virtual void incorporate_sample(std::mt19937* prng, const T_items& items) = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a description of this new method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
virtual void incorporate_to_cluster(const T_items& items, | ||
const ValueType& value) = 0; | ||
|
||
virtual void unincorporate_from_cluster(const T_items& items) = 0; | ||
|
||
// TODO(emilyaf): Standardize passing PRNG first or last. | ||
virtual std::vector<double> logp_gibbs_exact( | ||
const Domain& domain, const T_item& item, std::vector<int> tables, | ||
std::mt19937* prng) = 0; | ||
virtual std::vector<double> logp_gibbs_exact(const Domain& domain, | ||
const T_item& item, | ||
std::vector<int> tables, | ||
std::mt19937* prng) = 0; | ||
|
||
virtual void set_cluster_assignment_gibbs(const Domain& domain, const T_item& item, | ||
int table, std::mt19937* prng) = 0; | ||
virtual void set_cluster_assignment_gibbs(const Domain& domain, | ||
const T_item& item, int table, | ||
std::mt19937* prng) = 0; | ||
|
||
virtual void transition_cluster_hparams(std::mt19937* prng, int num_theta_steps) = 0; | ||
virtual void transition_cluster_hparams(std::mt19937* prng, | ||
int num_theta_steps) = 0; | ||
|
||
// Accessor/convenience methods, mostly for subclass members that can't be accessed through the base class. | ||
// Accessor/convenience methods, mostly for subclass members that can't be | ||
// accessed through the base class. | ||
virtual const std::vector<Domain*>& get_domains() const = 0; | ||
|
||
virtual const ValueType& get_value(const T_items& items) const = 0; | ||
|
||
virtual const std::unordered_map<const T_items, ValueType, H_items>& get_data() const = 0; | ||
virtual const std::unordered_map<const T_items, ValueType, H_items>& | ||
get_data() const = 0; | ||
|
||
virtual void update_value(const T_items& items, const ValueType& value) = 0; | ||
|
||
virtual std::vector<int> get_cluster_assignment(const T_items& items) const = 0; | ||
virtual std::vector<int> get_cluster_assignment( | ||
const T_items& items) const = 0; | ||
|
||
virtual bool has_observation(const Domain& domain, const T_item& item) const = 0; | ||
virtual bool has_observation(const Domain& domain, | ||
const T_item& item) const = 0; | ||
|
||
// Convert a string to ValueType. | ||
ValueType from_string(const std::string& s) { | ||
|
@@ -69,7 +84,6 @@ class Relation { | |
}; | ||
|
||
virtual ~Relation() = default; | ||
|
||
}; | ||
|
||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add brief descriptions of these methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.