Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: inference fault tolerance #108

Merged
merged 23 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b3eec53
fix: fix networking issue inference
ishaansehgal99 Oct 26, 2023
c1d3a18
nit: remove threading
ishaansehgal99 Oct 26, 2023
32586dd
fix: ensure child process
ishaansehgal99 Oct 27, 2023
933efc6
fix: upgrade nvidia pytorch
ishaansehgal99 Oct 30, 2023
86101a8
fix: lint
ishaansehgal99 Oct 30, 2023
b691d9b
fix: naming
ishaansehgal99 Oct 30, 2023
cbbebbd
fix: diff
ishaansehgal99 Oct 30, 2023
5248737
fix: fetch
ishaansehgal99 Oct 30, 2023
1690ccc
fix: log
ishaansehgal99 Oct 30, 2023
2ee601c
feat: add the headless service, add the resliency to ensure cleanup o…
ishaansehgal99 Oct 31, 2023
224416a
fix: timeout error handling
ishaansehgal99 Nov 1, 2023
a91bded
fix: remove comments
ishaansehgal99 Nov 1, 2023
fda874a
feat: added torchrdzvparams, headless service
ishaansehgal99 Nov 1, 2023
cdf10e7
fix: simplify timeout
ishaansehgal99 Nov 1, 2023
a12f16c
fix: headless service variable fixes
ishaansehgal99 Nov 1, 2023
2f2821d
fix: shutdown
ishaansehgal99 Nov 1, 2023
86ff9e0
fix: dockerfile
ishaansehgal99 Nov 1, 2023
ca788a4
fix: update docker file paths to avoid conflicting volume mounts
ishaansehgal99 Nov 1, 2023
2213d85
fix: fix service naming, and add service namespace and ownerreference…
ishaansehgal99 Nov 2, 2023
8a9b4ae
fix: remove logs
ishaansehgal99 Nov 2, 2023
c52b40c
fix
ishaansehgal99 Nov 2, 2023
ac0e880
Merge branch 'main' of https://github.com/Azure/kdm into Ishaan/fix-w…
ishaansehgal99 Nov 2, 2023
fd3f9db
fix: rename create to useHeadlessService
ishaansehgal99 Nov 2, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions .github/workflows/preset-image-build.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
name: Build and Push Preset Models

on:
pull_request:
branches:
Expand All @@ -17,7 +16,7 @@ on:
- 'presets/llama-2-chat/**'
workflow_dispatch:
inputs:
image_tag:
image_tag_name:
description: 'Image Tag'
required: true

Expand Down Expand Up @@ -48,7 +47,13 @@ jobs:

- name: Get Modified files
run: |
files=$(git diff --name-only HEAD^ HEAD)
git fetch origin main:main
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "main" ] || [ "$current_branch" == "master" ]; then
files=$(git diff --name-only HEAD^ HEAD)
else
files=$(git diff --name-only main...HEAD)
fi
echo "Modified files: $files"
FILES_MODIFIED=""
while IFS= read -r file; do
Expand Down Expand Up @@ -93,9 +98,9 @@ jobs:
- name: Set Image Tag
id: set_tag
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag }}" ]]; then
if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag_name }}" ]]; then
echo "Using workflow dispatch to set image tag"
echo "image_tag=${{ github.event.inputs.image_tag }}" >> $GITHUB_OUTPUT
echo "image_tag=${{ github.event.inputs.image_tag_name }}" >> $GITHUB_OUTPUT
else
echo "Setting image tag based on version set"
echo "image_tag=${{ env.VERSION }}" >> $GITHUB_OUTPUT
Expand Down Expand Up @@ -175,7 +180,7 @@ jobs:
- name: 'Attach and Login to ACR'
id: acr_login
run: |
if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag }}" ]]; then
if [[ "${{ github.event_name }}" == "workflow_dispatch" && -n "${{ github.event.inputs.image_tag_name }}" ]]; then
ACR_NAME="aimodelsregistry"
else
ACR_NAME="aimodelsregistrytest"
Expand Down
4 changes: 2 additions & 2 deletions docker/presets/falcon/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Use the NVIDIA PyTorch image as a base
FROM nvcr.io/nvidia/pytorch:23.06-py3
FROM nvcr.io/nvidia/pytorch:23.10-py3

# Set the working directory
WORKDIR /workspace/falcon
Expand All @@ -13,6 +13,6 @@ RUN pip install --no-cache-dir -r requirements.txt
ARG FALCON_MODEL_NAME

# Copy the entire model to the weights directory
COPY /home/falcon/${FALCON_MODEL_NAME} /workspace/falcon/weights
COPY /falcon/${FALCON_MODEL_NAME} /workspace/falcon/weights
# Copy the entire 'presets/falcon' folder to the working directory
COPY /home/presets/falcon /workspace/falcon
4 changes: 2 additions & 2 deletions docker/presets/llama-2/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# --build-arg SRC_DIR=/home/presets/llama-2-chat \
# -t llama-2-7b-chat:latest .

FROM nvcr.io/nvidia/pytorch:23.06-py3
FROM nvcr.io/nvidia/pytorch:23.10-py3
WORKDIR /workspace

RUN git clone https://github.com/facebookresearch/llama
Expand All @@ -24,5 +24,5 @@ RUN pip install 'uvicorn[standard]'
ARG LLAMA_VERSION
ARG SRC_DIR

ADD /home/llama/${LLAMA_VERSION} /workspace/llama/llama-2/weights
ADD /llama/${LLAMA_VERSION} /workspace/llama/llama-2/weights
ADD ${SRC_DIR} /workspace/llama/llama-2
18 changes: 13 additions & 5 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka
return reconcile.Result{}, err
}

if err := c.ensureService(ctx, wObj); err != nil {
useHeadlessService := *(wObj.Resource.Count) > 1 && strings.Contains(string(wObj.Inference.Preset.Name), "llama")
if err := c.ensureService(ctx, wObj, useHeadlessService); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand All @@ -102,7 +103,7 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka
return reconcile.Result{}, err
}

if err = c.applyInference(ctx, wObj); err != nil {
if err = c.applyInference(ctx, wObj, useHeadlessService); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand Down Expand Up @@ -378,7 +379,7 @@ func (c *WorkspaceReconciler) ensureNodePlugins(ctx context.Context, wObj *kaito
}
}

func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error {
serviceType := corev1.ServiceTypeClusterIP
wAnnotation := wObj.GetAnnotations()

Expand Down Expand Up @@ -407,6 +408,13 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al
if err != nil {
return err
}
if useHeadlessService {
headlessService := resources.GenerateHeadlessServiceManifest(ctx, wObj)
err = resources.CreateResource(ctx, headlessService, c.Client)
if err != nil {
return err
}
}
return nil
}

Expand Down Expand Up @@ -447,7 +455,7 @@ func (c *WorkspaceReconciler) getInferenceObjFromPreset(ctx context.Context, wOb
}

// applyInference applies inference spec.
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace, useHeadlessService bool) error {

inferErr := func() error {
if wObj.Inference.Template != nil {
Expand Down Expand Up @@ -482,7 +490,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
}
} else if apierrors.IsNotFound(err) {
// Need to create a new workload
workloadObj, err := inference.CreatePresetInference(ctx, wObj, inferenceObj, c.Client)
workloadObj, err := inference.CreatePresetInference(ctx, wObj, inferenceObj, useHeadlessService, c.Client)
if err != nil {
return err
}
Expand Down
19 changes: 19 additions & 0 deletions pkg/inference/preset-inference-types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ const (
DefaultMasterPort = "29500"
)

// Torch Rendezvous Params
const (
DefaultMaxRestarts = "3"
DefaultRdzvId = "rdzv_id"
DefaultRdzvBackend = "c10d" // Pytorch Native Distributed data store
DefaultRdzvEndpoint = "localhost:29500" // llama-2-13b-chat-0.llama-headless.default.svc.cluster.local:29500
)

const (
DefaultConfigFile = "config.yaml"
DefaultNumProcesses = "1"
Expand Down Expand Up @@ -56,6 +64,13 @@ var (
"master_port": DefaultMasterPort,
}

defaultTorchRunRdzvParams = map[string]string{
"max_restarts": DefaultMaxRestarts,
"rdzv_id": DefaultRdzvId,
"rdzv_backend": DefaultRdzvBackend,
"rdzv_endpoint": DefaultRdzvEndpoint,
}

defaultAccelerateParams = map[string]string{
"config_file": DefaultConfigFile,
"num_processes": DefaultNumProcesses,
Expand All @@ -78,6 +93,7 @@ type PresetInferenceParam struct {
GPURequirement string
GPUMemoryRequirement string
TorchRunParams map[string]string
TorchRunRdzvParams map[string]string
ModelRunParams map[string]string
InferenceFile string
// DeploymentTimeout defines the maximum duration for pulling the Preset image.
Expand All @@ -104,6 +120,7 @@ var (
GPURequirement: "1",
GPUMemoryRequirement: "16Gi",
TorchRunParams: defaultTorchRunParams,
TorchRunRdzvParams: defaultTorchRunRdzvParams,
ModelRunParams: llamaRunParams,
InferenceFile: llamaChatInferenceFile,
DeploymentTimeout: time.Duration(10) * time.Minute,
Expand All @@ -120,6 +137,7 @@ var (
GPURequirement: "2",
GPUMemoryRequirement: "16Gi",
TorchRunParams: defaultTorchRunParams,
TorchRunRdzvParams: defaultTorchRunRdzvParams,
ModelRunParams: llamaRunParams,
InferenceFile: llamaChatInferenceFile,
DeploymentTimeout: time.Duration(20) * time.Minute,
Expand All @@ -136,6 +154,7 @@ var (
GPURequirement: "8",
GPUMemoryRequirement: "19Gi",
TorchRunParams: defaultTorchRunParams,
TorchRunRdzvParams: defaultTorchRunRdzvParams,
ModelRunParams: llamaRunParams,
InferenceFile: llamaChatInferenceFile,
DeploymentTimeout: time.Duration(30) * time.Minute,
Expand Down
12 changes: 10 additions & 2 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1
inferenceObj.TorchRunParams["master_addr"] = existingService.Spec.ClusterIP
inferenceObj.TorchRunParams["master_port"] = "29500"
}
if inferenceObj.TorchRunRdzvParams != nil {
inferenceObj.TorchRunRdzvParams["max_restarts"] = "3"
inferenceObj.TorchRunRdzvParams["rdzv_id"] = "job"
inferenceObj.TorchRunRdzvParams["rdzv_backend"] = "c10d"
inferenceObj.TorchRunRdzvParams["rdzv_endpoint"] =
fmt.Sprintf("%s-0.%s-headless.default.svc.cluster.local:29500", wObj.Name, wObj.Name)
}
} else if inferenceObj.ModelName == "Falcon" {
inferenceObj.TorchRunParams["config_file"] = "config.yaml"
inferenceObj.TorchRunParams["num_processes"] = "1"
Expand All @@ -90,7 +97,7 @@ func setTorchParams(ctx context.Context, kubeClient client.Client, wObj *kaitov1
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
inferenceObj PresetInferenceParam, kubeClient client.Client) (client.Object, error) {
inferenceObj PresetInferenceParam, useHeadlessService bool, kubeClient client.Client) (client.Object, error) {
if inferenceObj.TorchRunParams != nil {
if err := setTorchParams(ctx, kubeClient, workspaceObj, inferenceObj); err != nil {
klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj)
Expand All @@ -105,7 +112,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
switch inferenceObj.ModelName {
case "LLaMa2":
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount, useHeadlessService)
case "Falcon":
depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, inferenceObj.Image, inferenceObj.ImagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
Expand All @@ -121,6 +128,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work

func prepareInferenceParameters(ctx context.Context, inferenceObj PresetInferenceParam) ([]string, corev1.ResourceRequirements) {
torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := buildCommandStr(inferenceObj.InferenceFile, inferenceObj.ModelRunParams)
commands := shellCommand(torchCommand + " " + modelCommand)

Expand Down
44 changes: 42 additions & 2 deletions pkg/resources/manifests.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,42 @@ import (

var controller = true

func GenerateHeadlessServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) *corev1.Service {
serviceName := fmt.Sprintf("%s-headless", workspaceObj.Name)
selector := make(map[string]string)
for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels {
selector[k] = v
}
return &corev1.Service{
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: workspaceObj.Namespace,
OwnerReferences: []v1.OwnerReference{
{
APIVersion: kaitov1alpha1.GroupVersion.String(),
Kind: "Workspace",
UID: workspaceObj.UID,
Name: workspaceObj.Name,
Controller: &controller,
},
},
},
Spec: corev1.ServiceSpec{
Selector: selector,
ClusterIP: "None",
Ports: []corev1.ServicePort{
{
Name: "torchrun",
Protocol: corev1.ProtocolTCP,
Port: 29500,
TargetPort: intstr.FromInt(29500),
},
},
PublishNotReadyAddresses: true,
},
}
}

func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, serviceType corev1.ServiceType, isStatefulSet bool) *corev1.Service {
selector := make(map[string]string)
for k, v := range workspaceObj.Resource.LabelSelector.MatchLabels {
Expand Down Expand Up @@ -71,7 +107,7 @@ func GenerateServiceManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Wo
func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string,
imagePullSecretRefs []corev1.LocalObjectReference, replicas int, commands []string, containerPorts []corev1.ContainerPort,
livenessProbe, readinessProbe *corev1.Probe, resourceRequirements corev1.ResourceRequirements,
tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount) *appsv1.StatefulSet {
tolerations []corev1.Toleration, volumes []corev1.Volume, volumeMount []corev1.VolumeMount, useHeadlessService bool) *appsv1.StatefulSet {

// Gather label requirements from workspaceObj's label selector
labelRequirements := make([]v1.LabelSelectorRequirement, 0, len(workspaceObj.Resource.LabelSelector.MatchLabels))
Expand All @@ -82,7 +118,7 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha
Values: []string{value},
})
}
return &appsv1.StatefulSet{
ss := &appsv1.StatefulSet{
ObjectMeta: v1.ObjectMeta{
Name: workspaceObj.Name,
Namespace: workspaceObj.Namespace,
Expand Down Expand Up @@ -136,6 +172,10 @@ func GenerateStatefulSetManifest(ctx context.Context, workspaceObj *kaitov1alpha
},
},
}
if useHeadlessService {
ss.Spec.ServiceName = fmt.Sprintf("%s-headless", workspaceObj.Name)
}
return ss
}

func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, imageName string,
Expand Down
Loading
Loading