Skip to content

Commit

Permalink
Merge pull request #100 from cnellington/dev
Browse files Browse the repository at this point in the history
Added base_predictors and save/load, updated demos
  • Loading branch information
cnellington authored Jun 20, 2022
2 parents b70fa0f + db008ca commit e1b1a06
Show file tree
Hide file tree
Showing 13 changed files with 871 additions and 183 deletions.
1 change: 1 addition & 0 deletions contextualized/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from contextualized.utils import *
56 changes: 55 additions & 1 deletion contextualized/dags/notmad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import dill as pickle
import numpy as np
import copy
from sklearn.model_selection import train_test_split
Expand All @@ -9,6 +11,7 @@

from contextualized.dags.notmad_helpers.tf_utils import NOTEARS_loss, DAG_loss
from contextualized.dags.notmad_helpers import graph_utils
from contextualized.dags.notmad_helpers.baselines import save_clusterednotears, load_clusterednotears


class NGAM(tf.keras.layers.Layer):
Expand Down Expand Up @@ -263,6 +266,32 @@ def __init__(self, context_shape, data_shape, n_archetypes,
init_compat=None,
freeze_compat=False
):
self.kwargs = {
'context_shape': context_shape,
'data_shape': data_shape,
'n_archetypes': n_archetypes,
'sample_specific_loss_params': sample_specific_loss_params,
'archetype_loss_params': archetype_loss_params,
'n_encoder_layers': n_encoder_layers,
'encoder_width': encoder_width,
'context_activity_regularizer': context_activity_regularizer,
'activation': activation,
'rank': rank,
'init_mat': init_mat,
'init_archs': init_archs,
'freeze_archs': freeze_archs,
'learning_rate': learning_rate,
'project_archs_to_dag': project_archs_to_dag,
'project_distance': project_distance,
'tf_dtype': tf_dtype,
'use_compatibility': use_compatibility,
'update_compat_by_grad': update_compat_by_grad,
'pop_model': pop_model,
'base_predictor': base_predictor,
'encoder_type': encoder_type,
'init_compat': init_compat,
'freeze_compat': freeze_compat,
}
super(NOTMAD, self).__init__()
encoder_input_shape = (context_shape[1], 1)
encoder_output_shape = (n_archetypes, )
Expand Down Expand Up @@ -394,7 +423,7 @@ def fit(self, C, X, epochs, batch_size,
# base_W = self.transform_to_low_rank(base_W)
else:
base_W = np.zeros((len(C), X.shape[-1], X.shape[-1])) # TODO: this is expensive.
self.model.fit({"C":C, "base_W": base_W},
self.history = self.model.fit({"C":C, "base_W": base_W},
y=X, batch_size=batch_size, epochs=epochs,
callbacks=callbacks, validation_split=val_split, verbose=0)

Expand Down Expand Up @@ -454,3 +483,28 @@ def predict_w(self, C, project_to_dag=False):
return np.array([graph_utils.project_to_dag(w)[0] for w in preds])
else:
return preds


def save_notmad(notmad, path):
if path[-1] != '/':
path += '/'
os.makedirs(path, exist_ok=True)
kwargs = notmad.kwargs.copy()
if kwargs['base_predictor'] is not None:
save_clusterednotears(kwargs['base_predictor'], path + 'base_predictor')
kwargs['base_predictor'] = True
pickle.dump(kwargs, open(path + 'kwargs.pkl', 'wb'))
notmad.model.save_weights(path + 'weights')


def load_notmad(path):
if path[-1] != '/':
path += '/'
kwargs = pickle.load(open(path + 'kwargs.pkl', 'rb'))
if kwargs['base_predictor'] is not None:
base_predictor = load_clusterednotears(path + 'base_predictor')
kwargs['base_predictor'] = base_predictor
notmad = NOTMAD(**kwargs)
notmad.model.load_weights(path + 'weights')
return notmad

59 changes: 59 additions & 0 deletions contextualized/dags/notmad_helpers/baselines.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import os
import dill as pickle
import numpy as np
from sklearn.cluster import KMeans
import tensorflow as tf
Expand Down Expand Up @@ -61,6 +63,13 @@ def __init__(self, loss_params, context_shape, W_shape,
learning_rate=1e-3,
tf_dtype=tf.dtypes.float32):
# super(NOTEARS, self).__init__()
self.kwargs = {
'loss_params': loss_params,
'context_shape': context_shape,
'W_shape': W_shape,
'learning_rate': learning_rate,
'tf_dtype': tf_dtype,
}
encoder_input_shape = (context_shape[1], 1)
self.context = tf.keras.layers.Input(
shape=encoder_input_shape, dtype=tf_dtype, name="C")
Expand Down Expand Up @@ -115,6 +124,23 @@ def get_w(self):
return self.W.W.numpy()


def save_notears(notears, path):
if path[-1] != '/':
path += '/'
os.makedirs(path, exist_ok=True)
pickle.dump(notears.kwargs, open(path + 'kwargs.pkl', 'wb'))
notears.model.save_weights(path + 'weights')


def load_notears(path):
if path[-1] != '/':
path += '/'
kwargs = pickle.load(open(path + 'kwargs.pkl', 'rb'))
notears = NOTEARS(**kwargs)
notears.model.load_weights(path + 'weights')
return notears


class ClusteredNOTEARS:
"""
Learn several NO-TEARS optimized DAGs based on a clustering function
Expand All @@ -124,6 +150,16 @@ def __init__(self, n_clusters, loss_params, context_shape, W_shape,
learning_rate=1e-3, clusterer=None, clusterer_fitted=False,
tf_dtype=tf.dtypes.float32):
# super(ClusteredNOTEARS, self).__init__()
self.kwargs = {
'n_clusters': n_clusters,
'loss_params': loss_params,
'context_shape': context_shape,
'W_shape': W_shape,
'learning_rate': learning_rate,
'clusterer': clusterer,
'clusterer_fitted': clusterer_fitted,
'tf_dtype': tf_dtype,
}
if clusterer is None:
self.clusterer = KMeans(n_clusters=n_clusters)
else:
Expand Down Expand Up @@ -161,3 +197,26 @@ def predict_w(self, C, project_to_dag=False):
def get_ws(self, project_to_dag=False):
# Already projected to DAG space, nothing to do here.
return np.array([model.get_w() for model in self.notears_models])


def save_clusterednotears(clusterednotears, path):
if path[-1] != '/':
path += '/'
os.makedirs(path, exist_ok=True)
kwargs = clusterednotears.kwargs.copy()
kwargs['clusterer'] = clusterednotears.clusterer
kwargs['clusterer_fitted'] = clusterednotears.clusterer_fitted
pickle.dump(kwargs, open(path + 'kwargs.pkl', 'wb'))
for i, notears in enumerate(clusterednotears.notears_models):
save_notears(notears, path + f'notears{i}')


def load_clusterednotears(path):
if path[-1] != '/':
path += '/'
kwargs = pickle.load(open(path + 'kwargs.pkl', 'rb'))
clusterednotears = ClusteredNOTEARS(**kwargs)
clusters = kwargs['n_clusters']
notears_models = [load_notears(path + f'notears{i}') for i in range(clusters)]
clusterednotears.notears_models = notears_models
return clusterednotears
4 changes: 3 additions & 1 deletion contextualized/dags/notmad_helpers/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def binary_search(arr, low, high, w): #low and high are indices
return -1

idx = binary_search(vals, low, high, w_dag) + 1
thresh = vals[idx]
thresh = np.max(vals) + 0.1
if idx > -1 and idx < len(vals):
thresh = vals[idx]
w_dag = trim_params(w_dag, thresh)

# Now add back in edges with weights smaller than the thresh that don't violate DAG-ness.
Expand Down
6 changes: 2 additions & 4 deletions contextualized/easy/ContextualizedRegressor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch

from contextualized.regression import NaiveContextualizedRegression, ContextualizedRegression
from contextualized.regression import REGULARIZERS, LINK_FUNCTIONS, LOSSES

from contextualized.easy.wrappers import SKLearnInterface

# TODO: Multitask metamodels
# TODO: Task-specific link functions.
# TODO: Easier early stopping (right now, have to pass in 'callbacks' kwargs.
# TODO: Easier early stopping (right now, have to pass in 'callback_constructors' kwarg.


class ContextualizedRegressor(SKLearnInterface):
Expand Down Expand Up @@ -42,7 +40,7 @@ def __init__(self, **kwargs):
super().__init__(self.constructor)

def fit(self, C, X, Y, **kwargs):
# Merge kwards and self.constructor_kwargs, prioritizing more recent kwargs.
# Merge kwargs and self.constructor_kwargs, prioritizing more recent kwargs.
for k, v in self.constructor_kwargs.items():
if k not in kwargs:
kwargs[k] = v
Expand Down
25 changes: 23 additions & 2 deletions contextualized/easy/tests/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor


class DummyParamPredictor:
def __init__(self, beta_dim, mu_dim):
self.beta_dim = beta_dim
self.mu_dim = mu_dim

def predict_params(self, *args):
n = len(args[0])
return torch.zeros((n, *self.beta_dim)), torch.zeros((n, *self.mu_dim))


class DummyYPredictor:
def __init__(self, y_dim):
self.y_dim = y_dim

def predict_y(self, *args):
n = len(args[0])
return torch.zeros((n, *self.y_dim))


def quicktest(model, C, X, Y, **kwargs):
print(f'{type(model)} quicktest')
model.fit(C, X, Y, max_epochs=0)
Expand Down Expand Up @@ -43,7 +62,9 @@ def test_regressor():
C, X, Y = C.numpy(), X.numpy(), Y.numpy()

# Naive Multivariate
model = ContextualizedRegressor()
parambase = DummyParamPredictor((y_dim, x_dim), (y_dim, 1))
ybase = DummyYPredictor((y_dim, 1))
model = ContextualizedRegressor(base_param_predictor=parambase, base_y_predictor=ybase)
quicktest(model, C, X, Y, max_epochs=1)

model = ContextualizedRegressor(num_archetypes=0)
Expand All @@ -59,7 +80,7 @@ def test_regressor():

# With bootstrap
model = ContextualizedRegressor(num_archetypes=4, alpha=0.1,
l1_ratio=0.5, mu_ratio=0.9)
l1_ratio=0.5, mu_ratio=0.9, base_param_predictor=parambase, base_y_predictor=ybase)
quicktest(model, C, X, Y, max_epochs=1, n_bootstraps=2,
learning_rate=1e-3)

Expand Down
24 changes: 16 additions & 8 deletions contextualized/easy/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ def _organize_kwargs(self, **kwargs):
acceptable_model_kwargs = [
'loss_fn', 'link_fn', 'univariate', 'encoder_type',
'encoder_kwargs', 'model_regularizer', 'num_archetypes',
'learning_rate'
'learning_rate', 'base_param_predictor', 'base_y_predictor'
]
acceptable_trainer_kwargs = [
'max_epochs', 'check_val_every_n_epoch', 'val_check_interval',
'callbacks'
]
acceptable_wrapper_kwargs = [
'n_bootstraps'
]
acceptable_fit_kwargs = []
data_kwargs, model_kwargs, trainer_kwargs, fit_kwargs = {}, {}, {}, {}
data_kwargs, model_kwargs, trainer_kwargs, fit_kwargs, wrapper_kwargs = {}, {}, {}, {}, {}
for k, v in kwargs.items():
if k in acceptable_data_kwargs:
data_kwargs[k] = v
Expand All @@ -35,6 +38,8 @@ def _organize_kwargs(self, **kwargs):
trainer_kwargs[k] = v
elif k in acceptable_fit_kwargs:
fit_kwargs[k] = v
elif k in acceptable_wrapper_kwargs:
wrapper_kwargs[k] = v
else:
print("Received unknown keyword argument {}, probably ignoring.".format(k))

Expand All @@ -44,7 +49,7 @@ def _organize_kwargs(self, **kwargs):
model_kwargs['context_dim'] = self.context_dim
model_kwargs['x_dim'] = self.x_dim
model_kwargs['y_dim'] = self.y_dim
self.n_bootstraps = kwargs.get("n_bootstraps", 1)
self.n_bootstraps = wrapper_kwargs.get("n_bootstraps", 1)

# Data kwargs
data_kwargs['context_dim'] = self.context_dim
Expand All @@ -53,8 +58,8 @@ def _organize_kwargs(self, **kwargs):
if 'C_val' not in data_kwargs or 'X_val' not in data_kwargs or 'Y_val' not in data_kwargs:
data_kwargs['val_split'] = data_kwargs.get('val_split', 0.2)

trainer_kwargs['callbacks'] = trainer_kwargs.get('callbacks',
[EarlyStopping(monitor='val_loss', mode='min', patience=1)]
trainer_kwargs['callback_constructors'] = trainer_kwargs.get('callback_constructors',
[lambda: EarlyStopping(monitor='val_loss', mode='min', patience=1)]
)

return data_kwargs, model_kwargs, trainer_kwargs, fit_kwargs
Expand All @@ -72,9 +77,12 @@ def fit(self, C, X, Y, **kwargs):
for i in range(self.n_bootstraps):
model = self.base_constructor(**model_kwargs)
train_dataloader, val_dataloader = self._build_dataloaders(C, X, Y, model, **data_kwargs)

# Makes a new trainer for each call to fit - bad practice, but necessary here.
trainer = RegressionTrainer(**trainer_kwargs)
# Makes a new trainer for each bootstrap fit - bad practice, but necessary here.
my_trainer_kwargs = {k: v for k, v in trainer_kwargs.items()}
# Must reconstruct the callbacks because they save state from fitting trajectories.
my_trainer_kwargs['callbacks'] = [f() for f in trainer_kwargs['callback_constructors']]
del my_trainer_kwargs['callback_constructors']
trainer = RegressionTrainer(**my_trainer_kwargs)
try:
trainer.fit(model, train_dataloader, val_dataloader, **fit_kwargs)
except:
Expand Down
Loading

0 comments on commit e1b1a06

Please sign in to comment.