Skip to content

Commit

Permalink
pushing notes from yesterday's meeting
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexYFM committed Jun 18, 2024
1 parent 2ef9665 commit 966ebf0
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 16 deletions.
Empty file added demo/fixed_points/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions demo/fixed_points/ball_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum, auto
import copy
from typing import List

class BallMode(Enum):
Normal = auto()

class State:
y: float
vy: float
agent_mode: BallMode

def __init__(self, y, vy, agent_mode: BallMode):
pass

def decisionLogic(ego: State):
output = copy.deepcopy(ego)

# TODO: Edit this part of decision logic
output = copy.deepcopy(ego)
if ego.y < 0:
output.vy = -ego.vy # arbitrary value to simulate the loss of energy from hitting the ground
output.y = 0
# if ego.vy!=0 and ((ego.vy<=0.01 and ego.vy>0) or (ego.vy>=-0.01 and ego.vy<0)):
# output.vy = 0

return output
112 changes: 112 additions & 0 deletions demo/fixed_points/bouncing_ball.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Tuple, List

import numpy as np
from scipy.integrate import ode

from verse import BaseAgent, Scenario
from verse.analysis.utils import wrap_to_pi
from verse.analysis.analysis_tree import TraceType, AnalysisTree
from verse.parser import ControllerIR
from verse.analysis import AnalysisTreeNode, AnalysisTree, AnalysisTreeNodeType
import copy


### full disclosure, structure of file from mp4_p2
refine_profile = {
'R1': [0],
'R2': [0],
'R3': [0,0,0,3]
}

def tree_safe(tree: AnalysisTree):
for node in tree.nodes:
if node.assert_hits is not None:
return False
return True

class BallAgent(BaseAgent):
def __init__(
self,
id,
file_name
):
super().__init__(id, code = None, file_name = file_name)

@staticmethod
def dynamic(t, state):
y, vy = state
vy_dot = -9.81
return [vy, vy_dot]

def TC_simulate(
self, mode: List[str], init, time_bound, time_step, lane_map = None
) -> TraceType:
time_bound = float(time_bound)
num_points = int(np.ceil(time_bound / time_step))
trace = np.zeros((num_points + 1, 1 + len(init)))
trace[1:, 0] = [round(i * time_step, 10) for i in range(num_points)]
trace[0, 1:] = init
for i in range(num_points):
r = ode(self.dynamic)
r.set_initial_value(init)
res: np.ndarray = r.integrate(r.t + time_step)
init = res.flatten()
trace[i + 1, 0] = time_step * (i + 1)
trace[i + 1, 1:] = init
return trace

def dist(pnt1, pnt2):
return np.linalg.norm(
np.array(pnt1) - np.array(pnt2)
)

def get_extreme(rect1, rect2):
lb11 = rect1[0]
lb12 = rect1[1]
ub11 = rect1[2]
ub12 = rect1[3]

lb21 = rect2[0]
lb22 = rect2[1]
ub21 = rect2[2]
ub22 = rect2[3]

# Using rect 2 as reference
left = lb21 > ub11
right = ub21 < lb11
bottom = lb22 > ub12
top = ub22 < lb12

if top and left:
dist_min = dist((ub11, lb12),(lb21, ub22))
dist_max = dist((lb11, ub12),(ub21, lb22))
elif bottom and left:
dist_min = dist((ub11, ub12),(lb21, lb22))
dist_max = dist((lb11, lb12),(ub21, ub22))
elif top and right:
dist_min = dist((lb11, lb12), (ub21, ub22))
dist_max = dist((ub11, ub12), (lb21, lb22))
elif bottom and right:
dist_min = dist((lb11, ub12),(ub21, lb22))
dist_max = dist((ub11, lb12),(lb21, ub22))
elif left:
dist_min = lb21 - ub11
dist_max = np.sqrt((lb21 - ub11)**2 + max((ub22-lb12)**2, (ub12-lb22)**2))
elif right:
dist_min = lb11 - ub21
dist_max = np.sqrt((lb21 - ub11)**2 + max((ub22-lb12)**2, (ub12-lb22)**2))
elif top:
dist_min = lb12 - ub22
dist_max = np.sqrt((ub12 - lb22)**2 + max((ub21-lb11)**2, (ub11-lb21)**2))
elif bottom:
dist_min = lb22 - ub12
dist_max = np.sqrt((ub22 - lb12)**2 + max((ub21-lb11)**2, (ub11-lb21)**2))
else:
dist_min = 0
dist_max = max(
dist((lb11, lb12), (ub21, ub22)),
dist((lb11, ub12), (ub21, lb22)),
dist((ub11, lb12), (lb21, ub12)),
dist((ub11, ub12), (lb21, lb22))
)
return dist_min, dist_max
43 changes: 43 additions & 0 deletions demo/fixed_points/bouncing_ball_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from verse import Scenario, ScenarioConfig
from vehicle_controller import VehicleMode, TLMode

from verse.plotter.plotter2D import *
from verse.plotter.plotter3D_new import *
import plotly.graph_objects as go
import copy

###
from bouncing_ball import BallAgent
from ball_controller import BallMode

from z3 import *
from fixed_points import fixed_points_aa_branching, fixed_points_aa_branching_composed, contained_single, reach_at, fixed_points_sat, reach_at_fix, fixed_points_fix
from fixed_points import contain_all_fix, contain_all, pp_fix, pp_old

if __name__ == "__main__":

import os
script_dir = os.path.realpath(os.path.dirname(__file__))
input_code_name = os.path.join(script_dir, "ball_controller.py")
ball = BallAgent('ball', file_name=input_code_name)

scenario = Scenario(ScenarioConfig(init_seg_length=1, parallel=False))

scenario.add_agent(ball) ### need to add breakpoint around here to check decision_logic of agents

init_ball = [[10,2],[10,2]]
# # -----------------------------------------

scenario.set_init_single(
'ball', init_ball,(BallMode.Normal,)
)

trace = scenario.verify(7, 0.01)

pp_fix(reach_at_fix(trace, 0, 7))
print(f'Fixed points exists? {fixed_points_fix(trace)}')

fig = go.Figure()
fig = reachtube_tree(trace, None, fig, 0, 1, [0, 1], "fill", "trace")
# fig = simulation_tree(trace, None, fig, 1, 2, [1, 2], "fill", "trace")
fig.show()
3 changes: 2 additions & 1 deletion demo/fixed_points/fixed_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def reach_at(trace: AnalysisTree, t_lower: float = None, t_upper: float = None)
return reached

### revised version of top function
### will now return a list of the vertices of composed hyperrectangles (2 vertices per rect) index by node number
### will now return a list of the vertices of composed/product hyperrectangles (2 vertices per rect) index by node number
def reach_at_fix(tree: AnalysisTree, t_lower: float = None, t_upper: float = None) -> Dict[int, List[List[float]]]:
nodes: List[AnalysisTreeNode] = tree.nodes
agents = nodes[0].agent.keys() # list of agents
Expand Down Expand Up @@ -118,6 +118,7 @@ def reach_at_fix(tree: AnalysisTree, t_lower: float = None, t_upper: float = Non
node_counter += 1
return reached

#unit test this
def contain_all_fix(reach1: Dict[int, List[List[float]]], reach2: Dict[int, List[List[float]]]) -> Bool:
nodes = list(reach1.keys()) # this is abritrary, could be from either reach set, just need this
state_len = len(reach1[nodes[0]][0]) # taking the first vertex, could be any
Expand Down
36 changes: 21 additions & 15 deletions demo/fixed_points/traffic_signal_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from ball_scenario_branch import BallScenarioBranch
from ball_scenario_branch_nt import BallScenarioBranchNT
from z3 import *
from fixed_points import fixed_points, fixed_points_aa_branching, fixed_points_aa_branching_composed, contained_single, reach_at, fixed_points_sat
from fixed_points import fixed_points_aa_branching, fixed_points_aa_branching_composed, contained_single, reach_at, fixed_points_sat, reach_at_fix, fixed_points_fix
from fixed_points import contain_all_fix, contain_all, pp_fix, pp_old

if __name__ == "__main__":

Expand Down Expand Up @@ -57,8 +58,11 @@
# ----------- Simulate single: Uncomment this block to perform single simulation -------------
# trace = scenario.simulate(80, 0.1)
# trace = scenario.verify(80, 0.1)
# print(fixed_points_sat(trace, 80))
# print(len(trace.get_leaf_nodes(trace.root)), len(trace._get_all_nodes(trace.root)))
# pp_fix(reach_at_fix(trace, 0, 79.91))
# pp_fix(reach_at_fix(trace))
# pp_old(reach_at(trace, 0, 79.91))
# pp_old(reach_at(trace))
# print('Do fixed points exist in the scenario:', fixed_points_fix(trace, 80))
# avg_vel, unsafe_frac, unsafe_init = eval_velocity([trace])
# fig = go.Figure()
# fig = simulation_tree_3d(trace, fig,\
Expand All @@ -83,22 +87,24 @@
###

###
ball_scenario = BallScenario().scenario
# ball_scenario = BallScenario().scenario
ball_scenario_branch = BallScenarioBranch().scenario

# ball_scenario_branch_nt = BallScenarioBranchNT().scenario ### this scenario's verify doesn't really make any sense given its simulate -- figure out why
# ## trying to verify with two agents in NT takes forever for some reason
# trace = ball_scenario_branch_nt.verify(80, 0.1)
# trace = ball_scenario_branch_nt.simulate(80, 0.1)
# # ball_scenario_branch_nt = BallScenarioBranchNT().scenario ### this scenario's verify doesn't really make any sense given its simulate -- figure out why
# # ## trying to verify with two agents in NT takes forever for some reason
# # trace = ball_scenario_branch_nt.verify(80, 0.1)
# # trace = ball_scenario_branch_nt.simulate(80, 0.1)

trace = ball_scenario_branch.verify(40, 0.1)
# print(reach_at(trace)) ### print out more elegantly, for example, line breaks and cut off array floats, test out more thoroughly
# print(reach_at(trace, 39, 39.91)) ### needs to be slightly more than T-delta T due to small differences in real trace, could also fix in reach_at by rounding off the times dimension
print(fixed_points_sat(trace, 40, 0.01))

# fig = go.Figure()
# fig = simulation_tree(trace, None, fig, 1, 2, [1, 2], "fill", "trace")
# fig.show()
pp_fix(reach_at_fix(trace)) ### print out more elegantly, for example, line breaks and cut off array floats, test out more thoroughly
# pp_old(reach_at(trace, 39, 39.91)) ### needs to be slightly more than T-delta T due to small differences in real trace, could also fix in reach_at by rounding off the times dimension
pp_fix(reach_at_fix(trace, 0, 39.91))
# # print(fixed_points_sat(trace, 40, 0.01))
print(fixed_points_fix(trace, 40, 0.01))

fig = go.Figure()
fig = reachtube_tree(trace, None, fig, 1, 2, [1, 2], "fill", "trace")
fig.show()
# print(f'Do there exist fixed points? {fixed_points(ball_scenario, "red-ball", t=80)}')
# print(f'Do there exist fixed points? {fixed_points_aa_branching(ball_scenario, t=80)}')
# print(f'Do there exist fixed points? {fixed_points_aa_branching(ball_scenario_branch, t=80)}')
Expand Down

0 comments on commit 966ebf0

Please sign in to comment.