diff --git a/pkg/inference/preset-llama2-inferences.go b/pkg/inference/preset-llama2-inferences.go index bde80855a..905634076 100644 --- a/pkg/inference/preset-llama2-inferences.go +++ b/pkg/inference/preset-llama2-inferences.go @@ -16,6 +16,9 @@ import ( ) const ( + Preset2ATimeout = 10 + Preset2BTimeout = 20 + Preset2CTimeout = 30 RegistryName = "aimodelsregistry.azurecr.io" PresetSetModelllama2AChatImage = RegistryName + "/llama-2-7b-chat:latest" PresetSetModelllama2BChatImage = RegistryName + "/llama-2-13b-chat:latest" @@ -98,10 +101,7 @@ func CreateLLAMA2APresetModel(ctx context.Context, workspaceObj *kdmv1alpha1.Wor return err } - ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - - if err := checkResourceStatus(ctxWithTimeout, depObj, kubeClient); err != nil { + if err := checkResourceStatus(depObj, kubeClient, Preset2ATimeout); err != nil { return err } return nil @@ -136,10 +136,7 @@ func CreateLLAMA2BPresetModel(ctx context.Context, workspaceObj *kdmv1alpha1.Wor return err } - ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 20*time.Minute) - defer cancel() - - if err := checkResourceStatus(ctxWithTimeout, depObj, kubeClient); err != nil { + if err := checkResourceStatus(depObj, kubeClient, Preset2BTimeout); err != nil { return err } return nil @@ -175,27 +172,26 @@ func CreateLLAMA2CPresetModel(ctx context.Context, workspaceObj *kdmv1alpha1.Wor return err } - ctxWithTimeout, cancel := context.WithTimeout(context.Background(), 30*time.Minute) - defer cancel() - - if err := checkResourceStatus(ctxWithTimeout, depObj, kubeClient); err != nil { + if err := checkResourceStatus(depObj, kubeClient, Preset2CTimeout); err != nil { return err } return nil } -func checkResourceStatus(ctx context.Context, obj client.Object, kubeClient client.Client) error { +func checkResourceStatus(obj client.Object, kubeClient client.Client, timeoutDuration int) error { klog.InfoS("checkResourceStatus", "resource", obj.GetName()) + + // Use Context for timeout + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutDuration)*time.Minute) + defer cancel() + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() - // Use context for timeout - timeoutChan := ctx.Done() - for { select { - case <-timeoutChan: - return fmt.Errorf("check resource status timed out. resource %s is not ready", obj.GetName()) + case <-ctx.Done(): + return ctx.Err() case <-ticker.C: key := client.ObjectKey{