diff --git a/api/v1alpha1/condition_types.go b/api/v1alpha1/condition_types.go index c68c21361..d317dffdf 100644 --- a/api/v1alpha1/condition_types.go +++ b/api/v1alpha1/condition_types.go @@ -22,6 +22,9 @@ const ( // WorkspaceConditionTypeTuningJobStatus is the state when the tuning job starts normally. WorkspaceConditionTypeTuningJobStatus ConditionType = ConditionType("JobStarted") + //RAGEngineConditionTypeDeleting is the RAGEngine state when starts to get deleted. + RAGEngineConditionTypeDeleting = ConditionType("RAGEngineDeleting") + //WorkspaceConditionTypeDeleting is the Workspace state when starts to get deleted. WorkspaceConditionTypeDeleting = ConditionType("WorkspaceDeleting") diff --git a/pkg/controllers/ragengine_controller.go b/pkg/controllers/ragengine_controller.go index 3c745e86d..9cc873d4c 100644 --- a/pkg/controllers/ragengine_controller.go +++ b/pkg/controllers/ragengine_controller.go @@ -32,6 +32,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/karpenter/pkg/apis/v1beta1" @@ -64,6 +65,15 @@ func (c *RAGEngineReconciler) Reconcile(ctx context.Context, req reconcile.Reque klog.InfoS("Reconciling", "RAG Engine", req.NamespacedName) + if err := c.ensureFinalizer(ctx, ragEngineObj); err != nil { + return reconcile.Result{}, err + } + + // Handle deleting ragengine, garbage collect all the resources. + if !ragEngineObj.DeletionTimestamp.IsZero() { + return c.deleteRAGEngine(ctx, ragEngineObj) + } + result, err := c.addRAGEngine(ctx, ragEngineObj) if err != nil { return result, err @@ -72,6 +82,18 @@ func (c *RAGEngineReconciler) Reconcile(ctx context.Context, req reconcile.Reque return result, nil } +func (c *RAGEngineReconciler) ensureFinalizer(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { + if !controllerutil.ContainsFinalizer(ragEngineObj, consts.RAGEngineFinalizer) { + patch := client.MergeFrom(ragEngineObj.DeepCopy()) + controllerutil.AddFinalizer(ragEngineObj, consts.RAGEngineFinalizer) + if err := c.Client.Patch(ctx, ragEngineObj, patch); err != nil { + klog.ErrorS(err, "failed to ensure the finalizer to the ragengine", "ragengine", klog.KObj(ragEngineObj)) + return err + } + } + return nil +} + func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) { err := c.applyRAGEngineResource(ctx, ragEngineObj) if err != nil { @@ -80,6 +102,17 @@ func (c *RAGEngineReconciler) addRAGEngine(ctx context.Context, ragEngineObj *ka return reconcile.Result{}, nil } +func (c *RAGEngineReconciler) deleteRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (reconcile.Result, error) { + klog.InfoS("deleteRAGEngine", "ragengine", klog.KObj(ragEngineObj)) + err := c.updateStatusConditionIfNotMatch(ctx, ragEngineObj, kaitov1alpha1.RAGEngineConditionTypeDeleting, metav1.ConditionTrue, "ragengineDeleted", "ragengine is being deleted") + if err != nil { + klog.ErrorS(err, "failed to update ragengine status", "ragengine", klog.KObj(ragEngineObj)) + return reconcile.Result{}, err + } + + return c.garbageCollectRAGEngine(ctx, ragEngineObj) +} + // applyRAGEngineResource applies RAGEngine resource spec. func (c *RAGEngineReconciler) applyRAGEngineResource(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { diff --git a/pkg/controllers/ragengine_gc_finalizer.go b/pkg/controllers/ragengine_gc_finalizer.go new file mode 100644 index 000000000..3ccc1b33e --- /dev/null +++ b/pkg/controllers/ragengine_gc_finalizer.go @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package controllers + +import ( + "context" + + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" + "github.com/azure/kaito/pkg/featuregates" + "github.com/azure/kaito/pkg/machine" + "github.com/azure/kaito/pkg/nodeclaim" + "github.com/azure/kaito/pkg/utils/consts" + "k8s.io/klog/v2" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" +) + +// garbageCollectRAGEngine remove finalizer associated with ragengine object. +func (c *RAGEngineReconciler) garbageCollectRAGEngine(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) (ctrl.Result, error) { + klog.InfoS("garbageCollectRAGEngine", "ragengine", klog.KObj(ragEngineObj)) + + // Check if there are any machines associated with this ragengine. + mList, err := machine.ListMachines(ctx, ragEngineObj, c.Client) + if err != nil { + return ctrl.Result{}, err + } + // We should delete all the machines that are created by this ragengine + for i := range mList.Items { + if deleteErr := c.Delete(ctx, &mList.Items[i], &client.DeleteOptions{}); deleteErr != nil { + klog.ErrorS(deleteErr, "failed to delete the machine", "machine", klog.KObj(&mList.Items[i])) + return ctrl.Result{}, deleteErr + } + } + + if featuregates.FeatureGates[consts.FeatureFlagKarpenter] { + // Check if there are any nodeClaims associated with this ragengine. + ncList, err := nodeclaim.ListNodeClaim(ctx, ragEngineObj, c.Client) + if err != nil { + return ctrl.Result{}, err + } + + // We should delete all the nodeClaims that are created by this ragengine + for i := range ncList.Items { + if deleteErr := c.Delete(ctx, &ncList.Items[i], &client.DeleteOptions{}); deleteErr != nil { + klog.ErrorS(deleteErr, "failed to delete the nodeClaim", "nodeClaim", klog.KObj(&ncList.Items[i])) + return ctrl.Result{}, deleteErr + } + } + } + + staleWObj := ragEngineObj.DeepCopy() + staleWObj.SetFinalizers(nil) + if updateErr := c.Update(ctx, staleWObj, &client.UpdateOptions{}); updateErr != nil { + klog.ErrorS(updateErr, "failed to remove the finalizer from the ragengine", + "ragengine", klog.KObj(ragEngineObj), "ragengine", klog.KObj(staleWObj)) + return ctrl.Result{}, updateErr + } + klog.InfoS("successfully removed the ragengine finalizers", + "ragengine", klog.KObj(ragEngineObj)) + controllerutil.RemoveFinalizer(ragEngineObj, consts.RAGEngineFinalizer) + return ctrl.Result{}, nil +} diff --git a/pkg/machine/machine.go b/pkg/machine/machine.go index 2843afa52..b996d9a82 100644 --- a/pkg/machine/machine.go +++ b/pkg/machine/machine.go @@ -42,7 +42,7 @@ var ( func GenerateMachineManifest(ctx context.Context, storageRequirement string, obj interface{}) *v1alpha5.Machine { // Determine the type of the input object and extract relevant fields - instanceType, namespace, name, labelSelector, err := resources.ExtractObjFields(obj) + instanceType, namespace, name, labelSelector, nameLabel, namespaceLabel, err := resources.ExtractObjFields(obj) if err != nil { klog.Error(err) return nil @@ -51,9 +51,9 @@ func GenerateMachineManifest(ctx context.Context, storageRequirement string, obj digest := sha256.Sum256([]byte(namespace + name + time.Now().Format("2006-01-02 15:04:05.000000000"))) // We make sure the nodeClaim name is not fixed to the object machineName := "ws" + hex.EncodeToString(digest[0:])[0:9] machineLabels := map[string]string{ - LabelProvisionerName: ProvisionerName, - kaitov1alpha1.LabelWorkspaceName: name, - kaitov1alpha1.LabelWorkspaceNamespace: namespace, + LabelProvisionerName: ProvisionerName, + nameLabel: name, + namespaceLabel: namespace, } if labelSelector != nil && len(labelSelector.MatchLabels) != 0 { @@ -146,7 +146,7 @@ func WaitForPendingMachines(ctx context.Context, obj interface{}, kubeClient cli var instanceType string // Determine the type of the input object and retrieve the InstanceType - instanceType, _, _, _, err := resources.ExtractObjFields(obj) + instanceType, _, _, _, _, _, err := resources.ExtractObjFields(obj) if err != nil { return err } diff --git a/pkg/nodeclaim/nodeclaim.go b/pkg/nodeclaim/nodeclaim.go index ddc034ae7..9d0cd60c9 100644 --- a/pkg/nodeclaim/nodeclaim.go +++ b/pkg/nodeclaim/nodeclaim.go @@ -47,7 +47,7 @@ func GenerateNodeClaimManifest(ctx context.Context, storageRequirement string, o klog.InfoS("GenerateNodeClaimManifest", "object", obj) // Determine the type of the input object and extract relevant fields - instanceType, namespace, name, labelSelector, err := resources.ExtractObjFields(obj) + instanceType, namespace, name, labelSelector, nameLabel, namespaceLabel, err := resources.ExtractObjFields(obj) if err != nil { klog.Error(err) return nil @@ -56,9 +56,9 @@ func GenerateNodeClaimManifest(ctx context.Context, storageRequirement string, o nodeClaimName := GenerateNodeClaimName(obj) nodeClaimLabels := map[string]string{ - LabelNodePool: KaitoNodePoolName, // Fake nodepool name to prevent Karpenter from scaling up. - kaitov1alpha1.LabelWorkspaceName: name, - kaitov1alpha1.LabelWorkspaceNamespace: namespace, + LabelNodePool: KaitoNodePoolName, // Fake nodepool name to prevent Karpenter from scaling up. + nameLabel: name, + namespaceLabel: namespace, } if labelSelector != nil && len(labelSelector.MatchLabels) != 0 { nodeClaimLabels = lo.Assign(nodeClaimLabels, labelSelector.MatchLabels) @@ -143,7 +143,7 @@ func GenerateNodeClaimManifest(ctx context.Context, storageRequirement string, o // GenerateNodeClaimName generates a nodeClaim name from the given workspace or RAGEngine. func GenerateNodeClaimName(obj interface{}) string { // Determine the type of the input object and extract relevant fields - _, namespace, name, _, err := resources.ExtractObjFields(obj) + _, namespace, name, _, _, _, err := resources.ExtractObjFields(obj) if err != nil { return "" } @@ -250,7 +250,7 @@ func CreateKarpenterNodeClass(ctx context.Context, kubeClient client.Client) err func WaitForPendingNodeClaims(ctx context.Context, obj interface{}, kubeClient client.Client) error { // Determine the type of the input object and retrieve the InstanceType - instanceType, _, _, _, err := resources.ExtractObjFields(obj) + instanceType, _, _, _, _, _, err := resources.ExtractObjFields(obj) if err != nil { return err } diff --git a/pkg/resources/nodes.go b/pkg/resources/nodes.go index a18dbf327..adfeb0a0f 100644 --- a/pkg/resources/nodes.go +++ b/pkg/resources/nodes.go @@ -90,18 +90,23 @@ func CheckNvidiaPlugin(ctx context.Context, nodeObj *corev1.Node) bool { return false } -func ExtractObjFields(obj interface{}) (instanceType, namespace, name string, labelSelector *metav1.LabelSelector, err error) { +func ExtractObjFields(obj interface{}) (instanceType, namespace, name string, labelSelector *metav1.LabelSelector, + nameLabel, namespaceLabel string, err error) { switch o := obj.(type) { case *kaitov1alpha1.Workspace: instanceType = o.Resource.InstanceType namespace = o.Namespace name = o.Name labelSelector = o.Resource.LabelSelector + nameLabel = kaitov1alpha1.LabelWorkspaceName + namespaceLabel = kaitov1alpha1.LabelWorkspaceNamespace case *kaitov1alpha1.RAGEngine: instanceType = o.Spec.Compute.InstanceType namespace = o.Namespace name = o.Name labelSelector = o.Spec.Compute.LabelSelector + nameLabel = kaitov1alpha1.LabelRAGEngineName + namespaceLabel = kaitov1alpha1.LabelRAGEngineNamespace default: err = fmt.Errorf("unsupported object type: %T", obj) } diff --git a/pkg/utils/consts/consts.go b/pkg/utils/consts/consts.go index c498f89bd..908dc789e 100644 --- a/pkg/utils/consts/consts.go +++ b/pkg/utils/consts/consts.go @@ -5,7 +5,9 @@ package consts const ( // WorkspaceFinalizer is used to make sure that workspace controller handles garbage collection. - WorkspaceFinalizer = "workspace.finalizer.kaito.sh" + WorkspaceFinalizer = "workspace.finalizer.kaito.sh" + // RAGEngineFinalizer is used to make sure that ragengine controller handles garbage collection. + RAGEngineFinalizer = "ragengine.finalizer.kaito.sh" DefaultReleaseNamespaceEnvVar = "RELEASE_NAMESPACE" FeatureFlagKarpenter = "Karpenter" AzureCloudName = "azure"