From 3efa4ed1a2987d51b69d49f04d8e59f439481113 Mon Sep 17 00:00:00 2001 From: moriya Date: Sat, 4 Nov 2023 16:05:04 +0900 Subject: [PATCH] add nodeToStatus proto --- Makefile | 2 + guest/postfilter/postfilter.go | 42 +- internal/e2e/e2e.go | 5 + internal/e2e/go.mod | 3 +- internal/e2e/go.sum | 4 +- .../e2e/scheduler_perf/scheduler_perf_test.go | 2 +- kubernetes/proto/nodetostatus/api.pb.go | 102 ++++ kubernetes/proto/nodetostatus/api.proto | 21 + .../proto/nodetostatus/api_vtproto.pb.go | 561 ++++++++++++++++++ scheduler/go.mod | 3 +- scheduler/go.sum | 4 +- scheduler/plugin/host.go | 25 +- scheduler/plugin/mem.go | 33 +- 13 files changed, 792 insertions(+), 15 deletions(-) create mode 100644 kubernetes/proto/nodetostatus/api.pb.go create mode 100644 kubernetes/proto/nodetostatus/api.proto create mode 100644 kubernetes/proto/nodetostatus/api_vtproto.pb.go diff --git a/Makefile b/Makefile index 705f5c39..4e7a5a41 100644 --- a/Makefile +++ b/Makefile @@ -85,6 +85,8 @@ submodule-update: .PHONY: update-kubernetes-proto update-kubernetes-proto: proto-tools echo "Regenerate the Go protobuf code." + protoc ./kubernetes/proto/nodetostatus/api.proto --go-plugin_out=./kubernetes/proto/nodetostatus \ + --go-plugin_opt=Mkubernetes/proto/nodetostatus/api.proto=.; \ cd kubernetes/kubernetes/staging/src/; \ protoc ./k8s.io/apimachinery/pkg/api/resource/generated.proto --go-plugin_out=../../../proto \ --go-plugin_opt=Mk8s.io/apimachinery/pkg/api/resource/generated.proto=./resource; \ diff --git a/guest/postfilter/postfilter.go b/guest/postfilter/postfilter.go index 99620784..857ab6e2 100644 --- a/guest/postfilter/postfilter.go +++ b/guest/postfilter/postfilter.go @@ -25,7 +25,8 @@ import ( "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/cyclestate" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/imports" "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/plugin" - internalproto "sigs.k8s.io/kube-scheduler-wasm-extension/guest/internal/proto" + + "sigs.k8s.io/kube-scheduler-wasm-extension/kubernetes/proto/nodetostatus" ) // postfilter is the current plugin assigned with SetPlugin. @@ -75,10 +76,9 @@ func _postfilter() uint64 { //nolint return 0 } - // TODO: fix PostFilter // The parameters passed are lazy with regard to host functions. This means // a no-op plugin should not have any unmarshal penalty. - nominatedNodeName, nominatingMode, status := postfilter.PostFilter(cyclestate.Values, cyclestate.Pod, nil) + nominatedNodeName, nominatingMode, status := postfilter.PostFilter(cyclestate.Values, cyclestate.Pod, &nodeToStatusMap{}) cString := []byte(nominatedNodeName) if cString != nil { @@ -101,10 +101,42 @@ func (n *nodeToStatusMap) NodeToStatusMap() map[string]*api.Status { // lazyNodeToStatusMap returns NodeToStatusMap from imports.NodeToStatusMap. func (n *nodeToStatusMap) lazyNodeToStatusMap() map[string]*api.Status { - var msg api.NodeToStatusMap + var msg nodetostatus.NodeToStatusMap if err := imports.NodeToStatusMap(msg.UnmarshalVT); err != nil { panic(err.Error()) } - n.statusMap = &internalproto.NodeToStatusMap{Msg: &msg} + n.statusMap = convertNodeStatusMapType(msg.NodeStatus) return n.statusMap } + +func convertNodeStatusMapType(nodeToStatusMap map[string]*nodetostatus.Status) map[string]*api.Status { + converted := make(map[string]*api.Status) + + for key, value := range nodeToStatusMap { + convertedStatus := &api.Status{ + Code: convertCode(*value.Code), + Reason: *value.Reason, + } + converted[key] = convertedStatus + } + return converted +} + +func convertCode(code nodetostatus.StatusCode) api.StatusCode { + switch code { + case nodetostatus.StatusCode_STATUS_CODE_SUCCESS: + return api.StatusCodeSuccess + case nodetostatus.StatusCode_STATUS_CODE_ERROR: + return api.StatusCodeError + case nodetostatus.StatusCode_STATUS_CODE_UNSCHEDULABLE: + return api.StatusCodeUnschedulable + case nodetostatus.StatusCode_STATUS_CODE_UNSCHEDULABLE_AND_UNRESOLVABLE: + return api.StatusCodeUnschedulableAndUnresolvable + case nodetostatus.StatusCode_STATUS_CODE_WAIT: + return api.StatusCodeWait + case nodetostatus.StatusCode_STATUS_CODE_SKIP: + return api.StatusCodeSkip + default: + panic("StatusCode can't be converted") + } +} diff --git a/internal/e2e/e2e.go b/internal/e2e/e2e.go index 0174eb91..9038c284 100644 --- a/internal/e2e/e2e.go +++ b/internal/e2e/e2e.go @@ -23,6 +23,11 @@ func RunAll(ctx context.Context, t Testing, plugin framework.Plugin, pod *v1.Pod RequireSuccess(t, s) } + if postfilterP, ok := plugin.(framework.PostFilterPlugin); ok { + _, s = postfilterP.PostFilter(ctx, nil, pod, nil) + RequireSuccess(t, s) + } + if prescoreP, ok := plugin.(framework.PreScorePlugin); ok { s = prescoreP.PreScore(ctx, nil, pod, []*v1.Node{ni.Node()}) RequireSuccess(t, s) diff --git a/internal/e2e/go.mod b/internal/e2e/go.mod index e9ea32db..7ca976f5 100644 --- a/internal/e2e/go.mod +++ b/internal/e2e/go.mod @@ -165,7 +165,7 @@ require ( google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/grpc v1.51.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.31.0 // indirect gopkg.in/gcfg.v1 v1.2.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect @@ -191,6 +191,7 @@ require ( k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.1.2 // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect + sigs.k8s.io/kube-scheduler-wasm-extension/kubernetes/proto v0.0.0-20230808005812-6708ed44fd99 // indirect sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect ) diff --git a/internal/e2e/go.sum b/internal/e2e/go.sum index 335f5b0a..93cdad68 100644 --- a/internal/e2e/go.sum +++ b/internal/e2e/go.sum @@ -846,8 +846,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/e2e/scheduler_perf/scheduler_perf_test.go b/internal/e2e/scheduler_perf/scheduler_perf_test.go index 2f44b8d6..52ce1e69 100644 --- a/internal/e2e/scheduler_perf/scheduler_perf_test.go +++ b/internal/e2e/scheduler_perf/scheduler_perf_test.go @@ -91,7 +91,7 @@ var ( Metrics: map[string]*labelValues{ "scheduler_framework_extension_point_duration_seconds": { label: extensionPointsLabelName, - values: []string{"PreFilter", "Filter", "PreScore", "Score"}, + values: []string{"PreFilter", "Filter", "PostFilter", "PreScore", "Score"}, }, "scheduler_scheduling_attempt_duration_seconds": nil, "scheduler_pod_scheduling_duration_seconds": nil, diff --git a/kubernetes/proto/nodetostatus/api.pb.go b/kubernetes/proto/nodetostatus/api.pb.go new file mode 100644 index 00000000..4df434ef --- /dev/null +++ b/kubernetes/proto/nodetostatus/api.pb.go @@ -0,0 +1,102 @@ +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.14.0 +// source: kubernetes/proto/nodetostatus/api.proto + +package nodetostatus + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type StatusCode int32 + +const ( + StatusCode_STATUS_CODE_SUCCESS StatusCode = 0 + StatusCode_STATUS_CODE_ERROR StatusCode = 1 + StatusCode_STATUS_CODE_UNSCHEDULABLE StatusCode = 2 + StatusCode_STATUS_CODE_UNSCHEDULABLE_AND_UNRESOLVABLE StatusCode = 3 + StatusCode_STATUS_CODE_WAIT StatusCode = 4 + StatusCode_STATUS_CODE_SKIP StatusCode = 5 +) + +// Enum value maps for StatusCode. +var ( + StatusCode_name = map[int32]string{ + 0: "STATUS_CODE_SUCCESS", + 1: "STATUS_CODE_ERROR", + 2: "STATUS_CODE_UNSCHEDULABLE", + 3: "STATUS_CODE_UNSCHEDULABLE_AND_UNRESOLVABLE", + 4: "STATUS_CODE_WAIT", + 5: "STATUS_CODE_SKIP", + } + StatusCode_value = map[string]int32{ + "STATUS_CODE_SUCCESS": 0, + "STATUS_CODE_ERROR": 1, + "STATUS_CODE_UNSCHEDULABLE": 2, + "STATUS_CODE_UNSCHEDULABLE_AND_UNRESOLVABLE": 3, + "STATUS_CODE_WAIT": 4, + "STATUS_CODE_SKIP": 5, + } +) + +func (x StatusCode) Enum() *StatusCode { + p := new(StatusCode) + *p = x + return p +} + +type Status struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Code *StatusCode `protobuf:"varint,1,req,name=code,enum=StatusCode" json:"code,omitempty"` + Reason *string `protobuf:"bytes,2,req,name=reason" json:"reason,omitempty"` +} + +func (x *Status) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *Status) GetCode() StatusCode { + if x != nil && x.Code != nil { + return *x.Code + } + return StatusCode_STATUS_CODE_SUCCESS +} + +func (x *Status) GetReason() string { + if x != nil && x.Reason != nil { + return *x.Reason + } + return "" +} + +type NodeToStatusMap struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + NodeStatus map[string]*Status `protobuf:"bytes,1,rep,name=nodeStatus" json:"nodeStatus,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` +} + +func (x *NodeToStatusMap) ProtoReflect() protoreflect.Message { + panic(`not implemented`) +} + +func (x *NodeToStatusMap) GetNodeStatus() map[string]*Status { + if x != nil { + return x.NodeStatus + } + return nil +} diff --git a/kubernetes/proto/nodetostatus/api.proto b/kubernetes/proto/nodetostatus/api.proto new file mode 100644 index 00000000..1f412e63 --- /dev/null +++ b/kubernetes/proto/nodetostatus/api.proto @@ -0,0 +1,21 @@ +syntax = "proto2"; + +option go_package = "sigs.k8s.io/kube-scheduler-wasm-extension/kubernetes/proto/nodetostatus"; + +enum StatusCode { + STATUS_CODE_SUCCESS = 0; + STATUS_CODE_ERROR = 1; + STATUS_CODE_UNSCHEDULABLE = 2; + STATUS_CODE_UNSCHEDULABLE_AND_UNRESOLVABLE = 3; + STATUS_CODE_WAIT = 4; + STATUS_CODE_SKIP = 5; + } + +message Status { + required StatusCode code = 1; + required string reason = 2; +}; + +message NodeToStatusMap { + map nodeStatus = 1; +} \ No newline at end of file diff --git a/kubernetes/proto/nodetostatus/api_vtproto.pb.go b/kubernetes/proto/nodetostatus/api_vtproto.pb.go new file mode 100644 index 00000000..288fc7a5 --- /dev/null +++ b/kubernetes/proto/nodetostatus/api_vtproto.pb.go @@ -0,0 +1,561 @@ +// Code generated by protoc-gen-go-plugin. DO NOT EDIT. +// versions: +// protoc-gen-go-plugin v0.1.0 +// protoc v3.14.0 +// source: kubernetes/proto/nodetostatus/api.proto + +package nodetostatus + +import ( + fmt "fmt" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + io "io" + bits "math/bits" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +func (m *Status) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Status) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *Status) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if m.Reason == nil { + return 0, fmt.Errorf("proto: required field reason not set") + } else { + i -= len(*m.Reason) + copy(dAtA[i:], *m.Reason) + i = encodeVarint(dAtA, i, uint64(len(*m.Reason))) + i-- + dAtA[i] = 0x12 + } + if m.Code == nil { + return 0, fmt.Errorf("proto: required field code not set") + } else { + i = encodeVarint(dAtA, i, uint64(*m.Code)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *NodeToStatusMap) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *NodeToStatusMap) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *NodeToStatusMap) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.NodeStatus) > 0 { + for k := range m.NodeStatus { + v := m.NodeStatus[k] + baseI := i + size, err := v.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + i -= len(k) + copy(dAtA[i:], k) + i = encodeVarint(dAtA, i, uint64(len(k))) + i-- + dAtA[i] = 0xa + i = encodeVarint(dAtA, i, uint64(baseI-i)) + i-- + dAtA[i] = 0xa + } + } + return len(dAtA) - i, nil +} + +func encodeVarint(dAtA []byte, offset int, v uint64) int { + offset -= sov(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *Status) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Code != nil { + n += 1 + sov(uint64(*m.Code)) + } + if m.Reason != nil { + l = len(*m.Reason) + n += 1 + l + sov(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *NodeToStatusMap) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if len(m.NodeStatus) > 0 { + for k, v := range m.NodeStatus { + _ = k + _ = v + l = 0 + if v != nil { + l = v.SizeVT() + } + l += 1 + sov(uint64(l)) + mapEntrySize := 1 + len(k) + sov(uint64(len(k))) + l + n += mapEntrySize + 1 + sov(uint64(mapEntrySize)) + } + } + n += len(m.unknownFields) + return n +} + +func sov(x uint64) (n int) { + return (bits.Len64(x|1) + 6) / 7 +} +func soz(x uint64) (n int) { + return sov(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *Status) UnmarshalVT(dAtA []byte) error { + var hasFields [1]uint64 + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Status: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Status: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Code", wireType) + } + var v StatusCode + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= StatusCode(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.Code = &v + hasFields[0] |= uint64(0x00000001) + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Reason", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(dAtA[iNdEx:postIndex]) + m.Reason = &s + iNdEx = postIndex + hasFields[0] |= uint64(0x00000002) + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + if hasFields[0]&uint64(0x00000001) == 0 { + return fmt.Errorf("proto: required field code not set") + } + if hasFields[0]&uint64(0x00000002) == 0 { + return fmt.Errorf("proto: required field reason not set") + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *NodeToStatusMap) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: NodeToStatusMap: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: NodeToStatusMap: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field NodeStatus", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.NodeStatus == nil { + m.NodeStatus = make(map[string]*Status) + } + var mapkey string + var mapvalue *Status + for iNdEx < postIndex { + entryPreIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + if fieldNum == 1 { + var stringLenmapkey uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLenmapkey |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLenmapkey := int(stringLenmapkey) + if intStringLenmapkey < 0 { + return ErrInvalidLength + } + postStringIndexmapkey := iNdEx + intStringLenmapkey + if postStringIndexmapkey < 0 { + return ErrInvalidLength + } + if postStringIndexmapkey > l { + return io.ErrUnexpectedEOF + } + mapkey = string(dAtA[iNdEx:postStringIndexmapkey]) + iNdEx = postStringIndexmapkey + } else if fieldNum == 2 { + var mapmsglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + mapmsglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if mapmsglen < 0 { + return ErrInvalidLength + } + postmsgIndex := iNdEx + mapmsglen + if postmsgIndex < 0 { + return ErrInvalidLength + } + if postmsgIndex > l { + return io.ErrUnexpectedEOF + } + mapvalue = &Status{} + if err := mapvalue.UnmarshalVT(dAtA[iNdEx:postmsgIndex]); err != nil { + return err + } + iNdEx = postmsgIndex + } else { + iNdEx = entryPreIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > postIndex { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + m.NodeStatus[mapkey] = mapvalue + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} + +func skip(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflow + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLength + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroup + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLength + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLength = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflow = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroup = fmt.Errorf("proto: unexpected end of group") +) diff --git a/scheduler/go.mod b/scheduler/go.mod index 5604b103..9d56474b 100644 --- a/scheduler/go.mod +++ b/scheduler/go.mod @@ -43,6 +43,7 @@ require ( k8s.io/klog/v2 v2.90.1 k8s.io/kubectl v0.27.3 k8s.io/kubernetes v1.27.3 + sigs.k8s.io/kube-scheduler-wasm-extension/kubernetes/proto v0.0.0-20230808005812-6708ed44fd99 ) require ( @@ -125,7 +126,7 @@ require ( google.golang.org/appengine v1.6.7 // indirect google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect google.golang.org/grpc v1.51.0 // indirect - google.golang.org/protobuf v1.28.1 // indirect + google.golang.org/protobuf v1.31.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/scheduler/go.sum b/scheduler/go.sum index 29f7c63d..4a33194f 100644 --- a/scheduler/go.sum +++ b/scheduler/go.sum @@ -646,8 +646,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/scheduler/plugin/host.go b/scheduler/plugin/host.go index 8f5ad2b6..afebf77c 100644 --- a/scheduler/plugin/host.go +++ b/scheduler/plugin/host.go @@ -18,12 +18,15 @@ package wasm import ( "context" + "strings" "github.com/tetratelabs/wazero" wazeroapi "github.com/tetratelabs/wazero/api" v1 "k8s.io/api/core/v1" "k8s.io/klog/v2" "k8s.io/kubernetes/pkg/scheduler/framework" + + proto "sigs.k8s.io/kube-scheduler-wasm-extension/kubernetes/proto/nodetostatus" ) const ( @@ -198,8 +201,9 @@ func k8sApiNodeToStatusMapFn(ctx context.Context, mod wazeroapi.Module, stack [] buf := uint32(stack[0]) bufLimit := bufLimit(stack[1]) - // nodeToStatusMap := paramsFromContext(ctx).nodeToStatusMap - stack[0] = uint64(marshalIfUnderLimit(mod.Memory(), nil, buf, bufLimit)) + nodeToStatusMap := paramsFromContext(ctx).nodeToStatusMap + protoMap := ConvertNodeToStatusMapToProtoMap(nodeToStatusMap) + stack[0] = uint64(marshalNodeToMapIfUnderLimit(mod.Memory(), protoMap, buf, bufLimit)) } type host struct { @@ -356,3 +360,20 @@ func k8sSchedulerResultStatusReasonFn(ctx context.Context, mod wazeroapi.Module, } paramsFromContext(ctx).resultStatusReason = reason } + +func ConvertNodeToStatusMapToProtoMap(statusMap map[string]*framework.Status) *proto.NodeToStatusMap { + var protoMap proto.NodeToStatusMap + separator := "," + + for key, status := range statusMap { + code := proto.StatusCode(int32(status.Code())) + reason := strings.Join(status.Reasons(), separator) + protobufStatus := &proto.Status{ + Code: &code, + Reason: &reason, + } + protoMap.NodeStatus[key] = protobufStatus + } + + return &protoMap +} diff --git a/scheduler/plugin/mem.go b/scheduler/plugin/mem.go index e23bc9a5..70fe2aab 100644 --- a/scheduler/plugin/mem.go +++ b/scheduler/plugin/mem.go @@ -16,7 +16,11 @@ package wasm -import wazeroapi "github.com/tetratelabs/wazero/api" +import ( + wazeroapi "github.com/tetratelabs/wazero/api" + + proto "sigs.k8s.io/kube-scheduler-wasm-extension/kubernetes/proto/nodetostatus" +) // bufLimit is the possibly zero maximum length of a result value to write in // bytes. If the actual value is larger than this, nothing is written to @@ -55,6 +59,33 @@ func marshalIfUnderLimit(mem wazeroapi.Memory, vt valueType, buf uint32, bufLimi return vLen } +func marshalNodeToMapIfUnderLimit(mem wazeroapi.Memory, vt *proto.NodeToStatusMap, buf uint32, bufLimit bufLimit) int { + // First, see if the caller passed enough memory to serialize the object. + vLen := vt.SizeVT() + if vLen == 0 { + return 0 // nothing to write + } + + // Next, see if the value will fit inside the buffer. + if vLen > int(bufLimit) { + // If it doesn't fit, the caller can decide to retry with a larger + // buffer or fail. + return vLen + } + + // Now, we know the value isn't too large to fit in the buffer. Write it + // directly to the Wasm memory. + if wasmMem, ok := mem.Read(buf, uint32(vLen)); !ok { + panic("out of memory") // Bug: caller passed a length outside memory + } else if _, err := vt.MarshalToSizedBufferVT(wasmMem); err != nil { + panic(err) // Bug: in marshaller. + } + + // Success: return the bytes written, so that the caller can unmarshal from + // a sized buffer. + return vLen +} + func writeStringIfUnderLimit(mem wazeroapi.Memory, v string, buf uint32, bufLimit bufLimit) int { vLen := len(v) if vLen == 0 {