From b3eec537f60c70b3212929efb2e77c78e4a8c413 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 25 Oct 2023 19:10:14 -0700 Subject: [PATCH 01/22] fix: fix networking issue inference --- presets/llama-2-chat/inference-api.py | 7 ++++--- presets/llama-2/inference-api.py | 8 ++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 12063350f..8e7bb2377 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -4,6 +4,7 @@ import uvicorn from pydantic import BaseModel from typing import Optional +from multiprocessing import Process import threading from llama import Llama @@ -147,8 +148,8 @@ def health_check(): return {"status": "Healthy"} def start_worker_server(): - uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") + uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) def worker_listen_tasks(): while True: @@ -203,8 +204,8 @@ def worker_listen_tasks(): # Start the worker server in a separate thread. This worker server will # provide a healthz endpoint for monitoring the health of the node. - server_thread = threading.Thread(target=start_worker_server, daemon=True) - server_thread.start() + server_process = Process(target=start_worker_server) + server_process.start() # Regardless of local rank, all non-globally-0-ranked processes will listen # for tasks (like chat completion) from the main server. diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 70aa12e25..16843e3dc 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -4,7 +4,7 @@ import uvicorn from pydantic import BaseModel from typing import Optional -import threading +from multiprocessing import Process import time from multiprocessing import Value @@ -136,8 +136,8 @@ def health_check(): return {"status": "Healthy"} def start_worker_server(): - uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") + uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) def worker_listen_tasks(): while True: @@ -191,8 +191,8 @@ def worker_listen_tasks(): # Start the worker server in a separate thread. This worker server will # provide a healthz endpoint for monitoring the health of the node. - server_thread = threading.Thread(target=start_worker_server, daemon=True) - server_thread.start() + server_process = Process(target=start_worker_server) + server_process.start() # Regardless of local rank, all non-globally-0-ranked processes will listen # for tasks (like text completion) from the main server. From c1d3a18dccc23e5daa83519f2fff2e6b33a36bbb Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 25 Oct 2023 19:18:37 -0700 Subject: [PATCH 02/22] nit: remove threading --- presets/llama-2-chat/inference-api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 8e7bb2377..8c1fd876d 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -5,7 +5,6 @@ from pydantic import BaseModel from typing import Optional from multiprocessing import Process -import threading from llama import Llama import torch From 32586dd0a4f3bcfb8dc0ba2c11e8fe3556b04f3c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Fri, 27 Oct 2023 01:47:27 -0700 Subject: [PATCH 03/22] fix: ensure child process --- presets/llama-2-chat/inference-api.py | 51 +++++++++++++----------- presets/llama-2/inference-api.py | 57 +++++++++++++++------------ 2 files changed, 59 insertions(+), 49 deletions(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 8c1fd876d..f80bf3949 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -4,7 +4,7 @@ import uvicorn from pydantic import BaseModel from typing import Optional -from multiprocessing import Process +import multiprocessing from llama import Llama import torch @@ -155,26 +155,31 @@ def worker_listen_tasks(): worker_num = dist.get_rank() print(f"Worker {worker_num} ready to recieve next command") config = [None] * 3 # Command and its associated data - dist.broadcast_object_list(config, src=0) - command = config[0] - - if command == "generate": - try: - input_string = config[1] - parameters = config[2] - generator.chat_completion( - input_string, - max_gen_len=parameters.get('max_gen_len', None), - temperature=parameters.get('temperature', 0.6), - top_p=parameters.get('top_p', 0.9) - ) - print(f"Worker {worker_num} completed generation") - except Exception as e: - print(f"Error in generation: {str(e)}") - elif command == "shutdown": - print(f"Worker {worker_num} shutting down") - sys.exit(0) - + try: + print(f"Worker {worker_num} entered broadcast listen") + dist.broadcast_object_list(config, src=0) + print(f"Worker {worker_num} left broadcast listen") + command = config[0] + + if command == "generate": + try: + input_string = config[1] + parameters = config[2] + print(f"Worker {worker_num} started generation") + generator.chat_completion( + input_string, + max_gen_len=parameters.get('max_gen_len', None), + temperature=parameters.get('temperature', 0.6), + top_p=parameters.get('top_p', 0.9) + ) + print(f"Worker {worker_num} completed generation") + except Exception as e: + print(f"Error in generation: {str(e)}") + elif command == "shutdown": + print(f"Worker {worker_num} shutting down") + sys.exit(0) + except Exception as e: + print(f"Error in Worker Listen Task", e) if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -201,9 +206,9 @@ def worker_listen_tasks(): app_worker = FastAPI() setup_worker_routes() - # Start the worker server in a separate thread. This worker server will + # Start the worker server in a separate process. This worker server will # provide a healthz endpoint for monitoring the health of the node. - server_process = Process(target=start_worker_server) + server_process = multiprocessing.Process(target=start_worker_server, daemon=True) server_process.start() # Regardless of local rank, all non-globally-0-ranked processes will listen diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 16843e3dc..3532b65f2 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -4,9 +4,8 @@ import uvicorn from pydantic import BaseModel from typing import Optional -from multiprocessing import Process +import multiprocessing import time -from multiprocessing import Value from llama import Llama import torch @@ -139,30 +138,36 @@ def start_worker_server(): print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) -def worker_listen_tasks(): +def worker_listen_tasks(): while True: worker_num = dist.get_rank() - print(f"Worker {worker_num} ready to receive next command") - config = [None] * 3 - dist.broadcast_object_list(config, src=0) - command = config[0] - - if command == "text_generate": - try: - prompts = config[1] - parameters = config[2] - generator.text_completion( - prompts, - max_gen_len=parameters.get('max_gen_len', None), - temperature=parameters.get('temperature', 0.6), - top_p=parameters.get('top_p', 0.9) - ) - print(f"Worker {worker_num} completed generation") - except Exception as e: - print(f"Error in generation: {str(e)}") - elif command == "shutdown": - print(f"Worker {worker_num} shutting down") - sys.exit(0) + print(f"Worker {worker_num} ready to recieve next command") + config = [None] * 3 # Command and its associated data + try: + print(f"Worker {worker_num} entered broadcast listen") + dist.broadcast_object_list(config, src=0) + print(f"Worker {worker_num} left broadcast listen") + command = config[0] + + if command == "text_generate": + try: + input_string = config[1] + parameters = config[2] + print(f"Worker {worker_num} started generation") + generator.text_completion( + input_string, + max_gen_len=parameters.get('max_gen_len', None), + temperature=parameters.get('temperature', 0.6), + top_p=parameters.get('top_p', 0.9) + ) + print(f"Worker {worker_num} completed generation") + except Exception as e: + print(f"Error in generation: {str(e)}") + elif command == "shutdown": + print(f"Worker {worker_num} shutting down") + sys.exit(0) + except Exception as e: + print(f"Error in Worker Listen Task", e) if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -189,9 +194,9 @@ def worker_listen_tasks(): app_worker = FastAPI() setup_worker_routes() - # Start the worker server in a separate thread. This worker server will + # Start the worker server in a separate process. This worker server will # provide a healthz endpoint for monitoring the health of the node. - server_process = Process(target=start_worker_server) + server_process = multiprocessing.Process(target=start_worker_server, daemon=True) server_process.start() # Regardless of local rank, all non-globally-0-ranked processes will listen From 933efc61fbae11dc908aab7d8118a4e2e417d4b5 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 29 Oct 2023 20:03:23 -0700 Subject: [PATCH 04/22] fix: upgrade nvidia pytorch --- docker/presets/llama-2/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/presets/llama-2/Dockerfile b/docker/presets/llama-2/Dockerfile index 3f091053b..8fa70579c 100644 --- a/docker/presets/llama-2/Dockerfile +++ b/docker/presets/llama-2/Dockerfile @@ -10,7 +10,7 @@ # --build-arg SRC_DIR=/home/presets/llama-2-chat \ # -t llama-2-7b-chat:latest . -FROM nvcr.io/nvidia/pytorch:23.06-py3 +FROM nvcr.io/nvidia/pytorch:23.10-py3 WORKDIR /workspace RUN git clone https://github.com/facebookresearch/llama From 86101a8347c651e92d4ce4aa51c69470fb2037f5 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 29 Oct 2023 20:19:23 -0700 Subject: [PATCH 05/22] fix: lint --- .github/workflows/preset-image-build.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/preset-image-build.yml b/.github/workflows/preset-image-build.yml index 347728a75..009592378 100644 --- a/.github/workflows/preset-image-build.yml +++ b/.github/workflows/preset-image-build.yml @@ -1,5 +1,4 @@ name: Build and Push Preset Models - on: pull_request: branches: From b691d9bcea394e38872a956c022b1bd338ead24e Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 29 Oct 2023 20:27:40 -0700 Subject: [PATCH 06/22] fix: naming --- .github/workflows/preset-image-build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/preset-image-build.yml b/.github/workflows/preset-image-build.yml index 009592378..2569bd180 100644 --- a/.github/workflows/preset-image-build.yml +++ b/.github/workflows/preset-image-build.yml @@ -20,7 +20,7 @@ on: description: 'Release (yes/no)' required: true default: 'no' - image_tag: + image_tag_name: description: 'Image Tag' required: true @@ -93,9 +93,9 @@ jobs: - name: Set Image Tag id: set_tag run: | - if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag }}" ]]; then + if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag_name }}" ]]; then echo "Using workflow dispatch to set image tag" - echo "image_tag=${{ github.event.inputs.image_tag }}" >> $GITHUB_OUTPUT + echo "image_tag=${{ github.event.inputs.image_tag_name }}" >> $GITHUB_OUTPUT elif [[ "${{ github.event_name }}" == "push" && "${{ github.ref }}" == "refs/heads/main" ]]; then echo "Setting image tag to be latest" echo "image_tag=latest" >> $GITHUB_OUTPUT From cbbebbdd28f985e3b3898f719f4b74feb9db4e9d Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 29 Oct 2023 20:43:38 -0700 Subject: [PATCH 07/22] fix: diff --- .github/workflows/preset-image-build.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/preset-image-build.yml b/.github/workflows/preset-image-build.yml index 2569bd180..81f8e5ccb 100644 --- a/.github/workflows/preset-image-build.yml +++ b/.github/workflows/preset-image-build.yml @@ -48,7 +48,12 @@ jobs: - name: Get Modified files run: | - files=$(git diff --name-only HEAD^ HEAD) + current_branch=$(git rev-parse --abbrev-ref HEAD) + if [ "$current_branch" == "main" ] || [ "$current_branch" == "master" ]; then + files=$(git diff --name-only HEAD^ HEAD) + else + files=$(git diff --name-only main...HEAD) + fi echo "Modified files: $files" FILES_MODIFIED="" while IFS= read -r file; do From 52487370488db742ace29530630251f3d2363b86 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 29 Oct 2023 20:46:18 -0700 Subject: [PATCH 08/22] fix: fetch --- .github/workflows/preset-image-build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/preset-image-build.yml b/.github/workflows/preset-image-build.yml index 81f8e5ccb..d9f3c38b8 100644 --- a/.github/workflows/preset-image-build.yml +++ b/.github/workflows/preset-image-build.yml @@ -48,6 +48,7 @@ jobs: - name: Get Modified files run: | + git fetch origin main:main current_branch=$(git rev-parse --abbrev-ref HEAD) if [ "$current_branch" == "main" ] || [ "$current_branch" == "master" ]; then files=$(git diff --name-only HEAD^ HEAD) From 1690ccc2a179bea6631f4cca20c8d57dd4627c1c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Sun, 29 Oct 2023 21:05:55 -0700 Subject: [PATCH 09/22] fix: log --- presets/llama-2-chat/inference-api.py | 5 +++++ presets/llama-2/inference-api.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index f80bf3949..325de1f26 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -21,6 +21,8 @@ parser.add_argument("--max_seq_len", type=int, default=128, help="Maximum sequence length.") parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size.") parser.add_argument("--model_parallel_size", type=int, default=int(os.environ.get("WORLD_SIZE", 1)), help="Model parallel size.") +parser.add_argument("--local-rank", type=int, default=int(os.environ.get("WORLD_SIZE", 1)), help="Model parallel size.") + args = parser.parse_args() should_shutdown = False @@ -178,6 +180,9 @@ def worker_listen_tasks(): elif command == "shutdown": print(f"Worker {worker_num} shutting down") sys.exit(0) + except torch.distributed.DistBackendError as e: + print("torch.distributed.DistBackendError:") + print(e) except Exception as e: print(f"Error in Worker Listen Task", e) diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 3532b65f2..0af022148 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -166,6 +166,9 @@ def worker_listen_tasks(): elif command == "shutdown": print(f"Worker {worker_num} shutting down") sys.exit(0) + except torch.distributed.DistBackendError as e: + print("torch.distributed.DistBackendError:") + print(e) except Exception as e: print(f"Error in Worker Listen Task", e) From 2ee601c4e14992651d6157d15640dc6f81583f42 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Mon, 30 Oct 2023 20:41:16 -0700 Subject: [PATCH 10/22] feat: add the headless service, add the resliency to ensure cleanup of processes upon termination --- presets/llama-2-chat/inference-api.py | 44 +++++++++++++--------- presets/llama-2/inference-api.py | 40 ++++++++++++-------- presets/test/manifests/llama-headless.yaml | 14 +++++++ 3 files changed, 64 insertions(+), 34 deletions(-) create mode 100644 presets/test/manifests/llama-headless.yaml diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 325de1f26..97184f4e8 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -21,8 +21,6 @@ parser.add_argument("--max_seq_len", type=int, default=128, help="Maximum sequence length.") parser.add_argument("--max_batch_size", type=int, default=4, help="Maximum batch size.") parser.add_argument("--model_parallel_size", type=int, default=int(os.environ.get("WORLD_SIZE", 1)), help="Model parallel size.") -parser.add_argument("--local-rank", type=int, default=int(os.environ.get("WORLD_SIZE", 1)), help="Model parallel size.") - args = parser.parse_args() should_shutdown = False @@ -181,10 +179,11 @@ def worker_listen_tasks(): print(f"Worker {worker_num} shutting down") sys.exit(0) except torch.distributed.DistBackendError as e: - print("torch.distributed.DistBackendError:") - print(e) + print("torch.distributed.DistBackendError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) except Exception as e: print(f"Error in Worker Listen Task", e) + os.killpg(os.getpgrp(), signal.SIGTERM) if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -205,17 +204,26 @@ def worker_listen_tasks(): # Uncomment to enable worker logs # sys.stdout = sys.__stdout__ - # If the current process is the locally ranked 0 (i.e., the primary process) - # on its node, then it starts a worker server that exposes a health check endpoint. - if local_rank == 0: - app_worker = FastAPI() - setup_worker_routes() - - # Start the worker server in a separate process. This worker server will - # provide a healthz endpoint for monitoring the health of the node. - server_process = multiprocessing.Process(target=start_worker_server, daemon=True) - server_process.start() - - # Regardless of local rank, all non-globally-0-ranked processes will listen - # for tasks (like chat completion) from the main server. - worker_listen_tasks() + os.setpgrp() + server_process = None + try: + # If the current process is the locally ranked 0 (i.e., the primary process) + # on its node, then it starts a worker server that exposes a health check endpoint. + if local_rank == 0: + app_worker = FastAPI() + setup_worker_routes() + + # Start the worker server in a separate process. This worker server will + # provide a healthz endpoint for monitoring the health of the node. + server_process = multiprocessing.Process(target=start_worker_server, daemon=True) + server_process.start() + + # Regardless of local rank, all non-globally-0-ranked processes will listen + # for tasks (like chat completion) from the main server. + worker_listen_tasks() + finally: + if server_process: + server_process.terminate() + server_process.join() + # Additional fail-safe (to ensure no lingering processes) + os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 0af022148..58bb4f79a 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -167,8 +167,7 @@ def worker_listen_tasks(): print(f"Worker {worker_num} shutting down") sys.exit(0) except torch.distributed.DistBackendError as e: - print("torch.distributed.DistBackendError:") - print(e) + print("torch.distributed.DistBackendError", e) except Exception as e: print(f"Error in Worker Listen Task", e) @@ -191,17 +190,26 @@ def worker_listen_tasks(): # Uncomment to enable worker logs # sys.stdout = sys.__stdout__ - # If the current process is the locally ranked 0 (i.e., the primary process) - # on its node, then it starts a worker server that exposes a health check endpoint. - if local_rank == 0: - app_worker = FastAPI() - setup_worker_routes() - - # Start the worker server in a separate process. This worker server will - # provide a healthz endpoint for monitoring the health of the node. - server_process = multiprocessing.Process(target=start_worker_server, daemon=True) - server_process.start() - - # Regardless of local rank, all non-globally-0-ranked processes will listen - # for tasks (like text completion) from the main server. - worker_listen_tasks() + os.setpgrp() + server_process = None + try: + # If the current process is the locally ranked 0 (i.e., the primary process) + # on its node, then it starts a worker server that exposes a health check endpoint. + if local_rank == 0: + app_worker = FastAPI() + setup_worker_routes() + + # Start the worker server in a separate process. This worker server will + # provide a healthz endpoint for monitoring the health of the node. + server_process = multiprocessing.Process(target=start_worker_server, daemon=True) + server_process.start() + + # Regardless of local rank, all non-globally-0-ranked processes will listen + # for tasks (like chat completion) from the main server. + worker_listen_tasks() + finally: + if server_process: + server_process.terminate() + server_process.join() + # Additional fail-safe (to ensure no lingering processes) + os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/test/manifests/llama-headless.yaml b/presets/test/manifests/llama-headless.yaml new file mode 100644 index 000000000..e0514564f --- /dev/null +++ b/presets/test/manifests/llama-headless.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: Service +metadata: + name: llama-headless +spec: + selector: + app: llama + clusterIP: None + ports: + - name: torchrun + protocol: TCP + port: 29500 + targetPort: 29500 + publishNotReadyAddresses: true \ No newline at end of file From 224416a2fa590cce5dc2cff768b82d9cd77ff28f Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 31 Oct 2023 19:25:35 -0700 Subject: [PATCH 11/22] fix: timeout error handling --- presets/llama-2-chat/inference-api.py | 78 +++++++++++++++++++-------- presets/llama-2/inference-api.py | 78 +++++++++++++++++++-------- 2 files changed, 113 insertions(+), 43 deletions(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 97184f4e8..b734686b9 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -5,6 +5,9 @@ from pydantic import BaseModel from typing import Optional import multiprocessing +import multiprocessing.pool +import threading +import functools from llama import Llama import torch @@ -25,6 +28,20 @@ should_shutdown = False +def timeout(max_timeout): + """Timeout decorator, parameter in seconds.""" + def timeout_decorator(item): + """Wrap the original function.""" + @functools.wraps(item) + def func_wrapper(*args, **kwargs): + """Closure for function.""" + with multiprocessing.pool.ThreadPool(processes=1) as pool: + async_result = pool.apply_async(item, args, kwargs) + # raises a TimeoutError if execution exceeds max_timeout + return async_result.get(max_timeout) + return func_wrapper + return timeout_decorator + def build_generator(params): """Build Llama generator from provided parameters.""" return Llama.build(**params) @@ -33,6 +50,7 @@ def broadcast_for_shutdown(): """Broadcasts shutdown command to worker processes.""" dist.broadcast_object_list(["shutdown", None, None], src=0) +@timeout(60.0) def broadcast_for_generation(input_string, max_gen_len, temperature, top_p): """Broadcasts generation parameters to worker processes.""" dist.broadcast_object_list(["generate", input_string, { @@ -41,6 +59,15 @@ def broadcast_for_generation(input_string, max_gen_len, temperature, top_p): 'top_p': top_p }], src=0) +@timeout(60.0) +def inference(input_string, max_gen_len, temperature, top_p): + return generator.chat_completion( + input_string, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + def shutdown_server(): """Shut down the server.""" os.kill(os.getpid(), signal.SIGINT) @@ -99,18 +126,24 @@ def chat_completion(params: ChatParameters): top_p = parameters.get('top_p', 0.9) if dist.get_world_size() > 1: - # Broadcast generation params to worker processes - broadcast_for_generation(input_string, max_gen_len, temperature, top_p) + try: + # Broadcast generation params to worker processes + broadcast_for_generation(input_string, max_gen_len, temperature, top_p) + except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Broadcast failed - TimeoutError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) + raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) # Master's own generation try: - results = generator.chat_completion( - input_string, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) + results = inference(input_string, max_gen_len, temperature, top_p) except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Inference failed - TimeoutError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) if len(results) == 0: @@ -166,14 +199,17 @@ def worker_listen_tasks(): input_string = config[1] parameters = config[2] print(f"Worker {worker_num} started generation") - generator.chat_completion( - input_string, - max_gen_len=parameters.get('max_gen_len', None), - temperature=parameters.get('temperature', 0.6), - top_p=parameters.get('top_p', 0.9) - ) + inference(input_string, + parameters.get('max_gen_len', None), + parameters.get('temperature', 0.6), + parameters.get('top_p', 0.9) + ) print(f"Worker {worker_num} completed generation") except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Inference failed - TimeoutError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) print(f"Error in generation: {str(e)}") elif command == "shutdown": print(f"Worker {worker_num} shutting down") @@ -205,7 +241,7 @@ def worker_listen_tasks(): # sys.stdout = sys.__stdout__ os.setpgrp() - server_process = None + server_thread = None try: # If the current process is the locally ranked 0 (i.e., the primary process) # on its node, then it starts a worker server that exposes a health check endpoint. @@ -213,17 +249,17 @@ def worker_listen_tasks(): app_worker = FastAPI() setup_worker_routes() - # Start the worker server in a separate process. This worker server will + # Start the worker server in a separate thread. This worker server will # provide a healthz endpoint for monitoring the health of the node. - server_process = multiprocessing.Process(target=start_worker_server, daemon=True) - server_process.start() + server_thread = threading.Thread(target=start_worker_server, daemon=True) + server_thread.start() # Regardless of local rank, all non-globally-0-ranked processes will listen # for tasks (like chat completion) from the main server. worker_listen_tasks() finally: - if server_process: - server_process.terminate() - server_process.join() + # if server_thread: + # server_thread.terminate() + # server_thread.join() # Additional fail-safe (to ensure no lingering processes) os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 58bb4f79a..49ab98e6b 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -5,7 +5,9 @@ from pydantic import BaseModel from typing import Optional import multiprocessing -import time +import multiprocessing.pool +import threading +import functools from llama import Llama import torch @@ -26,6 +28,20 @@ should_shutdown = False +def timeout(max_timeout): + """Timeout decorator, parameter in seconds.""" + def timeout_decorator(item): + """Wrap the original function.""" + @functools.wraps(item) + def func_wrapper(*args, **kwargs): + """Closure for function.""" + with multiprocessing.pool.ThreadPool(processes=1) as pool: + async_result = pool.apply_async(item, args, kwargs) + # raises a TimeoutError if execution exceeds max_timeout + return async_result.get(max_timeout) + return func_wrapper + return timeout_decorator + def build_generator(params): """Build Llama generator from provided parameters.""" return Llama.build(**params) @@ -34,6 +50,7 @@ def broadcast_for_shutdown(): """Broadcasts shutdown command to worker processes.""" dist.broadcast_object_list(["shutdown", None, None], src=0) +@timeout(60.0) def broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p): """Broadcasts generation parameters to worker processes.""" dist.broadcast_object_list(["text_generate", prompts, { @@ -42,6 +59,15 @@ def broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p): 'top_p': top_p }], src=0) +@timeout(60.0) +def inference(input_string, max_gen_len, temperature, top_p): + return generator.text_completion( + input_string, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + def shutdown_server(): """Shut down the server.""" os.kill(os.getpid(), signal.SIGINT) @@ -97,16 +123,22 @@ def generate_text(params: GenerationParameters): top_p = parameters.get('top_p', 0.9) if dist.get_world_size() > 1: - broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p) - - try: - results = generator.text_completion( - prompts, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) + try: + broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p) + except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Broadcast failed - TimeoutError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) + raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) + + try: + results = inference(prompts, max_gen_len, temperature, top_p) except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Inference failed - TimeoutError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) if len(results) == 0: @@ -154,22 +186,27 @@ def worker_listen_tasks(): input_string = config[1] parameters = config[2] print(f"Worker {worker_num} started generation") - generator.text_completion( - input_string, - max_gen_len=parameters.get('max_gen_len', None), - temperature=parameters.get('temperature', 0.6), - top_p=parameters.get('top_p', 0.9) - ) + inference(input_string, + parameters.get('max_gen_len', None), + parameters.get('temperature', 0.6), + parameters.get('top_p', 0.9) + ) print(f"Worker {worker_num} completed generation") except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Inference failed - TimeoutError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) print(f"Error in generation: {str(e)}") elif command == "shutdown": print(f"Worker {worker_num} shutting down") sys.exit(0) except torch.distributed.DistBackendError as e: print("torch.distributed.DistBackendError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) except Exception as e: print(f"Error in Worker Listen Task", e) + os.killpg(os.getpgrp(), signal.SIGTERM) if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -199,17 +236,14 @@ def worker_listen_tasks(): app_worker = FastAPI() setup_worker_routes() - # Start the worker server in a separate process. This worker server will + # Start the worker server in a separate thread. This worker server will # provide a healthz endpoint for monitoring the health of the node. - server_process = multiprocessing.Process(target=start_worker_server, daemon=True) - server_process.start() + server_thread = threading.Thread(target=start_worker_server, daemon=True) + server_thread.start() # Regardless of local rank, all non-globally-0-ranked processes will listen # for tasks (like chat completion) from the main server. worker_listen_tasks() finally: - if server_process: - server_process.terminate() - server_process.join() # Additional fail-safe (to ensure no lingering processes) os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file From a91bdedc5dab4a9c9f1eee1237ad3e1883a476c9 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Tue, 31 Oct 2023 19:28:57 -0700 Subject: [PATCH 12/22] fix: remove comments --- presets/llama-2-chat/inference-api.py | 4 ---- presets/llama-2/inference-api.py | 1 - 2 files changed, 5 deletions(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index b734686b9..65fe121e1 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -241,7 +241,6 @@ def worker_listen_tasks(): # sys.stdout = sys.__stdout__ os.setpgrp() - server_thread = None try: # If the current process is the locally ranked 0 (i.e., the primary process) # on its node, then it starts a worker server that exposes a health check endpoint. @@ -258,8 +257,5 @@ def worker_listen_tasks(): # for tasks (like chat completion) from the main server. worker_listen_tasks() finally: - # if server_thread: - # server_thread.terminate() - # server_thread.join() # Additional fail-safe (to ensure no lingering processes) os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 49ab98e6b..a027ac598 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -228,7 +228,6 @@ def worker_listen_tasks(): # sys.stdout = sys.__stdout__ os.setpgrp() - server_process = None try: # If the current process is the locally ranked 0 (i.e., the primary process) # on its node, then it starts a worker server that exposes a health check endpoint. From fda874af071c09b47ffef5ec265b3a1491b4439a Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 1 Nov 2023 10:12:24 -0700 Subject: [PATCH 13/22] feat: added torchrdzvparams, headless service --- pkg/controllers/workspace_controller.go | 6 +++++ pkg/inference/preset-inference-types.go | 19 +++++++++++++++ pkg/inference/preset-inferences.go | 8 +++++++ pkg/resources/manifests.go | 32 ++++++++++++++++++++++++- 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index b5b42f9ce..58b650c05 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -407,6 +407,12 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al if err != nil { return err } + headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj) + err = resources.CreateResource(ctx, headlessService, c.Client) + if err != nil { + return err + } + wObj.Annotations["headlessServiceCreated"] = "true" return nil } diff --git a/pkg/inference/preset-inference-types.go b/pkg/inference/preset-inference-types.go index 6ee8e9e8a..50c4a6ccf 100644 --- a/pkg/inference/preset-inference-types.go +++ b/pkg/inference/preset-inference-types.go @@ -19,6 +19,14 @@ const ( DefaultMasterPort = "29500" ) +// Torch Rendezvous Params +const ( + DefaultMaxRestarts = "3" + DefaultRdzvId = "rdzv_id" + DefaultRdzvBackend = "c10d" + DefaultRdzvEndpoint = "localhost:29500" // llama-2-13b-chat-0.llama-headless.default.svc.cluster.local:29500 +) + const ( DefaultConfigFile = "config.yaml" DefaultNumProcesses = "1" @@ -60,6 +68,13 @@ var ( "master_port": DefaultMasterPort, } + defaultTorchRunRdzvParams = map[string]string{ + "max_restarts": DefaultMaxRestarts, + "rdzv_id": DefaultRdzvId, + "rdzv_backend": DefaultRdzvBackend, + "rdzv_endpoint": DefaultRdzvEndpoint, + } + defaultAccelerateParams = map[string]string{ "config_file": DefaultConfigFile, "num_processes": DefaultNumProcesses, @@ -82,6 +97,7 @@ type PresetInferenceParam struct { GPURequirement string GPUMemoryRequirement string TorchRunParams map[string]string + TorchRunRdzvParams map[string]string ModelRunParams map[string]string InferenceFile string // DeploymentTimeout defines the maximum duration for pulling the Preset image. @@ -108,6 +124,7 @@ var ( GPURequirement: "1", GPUMemoryRequirement: "16Gi", TorchRunParams: defaultTorchRunParams, + TorchRunRdzvParams: defaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(10) * time.Minute, @@ -124,6 +141,7 @@ var ( GPURequirement: "2", GPUMemoryRequirement: "16Gi", TorchRunParams: defaultTorchRunParams, + TorchRunRdzvParams: defaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(20) * time.Minute, @@ -140,6 +158,7 @@ var ( GPURequirement: "8", GPUMemoryRequirement: "19Gi", TorchRunParams: defaultTorchRunParams, + TorchRunRdzvParams: defaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(30) * time.Minute, diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 8d2af6d8e..2c2b81c83 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -81,6 +81,13 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 inferenceObj.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP inferenceObj.TorchRunParams["master_port"] = "29500" } + if inferenceObj.TorchRunRdzvParams != nil { + inferenceObj.TorchRunRdzvParams["max_restarts"] = "3" + inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job" + inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d" + inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] = + fmt.Sprintf("%s-0.llama-headless.default.svc.cluster.local:29500", wObj.Inference.Preset.Name) + } } else if inferenceObj.ModelName == "Falcon" { inferenceObj.TorchRunParams["config_file"] = "config.yaml" inferenceObj.TorchRunParams["num_processes"] = "1" @@ -168,6 +175,7 @@ func checkResourceStatus(obj client.Object, kubeClient client.Client, timeoutDur func prepareInferenceParameters(ctx context.Context, inferenceObj PresetInferenceParam) ([]string, corev1.ResourceRequirements) { torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) + torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) modelCommand := buildCommandStr(inferenceObj.InferenceFile, inferenceObj.ModelRunParams) commands := shellCommand(torchCommand + " " + modelCommand) diff --git a/pkg/resources/manifests.go b/pkg/resources/manifests.go index a7cdc83de..a08b64b4f 100644 --- a/pkg/resources/manifests.go +++ b/pkg/resources/manifests.go @@ -17,6 +17,32 @@ import ( var controller = true +func GenerateHeadlessServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) *corev1.Service { + serviceName := fmt.Sprintf("%s-headless", workspaceObj.Inference.Preset.Name) + selector := make(map[string]string) + for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels { + selector[k] = v + } + return &corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + }, + Spec: corev1.ServiceSpec{ + Selector: selector, + ClusterIP: "None", + Ports: []corev1.ServicePort{ + { + Name: "torchrun", + Protocol: corev1.ProtocolTCP, + Port: 29500, + TargetPort: intstr.FromInt(29500), + }, + }, + PublishNotReadyAddresses: true, + }, + } +} + func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, serviceType corev1.ServiceType, isStatefulSet bool) *corev1.Service { selector := make(map[string]string) for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels { @@ -82,7 +108,7 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha Values: []string{value}, }) } - return &appsv1.StatefulSet{ + ss := &appsv1.StatefulSet{ ObjectMeta: v1.ObjectMeta{ Name: workspaceObj.Name, Namespace: workspaceObj.Namespace, @@ -136,6 +162,10 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha }, }, } + if val, ok := workspaceObj.Annotations["headlessServiceCreated"]; ok && val == "true" { + ss.Spec.ServiceName = fmt.Sprintf("%s-headless", workspaceObj.Inference.Preset.Name) + } + return ss } func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string, From cdf10e7886078c55a0082d8f1a0d6c7cc99f809e Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 1 Nov 2023 11:59:39 -0700 Subject: [PATCH 14/22] fix: simplify timeout --- presets/llama-2-chat/inference-api.py | 59 ++++++++++++-------------- presets/llama-2/inference-api.py | 61 +++++++++++++-------------- 2 files changed, 58 insertions(+), 62 deletions(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 65fe121e1..e79da0d7d 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -50,7 +50,6 @@ def broadcast_for_shutdown(): """Broadcasts shutdown command to worker processes.""" dist.broadcast_object_list(["shutdown", None, None], src=0) -@timeout(60.0) def broadcast_for_generation(input_string, max_gen_len, temperature, top_p): """Broadcasts generation parameters to worker processes.""" dist.broadcast_object_list(["generate", input_string, { @@ -59,14 +58,27 @@ def broadcast_for_generation(input_string, max_gen_len, temperature, top_p): 'top_p': top_p }], src=0) -@timeout(60.0) -def inference(input_string, max_gen_len, temperature, top_p): - return generator.chat_completion( +@timeout(180.0) +def master_inference(input_string, max_gen_len, temperature, top_p): + if dist.get_world_size() > 1: + try: + # Broadcast generation params to worker processes + broadcast_for_generation(input_string, max_gen_len, temperature, top_p) + except Exception as e: + print("Error in broadcast_for_generation:", str(e)) + raise + + # Master's own generation + try: + return generator.chat_completion( input_string, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, ) + except Exception as e: + print("Error in chat_completion:", str(e)) + raise def shutdown_server(): """Shut down the server.""" @@ -125,25 +137,13 @@ def chat_completion(params: ChatParameters): temperature = parameters.get('temperature', 0.6) top_p = parameters.get('top_p', 0.9) - if dist.get_world_size() > 1: - try: - # Broadcast generation params to worker processes - broadcast_for_generation(input_string, max_gen_len, temperature, top_p) - except Exception as e: - exception_type = type(e).__name__ - if exception_type == "TimeoutError": - print("Broadcast failed - TimeoutError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) - raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) - - # Master's own generation - try: - results = inference(input_string, max_gen_len, temperature, top_p) - except Exception as e: + try: + results = master_inference(input_string, max_gen_len, temperature, top_p) + except Exception as e: exception_type = type(e).__name__ if exception_type == "TimeoutError": - print("Inference failed - TimeoutError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) + print("Master Inference Failed - TimeoutError", e) + raise HTTPException(status_code=408, detail="Request Timed Out") raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) if len(results) == 0: @@ -199,21 +199,18 @@ def worker_listen_tasks(): input_string = config[1] parameters = config[2] print(f"Worker {worker_num} started generation") - inference(input_string, - parameters.get('max_gen_len', None), - parameters.get('temperature', 0.6), - parameters.get('top_p', 0.9) - ) + generator.chat_completion( + input_string, + max_gen_len=parameters.get('max_gen_len', None), + temperature=parameters.get('temperature', 0.6), + top_p=parameters.get('top_p', 0.9), + ) print(f"Worker {worker_num} completed generation") except Exception as e: - exception_type = type(e).__name__ - if exception_type == "TimeoutError": - print("Inference failed - TimeoutError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) print(f"Error in generation: {str(e)}") elif command == "shutdown": print(f"Worker {worker_num} shutting down") - sys.exit(0) + os.killpg(os.getpgrp(), signal.SIGTERM) except torch.distributed.DistBackendError as e: print("torch.distributed.DistBackendError", e) os.killpg(os.getpgrp(), signal.SIGTERM) diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index a027ac598..68a8c67cb 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -50,7 +50,6 @@ def broadcast_for_shutdown(): """Broadcasts shutdown command to worker processes.""" dist.broadcast_object_list(["shutdown", None, None], src=0) -@timeout(60.0) def broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p): """Broadcasts generation parameters to worker processes.""" dist.broadcast_object_list(["text_generate", prompts, { @@ -59,18 +58,31 @@ def broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p): 'top_p': top_p }], src=0) -@timeout(60.0) -def inference(input_string, max_gen_len, temperature, top_p): - return generator.text_completion( - input_string, +@timeout(180.0) +def master_inference(prompts, max_gen_len, temperature, top_p): + if dist.get_world_size() > 1: + try: + # Broadcast generation params to worker processes + broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p) + except Exception as e: + print("Error in broadcast_for_text_generation:", str(e)) + raise + + # Master's own generation + try: + return generator.text_completion( + prompts, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, ) + except Exception as e: + print("Error in text_completion:", str(e)) + raise def shutdown_server(): """Shut down the server.""" - os.kill(os.getpid(), signal.SIGINT) + os.killpg(os.getpgrp(), signal.SIGTERM) # Default values for the generator gen_params = { @@ -122,23 +134,13 @@ def generate_text(params: GenerationParameters): temperature = parameters.get('temperature', 0.6) top_p = parameters.get('top_p', 0.9) - if dist.get_world_size() > 1: - try: - broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p) - except Exception as e: - exception_type = type(e).__name__ - if exception_type == "TimeoutError": - print("Broadcast failed - TimeoutError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) - raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) - - try: - results = inference(prompts, max_gen_len, temperature, top_p) - except Exception as e: + try: + results = master_inference(prompts, max_gen_len, temperature, top_p) + except Exception as e: exception_type = type(e).__name__ if exception_type == "TimeoutError": - print("Inference failed - TimeoutError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) + print("Master Inference Failed - TimeoutError", e) + raise HTTPException(status_code=408, detail="Request Timed Out") raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) if len(results) == 0: @@ -186,21 +188,18 @@ def worker_listen_tasks(): input_string = config[1] parameters = config[2] print(f"Worker {worker_num} started generation") - inference(input_string, - parameters.get('max_gen_len', None), - parameters.get('temperature', 0.6), - parameters.get('top_p', 0.9) - ) + generator.text_completion( + input_string, + max_gen_len=parameters.get('max_gen_len', None), + temperature=parameters.get('temperature', 0.6), + top_p=parameters.get('top_p', 0.9), + ) print(f"Worker {worker_num} completed generation") except Exception as e: - exception_type = type(e).__name__ - if exception_type == "TimeoutError": - print("Inference failed - TimeoutError", e) - os.killpg(os.getpgrp(), signal.SIGTERM) print(f"Error in generation: {str(e)}") elif command == "shutdown": print(f"Worker {worker_num} shutting down") - sys.exit(0) + os.killpg(os.getpgrp(), signal.SIGTERM) except torch.distributed.DistBackendError as e: print("torch.distributed.DistBackendError", e) os.killpg(os.getpgrp(), signal.SIGTERM) From a12f16ca4579d2cab0be832d37dc66033a3976f5 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 1 Nov 2023 14:00:50 -0700 Subject: [PATCH 15/22] fix: headless service variable fixes --- pkg/controllers/workspace_controller.go | 22 ++++++++++++---------- pkg/inference/preset-inference-types.go | 2 +- pkg/inference/preset-inferences.go | 6 +++--- pkg/resources/manifests.go | 8 ++++---- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index 58b650c05..eaf30d3c5 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -93,7 +93,8 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err := c.ensureService(ctx, wObj); err != nil { + createHeadlessService := *(wObj.Resource.Count) > 1 && strings.Contains(string(wObj.Inference.Preset.Name), "llama") + if err := c.ensureService(ctx, wObj, createHeadlessService); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, "workspaceFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -102,7 +103,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj); err != nil { + if err = c.applyInference(ctx, wObj, createHeadlessService); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, "workspaceFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -378,7 +379,7 @@ func (c *WorkspaceReconciler) ensureNodePlugins(ctx context.Context, wObj *kaito } } -func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { +func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace, createHeadlessService bool) error { serviceType := corev1.ServiceTypeClusterIP wAnnotation := wObj.GetAnnotations() @@ -407,12 +408,13 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al if err != nil { return err } - headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj) - err = resources.CreateResource(ctx, headlessService, c.Client) - if err != nil { - return err + if createHeadlessService { + headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj) + err = resources.CreateResource(ctx, headlessService, c.Client) + if err != nil { + return err + } } - wObj.Annotations["headlessServiceCreated"] = "true" return nil } @@ -453,7 +455,7 @@ func (c *WorkspaceReconciler) getInferenceObjFromPreset(ctx context.Context, wOb } // applyInference applies inference spec. -func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { +func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace, createHeadlessService bool) error { if wObj.Inference.Template != nil { // TODO: handle update @@ -489,7 +491,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a return nil } - err = inference.CreatePresetInference(ctx, wObj, inferenceObj, c.Client) + err = inference.CreatePresetInference(ctx, wObj, inferenceObj, createHeadlessService, c.Client) if err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeInferenceStatus, metav1.ConditionFalse, "WorkspaceInferenceStatusFailed", err.Error()); updateErr != nil { diff --git a/pkg/inference/preset-inference-types.go b/pkg/inference/preset-inference-types.go index 50c4a6ccf..a7902447a 100644 --- a/pkg/inference/preset-inference-types.go +++ b/pkg/inference/preset-inference-types.go @@ -23,7 +23,7 @@ const ( const ( DefaultMaxRestarts = "3" DefaultRdzvId = "rdzv_id" - DefaultRdzvBackend = "c10d" + DefaultRdzvBackend = "c10d" // Pytorch Native Distributed data store DefaultRdzvEndpoint = "localhost:29500" // llama-2-13b-chat-0.llama-headless.default.svc.cluster.local:29500 ) diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 2c2b81c83..8db451702 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -86,7 +86,7 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job" inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d" inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] = - fmt.Sprintf("%s-0.llama-headless.default.svc.cluster.local:29500", wObj.Inference.Preset.Name) + fmt.Sprintf("%s-0.llama-headless.default.svc.cluster.local:29500", wObj.Name) } } else if inferenceObj.ModelName == "Falcon" { inferenceObj.TorchRunParams["config_file"] = "config.yaml" @@ -99,7 +99,7 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj PresetInferenceParam, kubeClient client.Client) error { + inferenceObj PresetInferenceParam, createHeadlessService bool, kubeClient client.Client) error { if inferenceObj.TorchRunParams != nil { if err := setTorchParams(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -114,7 +114,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work switch inferenceObj.ModelName { case "LLaMa2": depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands, - containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount, createHeadlessService) case "Falcon": depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands, containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) diff --git a/pkg/resources/manifests.go b/pkg/resources/manifests.go index a08b64b4f..505930253 100644 --- a/pkg/resources/manifests.go +++ b/pkg/resources/manifests.go @@ -18,7 +18,7 @@ import ( var controller = true func GenerateHeadlessServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) *corev1.Service { - serviceName := fmt.Sprintf("%s-headless", workspaceObj.Inference.Preset.Name) + serviceName := fmt.Sprintf("%s-headless", workspaceObj.Name) selector := make(map[string]string) for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels { selector[k] = v @@ -97,7 +97,7 @@ func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Wo func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string, imagePullSecretRefs []corev1.LocalObjectReference, replicas int, commands []string, containerPorts []corev1.ContainerPort, livenessProbe, readinessProbe *corev1.Probe, resourceRequirements corev1.ResourceRequirements, - tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount) *appsv1.StatefulSet { + tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount, createHeadlessService bool) *appsv1.StatefulSet { // Gather label requirements from workspaceObj's label selector labelRequirements := make([]v1.LabelSelectorRequirement, 0, len(workspaceObj.Resource.LabelSelector.MatchLabels)) @@ -162,8 +162,8 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha }, }, } - if val, ok := workspaceObj.Annotations["headlessServiceCreated"]; ok && val == "true" { - ss.Spec.ServiceName = fmt.Sprintf("%s-headless", workspaceObj.Inference.Preset.Name) + if createHeadlessService { + ss.Spec.ServiceName = fmt.Sprintf("%s-headless", workspaceObj.Name) } return ss } From 2f2821d200d4708df96953e0da31c33942df2cd0 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 1 Nov 2023 14:03:04 -0700 Subject: [PATCH 16/22] fix: shutdown --- presets/llama-2-chat/inference-api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index e79da0d7d..ed80be8f0 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -82,7 +82,7 @@ def master_inference(input_string, max_gen_len, temperature, top_p): def shutdown_server(): """Shut down the server.""" - os.kill(os.getpid(), signal.SIGINT) + os.killpg(os.getpgrp(), signal.SIGTERM) # Default values for the generator gen_params = { From 86ff9e0daa1d2ab79c19ebc953bd9c3d12a562ca Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 1 Nov 2023 14:04:37 -0700 Subject: [PATCH 17/22] fix: dockerfile --- docker/presets/falcon/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/presets/falcon/Dockerfile b/docker/presets/falcon/Dockerfile index d81b4e325..f8e3dfa1d 100644 --- a/docker/presets/falcon/Dockerfile +++ b/docker/presets/falcon/Dockerfile @@ -1,5 +1,5 @@ # Use the NVIDIA PyTorch image as a base -FROM nvcr.io/nvidia/pytorch:23.06-py3 +FROM nvcr.io/nvidia/pytorch:23.10-py3 # Set the working directory WORKDIR /workspace/falcon From ca788a43f68b01c3addf15242d3aadb3ec069a7c Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Wed, 1 Nov 2023 15:51:40 -0700 Subject: [PATCH 18/22] fix: update docker file paths to avoid conflicting volume mounts --- docker/presets/falcon/Dockerfile | 2 +- docker/presets/llama-2/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/presets/falcon/Dockerfile b/docker/presets/falcon/Dockerfile index f8e3dfa1d..390e2c891 100644 --- a/docker/presets/falcon/Dockerfile +++ b/docker/presets/falcon/Dockerfile @@ -13,6 +13,6 @@ RUN pip install --no-cache-dir -r requirements.txt ARG FALCON_MODEL_NAME # Copy the entire model to the weights directory -COPY /home/falcon/${FALCON_MODEL_NAME} /workspace/falcon/weights +COPY /falcon/${FALCON_MODEL_NAME} /workspace/falcon/weights # Copy the entire 'presets/falcon' folder to the working directory COPY /home/presets/falcon /workspace/falcon diff --git a/docker/presets/llama-2/Dockerfile b/docker/presets/llama-2/Dockerfile index 8fa70579c..250d525de 100644 --- a/docker/presets/llama-2/Dockerfile +++ b/docker/presets/llama-2/Dockerfile @@ -24,5 +24,5 @@ RUN pip install 'uvicorn[standard]' ARG LLAMA_VERSION ARG SRC_DIR -ADD /home/llama/${LLAMA_VERSION} /workspace/llama/llama-2/weights +ADD /llama/${LLAMA_VERSION} /workspace/llama/llama-2/weights ADD ${SRC_DIR} /workspace/llama/llama-2 From 2213d85e5452cbc01efa281c22e507965e384711 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 2 Nov 2023 10:59:45 -0700 Subject: [PATCH 19/22] fix: fix service naming, and add service namespace and ownerreferences which are required --- pkg/inference/preset-inferences.go | 2 +- pkg/resources/manifests.go | 12 +++++++++++- presets/test/docker.yaml | 6 +++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 8db451702..1518c3da7 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -86,7 +86,7 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job" inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d" inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] = - fmt.Sprintf("%s-0.llama-headless.default.svc.cluster.local:29500", wObj.Name) + fmt.Sprintf("%s-0.%s-headless.default.svc.cluster.local:29500", wObj.Name, wObj.Name) } } else if inferenceObj.ModelName == "Falcon" { inferenceObj.TorchRunParams["config_file"] = "config.yaml" diff --git a/pkg/resources/manifests.go b/pkg/resources/manifests.go index 505930253..6ed99330e 100644 --- a/pkg/resources/manifests.go +++ b/pkg/resources/manifests.go @@ -25,7 +25,17 @@ func GenerateHeadlessServiceManifest(ctx context.Context, workspaceObj *kaitov1a } return &corev1.Service{ ObjectMeta: v1.ObjectMeta{ - Name: serviceName, + Name: serviceName, + Namespace: workspaceObj.Namespace, + OwnerReferences: []v1.OwnerReference{ + { + APIVersion: kaitov1alpha1.GroupVersion.String(), + Kind: "Workspace", + UID: workspaceObj.UID, + Name: workspaceObj.Name, + Controller: &controller, + }, + }, }, Spec: corev1.ServiceSpec{ Selector: selector, diff --git a/presets/test/docker.yaml b/presets/test/docker.yaml index ef4e152a4..2fb4ab10c 100644 --- a/presets/test/docker.yaml +++ b/presets/test/docker.yaml @@ -21,13 +21,13 @@ spec: - name: host-volume mountPath: /home - name: llama-volume - mountPath: /home/llama + mountPath: /llama - name: falcon-volume - mountPath: /home/falcon + mountPath: /falcon volumes: - name: host-volume hostPath: - path: /actions-runner/_work/kdm/kdm + path: /home/runner-0/runner/_work/kaito/kaito type: Directory - name: llama-volume hostPath: From 8a9b4aee129c0d3ba04923d7d7d1703727222881 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 2 Nov 2023 13:46:00 -0700 Subject: [PATCH 20/22] fix: remove logs --- presets/llama-2-chat/inference-api.py | 4 ---- presets/llama-2/inference-api.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index ed80be8f0..fcb1c7cc1 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -189,23 +189,19 @@ def worker_listen_tasks(): print(f"Worker {worker_num} ready to recieve next command") config = [None] * 3 # Command and its associated data try: - print(f"Worker {worker_num} entered broadcast listen") dist.broadcast_object_list(config, src=0) - print(f"Worker {worker_num} left broadcast listen") command = config[0] if command == "generate": try: input_string = config[1] parameters = config[2] - print(f"Worker {worker_num} started generation") generator.chat_completion( input_string, max_gen_len=parameters.get('max_gen_len', None), temperature=parameters.get('temperature', 0.6), top_p=parameters.get('top_p', 0.9), ) - print(f"Worker {worker_num} completed generation") except Exception as e: print(f"Error in generation: {str(e)}") elif command == "shutdown": diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 68a8c67cb..7e1810049 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -178,23 +178,19 @@ def worker_listen_tasks(): print(f"Worker {worker_num} ready to recieve next command") config = [None] * 3 # Command and its associated data try: - print(f"Worker {worker_num} entered broadcast listen") dist.broadcast_object_list(config, src=0) - print(f"Worker {worker_num} left broadcast listen") command = config[0] if command == "text_generate": try: input_string = config[1] parameters = config[2] - print(f"Worker {worker_num} started generation") generator.text_completion( input_string, max_gen_len=parameters.get('max_gen_len', None), temperature=parameters.get('temperature', 0.6), top_p=parameters.get('top_p', 0.9), ) - print(f"Worker {worker_num} completed generation") except Exception as e: print(f"Error in generation: {str(e)}") elif command == "shutdown": From c52b40c684a833f9afdeb941bb83562d8e6730b8 Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 2 Nov 2023 13:47:16 -0700 Subject: [PATCH 21/22] fix --- api/v1alpha1/zz_generated.deepcopy.go | 17 ++--------------- cmd/main.go | 4 ++-- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 4461ee6a0..bfdc825b4 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1,21 +1,8 @@ //go:build !ignore_autogenerated // +build !ignore_autogenerated -/* -Copyright 2023. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. // Code generated by controller-gen. DO NOT EDIT. diff --git a/cmd/main.go b/cmd/main.go index 0ceaa87c1..9c71ad26c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,5 +1,5 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. package main import ( From fd3f9dbfb0a5280eda4a890f7b5241e46514610d Mon Sep 17 00:00:00 2001 From: ishaansehgal99 Date: Thu, 2 Nov 2023 14:01:56 -0700 Subject: [PATCH 22/22] fix: rename create to useHeadlessService --- pkg/controllers/workspace_controller.go | 14 +++++++------- pkg/inference/preset-inferences.go | 4 ++-- pkg/resources/manifests.go | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index 777baff78..27ef2ca98 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -93,8 +93,8 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - createHeadlessService := *(wObj.Resource.Count) > 1 && strings.Contains(string(wObj.Inference.Preset.Name), "llama") - if err := c.ensureService(ctx, wObj, createHeadlessService); err != nil { + useHeadlessService := *(wObj.Resource.Count) > 1 && strings.Contains(string(wObj.Inference.Preset.Name), "llama") + if err := c.ensureService(ctx, wObj, useHeadlessService); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, "workspaceFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -103,7 +103,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj, createHeadlessService); err != nil { + if err = c.applyInference(ctx, wObj, useHeadlessService); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, "workspaceFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -379,7 +379,7 @@ func (c *WorkspaceReconciler) ensureNodePlugins(ctx context.Context, wObj *kaito } } -func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace, createHeadlessService bool) error { +func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error { serviceType := corev1.ServiceTypeClusterIP wAnnotation := wObj.GetAnnotations() @@ -408,7 +408,7 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al if err != nil { return err } - if createHeadlessService { + if useHeadlessService { headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj) err = resources.CreateResource(ctx, headlessService, c.Client) if err != nil { @@ -455,7 +455,7 @@ func (c *WorkspaceReconciler) getInferenceObjFromPreset(ctx context.Context, wOb } // applyInference applies inference spec. -func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace, createHeadlessService bool) error { +func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error { inferErr := func() error { if wObj.Inference.Template != nil { @@ -490,7 +490,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a } } else if apierrors.IsNotFound(err) { // Need to create a new workload - workloadObj, err := inference.CreatePresetInference(ctx, wObj, inferenceObj, createHeadlessService, c.Client) + workloadObj, err := inference.CreatePresetInference(ctx, wObj, inferenceObj, useHeadlessService, c.Client) if err != nil { return err } diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 46c1b4eb9..ae2a6aeb9 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -97,7 +97,7 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj PresetInferenceParam, createHeadlessService bool, kubeClient client.Client) (client.Object, error) { + inferenceObj PresetInferenceParam, useHeadlessService bool, kubeClient client.Client) (client.Object, error) { if inferenceObj.TorchRunParams != nil { if err := setTorchParams(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -112,7 +112,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work switch inferenceObj.ModelName { case "LLaMa2": depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands, - containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount, createHeadlessService) + containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount, useHeadlessService) case "Falcon": depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands, containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) diff --git a/pkg/resources/manifests.go b/pkg/resources/manifests.go index 6ed99330e..195baff78 100644 --- a/pkg/resources/manifests.go +++ b/pkg/resources/manifests.go @@ -107,7 +107,7 @@ func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Wo func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string, imagePullSecretRefs []corev1.LocalObjectReference, replicas int, commands []string, containerPorts []corev1.ContainerPort, livenessProbe, readinessProbe *corev1.Probe, resourceRequirements corev1.ResourceRequirements, - tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount, createHeadlessService bool) *appsv1.StatefulSet { + tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount, useHeadlessService bool) *appsv1.StatefulSet { // Gather label requirements from workspaceObj's label selector labelRequirements := make([]v1.LabelSelectorRequirement, 0, len(workspaceObj.Resource.LabelSelector.MatchLabels)) @@ -172,7 +172,7 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha }, }, } - if createHeadlessService { + if useHeadlessService { ss.Spec.ServiceName = fmt.Sprintf("%s-headless", workspaceObj.Name) } return ss