From 37d9c84d9c1fcde6380761634e771f6fd06bfa80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9A=96=E9=98=B3?= <532764597@qq.com> Date: Mon, 19 Aug 2024 18:56:20 +0800 Subject: [PATCH] feat(codec): Unknown Service Handler (#1321) --- pkg/remote/option.go | 4 +- pkg/unknownservice/service/unknown_service.go | 85 ++++++++ pkg/unknownservice/unknownservice_codec.go | 197 ++++++++++++++++++ .../unknownservice_codec_test.go | 119 +++++++++++ server/option.go | 16 +- server/server.go | 18 +- 6 files changed, 434 insertions(+), 5 deletions(-) create mode 100644 pkg/unknownservice/service/unknown_service.go create mode 100644 pkg/unknownservice/unknownservice_codec.go create mode 100644 pkg/unknownservice/unknownservice_codec_test.go diff --git a/pkg/remote/option.go b/pkg/remote/option.go index 71e81c2500..76ddcccdec 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -7,7 +7,6 @@ * * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and @@ -27,6 +26,7 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/pkg/unknownservice/service" ) // Option is used to pack the inbound and outbound handlers. @@ -113,6 +113,8 @@ type ServerOption struct { GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error + UnknownServiceHandler service.UnknownServiceHandler + Option // invoking chain with recv/send middlewares for streaming APIs diff --git a/pkg/unknownservice/service/unknown_service.go b/pkg/unknownservice/service/unknown_service.go new file mode 100644 index 0000000000..a599877d2f --- /dev/null +++ b/pkg/unknownservice/service/unknown_service.go @@ -0,0 +1,85 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +const ( + // UnknownService name + UnknownService = "$UnknownService" // private as "$" + // UnknownMethod name + UnknownMethod = "$UnknownMethod" +) + +type Args struct { + Request []byte + Method string + ServiceName string +} + +type Result struct { + Success []byte + Method string + ServiceName string +} + +type UnknownServiceHandler interface { + UnknownServiceHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error) +} + +// NewServiceInfo creates a new ServiceInfo containing unknown methods +func NewServiceInfo(pcType serviceinfo.PayloadCodec, service, method string) *serviceinfo.ServiceInfo { + methods := map[string]serviceinfo.MethodInfo{ + method: serviceinfo.NewMethodInfo(callHandler, newServiceArgs, newServiceResult, false), + } + handlerType := (*UnknownServiceHandler)(nil) + + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: service, + HandlerType: handlerType, + Methods: methods, + PayloadCodec: pcType, + Extra: make(map[string]interface{}), + } + + return svcInfo +} + +func callHandler(ctx context.Context, handler, arg, result interface{}) error { + realArg := arg.(*Args) + realResult := result.(*Result) + realResult.Method = realArg.Method + realResult.ServiceName = realArg.ServiceName + success, err := handler.(UnknownServiceHandler).UnknownServiceHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newServiceArgs() interface{} { + return &Args{} +} + +func newServiceResult() interface{} { + return &Result{} +} diff --git a/pkg/unknownservice/unknownservice_codec.go b/pkg/unknownservice/unknownservice_codec.go new file mode 100644 index 0000000000..4ca99fc23e --- /dev/null +++ b/pkg/unknownservice/unknownservice_codec.go @@ -0,0 +1,197 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unknownservice + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + + gthrift "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service" +) + +// UnknownCodec implements PayloadCodec +type unknownServiceCodec struct { + Codec remote.PayloadCodec +} + +// NewUnknownServiceCodec creates the unknown binary codec. +func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec { + return &unknownServiceCodec{code} +} + +// Marshal implements the remote.PayloadCodec interface. +func (c unknownServiceCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { + ink := msg.RPCInfo().Invocation() + data := msg.Data() + + res, ok := data.(*unknownservice.Result) + if !ok { + return c.Codec.Marshal(ctx, msg, out) + } + if msg.MessageType() == remote.Exception { + return c.Codec.Marshal(ctx, msg, out) + } + if ink, ok := ink.(rpcinfo.InvocationSetter); ok { + ink.SetMethodName(res.Method) + ink.SetServiceName(res.ServiceName) + } else { + return errors.New("the interface Invocation doesn't implement InvocationSetter") + } + + if res.Success == nil { + sz := gthrift.Binary.MessageBeginLength(msg.RPCInfo().Invocation().MethodName()) + if msg.ProtocolInfo().CodecType == serviceinfo.Thrift { + sz += gthrift.Binary.FieldStopLength() + buf, err := out.Malloc(sz) + if err != nil { + return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err)) + } + buf = gthrift.Binary.AppendMessageBegin(buf[:0], + msg.RPCInfo().Invocation().MethodName(), gthrift.TMessageType(msg.MessageType()), msg.RPCInfo().Invocation().SeqID()) + buf = gthrift.Binary.AppendFieldStop(buf) + _ = buf + } + + if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf { + buf, err := out.Malloc(sz) + if err != nil { + return perrors.NewProtocolError(fmt.Errorf("binary thrift generic marshal, remote.ByteBuffer Malloc err: %w", err)) + } + binary.BigEndian.PutUint32(buf, codec.ProtobufV1Magic+uint32(msg.MessageType())) + offset := 4 + offset += gthrift.Binary.WriteString(buf[offset:], res.Method) + offset += gthrift.Binary.WriteI32(buf[offset:], msg.RPCInfo().Invocation().SeqID()) + _ = buf + } + return nil + } + out.WriteBinary(res.Success) + return nil +} + +// Unmarshal implements the remote.PayloadCodec interface. +func (c unknownServiceCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { + ink := message.RPCInfo().Invocation() + magicAndMsgType, err := codec.PeekUint32(in) + if err != nil { + return err + } + msgType := magicAndMsgType & codec.FrontMask + if msgType == uint32(remote.Exception) { + return c.Codec.Unmarshal(ctx, message, in) + } + if err = codec.UpdateMsgType(msgType, message); err != nil { + return err + } + service, method, err := readDecode(message, in) + if err != nil { + return err + } + err = codec.SetOrCheckMethodName(method, message) + var te *remote.TransError + if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) { + svcInfo, err := message.SpecifyServiceInfo(unknownservice.UnknownService, unknownservice.UnknownMethod) + if err != nil { + return err + } + + if ink, ok := ink.(rpcinfo.InvocationSetter); ok { + ink.SetMethodName(unknownservice.UnknownMethod) + ink.SetPackageName(svcInfo.GetPackageName()) + ink.SetServiceName(unknownservice.UnknownService) + } else { + return errors.New("the interface Invocation doesn't implement InvocationSetter") + } + if err = codec.NewDataIfNeeded(unknownservice.UnknownMethod, message); err != nil { + return err + } + + data := message.Data() + + if data, ok := data.(*unknownservice.Args); ok { + data.Method = method + data.ServiceName = service + buf, err := in.Next(in.ReadableLen()) + if err != nil { + return err + } + data.Request = buf + } + return nil + } + + return c.Codec.Unmarshal(ctx, message, in) +} + +// Name implements the remote.PayloadCodec interface. +func (c unknownServiceCodec) Name() string { + return "unknownServiceCodec" +} + +func readDecode(message remote.Message, in remote.ByteBuffer) (string, string, error) { + code := message.ProtocolInfo().CodecType + if code == serviceinfo.Thrift || code == serviceinfo.Protobuf { + method, size, err := peekMethod(in) + if err != nil { + return "", "", err + } + + seqID, err := peekSeqID(in, size) + if err != nil { + return "", "", err + } + if err = codec.SetOrCheckSeqID(seqID, message); err != nil { + return "", "", err + } + return message.RPCInfo().Invocation().ServiceName(), method, nil + } + return "", "", nil +} + +func peekMethod(in remote.ByteBuffer) (string, int32, error) { + buf, err := in.Peek(8) + if err != nil { + return "", 0, err + } + buf = buf[4:] + size := int32(binary.BigEndian.Uint32(buf)) + buf, err = in.Peek(int(size + 8)) + if err != nil { + return "", 0, perrors.NewProtocolError(err) + } + buf = buf[8:] + method := string(buf) + return method, size + 8, nil +} + +func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) { + buf, err := in.Peek(int(size + 4)) + if err != nil { + return 0, perrors.NewProtocolError(err) + } + buf = buf[size:] + seqID := int32(binary.BigEndian.Uint32(buf)) + return seqID, nil +} diff --git a/pkg/unknownservice/unknownservice_codec_test.go b/pkg/unknownservice/unknownservice_codec_test.go new file mode 100644 index 0000000000..3a7f16d7b2 --- /dev/null +++ b/pkg/unknownservice/unknownservice_codec_test.go @@ -0,0 +1,119 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unknownservice + +import ( + "context" + "testing" + + gthrift "github.com/cloudwego/gopkg/protocol/thrift" + "github.com/cloudwego/kitex/internal/mocks" + mt "github.com/cloudwego/kitex/internal/mocks/thrift" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/unknownservice/service" + "github.com/cloudwego/kitex/transport" + "github.com/cloudwego/netpoll" +) + +var ( + thr = thrift.NewThriftCodec() + payloadCodec = unknownServiceCodec{thr} + svcInfo = mocks.ServiceInfo() +) + +func TestNormal(t *testing.T) { + sendMsg := initSendMsg(transport.TTHeader) + buf := netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) + ctx := context.Background() + err := payloadCodec.Marshal(ctx, sendMsg, buf) + test.Assert(t, err == nil, err) + err = buf.Flush() + test.Assert(t, err == nil, err) + recvMsg := initRecvMsg() + recvMsg.SetPayloadLen(buf.ReadableLen()) + _, size, err := peekMethod(buf) + test.Assert(t, err == nil, err) + err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + test.Assert(t, err == nil, err) + + req := (sendMsg.Data()).(*service.Result).Success + resp := (recvMsg.Data()).(*service.Args).Request + req = req[size+4:] + resp = resp[size+4:] + for i, item := range req { + test.Assert(t, item == resp[i]) + } + var _req mt.MockTestArgs + var _resp mt.MockTestArgs + reqMsg, err := _req.FastRead(req) + test.Assert(t, err == nil, err) + respMsg, err := _resp.FastRead(resp) + test.Assert(t, err == nil && reqMsg == respMsg, err) + test.Assert(t, len(_req.Req.StrList) == len(_resp.Req.StrList)) + test.Assert(t, len(_req.Req.StrMap) == len(_resp.Req.StrList)) + for i, item := range _req.Req.StrList { + test.Assert(t, item == _resp.Req.StrList[i]) + } + for k := range _resp.Req.StrMap { + test.Assert(t, _req.Req.StrMap[k] == _resp.Req.StrMap[k]) + } +} + +func initSendMsg(tp transport.Protocol) remote.Message { + var _args mt.MockTestArgs + _args.Req = prepareReq() + length := _args.BLength() + method := "mock" + size := gthrift.Binary.MessageBeginLength(method) + bytes := make([]byte, length+size) + offset := gthrift.Binary.WriteMessageBegin(bytes, method, int32(remote.Call), 1) + _args.FastWriteNocopy(bytes[offset:], nil) + arg := service.Result{Success: bytes, Method: method, ServiceName: ""} + ink := rpcinfo.NewInvocation("", service.UnknownMethod) + ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) + + msg := remote.NewMessage(&arg, svcInfo, ri, remote.Call, remote.Client) + + msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) + return msg +} + +func initRecvMsg() remote.Message { + arg := service.Args{Request: make([]byte, 0), Method: "mock", ServiceName: ""} + ink := rpcinfo.NewInvocation("", service.UnknownMethod) + ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) + svc := service.NewServiceInfo(svcInfo.PayloadCodec, service.UnknownService, service.UnknownMethod) + msg := remote.NewMessage(&arg, svc, ri, remote.Call, remote.Server) + return msg +} + +func prepareReq() *mt.MockReq { + strMap := make(map[string]string) + strMap["key1"] = "val1" + strMap["key2"] = "val2" + strList := []string{"str1", "str2"} + req := &mt.MockReq{ + Msg: "MockReq", + StrMap: strMap, + StrList: strList, + } + return req +} diff --git a/server/option.go b/server/option.go index a9c67fc37d..8c043cf398 100644 --- a/server/option.go +++ b/server/option.go @@ -22,8 +22,6 @@ import ( "net" "time" - "github.com/cloudwego/localsession/backup" - internal_server "github.com/cloudwego/kitex/internal/server" "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/klog" @@ -39,7 +37,10 @@ import ( "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" "github.com/cloudwego/kitex/pkg/streaming" + "github.com/cloudwego/kitex/pkg/unknownservice" + unknown "github.com/cloudwego/kitex/pkg/unknownservice/service" "github.com/cloudwego/kitex/pkg/utils" + "github.com/cloudwego/localsession/backup" ) // Option is the only way to config server. @@ -348,6 +349,17 @@ func WithGRPCUnknownServiceHandler(f func(ctx context.Context, methodName string }} } +// WithUnknownServiceHandler Inject an implementation of a method for handling unknown requests +// supporting only Thrift and Kitex protobuf protocols +func WithUnknownServiceHandler(f unknown.UnknownServiceHandler) Option { + return Option{F: func(o *internal_server.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("WithUnknownServiceHandler(%+v)", utils.GetFuncName(f))) + o.RemoteOpt.UnknownServiceHandler = f + remote.PutPayloadCode(serviceinfo.Thrift, unknownservice.NewUnknownServiceCodec(thrift.NewThriftCodec())) + remote.PutPayloadCode(serviceinfo.Protobuf, unknownservice.NewUnknownServiceCodec(protobuf.NewProtobufCodec())) + }} +} + // Deprecated: Use WithConnectionLimiter instead. func WithConcurrencyLimiter(conLimit limiter.ConcurrencyLimiter) Option { return Option{F: func(o *internal_server.Options, di *utils.Slice) { diff --git a/server/server.go b/server/server.go index 2556b1c81b..dcb2a4aa05 100644 --- a/server/server.go +++ b/server/server.go @@ -27,8 +27,6 @@ import ( "sync" "time" - "github.com/cloudwego/localsession/backup" - internal_server "github.com/cloudwego/kitex/internal/server" "github.com/cloudwego/kitex/pkg/acl" "github.com/cloudwego/kitex/pkg/diagnosis" @@ -45,6 +43,8 @@ import ( "github.com/cloudwego/kitex/pkg/rpcinfo" "github.com/cloudwego/kitex/pkg/serviceinfo" "github.com/cloudwego/kitex/pkg/stats" + unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service" + "github.com/cloudwego/localsession/backup" ) // Server is an abstraction of an RPC server. It accepts connections and dispatches them to the service @@ -102,6 +102,7 @@ func (s *server) init() { backup.Init(s.opt.BackupOpt) s.buildInvokeChain() s.buildStreamInvokeChain() + s.registerUnknownServiceHandler() } func fillContext(opt *internal_server.Options) context.Context { @@ -546,6 +547,19 @@ func (s *server) waitExit(errCh chan error) error { } } +func (s *server) registerUnknownServiceHandler() { + if s.opt.RemoteOpt.UnknownServiceHandler != nil { + if len(s.svcs.svcMap) == 1 && s.svcs.svcMap[serviceinfo.GenericService] != nil { + panic(errors.New("generic services do not support handling of unknown methods")) + } else { + serviceInfo := unknownservice.NewServiceInfo(serviceinfo.Thrift, unknownservice.UnknownService, unknownservice.UnknownMethod) + if err := s.RegisterService(serviceInfo, s.opt.RemoteOpt.UnknownServiceHandler); err != nil { + panic(err) + } + } + } +} + func (s *server) findAndSetDefaultService() { if len(s.svcs.svcMap) == 1 { s.targetSvcInfo = getDefaultSvcInfo(s.svcs)