Skip to content

Commit

Permalink
Add empirical test for fleace
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Apr 11, 2024
1 parent 2fdd284 commit 609731e
Showing 1 changed file with 70 additions and 0 deletions.
70 changes: 70 additions & 0 deletions experiments/fleace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def v_transform(K: torch.Tensor):
"""ReLU(Kernel) V-transform as described in Tensor Programs I: Wide Feedforward or
Recurrent Neural Networks of Any Architecture are Gaussian Processes https://arxiv.org/abs/1910.12478.
"""
diag = torch.diagonal(K)
scale = torch.sqrt(diag.unsqueeze(1) * diag.unsqueeze(0))
c = K / scale

return (
(1 / (2 * torch.pi))
* ((1 - c.pow(2)).sqrt() + (torch.pi - torch.acos(c)) * c)
* scale
)


def g(x: torch.Tensor, k: int, device="cpu"):
"""Apply `ReLU(x) + e` to x with e ~ N(0, 1) and return the first k dimensions."""
e = torch.randn(x.shape, device=device)

return (F.relu(x) + e)[..., :k]


def relu_eraser(x: torch.Tensor, n: int, k: int):
"""Use a closed-form solution for free-form LEACE for ReLU to remove linearly accessible
information about ReLU(x) from x as described in Non-Linear Least-Squares Concept Erasure.
"""
# E[Z | X]
f = F.relu(x)[:, :k]

# Closed form solution for E[X ReLU(X).T] in R^(n k)
# Note: non-centered cross covariance
cross_cov = torch.eye(n)[:, :k] * 0.5

# Closed form solution for E[ReLU(X) ReLU(X).T] in R^(k k)
V = v_transform(torch.eye(k))

A = cross_cov @ torch.linalg.pinv(V)

return x - (A @ f.T).T


def test_v_transform_monte_carlo():
num_samples = 100_000
dim = 10
X = torch.randn((num_samples, dim))

cov = (F.relu(X).T @ F.relu(X)) / (num_samples - 1)

torch.testing.assert_close(cov, v_transform(torch.eye(10)), rtol=0.01, atol=0.01)


def test_relu_linear_erasure():
batch_size = 100_000
n, k = 16, 8

x = torch.randn((batch_size, n))
z = g(x, k)
r_x = relu_eraser(x, n, k)

assert torch.norm((r_x.T @ z) / (batch_size - 1)) < 0.4


if __name__ == "__main__":
test_v_transform_monte_carlo()
test_relu_linear_erasure()

0 comments on commit 609731e

Please sign in to comment.