-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
349 lines (313 loc) · 13.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
import sys
import time
import json
import wget
import yaml
import boto3
import base64
import urllib
import logging
import asyncio
import globals
import argparse
import paramiko
from utils import *
from constants import *
from pathlib import Path
from scp import SCPClient
from typing import Optional, List
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from botocore.exceptions import NoCredentialsError, ClientError
from globals import (
create_iam_instance_profile_arn,
get_region,
get_iam_role,
get_sg_id,
get_key_pair,
)
executor = ThreadPoolExecutor()
# Initialize global variables for this file
instance_id_list: List = []
fmbench_config_map: List = []
fmbench_post_startup_script_map: List = []
instance_data_map: Dict = {}
logging.basicConfig(
level=logging.INFO, # Set the log level to INFO
# Define log message format
format="[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("fmbench-orchestrator.log"), # Log to a file
logging.StreamHandler(), # Also log to console
],
)
async def execute_fmbench(instance, formatted_script, remote_script_path):
"""
Asynchronous wrapper for deploying an instance using synchronous functions.
"""
# Check for the startup completion flag
startup_complete = await asyncio.get_event_loop().run_in_executor(
executor,
wait_for_flag,
instance,
STARTUP_COMPLETE_FLAG_FPATH,
CLOUD_INITLOG_PATH,
)
if startup_complete:
if instance['byo_dataset_fpath']:
await upload_byo_dataset(
instance["hostname"],
instance["username"],
instance["key_file_path"],
instance["byo_dataset_fpath"],
)
# Handle configuration file (download/upload) and get the remote path
remote_config_path = await handle_config_file_async(instance)
# Format the script with the remote config file path
# Change this later to be a better implementation, right now it is bad.
formatted_script = formatted_script.format(config_file=remote_config_path)
print("Startup Script complete, executing fmbench now")
if instance["fmbench_llm_config_fpath"]:
logger.info("Going to use custom tokenizer and config")
await upload_config_and_tokenizer(
instance["hostname"],
instance["username"],
instance["key_file_path"],
instance["fmbench_llm_config_fpath"],
instance["fmbench_llm_tokenizer_fpath"],
instance["fmbench_tokenizer_remote_dir"],
)
# Upload and execute the script on the instance
script_output = await asyncio.get_event_loop().run_in_executor(
executor,
upload_and_execute_script_invoke_shell,
instance["hostname"],
instance["username"],
instance["key_file_path"],
formatted_script,
remote_script_path,
)
print(f"Script Output from {instance['hostname']}:\n{script_output}")
# Check for the fmbench completion flag
fmbench_complete = await asyncio.get_event_loop().run_in_executor(
executor,
wait_for_flag,
instance,
FMBENCH_TEST_COMPLETE_FLAG_FPATH,
FMBENCH_LOG_PATH,
instance["fmbench_complete_timeout"],
SCRIPT_CHECK_INTERVAL_IN_SECONDS,
)
if fmbench_complete:
logger.info("Fmbench Run successful, Getting the folders now")
results_folder = os.path.join(
RESULTS_DIR, globals.config_data["general"]["name"]
)
await asyncio.get_event_loop().run_in_executor(
executor, check_and_retrieve_results_folder, instance, results_folder
)
if globals.config_data["run_steps"]["delete_ec2_instance"]:
delete_ec2_instance(instance["instance_id"], instance["region"])
instance_id_list.remove(instance["instance_id"])
async def multi_deploy_fmbench(instance_details, remote_script_path):
tasks = []
# Create a task for each instance
for instance in instance_details:
# Make this async as well?
# Format the script with the specific config file
logger.info(f"Instance Details are: {instance}")
logger.info(
f"Attempting to open bash script at {instance['post_startup_script']}"
)
with open(instance["post_startup_script"]) as file:
bash_script = file.read()
logger.info("Read Bash Script")
logger.info(f"Post startup script is: {bash_script}")
# Create an async task for this instance
tasks.append(execute_fmbench(instance, bash_script, remote_script_path))
# Run all tasks concurrently
await asyncio.gather(*tasks)
async def main():
await multi_deploy_fmbench(instance_details, remote_script_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run FMBench orchestrator with a specified config file."
)
parser.add_argument(
"--config-file",
type=str,
help="Path to your Config File",
required=False,
default="configs/config.yml",
)
args = parser.parse_args()
logger.info(f"main, {args} = args")
globals.config_data = load_yaml_file(args.config_file)
logger.info(f"Loaded Config {globals.config_data}")
hf_token_fpath = globals.config_data["aws"].get("hf_token_fpath")
hf_token: Optional[str] = None
logger.info(f"Got Hugging Face Token file path from config. {hf_token_fpath}")
logger.info("Attempting to open it")
if Path(hf_token_fpath).is_file():
hf_token = Path(hf_token_fpath).read_text().strip()
else:
logger.error(f"{hf_token_fpath} does not exist, cannot continue")
sys.exit(1)
logger.info(f"read hugging face token {hf_token} from file path")
assert len(hf_token) > 4, "Hf_token is too small or invalid, please check"
for i in globals.config_data["instances"]:
logger.info(f"Instance list is as follows: {i}")
logger.info(f"Deploying Ec2 Instances")
if globals.config_data["run_steps"]["deploy_ec2_instance"]:
if globals.config_data["run_steps"]["create_iam_role"]:
try:
iam_arn = create_iam_instance_profile_arn()
except Exception as e:
logger.error(f"Cannot create IAM Role due to exception {e}")
logger.info("Going to get iam role from the current instance")
iam_arn = get_iam_role()
else:
try:
iam_arn = get_iam_role()
except Exception as e:
logger.error(f"Cannot get IAM Role due to exception {e}")
if not iam_arn:
raise NoCredentialsError(
"""Unable to locate credentials,
Please check if an IAM role is
attched to your instance."""
)
logger.info(f"iam arn: {iam_arn}")
# WIP Parallelize This.
num_instances: int = len(globals.config_data["instances"])
for idx, instance in enumerate(globals.config_data["instances"]):
idx += 1
logger.info(f"going to create instance {idx} of {num_instances}")
deploy: bool = instance.get("deploy", True)
if deploy is False:
logger.warning(
f"deploy={deploy} for instance={json.dumps(instance, indent=2)}, skipping it..."
)
continue
region = instance["region"]
startup_script = instance["startup_script"]
logger.info(f"Region Set for instance is: {region}")
if globals.config_data["run_steps"]["security_group_creation"]:
logger.info(
f"Creating Security Groups. getting them by name if they exist"
)
sg_id = get_sg_id(region)
PRIVATE_KEY_FNAME, PRIVATE_KEY_NAME = get_key_pair(region)
# command_to_run = instance["command_to_run"]
with open(f"{startup_script}", "r") as file:
user_data_script = file.read()
# Replace the hf token in the bash script to pull the HF model
user_data_script = user_data_script.replace("__HF_TOKEN__", hf_token)
if instance.get("instance_id") is None:
instance_type = instance["instance_type"]
ami_id = instance["ami_id"]
device_name = instance["device_name"]
ebs_del_on_termination = instance["ebs_del_on_termination"]
ebs_Iops = instance["ebs_Iops"]
ebs_VolumeSize = instance["ebs_VolumeSize"]
ebs_VolumeType = instance["ebs_VolumeType"]
# Retrieve CapacityReservationId and CapacityReservationResourceGroupArn if they exist
CapacityReservationId = instance.get("CapacityReservationId", None)
CapacityReservationPreference = instance.get(
"CapacityReservationPreference", "none"
)
CapacityReservationResourceGroupArn = instance.get(
"CapacityReservationResourceGroupArn", None
)
# Initialize CapacityReservationTarget only if either CapacityReservationId or CapacityReservationResourceGroupArn is provided
CapacityReservationTarget = {}
if CapacityReservationId:
CapacityReservationTarget["CapacityReservationId"] = (
CapacityReservationId
)
if CapacityReservationResourceGroupArn:
CapacityReservationTarget["CapacityReservationResourceGroupArn"] = (
CapacityReservationResourceGroupArn
)
# If CapacityReservationTarget is empty, set it to None
if not CapacityReservationTarget:
CapacityReservationTarget = None
# user_data_script += command_to_run
# Create an EC2 instance with the user data script
instance_id = create_ec2_instance(
idx,
PRIVATE_KEY_NAME,
sg_id,
user_data_script,
ami_id,
instance_type,
iam_arn,
region,
device_name,
ebs_del_on_termination,
ebs_Iops,
ebs_VolumeSize,
ebs_VolumeType,
CapacityReservationPreference,
CapacityReservationTarget,
)
instance_id_list.append(instance_id)
instance_data_map[instance_id] = {
"fmbench_config": instance["fmbench_config"],
"post_startup_script": instance["post_startup_script"],
"fmbench_llm_tokenizer_fpath": instance.get(
"fmbench_llm_tokenizer_fpath"
),
"fmbench_llm_config_fpath": instance.get(
"fmbench_llm_config_fpath"
),
"fmbench_tokenizer_remote_dir": instance.get(
"fmbench_tokenizer_remote_dir"
),
"fmbench_complete_timeout": instance["fmbench_complete_timeout"],
"region": instance["region"],
"PRIVATE_KEY_FNAME": PRIVATE_KEY_FNAME,
"byo_dataset_fpath": instance.get("byo_dataset_fpath")
}
if instance.get("instance_id") is not None:
instance_id = instance["instance_id"]
# TODO: Check if host machine can open the private key provided, if it cant, raise exception
PRIVATE_KEY_FNAME = instance["private_key_fname"]
if not PRIVATE_KEY_FNAME:
logger.error(
"Private key not found, not adding instance to instance id list"
)
if PRIVATE_KEY_FNAME:
instance_id_list.append(instance_id)
instance_data_map[instance_id] = {
"fmbench_config": instance["fmbench_config"],
"post_startup_script": instance["post_startup_script"],
"fmbench_llm_tokenizer_fpath": instance.get(
"fmbench_llm_tokenizer_fpath"
),
"fmbench_llm_config_fpath": instance.get(
"fmbench_llm_config_fpath"
),
"fmbench_tokenizer_remote_dir": instance.get(
"fmbench_tokenizer_remote_dir"
),
"fmbench_complete_timeout": instance[
"fmbench_complete_timeout"
],
"region": instance["region"],
"PRIVATE_KEY_FNAME": PRIVATE_KEY_FNAME,
"byo_dataset_fpath": instance.get("byo_dataset_fpath")
}
logger.info(f"done creating instance {idx} of {num_instances}")
sleep_time = 60
logger.info(
f"Going to Sleep for {sleep_time} seconds to make sure the instances are up"
)
time.sleep(sleep_time)
if globals.config_data["run_steps"]["run_bash_script"]:
instance_details = generate_instance_details(
instance_id_list, instance_data_map
) # Call the async function
asyncio.run(main())
logger.info("all done")