diff --git a/pkg/remote/option.go b/pkg/remote/option.go index 71e81c2500..281ffb605f 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -21,6 +21,8 @@ import ( "net" "time" + "github.com/cloudwego/kitex/pkg/unknownservice/service" + "github.com/cloudwego/kitex/pkg/endpoint" "github.com/cloudwego/kitex/pkg/profiler" "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" @@ -113,6 +115,8 @@ type ServerOption struct { GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error + UnknownServiceHandler service.UnknownServiceHandler + Option // invoking chain with recv/send middlewares for streaming APIs diff --git a/pkg/unknownservice/service/unknown_service.go b/pkg/unknownservice/service/unknown_service.go new file mode 100644 index 0000000000..e1acac3f33 --- /dev/null +++ b/pkg/unknownservice/service/unknown_service.go @@ -0,0 +1,85 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package service + +import ( + "context" + + "github.com/cloudwego/kitex/pkg/serviceinfo" +) + +const ( + // UnknownService name + UnknownService = "$UnknownService" // private as "$" + // UnknownMethod name + UnknownMethod = "$UnknownMethod" +) + +type Args struct { + Request []byte + Method string + ServiceName string +} + +type Result struct { + Success []byte + Method string + ServiceName string +} + +type UnknownServiceHandler interface { + UnknownServiceHandler(ctx context.Context, serviceName, method string, request []byte) ([]byte, error) +} + +// NewServiceInfo create serviceInfo +func NewServiceInfo(pcType serviceinfo.PayloadCodec, service, method string) *serviceinfo.ServiceInfo { + methods := map[string]serviceinfo.MethodInfo{ + method: serviceinfo.NewMethodInfo(callHandler, newServiceArgs, newServiceResult, false), + } + handlerType := (*UnknownServiceHandler)(nil) + + svcInfo := &serviceinfo.ServiceInfo{ + ServiceName: service, + HandlerType: handlerType, + Methods: methods, + PayloadCodec: pcType, + Extra: make(map[string]interface{}), + } + + return svcInfo +} + +func callHandler(ctx context.Context, handler, arg, result interface{}) error { + realArg := arg.(*Args) + realResult := result.(*Result) + realResult.Method = realArg.Method + realResult.ServiceName = realArg.ServiceName + success, err := handler.(UnknownServiceHandler).UnknownServiceHandler(ctx, realArg.ServiceName, realArg.Method, realArg.Request) + if err != nil { + return err + } + realResult.Success = success + return nil +} + +func newServiceArgs() interface{} { + return &Args{} +} + +func newServiceResult() interface{} { + return &Result{} +} diff --git a/pkg/unknownservice/unknown.go b/pkg/unknownservice/unknown.go new file mode 100644 index 0000000000..43cb9c626f --- /dev/null +++ b/pkg/unknownservice/unknown.go @@ -0,0 +1,236 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unknownservice + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + + "github.com/cloudwego/kitex/pkg/protocol/bthrift" + thrift "github.com/cloudwego/kitex/pkg/protocol/bthrift/apache" + + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec" + "github.com/cloudwego/kitex/pkg/remote/codec/perrors" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/serviceinfo" + unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service" +) + +// UnknownCodec implements PayloadCodec +type unknownCodec struct { + Codec remote.PayloadCodec +} + +// NewUnknownServiceCodec creates the unknown binary codec. +func NewUnknownServiceCodec(code remote.PayloadCodec) remote.PayloadCodec { + return &unknownCodec{code} +} + +// Marshal implements the remote.PayloadCodec interface. +func (c unknownCodec) Marshal(ctx context.Context, msg remote.Message, out remote.ByteBuffer) error { + ink := msg.RPCInfo().Invocation() + data := msg.Data() + + res, ok := data.(*unknownservice.Result) + if !ok { + return c.Codec.Marshal(ctx, msg, out) + } + if len(res.Success) == 0 { + return errors.New("unknown messages cannot be empty") + } + if msg.MessageType() == remote.Exception { + return c.Codec.Marshal(ctx, msg, out) + } + if ink, ok := ink.(rpcinfo.InvocationSetter); ok { + ink.SetMethodName(res.Method) + ink.SetServiceName(res.ServiceName) + } else { + return errors.New("the interface Invocation doesn't implement InvocationSetter") + } + if err := encode(res, msg, out); err != nil { + return c.Codec.Marshal(ctx, msg, out) + } + return nil +} + +// Unmarshal implements the remote.PayloadCodec interface. +func (c unknownCodec) Unmarshal(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { + ink := message.RPCInfo().Invocation() + magicAndMsgType, err := codec.PeekUint32(in) + if err != nil { + return err + } + msgType := magicAndMsgType & codec.FrontMask + if msgType == uint32(remote.Exception) { + return c.Codec.Unmarshal(ctx, message, in) + } + if err = codec.UpdateMsgType(msgType, message); err != nil { + return err + } + service, method, err := readDecode(message, in) + if err != nil { + return err + } + err = codec.SetOrCheckMethodName(method, message) + var te *remote.TransError + if errors.As(err, &te) && (te.TypeID() == remote.UnknownMethod || te.TypeID() == remote.UnknownService) { + svcInfo, err := message.SpecifyServiceInfo(unknownservice.UnknownService, unknownservice.UnknownMethod) + if err != nil { + return err + } + + if ink, ok := ink.(rpcinfo.InvocationSetter); ok { + ink.SetMethodName(unknownservice.UnknownMethod) + ink.SetPackageName(svcInfo.GetPackageName()) + ink.SetServiceName(unknownservice.UnknownService) + } else { + return errors.New("the interface Invocation doesn't implement InvocationSetter") + } + if err = codec.NewDataIfNeeded(unknownservice.UnknownMethod, message); err != nil { + return err + } + + data := message.Data() + + if data, ok := data.(*unknownservice.Args); ok { + data.Method = method + data.ServiceName = service + buf, err := in.Next(in.ReadableLen()) + if err != nil { + return err + } + data.Request = buf + } + return nil + } + + return c.Codec.Unmarshal(ctx, message, in) +} + +// Name implements the remote.PayloadCodec interface. +func (c unknownCodec) Name() string { + return "unknownMethodCodec" +} + +func write(dst, src []byte) { + copy(dst, src) +} + +func readDecode(message remote.Message, in remote.ByteBuffer) (string, string, error) { + code := message.ProtocolInfo().CodecType + if code == serviceinfo.Thrift || code == serviceinfo.Protobuf { + method, size, err := peekMethod(in) + if err != nil { + return "", "", err + } + + seqID, err := peekSeqID(in, size) + if err != nil { + return "", "", err + } + if err = codec.SetOrCheckSeqID(seqID, message); err != nil { + return "", "", err + } + return message.RPCInfo().Invocation().ServiceName(), method, nil + } + return "", "", nil +} + +func peekMethod(in remote.ByteBuffer) (string, int32, error) { + buf, err := in.Peek(8) + if err != nil { + return "", 0, err + } + buf = buf[4:] + size := int32(binary.BigEndian.Uint32(buf)) + buf, err = in.Peek(int(size + 8)) + if err != nil { + return "", 0, perrors.NewProtocolError(err) + } + buf = buf[8:] + method := string(buf) + return method, size + 8, nil +} + +func peekSeqID(in remote.ByteBuffer, size int32) (int32, error) { + buf, err := in.Peek(int(size + 4)) + if err != nil { + return 0, perrors.NewProtocolError(err) + } + buf = buf[size:] + seqID := int32(binary.BigEndian.Uint32(buf)) + return seqID, nil +} + +func encode(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error { + if msg.ProtocolInfo().CodecType == serviceinfo.Thrift { + return encodeThrift(res, msg, out) + } + if msg.ProtocolInfo().CodecType == serviceinfo.Protobuf { + return encodeKitexProtobuf(res, msg, out) + } + return nil +} + +// encodeThrift Thrift encoder +func encodeThrift(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error { + nw, _ := out.(remote.NocopyWrite) + msgType := msg.MessageType() + ink := msg.RPCInfo().Invocation() + msgBeginLen := bthrift.Binary.MessageBeginLength(res.Method, thrift.TMessageType(msgType), ink.SeqID()) + msgEndLen := bthrift.Binary.MessageEndLength() + + buf, err := out.Malloc(msgBeginLen + len(res.Success) + msgEndLen) + if err != nil { + return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, Malloc failed: %s", err.Error())) + } + offset := bthrift.Binary.WriteMessageBegin(buf, res.Method, thrift.TMessageType(msgType), ink.SeqID()) + write(buf[offset:], res.Success) + bthrift.Binary.WriteMessageEnd(buf[offset:]) + if nw == nil { + // if nw is nil, FastWrite will act in Copy mode. + return nil + } + return nw.MallocAck(out.MallocLen()) +} + +// encodeKitexProtobuf Kitex Protobuf encoder +func encodeKitexProtobuf(res *unknownservice.Result, msg remote.Message, out remote.ByteBuffer) error { + ink := msg.RPCInfo().Invocation() + // 3.1 magic && msgType + if err := codec.WriteUint32(codec.ProtobufV1Magic+uint32(msg.MessageType()), out); err != nil { + return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write meta info failed: %s", err.Error())) + } + // 3.2 methodName + if _, err := codec.WriteString(res.Method, out); err != nil { + return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write method name failed: %s", err.Error())) + } + // 3.3 seqID + if err := codec.WriteUint32(uint32(ink.SeqID()), out); err != nil { + return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("protobuf marshal, write seqID failed: %s", err.Error())) + } + dataLen := len(res.Success) + buf, err := out.Malloc(dataLen) + if err != nil { + return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("protobuf malloc size %d failed: %s", dataLen, err.Error())) + } + write(buf, res.Success) + return nil +} diff --git a/pkg/unknownservice/unknown_test.go b/pkg/unknownservice/unknown_test.go new file mode 100644 index 0000000000..87bc0e9fb0 --- /dev/null +++ b/pkg/unknownservice/unknown_test.go @@ -0,0 +1,115 @@ +/* + * Copyright 2021 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unknownservice + +import ( + "context" + "testing" + + "github.com/cloudwego/netpoll" + + "github.com/cloudwego/kitex/internal/mocks" + mt "github.com/cloudwego/kitex/internal/mocks/thrift" + "github.com/cloudwego/kitex/internal/test" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/codec/thrift" + netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" + "github.com/cloudwego/kitex/pkg/rpcinfo" + "github.com/cloudwego/kitex/pkg/unknownservice/service" + "github.com/cloudwego/kitex/transport" +) + +var ( + thr = thrift.NewThriftCodec() + payloadCodec = unknownCodec{thr} + svcInfo = mocks.ServiceInfo() +) + +func TestNormal(t *testing.T) { + sendMsg := initSendMsg(transport.TTHeader) + buf := netpolltrans.NewReaderWriterByteBuffer(netpoll.NewLinkBuffer(1024)) + ctx := context.Background() + err := payloadCodec.Marshal(ctx, sendMsg, buf) + test.Assert(t, err == nil, err) + err = buf.Flush() + test.Assert(t, err == nil, err) + recvMsg := initRecvMsg() + recvMsg.SetPayloadLen(buf.ReadableLen()) + _, size, err := peekMethod(buf) + test.Assert(t, err == nil, err) + err = payloadCodec.Unmarshal(ctx, recvMsg, buf) + test.Assert(t, err == nil, err) + + req := (sendMsg.Data()).(*service.Result).Success + resp := (recvMsg.Data()).(*service.Args).Request + resp = resp[size+4:] + for i, item := range req { + test.Assert(t, item == resp[i]) + } + var _req mt.MockTestArgs + var _resp mt.MockTestArgs + reqMsg, err := _req.FastRead(req) + test.Assert(t, err == nil, err) + respMsg, err := _resp.FastRead(resp) + test.Assert(t, err == nil && reqMsg == respMsg, err) + test.Assert(t, len(_req.Req.StrList) == len(_resp.Req.StrList)) + test.Assert(t, len(_req.Req.StrMap) == len(_resp.Req.StrList)) + for i, item := range _req.Req.StrList { + test.Assert(t, item == _resp.Req.StrList[i]) + } + for k := range _resp.Req.StrMap { + test.Assert(t, _req.Req.StrMap[k] == _resp.Req.StrMap[k]) + } +} + +func initSendMsg(tp transport.Protocol) remote.Message { + var _args mt.MockTestArgs + _args.Req = prepareReq() + length := _args.BLength() + bytes := make([]byte, length) + _args.FastWriteNocopy(bytes, nil) + arg := service.Result{Success: bytes, Method: "mock", ServiceName: ""} + ink := rpcinfo.NewInvocation("", service.UnknownMethod) + ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) + + msg := remote.NewMessage(&arg, svcInfo, ri, remote.Call, remote.Client) + + msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec)) + return msg +} + +func initRecvMsg() remote.Message { + arg := service.Args{Request: make([]byte, 0), Method: "mock", ServiceName: ""} + ink := rpcinfo.NewInvocation("", service.UnknownMethod) + ri := rpcinfo.NewRPCInfo(nil, nil, ink, nil, nil) + svc := service.NewServiceInfo(svcInfo.PayloadCodec, service.UnknownService, service.UnknownMethod) + msg := remote.NewMessage(&arg, svc, ri, remote.Call, remote.Server) + return msg +} + +func prepareReq() *mt.MockReq { + strMap := make(map[string]string) + strMap["key1"] = "val1" + strMap["key2"] = "val2" + strList := []string{"str1", "str2"} + req := &mt.MockReq{ + Msg: "MockReq", + StrMap: strMap, + StrList: strList, + } + return req +} diff --git a/server/option.go b/server/option.go index a9c67fc37d..8b441f6235 100644 --- a/server/option.go +++ b/server/option.go @@ -22,6 +22,9 @@ import ( "net" "time" + "github.com/cloudwego/kitex/pkg/unknownservice" + unknown "github.com/cloudwego/kitex/pkg/unknownservice/service" + "github.com/cloudwego/localsession/backup" internal_server "github.com/cloudwego/kitex/internal/server" @@ -348,6 +351,17 @@ func WithGRPCUnknownServiceHandler(f func(ctx context.Context, methodName string }} } +// WithUnknownServiceHandler Inject an implementation of a method for handling unknown requests +// supporting only Thrift and Kitex protobuf protocols +func WithUnknownServiceHandler(f unknown.UnknownServiceHandler) Option { + return Option{F: func(o *internal_server.Options, di *utils.Slice) { + di.Push(fmt.Sprintf("WithUnknownMethodHandler(%+v)", utils.GetFuncName(f))) + o.RemoteOpt.UnknownServiceHandler = f + remote.PutPayloadCode(serviceinfo.Thrift, unknownservice.NewUnknownServiceCodec(thrift.NewThriftCodec())) + remote.PutPayloadCode(serviceinfo.Protobuf, unknownservice.NewUnknownServiceCodec(protobuf.NewProtobufCodec())) + }} +} + // Deprecated: Use WithConnectionLimiter instead. func WithConcurrencyLimiter(conLimit limiter.ConcurrencyLimiter) Option { return Option{F: func(o *internal_server.Options, di *utils.Slice) { diff --git a/server/server.go b/server/server.go index 2556b1c81b..5feec3c1d8 100644 --- a/server/server.go +++ b/server/server.go @@ -27,6 +27,7 @@ import ( "sync" "time" + unknownservice "github.com/cloudwego/kitex/pkg/unknownservice/service" "github.com/cloudwego/localsession/backup" internal_server "github.com/cloudwego/kitex/internal/server" @@ -102,6 +103,7 @@ func (s *server) init() { backup.Init(s.opt.BackupOpt) s.buildInvokeChain() s.buildStreamInvokeChain() + s.registerUnknownServiceHandler() } func fillContext(opt *internal_server.Options) context.Context { @@ -546,6 +548,19 @@ func (s *server) waitExit(errCh chan error) error { } } +func (s *server) registerUnknownServiceHandler() { + if s.opt.RemoteOpt.UnknownServiceHandler != nil { + if len(s.svcs.svcMap) == 1 && s.svcs.svcMap[serviceinfo.GenericService] != nil { + panic(errors.New("generic services do not support handling of unknown methods")) + } else { + serviceInfo := unknownservice.NewServiceInfo(serviceinfo.Thrift, unknownservice.UnknownService, unknownservice.UnknownMethod) + if err := s.RegisterService(serviceInfo, s.opt.RemoteOpt.UnknownServiceHandler); err != nil { + panic(err) + } + } + } +} + func (s *server) findAndSetDefaultService() { if len(s.svcs.svcMap) == 1 { s.targetSvcInfo = getDefaultSvcInfo(s.svcs)