Skip to content

Commit

Permalink
fixed the jax algorithm with what I discovered from the nn
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexYFM committed Sep 9, 2024
1 parent b8ca73c commit 668221d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion demo/dryvr_demo/vanderpol_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def plot_stars(stars: List[StarSet], dim1: int = None, dim2: int = None):
scenario.add_agent(car)
scenario.config.reachability_method = ReachabilityMethod.STAR_SETS
scenario.set_sensor(BaseStarSensor())
traces = scenario.verify(7, 0.05)
traces = scenario.verify(7, 0.1)

car1 = traces.nodes[0].trace['car1']
car1 = [star[1] for star in car1]
Expand Down
13 changes: 7 additions & 6 deletions verse/stars/star_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def forward(self, x):

num_epochs = 50 # sample number of epoch -- can play with this/set this as a hyperparameter
num_samples = 100 # number of samples per time step
lamb = 0.60
lamb = 10

T = 7
ts = 0.1
Expand Down Expand Up @@ -154,11 +154,11 @@ def sample_initial(num_samples: int = num_samples) -> List[List[float]]:
cont = lambda p, i: torch.linalg.vector_norm(torch.relu(C@torch.linalg.inv(bases[i])@(p-centers[i])-mu*g))
# cont = lambda p, i: torch.linalg.vector_norm(torch.relu([email protected](bases[i])@(p-centers[i])-torch.diag(mu)@g))
# cont = lambda p, i: torch.linalg.vector_norm(torch.relu([email protected](bases[i])@(p-center)-mu*g))
loss = (1-lamb)*mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
# loss = 25*torch.linalg.vector_norm(mu) + torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))
# loss = (1-lamb)*mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
loss = mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))

if i==len(times)-1 and (epoch+1)%10==0:
f = 1
# if i==len(times)-1 and (epoch+1)%10==0:
# f = 1
# Backward pass and optimize
# pretty sure I'll need to modify this if I'm not doing batch training
# will just putting optimizer on the earlier for loop help?
Expand All @@ -178,7 +178,8 @@ def sample_initial(num_samples: int = num_samples) -> List[List[float]]:
t = torch.tensor([times[i]], dtype=torch.float32)
mu = model(t)
cont = lambda p, i: torch.linalg.vector_norm(torch.relu(C@torch.linalg.inv(bases[i])@(p-centers[i])-mu*g))
loss = (1-lamb)*mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
loss = mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
# loss = (1-lamb)*mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
print(f'loss: {loss.item():.4f}, mu: {mu.item():.4f}, time: {t.item():.1f}')


Expand Down
6 changes: 3 additions & 3 deletions verse/stars/starset.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def gen_starset(points: np.ndarray, old_star: StarSet) -> StarSet:
def starset_loss(C: np.ndarray, g: np.ndarray, derived_basis: np.ndarray, points: np.ndarray, mu: float) -> float:
output = mu
x_0 = np.mean(points, axis=0) # this should be a parameter to optimze in the future but hold it here for now
V_m1 = np.linalg.inv(derived_basis) # derived_basis assumed to be invertible, may not necessarily be true right now
V_m1 = np.linalg.inv(derived_basis.T) # derived_basis assumed to be invertible, may not necessarily be true right now
for point in points:
contain = C@V_m1@(point-x_0)-mu*g ### kxm mxm nx1 - kx1 = kx1, should work so long as m=n which is the case if doing by PCA
output += jax.numpy.linalg.norm(jax.nn.relu(contain), ord=np.inf) ### unsure if l inf norm or any norm is the correct approach
Expand Down Expand Up @@ -743,8 +743,8 @@ def gen_starsets_post_sim(old_star: StarSet, sim: Callable, T: float = 7, ts: fl
post_points = np.array(post_points)
stars: List[StarSet] = []
for t in range(post_points.shape[1]): # pp has shape N x (T/dt) x (n + 1), so index using first
stars.append(gen_starset(post_points[:, t, 1:], old_star))
# stars.append(gen_starset_grad(post_points[:, t, 1:], old_star)) ### testing out new algorithm here, could also do so in startests if I remember
# stars.append(gen_starset(post_points[:, t, 1:], old_star))
stars.append(gen_starset_grad(post_points[:, t, 1:], old_star)) ### testing out new algorithm here, could also do so in startests if I remember
for t in range(post_points.shape[1]):
plt.scatter(post_points[:, t, 1], post_points[:, t, 2])
return stars
Expand Down

0 comments on commit 668221d

Please sign in to comment.