diff --git a/utils/sherlock_script.py b/utils/sherlock_script.py index 003433c..f7b2dda 100644 --- a/utils/sherlock_script.py +++ b/utils/sherlock_script.py @@ -12,9 +12,14 @@ import torch import sys -if len(sys.argv) != 2: +if len(sys.argv) == 1 or len(sys.argv) > 3: raise RuntimeError() +elif len(sys.argv) == 2: + task_id = 0xEE364A +else len(sys.argv) == 3: + task_id = int(sys.argv[2]) + data_dir = "/oak/stanford/groups/candes/for_parth" cache_dir = "/scratch/groups/candes/parth" df = pd.read_csv(os.path.join(data_dir, "phenotypes.QC.britishonly.csv"), index_col=0) @@ -38,7 +43,7 @@ ) print(X.shape) -rng = np.random.default_rng(0xEE364A) +rng = np.random.default_rng(task_id) P = np.random.permutation(y.shape[-1]) n_train = P.size * 9 // 10 train_mask = P[:n_train] @@ -61,12 +66,12 @@ L = state.betas.shape[0] oos = np.empty(L) ins = np.empty(L) -y_hat_test = X_test @ state.betas.T -y_hat_train = X_train @ state.betas.T +y_hat_test = ad.diagnostic.predict(X_test, state.betas, state.intercepts) +y_hat_train = ad.diagnostic.predict(X_train, state.betas, state.intercepts) for i in range(L): - oos[i] = loss(torch.from_numpy(y_hat_test[:, i]), torch.from_numpy(y_test)) - ins[i] = loss(torch.from_numpy(y_hat_train[:, i]), torch.from_numpy(y_train)) + oos[i] = loss(torch.from_numpy(y_hat_test[i]), torch.from_numpy(y_test)) + ins[i] = loss(torch.from_numpy(y_hat_train[i]), torch.from_numpy(y_train)) ld, alo, ts, r2 = ai.get_alo_for_sweep(y_train, state, loss) -np.savez(sys.argv[-1], lamda=ld, alo=alo, oos=oos, in_sample=ins, ts=ts, r2=r2) +np.savez(sys.argv[1], lamda=ld, alo=alo, oos=oos, in_sample=ins, ts=ts, r2=r2)