-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathHELLM.py
111 lines (102 loc) · 3.12 KB
/
HELLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import yaml
import numpy as np
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from scenario.scenario import Scenario
from LLMDriver.driverAgent import DriverAgent
from LLMDriver.outputAgent import OutputParser
from LLMDriver.customTools import (
getAvailableActions,
getAvailableLanes,
getLaneInvolvedCar,
isChangeLaneConflictWithCar,
isAccelerationConflictWithCar,
isKeepSpeedConflictWithCar,
isDecelerationSafe,
isActionSafe,
)
OPENAI_CONFIG = yaml.load(open('config.yaml'), Loader=yaml.FullLoader)
if OPENAI_CONFIG['OPENAI_API_TYPE'] == 'azure':
os.environ["OPENAI_API_TYPE"] = OPENAI_CONFIG['OPENAI_API_TYPE']
os.environ["OPENAI_API_VERSION"] = OPENAI_CONFIG['AZURE_API_VERSION']
os.environ["OPENAI_API_BASE"] = OPENAI_CONFIG['AZURE_API_BASE']
os.environ["OPENAI_API_KEY"] = OPENAI_CONFIG['AZURE_API_KEY']
llm = AzureChatOpenAI(
deployment_name=OPENAI_CONFIG['AZURE_MODEL'],
temperature=0,
max_tokens=1024,
request_timeout=60
)
elif OPENAI_CONFIG['OPENAI_API_TYPE'] == 'openai':
os.environ["OPENAI_API_KEY"] = OPENAI_CONFIG['OPENAI_KEY']
llm = ChatOpenAI(
temperature=0,
model_name='gpt-3.5-turbo-1106', # or any other model with 8k+ context
max_tokens=1024,
request_timeout=60
)
# base setting
vehicleCount = 15
# environment setting
config = {
"observation": {
"type": "Kinematics",
"features": ["presence", "x", "y", "vx", "vy"],
"absolute": True,
"normalize": False,
"vehicles_count": vehicleCount,
"see_behind": True,
},
"action": {
"type": "DiscreteMetaAction",
"target_speeds": np.linspace(0, 32, 9),
},
"duration": 40,
"vehicles_density": 2,
"show_trajectories": True,
"render_agent": True,
}
env = gym.make('highway-v0', render_mode="rgb_array")
env.configure(config)
env = RecordVideo(
env, './results-video',
name_prefix=f"highwayv0"
)
env.unwrapped.set_record_video_wrapper(env)
obs, info = env.reset()
env.render()
# scenario and driver agent setting
if not os.path.exists('results-db/'):
os.mkdir('results-db')
database = f"results-db/highwayv0.db"
sce = Scenario(vehicleCount, database)
toolModels = [
getAvailableActions(env),
getAvailableLanes(sce),
getLaneInvolvedCar(sce),
isChangeLaneConflictWithCar(sce),
isAccelerationConflictWithCar(sce),
isKeepSpeedConflictWithCar(sce),
isDecelerationSafe(sce),
isActionSafe(),
]
DA = DriverAgent(llm, toolModels, sce, verbose=True)
outputParser = OutputParser(sce, llm)
output = None
done = truncated = False
frame = 0
try:
while not (done or truncated):
sce.upateVehicles(obs, frame)
DA.agentRun(output)
da_output = DA.exportThoughts()
output = outputParser.agentRun(da_output)
env.render()
env.unwrapped.automatic_rendering_callback = env.video_recorder.capture_frame()
obs, reward, done, info, _ = env.step(output["action_id"])
print(output)
frame += 1
finally:
env.close()