Skip to content

Commit

Permalink
fixed containment function but nn not able to learn complicated behav…
Browse files Browse the repository at this point in the history
…ior -- in general it seems like the more lambda increases, the less fine mu becomes
  • Loading branch information
AlexYFM committed Sep 11, 2024
1 parent 564e4fc commit 6ea75f5
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 83 deletions.
142 changes: 71 additions & 71 deletions verse/stars/nn_results.csv
Original file line number Diff line number Diff line change
@@ -1,72 +1,72 @@
time,mu,percent of points contained
0.0,-0.00066148525,0.0
0.1,0.077371135,0.2
0.2,0.24127647,1.1
0.3,0.406301,4.8
0.4,0.5714733,10.9
0.5,0.7369145,17.4
0.6,0.9024419,25.6
0.7,1.0679693,37.9
0.8,1.2334975,50.6
0.90000004,1.3990248,66.9
1.0,1.5645533,75.1
1.1,1.73008,84.2
1.2,1.8956078,92.799995
1.3,2.061135,97.2
1.4,2.226663,99.5
1.5,2.3921907,99.9
1.6,2.5577173,100.0
1.7,2.7232463,100.0
1.8000001,2.888774,100.0
1.9,3.0543008,100.0
2.0,3.2198296,100.0
2.1,3.385357,100.0
2.2,3.550885,100.0
2.3,3.7164116,100.0
2.4,3.8819404,100.0
2.5,4.0474677,100.0
2.6000001,4.212995,100.0
2.7,4.3785233,100.0
2.8000002,4.544051,100.0
2.9,4.709577,100.0
3.0,4.875105,100.0
3.1000001,5.0406337,100.0
3.2,5.206161,100.0
3.3,5.3716884,100.0
3.4,5.5372157,100.0
3.5,5.702745,100.0
3.6000001,5.8682723,100.0
3.7,6.0338016,100.0
3.8,6.199328,100.0
3.9,6.3648534,100.0
4.0,6.5303807,100.0
4.1,6.695911,100.0
4.2,6.8614364,100.0
4.3,7.026966,100.0
4.4,7.192496,100.0
4.5,7.3580213,100.0
4.6,7.5235486,100.0
4.7,7.6890736,100.0
4.8,7.8546042,100.0
4.9,8.02013,100.0
5.0,8.18566,100.0
5.1000004,8.351188,100.0
5.2000003,8.516713,100.0
5.3,8.682243,100.0
5.4,8.847773,100.0
5.5,9.013298,100.0
5.6,9.178825,100.0
5.7,9.344354,100.0
5.7999997,9.50988,100.0
5.9,9.675406,100.0
6.0,9.840935,100.0
6.1,10.006463,100.0
6.2,10.171992,100.0
6.2999997,10.337519,100.0
6.4,10.503046,100.0
6.5,10.668571,100.0
6.6,10.834097,100.0
6.7,10.999626,100.0
6.8,11.165157,100.0
6.9,11.330681,100.0
7.0,11.49621,100.0
0.0,0.09291038,0.8
0.1,0.16880555,1.8
0.2,0.25843868,2.8999999
0.3,0.3481119,4.1
0.4,0.4377851,6.8
0.5,0.5274583,9.6
0.6,0.61713165,11.599999
0.7,0.70680517,17.1
0.8,0.7964783,22.0
0.90000004,0.88615125,27.8
1.0,0.97582465,34.4
1.1,1.0654979,41.5
1.2,1.1551714,48.100002
1.3,1.2448454,52.999996
1.4,1.3345189,58.1
1.5,1.424191,63.9
1.6,1.5138638,71.3
1.7,1.6035373,76.5
1.8000001,1.6932112,83.0
1.9,1.7828827,87.2
2.0,1.8725567,91.799995
2.1,1.9622312,95.200005
2.2,2.0519025,96.4
2.3,2.1415772,98.1
2.4,2.2312512,99.8
2.5,2.3209217,100.0
2.6000001,2.410597,100.0
2.7,2.5002697,100.0
2.8000002,2.5899427,99.2
2.9,2.6796165,98.299995
3.0,2.7692878,98.299995
3.1000001,2.8589652,98.299995
3.2,2.9486353,98.799995
3.3,3.0383105,99.2
3.4,3.1279805,99.299995
3.5,3.217655,99.299995
3.6000001,3.3073316,99.9
3.7,3.3970027,100.0
3.8,3.4866755,100.0
3.9,3.576347,100.0
4.0,3.6660233,100.0
4.1,3.7556927,100.0
4.2,3.8453677,100.0
4.3,3.9350421,100.0
4.4,4.024714,100.0
4.5,4.114386,100.0
4.6,4.2040615,100.0
4.7,4.293734,100.0
4.8,4.3834085,100.0
4.9,4.473078,100.0
5.0,4.5627522,100.0
5.1000004,4.652427,100.0
5.2000003,4.7421017,100.0
5.3,4.8317757,100.0
5.4,4.9214516,100.0
5.5,5.01112,100.0
5.6,5.1007924,100.0
5.7,5.1904635,100.0
5.7999997,5.2801404,100.0
5.9,5.3698173,100.0
6.0,5.4594874,100.0
6.1,5.549162,100.0
6.2,5.638834,100.0
6.2999997,5.7285037,100.0
6.4,5.8181825,100.0
6.5,5.9078546,100.0
6.6,5.997528,100.0
6.7,6.0872016,100.0
6.8,6.1768727,100.0
6.9,6.2665462,100.0
7.0,6.3562202,100.0
34 changes: 22 additions & 12 deletions verse/stars/star_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def he_init(m):

num_epochs = 50 # sample number of epoch -- can play with this/set this as a hyperparameter
num_samples = 100 # number of samples per time step
lamb = 100
lamb = 15

T = 7
ts = 0.1
Expand Down Expand Up @@ -135,22 +135,28 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten
dim = points.shape[2]

shifted_points = points - torch.stack(centers).unsqueeze(0)
shifted_points_flat = shifted_points.view(num_samples * len_times, dim)
bases_inv = torch.linalg.inv(torch.stack(bases)) #
shifted_points_flat = shifted_points.view(num_samples * len_times, dim) # (n_samples*len_times, dim)
bases_inv = torch.linalg.inv(torch.stack(bases)) # has shape (len_times, dim, dim)

transformed_points_flat = torch.bmm(shifted_points_flat.unsqueeze(1), bases_inv.repeat(num_samples, 1, 1)) # Shape: (n_samples * n_times, 1, point_dim)
bases_inv_repeated = bases_inv.repeat(num_samples, 1, 1) # Shape: (n_samples, n_times, point_dim, point_dim)

transformed_points = transformed_points_flat.squeeze(1) # Reshape back to (n_samples * n_times, point_dim)
transformed_points = transformed_points.view(num_samples, len_times, dim) # Reshape back to (n_samples, n_times, point_dim)
# Reshape bases_inv_repeated to (n_samples * n_times, point_dim, point_dim)
bases_inv_flat = bases_inv_repeated.view(num_samples * len_times, dim, dim)

# Perform batched matrix multiplication
transformed_points_flat = torch.bmm(bases_inv_flat, shifted_points_flat.unsqueeze(2)).squeeze(2) # Shape: (n_samples * n_times, point_dim)

transformed_points = transformed_points_flat.squeeze(1) # Reshape back to (n_samples * len_times, point_dim)
transformed_points = transformed_points.view(num_samples, len_times, dim) # Reshape back to (n_samples, len_times, dim)

# Step 4: Apply C matrix (batch-matrix multiplication)
transformed_points = torch.matmul(transformed_points, C.T) # C has shape (point_dim, point_dim), apply to all points
transformed_points = torch.matmul(transformed_points, C.T) # C has shape (k, dim), apply to all points

# Step 5: Apply ReLU and subtract time-dependent mu
# Step 5: Apply ReLU and subtract time-dependent mu*g
transformed_points = torch.relu(transformed_points - mu.view(1, len_times, 1) * g)

# Step 6: Compute vector norm for each point for each time step
return torch.linalg.vector_norm(transformed_points)
return torch.linalg.vector_norm(transformed_points, dim=2)

for epoch in range(num_epochs):
# Zero the parameter gradients
Expand Down Expand Up @@ -213,7 +219,7 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten
# # print(model.fc1.weight.grad, model.fc1.bias.grad)
# optimizer.step()
mu = model(times.unsqueeze(1)) # get times in right form
loss = torch.sum(mu)/len(times)+lamb*containment(post_points[:, :, 1:], times, bases, centers)/(num_samples*len(times))
loss = (torch.sum(mu)+lamb*torch.sum(containment(post_points[:, :, 1:], times, bases, centers))/num_samples)/len(times)
loss.backward()
optimizer.step()

Expand All @@ -222,15 +228,19 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten
# print(f'Loss: {loss.item():.4f}')
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch + 1}/{num_epochs}] \n_____________\n')
# print("Gradients of weights and loss", model.fc1.weight.grad, model.fc1.bias.grad)
print("Gradients of weights and loss", model.fc1.weight.grad, model.fc1.bias.grad)
losses = 0
for i in range(len(times)):
t = torch.tensor([times[i]], dtype=torch.float32)
mu = model(t)
cont = lambda p, i: torch.linalg.vector_norm(torch.relu(C@torch.linalg.inv(bases[i])@(p-centers[i])-mu*g))
loss = mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
# loss = (1-lamb)*mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
print(f'loss: {loss.item():.4f}, mu: {mu.item():.4f}, time: {t.item():.1f}')

losses += loss.item()
mu = model(times.unsqueeze(1)) # get times in right form
other_loss = torch.sum(mu)+lamb*torch.sum(containment(post_points[:, :, 1:], times, bases, centers))/(num_samples)
print(f'Losses: {losses/len(times):.4f}, ..., other loss {other_loss/len(times)}')

# test the new model

Expand Down

0 comments on commit 6ea75f5

Please sign in to comment.