Skip to content

Commit

Permalink
Merge pull request #23 from normal-computing/remove-diag-hess
Browse files Browse the repository at this point in the history
Remove diag Hessian
  • Loading branch information
SamDuffield authored Feb 21, 2024
2 parents 63ea636 + 4cdc8f1 commit d64a63d
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 498 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[2024-02-21 10:55:44,684] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs.\n"
]
}
],
"source": [
"import torch\n",
"from tqdm.auto import tqdm\n",
Expand All @@ -23,6 +31,20 @@
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "67ca6e2403094b0192d9baa19f6dda05",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading tokenizer_config.json: 0%| | 0.00/49.0 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
Expand All @@ -47,7 +69,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
Expand Down Expand Up @@ -121,7 +143,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "bdb87f3bddf94180a878ae6df9997ae3",
"model_id": "da1c946df9314a1991cd186811dcad48",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -136,13 +158,24 @@
"name": "stdout",
"output_type": "stream",
"text": [
"-1.4536418914794922\r"
"tensor(-1.5954)\r"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
"source": [
"# Optimize\n",
"ekf_transform = uqlib.ekf.diag_hessian.build(sub_param_to_log_lik, lr=1e-2, transition_sd=0.)\n",
"ekf_transform = uqlib.ekf.diag_fisher.build(sub_param_to_log_lik, lr=1e-2, transition_sd=0.)\n",
"ekf_state = ekf_transform.init(init_mean)\n",
"\n",
"progress_bar = tqdm(range(num_training_steps))\n",
Expand All @@ -160,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -173,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -194,7 +227,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -222,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -245,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down
72 changes: 0 additions & 72 deletions tests/ekf/test_diag_hessian.py

This file was deleted.

101 changes: 0 additions & 101 deletions tests/laplace/test_diag_hessian.py

This file was deleted.

40 changes: 0 additions & 40 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from uqlib import (
model_to_function,
linearized_forward_diag,
hessian_diag,
diag_normal_log_prob,
diag_normal_sample,
extract_requires_grad,
Expand Down Expand Up @@ -119,45 +118,6 @@ def forward_func(p, batch):
assert torch.allclose(lin_cov[i], delta, atol=1e-5)


def test_hessian_diag():
# Test with a constant function
def const_fn(_):
return torch.tensor(3.0)

hessian_diag_func = hessian_diag(const_fn)
x = torch.tensor([1.0, 2.0])
result = hessian_diag_func(x)
expected = torch.zeros_like(x)
assert torch.equal(result, expected)

x = {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
result = hessian_diag_func(x)
expected = tree_map(lambda v: torch.zeros_like(v), x)
for key in result:
assert torch.equal(result[key], expected[key])

# Test with a linear function
def linear_fn(x):
return x["a"].sum() + x["b"].sum()

hessian_diag_func = hessian_diag(linear_fn)
x = {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
result = hessian_diag_func(x)
for key in result:
assert torch.equal(result[key], expected[key])

# Test with a quadratic function
def quad_fn(x):
return (x["a"] ** 2).sum() + (x["b"] ** 2).sum()

hessian_diag_func = hessian_diag(quad_fn)
x = {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
result = hessian_diag_func(x)
expected = tree_map(lambda v: 2 * torch.ones_like(v), x)
for key in result:
assert torch.equal(result[key], expected[key])


def test_diag_normal_log_prob():
mean = {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}
sd_diag = {"a": torch.tensor([0.1, 0.2]), "b": torch.tensor([0.3, 0.4])}
Expand Down
1 change: 0 additions & 1 deletion uqlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from uqlib.utils import model_to_function
from uqlib.utils import linearized_forward_diag
from uqlib.utils import hvp
from uqlib.utils import hessian_diag
from uqlib.utils import diag_normal_log_prob
from uqlib.utils import diag_normal_sample
from uqlib.utils import tree_extract
Expand Down
1 change: 0 additions & 1 deletion uqlib/ekf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from uqlib.ekf import diag_fisher
from uqlib.ekf import diag_hessian
Loading

0 comments on commit d64a63d

Please sign in to comment.