diff --git a/guest/api/types.go b/guest/api/types.go index 1e5579b1..72e59fa1 100644 --- a/guest/api/types.go +++ b/guest/api/types.go @@ -67,6 +67,22 @@ type FilterPlugin interface { Filter(state CycleState, pod proto.Pod, nodeInfo NodeInfo) *Status } +// PostFilterPlugin is a WebAssembly implementation of framework.PostFilterPlugin. +type PostFilterPlugin interface { + Plugin + + PostFilter(state CycleState, pod proto.Pod, filteredNodeStatusMap NodeToStatusMap) (nominatedNodeName string, nominatingMode NominatingMode, status *Status) +} + +// NominatingMode is the Mode which is returned from PostFilter. +type NominatingMode int32 + +// These are predefined modes +const ( + ModeNoop NominatingMode = iota + ModeOverride +) + // EnqueueExtensions is a WebAssembly implementation of framework.EnqueueExtensions. type EnqueueExtensions interface { EventsToRegister() []ClusterEvent @@ -97,3 +113,8 @@ type NodeInfo interface { Node() proto.Node } + +// PostFilterPlugin use nodeToStatusMap +type NodeToStatusMap interface { + Map() map[string]StatusCode +} diff --git a/guest/internal/imports/host.go b/guest/internal/imports/host.go index 4ac7794f..584ccf30 100644 --- a/guest/internal/imports/host.go +++ b/guest/internal/imports/host.go @@ -17,6 +17,7 @@ package imports import ( + "encoding/json" "runtime" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/api" @@ -66,3 +67,17 @@ func Pod(updater func([]byte) error) error { return k8sApiPod(ptr, limit) }, updater) } + +func NodeToStatusMap() map[string]api.StatusCode { + // Wrap to avoid TinyGo 0.28: cannot use an exported function as value + jsonStr := mem.GetString(func(ptr uint32, limit mem.BufLimit) (len uint32) { + return k8sSchedulerNodeToStatusMap(ptr, limit) + }) + byte := []byte(jsonStr) + var nodeToMap map[string]api.StatusCode + err := json.Unmarshal(byte, &nodeToMap) + if err != nil { + panic(err) + } + return nodeToMap +} diff --git a/guest/internal/imports/imports.go b/guest/internal/imports/imports.go index 797ed2f1..bbcf2079 100644 --- a/guest/internal/imports/imports.go +++ b/guest/internal/imports/imports.go @@ -29,5 +29,8 @@ func k8sApiNodeName(ptr uint32, limit mem.BufLimit) (len uint32) //go:wasmimport k8s.io/api pod func k8sApiPod(ptr uint32, limit mem.BufLimit) (len uint32) +//go:wasmimport k8s.io/scheduler nodeToStatusMap +func k8sSchedulerNodeToStatusMap(ptr uint32, limit mem.BufLimit) (len uint32) + //go:wasmimport k8s.io/scheduler result.status_reason func k8sSchedulerResultStatusReason(ptr, size uint32) diff --git a/guest/internal/imports/imports_stub.go b/guest/internal/imports/imports_stub.go index 63ef99b4..22a03b57 100644 --- a/guest/internal/imports/imports_stub.go +++ b/guest/internal/imports/imports_stub.go @@ -29,5 +29,8 @@ func k8sApiNodeName(uint32, mem.BufLimit) (len uint32) { return } // k8sApiPod is stubbed for compilation outside TinyGo. func k8sApiPod(uint32, mem.BufLimit) (len uint32) { return } +// k8sSchedulerNodeToStatusMap is stubbed for compilation outside TinyGo. +func k8sSchedulerNodeToStatusMap(uint32, mem.BufLimit) (len uint32) { return } + // k8sSchedulerResultStatusReason is stubbed for compilation outside TinyGo. func k8sSchedulerResultStatusReason(uint32, uint32) {} diff --git a/guest/plugin/plugin.go b/guest/plugin/plugin.go index 15c668b7..7d8b9af0 100644 --- a/guest/plugin/plugin.go +++ b/guest/plugin/plugin.go @@ -21,6 +21,7 @@ import ( "sigs.k8s.io/kube-scheduler-wasm-extension/guest/enqueue" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/filter" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/prefilter" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/postfilter" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/prescore" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/score" ) @@ -51,6 +52,9 @@ func Set(plugin api.Plugin) { if plugin, ok := plugin.(api.FilterPlugin); ok { filter.SetPlugin(plugin) } + if plugin, ok := plugin.(api.PostFilterPlugin); ok { + postfilter.SetPlugin(plugin) + } if plugin, ok := plugin.(api.PreScorePlugin); ok { prescore.SetPlugin(plugin) } diff --git a/guest/postfilter/imports.go b/guest/postfilter/imports.go new file mode 100644 index 00000000..7fa755d1 --- /dev/null +++ b/guest/postfilter/imports.go @@ -0,0 +1,22 @@ +//go:build tinygo.wasm + +/* + Copyright 2023 The Kubernetes Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package postfilter + +//go:wasmimport k8s.io/scheduler result.nominated_node_name +func setNominatedNodeNameResult(ptr, size uint32) diff --git a/guest/postfilter/imports_stub.go b/guest/postfilter/imports_stub.go new file mode 100644 index 00000000..83e49d80 --- /dev/null +++ b/guest/postfilter/imports_stub.go @@ -0,0 +1,22 @@ +//go:build !tinygo.wasm + +/* + Copyright 2023 The Kubernetes Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package postfilter + +// setNominatedNodeNameResult is stubbed for compilation outside TinyGo. +func setNominatedNodeNameResult(uint32, uint32) {} diff --git a/guest/postfilter/postfilter.go b/guest/postfilter/postfilter.go new file mode 100644 index 00000000..200d9f44 --- /dev/null +++ b/guest/postfilter/postfilter.go @@ -0,0 +1,100 @@ +/* + Copyright 2023 The Kubernetes Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Package postfilter exports an api.PostFilterPlugin to the host. +package postfilter + +import ( + "runtime" + + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/api" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/cyclestate" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/imports" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/mem" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/plugin" +) + +// postfilter is the current plugin assigned with SetPlugin. +var postfilter api.PostFilterPlugin + +// SetPlugin should be called in `main` to assign an api.PostFilterPlugin +// instance. +// +// For example: +// +// func main() { +// plugin := filterPlugin{} +// postfilter.SetPlugin(plugin) +// filter.SetPlugin(plugin) +// } +// +// type filterPlugin struct{} +// +// func (filterPlugin) PostFilter(state api.CycleState, pod proto.Pod, filteredNodeStatusMap api.NodeToStatusMap) (int32, status *api.Status) { +// // Write state you need on Filter +// } +// +// func (filterPlugin) Filter(state api.CycleState, pod api.Pod, nodeInfo api.NodeInfo) (status *api.Status) { +// var Filter int32 +// // Derive Filter for the node name using state set on PreFilter! +// return Filter, nil +// } +func SetPlugin(postfilterPlugin api.PostFilterPlugin) { + if postfilterPlugin == nil { + panic("nil postfilterPlugin") + } + postfilter = postfilterPlugin + plugin.MustSet(postfilterPlugin) +} + +// prevent unused lint errors (lint is run with normal go). +var _ func() uint64 = _postfilter + +// _postfilter is only exported to the host. +// +//export postfilter +func _postfilter() uint64 { //nolint + + if postfilter == nil { // Then, the user didn't define one. + // Unlike most plugins we always export postfilter so that we can reset + // the cycle state: return success to avoid no-op overhead. + return 0 + } + + // The parameters passed are lazy with regard to host functions. This means + // a no-op plugin should not have any unmarshal penalty. + nominatedNodeName, nominatingMode, status := postfilter.PostFilter(cyclestate.Values, cyclestate.Pod, &nodeToStatusMap{}) + ptr, size := mem.StringToPtr(nominatedNodeName) + setNominatedNodeNameResult(ptr, size) + runtime.KeepAlive(nominatedNodeName) // until ptr is no longer needed. + + return (uint64(nominatingMode) << uint64(32)) | uint64(imports.StatusToCode(status)) +} + +type nodeToStatusMap struct { + statusMap map[string]api.StatusCode +} + +func (n *nodeToStatusMap) Map() map[string]api.StatusCode { + return n.lazyNodeToStatusMap() +} + +// lazyNodeToStatusMap returns NodeToStatusMap from imports.NodeToStatusMap. +func (n *nodeToStatusMap) lazyNodeToStatusMap() map[string]api.StatusCode { + nodeMap := imports.NodeToStatusMap() + n.statusMap = nodeMap + return n.statusMap +} diff --git a/guest/prefilter/prefilter.go b/guest/prefilter/prefilter.go index 73db7179..26745fa0 100644 --- a/guest/prefilter/prefilter.go +++ b/guest/prefilter/prefilter.go @@ -35,7 +35,7 @@ import ( // // type filterPlugin struct{} // -// func (filterPlugin) PreFilter(state api.CycleState, pod proto.Pod, nodeList proto.NodeList) { +// func (filterPlugin) PreFilter(state api.CycleState, pod proto.Pod) (nodeNames []string, status *Status) { // // Write state you need on Filter // } // diff --git a/guest/testdata/cyclestate/main.go b/guest/testdata/cyclestate/main.go index 4b4bbf80..1266aefa 100644 --- a/guest/testdata/cyclestate/main.go +++ b/guest/testdata/cyclestate/main.go @@ -26,6 +26,7 @@ import ( "sigs.k8s.io/kube-scheduler-wasm-extension/guest/api/proto" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/enqueue" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/filter" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/postfilter" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/prefilter" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/prescore" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/score" @@ -56,6 +57,7 @@ func main() { enqueue.SetPlugin(plugin) prefilter.SetPlugin(plugin) filter.SetPlugin(plugin) + postfilter.SetPlugin(plugin) prescore.SetPlugin(plugin) score.SetPlugin(plugin) } @@ -82,11 +84,11 @@ type preFilterStateVal map[string]any type preScoreStateVal map[string]any func (statePlugin) PreFilter(state api.CycleState, pod proto.Pod) (nodeNames []string, status *api.Status) { - if nextPodSpec := pod.Spec(); unsafe.Pointer(nextPodSpec) == unsafe.Pointer(podSpec) { + nextPodSpec := pod.Spec() + if unsafe.Pointer(nextPodSpec) == unsafe.Pointer(podSpec) { panic("didn't reset pod on pre-filter") - } else { - podSpec = nextPodSpec } + podSpec = nextPodSpec mustNotScoreState(state) if _, ok := state.Read(preFilterStateKey); ok { panic("didn't reset filter state on pre-filter") @@ -109,10 +111,23 @@ func (statePlugin) Filter(state api.CycleState, pod proto.Pod, _ api.NodeInfo) ( return } -func (statePlugin) PreScore(state api.CycleState, pod proto.Pod, _ proto.NodeList) *api.Status { +func (statePlugin) PostFilter(state api.CycleState, pod proto.Pod, _ api.NodeToStatusMap) (nominatedNodeName string, nominatingMode api.NominatingMode, status *api.Status) { if unsafe.Pointer(pod.Spec()) != unsafe.Pointer(podSpec) { panic("didn't cache pod from filter") } + mustNotScoreState(state) + if val, ok := state.Read(preFilterStateKey); !ok { + panic("didn't propagate state from pre-filter") + } else if _, ok = val.(preFilterStateVal)["filter"]; !ok { + panic("filter value lost propagating from filter") + } + return +} + +func (statePlugin) PreScore(state api.CycleState, pod proto.Pod, _ proto.NodeList) *api.Status { + if unsafe.Pointer(pod.Spec()) != unsafe.Pointer(podSpec) { + panic("didn't cache pod from pre-filter") + } mustFilterState(state) if _, ok := state.Read(preScoreStateKey); ok { panic("didn't reset score state on pre-score") diff --git a/guest/testdata/cyclestate/main.wasm b/guest/testdata/cyclestate/main.wasm index b9e1c031..1f143c0e 100755 Binary files a/guest/testdata/cyclestate/main.wasm and b/guest/testdata/cyclestate/main.wasm differ diff --git a/guest/testdata/filter/main.go b/guest/testdata/filter/main.go index 05a31010..bb3df312 100644 --- a/guest/testdata/filter/main.go +++ b/guest/testdata/filter/main.go @@ -22,12 +22,14 @@ import ( "sigs.k8s.io/kube-scheduler-wasm-extension/guest/api" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/api/proto" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/filter" + "sigs.k8s.io/kube-scheduler-wasm-extension/guest/postfilter" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/prefilter" ) type extensionPoints interface { api.PreFilterPlugin api.FilterPlugin + api.PostFilterPlugin } func main() { @@ -40,10 +42,13 @@ func main() { plugin = filterPlugin{} case "preFilter": plugin = preFilterPlugin{} + case "postFilter": + plugin = postFilterPlugin{} } } prefilter.SetPlugin(plugin) filter.SetPlugin(plugin) + postfilter.SetPlugin(plugin) } // noopPlugin doesn't do anything, except evaluate each parameter. @@ -62,6 +67,13 @@ func (noopPlugin) Filter(state api.CycleState, pod proto.Pod, nodeInfo api.NodeI return } +func (noopPlugin) PostFilter(state api.CycleState, pod proto.Pod, nodeMap api.NodeToStatusMap) (nominatedNodeName string, nominatingMode api.NominatingMode, status *api.Status) { + _, _ = state.Read("ok") + _ = pod.Spec() + _ = nodeMap.Map() + return +} + // preFilterPlugin schedules a node if its name equals its pod spec. type preFilterPlugin struct{ noopPlugin } @@ -96,3 +108,28 @@ func (filterPlugin) Filter(_ api.CycleState, pod proto.Pod, nodeInfo api.NodeInf Reason: podSpecNodeName + " != " + nodeName, } } + +type postFilterPlugin struct{ noopPlugin } + +func (postFilterPlugin) PostFilter(_ api.CycleState, pod proto.Pod, nodeMap api.NodeToStatusMap) (string, api.NominatingMode, *api.Status) { + // First, check if the pod spec node name is empty. If so, pass! + podSpecNodeName := pod.Spec().GetNodeName() + if len(podSpecNodeName) == 0 { + return "", 0, nil + } + m := nodeMap.Map() + if m == nil { + return "", 0, nil + } + // If nominatedNodeName is schedulable, pass! + if val, ok := m[podSpecNodeName]; ok { + if val == 0 { + return podSpecNodeName, api.ModeOverride, nil + } + } + // Otherwise, this is unschedulableAndUnresolvable, so note the reason. + return podSpecNodeName, api.ModeNoop, &api.Status{ + Code: api.StatusCodeUnschedulableAndUnresolvable, + Reason: podSpecNodeName + " is unschedulable", + } +} diff --git a/guest/testdata/filter/main.wasm b/guest/testdata/filter/main.wasm index bc2802c7..3d345dca 100755 Binary files a/guest/testdata/filter/main.wasm and b/guest/testdata/filter/main.wasm differ diff --git a/internal/e2e/e2e.go b/internal/e2e/e2e.go index 0174eb91..9038c284 100644 --- a/internal/e2e/e2e.go +++ b/internal/e2e/e2e.go @@ -23,6 +23,11 @@ func RunAll(ctx context.Context, t Testing, plugin framework.Plugin, pod *v1.Pod RequireSuccess(t, s) } + if postfilterP, ok := plugin.(framework.PostFilterPlugin); ok { + _, s = postfilterP.PostFilter(ctx, nil, pod, nil) + RequireSuccess(t, s) + } + if prescoreP, ok := plugin.(framework.PreScorePlugin); ok { s = prescoreP.PreScore(ctx, nil, pod, []*v1.Node{ni.Node()}) RequireSuccess(t, s) diff --git a/internal/e2e/go.mod b/internal/e2e/go.mod index e9ea32db..2d0655a6 100644 --- a/internal/e2e/go.mod +++ b/internal/e2e/go.mod @@ -165,7 +165,7 @@ require ( google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/grpc v1.51.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect gopkg.in/gcfg.v1 v1.2.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect diff --git a/internal/e2e/go.sum b/internal/e2e/go.sum index 335f5b0a..93cdad68 100644 --- a/internal/e2e/go.sum +++ b/internal/e2e/go.sum @@ -846,8 +846,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/e2e/scheduler_perf/scheduler_perf_test.go b/internal/e2e/scheduler_perf/scheduler_perf_test.go index 565e83cd..071f091e 100644 --- a/internal/e2e/scheduler_perf/scheduler_perf_test.go +++ b/internal/e2e/scheduler_perf/scheduler_perf_test.go @@ -91,7 +91,7 @@ var ( Metrics: map[string]*labelValues{ "scheduler_framework_extension_point_duration_seconds": { label: extensionPointsLabelName, - values: []string{"PreFilter", "Filter", "PreScore", "Score"}, + values: []string{"PreFilter", "Filter", "PostFilter", "PreScore", "Score"}, }, "scheduler_scheduling_attempt_duration_seconds": nil, "scheduler_pod_scheduling_duration_seconds": nil, diff --git a/scheduler/go.mod b/scheduler/go.mod index e90af7ca..c13b068c 100644 --- a/scheduler/go.mod +++ b/scheduler/go.mod @@ -126,7 +126,7 @@ require ( google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/grpc v1.51.0 // indirect - google.golang.org/protobuf v1.28.1 // indirect + google.golang.org/protobuf v1.31.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/scheduler/go.sum b/scheduler/go.sum index 19041223..0b14f5c1 100644 --- a/scheduler/go.sum +++ b/scheduler/go.sum @@ -647,8 +647,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/scheduler/plugin/guest.go b/scheduler/plugin/guest.go index 37ed5b6c..333c2b60 100644 --- a/scheduler/plugin/guest.go +++ b/scheduler/plugin/guest.go @@ -28,23 +28,25 @@ import ( ) const ( - guestExportMemory = "memory" - guestExportEnqueue = "enqueue" - guestExportPreFilter = "prefilter" - guestExportFilter = "filter" - guestExportPreScore = "prescore" - guestExportScore = "score" + guestExportMemory = "memory" + guestExportEnqueue = "enqueue" + guestExportPreFilter = "prefilter" + guestExportFilter = "filter" + guestExportPostFilter = "postfilter" + guestExportPreScore = "prescore" + guestExportScore = "score" ) type guest struct { - guest wazeroapi.Module - out *bytes.Buffer - enqueueFn wazeroapi.Function - prefilterFn wazeroapi.Function - filterFn wazeroapi.Function - prescoreFn wazeroapi.Function - scoreFn wazeroapi.Function - callStack []uint64 + guest wazeroapi.Module + out *bytes.Buffer + enqueueFn wazeroapi.Function + prefilterFn wazeroapi.Function + filterFn wazeroapi.Function + postfilterFn wazeroapi.Function + prescoreFn wazeroapi.Function + scoreFn wazeroapi.Function + callStack []uint64 } func compileGuest(ctx context.Context, runtime wazero.Runtime, guestBin []byte) (guest wazero.CompiledModule, err error) { @@ -82,14 +84,15 @@ func (pl *wasmPlugin) newGuest(ctx context.Context) (*guest, error) { callStack := make([]uint64, 1) return &guest{ - guest: g, - out: &out, - enqueueFn: g.ExportedFunction(guestExportEnqueue), - prefilterFn: g.ExportedFunction(guestExportPreFilter), - filterFn: g.ExportedFunction(guestExportFilter), - prescoreFn: g.ExportedFunction(guestExportPreScore), - scoreFn: g.ExportedFunction(guestExportScore), - callStack: callStack, + guest: g, + out: &out, + enqueueFn: g.ExportedFunction(guestExportEnqueue), + prefilterFn: g.ExportedFunction(guestExportPreFilter), + filterFn: g.ExportedFunction(guestExportFilter), + postfilterFn: g.ExportedFunction(guestExportPostFilter), + prescoreFn: g.ExportedFunction(guestExportPreScore), + scoreFn: g.ExportedFunction(guestExportScore), + callStack: callStack, }, nil } @@ -131,6 +134,23 @@ func (g *guest) filter(ctx context.Context) *framework.Status { return framework.NewStatus(framework.Code(statusCode), statusReason) } +// postFilter calls guestExportPostFilter. +func (g *guest) postFilter(ctx context.Context) (*framework.PostFilterResult, *framework.Status) { + defer g.out.Reset() + callStack := g.callStack + if err := g.postfilterFn.CallWithStack(ctx, callStack); err != nil { + return nil, framework.AsStatus(decorateError(g.out, guestExportPostFilter, err)) + } + nominatedNodeName := paramsFromContext(ctx).resultNominatedNodeName + nominatingMode := framework.NominatingMode(int32(callStack[0] >> 32)) + + statusCode := int32(callStack[0]) + statusReason := paramsFromContext(ctx).resultStatusReason + + nominatingInfo := &framework.NominatingInfo{NominatedNodeName: nominatedNodeName, NominatingMode: nominatingMode} + return &framework.PostFilterResult{NominatingInfo: nominatingInfo}, framework.NewStatus(framework.Code(statusCode), statusReason) +} + // preScore calls guestExportPreScore. func (g *guest) preScore(ctx context.Context) *framework.Status { defer g.out.Reset() @@ -188,6 +208,11 @@ func detectInterfaces(exportedFns map[string]wazeroapi.FunctionDefinition) (inte return 0, fmt.Errorf("wasm: guest exports the wrong signature for func[%s]. should be () -> (i32)", name) } e |= iFilterPlugin + case guestExportPostFilter: + if len(f.ParamTypes()) != 0 || !bytes.Equal(f.ResultTypes(), []wazeroapi.ValueType{i64}) { + return 0, fmt.Errorf("wasm: guest exports the wrong signature for func[%s]. should be () -> (i64)", name) + } + e |= iPostFilterPlugin case guestExportPreScore: if len(f.ParamTypes()) != 0 || !bytes.Equal(f.ResultTypes(), []wazeroapi.ValueType{i32}) { return 0, fmt.Errorf("wasm: guest exports the wrong signature for func[%s]. should be () -> (i32)", name) diff --git a/scheduler/plugin/host.go b/scheduler/plugin/host.go index 5befeb10..03db2e0a 100644 --- a/scheduler/plugin/host.go +++ b/scheduler/plugin/host.go @@ -18,6 +18,7 @@ package wasm import ( "context" + "encoding/json" "github.com/tetratelabs/wazero" wazeroapi "github.com/tetratelabs/wazero/api" @@ -27,22 +28,24 @@ import ( ) const ( - i32 = wazeroapi.ValueTypeI32 - i64 = wazeroapi.ValueTypeI64 - k8sApi = "k8s.io/api" - k8sApiNode = "node" - k8sApiNodeList = "nodeList" - k8sApiNodeName = "nodeName" - k8sApiPod = "pod" - k8sKlog = "k8s.io/klog" - k8sKlogLog = "log" - k8sKlogLogs = "logs" - k8sKlogSeverity = "severity" - k8sScheduler = "k8s.io/scheduler" - k8sSchedulerGetConfig = "get_config" - k8sSchedulerResultClusterEvents = "result.cluster_events" - k8sSchedulerResultNodeNames = "result.node_names" - k8sSchedulerResultStatusReason = "result.status_reason" + i32 = wazeroapi.ValueTypeI32 + i64 = wazeroapi.ValueTypeI64 + k8sApi = "k8s.io/api" + k8sApiNode = "node" + k8sApiNodeList = "nodeList" + k8sApiNodeName = "nodeName" + k8sApiPod = "pod" + k8sApiNodeToStatusMap = "nodeToStatusMap" + k8sKlog = "k8s.io/klog" + k8sKlogLog = "log" + k8sKlogLogs = "logs" + k8sKlogSeverity = "severity" + k8sScheduler = "k8s.io/scheduler" + k8sSchedulerGetConfig = "get_config" + k8sSchedulerResultClusterEvents = "result.cluster_events" + k8sSchedulerResultNodeNames = "result.node_names" + k8sSchedulerResultNominatedNodeName = "result.nominated_node_name" + k8sSchedulerResultStatusReason = "result.status_reason" ) func instantiateHostApi(ctx context.Context, runtime wazero.Runtime) (wazeroapi.Module, error) { @@ -90,8 +93,14 @@ func instantiateHostScheduler(ctx context.Context, runtime wazero.Runtime, guest WithGoModuleFunction(wazeroapi.GoModuleFunc(k8sSchedulerResultNodeNamesFn), []wazeroapi.ValueType{i32, i32}, []wazeroapi.ValueType{}). WithParameterNames("buf", "buf_len").Export(k8sSchedulerResultNodeNames). NewFunctionBuilder(). + WithGoModuleFunction(wazeroapi.GoModuleFunc(k8sSchedulerResultNominatedNodeNameFn), []wazeroapi.ValueType{i32, i32}, []wazeroapi.ValueType{}). + WithParameterNames("buf", "buf_len").Export(k8sSchedulerResultNominatedNodeName). + NewFunctionBuilder(). WithGoModuleFunction(wazeroapi.GoModuleFunc(k8sSchedulerResultStatusReasonFn), []wazeroapi.ValueType{i32, i32}, []wazeroapi.ValueType{}). WithParameterNames("buf", "buf_len").Export(k8sSchedulerResultStatusReason). + NewFunctionBuilder(). + WithGoModuleFunction(wazeroapi.GoModuleFunc(k8sSchedulerNodeToStatusMapFn), []wazeroapi.ValueType{i32, i32}, []wazeroapi.ValueType{i32}). + WithParameterNames("buf", "buf_limit").Export(k8sApiNodeToStatusMap). Instantiate(ctx) } @@ -122,12 +131,18 @@ type stack struct { // pod is used by guest.filterFn and guest.scoreFn pod *v1.Pod + // nodeToStatusMap is used by guest.postfilterFn + nodeToStatusMap map[string]*framework.Status + // resultClusterEvents is returned by guest.enqueueFn resultClusterEvents []framework.ClusterEvent // resultNodeNames is returned by guest.prefilterFn resultNodeNames []string + // resultNominatedNodeName is returned by guest.postfilterFn + resultNominatedNodeName string + // reason returned by all guest exports except guest.enqueueFn // // It is a field to avoid compiler-specific malloc/free functions, and to @@ -179,6 +194,20 @@ func k8sApiPodFn(ctx context.Context, mod wazeroapi.Module, stack []uint64) { stack[0] = uint64(marshalIfUnderLimit(mod.Memory(), pod, buf, bufLimit)) } +// k8sSchedulerNodeToStatusMapFn is a function used by the host to send the nodeStatusMap. +func k8sSchedulerNodeToStatusMapFn(ctx context.Context, mod wazeroapi.Module, stack []uint64) { + buf := uint32(stack[0]) + bufLimit := bufLimit(stack[1]) + + nodeToStatusMap := paramsFromContext(ctx).nodeToStatusMap + nodeCodeMap := nodeStatusMapToMap(nodeToStatusMap) + mapByte, err := json.Marshal(nodeCodeMap) + if err != nil { + panic(err) + } + stack[0] = uint64(writeStringIfUnderLimit(mod.Memory(), string(mapByte), buf, bufLimit)) +} + type host struct { guestConfig string logSeverity int32 @@ -303,6 +332,21 @@ func k8sSchedulerResultNodeNamesFn(ctx context.Context, mod wazeroapi.Module, st paramsFromContext(ctx).resultNodeNames = nodeNames } +// k8sSchedulerResultNominatedNodeNameFn is a function used by the wasm guest to set the +// nominated node name result from guestExportPostFilter. +func k8sSchedulerResultNominatedNodeNameFn(ctx context.Context, mod wazeroapi.Module, stack []uint64) { + buf := uint32(stack[0]) + bufLen := uint32(stack[1]) + + var nominatedNodeName string + if b, ok := mod.Memory().Read(buf, bufLen); !ok { + panic("out of memory reading nominatedNodeName") + } else { + nominatedNodeName = string(b) + } + paramsFromContext(ctx).resultNominatedNodeName = nominatedNodeName +} + // k8sSchedulerResultStatusReasonFn is a function used by the wasm guest to set the // framework.Status reason result from all functions. func k8sSchedulerResultStatusReasonFn(ctx context.Context, mod wazeroapi.Module, stack []uint64) { @@ -318,3 +362,14 @@ func k8sSchedulerResultStatusReasonFn(ctx context.Context, mod wazeroapi.Module, } paramsFromContext(ctx).resultStatusReason = reason } + +// Converts nodeToStatusMap to a map with node names as keys and their scores as integer values. +func nodeStatusMapToMap(originalMap map[string]*framework.Status) map[string]int { + newMap := make(map[string]int) + for key, value := range originalMap { + if value != nil { + newMap[key] = int(value.Code()) + } + } + return newMap +} diff --git a/scheduler/plugin/plugin.go b/scheduler/plugin/plugin.go index 1b3b07e1..424aad62 100644 --- a/scheduler/plugin/plugin.go +++ b/scheduler/plugin/plugin.go @@ -262,13 +262,22 @@ func (pl *wasmPlugin) Filter(ctx context.Context, _ *framework.CycleState, pod * var _ framework.PostFilterPlugin = (*wasmPlugin)(nil) // PostFilter implements the same method as documented on framework.PostFilterPlugin. -func (pl *wasmPlugin) PostFilter(ctx context.Context, state *framework.CycleState, pod *v1.Pod, filteredNodeStatusMap framework.NodeToStatusMap) (*framework.PostFilterResult, *framework.Status) { +func (pl *wasmPlugin) PostFilter(ctx context.Context, state *framework.CycleState, pod *v1.Pod, filteredNodeStatusMap framework.NodeToStatusMap) (result *framework.PostFilterResult, status *framework.Status) { // We implement PostFilterPlugin with FilterPlugin, even when the guest doesn't. if pl.guestInterfaces&iPostFilterPlugin == 0 { return nil, nil // unimplemented } - panic("TODO: scheduling: PostFilter") + // Add the stack to the go context so that the corresponding host function + // can look them up. + params := &stack{pod: pod, nodeToStatusMap: filteredNodeStatusMap} + ctx = context.WithValue(ctx, stackKey{}, params) + if err := pl.pool.doWithSchedulingGuest(ctx, pod.UID, func(g *guest) { + result, status = g.postFilter(ctx) + }); err != nil { + status = framework.AsStatus(err) + } + return } var _ framework.PreScorePlugin = (*wasmPlugin)(nil) diff --git a/scheduler/plugin/plugin_test.go b/scheduler/plugin/plugin_test.go index bd6b8e82..7b7a4fcb 100644 --- a/scheduler/plugin/plugin_test.go +++ b/scheduler/plugin/plugin_test.go @@ -561,6 +561,111 @@ wasm stack trace: } } +func TestPostFilter(t *testing.T) { + tests := []struct { + name string + guestURL string + args []string + globals map[string]int32 + pod *v1.Pod + nodeToStatusMap map[string]*framework.Status + expectedResult *framework.PostFilterResult + expectedStatusCode framework.Code + expectedStatusMessage string + }{ + { + name: "success", + args: []string{"test", "postFilter"}, + pod: test.PodSmall, + nodeToStatusMap: map[string]*framework.Status{test.NodeSmallName: framework.NewStatus(framework.Success, "")}, + expectedResult: &framework.PostFilterResult{NominatingInfo: &framework.NominatingInfo{NominatedNodeName: "good-node", NominatingMode: framework.ModeOverride}}, + expectedStatusCode: framework.Success, + }, + { + name: "unschedulable", + args: []string{"test", "postFilter"}, + pod: test.PodSmall, + nodeToStatusMap: map[string]*framework.Status{test.NodeSmallName: framework.NewStatus(framework.Unschedulable, "")}, + expectedResult: &framework.PostFilterResult{NominatingInfo: &framework.NominatingInfo{NominatedNodeName: "good-node", NominatingMode: framework.ModeNoop}}, + expectedStatusMessage: "good-node is unschedulable", + expectedStatusCode: framework.UnschedulableAndUnresolvable, + }, + { + name: "min statusCode", + guestURL: test.URLTestPostFilterFromGlobal, + pod: test.PodSmall, + globals: map[string]int32{"status_code": math.MinInt32}, + expectedResult: &framework.PostFilterResult{NominatingInfo: &framework.NominatingInfo{NominatedNodeName: "", NominatingMode: 0}}, + expectedStatusCode: math.MinInt32, + }, + { + name: "max statusCode", + guestURL: test.URLTestPostFilterFromGlobal, + pod: test.PodSmall, + globals: map[string]int32{"status_code": math.MaxInt32}, + expectedResult: &framework.PostFilterResult{NominatingInfo: &framework.NominatingInfo{NominatedNodeName: "", NominatingMode: 0}}, + expectedStatusCode: math.MaxInt32, + }, + { + name: "min nominatingMode", + guestURL: test.URLTestPostFilterFromGlobal, + pod: test.PodSmall, + globals: map[string]int32{"nominating_mode": math.MinInt32}, + expectedResult: &framework.PostFilterResult{NominatingInfo: &framework.NominatingInfo{NominatedNodeName: "", NominatingMode: math.MinInt32}}, + expectedStatusCode: framework.Success, + }, + { + name: "max nominatingMode", + guestURL: test.URLTestPostFilterFromGlobal, + pod: test.PodSmall, + globals: map[string]int32{"nominating_mode": math.MaxInt32}, + expectedResult: &framework.PostFilterResult{NominatingInfo: &framework.NominatingInfo{NominatedNodeName: "", NominatingMode: math.MaxInt32}}, + expectedStatusCode: framework.Success, + }, + { + name: "panic", + guestURL: test.URLErrorPanicOnPostFilter, + pod: test.PodSmall, + expectedStatusCode: framework.Error, + expectedStatusMessage: `wasm: postfilter error: panic! +wasm error: unreachable +wasm stack trace: + panic_on_postfilter.$1() i64`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + guestURL := tc.guestURL + if guestURL == "" { + guestURL = test.URLTestFilter + } + + p, err := wasm.NewFromConfig(ctx, "wasm", wasm.PluginConfig{GuestURL: guestURL, Args: tc.args}) + if err != nil { + t.Fatal(err) + } + defer p.(io.Closer).Close() + + if len(tc.globals) > 0 { + pl := wasm.NewTestWasmPlugin(p) + pl.SetGlobals(tc.globals) + } + + result, status := p.(framework.PostFilterPlugin).PostFilter(ctx, nil, tc.pod, tc.nodeToStatusMap) + if want, have := tc.expectedResult, result; !reflect.DeepEqual(want, have) { + t.Fatalf("unexpected result: want %#v, have %#v", want.NominatingInfo, have.NominatingInfo) + } + if want, have := tc.expectedStatusMessage, status.Message(); want != have { + t.Fatalf("unexpected status message: want %v, have %v", want, have) + } + if want, have := tc.expectedStatusCode, status.Code(); want != have { + t.Fatalf("unexpected status code: want %v, have %v", want, have) + } + }) + } +} + func TestPreScore(t *testing.T) { tests := []struct { name string diff --git a/scheduler/test/testdata.go b/scheduler/test/testdata.go index 52219ca3..058d290f 100644 --- a/scheduler/test/testdata.go +++ b/scheduler/test/testdata.go @@ -26,6 +26,8 @@ var URLErrorPanicOnPreFilter = localURL(pathWatError("panic_on_prefilter")) var URLErrorPanicOnFilter = localURL(pathWatError("panic_on_filter")) +var URLErrorPanicOnPostFilter = localURL(pathWatError("panic_on_postfilter")) + var URLErrorPanicOnPreScore = localURL(pathWatError("panic_on_prescore")) var URLErrorPreScoreWithoutScore = localURL(pathWatError("prescore_without_score")) @@ -48,6 +50,8 @@ var URLTestFilter = localURL(pathTinyGoTest("filter")) var URLTestFilterFromGlobal = localURL(pathWatTest("filter_from_global")) +var URLTestPostFilterFromGlobal = localURL(pathWatTest("postfilter_from_global")) + var URLTestPreScoreFromGlobal = localURL(pathWatTest("prescore_from_global")) var URLTestScore = localURL(pathTinyGoTest("score")) @@ -64,8 +68,10 @@ var NodeReal = func() *v1.Node { return &node }() +var NodeSmallName = "good-node" + // NodeSmall is the smallest node that works with URLExampleFilterSimple. -var NodeSmall = &v1.Node{ObjectMeta: apimeta.ObjectMeta{Name: "good-node"}} +var NodeSmall = &v1.Node{ObjectMeta: apimeta.ObjectMeta{Name: NodeSmallName}} //go:embed testdata/yaml/pod.yaml var yamlPodReal string diff --git a/scheduler/test/testdata/error/panic_on_postfilter.wasm b/scheduler/test/testdata/error/panic_on_postfilter.wasm new file mode 100644 index 00000000..512a3f64 Binary files /dev/null and b/scheduler/test/testdata/error/panic_on_postfilter.wasm differ diff --git a/scheduler/test/testdata/error/panic_on_postfilter.wat b/scheduler/test/testdata/error/panic_on_postfilter.wat new file mode 100644 index 00000000..069461f4 --- /dev/null +++ b/scheduler/test/testdata/error/panic_on_postfilter.wat @@ -0,0 +1,32 @@ +;; panic_on_postfilter is a postfilter which issues an unreachable instruction +;; after writing an error to stdout. This simulates a panic in TinyGo. +(module $panic_on_postfilter + ;; Import the fd_write function from wasi, used in TinyGo for println. + (import "wasi_snapshot_preview1" "fd_write" + (func $wasi.fd_write (param $fd i32) (param $iovs i32) (param $iovs_len i32) (param $result.size i32) (result (;errno;) i32))) + + ;; Allocate the minimum amount of memory, 1 page (64KB). + (memory (export "memory") 1 1) + + ;; Pre-populate memory with the panic message, in iovec format + (data (i32.const 0) "\08") ;; iovs[0].offset + (data (i32.const 4) "\06") ;; iovs[0].length + (data (i32.const 8) "panic!") ;; iovs[0] + + ;; On postfilter, write "panic!" to stdout and crash. + (func (export "postfilter") (result i64) + ;; Write the panic to stdout via its iovec [offset, len]. + (call $wasi.fd_write + (i32.const 1) ;; stdout + (i32.const 0) ;; where's the iovec + (i32.const 1) ;; only one iovec + (i32.const 0) ;; overwrite the iovec with the ignored result. + ) + drop ;; ignore the errno returned + + ;; Issue the unreachable instruction instead of returning a code + (unreachable)) + + ;; We require exporting filter with postfilter + (func (export "filter") (result i32) (unreachable)) +) diff --git a/scheduler/test/testdata/test/all_noop.wasm b/scheduler/test/testdata/test/all_noop.wasm index 151c227e..50dfa3f2 100644 Binary files a/scheduler/test/testdata/test/all_noop.wasm and b/scheduler/test/testdata/test/all_noop.wasm differ diff --git a/scheduler/test/testdata/test/all_noop.wat b/scheduler/test/testdata/test/all_noop.wat index d505c42a..7c80ae68 100644 --- a/scheduler/test/testdata/test/all_noop.wat +++ b/scheduler/test/testdata/test/all_noop.wat @@ -6,5 +6,6 @@ (func (export "prefilter") (result i32) (return (i32.const 0))) (func (export "filter") (result i32) (return (i32.const 0))) + (func (export "postfilter") (result i32) (return (i32.const 0))) (func (export "score") (result i64) (return (i64.const 0))) ) diff --git a/scheduler/test/testdata/test/postfilter_from_global.wasm b/scheduler/test/testdata/test/postfilter_from_global.wasm new file mode 100644 index 00000000..29ee71f2 Binary files /dev/null and b/scheduler/test/testdata/test/postfilter_from_global.wasm differ diff --git a/scheduler/test/testdata/test/postfilter_from_global.wat b/scheduler/test/testdata/test/postfilter_from_global.wat new file mode 100644 index 00000000..208cb2f4 --- /dev/null +++ b/scheduler/test/testdata/test/postfilter_from_global.wat @@ -0,0 +1,33 @@ +;; postfilter_from_global lets us test the value range of nominating_mode and status_code +(module $postfilter_from_global + + ;; Allocate the minimum amount of memory, 1 page (64KB). + (memory (export "memory") 1 1) + + ;; nominating_mode is set by the host. + (global $nominating_mode (export "nominating_mode_global") (mut i32) (i32.const 0)) + ;; status_code is set by the host. + (global $status_code (export "status_code_global") (mut i32) (i32.const 0)) + + (func (export "postfilter") (result i64) + ;; var nominating_mode int32 + (local $nominating_mode i32) + + ;; var status_code int32 + (local $status_code i32) + + ;; nominating_mode = global.nominating_mode + (local.set $nominating_mode (global.get $nominating_mode)) + + ;; status_code = global.status_code + (local.set $status_code (global.get $status_code)) + + ;; return uint64(nominating_mode) << 32 | uint64(status_code) + (return + (i64.or + (i64.shl (i64.extend_i32_u (local.get $nominating_mode)) (i64.const 32)) + (i64.extend_i32_u (local.get $status_code))))) + + ;; We require exporting filter with postfilter + (func (export "filter") (result i32) (unreachable)) +)