Skip to content

Commit

Permalink
refactor workload controller to prepare to upstream to Kueue
Browse files Browse the repository at this point in the history
  • Loading branch information
dgrove-oss committed Jan 2, 2025
1 parent b5ffafb commit 157596a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 57 deletions.
63 changes: 6 additions & 57 deletions internal/controller/workload/workload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ limitations under the License.
package workload

import (
"fmt"

"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
"k8s.io/apimachinery/pkg/runtime/schema"

"sigs.k8s.io/controller-runtime/pkg/builder"
Expand Down Expand Up @@ -75,72 +72,24 @@ func (aw *AppWrapper) GVK() schema.GroupVersionKind {
}

func (aw *AppWrapper) PodSets() []kueue.PodSet {
podSets := []kueue.PodSet{}
if err := utils.EnsureComponentStatusInitialized((*workloadv1beta2.AppWrapper)(aw)); err != nil {
// Kueue will raise an error on zero length PodSet. Unfortunately, the Kueue API prevents propagating the actual error
return podSets
}
for idx := range aw.Status.ComponentStatus {
if len(aw.Status.ComponentStatus[idx].PodSets) > 0 {
obj := &unstructured.Unstructured{}
if _, _, err := unstructured.UnstructuredJSONScheme.Decode(aw.Spec.Components[idx].Template.Raw, nil, obj); err != nil {
// Should be unreachable; Template.Raw validated by AppWrapper AdmissionController
return []kueue.PodSet{} // Kueue will raise an error on zero length PodSet.
}
for psIdx, podSet := range aw.Status.ComponentStatus[idx].PodSets {
replicas := utils.Replicas(podSet)
if template, err := utils.GetPodTemplateSpec(obj, podSet.Path); err == nil {
podSets = append(podSets, kueue.PodSet{
Name: fmt.Sprintf("%s-%v-%v", aw.Name, idx, psIdx),
Template: *template,
Count: replicas,
})
}
}
}
podSets, err := utils.GetPodSets((*workloadv1beta2.AppWrapper)(aw))
if err != nil {
// Kueue will raise an error on zero length PodSet; the Kueue GenericJob API prevents propagating the actual error.
return []kueue.PodSet{}
}
return podSets
}

// RunWithPodSetsInfo records the assigned PodSetInfos for each component and sets aw.spec.Suspend to false
func (aw *AppWrapper) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error {
if err := utils.EnsureComponentStatusInitialized((*workloadv1beta2.AppWrapper)(aw)); err != nil {
if err := utils.SetPodSetInfos((*workloadv1beta2.AppWrapper)(aw), podSetsInfo); err != nil {
return err
}
podSetsInfoIndex := 0
for idx := range aw.Spec.Components {
if len(aw.Spec.Components[idx].PodSetInfos) != len(aw.Status.ComponentStatus[idx].PodSets) {
aw.Spec.Components[idx].PodSetInfos = make([]workloadv1beta2.AppWrapperPodSetInfo, len(aw.Status.ComponentStatus[idx].PodSets))
}
for podSetIdx := range aw.Status.ComponentStatus[idx].PodSets {
podSetsInfoIndex += 1
if podSetsInfoIndex > len(podSetsInfo) {
continue // we will return an error below...continuing to get an accurate count for the error message
}
aw.Spec.Components[idx].PodSetInfos[podSetIdx] = workloadv1beta2.AppWrapperPodSetInfo{
Annotations: podSetsInfo[podSetsInfoIndex-1].Annotations,
Labels: podSetsInfo[podSetsInfoIndex-1].Labels,
NodeSelector: podSetsInfo[podSetsInfoIndex-1].NodeSelector,
Tolerations: podSetsInfo[podSetsInfoIndex-1].Tolerations,
}
}
}

if podSetsInfoIndex != len(podSetsInfo) {
return podset.BadPodSetsInfoLenError(podSetsInfoIndex, len(podSetsInfo))
}

aw.Spec.Suspend = false

return nil
}

// RestorePodSetsInfo clears the PodSetInfos saved by RunWithPodSetsInfo
func (aw *AppWrapper) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
for idx := range aw.Spec.Components {
aw.Spec.Components[idx].PodSetInfos = nil
}
return true
return utils.ClearPodSetInfos((*workloadv1beta2.AppWrapper)(aw))
}

func (aw *AppWrapper) Finished() (message string, success, finished bool) {
Expand Down
69 changes: 69 additions & 0 deletions pkg/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import (
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/utils/ptr"

kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
"sigs.k8s.io/kueue/pkg/podset"

workloadv1beta2 "github.com/project-codeflare/appwrapper/api/v1beta2"
)

Expand Down Expand Up @@ -327,6 +330,72 @@ func EnsureComponentStatusInitialized(aw *workloadv1beta2.AppWrapper) error {
return nil
}

// GetPodSets constructs the kueue.PodSets for an AppWrapper
func GetPodSets(aw *workloadv1beta2.AppWrapper) ([]kueue.PodSet, error) {
podSets := []kueue.PodSet{}
if err := EnsureComponentStatusInitialized(aw); err != nil {
return nil, err
}
for idx := range aw.Status.ComponentStatus {
if len(aw.Status.ComponentStatus[idx].PodSets) > 0 {
obj := &unstructured.Unstructured{}
if _, _, err := unstructured.UnstructuredJSONScheme.Decode(aw.Spec.Components[idx].Template.Raw, nil, obj); err != nil {
// Should be unreachable; Template.Raw validated by AppWrapper AdmissionController
return nil, err
}
for psIdx, podSet := range aw.Status.ComponentStatus[idx].PodSets {
replicas := Replicas(podSet)
if template, err := GetPodTemplateSpec(obj, podSet.Path); err == nil {
podSets = append(podSets, kueue.PodSet{
Name: fmt.Sprintf("%s-%v-%v", aw.Name, idx, psIdx),
Template: *template,
Count: replicas,
})
}
}
}
}
return podSets, nil
}

// SetPodSetInfos propagates podSetsInfo into the PodSetInfos of aw.Spec.Components
func SetPodSetInfos(aw *workloadv1beta2.AppWrapper, podSetsInfo []podset.PodSetInfo) error {
if err := EnsureComponentStatusInitialized(aw); err != nil {
return err
}
podSetsInfoIndex := 0
for idx := range aw.Spec.Components {
if len(aw.Spec.Components[idx].PodSetInfos) != len(aw.Status.ComponentStatus[idx].PodSets) {
aw.Spec.Components[idx].PodSetInfos = make([]workloadv1beta2.AppWrapperPodSetInfo, len(aw.Status.ComponentStatus[idx].PodSets))
}
for podSetIdx := range aw.Status.ComponentStatus[idx].PodSets {
podSetsInfoIndex += 1
if podSetsInfoIndex > len(podSetsInfo) {
continue // we will return an error below...continuing to get an accurate count for the error message
}
aw.Spec.Components[idx].PodSetInfos[podSetIdx] = workloadv1beta2.AppWrapperPodSetInfo{
Annotations: podSetsInfo[podSetsInfoIndex-1].Annotations,
Labels: podSetsInfo[podSetsInfoIndex-1].Labels,
NodeSelector: podSetsInfo[podSetsInfoIndex-1].NodeSelector,
Tolerations: podSetsInfo[podSetsInfoIndex-1].Tolerations,
}
}
}

if podSetsInfoIndex != len(podSetsInfo) {
return podset.BadPodSetsInfoLenError(podSetsInfoIndex, len(podSetsInfo))
}
return nil
}

// ClearPodSetInfos clears the PodSetInfos saved by SetPodSetInfos
func ClearPodSetInfos(aw *workloadv1beta2.AppWrapper) bool {
for idx := range aw.Spec.Components {
aw.Spec.Components[idx].PodSetInfos = nil
}
return true
}

// inferReplicas parses the value at the given path within obj as an int or return 1 or error
func inferReplicas(obj map[string]interface{}, path string) (int32, error) {
if path == "" {
Expand Down

0 comments on commit 157596a

Please sign in to comment.