diff --git a/.github/workflows/preset-image-build.yml b/.github/workflows/preset-image-build.yml index c0c7ef498..ea2fe382d 100644 --- a/.github/workflows/preset-image-build.yml +++ b/.github/workflows/preset-image-build.yml @@ -1,5 +1,4 @@ name: Build and Push Preset Models - on: pull_request: branches: @@ -17,7 +16,7 @@ on: - 'presets/llama-2-chat/**' workflow_dispatch: inputs: - image_tag: + image_tag_name: description: 'Image Tag' required: true @@ -48,7 +47,13 @@ jobs: - name: Get Modified files run: | - files=$(git diff --name-only HEAD^ HEAD) + git fetch origin main:main + current_branch=$(git rev-parse --abbrev-ref HEAD) + if [ "$current_branch" == "main" ] || [ "$current_branch" == "master" ]; then + files=$(git diff --name-only HEAD^ HEAD) + else + files=$(git diff --name-only main...HEAD) + fi echo "Modified files: $files" FILES_MODIFIED="" while IFS= read -r file; do @@ -93,9 +98,9 @@ jobs: - name: Set Image Tag id: set_tag run: | - if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag }}" ]]; then + if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag_name }}" ]]; then echo "Using workflow dispatch to set image tag" - echo "image_tag=${{ github.event.inputs.image_tag }}" >> $GITHUB_OUTPUT + echo "image_tag=${{ github.event.inputs.image_tag_name }}" >> $GITHUB_OUTPUT else echo "Setting image tag based on version set" echo "image_tag=${{ env.VERSION }}" >> $GITHUB_OUTPUT @@ -175,7 +180,7 @@ jobs: - name: 'Attach and Login to ACR' id: acr_login run: | - if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag }}" ]]; then + if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag_name }}" ]]; then ACR_NAME="aimodelsregistry" else ACR_NAME="aimodelsregistrytest" diff --git a/docker/presets/falcon/Dockerfile b/docker/presets/falcon/Dockerfile index d81b4e325..390e2c891 100644 --- a/docker/presets/falcon/Dockerfile +++ b/docker/presets/falcon/Dockerfile @@ -1,5 +1,5 @@ # Use the NVIDIA PyTorch image as a base -FROM nvcr.io/nvidia/pytorch:23.06-py3 +FROM nvcr.io/nvidia/pytorch:23.10-py3 # Set the working directory WORKDIR /workspace/falcon @@ -13,6 +13,6 @@ RUN pip install --no-cache-dir -r requirements.txt ARG FALCON_MODEL_NAME # Copy the entire model to the weights directory -COPY /home/falcon/${FALCON_MODEL_NAME} /workspace/falcon/weights +COPY /falcon/${FALCON_MODEL_NAME} /workspace/falcon/weights # Copy the entire 'presets/falcon' folder to the working directory COPY /home/presets/falcon /workspace/falcon diff --git a/docker/presets/llama-2/Dockerfile b/docker/presets/llama-2/Dockerfile index 3f091053b..250d525de 100644 --- a/docker/presets/llama-2/Dockerfile +++ b/docker/presets/llama-2/Dockerfile @@ -10,7 +10,7 @@ # --build-arg SRC_DIR=/home/presets/llama-2-chat \ # -t llama-2-7b-chat:latest . -FROM nvcr.io/nvidia/pytorch:23.06-py3 +FROM nvcr.io/nvidia/pytorch:23.10-py3 WORKDIR /workspace RUN git clone https://github.com/facebookresearch/llama @@ -24,5 +24,5 @@ RUN pip install 'uvicorn[standard]' ARG LLAMA_VERSION ARG SRC_DIR -ADD /home/llama/${LLAMA_VERSION} /workspace/llama/llama-2/weights +ADD /llama/${LLAMA_VERSION} /workspace/llama/llama-2/weights ADD ${SRC_DIR} /workspace/llama/llama-2 diff --git a/pkg/controllers/workspace_controller.go b/pkg/controllers/workspace_controller.go index 685cdd36b..27ef2ca98 100644 --- a/pkg/controllers/workspace_controller.go +++ b/pkg/controllers/workspace_controller.go @@ -93,7 +93,8 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err := c.ensureService(ctx, wObj); err != nil { + useHeadlessService := *(wObj.Resource.Count) > 1 && strings.Contains(string(wObj.Inference.Preset.Name), "llama") + if err := c.ensureService(ctx, wObj, useHeadlessService); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, "workspaceFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -102,7 +103,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka return reconcile.Result{}, err } - if err = c.applyInference(ctx, wObj); err != nil { + if err = c.applyInference(ctx, wObj, useHeadlessService); err != nil { if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse, "workspaceFailed", err.Error()); updateErr != nil { klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj)) @@ -378,7 +379,7 @@ func (c *WorkspaceReconciler) ensureNodePlugins(ctx context.Context, wObj *kaito } } -func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { +func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error { serviceType := corev1.ServiceTypeClusterIP wAnnotation := wObj.GetAnnotations() @@ -407,6 +408,13 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al if err != nil { return err } + if useHeadlessService { + headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj) + err = resources.CreateResource(ctx, headlessService, c.Client) + if err != nil { + return err + } + } return nil } @@ -447,7 +455,7 @@ func (c *WorkspaceReconciler) getInferenceObjFromPreset(ctx context.Context, wOb } // applyInference applies inference spec. -func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { +func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error { inferErr := func() error { if wObj.Inference.Template != nil { @@ -482,7 +490,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a } } else if apierrors.IsNotFound(err) { // Need to create a new workload - workloadObj, err := inference.CreatePresetInference(ctx, wObj, inferenceObj, c.Client) + workloadObj, err := inference.CreatePresetInference(ctx, wObj, inferenceObj, useHeadlessService, c.Client) if err != nil { return err } diff --git a/pkg/inference/preset-inference-types.go b/pkg/inference/preset-inference-types.go index 58bf8c35b..6eec355a4 100644 --- a/pkg/inference/preset-inference-types.go +++ b/pkg/inference/preset-inference-types.go @@ -19,6 +19,14 @@ const ( DefaultMasterPort = "29500" ) +// Torch Rendezvous Params +const ( + DefaultMaxRestarts = "3" + DefaultRdzvId = "rdzv_id" + DefaultRdzvBackend = "c10d" // Pytorch Native Distributed data store + DefaultRdzvEndpoint = "localhost:29500" // llama-2-13b-chat-0.llama-headless.default.svc.cluster.local:29500 +) + const ( DefaultConfigFile = "config.yaml" DefaultNumProcesses = "1" @@ -56,6 +64,13 @@ var ( "master_port": DefaultMasterPort, } + defaultTorchRunRdzvParams = map[string]string{ + "max_restarts": DefaultMaxRestarts, + "rdzv_id": DefaultRdzvId, + "rdzv_backend": DefaultRdzvBackend, + "rdzv_endpoint": DefaultRdzvEndpoint, + } + defaultAccelerateParams = map[string]string{ "config_file": DefaultConfigFile, "num_processes": DefaultNumProcesses, @@ -78,6 +93,7 @@ type PresetInferenceParam struct { GPURequirement string GPUMemoryRequirement string TorchRunParams map[string]string + TorchRunRdzvParams map[string]string ModelRunParams map[string]string InferenceFile string // DeploymentTimeout defines the maximum duration for pulling the Preset image. @@ -104,6 +120,7 @@ var ( GPURequirement: "1", GPUMemoryRequirement: "16Gi", TorchRunParams: defaultTorchRunParams, + TorchRunRdzvParams: defaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(10) * time.Minute, @@ -120,6 +137,7 @@ var ( GPURequirement: "2", GPUMemoryRequirement: "16Gi", TorchRunParams: defaultTorchRunParams, + TorchRunRdzvParams: defaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(20) * time.Minute, @@ -136,6 +154,7 @@ var ( GPURequirement: "8", GPUMemoryRequirement: "19Gi", TorchRunParams: defaultTorchRunParams, + TorchRunRdzvParams: defaultTorchRunRdzvParams, ModelRunParams: llamaRunParams, InferenceFile: llamaChatInferenceFile, DeploymentTimeout: time.Duration(30) * time.Minute, diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index 9bdb195b4..ae2a6aeb9 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -79,6 +79,13 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 inferenceObj.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP inferenceObj.TorchRunParams["master_port"] = "29500" } + if inferenceObj.TorchRunRdzvParams != nil { + inferenceObj.TorchRunRdzvParams["max_restarts"] = "3" + inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job" + inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d" + inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] = + fmt.Sprintf("%s-0.%s-headless.default.svc.cluster.local:29500", wObj.Name, wObj.Name) + } } else if inferenceObj.ModelName == "Falcon" { inferenceObj.TorchRunParams["config_file"] = "config.yaml" inferenceObj.TorchRunParams["num_processes"] = "1" @@ -90,7 +97,7 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1 } func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, - inferenceObj PresetInferenceParam, kubeClient client.Client) (client.Object, error) { + inferenceObj PresetInferenceParam, useHeadlessService bool, kubeClient client.Client) (client.Object, error) { if inferenceObj.TorchRunParams != nil { if err := setTorchParams(ctx, kubeClient, workspaceObj, inferenceObj); err != nil { klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj) @@ -105,7 +112,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work switch inferenceObj.ModelName { case "LLaMa2": depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands, - containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) + containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount, useHeadlessService) case "Falcon": depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands, containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount) @@ -121,6 +128,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work func prepareInferenceParameters(ctx context.Context, inferenceObj PresetInferenceParam) ([]string, corev1.ResourceRequirements) { torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams) + torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams) modelCommand := buildCommandStr(inferenceObj.InferenceFile, inferenceObj.ModelRunParams) commands := shellCommand(torchCommand + " " + modelCommand) diff --git a/pkg/resources/manifests.go b/pkg/resources/manifests.go index a7cdc83de..195baff78 100644 --- a/pkg/resources/manifests.go +++ b/pkg/resources/manifests.go @@ -17,6 +17,42 @@ import ( var controller = true +func GenerateHeadlessServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) *corev1.Service { + serviceName := fmt.Sprintf("%s-headless", workspaceObj.Name) + selector := make(map[string]string) + for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels { + selector[k] = v + } + return &corev1.Service{ + ObjectMeta: v1.ObjectMeta{ + Name: serviceName, + Namespace: workspaceObj.Namespace, + OwnerReferences: []v1.OwnerReference{ + { + APIVersion: kaitov1alpha1.GroupVersion.String(), + Kind: "Workspace", + UID: workspaceObj.UID, + Name: workspaceObj.Name, + Controller: &controller, + }, + }, + }, + Spec: corev1.ServiceSpec{ + Selector: selector, + ClusterIP: "None", + Ports: []corev1.ServicePort{ + { + Name: "torchrun", + Protocol: corev1.ProtocolTCP, + Port: 29500, + TargetPort: intstr.FromInt(29500), + }, + }, + PublishNotReadyAddresses: true, + }, + } +} + func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, serviceType corev1.ServiceType, isStatefulSet bool) *corev1.Service { selector := make(map[string]string) for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels { @@ -71,7 +107,7 @@ func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Wo func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string, imagePullSecretRefs []corev1.LocalObjectReference, replicas int, commands []string, containerPorts []corev1.ContainerPort, livenessProbe, readinessProbe *corev1.Probe, resourceRequirements corev1.ResourceRequirements, - tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount) *appsv1.StatefulSet { + tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount, useHeadlessService bool) *appsv1.StatefulSet { // Gather label requirements from workspaceObj's label selector labelRequirements := make([]v1.LabelSelectorRequirement, 0, len(workspaceObj.Resource.LabelSelector.MatchLabels)) @@ -82,7 +118,7 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha Values: []string{value}, }) } - return &appsv1.StatefulSet{ + ss := &appsv1.StatefulSet{ ObjectMeta: v1.ObjectMeta{ Name: workspaceObj.Name, Namespace: workspaceObj.Namespace, @@ -136,6 +172,10 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha }, }, } + if useHeadlessService { + ss.Spec.ServiceName = fmt.Sprintf("%s-headless", workspaceObj.Name) + } + return ss } func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string, diff --git a/presets/llama-2-chat/inference-api.py b/presets/llama-2-chat/inference-api.py index 12063350f..fcb1c7cc1 100644 --- a/presets/llama-2-chat/inference-api.py +++ b/presets/llama-2-chat/inference-api.py @@ -4,7 +4,10 @@ import uvicorn from pydantic import BaseModel from typing import Optional +import multiprocessing +import multiprocessing.pool import threading +import functools from llama import Llama import torch @@ -25,6 +28,20 @@ should_shutdown = False +def timeout(max_timeout): + """Timeout decorator, parameter in seconds.""" + def timeout_decorator(item): + """Wrap the original function.""" + @functools.wraps(item) + def func_wrapper(*args, **kwargs): + """Closure for function.""" + with multiprocessing.pool.ThreadPool(processes=1) as pool: + async_result = pool.apply_async(item, args, kwargs) + # raises a TimeoutError if execution exceeds max_timeout + return async_result.get(max_timeout) + return func_wrapper + return timeout_decorator + def build_generator(params): """Build Llama generator from provided parameters.""" return Llama.build(**params) @@ -41,9 +58,31 @@ def broadcast_for_generation(input_string, max_gen_len, temperature, top_p): 'top_p': top_p }], src=0) +@timeout(180.0) +def master_inference(input_string, max_gen_len, temperature, top_p): + if dist.get_world_size() > 1: + try: + # Broadcast generation params to worker processes + broadcast_for_generation(input_string, max_gen_len, temperature, top_p) + except Exception as e: + print("Error in broadcast_for_generation:", str(e)) + raise + + # Master's own generation + try: + return generator.chat_completion( + input_string, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + except Exception as e: + print("Error in chat_completion:", str(e)) + raise + def shutdown_server(): """Shut down the server.""" - os.kill(os.getpid(), signal.SIGINT) + os.killpg(os.getpgrp(), signal.SIGTERM) # Default values for the generator gen_params = { @@ -98,19 +137,13 @@ def chat_completion(params: ChatParameters): temperature = parameters.get('temperature', 0.6) top_p = parameters.get('top_p', 0.9) - if dist.get_world_size() > 1: - # Broadcast generation params to worker processes - broadcast_for_generation(input_string, max_gen_len, temperature, top_p) - - # Master's own generation - try: - results = generator.chat_completion( - input_string, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) - except Exception as e: + try: + results = master_inference(input_string, max_gen_len, temperature, top_p) + except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Master Inference Failed - TimeoutError", e) + raise HTTPException(status_code=408, detail="Request Timed Out") raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) if len(results) == 0: @@ -147,34 +180,39 @@ def health_check(): return {"status": "Healthy"} def start_worker_server(): - uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") + uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) def worker_listen_tasks(): while True: worker_num = dist.get_rank() print(f"Worker {worker_num} ready to recieve next command") config = [None] * 3 # Command and its associated data - dist.broadcast_object_list(config, src=0) - command = config[0] - - if command == "generate": - try: - input_string = config[1] - parameters = config[2] - generator.chat_completion( - input_string, - max_gen_len=parameters.get('max_gen_len', None), - temperature=parameters.get('temperature', 0.6), - top_p=parameters.get('top_p', 0.9) - ) - print(f"Worker {worker_num} completed generation") - except Exception as e: - print(f"Error in generation: {str(e)}") - elif command == "shutdown": - print(f"Worker {worker_num} shutting down") - sys.exit(0) - + try: + dist.broadcast_object_list(config, src=0) + command = config[0] + + if command == "generate": + try: + input_string = config[1] + parameters = config[2] + generator.chat_completion( + input_string, + max_gen_len=parameters.get('max_gen_len', None), + temperature=parameters.get('temperature', 0.6), + top_p=parameters.get('top_p', 0.9), + ) + except Exception as e: + print(f"Error in generation: {str(e)}") + elif command == "shutdown": + print(f"Worker {worker_num} shutting down") + os.killpg(os.getpgrp(), signal.SIGTERM) + except torch.distributed.DistBackendError as e: + print("torch.distributed.DistBackendError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) + except Exception as e: + print(f"Error in Worker Listen Task", e) + os.killpg(os.getpgrp(), signal.SIGTERM) if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -195,17 +233,22 @@ def worker_listen_tasks(): # Uncomment to enable worker logs # sys.stdout = sys.__stdout__ - # If the current process is the locally ranked 0 (i.e., the primary process) - # on its node, then it starts a worker server that exposes a health check endpoint. - if local_rank == 0: - app_worker = FastAPI() - setup_worker_routes() - - # Start the worker server in a separate thread. This worker server will - # provide a healthz endpoint for monitoring the health of the node. - server_thread = threading.Thread(target=start_worker_server, daemon=True) - server_thread.start() - - # Regardless of local rank, all non-globally-0-ranked processes will listen - # for tasks (like chat completion) from the main server. - worker_listen_tasks() + os.setpgrp() + try: + # If the current process is the locally ranked 0 (i.e., the primary process) + # on its node, then it starts a worker server that exposes a health check endpoint. + if local_rank == 0: + app_worker = FastAPI() + setup_worker_routes() + + # Start the worker server in a separate thread. This worker server will + # provide a healthz endpoint for monitoring the health of the node. + server_thread = threading.Thread(target=start_worker_server, daemon=True) + server_thread.start() + + # Regardless of local rank, all non-globally-0-ranked processes will listen + # for tasks (like chat completion) from the main server. + worker_listen_tasks() + finally: + # Additional fail-safe (to ensure no lingering processes) + os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/llama-2/inference-api.py b/presets/llama-2/inference-api.py index 70aa12e25..7e1810049 100644 --- a/presets/llama-2/inference-api.py +++ b/presets/llama-2/inference-api.py @@ -4,9 +4,10 @@ import uvicorn from pydantic import BaseModel from typing import Optional +import multiprocessing +import multiprocessing.pool import threading -import time -from multiprocessing import Value +import functools from llama import Llama import torch @@ -27,6 +28,20 @@ should_shutdown = False +def timeout(max_timeout): + """Timeout decorator, parameter in seconds.""" + def timeout_decorator(item): + """Wrap the original function.""" + @functools.wraps(item) + def func_wrapper(*args, **kwargs): + """Closure for function.""" + with multiprocessing.pool.ThreadPool(processes=1) as pool: + async_result = pool.apply_async(item, args, kwargs) + # raises a TimeoutError if execution exceeds max_timeout + return async_result.get(max_timeout) + return func_wrapper + return timeout_decorator + def build_generator(params): """Build Llama generator from provided parameters.""" return Llama.build(**params) @@ -43,9 +58,31 @@ def broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p): 'top_p': top_p }], src=0) +@timeout(180.0) +def master_inference(prompts, max_gen_len, temperature, top_p): + if dist.get_world_size() > 1: + try: + # Broadcast generation params to worker processes + broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p) + except Exception as e: + print("Error in broadcast_for_text_generation:", str(e)) + raise + + # Master's own generation + try: + return generator.text_completion( + prompts, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + except Exception as e: + print("Error in text_completion:", str(e)) + raise + def shutdown_server(): """Shut down the server.""" - os.kill(os.getpid(), signal.SIGINT) + os.killpg(os.getpgrp(), signal.SIGTERM) # Default values for the generator gen_params = { @@ -97,17 +134,13 @@ def generate_text(params: GenerationParameters): temperature = parameters.get('temperature', 0.6) top_p = parameters.get('top_p', 0.9) - if dist.get_world_size() > 1: - broadcast_for_text_generation(prompts, max_gen_len, temperature, top_p) - try: - results = generator.text_completion( - prompts, - max_gen_len=max_gen_len, - temperature=temperature, - top_p=top_p, - ) - except Exception as e: + results = master_inference(prompts, max_gen_len, temperature, top_p) + except Exception as e: + exception_type = type(e).__name__ + if exception_type == "TimeoutError": + print("Master Inference Failed - TimeoutError", e) + raise HTTPException(status_code=408, detail="Request Timed Out") raise HTTPException(status_code=400, detail="Request Failed: " + str(e)) if len(results) == 0: @@ -136,33 +169,39 @@ def health_check(): return {"status": "Healthy"} def start_worker_server(): - uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) print(f"Worker {dist.get_rank()} HTTP health server started at port 5000") + uvicorn.run(app=app_worker, host='0.0.0.0', port=5000) -def worker_listen_tasks(): +def worker_listen_tasks(): while True: worker_num = dist.get_rank() - print(f"Worker {worker_num} ready to receive next command") - config = [None] * 3 - dist.broadcast_object_list(config, src=0) - command = config[0] - - if command == "text_generate": - try: - prompts = config[1] - parameters = config[2] - generator.text_completion( - prompts, - max_gen_len=parameters.get('max_gen_len', None), - temperature=parameters.get('temperature', 0.6), - top_p=parameters.get('top_p', 0.9) - ) - print(f"Worker {worker_num} completed generation") - except Exception as e: - print(f"Error in generation: {str(e)}") - elif command == "shutdown": - print(f"Worker {worker_num} shutting down") - sys.exit(0) + print(f"Worker {worker_num} ready to recieve next command") + config = [None] * 3 # Command and its associated data + try: + dist.broadcast_object_list(config, src=0) + command = config[0] + + if command == "text_generate": + try: + input_string = config[1] + parameters = config[2] + generator.text_completion( + input_string, + max_gen_len=parameters.get('max_gen_len', None), + temperature=parameters.get('temperature', 0.6), + top_p=parameters.get('top_p', 0.9), + ) + except Exception as e: + print(f"Error in generation: {str(e)}") + elif command == "shutdown": + print(f"Worker {worker_num} shutting down") + os.killpg(os.getpgrp(), signal.SIGTERM) + except torch.distributed.DistBackendError as e: + print("torch.distributed.DistBackendError", e) + os.killpg(os.getpgrp(), signal.SIGTERM) + except Exception as e: + print(f"Error in Worker Listen Task", e) + os.killpg(os.getpgrp(), signal.SIGTERM) if __name__ == "__main__": # Fetch the LOCAL_RANK environment variable to determine the rank of this process @@ -183,17 +222,22 @@ def worker_listen_tasks(): # Uncomment to enable worker logs # sys.stdout = sys.__stdout__ - # If the current process is the locally ranked 0 (i.e., the primary process) - # on its node, then it starts a worker server that exposes a health check endpoint. - if local_rank == 0: - app_worker = FastAPI() - setup_worker_routes() - - # Start the worker server in a separate thread. This worker server will - # provide a healthz endpoint for monitoring the health of the node. - server_thread = threading.Thread(target=start_worker_server, daemon=True) - server_thread.start() - - # Regardless of local rank, all non-globally-0-ranked processes will listen - # for tasks (like text completion) from the main server. - worker_listen_tasks() + os.setpgrp() + try: + # If the current process is the locally ranked 0 (i.e., the primary process) + # on its node, then it starts a worker server that exposes a health check endpoint. + if local_rank == 0: + app_worker = FastAPI() + setup_worker_routes() + + # Start the worker server in a separate thread. This worker server will + # provide a healthz endpoint for monitoring the health of the node. + server_thread = threading.Thread(target=start_worker_server, daemon=True) + server_thread.start() + + # Regardless of local rank, all non-globally-0-ranked processes will listen + # for tasks (like chat completion) from the main server. + worker_listen_tasks() + finally: + # Additional fail-safe (to ensure no lingering processes) + os.killpg(os.getpgrp(), signal.SIGTERM) \ No newline at end of file diff --git a/presets/test/docker.yaml b/presets/test/docker.yaml index ae1f4ee62..2fb4ab10c 100644 --- a/presets/test/docker.yaml +++ b/presets/test/docker.yaml @@ -21,9 +21,9 @@ spec: - name: host-volume mountPath: /home - name: llama-volume - mountPath: /home/llama + mountPath: /llama - name: falcon-volume - mountPath: /home/falcon + mountPath: /falcon volumes: - name: host-volume hostPath: diff --git a/presets/test/manifests/llama-headless.yaml b/presets/test/manifests/llama-headless.yaml new file mode 100644 index 000000000..e0514564f --- /dev/null +++ b/presets/test/manifests/llama-headless.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: Service +metadata: + name: llama-headless +spec: + selector: + app: llama + clusterIP: None + ports: + - name: torchrun + protocol: TCP + port: 29500 + targetPort: 29500 + publishNotReadyAddresses: true \ No newline at end of file