From e6cd4f058603852b12175447335383cd141d3d9d Mon Sep 17 00:00:00 2001 From: "vforfreedom96@gmail.com" Date: Sun, 7 Jul 2024 17:21:05 +0800 Subject: [PATCH] feat:add validate field ext and protobuf message validator --- .github/workflows/go.yml | 2 +- cmd/begonia/main.go | 10 + config/settings.yml | 13 +- gateway.json | 299 ------------------ gateway/endpoint.go | 4 +- gateway/exception.go | 77 ++++- gateway/exception_test.go | 47 +++ gateway/gateway.go | 25 +- gateway/gateway_test.go | 35 ++ gateway/http.go | 27 +- gateway/http_test.go | 12 +- gateway/middlewares.go | 176 ++++++++--- gateway/middlewares_test.go | 110 ++++++- gateway/protobuf.go | 2 + gateway/{grpc.go => proxy.go} | 182 +++++++++-- gateway/{grpc_test.go => proxy_test.go} | 75 ++++- {internal/pkg/routers => gateway}/routers.go | 50 +-- .../pkg/routers => gateway}/routers_test.go | 16 +- gateway/serialization.go | 41 +-- gateway/serialization_test.go | 77 +++++ gateway/types.go | 2 +- gateway/utils_test.go | 2 +- go.mod | 8 +- go.sum | 30 +- internal/biz/aksk.go | 4 +- internal/biz/aksk_test.go | 3 +- internal/biz/data_test.go | 2 +- internal/biz/endpoint/endpoint_test.go | 30 +- internal/biz/endpoint/utils.go | 3 +- internal/biz/endpoint/watcher.go | 15 +- internal/data/app.go | 2 +- internal/data/app_test.go | 10 +- internal/data/curd_test.go | 18 +- internal/middleware/auth/ak_test.go | 41 ++- internal/middleware/auth/aksk.go | 64 +++- internal/middleware/auth/apikey.go | 63 +++- internal/middleware/auth/apikey_test.go | 20 +- internal/middleware/auth/auth.go | 29 +- internal/middleware/auth/auth_test.go | 246 +++++++++++++- internal/middleware/auth/headers.go | 225 +++++++------ internal/middleware/auth/headers_test.go | 52 +-- internal/middleware/auth/jwt.go | 81 +++-- internal/middleware/auth/jwt_test.go | 52 ++- internal/middleware/auth/stream.go | 53 ++-- internal/middleware/http.go | 72 ++--- internal/middleware/http_test.go | 42 ++- internal/middleware/middleware.go | 8 + internal/middleware/middleware_test.go | 1 + internal/middleware/protobuf_validate.go | 187 +++++++++++ internal/middleware/rpc.go | 72 +++-- internal/middleware/rpc_test.go | 120 ++++++- internal/middleware/stream.go | 30 +- internal/middleware/vaildator.go | 270 ++++++---------- internal/middleware/vaildator_test.go | 277 ++++++++++++++-- internal/server/server.go | 8 +- internal/service/file_test.go | 1 + internal/service/tenant.go | 2 + testdata/desc.pb | Bin 11893 -> 14147 bytes testdata/gateway.json | 298 +++++++++++++++++ testdata/helloworld.pb | Bin 12161 -> 12234 bytes testdata/helloworld.proto | 102 ++++++ testdata/options.proto | 6 +- testdata/test.proto | 5 +- 63 files changed, 2787 insertions(+), 1049 deletions(-) delete mode 100644 gateway.json create mode 100644 gateway/gateway_test.go rename gateway/{grpc.go => proxy.go} (58%) rename gateway/{grpc_test.go => proxy_test.go} (63%) rename {internal/pkg/routers => gateway}/routers.go (74%) rename {internal/pkg/routers => gateway}/routers_test.go (75%) create mode 100644 internal/middleware/protobuf_validate.go create mode 100644 testdata/helloworld.proto diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 89381f3..d8fae3f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -73,7 +73,7 @@ jobs: version: "22.2" - name: Install protoc-gen-grpc-gateway run: | - git clone https://github.com/geebytes/grpc-gateway.git + git clone --recursive https://github.com/geebytes/grpc-gateway.git cd grpc-gateway go install ./protoc-gen-grpc-gateway - name: Test diff --git a/cmd/begonia/main.go b/cmd/begonia/main.go index 8002b6b..e3e499a 100644 --- a/cmd/begonia/main.go +++ b/cmd/begonia/main.go @@ -205,4 +205,14 @@ func main() { if err := cmd.Execute(); err != nil { log.Fatalf("failed to start begonia: %v", err) } + // env, _ := cmd.Flags().GetString("env") + // cnf, err := cmd.Flags().GetString("config") + // if err != nil { + // log.Fatalf("failed to get config: %v", err) + // } + // config := config.ReadConfigWithDir("dev", "/data/work/begonia-org/begonia/config/settings.yml") + // worker := internal.New(config, gateway.Log, "127.0.0.1:12138") + // hd, _ := os.UserHomeDir() + // _ = os.WriteFile(hd+"/.begonia/gateway.json", []byte(fmt.Sprintf(`{"addr":"http://%s"}`, "127.0.0.1:12138")), 0666) + // worker.Start() } diff --git a/config/settings.yml b/config/settings.yml index 0a5b32e..6d524f1 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -11,8 +11,8 @@ file: engines: - name: "FILE_ENGINE_MINIO" endpoint: "127.0.0.1:9000" - accessKey: "7OdVJK1alV8cpRMeBLBW" - secretKey: "2GNNRqqElReC1KnV3kX9jSjyLU4kwOaTZEqDS2vH" + accessKey: "rLV2Jjj2UbMWSJOhTOtZ" + secretKey: "OVyJOILwx4iVE0EVJB4CKB65j7xlhjT5q1aGCv5t" - name: "FILE_ENGINE_LOCAL" endpoint: "/data/work/begonia-org/begonia/upload" protos: @@ -70,11 +70,12 @@ gateway: - "example.com" plugins: local: - logger: 1 - exception: 0 + # 优先级越大越先执行 + exception: 4 + logger: 3 http: 2 - params_validator: 3 - auth: 4 + auth: 1 + params_validator: 0 # only_api_key_auth: 4 rpc: # - server: diff --git a/gateway.json b/gateway.json deleted file mode 100644 index 825d680..0000000 --- a/gateway.json +++ /dev/null @@ -1,299 +0,0 @@ -{ - "/helloworld.Greeter/SayHello": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "post" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/post" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHello", - "HttpUri": "/api/v1/example/post", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloBody": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "body" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/body" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHelloBody", - "HttpUri": "/api/v1/example/body", - "PathParams": [], - "InName": "HttpBody", - "OutName": "HttpBody", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "google.api", - "OutPkg": "google.api" - } - ], - "/helloworld.Greeter/SayHelloClientStream": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "client", - "stream" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/client/stream" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHelloClientStream", - "HttpUri": "/api/v1/example/client/stream", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "RepeatedReply", - "IsClientStream": true, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloError": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "error", - "test" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/error/test" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloError", - "HttpUri": "/api/v1/example/error/test", - "PathParams": [], - "InName": "ErrorRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloGet": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 1, - 0, - 4, - 1, - 5, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "name" - ], - "Verb": "", - "Fields": [ - "name" - ], - "Template": "/api/v1/example/{name}" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloGet", - "HttpUri": "/api/v1/example/{name}", - "PathParams": [ - "name" - ], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloServerSideEvent": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4, - 1, - 0, - 4, - 1, - 5, - 5 - ], - "Pool": [ - "api", - "v1", - "example", - "server", - "sse", - "name" - ], - "Verb": "", - "Fields": [ - "name" - ], - "Template": "/api/v1/example/server/sse/{name}" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloServerSideEvent", - "HttpUri": "/api/v1/example/server/sse/{name}", - "PathParams": [ - "name" - ], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": true, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloWebsocket": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "server", - "websocket" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/server/websocket" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloWebsocket", - "HttpUri": "/api/v1/example/server/websocket", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": true, - "IsServerStream": true, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ] -} \ No newline at end of file diff --git a/gateway/endpoint.go b/gateway/endpoint.go index a91e85a..219e763 100644 --- a/gateway/endpoint.go +++ b/gateway/endpoint.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "strings" loadbalance "github.com/begonia-org/go-loadbalancer" @@ -42,7 +43,7 @@ func (e *httpForwardGrpcEndpointImpl) Request(req GrpcRequest) (proto.Message, r return nil, runtime.ServerMetadata{ HeaderMD: make(map[string][]string), TrailerMD: make(map[string][]string), - }, err + }, fmt.Errorf("get conn error:%v", err) } defer e.pool.Release(req.GetContext(), cc) @@ -51,6 +52,7 @@ func (e *httpForwardGrpcEndpointImpl) Request(req GrpcRequest) (proto.Message, r in := req.GetIn() ctx := req.GetContext() err = conn.Invoke(ctx, req.GetFullMethodName(), in, out, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + // log.Printf("request %s out:%v",req.GetFullMethodName(), out.ProtoReflect().Type().Descriptor().FullName()) return out, metadata, err } diff --git a/gateway/exception.go b/gateway/exception.go index 7f3da65..017c188 100644 --- a/gateway/exception.go +++ b/gateway/exception.go @@ -9,8 +9,10 @@ import ( gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" + "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" ) type Exception struct { @@ -19,17 +21,34 @@ type Exception struct { name string } +func (e *Exception) setHeader(ctx context.Context) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + reqId := "" + if !ok || len(md.Get(XRequestID)) == 0 { + reqId = uuid.New().String() + if !ok { + md = metadata.New(make(map[string]string)) + } + md.Set(XRequestID, reqId) + ctx = metadata.NewIncomingContext(ctx, md) + + } else { + reqId = md.Get(XRequestID)[0] + } + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(XRequestID, reqId)) + + _ = grpc.SetHeader(ctx, metadata.Pairs(XRequestID, reqId)) + return ctx + +} func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { defer func() { if p := recover(); p != nil { - // buf := make([]byte, 1024) - // n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 - // stackTrace := string(buf[:n]) - // err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - // err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") err = e.handlePanic(p) } + }() + ctx = e.setHeader(ctx) resp, err = handler(ctx, req) if err == nil { return resp, err @@ -37,7 +56,7 @@ func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info return nil, err } func (e *Exception) handlePanic(p interface{}) error { - const maxFrames = 10 + const maxFrames = 15 var pcs [maxFrames]uintptr n := runtime.Callers(2, pcs[:]) // skip first 3 frames @@ -60,18 +79,50 @@ func (e *Exception) handlePanic(p interface{}) error { func (e *Exception) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { defer func() { if p := recover(); p != nil { - // buf := make([]byte, 512) - // n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 - // stackTrace := string(buf[:n]) - - // err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - // err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") - // _ = ss.SendMsg(err) err = e.handlePanic(p) } + }() + e.setHeader(ss.Context()) return handler(srv, ss) } +func (e *Exception) wrapHandlerWithPanicRecovery(handler grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) (err error) { + // reqId := "" + defer func() { + if p := recover(); p != nil { + err = e.handlePanic(p) + } + }() + + return handler(srv, stream) + } +} +func (e *Exception) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + reqId := "" + md, ok := metadata.FromOutgoingContext(ctx) + if !ok || len(md.Get(XRequestID)) == 0 { + reqID := uuid.New().String() + reqId = reqID + if !ok { + md = metadata.New(make(map[string]string)) + } + md.Set(XRequestID, reqId) + ctx = metadata.NewOutgoingContext(ctx, md) + + } else { + reqId = md.Get(XRequestID)[0] + } + + _ = grpc.SetHeader(ctx, metadata.Pairs(XRequestID, reqId)) + desc.Handler = e.wrapHandlerWithPanicRecovery(desc.Handler) + ss, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, err + + } + return ss, nil +} func NewException(log logger.Logger) *Exception { return &Exception{log: log, name: "exception"} } diff --git a/gateway/exception_test.go b/gateway/exception_test.go index be8175a..2f497bd 100644 --- a/gateway/exception_test.go +++ b/gateway/exception_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/begonia-org/go-sdk/logger" + "github.com/google/uuid" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type MiddlewaresTest struct { @@ -39,6 +41,21 @@ func (e *MiddlewaresTest) Name() string { return e.name } +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func TestUnaryInterceptor(t *testing.T) { c.Convey("TestUnaryInterceptor", t, func() { mid := NewException(Log) @@ -64,3 +81,33 @@ func TestUnaryInterceptor(t *testing.T) { }) } +func TestExceptionStreamClientInterceptor(t *testing.T) { + c.Convey("TestExceptionStreamClientInterceptor", t, func() { + mid := NewException(Log) + ctx := context.Background() + + desc := &grpc.StreamDesc{ + StreamName: "/INTEGRATION.TESTSERVICE/GET", + ClientStreams: true, + ServerStreams: true, + Handler: func(srv interface{}, ss grpc.ServerStream) error { + panic("test painc") + }, + } + st, err := mid.StreamClientInterceptor(ctx, desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: context.Background()}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + // has request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, uuid.New().String())), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: context.Background()}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, nil) + c.So(err, c.ShouldNotBeNil) + + }) +} diff --git a/gateway/gateway.go b/gateway/gateway.go index b5b6eca..926d613 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -18,7 +18,7 @@ import ( ) type GrpcServerOptions struct { - Middlewares []GrpcProxyMiddleware + Middlewares []grpc.StreamClientInterceptor Options []grpc.ServerOption PoolOptions []loadbalance.PoolOptionsBuildOption HttpMiddlewares []runtime.ServeMuxOption @@ -28,6 +28,8 @@ type GatewayConfig struct { GatewayAddr string GrpcProxyAddr string } +type DynamicGrpcItem struct { +} type GatewayServer struct { grpcServer *grpc.Server httpGateway HttpEndpoint @@ -37,14 +39,15 @@ type GatewayServer struct { proxyAddr string opts *GrpcServerOptions mux *sync.Mutex + proxy *GrpcProxy } -func NewGrpcServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *grpc.Server { +func NewGrpcProxyServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *GrpcProxy { - proxy := NewGrpcProxy(lb, opts.Middlewares...) + proxy := NewGrpcProxy(lb, Log, opts.Middlewares...) - opts.Options = append(opts.Options, grpc.UnknownServiceHandler(proxy.Handler)) - return grpc.NewServer(opts.Options...) + opts.Options = append(opts.Options, grpc.UnknownServiceHandler(proxy.Do)) + return proxy } func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) (HttpEndpoint, error) { @@ -57,7 +60,9 @@ func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) ( } func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { lb := NewGrpcLoadBalancer() - grpcServer := NewGrpcServer(opts, lb) + gProxy := NewGrpcProxyServer(opts, lb) + opts.Options = append(opts.Options, grpc.UnknownServiceHandler(gProxy.Do)) + grpcServer := grpc.NewServer(opts.Options...) _, port, _ := net.SplitHostPort(cfg.GrpcProxyAddr) proxy := fmt.Sprintf("127.0.0.1:%s", port) @@ -77,6 +82,7 @@ func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { proxyAddr: cfg.GrpcProxyAddr, opts: opts, mux: &sync.Mutex{}, + proxy: gProxy, } // }) return gatewayS @@ -97,6 +103,13 @@ func (g *GatewayServer) RegisterLocalService(ctx context.Context, pd ProtobufDes g.grpcServer.RegisterService(sd, ss) return g.httpGateway.RegisterHandlerClient(ctx, pd, g.gatewayMux) } +func (g *GatewayServer) RegisterServiceWithProxy(pd ProtobufDescription) { + g.mux.Lock() + defer g.mux.Unlock() + g.proxy.buildServiceDesc(pd) + +} + func (g *GatewayServer) DeleteLocalService(pd ProtobufDescription) { g.mux.Lock() defer g.mux.Unlock() diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go new file mode 100644 index 0000000..1283661 --- /dev/null +++ b/gateway/gateway_test.go @@ -0,0 +1,35 @@ +package gateway + +import ( + "fmt" + "log" + "os" + "path/filepath" + "testing" + + "github.com/begonia-org/begonia/internal/pkg/config" + common "github.com/begonia-org/go-sdk/common/api/v1" +) + +func readDesc(conf *config.Config) (ProtobufDescription, error) { + desc := conf.GetLocalAPIDesc() + log.Printf("read desc file:%s", desc) + bin, err := os.ReadFile(desc) + if err != nil { + return nil, fmt.Errorf("read desc file error:%w", err) + } + pd, err := NewDescriptionFromBinary(bin, filepath.Dir(desc)) + if err != nil { + return nil, err + } + err = pd.SetHttpResponse(common.E_HttpResponse) + if err != nil { + return nil, err + } + return pd, nil +} +func TestRegisterDynamicServices(t *testing.T) { + // pd, _ := readDesc(config.NewConfig(cfg.ReadConfig("test"))) + // _ = &GatewayServer{} + // gw.buildServiceDesc(pd) +} diff --git a/gateway/http.go b/gateway/http.go index db10a89..2e28889 100644 --- a/gateway/http.go +++ b/gateway/http.go @@ -5,8 +5,10 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "io" + "log" "net/http" "os" "strings" @@ -88,6 +90,9 @@ func loadHttpEndpointItem(pd ProtobufDescription, descFile string) ([]*HttpEndpo return nil, fmt.Errorf("Failed to unmarshal %s file: %w,%s", descFile, err, string(data)) } for _, binds := range items { + if len(binds) == 0 { + continue + } item := binds[0] // 设置入参和出参 item.In = pd.GetMessageTypeByName(item.InPkg, item.InName) @@ -208,7 +213,7 @@ func (h *HttpEndpointImpl) serverStreamRequest(ctx context.Context, item *HttpEn dec := marshaler.NewDecoder(req.Body) err := dec.Decode(protoReq) - if err != nil && err != io.EOF { + if err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } } @@ -368,6 +373,7 @@ func (h *HttpEndpointImpl) DeleteEndpoint(ctx context.Context, pd ProtobufDescri } return nil } + func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd ProtobufDescription, mux *runtime.ServeMux) error { h.mux.Lock() defer h.mux.Unlock() @@ -384,12 +390,17 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu // log.Printf("register endpoint %s: %s %v", strings.ToUpper(item.HttpMethod), item.HttpUri, item.Pattern) mux.Handle(strings.ToUpper(item.HttpMethod), item.Pattern, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { if req.Header.Get("accept") == "" || req.Header.Get("accept") == "*/*" { - req.Header.Set("accept", "application/json") + if item.IsServerStream && !item.IsClientStream { + req.Header.Set("accept", "text/event-stream") + } else if !item.IsClientStream && !item.IsServerStream { + req.Header.Set("accept", "application/json") + } } // log.Printf("request content-type:%s", req.Header.Get("content-type")) ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + // log.Printf("outbound marshaler:%s", outboundMarshaler.ContentType(req)) var err error var annotatedContext context.Context // 添加sha256 hash @@ -421,12 +432,12 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } resp, md, err := h.client.Request(reqInstance) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } + // log.Printf("response marshaler:%s",outboundMarshaler.ContentType(resp)) runtime.ForwardResponseMessage(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) } else if item.IsServerStream && !item.IsClientStream { // 服务端推流,升级为sse服务 @@ -436,9 +447,14 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") recv := func() (proto.Message, error) { - return resp.Recv() + + rsp, err := resp.Recv() + return rsp, err + } runtime.ForwardResponseStream(annotatedContext, mux, outboundMarshaler, w, req, recv, mux.GetForwardResponseOptions()...) } else if !item.IsServerStream && item.IsClientStream { @@ -460,6 +476,7 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } // defer ws.Close() + log.Printf("upgrade to websocket:%v", outboundMarshaler.ContentType(req)) stream, md, err := h.stream(annotatedContext, item, inboundMarshaler, ws) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) diff --git a/gateway/http_test.go b/gateway/http_test.go index b1cecfb..bfe6b37 100644 --- a/gateway/http_test.go +++ b/gateway/http_test.go @@ -51,7 +51,7 @@ var eps []loadbalance.Endpoint func newTestServer(gwPort, randomNumber int) (*GrpcServerOptions, *GatewayConfig) { opts := &GrpcServerOptions{ - Middlewares: make([]GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -68,6 +68,8 @@ func newTestServer(gwPort, randomNumber int) (*GrpcServerOptions, *GatewayConfig opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption(gwRuntime.MIMEWildcard, NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption("application/octet-stream", NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption("text/event-stream", NewEventSourceMarshaler())) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithStreamErrorHandler(HandleServerStreamError(Log))) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption(ClientStreamContentType, NewProtobufWithLengthPrefix())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMetadata(IncomingHeadersToMetadata)) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithErrorHandler(HandleErrorWithLogger(Log))) @@ -131,6 +133,7 @@ func testRegisterClient(t *testing.T) { load, err := loadbalance.New(loadbalance.RRBalanceType, endps) c.So(err, c.ShouldBeNil) + gw.RegisterServiceWithProxy(pd) err = gw.RegisterService(context.Background(), pd, load) c.So(err, c.ShouldBeNil) c.So(gw.GetLoadbalanceName(), c.ShouldEqual, loadbalance.RRBalanceType) @@ -319,8 +322,6 @@ func testRequestPost(t *testing.T) { func testServerSideEvent(t *testing.T) { c.Convey("test server side event", t, func() { url := fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/server/sse/world?msg=hello", gwPort) - // t.Logf("url:%s", url) - // time.Sleep(30 * time.Second) client := sse.NewClient(url, func(c *sse.Client) { c.ReconnectStrategy = &backoff.StopBackOff{} }) @@ -364,6 +365,7 @@ func testWebsocket(t *testing.T) { _, message, err := conn.ReadMessage() c.So(err, c.ShouldBeNil) reply := &hello.HelloReply{} + // t.Logf("read message:%s", string(message)) err = json.Unmarshal(message, reply) c.So(err, c.ShouldBeNil) c.So(reply.Message, c.ShouldEqual, fmt.Sprintf("hello-%d-%d", i, i)) @@ -661,6 +663,10 @@ func testRequestError(t *testing.T) { patch: (*GrpcLoadBalancer).Select, output: []interface{}{nil, fmt.Errorf("test select error")}, }, + { + patch: (*goloadbalancer.ConnPool).Get, + output: []interface{}{nil, fmt.Errorf("test get error")}, + }, { patch: (*GrpcProxy).forwardServerToClient, output: []interface{}{errChan}, diff --git a/gateway/middlewares.go b/gateway/middlewares.go index 29e8629..ef6ef24 100644 --- a/gateway/middlewares.go +++ b/gateway/middlewares.go @@ -2,13 +2,14 @@ package gateway import ( "context" + "fmt" "net/http" "strconv" "strings" "time" gosdk "github.com/begonia-org/go-sdk" - _ "github.com/begonia-org/go-sdk/api" + // _ "github.com/begonia-org/go-sdk/api" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" "github.com/google/uuid" @@ -189,12 +190,45 @@ func (log *LoggerMiddleware) UnaryInterceptor(ctx context.Context, req interface } }() - + // fmt.Printf("call logger mid\n") rsp, err = handler(ctx, req) elapsed := time.Since(now) + // fmt.Printf("logger error:%v", err) log.logger(ctx, info.FullMethod, err, elapsed) return } +func (log *LoggerMiddleware) wrapHandlerWithLogger(handler grpc.StreamHandler) grpc.StreamHandler { + return func(srv interface{}, ss grpc.ServerStream) (err error) { + now := time.Now() + defer func() { + if r := recover(); r != nil { + elapsed := time.Since(now) + method, _ := grpc.Method(ss.Context()) + err = fmt.Errorf("handle err:%v", r) + log.logger(ss.Context(), method, fmt.Errorf("handle err:%v", r), elapsed) + } + if md, ok := metadata.FromIncomingContext(ss.Context()); ok { + reqId := md.Get(XRequestID) + if len(reqId) > 0 { + _ = grpc.SendHeader(ss.Context(), metadata.Pairs(XRequestID, reqId[0])) + } + } + }() + err = handler(srv, ss) + elapsed := time.Since(now) + method, _ := grpc.Method(ss.Context()) + + log.logger(ss.Context(), method, err, elapsed) + return err + } + +} +func (log *LoggerMiddleware) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + + desc.Handler = log.wrapHandlerWithLogger(desc.Handler) + return streamer(ctx, desc, cc, method, opts...) + +} func (log *LoggerMiddleware) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { now := time.Now() defer func() { @@ -252,6 +286,42 @@ func clientMessageFromCode(code codes.Code) string { } } +// HandleServerStreamError handle server stream error +// +// convert error to status.Status, and get error message from error details, +// try to get error message from error details message or ToClientMessage, if not found, get error message from error code, +// if not found, return "internal error" +func HandleServerStreamError(logger logger.Logger) runtime.StreamErrorHandlerFunc { + return func(ctx context.Context, err error) *status.Status { + if st, ok := status.FromError(err); ok { + details := st.Details() + message := clientMessageFromCode(st.Code()) + for _, detail := range details { + var errDetail *common.Errors = new(common.Errors) + + if d, ok := detail.(*common.Errors); ok { + + errDetail = d + } else if anyType, ok := detail.(*anypb.Any); ok { + if err := anyType.UnmarshalTo(errDetail); err != nil { + Log.Errorf(ctx, "unmarshal error details err:%v", err) + continue + } + } + if errDetail.Message != "" { + message = errDetail.Message + } + if errDetail.ToClientMessage != "" { + message = errDetail.ToClientMessage + } + } + return status.New(st.Code(), message) + } + return status.New(codes.Internal, fmt.Sprintf("Unknown error:%v", err)) + + } + +} func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { codes := getClientMessageMap() return func(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, req *http.Request, err error) { @@ -272,48 +342,53 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { data := &common.HttpResponse{} data.Code = int32(common.Code_INTERNAL_ERROR) data.Message = "internal error" + // fmt.Printf("sse error type:%T, error:%v", err, err) if st, ok := status.FromError(err); ok { msg := st.Message() details := st.Details() data.Message = clientMessageFromCode(st.Code()) - // log.Info(ctx, fmt.Sprintf("error message:%s", data.Message)) - // data.Data = &structpb.Struct{} for _, detail := range details { - if anyType, ok := detail.(*anypb.Any); ok { - var errDetail common.Errors - if err := anyType.UnmarshalTo(&errDetail); err == nil { - rspCode := float64(errDetail.Code) - log = log.WithFields(logrus.Fields{ - "status": int(rspCode), - "file": errDetail.File, - "line": errDetail.Line, - "fn": errDetail.Fn, - }) - - msg := codes[int32(errDetail.Code)] - // log.Infof(ctx, "error message:%s,err code:%d", msg, errDetail.Code) - // log.Infof(ctx, "codes map:%v", codes) - if errDetail.ToClientMessage != "" { - msg = errDetail.ToClientMessage - } - - data.Code = errDetail.Code - data.Message = msg - data.Data = &structpb.Struct{} - break + var errDetail *common.Errors = new(common.Errors) + + if d, ok := detail.(*common.Errors); ok { + + errDetail = d + } else if anyType, ok := detail.(*anypb.Any); ok { + if err := anyType.UnmarshalTo(errDetail); err != nil { + log.Errorf(ctx, "error type:%T, error:%v", err, err) + continue } } else { log.Errorf(ctx, "error type:%T, error:%v", err, err) } + if errDetail.Message != "" { + rspCode := float64(errDetail.Code) + log = log.WithFields(logrus.Fields{ + "status": int(rspCode), + "file": errDetail.File, + "line": errDetail.Line, + "fn": errDetail.Fn, + }) + + msg := codes[int32(errDetail.Code)] + if errDetail.ToClientMessage != "" { + msg = errDetail.ToClientMessage + } + data.Code = errDetail.Code + data.Message = msg + data.Data = &structpb.Struct{} + break + } } + // fmt.Printf("error message:%s,err code:%d", data.Message, st.Code()) code = runtime.HTTPStatusFromCode(st.Code()) log.WithField("status", code).Errorf(ctx, msg) w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) - // log.Infof(ctx, "error message:%s,err code:%d", data.Message, data.Code) + log.Errorf(ctx, "error message:%s,err code:%d", data.Message, data.Code) bData, _ := protojson.Marshal(data) _, _ = w.Write(bData) return @@ -331,8 +406,9 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { for _, v := range value { w.Header().Del(key) if v != "" { + // log.Printf("http key:%s, value:%s", httpKey, v) if strings.EqualFold(httpKey, "Content-Type") { - if v == "application/grpc" { + if strings.EqualFold(v, "application/grpc") { continue } w.Header().Set(httpKey, v) @@ -347,6 +423,9 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { return } } + if strings.HasPrefix(strings.ToLower(httpKey), strings.ToLower("Grpc-")) { + continue + } headers = append(headers, http.CanonicalHeaderKey(httpKey)) w.Header().Set("Access-Control-Expose-Headers", strings.Join(headers, ",")) @@ -358,32 +437,37 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { } func HttpResponseBodyModify(ctx context.Context, w http.ResponseWriter, msg proto.Message) error { httpCode := http.StatusOK - for key, value := range w.Header() { - if strings.HasPrefix(key, "Grpc-Metadata-") { + // log.Printf("response message header:%v", w.Header()) + for key := range w.Header() { + if strings.HasPrefix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("Grpc-")) { w.Header().Del(key) + continue } - // log.Printf("send to client rsp header,key:%s,value:%s", key, value[0]) - writeHttpHeaders(w, key, value) - if strings.HasSuffix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("X-Http-Code")) { - codeStr := value[0] - code, err := strconv.ParseInt(codeStr, 10, 32) - if err != nil { - Log.Error(ctx, err) - return status.Error(codes.Internal, err.Error()) - } - httpCode = int(code) + } + if out, ok := runtime.ServerMetadataFromContext(ctx); ok { + // log.Printf("response message header from server metadata:%v", out.HeaderMD) + for key, value := range out.HeaderMD { - } + if strings.HasPrefix(strings.ToLower(key), strings.ToLower("Grpc-")) || strings.EqualFold(key, "content-type") { + continue - } + } + writeHttpHeaders(w, key, value) + if strings.HasSuffix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("X-Http-Code")) { + codeStr := value[0] + code, err := strconv.ParseInt(codeStr, 10, 32) + if err != nil { + Log.Error(ctx, err) + return status.Error(codes.Internal, err.Error()) + } + httpCode = int(code) + + } - out, ok := metadata.FromIncomingContext(ctx) - if ok { - for k, v := range out { - writeHttpHeaders(w, k, v) } } + if httpCode != http.StatusOK { w.WriteHeader(httpCode) } diff --git a/gateway/middlewares_test.go b/gateway/middlewares_test.go index 688aa41..c647f16 100644 --- a/gateway/middlewares_test.go +++ b/gateway/middlewares_test.go @@ -2,14 +2,21 @@ package gateway import ( "context" + "fmt" "net/http" "testing" + "github.com/agiledragon/gomonkey/v2" hello "github.com/begonia-org/go-sdk/api/example/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/google/uuid" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/anypb" ) type responseWriter struct { @@ -49,11 +56,46 @@ func TestLoggerMiddlewares(t *testing.T) { } c.So(f, c.ShouldNotPanic) f2 := func() { - _ = mid.StreamInterceptor(nil, &streamMock{}, &grpc.StreamServerInfo{FullMethod: "/test"}, func(srv interface{}, ss grpc.ServerStream) error { + _ = mid.StreamInterceptor(nil, &streamMock{ctx: context.Background()}, &grpc.StreamServerInfo{FullMethod: "/test"}, func(srv interface{}, ss grpc.ServerStream) error { panic("test") }) } + // f2() c.So(f2, c.ShouldNotPanic) + + desc := &grpc.StreamDesc{ + StreamName: "/INTEGRATION.TESTSERVICE/GET", + ClientStreams: true, + ServerStreams: true, + Handler: func(srv interface{}, ss grpc.ServerStream) error { + panic("test painc") + }, + } + patch := gomonkey.ApplyFuncReturn(grpc.Method, "/INTEGRATION.TESTSERVICE/GET", true) + defer patch.Reset() + st, err := mid.StreamClientInterceptor(context.Background(), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + // has request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, uuid.New().String())), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, &streamMock{ctx: context.Background()}) + c.So(err, c.ShouldNotBeNil) + // no request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, "test")), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, "test"))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, &streamMock{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs(XRequestID, "test"))}) + c.So(err, c.ShouldNotBeNil) + }) } @@ -72,7 +114,73 @@ func TestHttpResponseBodyModify(t *testing.T) { resp := &responseWriter{header: make(http.Header)} ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(XAccessKey, "123456")) + header := metadata.New(map[string]string{"Content-Type": "application/grpc", "Grpc-Metadata-key": "value", "Grpc-Key": "grpc-val"}) + patch := gomonkey.ApplyFuncReturn(runtime.ServerMetadataFromContext, runtime.ServerMetadata{HeaderMD: header, TrailerMD: metadata.New(make(map[string]string))}, true) + defer patch.Reset() resp2 := HttpResponseBodyModify(ctx, resp, &hello.HelloReply{}) c.So(resp2, c.ShouldBeNil) + for k, v := range header { + writeHttpHeaders(resp, k, v) + } + }) +} + +func TestHandleErrorWithLogger(t *testing.T) { + c.Convey("TestHandleErrorWithLogger", t, func() { + f := HandleErrorWithLogger(Log) + resp := &responseWriter{header: make(http.Header)} + req, _ := http.NewRequest("Get", "http://www.example.com", nil) + st := status.New(codes.NotFound, "not found") + srvErr := &common.Errors{ + Code: int32(common.Code_NOT_FOUND), + Message: "not found", + Action: "action", + File: "file", + Line: int32(0), + Fn: "funcName", + } + st, _ = st.WithDetails(srvErr) + f(metadata.NewIncomingContext(context.Background(), metadata.Pairs(XRequestID, "123456")), &runtime.ServeMux{}, nil, resp, req, st.Err()) + st1 := status.New(codes.NotFound, "not found") + st1, _ = st1.WithDetails(&common.APIResponse{}) + f(metadata.NewIncomingContext(context.Background(), metadata.Pairs(XRequestID, "123456")), &runtime.ServeMux{}, nil, resp, req, st1.Err()) + + }) +} + +func TestHandleServerStreamError(t *testing.T) { + c.Convey("TestHandleServerStreamError", t, func() { + f := HandleServerStreamError(Log) + st := status.New(codes.NotFound, "not found") + srvErr := &common.Errors{ + Code: int32(common.Code_NOT_FOUND), + Message: "not found", + Action: "action", + File: "file", + Line: int32(0), + Fn: "funcName", + ToClientMessage: "not found resource", + } + st, _ = st.WithDetails(srvErr) + err := f(context.Background(), st.Err()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "not found resource") + + err = f(context.Background(), status.Error(codes.Internal, "internal error")) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "Unknown error") + err = f(context.Background(), fmt.Errorf("test error")) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "test error") + + patch := gomonkey.ApplyFuncReturn((*anypb.Any).UnmarshalTo, fmt.Errorf("test error")) + defer patch.Reset() + ay, _ := anypb.New(srvErr) + st = status.New(codes.NotFound, "not found") + + st, _ = st.WithDetails(ay) + err = f(context.Background(), st.Err()) + c.So(err, c.ShouldNotBeNil) + c.So(err.Err().Error(), c.ShouldContainSubstring, "The requested resource is not found.") }) } diff --git a/gateway/protobuf.go b/gateway/protobuf.go index d1b80b9..2bce1d6 100644 --- a/gateway/protobuf.go +++ b/gateway/protobuf.go @@ -159,6 +159,7 @@ func NewDescriptionFromBinary(data []byte, outDir string) (ProtobufDescription, return nil, err } // desc.gatewayJsonSchema = filepath.Join(outDir, "gateway.json") + // log.Printf("GetFileDescriptorSet result is :%v",desc.GetFileDescriptorSet()) contents, err := register.Register(desc.GetFileDescriptorSet(), false, "") if err != nil { return nil, fmt.Errorf("Failed to register: %w", err) @@ -196,6 +197,7 @@ func (p *protobufDescription) GetMessageTypeByFullName(fullName string) protoref v := desc.(protoreflect.MessageDescriptor) return v } + // log.Printf("GetMessageTypeByFullName failed:%s", fullName) return nil } func (p *protobufDescription) GetGatewayJsonSchema() string { diff --git a/gateway/grpc.go b/gateway/proxy.go similarity index 58% rename from gateway/grpc.go rename to gateway/proxy.go index fe59bdf..77bc98b 100644 --- a/gateway/grpc.go +++ b/gateway/proxy.go @@ -2,13 +2,16 @@ package gateway import ( "context" + "errors" "fmt" "io" + "runtime/debug" "strings" "sync" "time" loadbalance "github.com/begonia-org/go-loadbalancer" + "github.com/begonia-org/go-sdk/logger" "github.com/spark-lence/tiga" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -16,7 +19,8 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) type grpcEndpointImpl struct { @@ -128,17 +132,34 @@ func (g *GrpcLoadBalancer) Select(method string, args ...interface{}) (loadbalan return nil, loadbalance.ErrNoEndpoint } +type IOType struct { + In protoreflect.MessageDescriptor + Out protoreflect.MessageDescriptor +} type GrpcProxyMiddleware func(srv interface{}, serverStream grpc.ServerStream) error type GrpcProxy struct { - lb *GrpcLoadBalancer - middlewares []GrpcProxyMiddleware + lb *GrpcLoadBalancer + // middlewares []grpc.StreamServerInterceptor + // chainStreamInts []grpc.StreamServerInterceptor + streamInt grpc.StreamClientInterceptor + chainClientStream []grpc.StreamClientInterceptor + ioType map[string]*IOType + log logger.Logger } -func NewGrpcProxy(lb *GrpcLoadBalancer, middlewares ...GrpcProxyMiddleware) *GrpcProxy { - return &GrpcProxy{ - lb: lb, - middlewares: middlewares, +func NewGrpcProxy(lb *GrpcLoadBalancer, log logger.Logger, middlewares ...grpc.StreamClientInterceptor) *GrpcProxy { + g := &GrpcProxy{ + lb: lb, + // middlewares: middlewares, + chainClientStream: middlewares, + ioType: make(map[string]*IOType), + log: log, } + g.chainStreamClientInterceptors() + return g +} +func (g *GrpcProxy) Register(lb loadbalance.LoadBalance, pd ProtobufDescription) { + g.lb.Register(lb, pd) } func (g *GrpcProxy) getClientIP(ctx context.Context) (string, error) { peer, ok := peer.FromContext(ctx) @@ -156,14 +177,53 @@ func (g *GrpcProxy) getXForward(ctx context.Context) []string { return md.Get("X-Forwarded-For") } -func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) error { +// chainStreamServerInterceptors chains all stream server interceptors into one. +func (g *GrpcProxy) chainStreamClientInterceptors() { + // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will + // be executed before any other chained interceptors. + interceptors := g.chainClientStream + if g.streamInt != nil { + interceptors = append([]grpc.StreamClientInterceptor{g.streamInt}, g.chainClientStream...) + } - // 执行中间件 - for _, middleware := range g.middlewares { - if err := middleware(srv, serverStream); err != nil { - return err - } + var chainedInt grpc.StreamClientInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = g.chainStreamInterceptors(interceptors) } + + g.streamInt = chainedInt +} + +func (g *GrpcProxy) chainStreamInterceptors(interceptors []grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptors[0](ctx, desc, cc, method, g.getChainStreamHandler(interceptors, 0, streamer)) + } +} + +func (g *GrpcProxy) getChainStreamHandler(interceptors []grpc.StreamClientInterceptor, curr int, finalHandler grpc.Streamer) grpc.Streamer { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptors[curr+1](ctx, desc, cc, method, g.getChainStreamHandler(interceptors, curr+1, finalHandler)) + } +} + +func (g *GrpcProxy) Do(srv interface{}, serverStream grpc.ServerStream) error { + + return g.Handler(srv, serverStream) +} +func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) error { + defer func() { + if p := recover(); p != nil { + s := debug.Stack() + g.log.Errorf(serverStream.Context(), "panic recover! p: %v stack:%s", p, s) + } + }() // 获取方法名 fullMethodName, ok := grpc.MethodFromServerStream(serverStream) if !ok { @@ -193,6 +253,7 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err defer endpoint.AfterTransform(serverStream.Context(), cn.((loadbalance.Connection))) conn := cn.(loadbalance.Connection).ConnInstance().(*grpc.ClientConn) + clientCtx, clientCancel := context.WithCancel(serverStream.Context()) defer clientCancel() proxyDesc := &grpc.StreamDesc{ @@ -202,38 +263,53 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err // 添加本地ip local, _ := tiga.GetLocalIP() xForwards = append(xForwards, local) - md, ok := metadata.FromIncomingContext(clientCtx) + md, ok := metadata.FromIncomingContext(serverStream.Context()) if !ok { md = metadata.MD{} } + md.Set("X-Forwarded-For", strings.Join(xForwards, ",")) clientCtx = metadata.NewOutgoingContext(clientCtx, md) + var clientStream grpc.ClientStream = nil - clientStream, err := grpc.NewClientStream(clientCtx, proxyDesc, conn, fullMethodName) - if err != nil { - return err + if len(g.chainClientStream) > 0 && g.streamInt != nil { + clientStream, err = g.streamInt(clientCtx, proxyDesc, conn, fullMethodName, g.getChainStreamHandler(g.chainClientStream, 0, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return grpc.NewClientStream(ctx, desc, cc, method, opts...) + })) + if err != nil { + return fmt.Errorf("failed creating client stream from stream client int: %v", err) + } + } else { + clientStream, err = grpc.NewClientStream(clientCtx, proxyDesc, conn, fullMethodName) + if err != nil { + return err + } } + + // proxyStream := &proxyClientStream{ClientStream: clientStream, ctx: clientCtx} // 转发流量 // 从客户端到服务端 s2cErrChan := g.forwardServerToClient(serverStream, clientStream) + // 从服务端到客户端 c2sErrChan := g.forwardClientToServer(clientStream, serverStream) for i := 0; i < 2; i++ { select { case s2cErr := <-s2cErrChan: - if s2cErr == io.EOF { + if errors.Is(s2cErr, io.EOF) { // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./ // the clientStream>serverStream may continue pumping though. + // log.Printf("s2cErr:%v", s2cErr) err = clientStream.CloseSend() if err != nil { - return status.Errorf(codes.Internal, "failed closing client stream: %v", err) + return fmt.Errorf("failed closing client stream: %w", err) } } else { // however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need // to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and // exit with an error to the stack clientCancel() - return status.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + return fmt.Errorf("failed proxying s2c: %w", s2cErr) } case c2sErr := <-c2sErrChan: // This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two @@ -241,7 +317,8 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err // will be nil. serverStream.SetTrailer(clientStream.Trailer()) // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. - if c2sErr != io.EOF { + if !errors.Is(c2sErr, io.EOF) { + // log.Printf("c2sErr:%v", c2sErr) return c2sErr } return nil @@ -253,10 +330,21 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error { ret := make(chan error, 1) go func() { - f := &emptypb.Empty{} + defer func() { + if p := recover(); p != nil { + s := debug.Stack() + g.log.Errorf(dst.Context(), "panic recover! p: %v stack:%s", p, s) + ret <- fmt.Errorf("panic recover! p: %v stack:%s", p, s) + } + }() + method, _ := grpc.Method(dst.Context()) + io := g.ioType[method] + f := dynamicpb.NewMessage(io.Out) + // f := &emptypb.Empty{} + for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - ret <- err // this can be io.EOF which is happy case + ret <- fmt.Errorf("fail forward client to server:%w", err) // this can be io.EOF which is happy case break } if i == 0 { @@ -264,18 +352,25 @@ func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.Server // received but must be written to server stream before the first msg is flushed. // This is the only place to do it nicely // 先转发header + // inMD, _ := metadata.FromIncomingContext(src.Context()) + // outMD, _ := metadata.FromOutgoingContext(src.Context()) + + // md := metadata.Join(inMD, outMD) + // m := make(map[string][]string) + md, err := src.Header() if err != nil { - ret <- err + ret <- fmt.Errorf("failed reading header from client stream: %w", err) break } + if err := dst.SendHeader(md); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending header to server stream: %w", err) break } } if err := dst.SendMsg(f); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending msg to server stream: %w", err) break } } @@ -286,17 +381,46 @@ func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.Server func (g *GrpcProxy) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error { ret := make(chan error, 1) go func() { - f := &emptypb.Empty{} + method, _ := grpc.Method(dst.Context()) + io := g.ioType[method] + f := dynamicpb.NewMessage(io.In) + for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - ret <- err // this can be io.EOF which is happy case + ret <- fmt.Errorf("recv msg error:%w from client forward", err) // this can be io.EOF which is happy case break } + if err := dst.SendMsg(f); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending msg to client stream: %w", err) break } } }() return ret } + +func (g *GrpcProxy) buildServiceDesc(pd ProtobufDescription) { + fd := pd.GetFileDescriptorSet() + // sds := make([]*grpc.ServiceDesc, 0) + for _, file := range fd.GetFile() { + sd := file.GetService() + + for _, service := range sd { + + methods := service.GetMethod() + for _, method := range methods { + in := method.GetInputType() + inDesc := pd.GetMessageTypeByFullName(strings.TrimPrefix(in, ".")) + outDesc := pd.GetMessageTypeByFullName(strings.TrimPrefix(method.GetOutputType(), ".")) + srv := fmt.Sprintf("/%s.%s/%s", file.GetPackage(), service.GetName(), method.GetName()) + g.ioType[srv] = &IOType{ + In: inDesc, + Out: outDesc, + } + + } + + } + } +} diff --git a/gateway/grpc_test.go b/gateway/proxy_test.go similarity index 63% rename from gateway/grpc_test.go rename to gateway/proxy_test.go index 2bbacbb..2acf50e 100644 --- a/gateway/grpc_test.go +++ b/gateway/proxy_test.go @@ -14,6 +14,8 @@ import ( "time" "github.com/agiledragon/gomonkey/v2" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/internal/pkg/config" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" hello "github.com/begonia-org/go-sdk/api/example/v1" @@ -26,8 +28,11 @@ import ( ) type streamMock struct { + ctx context.Context +} +type clientStreamMock struct { + ctx context.Context } -type clientStreamMock struct{} func (*streamMock) SendHeader(md metadata.MD) error { return nil @@ -37,10 +42,11 @@ func (*streamMock) SetHeader(md metadata.MD) error { } func (*streamMock) SetTrailer(md metadata.MD) { } -func (*streamMock) Context() context.Context { - return context.Background() +func (s *streamMock) Context() context.Context { + return s.ctx } func (*streamMock) SendMsg(m interface{}) error { + time.Sleep(1 * time.Second) return nil } func (*streamMock) RecvMsg(m interface{}) error { @@ -94,13 +100,27 @@ func TestGrpcHandleErr(t *testing.T) { load, _ := loadbalance.New(loadbalance.RRBalanceType, endps) lb := NewGrpcLoadBalancer() lb.Register(load, pd) - mid := func(srv interface{}, serverStream grpc.ServerStream) error { - return nil + fullMethod1 := "" + mid := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + fullMethod1 = method + return streamer(ctx, desc, cc, method, opts...) + } + fullMethod2 := "" + mid2 := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + fullMethod2 = method + return streamer(ctx, desc, cc, method, opts...) + } + proxy := NewGrpcProxy(lb, Log, mid, mid2) + proxy.buildServiceDesc(pd) + proxy.Register(load, pd) + + stream := &streamMock{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("uri", "/api/v1/example/server/websocket")), } - proxy := NewGrpcProxy(lb, mid) - stream := &streamMock{} patch := gomonkey.ApplyFuncReturn(grpc.MethodFromServerStream, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket"), true) - patch.ApplyFuncReturn(grpc.NewClientStream, &clientStreamMock{}, nil) + patch.ApplyFuncReturn(grpc.NewClientStream, &clientStreamMock{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("uri", "/api/v1/example/server/websocket")), + }, nil) addrs, _ := net.InterfaceAddrs() var localAddr net.Addr for _, addr := range addrs { @@ -121,10 +141,15 @@ func TestGrpcHandleErr(t *testing.T) { err error output []interface{} }{ + // { + // patch: metadata.FromIncomingContext, + // err: fmt.Errorf("metadata not exists in context"), + // output: []interface{}{nil, false}, + // }, { patch: mid, err: fmt.Errorf("mid handle err"), - output: []interface{}{fmt.Errorf("mid handle err")}, + output: []interface{}{nil, fmt.Errorf("mid handle err")}, }, { patch: (*clientStreamMock).CloseSend, @@ -153,14 +178,21 @@ func TestGrpcHandleErr(t *testing.T) { return io.EOF }) defer patch3.Reset() + patch6 := gomonkey.ApplyFuncReturn(grpc.Method, "/helloworld.Greeter/SayHelloWebsocket", true) + defer patch6.Reset() for _, caseV := range cases { patch2 := gomonkey.ApplyFuncReturn(caseV.patch, caseV.output...) defer patch2.Reset() - err = proxy.Handler(&hello.HelloRequest{}, stream) + + err = proxy.Do(&hello.HelloRequest{}, stream) + t.Log(caseV.err.Error()) + // t.Logf("err:%v", err.Error()) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, caseV.err.Error()) patch2.Reset() } + c.So(fullMethod1, c.ShouldEqual, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket")) + c.So(fullMethod2, c.ShouldEqual, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket")) patch3.Reset() errChan2 := make(chan error, 3) @@ -168,12 +200,33 @@ func TestGrpcHandleErr(t *testing.T) { errChan2 <- io.EOF errChan2 <- io.EOF patch4 := gomonkey.ApplyFuncReturn((*GrpcProxy).forwardServerToClient, errChan2) - err = proxy.Handler(&hello.HelloRequest{}, stream) + err = proxy.Do(&hello.HelloRequest{}, stream) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "proxying should never reach") defer patch4.Reset() patch.Reset() + patch4.Reset() + // patch6.Reset() + patch7 := gomonkey.ApplyFuncReturn(grpc.MethodFromServerStream, "/helloworld.Greeter/SayHelloWebsocket", false) + defer patch7.Reset() + + proxy2 := NewGrpcProxy(lb, Log, mid) + err = proxy2.Do(&hello.HelloRequest{}, stream) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "stream not exists in context") + + proxy2.chainStreamClientInterceptors() + f := func() { + proxy.forwardClientToServer(nil, stream) + } + c.So(f, c.ShouldNotPanic) }) } + +func TestProxyDo(t *testing.T) { + pd, _ := readDesc(config.NewConfig(cfg.ReadConfig("test"))) + p := &GrpcProxy{ioType: make(map[string]*IOType), lb: NewGrpcLoadBalancer(), log: Log} + p.buildServiceDesc(pd) +} diff --git a/internal/pkg/routers/routers.go b/gateway/routers.go similarity index 74% rename from internal/pkg/routers/routers.go rename to gateway/routers.go index 52ddb09..b1377e7 100644 --- a/internal/pkg/routers/routers.go +++ b/gateway/routers.go @@ -1,20 +1,10 @@ -package routers +package gateway import ( "fmt" - "log" "strings" "sync" - "github.com/begonia-org/begonia/gateway" - _ "github.com/begonia-org/go-sdk/api/app/v1" - _ "github.com/begonia-org/go-sdk/api/endpoint/v1" - _ "github.com/begonia-org/go-sdk/api/example/v1" - _ "github.com/begonia-org/go-sdk/api/iam/v1" - _ "github.com/begonia-org/go-sdk/api/plugin/v1" - _ "github.com/begonia-org/go-sdk/api/sys/v1" - _ "github.com/begonia-org/go-sdk/api/user/v1" - _ "github.com/begonia-org/go-sdk/common/api/v1" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/protobuf/proto" @@ -53,7 +43,7 @@ func NewHttpURIRouteToSrvMethod() *HttpURIRouteToSrvMethod { }) return httpURIRouteToSrvMethod } -func Get() *HttpURIRouteToSrvMethod { +func GetRouter() *HttpURIRouteToSrvMethod { return NewHttpURIRouteToSrvMethod() } @@ -61,7 +51,8 @@ func (r *HttpURIRouteToSrvMethod) AddRoute(uri string, srvMethod *APIMethodDetai r.mux.Lock() defer r.mux.Unlock() r.routers[uri] = srvMethod - r.grpcRouter[srvMethod.GrpcFullRouter] = srvMethod + // log.Printf("add srv method grpc router:%s,pointer:%p", srvMethod.GrpcFullRouter, r) + r.grpcRouter[strings.ToUpper(srvMethod.GrpcFullRouter)] = srvMethod } func (r *HttpURIRouteToSrvMethod) deleteRoute(uri string, grpcFullMethod string) { delete(r.routers, uri) @@ -72,6 +63,7 @@ func (r *HttpURIRouteToSrvMethod) GetRoute(uri string) *APIMethodDetails { return r.routers[uri] } func (r *HttpURIRouteToSrvMethod) GetRouteByGrpcMethod(method string) *APIMethodDetails { + // log.Printf("get grpc method,%v:%s,pointer:%p",r.grpcRouter,strings.ToUpper(method),r) return r.grpcRouter[strings.ToUpper(method)] } func (r *HttpURIRouteToSrvMethod) GetAllRoutes() map[string]*APIMethodDetails { @@ -86,7 +78,14 @@ func (r *HttpURIRouteToSrvMethod) getServiceOptionByExt(service *descriptorpb.Se } return nil } - +func (r *HttpURIRouteToSrvMethod) getMethodOptionByExt(method *descriptorpb.MethodDescriptorProto, ext protoreflect.ExtensionType) interface{} { + if options := method.GetOptions(); options != nil { + if ext := proto.GetExtension(options, ext); ext != nil { + return ext + } + } + return nil +} func (r *HttpURIRouteToSrvMethod) getHttpRule(method *descriptorpb.MethodDescriptorProto) *annotations.HttpRule { if options := method.GetOptions(); options != nil { if ext := proto.GetExtension(options, annotations.E_Http); ext != nil { @@ -98,7 +97,7 @@ func (r *HttpURIRouteToSrvMethod) getHttpRule(method *descriptorpb.MethodDescrip return nil } func (r *HttpURIRouteToSrvMethod) AddLocalSrv(fullMethod string) { - log.Printf("add local srv:%s", fullMethod) + // log.Printf("add local srv:%s", fullMethod) r.localSrv[strings.ToUpper(fullMethod)] = true } func (r *HttpURIRouteToSrvMethod) IsLocalSrv(fullMethod string) bool { @@ -156,7 +155,13 @@ func (r *HttpURIRouteToSrvMethod) addRouterDetails(serviceName string, useJsonRe } } -func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) { + +// LoadAllRouters Load all routers from protobuf description +// for service methods, if the method has a google.api.http annotation, then add the router +// to the router list, and set the authRequired flag to true if the method has a pb.auth_required annotation, +// if the method has a pb.http_response annotation, then set the useJsonResponse flag to true, +// if the method has a pb.dont_use_http_response annotation, then set the useJsonResponse flag to false. +func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd ProtobufDescription) { fds := pd.GetFileDescriptorSet() for _, fd := range fds.File { for _, service := range fd.Service { @@ -167,13 +172,20 @@ func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) if authRequiredExt := r.getServiceOptionByExt(service, common.E_AuthReqiured); authRequiredExt != nil { authRequired, _ = authRequiredExt.(bool) } - if httpResponseExt := r.getServiceOptionByExt(service, common.E_HttpResponse); httpResponseExt != nil { + if httpResponseExt := r.getServiceOptionByExt(service, common.E_HttpResponse); httpResponseExt != nil && httpResponseExt.(string) != "" { httpResponse = true } // 遍历服务中的所有方法 for _, method := range service.GetMethod() { key := fmt.Sprintf("/%s.%s/%s", fd.GetPackage(), service.GetName(), method.GetName()) - r.addRouterDetails(strings.ToUpper(key), httpResponse, authRequired, method) + // log.Printf("add router:%s,%v", key, httpResponse) + // do not use HttpResponse for this method if it is set + dontUseHttpResponse := r.getMethodOptionByExt(method, common.E_DontUseHttpResponse) + useHttpResponse := httpResponse + if dontUseHttpResponse != nil && dontUseHttpResponse.(bool) { + useHttpResponse = false + } + r.addRouterDetails(strings.ToUpper(key), useHttpResponse, authRequired, method) } } @@ -181,7 +193,7 @@ func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) } -func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd gateway.ProtobufDescription) { +func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd ProtobufDescription) { fds := pd.GetFileDescriptorSet() for _, fd := range fds.File { for _, service := range fd.Service { diff --git a/internal/pkg/routers/routers_test.go b/gateway/routers_test.go similarity index 75% rename from internal/pkg/routers/routers_test.go rename to gateway/routers_test.go index c07cd8d..ed3e8b5 100644 --- a/internal/pkg/routers/routers_test.go +++ b/gateway/routers_test.go @@ -1,4 +1,4 @@ -package routers_test +package gateway_test import ( "path/filepath" @@ -6,15 +6,14 @@ import ( "testing" "github.com/begonia-org/begonia/gateway" - "github.com/begonia-org/begonia/internal/pkg/routers" c "github.com/smartystreets/goconvey/convey" ) func TestLoadAllRouters(t *testing.T) { c.Convey("TestLoadAllRouters", t, func() { - R := routers.NewHttpURIRouteToSrvMethod() + R := gateway.NewHttpURIRouteToSrvMethod() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata") pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) R.LoadAllRouters(pd) @@ -32,14 +31,19 @@ func TestLoadAllRouters(t *testing.T) { d, ok := rs["/test/custom"] c.So(ok, c.ShouldBeTrue) c.So(d.ServiceName, c.ShouldEqual, "/INTEGRATION.TESTSERVICE/CUSTOM") + c.So(d.UseJsonResponse, c.ShouldBeTrue) + + r := rs["/test/body"] + c.So(r, c.ShouldNotBeNil) + c.So(r.UseJsonResponse, c.ShouldBeFalse) }) } func TestDeleteRouters(t *testing.T) { c.Convey("TestDeleteRouters", t, func() { - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + pbFile := filepath.Join((filepath.Dir(filepath.Dir(filename))), "testdata") pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) diff --git a/gateway/serialization.go b/gateway/serialization.go index b63cb05..65f22f2 100644 --- a/gateway/serialization.go +++ b/gateway/serialization.go @@ -15,6 +15,7 @@ import ( common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "google.golang.org/genproto/googleapis/api/httpbody" + spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" @@ -36,25 +37,6 @@ type BinaryDecoder struct { marshaler runtime.Marshaler } -// type JSONDecoder struct{ -// *runtime.DecoderWrapper -// } -// func (d *JSONDecoder) Decode(v interface{}) error { -// if response, ok := v.(map[string]interface{}); ok { -// if _, ok := response["result"]; ok { -// v = response["result"] -// } - -// } -// if msg,ok:=v.(protoreflect.Message);ok{ - -// } -// return d.DecoderWrapper.Decode(v) - -// } -// var typeOfBytes = reflect.TypeOf([]byte(nil)) -// var typeOfHttpbody = reflect.TypeOf(&httpbody.HttpBody{}) - func (d *BinaryDecoder) Decode(v interface{}) error { if v == nil { return nil @@ -147,6 +129,11 @@ func (m *RawBinaryUnmarshaler) Marshal(v interface{}) ([]byte, error) { } } } + if resp, ok := v.(map[string]interface{}); ok { + if _, ok := resp["result"]; ok { + v = resp["result"] + } + } return m.Marshaler.Marshal(v) } @@ -156,17 +143,24 @@ func (m *EventSourceMarshaler) ContentType(v interface{}) string { func (m *EventSourceMarshaler) Marshal(v interface{}) ([]byte, error) { if response, ok := v.(map[string]interface{}); ok { - // result:=response if _, ok := response["result"]; ok { v = response["result"] } - } - // 在这里定义你的自定义序列化逻辑 + if response, ok := v.(map[string]proto.Message); ok { + if _, ok := response["error"]; ok { + v = response["error"] + } + } + // build event stream format line by line if stream, ok := v.(*common.EventStream); ok { line := fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", stream.Id, stream.Event, stream.Retry, stream.Data) return []byte(line), nil - + } + // build error message + if stream, ok := v.(*spb.Status); ok { + line := fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 0, "error", 0, stream.GetMessage()) + return []byte(line), nil } return m.JSONPb.Marshal(v) } @@ -199,7 +193,6 @@ func (m *JSONMarshaler) Marshal(v interface{}) ([]byte, error) { } } - if response, ok := v.(*dynamicpb.Message); ok { // log.Println("实际类型,", response.Type().Descriptor().Name()) byteData, err := m.JSONPb.Marshal(response) diff --git a/gateway/serialization_test.go b/gateway/serialization_test.go index 4eab9e5..0040e03 100644 --- a/gateway/serialization_test.go +++ b/gateway/serialization_test.go @@ -7,10 +7,15 @@ import ( "testing" "github.com/agiledragon/gomonkey/v2" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" c "github.com/smartystreets/goconvey/convey" "google.golang.org/genproto/googleapis/api/httpbody" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/dynamicpb" + "google.golang.org/protobuf/types/known/anypb" ) func TestRawBinaryUnmarshaler(t *testing.T) { @@ -108,3 +113,75 @@ func TestRawBinaryDecodeErr(t *testing.T) { }) } +func TestJSONMarshaler(t *testing.T) { + c.Convey("TestJSONMarshaler", t, func() { + marshaler := NewJSONMarshaler() + data := map[string]interface{}{ + "test": "test", + } + buf, err := marshaler.Marshal(data) + c.So(err, c.ShouldBeNil) + c.So(string(buf), c.ShouldEqual, `{"test":"test"}`) + + httpBody := &httpbody.HttpBody{ + ContentType: "application/octet-stream-test", + Data: []byte("test"), + } + msg2 := dynamicpb.NewMessage(httpBody.ProtoReflect().Descriptor()).New() + patch := gomonkey.ApplyFuncReturn((*runtime.JSONPb).Marshal, nil, fmt.Errorf("runtime.JSONPb{}.Marshal: nil")) + defer patch.Reset() + _, err = marshaler.Marshal(msg2) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "runtime.JSONPb{}.Marshal: nil") + }) +} + +func TestEventSourceMarshaler(t *testing.T) { + c.Convey("TestEventSourceMarshaler", t, func() { + marshaler := NewEventSourceMarshaler() + cases := []struct { + data interface{} + err error + exception string + }{ + { + data: map[string]interface{}{ + "result": &common.EventStream{ + Event: "test", + Id: 1, + Data: "test", + Retry: 0, + }, + }, + err: nil, + exception: fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 1, "test", 0, "test"), + }, + { + data: &common.EventStream{ + Event: "test-data", + Id: 1, + Data: "test-data", + Retry: 0, + }, + err: nil, + exception: fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 1, "test-data", 0, "test-data"), + }, + { + data: map[string]proto.Message{ + "error": &spb.Status{ + Message: "test error", + Code: int32(codes.Internal), + Details: []*anypb.Any{}, + }, + }, + err: nil, + exception: fmt.Sprintf("id: %d\nevent: %s\nretry: %d\ndata: %s\n", 0, "error", 0, "test error"), + }, + } + for _, caseV := range cases { + buf, err := marshaler.Marshal(caseV.data) + c.So(err, c.ShouldBeNil) + c.So(string(buf), c.ShouldEqual, caseV.exception) + } + }) +} diff --git a/gateway/types.go b/gateway/types.go index ed309ac..924f5ed 100644 --- a/gateway/types.go +++ b/gateway/types.go @@ -47,7 +47,7 @@ func (x *serverSideStreamClient) buildEventStreamResponse(dpm *dynamicpb.Message return nil, err } - + // log.Printf("buildEventStreamResponse data:%s", string(data)) commonEvent := &common.EventStream{ Event: string(dpm.Descriptor().Name()), Id: atomic.LoadInt64(&x.ID), diff --git a/gateway/utils_test.go b/gateway/utils_test.go index b368ed0..ee16dc0 100644 --- a/gateway/utils_test.go +++ b/gateway/utils_test.go @@ -16,7 +16,7 @@ import ( func TestNewEndpoint(t *testing.T) { opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), diff --git a/go.mod b/go.mod index 060e446..e53e5b3 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/spark-lence/tiga v0.0.0-20240707025120-c2e1f47f88dc github.com/spf13/cobra v1.8.0 - google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 + google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 ) @@ -80,15 +80,17 @@ require ( require ( github.com/agiledragon/gomonkey/v2 v2.11.0 github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a - github.com/begonia-org/go-sdk v0.0.0-20240707025218-b8159f0b2462 + github.com/begonia-org/go-sdk v0.0.0-20240722164044-30e4fbeca04c github.com/go-git/go-git/v5 v5.11.0 github.com/go-playground/validator/v10 v10.19.0 github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 + github.com/iancoleman/strcase v0.3.0 github.com/minio/minio-go/v7 v7.0.71 github.com/r3labs/sse/v2 v2.10.0 go.etcd.io/etcd/api/v3 v3.5.14 go.etcd.io/etcd/client/v3 v3.5.14 + google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade gopkg.in/cenkalti/backoff.v1 v1.1.0 ) @@ -118,7 +120,6 @@ require ( github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect - github.com/iancoleman/strcase v0.3.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect @@ -140,7 +141,6 @@ require ( go.uber.org/zap v1.27.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index 708579d..f271a6b 100644 --- a/go.sum +++ b/go.sum @@ -15,7 +15,6 @@ github.com/agiledragon/gomonkey/v2 v2.11.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoa github.com/allegro/bigcache/v3 v3.1.0 h1:H2Vp8VOvxcrB91o86fUSVJFqeuz8kpyyB02eH3bSzwk= github.com/allegro/bigcache/v3 v3.1.0/go.mod h1:aPyh7jEvrog9zAwx5N7+JUQX5dZTSGpxF1LAR4dr35I= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= -github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= @@ -23,12 +22,8 @@ github.com/begonia-org/go-layered-cache v0.0.0-20240510102605-41bdb7aa07fa h1:DH github.com/begonia-org/go-layered-cache v0.0.0-20240510102605-41bdb7aa07fa/go.mod h1:xEqoca1vNGqH8CV7X9EzhDV5Ihtq9J95p7ZipzUB6pc= github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a h1:Mpw7T+90KC5QW7yCa8Nn/5psnlvsexipAOrQAcc7YE0= github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a/go.mod h1:crPS67sfgmgv47psftwfmTMbmTfdepVm8MPeqApINlI= -github.com/begonia-org/go-sdk v0.0.0-20240704075659-182a1008f0ab h1:zr5JyiG4eIkM80SIkRiu/4ulX0735OMS5vPnzB+9e0s= -github.com/begonia-org/go-sdk v0.0.0-20240704075659-182a1008f0ab/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= -github.com/begonia-org/go-sdk v0.0.0-20240704083802-6c7fd1cb3fbc h1:z8zW/vMZkEOxrYSKY6N/yX+tVhSqt1kWJKPuUcJX4to= -github.com/begonia-org/go-sdk v0.0.0-20240704083802-6c7fd1cb3fbc/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= -github.com/begonia-org/go-sdk v0.0.0-20240707025218-b8159f0b2462 h1:qd5Fim08aHI+RlBqyZhJEWtuIm9Q03DMm+AToh4Uprc= -github.com/begonia-org/go-sdk v0.0.0-20240707025218-b8159f0b2462/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= +github.com/begonia-org/go-sdk v0.0.0-20240722164044-30e4fbeca04c h1:Al780pVPkCmZ9n+NrjFZbTVzbKrd82dVMVfKn+twacQ= +github.com/begonia-org/go-sdk v0.0.0-20240722164044-30e4fbeca04c/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -123,7 +118,6 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -169,7 +163,6 @@ github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ib github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= -github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -211,7 +204,6 @@ github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0= github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= github.com/redis/go-redis/v9 v9.5.3 h1:fOAp1/uJG+ZtcITgZOfYFmTKPE7n4Vclj1wZFgRciUU= github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= -github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= @@ -239,8 +231,6 @@ github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sS github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/spark-lence/tiga v0.0.0-20240704075559-c825086a1249 h1:Rk1qEMiMzTeZddgGotXAtXNK4sryZh2hhG3i5ncFQ0Y= -github.com/spark-lence/tiga v0.0.0-20240704075559-c825086a1249/go.mod h1:h7BTZeR6xD6+tr3ClEhHC1PeXPOn3jRt7NnThQg1JvE= github.com/spark-lence/tiga v0.0.0-20240707025120-c2e1f47f88dc h1:F+20XqYEhBTSprENDS0dcNznnwPV2Nq0mT6H+yUB/po= github.com/spark-lence/tiga v0.0.0-20240707025120-c2e1f47f88dc/go.mod h1:h7BTZeR6xD6+tr3ClEhHC1PeXPOn3jRt7NnThQg1JvE= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= @@ -309,8 +299,6 @@ golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= @@ -339,8 +327,6 @@ golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -371,8 +357,6 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -384,7 +368,6 @@ golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -412,10 +395,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:BwIjyKYGsK9dMCBOorzRri8MQwmi7mT9rGHsCEinZkA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade h1:WxZOF2yayUHpHSbUE6NMzumUzBxYc3YGwo0YHnbzsJY= +google.golang.org/genproto/googleapis/api v0.0.0-20240722135656-d784300faade/go.mod h1:mw8MG/Qz5wfgYr6VqVCiZcHe/GJEfI+oGGDCohaVgB0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade h1:oCRSWfwGXQsqlVdErcyTt4A93Y8fo0/9D4b1gnI++qo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240722135656-d784300faade/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= @@ -425,7 +408,6 @@ gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UD gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= diff --git a/internal/biz/aksk.go b/internal/biz/aksk.go index 6fa1329..7d81263 100644 --- a/internal/biz/aksk.go +++ b/internal/biz/aksk.go @@ -5,9 +5,9 @@ import ( "strings" "time" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -30,7 +30,7 @@ func NewAccessKeyAuth(app AppRepo, config *config.Config, log logger.Logger) *Ac } func IfNeedValidate(ctx context.Context, fullMethod string) bool { - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(strings.ToUpper(fullMethod)) if router == nil { return false diff --git a/internal/biz/aksk_test.go b/internal/biz/aksk_test.go index 373cc6e..06aef13 100644 --- a/internal/biz/aksk_test.go +++ b/internal/biz/aksk_test.go @@ -16,7 +16,6 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/pkg/utils" gosdk "github.com/begonia-org/go-sdk" @@ -138,7 +137,7 @@ func testIfNeedValidate(t *testing.T) { ok := biz.IfNeedValidate(context.TODO(), akskAccess) c.So(ok, c.ShouldBeFalse) - patch := gomonkey.ApplyFuncReturn((*routers.HttpURIRouteToSrvMethod).GetRouteByGrpcMethod, &routers.APIMethodDetails{AuthRequired: true}) + patch := gomonkey.ApplyFuncReturn((*gateway.HttpURIRouteToSrvMethod).GetRouteByGrpcMethod, &gateway.APIMethodDetails{AuthRequired: true}) defer patch.Reset() ok = biz.IfNeedValidate(context.TODO(), akskAccess) c.So(ok, c.ShouldBeTrue) diff --git a/internal/biz/data_test.go b/internal/biz/data_test.go index c8c693a..91c5958 100644 --- a/internal/biz/data_test.go +++ b/internal/biz/data_test.go @@ -116,7 +116,7 @@ func TestDo(t *testing.T) { _ = cache.Del(context.Background(), "begonia:user:black:lock") _ = cache.Del(context.Background(), "begonia:user:black:last_updated") opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), diff --git a/internal/biz/endpoint/endpoint_test.go b/internal/biz/endpoint/endpoint_test.go index b669b22..549050f 100644 --- a/internal/biz/endpoint/endpoint_test.go +++ b/internal/biz/endpoint/endpoint_test.go @@ -21,7 +21,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gwRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" goloadbalancer "github.com/begonia-org/go-loadbalancer" @@ -374,7 +373,7 @@ func testWatcherUpdate(t *testing.T) { } val, _ := json.Marshal(value) opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -385,12 +384,12 @@ func testWatcherUpdate(t *testing.T) { GrpcProxyAddr: "127.0.0.1:12148", } gateway.New(gwCnf, opts) - routers.NewHttpURIRouteToSrvMethod() + gateway.NewHttpURIRouteToSrvMethod() c.Convey("Test Watcher Update", t, func() { err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldBeNil) - r := routers.Get() + r := gateway.GetRouter() detail := r.GetRoute("/api/v1/example/{name}") c.So(detail, c.ShouldNotBeNil) @@ -433,6 +432,27 @@ func testWatcherUpdate(t *testing.T) { c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrUnknownLoadBalancer.Error()) + // SetHttpResponse err + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + conf := config.ReadConfig(env) + cnf := cfg.NewConfig(conf) + outDir := cnf.GetGatewayDescriptionOut() + _, filename, _, _ := runtime.Caller(0) + + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata", "helloworld.pb") + pb, err := os.ReadFile(pbFile) + c.So(err, c.ShouldBeNil) + pd, err := gateway.NewDescriptionFromBinary(pb, filepath.Join(outDir, "tmp-test")) + c.So(err, c.ShouldBeNil) + patch6 := gomonkey.ApplyMethodReturn(pd, "SetHttpResponse", fmt.Errorf("test SetHttpResponse error")) + defer patch6.Reset() + err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "test SetHttpResponse error") }) } func testWatcherDel(t *testing.T) { @@ -453,7 +473,7 @@ func testWatcherDel(t *testing.T) { c.Convey("Test Watcher Del", t, func() { err := watcher.Handle(context.TODO(), mvccpb.DELETE, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldBeNil) - r := routers.Get() + r := gateway.GetRouter() detail := r.GetRoute("/api/v1/example/{name}") c.So(detail, c.ShouldBeNil) }) diff --git a/internal/biz/endpoint/utils.go b/internal/biz/endpoint/utils.go index bcd5d4e..2de7aca 100644 --- a/internal/biz/endpoint/utils.go +++ b/internal/biz/endpoint/utils.go @@ -7,14 +7,13 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/grpc/codes" ) func deleteAll(ctx context.Context, pd gateway.ProtobufDescription) error { - routersList := routers.Get() + routersList := gateway.GetRouter() routersList.DeleteRouters(pd) gw := gateway.Get() gw.DeleteLoadBalance(pd) diff --git a/internal/biz/endpoint/watcher.go b/internal/biz/endpoint/watcher.go index 8149c29..b8365e5 100644 --- a/internal/biz/endpoint/watcher.go +++ b/internal/biz/endpoint/watcher.go @@ -3,17 +3,17 @@ package endpoint import ( "context" "fmt" + "log" "sync" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" "go.etcd.io/etcd/api/v3/mvccpb" "encoding/json" - "github.com/begonia-org/begonia/gateway" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -38,7 +38,7 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) return nil } endpoint := &api.Endpoints{} - routersList := routers.NewHttpURIRouteToSrvMethod() + routersList := gateway.GetRouter() err := json.Unmarshal([]byte(value), endpoint) if err != nil { return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") @@ -62,15 +62,24 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) } // register routers // log.Print("register router") + err = pd.SetHttpResponse(common.E_HttpResponse) + if err != nil { + return gosdk.NewError(fmt.Errorf("set http response error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "set_http_response") + + } routersList.LoadAllRouters(pd) + // register service to gateway gw := gateway.Get() err = gw.RegisterService(ctx, pd, lb) + if err != nil { return gosdk.NewError(fmt.Errorf("register service error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "register_service") } + gw.RegisterServiceWithProxy(pd) // err = g.repo.PutTags(ctx, endpoint.Key, endpoint.Tags) + log.Printf("register service success") return nil } func (g *EndpointWatcher) Del(ctx context.Context, key string, value string) error { diff --git a/internal/data/app.go b/internal/data/app.go index 682b082..41c3d11 100644 --- a/internal/data/app.go +++ b/internal/data/app.go @@ -100,7 +100,7 @@ func (a *appRepoImpl) GetSecret(ctx context.Context, accessKey string) (string, cacheKey := a.cfg.GetAPPAccessKey(accessKey) secretBytes, err := a.local.Get(ctx, cacheKey) secret := string(secretBytes) - if err != nil { + if err != nil || secret == "" { apps, err := a.Get(ctx, accessKey) if err != nil || apps.Secret == "" { return "", fmt.Errorf("get app secret failed: %w", err) diff --git a/internal/data/app_test.go b/internal/data/app_test.go index 5a79d8b..9ba296a 100644 --- a/internal/data/app_test.go +++ b/internal/data/app_test.go @@ -316,10 +316,16 @@ func delTest(t *testing.T) { env = begonia.Env } repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) - + // set boolean err + patch4 := gomonkey.ApplyFuncReturn((*curdImpl).SetBoolean, fmt.Errorf("set boolean error")) + defer patch4.Reset() + err := repo.Del(context.TODO(), appid) + patch4.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "set boolean error") patch := gomonkey.ApplyFuncReturn(getPrimaryColumnValue, nil, fmt.Errorf("getPrimaryColumnValue,error")) defer patch.Reset() - err := repo.Del(context.TODO(), appid) + err = repo.Del(context.TODO(), appid) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "getPrimaryColumnValue,error") patch.Reset() diff --git a/internal/data/curd_test.go b/internal/data/curd_test.go index c0f8ebe..4ec50c9 100644 --- a/internal/data/curd_test.go +++ b/internal/data/curd_test.go @@ -1,15 +1,18 @@ package data import ( + "context" "testing" + "github.com/agiledragon/gomonkey/v2" api "github.com/begonia-org/go-sdk/api/app/v1" c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" ) func TestAssertDeletedModel(t *testing.T) { c.Convey("test assert deleted model", t, func() { - curd := &curdImpl{} + curd := &curdImpl{db: &tiga.MySQLDao{}} v, ok := curd.assertDeletedModel(&struct{}{}) c.So(ok, c.ShouldBeFalse) c.So(v, c.ShouldBeNil) @@ -24,6 +27,9 @@ func TestAssertDeletedModel(t *testing.T) { c.So(err, c.ShouldNotBeNil) err = curd.SetDatetimeAt(&api.Apps{}, "deleted_at_test") c.So(err, c.ShouldNotBeNil) + patch := gomonkey.ApplyFuncReturn(tiga.MySQLDao.Begin, nil) + defer patch.Reset() + curd.BeginTx(context.Background()) }) } func TestGetPrimaryColumnValueErr(t *testing.T) { @@ -31,15 +37,5 @@ func TestGetPrimaryColumnValueErr(t *testing.T) { _, err := getPrimaryColumnValue(make(map[string]interface{}), "primary") c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "not a struct type") - - // _, err = getPrimaryColumnValue(&struct { - // Primary string - // Name string - // }{ - // Primary: "primary", - // Name: "name", - // }, "primary") - // c.So(err, c.ShouldNotBeNil) - // c.So(err.Error(), c.ShouldContainSubstring, "not found primary column") }) } diff --git a/internal/middleware/auth/ak_test.go b/internal/middleware/auth/ak_test.go index 9517232..62168ae 100644 --- a/internal/middleware/auth/ak_test.go +++ b/internal/middleware/auth/ak_test.go @@ -15,7 +15,6 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/middleware/auth" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" c "github.com/smartystreets/goconvey/convey" @@ -38,7 +37,7 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { ak.SetPriority(1) c.So(ak.Name(), c.ShouldEqual, "ak_auth") c.So(ak.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -58,11 +57,18 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { return fmt.Errorf("metadata not exists in context") }) c.So(err.Error(), c.ShouldContainSubstring, fmt.Errorf("metadata not exists in context").Error()) - - patch := gomonkey.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamRequestBefore, nil, nil) + patch := gomonkey.ApplyMethodReturn(akBiz, "AppValidator", "test", nil) + patch = patch.ApplyMethodReturn(akBiz, "GetAppOwner", "test", nil) + // patch := gomonkey.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamRequestBefore, nil, nil) patch = patch.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamResponseAfter, fmt.Errorf("StreamResponseAfter err")) defer patch.Reset() - err = ak.StreamInterceptor(context.Background(), &testStream{ctx: context.Background()}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, func(srv any, stream grpc.ServerStream) error { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + err = ak.StreamInterceptor(ctx, &testStream{ctx: ctx}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, func(srv any, ss grpc.ServerStream) error { + md, _ := metadata.FromIncomingContext(ss.Context()) + if len(md.Get(gosdk.HeaderXIdentity)) == 0 || md.Get(gosdk.HeaderXIdentity)[0] == "" { + t.Error("identity not exists in context") + return fmt.Errorf("identity not exists in context") + } return nil }) @@ -76,6 +82,29 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { }) patch2.Reset() c.So(err.Error(), c.ShouldContainSubstring, fmt.Errorf("StreamRequestBefore err").Error()) + // do not need validate + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + _, err = ak.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + + // NO CONTEXT + _, err = ak.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + // get owner error + patch3 := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, "", fmt.Errorf("get owner error")) + defer patch3.Reset() + // get owner error + in := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + _, err = ak.StreamRequestBefore(in, &testStream{ctx: in}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, nil) + patch3.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get app owner error") + }) } func TestRequestBeforeErr(t *testing.T) { @@ -133,7 +162,7 @@ func TestValidateStream(t *testing.T) { ak := auth.NewAccessKeyAuth(akBiz, cnf, gateway.Log) patch := gomonkey.ApplyFuncReturn(gosdk.NewGatewayRequestFromGrpc, nil, fmt.Errorf("NewGatewayRequestFromGrpc err")) defer patch.Reset() - _, err := ak.ValidateStream(context.TODO(), nil, "", nil) + _, err := ak.ValidateStream(context.TODO(), nil, "") patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "NewGatewayRequestFromGrpc err") diff --git a/internal/middleware/auth/aksk.go b/internal/middleware/auth/aksk.go index 01c572c..905b57b 100644 --- a/internal/middleware/auth/aksk.go +++ b/internal/middleware/auth/aksk.go @@ -4,9 +4,9 @@ import ( "context" "strings" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" "github.com/begonia-org/go-sdk/logger" "google.golang.org/grpc" @@ -34,7 +34,7 @@ func NewAccessKeyAuth(app *biz.AccessKeyAuth, config *config.Config, log logger. } func IfNeedValidate(ctx context.Context, fullMethod string) bool { - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(strings.ToUpper(fullMethod)) if router == nil { return false @@ -63,8 +63,8 @@ func (a *AccessKeyAuthMiddleware) RequestBefore(ctx context.Context, info *grpc. if !ok { md = metadata.MD{} } - // md.Set(gosdk.HeaderXIdentity, owner) - md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, owner)) + md.Set(gosdk.HeaderXIdentity, owner) + // md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, owner)) ctx = metadata.NewIncomingContext(ctx, md) // md2, _ := metadata.FromIncomingContext(ctx) @@ -73,21 +73,29 @@ func (a *AccessKeyAuthMiddleware) RequestBefore(ctx context.Context, info *grpc. } -func (a *AccessKeyAuthMiddleware) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *AccessKeyAuthMiddleware) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { ctx, err := a.RequestBefore(ctx, &grpc.UnaryServerInfo{FullMethod: fullName}, req) if err != nil { return ctx, err } - md, _ := metadata.FromIncomingContext(ctx) - if identity := md.Get(gosdk.HeaderXIdentity); len(identity) > 0 { - headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity[0]) - } + return ctx, nil } func (a *AccessKeyAuthMiddleware) StreamRequestBefore(ctx context.Context, ss grpc.ServerStream, info *grpc.StreamServerInfo, req interface{}) (grpc.ServerStream, error) { - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) - // defer grpcStream.Release() + if in, ok := metadata.FromIncomingContext(ctx); ok { + if ak := in.Get(gosdk.HeaderXAccessKey); len(ak) > 0 { + identity, err := a.app.GetAppOwner(ctx, ak[0]) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "get app owner error,%v", err) + } + md, _ := metadata.FromIncomingContext(ctx) + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.NewOutgoingContext(ctx, md) + } + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) return grpcStream, nil } @@ -103,6 +111,7 @@ func (a *AccessKeyAuthMiddleware) UnaryInterceptor(ctx context.Context, req any, defer func() { _ = a.ResponseAfter(ctx, info, req, resp) }() + resp, err = handler(ctx, req) return resp, err @@ -111,10 +120,12 @@ func (a *AccessKeyAuthMiddleware) StreamInterceptor(srv interface{}, ss grpc.Ser if !IfNeedValidate(ss.Context(), info.FullMethod) { return handler(srv, ss) } + grpcStream, err := a.StreamRequestBefore(ss.Context(), ss, info, srv) if err != nil { return err } + // log.Printf("AccessKeyAuthMiddleware StreamInterceptor") defer func() { err := a.StreamResponseAfter(ss.Context(), ss, info) if err != nil { @@ -145,3 +156,34 @@ func (a *AccessKeyAuthMiddleware) Priority() int { func (a *AccessKeyAuthMiddleware) Name() string { return a.name } +func (a *AccessKeyAuthMiddleware) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") + } + if AK := md.Get(gosdk.HeaderXAccessKey); len(AK) > 0 { + accessKey := AK[0] + identity, err := a.app.GetAppOwner(ctx, accessKey) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "get app owner error,%v", err) + } + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewOutgoingContext(ctx, md) + in, ok := metadata.FromIncomingContext(ctx) + if !ok { + in = metadata.MD{} + } + in.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, in) + // ctx = metadata.NewIncomingContext() + } + st, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, status.Errorf(codes.Internal, "streamer error,%v", err) + } + return st, nil +} diff --git a/internal/middleware/auth/apikey.go b/internal/middleware/auth/apikey.go index 99c6db6..0039a6e 100644 --- a/internal/middleware/auth/apikey.go +++ b/internal/middleware/auth/apikey.go @@ -3,7 +3,6 @@ package auth import ( "context" "fmt" - "strings" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/pkg" @@ -66,22 +65,31 @@ func NewApiKeyAuth(config *config.Config, authz *biz.AuthzUsecase) ApiKeyAuth { } func (a *ApiKeyAuthImpl) check(ctx context.Context) (string, error) { md, ok := metadata.FromIncomingContext(ctx) - if !ok { + out, outOK := metadata.FromOutgoingContext(ctx) + if !ok && !outOK { return "", gosdk.NewError(status.Errorf(codes.Unauthenticated, "metadata not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } // authorization := a.GetAuthorizationFromMetadata(md) apikeys := md.Get(gosdk.HeaderXApiKey) - if len(apikeys) == 0 { + outAPIKeys := out.Get(gosdk.HeaderXApiKey) + if len(apikeys) == 0 && len(outAPIKeys) == 0 { return "", gosdk.NewError(status.Errorf(codes.Unauthenticated, "apikey not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } - apikey := apikeys[0] + + apikey := "" + if len(apikeys) != 0 { + apikey = apikeys[0] + } else if len(outAPIKeys) != 0 { + apikey = outAPIKeys[0] + } + if apikey != a.config.GetAdminAPIKey() { return "", gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } return apikey, nil } -func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { apikey := "" var err error if apikey, err = a.check(ctx); err == nil && apikey != "" { @@ -92,7 +100,7 @@ func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fu md, _ := metadata.FromIncomingContext(ctx) md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) - headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity) + // headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity) return metadata.NewIncomingContext(ctx, md), err } return ctx, err @@ -102,7 +110,46 @@ func (a *ApiKeyAuthImpl) StreamInterceptor(srv interface{}, ss grpc.ServerStream if !IfNeedValidate(ss.Context(), info.FullMethod) { return handler(srv, ss) } - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) + ctx := ss.Context() + if apikey, err := a.check(ctx); err == nil && apikey != "" { + identity, err := a.authz.GetIdentity(ctx, gosdk.ApiKeyType, apikey) + if err != nil { + return gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + } + md, _ := metadata.FromIncomingContext(ctx) + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.AppendToOutgoingContext(ctx, gosdk.HeaderXIdentity, identity) + } else { + return gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) defer grpcStream.Release() - return handler(srv, grpcStream) + err := handler(srv, grpcStream) + + return err +} +func (a *ApiKeyAuthImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + if apikey, err := a.check(ctx); err == nil && apikey != "" { + identity, err := a.authz.GetIdentity(ctx, gosdk.ApiKeyType, apikey) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + } + md, _ := metadata.FromOutgoingContext(ctx) + md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) + ctx = metadata.NewOutgoingContext(ctx, md) + in, ok := metadata.FromIncomingContext(ctx) + if !ok { + in = metadata.New(make(map[string]string)) + } + in.Set(gosdk.HeaderXIdentity, identity) + // log.Printf("incoming identity:%s", identity) + ctx = metadata.NewIncomingContext(ctx, in) + return streamer(ctx, desc, cc, method, opts...) + } + return nil, gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } diff --git a/internal/middleware/auth/apikey_test.go b/internal/middleware/auth/apikey_test.go index 489d397..8c23a44 100644 --- a/internal/middleware/auth/apikey_test.go +++ b/internal/middleware/auth/apikey_test.go @@ -17,7 +17,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" c "github.com/smartystreets/goconvey/convey" @@ -39,7 +38,7 @@ func TestAPIKeyUnaryInterceptor(t *testing.T) { c.So(apikey.Name(), c.ShouldEqual, "api_key_auth") c.So(apikey.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -101,6 +100,18 @@ func TestAPIKeyUnaryInterceptor(t *testing.T) { patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "get user id base on apikey error") + // ValidateStream error + patch2 := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get user id base on apikey error")) + defer patch2.Reset() + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = apikey.(*auth.ApiKeyAuthImpl).ValidateStream(ctx, nil, "/integration.TestService/Get") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get user id base on apikey error") + patch2.Reset() + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) + _, err = apikey.(*auth.ApiKeyAuthImpl).ValidateStream(ctx, nil, "/integration.TestService/Get") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) }) } func newAuthzBiz() *biz.AuthzUsecase { @@ -131,7 +142,7 @@ func TestApiKeyStreamInterceptor(t *testing.T) { c.So(apikey.Name(), c.ShouldEqual, "api_key_auth") c.So(apikey.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -155,12 +166,13 @@ func TestApiKeyStreamInterceptor(t *testing.T) { ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())), }}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { - err := ss.RecvMsg(srv) md, _ := metadata.FromIncomingContext(ss.Context()) if len(md.Get(gosdk.HeaderXIdentity)) == 0 || md.Get(gosdk.HeaderXIdentity)[0] == "" { t.Error("identity not exists in context") return fmt.Errorf("identity not exists in context") } + err := ss.RecvMsg(srv) + return err }) diff --git a/internal/middleware/auth/auth.go b/internal/middleware/auth/auth.go index acc21cd..9a0f8e3 100644 --- a/internal/middleware/auth/auth.go +++ b/internal/middleware/auth/auth.go @@ -31,6 +31,7 @@ func NewAuth(ak *AccessKeyAuthMiddleware, jwt *JWTAuth, apikey ApiKeyAuth) gosdk } func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // fmt.Print("auth unary interceptor \n") if !IfNeedValidate(ctx, info.FullMethod) { return handler(ctx, req) } @@ -57,13 +58,14 @@ func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnarySe func (a *Auth) StreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if !IfNeedValidate(ss.Context(), info.FullMethod) { + // log.Printf("no need stream validate %s", info.FullMethod) return handler(srv, ss) } md, ok := metadata.FromIncomingContext(ss.Context()) if !ok { return status.Errorf(codes.Unauthenticated, "metadata not exists in context") } - xApiKey := md.Get("x-api-key") + xApiKey := md.Get(gosdk.HeaderXApiKey) if len(xApiKey) != 0 { return a.apikey.StreamInterceptor(srv, ss, info, handler) } @@ -89,3 +91,28 @@ func (a *Auth) Priority() int { func (a *Auth) Name() string { return a.name } + +func (a *Auth) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + // log.Printf("no need stream validate %s", info.FullMethod) + return streamer(ctx, desc, cc, method, opts...) + } + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") + } + xApiKey := md.Get("x-api-key") + if len(xApiKey) != 0 { + return a.apikey.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) + } + authorization := a.jwt.GetAuthorizationFromMetadata(md) + + if authorization == "" { + return nil, gosdk.NewError(pkg.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + } + if strings.Contains(authorization, "Bearer") { + return a.jwt.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) + + } + return a.ak.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) +} diff --git a/internal/middleware/auth/auth_test.go b/internal/middleware/auth/auth_test.go index 96ed475..69eb004 100644 --- a/internal/middleware/auth/auth_test.go +++ b/internal/middleware/auth/auth_test.go @@ -7,10 +7,12 @@ import ( "crypto/cipher" "crypto/rand" "crypto/rsa" + "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" + "io" "log" "net/http" "os" @@ -29,7 +31,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/pkg/utils" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" @@ -44,6 +45,21 @@ import ( type testStream struct { ctx context.Context } +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func (t *testStream) SetHeader(metadata.MD) error { return nil @@ -253,7 +269,7 @@ func TestUnaryInterceptor(t *testing.T) { } return nil, nil } - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -288,8 +304,11 @@ func TestUnaryInterceptor(t *testing.T) { c.Convey("TestUnaryInterceptor aksk", t, func() { u := &v1.Users{} bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) access, secret, appid := readInitAPP() sgin := gosdk.NewAppAuthSigner(access, secret) gwReq, err := gosdk.NewGatewayRequestFromHttp(req) @@ -342,7 +361,7 @@ func TestStreamInterceptor(t *testing.T) { mid := getMid() ctx := context.Background() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -353,6 +372,9 @@ func TestStreamInterceptor(t *testing.T) { ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { err := ss.RecvMsg(srv) + if err != nil { + return err + } md, _ := metadata.FromIncomingContext(ss.Context()) if identify := md.Get(gosdk.HeaderXIdentity); len(identify) == 0 || identify[0] == "" { return fmt.Errorf("no app identity") @@ -360,7 +382,7 @@ func TestStreamInterceptor(t *testing.T) { if xAppKey := md.Get(gosdk.HeaderXApiKey); len(xAppKey) == 0 || xAppKey[0] == "" { return fmt.Errorf("no app key") } - return err + return ss.RecvMsg(srv) }) c.So(err, c.ShouldBeNil) ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) @@ -368,6 +390,27 @@ func TestStreamInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldNotBeNil) + // get identity err + patch := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get identity err")) + defer patch.Reset() + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "query user id base on apikey") + patch.Reset() + + // check apikey not match + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) }) c.Convey("TestStreamInterceptor jwt", t, func() { @@ -390,13 +433,26 @@ func TestStreamInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldNotBeNil) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "")) + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrTokenMissing.Error()) }) - c.Convey("TestUnaryInterceptor aksk", t, func() { + c.Convey("TestStreamInterceptor aksk", t, func() { u := &v1.Users{} + bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) + + t.Logf("data:%s", string(bData)) req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) access, secret, appid := readInitAPP() sgin := gosdk.NewAppAuthSigner(access, secret) gwReq, err := gosdk.NewGatewayRequestFromHttp(req) @@ -426,13 +482,11 @@ func TestStreamInterceptor(t *testing.T) { return fmt.Errorf("no app identity") } if xAccessKey := md.Get(gosdk.HeaderXAccessKey); len(xAccessKey) == 0 || xAccessKey[0] == "" { - t.Logf("error metadata:%v", md) return fmt.Errorf("no app access key") } return err }) c.So(err, c.ShouldBeNil) - sign1 := gosdk.NewAppAuthSigner("ASDASDCASDFQ", "ASDASDCASDFQ") gwReq1, err := gosdk.NewGatewayRequestFromHttp(req) c.So(err, c.ShouldBeNil) @@ -452,13 +506,187 @@ func TestStreamInterceptor(t *testing.T) { }) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAppSignatureInvalid.Error()) + + // recv msg err + ctx = metadata.NewIncomingContext(context.Background(), md) + patch4 := gomonkey.ApplyFuncReturn((*testStream).RecvMsg, fmt.Errorf("recv msg err")) + + defer patch4.Reset() + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + patch4.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "recv msg err") + // recv io.EOF + patch5 := gomonkey.ApplyFuncReturn((*testStream).RecvMsg, io.EOF) + defer patch5.Reset() + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + patch5.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, io.EOF.Error()) + + }) +} +func TestStreamClientInterceptor(t *testing.T) { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + config := config.ReadConfig(env) + cnf := cfg.NewConfig(config) + mid := getMid() + ctx := context.Background() + + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + + pd, _ := gateway.NewDescription(pbFile) + R.LoadAllRouters(pd) + c.Convey("TestStreamClientInterceptor apikey", t, func() { + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + st, err := mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + err = st.SendMsg(&v1.Users{}) + c.So(err, c.ShouldBeNil) + + // no need validate + st, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldBeEmpty) + // no outgoing context + _, err = mid.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + // do not need validate + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + // not match apikey + patch := gomonkey.ApplyFuncReturn((*cfg.Config).GetAdminAPIKey, "") + defer patch.Reset() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) + patch.Reset() + // get owner error + patch2 := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get owner err")) + defer patch2.Reset() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + patch2.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "query user id base on apikey") + + }) + c.Convey("TestStreamClientInterceptor jwt", t, func() { + jwt := getJWT() + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Bearer "+jwt)) + st, err := mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldNotBeEmpty) + }) + c.Convey("TestStreamClientInterceptor aksk", t, func() { + u := &v1.Users{} + + bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) + + t.Logf("data:%s", string(bData)) + req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) + access, secret, appid := readInitAPP() + sgin := gosdk.NewAppAuthSigner(access, secret) + gwReq, err := gosdk.NewGatewayRequestFromHttp(req) + c.So(err, c.ShouldBeNil) + err = sgin.SignRequest(gwReq) + c.So(err, c.ShouldBeNil) + md := metadata.New(make(map[string]string)) + + headers := gwReq.Headers + for _, k := range headers.Keys() { + // t.Logf("header:%s,value:%s", k, headers.Get(k)) + md.Append(k, headers.Get(k)) + } + md.Append("uri", "/test/post") + md.Append("x-http-method", http.MethodPost) + patch := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetSecret, secret, nil) + patch = patch.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, appid, nil) + defer patch.Reset() + outCTX := metadata.NewOutgoingContext(context.Background(), md) + st, err := mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldNotBeEmpty) + // no need validate + st, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok = metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldBeEmpty) + // no context + _, err = mid.StreamClientInterceptor(req.Context(), nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + + // streamer err + _, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, fmt.Errorf("streamer err") + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "streamer err") + + // get owner err + + patch2 := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, "", fmt.Errorf("get owner err")) + defer patch2.Reset() + _, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get owner err") + }) } func TestTestUnaryInterceptorErr(t *testing.T) { mid := getMid() ctx := context.Background() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -485,7 +713,7 @@ func TestTestUnaryInterceptorErr(t *testing.T) { func TestStreamInterceptorErr(t *testing.T) { mid := getMid() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") diff --git a/internal/middleware/auth/headers.go b/internal/middleware/auth/headers.go index b47c0f1..7a44932 100644 --- a/internal/middleware/auth/headers.go +++ b/internal/middleware/auth/headers.go @@ -1,119 +1,134 @@ package auth -import ( - "context" - "net/http" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" -) - type Header interface { Set(key, value string) SendHeader(key, value string) } -type GrpcHeader struct { - in metadata.MD - ctx context.Context - out metadata.MD -} -type httpHeader struct { - w http.ResponseWriter - r *http.Request -} -type GrpcStreamHeader struct { - *GrpcHeader - ss grpc.ServerStream -} -var headerPool = &sync.Pool{ - New: func() interface{} { - return &GrpcHeader{} - }, -} -var httpHeaderPool = &sync.Pool{ - New: func() interface{} { - return &httpHeader{} - }, -} +// type GrpcHeader struct { +// in metadata.MD +// ctx context.Context +// out metadata.MD +// } +// type httpHeader struct { +// w http.ResponseWriter +// r *http.Request +// } +// type GrpcStreamHeader struct { +// *GrpcHeader +// ss grpc.ServerStream +// } -var grpcStreamHeaderPool = &sync.Pool{ - New: func() interface{} { - return &GrpcStreamHeader{} - }, -} +// var headerPool = &sync.Pool{ +// New: func() interface{} { +// return &GrpcHeader{} +// }, +// } +// var httpHeaderPool = &sync.Pool{ +// New: func() interface{} { +// return &httpHeader{} +// }, +// } -func (g *GrpcHeader) Release() { - g.ctx = nil - g.in = nil - g.out = nil - headerPool.Put(g) -} -func (g *GrpcHeader) Set(key, value string) { - g.in.Set(key, value) - md, _ := metadata.FromIncomingContext(g.ctx) - newMd := metadata.Join(md, g.in) - g.ctx = metadata.NewIncomingContext(g.ctx, newMd) +// var grpcStreamHeaderPool = &sync.Pool{ +// New: func() interface{} { +// return &GrpcStreamHeader{} +// }, +// } -} -func (g *GrpcHeader) SendHeader(key, value string) { - g.out.Append(key, value) - _ = grpc.SendHeader(g.ctx, g.out) - g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) -} -func (g *httpHeader) Release() { - g.w = nil - g.r = nil - httpHeaderPool.Put(g) -} -func (g *httpHeader) Set(key, value string) { - g.r.Header.Add(key, value) +// func (g *GrpcHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// headerPool.Put(g) +// } +// func (g *GrpcHeader) Set(key, value string) { +// g.in.Set(key, value) +// md, _ := metadata.FromIncomingContext(g.ctx) +// newMd := metadata.Join(md, g.in) +// g.ctx = metadata.NewIncomingContext(g.ctx, newMd) +// log.Printf("grpc header set key:%v value:%v", key, value) -} -func (g *httpHeader) SendHeader(key, value string) { - g.w.Header().Add(key, value) -} -func (g *GrpcStreamHeader) Release() { - g.ctx = nil - g.in = nil - g.out = nil - g.ss = nil - grpcStreamHeaderPool.Put(g) +// } +// func (g *GrpcHeader) SendHeader(key, value string) { +// g.out.Append(key, value) +// _ = grpc.SendHeader(g.ctx, g.out) +// g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) +// } +// func (g *httpHeader) Release() { +// g.w = nil +// g.r = nil +// httpHeaderPool.Put(g) +// } +// func (g *httpHeader) Set(key, value string) { +// g.r.Header.Add(key, value) -} -func (g *GrpcStreamHeader) Set(key, value string) { - g.in.Append(key, value) - newCtx := metadata.NewIncomingContext(g.ctx, g.in) - g.ctx = newCtx - _ = g.ss.SetHeader(g.in) -} -func (g *GrpcStreamHeader) SendHeader(key, value string) { - g.out.Append(key, value) - _ = g.ss.SendHeader(g.out) - g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) -} +// } +// func (g *httpHeader) SendHeader(key, value string) { +// g.w.Header().Add(key, value) +// } -func NewGrpcHeader(in metadata.MD, ctx context.Context, out metadata.MD) *GrpcHeader { - // return &GrpcHeader{in: in, ctx: ctx, out: out} - header := headerPool.Get().(*GrpcHeader) - header.in = in - header.ctx = ctx - header.out = out - return header -} -func NewHttpHeader(w http.ResponseWriter, r *http.Request) *httpHeader { - // return &httpHeader{w: w, r: r} - header := httpHeaderPool.Get().(*httpHeader) - header.w = w - header.r = r - return header -} +// func (g *GrpcStreamHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// g.ss = nil +// grpcStreamHeaderPool.Put(g) -func NewGrpcStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, ss grpc.ServerStream) *GrpcStreamHeader { - // return &GrpcStreamHeader{&GrpcHeader{in: in, ctx: ctx, out: out}, ss} - header := grpcStreamHeaderPool.Get().(*GrpcStreamHeader) - header.GrpcHeader = NewGrpcHeader(in, ctx, out) - header.ss = ss - return header -} +// } +// func (g *GrpcStreamHeader) Set(key, value string) { +// g.in.Set(key, value) +// newCtx := metadata.NewIncomingContext(g.ctx, g.in) +// g.ctx = newCtx +// _ = g.ss.SetHeader(g.in) +// } +// func (g *GrpcStreamHeader) SendHeader(key, value string) { +// g.out.Append(key, value) +// _ = g.ss.SendHeader(g.out) +// g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) +// } + +// func (g *GrpcClientStreamHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// g.cs = nil +// grpcClientHeaderPool.Put(g) + +// } +// func (g *GrpcClientStreamHeader) Set(key, value string) { +// g.out.Set(key, value) +// newCtx := metadata.NewOutgoingContext(g.ctx, g.in) +// g.ctx = newCtx +// // _ = g.cs.(g.in) + +// } +// func NewGrpcHeader(in metadata.MD, ctx context.Context, out metadata.MD) *GrpcHeader { +// // return &GrpcHeader{in: in, ctx: ctx, out: out} +// header := headerPool.Get().(*GrpcHeader) +// header.in = in +// header.ctx = ctx +// header.out = out +// return header +// } +// func NewHttpHeader(w http.ResponseWriter, r *http.Request) *httpHeader { +// // return &httpHeader{w: w, r: r} +// header := httpHeaderPool.Get().(*httpHeader) +// header.w = w +// header.r = r +// return header +// } + +// func NewGrpcStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, ss grpc.ServerStream) *GrpcStreamHeader { +// // return &GrpcStreamHeader{&GrpcHeader{in: in, ctx: ctx, out: out}, ss} +// header := grpcStreamHeaderPool.Get().(*GrpcStreamHeader) +// header.GrpcHeader = NewGrpcHeader(in, ctx, out) +// header.ss = ss +// return header +// } +// func NewGrpcClientStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, cs grpc.ClientStream) *GrpcClientStreamHeader { +// header := grpcClientHeaderPool.Get().(*GrpcClientStreamHeader) +// header.GrpcHeader = NewGrpcHeader(in, ctx, out) +// header.cs = cs +// return header +// } diff --git a/internal/middleware/auth/headers_test.go b/internal/middleware/auth/headers_test.go index 8344ba2..27cdf96 100644 --- a/internal/middleware/auth/headers_test.go +++ b/internal/middleware/auth/headers_test.go @@ -1,32 +1,32 @@ package auth_test -import ( - "net/http" - "testing" +// import ( +// "net/http" +// "testing" - "github.com/begonia-org/begonia/internal/middleware/auth" - c "github.com/smartystreets/goconvey/convey" -) +// "github.com/begonia-org/begonia/internal/middleware/auth" +// c "github.com/smartystreets/goconvey/convey" +// ) -type responseWriter struct { -} +// type responseWriter struct { +// } -func (r *responseWriter) Header() http.Header { - return make(http.Header) -} -func (r *responseWriter) Write([]byte) (int, error) { - return 0, nil -} -func (r *responseWriter) WriteHeader(int) { +// func (r *responseWriter) Header() http.Header { +// return make(http.Header) +// } +// func (r *responseWriter) Write([]byte) (int, error) { +// return 0, nil +// } +// func (r *responseWriter) WriteHeader(int) { -} -func TestHeaders(t *testing.T) { - c.Convey("TestHeaders", t, func() { - req, _ := http.NewRequest("GET", "http://localhost", nil) - h := auth.NewHttpHeader(&responseWriter{}, req) - c.So(h, c.ShouldNotBeNil) - h.Set("key", "value") - h.SendHeader("key", "value") - h.Release() - }) -} +// } +// func TestHeaders(t *testing.T) { +// c.Convey("TestHeaders", t, func() { +// req, _ := http.NewRequest("GET", "http://localhost", nil) +// h := auth.NewHttpHeader(&responseWriter{}, req) +// c.So(h, c.ShouldNotBeNil) +// h.Set("key", "value") +// h.SendHeader("key", "value") +// h.Release() +// }) +// } diff --git a/internal/middleware/auth/jwt.go b/internal/middleware/auth/jwt.go index 381a978..c5c35b2 100644 --- a/internal/middleware/auth/jwt.go +++ b/internal/middleware/auth/jwt.go @@ -104,18 +104,21 @@ func (a *JWTAuth) checkJWTItem(ctx context.Context, payload *api.BasicAuth, toke } return true, nil } -func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, rspHeader Header, reqHeader Header) (ok bool, err error) { +func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, io *metadata.MD) (ok bool, err error) { payload, errAuth := a.jwt2BasicAuth(authorization) err = errAuth if err != nil { return false, err } + io.Set("x-uid", payload.Uid) + strArr := strings.Split(authorization, " ") token := strArr[1] ok, err = a.checkJWTItem(ctx, payload, token) if err != nil || !ok { return false, err } + io.Set("x-token", token) left := payload.Expiration - time.Now().Unix() // expiration := a.config.GetJWTExpiration() @@ -156,63 +159,73 @@ func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, rspHeader } // 旧token加入黑名单 go a.biz.PutBlackList(ctx, a.config.GetUserBlackListKey(tiga.GetMd5(token))) - rspHeader.SendHeader("Authorization", fmt.Sprintf("Bearer %s", newToken)) + // rspHeader.Set("Authorization", fmt.Sprintf("Bearer %s", newToken)) + _ = grpc.SetHeader(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", newToken))) token = newToken } - // 设置uid - reqHeader.Set("x-token", token) - reqHeader.Set("x-uid", payload.Uid) - reqHeader.Set(gosdk.HeaderXIdentity, payload.Uid) + io.Set("x-token", token) return true, nil } -func (a *JWTAuth) jwtValidator(ctx context.Context, headers Header) (context.Context, error) { - - md, _ := metadata.FromIncomingContext(ctx) +func (a *JWTAuth) jwtValidator(ctx context.Context) (context.Context, error) { + in, inOK := metadata.FromIncomingContext(ctx) + if !inOK { + in = metadata.MD{} + } + out, ok := metadata.FromOutgoingContext(ctx) + if !ok { + out = metadata.MD{} + } + md := metadata.Join(in, out) token := a.GetAuthorizationFromMetadata(md) if token == "" { - return nil, status.Errorf(codes.Unauthenticated, "token not exists in context") - } + if token = a.GetAuthorizationFromMetadata(out); token == "" { + return nil, status.Errorf(codes.Unauthenticated, "token not exists in context") + } - ok, err := a.checkJWT(ctx, token, headers, headers) + } + ioMD := metadata.New(make(map[string]string)) + ok, err := a.checkJWT(ctx, token, &ioMD) if err != nil || !ok { return nil, status.Errorf(codes.Unauthenticated, "check token error,%v", err) } + in.Set("x-token", ioMD.Get("x-token")...) + in.Set("x-uid", ioMD.Get("x-uid")...) + in.Set(gosdk.HeaderXIdentity, ioMD.Get("x-uid")...) + newCtx := metadata.NewIncomingContext(ctx, in) + + out.Set("x-token", ioMD.Get("x-token")...) + out.Set("x-uid", ioMD.Get("x-uid")...) + out.Set(gosdk.HeaderXIdentity, ioMD.Get("x-uid")...) - newCtx := metadata.NewIncomingContext(ctx, md) + newCtx = metadata.NewOutgoingContext(newCtx, out) return newCtx, nil // return handler(newCtx, req) } func (a *JWTAuth) RequestBefore(ctx context.Context, info *grpc.UnaryServerInfo, req interface{}) (context.Context, error) { - in, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") - } - out, ok := metadata.FromOutgoingContext(ctx) - if !ok { - out = metadata.MD{} - } - headers := NewGrpcHeader(in, ctx, out) - defer headers.Release() - _, err := a.jwtValidator(ctx, headers) + ctx, err := a.jwtValidator(ctx) if err != nil { return nil, err } - return headers.ctx, nil + return ctx, nil } -func (a *JWTAuth) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *JWTAuth) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { // headers := NewGrpcStreamHeader(in, ctx, out,ss) - ctx, err := a.jwtValidator(ctx, headers) + ctx, err := a.jwtValidator(ctx) return ctx, err } func (a *JWTAuth) StreamRequestBefore(ctx context.Context, ss grpc.ServerStream, info *grpc.StreamServerInfo, req interface{}) (grpc.ServerStream, error) { - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) + ctx, err := a.jwtValidator(ctx) + if err != nil { + return nil, err + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) return grpcStream, nil } @@ -272,3 +285,15 @@ func (jwt *JWTAuth) Priority() int { func (jwt *JWTAuth) Name() string { return jwt.name } + +func (jwt *JWTAuth) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + + ctx, err := jwt.jwtValidator(ctx) + if err != nil { + return nil, err + } + return streamer(ctx, desc, cc, method, opts...) +} diff --git a/internal/middleware/auth/jwt_test.go b/internal/middleware/auth/jwt_test.go index d1e388e..a6d5bef 100644 --- a/internal/middleware/auth/jwt_test.go +++ b/internal/middleware/auth/jwt_test.go @@ -19,7 +19,7 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" + gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" api "github.com/begonia-org/go-sdk/api/user/v1" "github.com/bsm/redislock" @@ -62,7 +62,7 @@ func TestJWTUnaryInterceptor(t *testing.T) { config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -85,15 +85,6 @@ func TestJWTUnaryInterceptor(t *testing.T) { }) c.So(err, c.ShouldBeNil) - _, err = jwt.UnaryInterceptor(context.Background(), &hello.HelloRequest{}, &grpc.UnaryServerInfo{ - FullMethod: "/integration.TestService/Get", - }, func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, nil - - }) - c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") - _, err = jwt.UnaryInterceptor(metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test")), &hello.HelloRequest{}, &grpc.UnaryServerInfo{ FullMethod: "/integration.TestService/Get", }, func(ctx context.Context, req interface{}) (interface{}, error) { @@ -289,7 +280,7 @@ func TestJWTStreamInterceptor(t *testing.T) { config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -320,3 +311,40 @@ func TestJWTStreamInterceptor(t *testing.T) { c.So(err, c.ShouldBeNil) }) } + +func TestJWTClientStream(t *testing.T) { + c.Convey("TestJWTClientStream", t, func() { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + config := config.ReadConfig(env) + cnf := cfg.NewConfig(config) + + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + + pd, _ := gateway.NewDescription(pbFile) + R.LoadAllRouters(pd) + user := data.NewUserRepo(config, gateway.Log) + userAuth := crypto.NewUsersAuth(cnf) + authzRepo := data.NewAuthzRepo(config, gateway.Log) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) + jwt := auth.NewJWTAuth(cnf, tiga.NewRedisDao(config), authz, gateway.Log) + jwt.SetPriority(1) + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAuthorization, cnf.GetAdminAPIKey())) + _, err := jwt.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + + outCTX = metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAuthorization, cnf.GetAdminAPIKey())) + _, err = jwt.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + + }) +} diff --git a/internal/middleware/auth/stream.go b/internal/middleware/auth/stream.go index 2fde49c..e4a89d8 100644 --- a/internal/middleware/auth/stream.go +++ b/internal/middleware/auth/stream.go @@ -2,24 +2,34 @@ package auth import ( "context" + "errors" + "io" "sync" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) type StreamValidator interface { - ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) + ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) } type grpcServerStream struct { grpc.ServerStream - fullName string - validate StreamValidator - ctx context.Context + fullName string + validate StreamValidator + ctx context.Context + firstFrame bool } +// type grpcClientStream struct { +// grpc.ClientStream +// fullName string +// ctx context.Context +// validate StreamValidator +// firstFrame bool +// } + var streamPool = &sync.Pool{ New: func() interface{} { return &grpcServerStream{ @@ -28,11 +38,17 @@ var streamPool = &sync.Pool{ }, } +// var clientStreamPool = &sync.Pool{ +// New: func() interface{} { +// return &grpcClientStream{} +// }, +// } + func NewGrpcStream(s grpc.ServerStream, fullName string, ctx context.Context, validator StreamValidator) *grpcServerStream { stream := streamPool.Get().(*grpcServerStream) stream.ServerStream = s stream.fullName = fullName - stream.ctx = s.Context() + stream.ctx = ctx stream.validate = validator return stream } @@ -41,29 +57,28 @@ func (g *grpcServerStream) Release() { g.fullName = "" g.ServerStream = nil g.validate = nil + g.firstFrame = false streamPool.Put(g) } func (g *grpcServerStream) Context() context.Context { return g.ctx } func (s *grpcServerStream) RecvMsg(m interface{}) error { - if err := s.ServerStream.RecvMsg(m); err != nil { - return err + var err error + if err = s.ServerStream.RecvMsg(m); err != nil && !errors.Is(err, io.EOF) { + return status.Errorf(codes.Internal, "recv msg err:%s", err.Error()) } - in, ok := metadata.FromIncomingContext(s.Context()) - if !ok { - return status.Errorf(codes.Unauthenticated, "metadata not exists in context") + if err != nil { + return err } - out, ok := metadata.FromOutgoingContext(s.Context()) - if !ok { - out = metadata.MD{} + if !s.firstFrame { + ctx, err := s.validate.ValidateStream(s.Context(), m, s.fullName) + s.ctx = ctx + s.firstFrame = true + return err } - header := NewGrpcStreamHeader(in, s.Context(), out, s.ServerStream) - _, err := s.validate.ValidateStream(s.Context(), m, s.fullName, header) - s.ctx = header.ctx - header.Release() - return err + return nil } diff --git a/internal/middleware/http.go b/internal/middleware/http.go index 8bf0804..a1b4bf6 100644 --- a/internal/middleware/http.go +++ b/internal/middleware/http.go @@ -3,9 +3,8 @@ package middleware import ( "context" "fmt" - "strings" - "github.com/begonia-org/begonia/internal/pkg/routers" + "github.com/begonia-org/begonia/gateway" gosdk "github.com/begonia-org/go-sdk" _ "github.com/begonia-org/go-sdk/api/app/v1" _ "github.com/begonia-org/go-sdk/api/endpoint/v1" @@ -19,7 +18,6 @@ import ( "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -38,28 +36,24 @@ type HttpStream struct { grpc.ServerStream FullMethod string } + type Http struct { priority int name string } func (s *HttpStream) SendMsg(m interface{}) error { - ctx := s.ServerStream.Context() - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return s.ServerStream.SendMsg(m) - } - if protocol, ok := md["grpcgateway-accept"]; ok { - if !strings.EqualFold(protocol[0], "application/json") { + + routersList := gateway.GetRouter() + router := routersList.GetRouteByGrpcMethod(s.FullMethod) + // 对内置服务的http响应进行格式化 + if routersList.IsLocalSrv(s.FullMethod) || (router != nil && router.UseJsonResponse) { + if _, ok := m.(*httpbody.HttpBody); ok { return s.ServerStream.SendMsg(m) } - routersList := routers.Get() - router := routersList.GetRouteByGrpcMethod(s.FullMethod) - // 对内置服务的http响应进行格式化 - if routersList.IsLocalSrv(s.FullMethod) || router.UseJsonResponse { - rsp, _ := grpcToHttpResponse(m, nil) - return s.ServerStream.SendMsg(rsp) - } + rsp, _ := grpcToHttpResponse(m, s.Context().Err()) + err := s.ServerStream.SendMsg(rsp) + return err } return s.ServerStream.SendMsg(m) } @@ -94,7 +88,6 @@ func toStructMessage(msg protoreflect.ProtoMessage) (*structpb.Struct, error) { return structMsg, nil } func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error) { - if err != nil { if st, ok := status.FromError(err); ok { details := st.Details() @@ -102,7 +95,6 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error if anyType, ok := detail.(*anypb.Any); ok { var errDetail common.Errors var stErr = anyType.UnmarshalTo(&errDetail) - if stErr == nil { rspCode := int32(errDetail.Code) codesMap := getClientMessageMap() @@ -125,7 +117,13 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error if st.Code() == codes.Unimplemented { code = int32(common.Code_NOT_FOUND) } + if st.Code() == codes.InvalidArgument { + code = int32(common.Code_PARAMS_ERROR) + } + if st.Code() == codes.AlreadyExists { + code = int32(common.Code_CONFLICT) + } return &common.HttpResponse{ Code: code, Message: st.Message(), @@ -155,25 +153,19 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error }, err } func (h *Http) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return handler(ctx, req) - } - if protocol, ok := md["grpcgateway-accept"]; ok { - if !strings.EqualFold(protocol[0], "application/json") { - return handler(ctx, req) - } - routersList := routers.Get() - router := routersList.GetRouteByGrpcMethod(info.FullMethod) - // 对内置服务的http响应进行格式化 - if routersList.IsLocalSrv(info.FullMethod) || router.UseJsonResponse { - rsp, err := handler(ctx, req) - if _, ok := rsp.(*httpbody.HttpBody); ok { - return rsp, err - } - return grpcToHttpResponse(rsp, err) + + routersList := gateway.GetRouter() + router := routersList.GetRouteByGrpcMethod(info.FullMethod) + // 对内置服务的http响应进行格式化 + if routersList.IsLocalSrv(info.FullMethod) || router.UseJsonResponse { + // ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("content-type", "application/json")) + rsp, err := handler(ctx, req) + if _, ok := rsp.(*httpbody.HttpBody); ok { + return rsp, err } + return grpcToHttpResponse(rsp, err) } + // } return handler(ctx, req) } @@ -183,7 +175,13 @@ func (h *Http) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *gr return handler(srv, stream) } - +func (h *Http) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ss, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, status.Errorf(codes.Internal, "create client stream error:%s", err.Error()) + } + return ss, nil +} func NewHttp() *Http { return &Http{name: "http"} } diff --git a/internal/middleware/http_test.go b/internal/middleware/http_test.go index 805d4c0..be75896 100644 --- a/internal/middleware/http_test.go +++ b/internal/middleware/http_test.go @@ -11,7 +11,6 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/middleware" "github.com/begonia-org/begonia/internal/pkg" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" user "github.com/begonia-org/go-sdk/api/user/v1" @@ -70,7 +69,7 @@ func (x *greeterSayHelloWebsocketServer) Context() context.Context { func TestStreamInterceptor(t *testing.T) { c.Convey("test stream interceptor", t, func() { mid := middleware.NewHttp() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") @@ -84,10 +83,37 @@ func TestStreamInterceptor(t *testing.T) { c.So(err, c.ShouldBeNil) }) } +func TestHttpStreamClientInterceptor(t *testing.T) { + c.Convey("test http stream client interceptor", t, func() { + mid := middleware.NewHttp() + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") + + pd, err := gateway.NewDescription(pbFile) + c.So(err, c.ShouldBeNil) + R.LoadAllRouters(pd) + stream, err := mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("grpcgateway-accept", "application/json")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + c.So(stream, c.ShouldNotBeNil) + + stream, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("grpcgateway-accept", "application/json")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, fmt.Errorf("new stream err") + }, + ) + c.So(err, c.ShouldNotBeNil) + c.So(stream, c.ShouldBeNil) + + }) + +} func TestUnaryInterceptor(t *testing.T) { c.Convey("test unary interceptor", t, func() { mid := middleware.NewHttp() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") @@ -120,6 +146,16 @@ func TestUnaryInterceptor(t *testing.T) { }) c.So(err, c.ShouldNotBeNil) c.So(req, c.ShouldNotBeNil) + req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.Error(codes.InvalidArgument, "test") + }) + c.So(err, c.ShouldNotBeNil) + c.So(req, c.ShouldNotBeNil) + req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.Error(codes.AlreadyExists, "test") + }) + c.So(err, c.ShouldNotBeNil) + c.So(req, c.ShouldNotBeNil) req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, fmt.Errorf("test") diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 234030f..e0d69f1 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -102,3 +102,11 @@ func (p *PluginsApply) StreamInterceptorChains() []grpc.StreamServerInterceptor } return chains } + +func (p *PluginsApply) StreamClientInterceptorChains() []grpc.StreamClientInterceptor { + chains := make([]grpc.StreamClientInterceptor, 0) + for _, plugin := range p.Plugins { + chains = append(chains, plugin.(gosdk.LocalPlugin).StreamClientInterceptor) + } + return chains +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 0f83484..5d631a9 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -37,6 +37,7 @@ func TestMiddlewareUnaryInterceptorChains(t *testing.T) { // mid.SetPriority(1) c.So(len(mid.StreamInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) c.So(len(mid.UnaryInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) + c.So(len(mid.StreamClientInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) plugins := cnf.GetPlugins() plugins["test"] = 1 diff --git a/internal/middleware/protobuf_validate.go b/internal/middleware/protobuf_validate.go new file mode 100644 index 0000000..ae025d8 --- /dev/null +++ b/internal/middleware/protobuf_validate.go @@ -0,0 +1,187 @@ +package middleware + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/go-playground/validator/v10" + "github.com/iancoleman/strcase" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" +) + +type protobufValidator struct { + validate *validator.Validate +} + +// NewProtobufValidate Create a new protobuf validator +// +// The validate parameter is a validator instance that can be used to validate the structure of the protobuf message +func NewProtobufValidate(validate *validator.Validate) *protobufValidator { + return &protobufValidator{validate: validate} +} + +// getValue Get the value of the field +func (p *protobufValidator) getValue(v protoreflect.Value, k protoreflect.Kind, f protoreflect.FieldDescriptor) interface{} { + switch k { + case protoreflect.BoolKind: + return v.Bool() + case protoreflect.StringKind: + return v.String() + case protoreflect.Int32Kind, protoreflect.Int64Kind: + return v.Int() + case protoreflect.Uint32Kind, protoreflect.Uint64Kind: + return v.Uint() + case protoreflect.FloatKind, protoreflect.DoubleKind: + return v.Float() + case protoreflect.MessageKind: + return v.Message().Interface() + case protoreflect.EnumKind: + return f.Enum().Values().ByNumber(v.Enum()).Name() + case protoreflect.BytesKind: + return v.Bytes() + + } + return nil +} + +// getFieldTag Get the tag of the field +// +// The validateTag parameter is the validate tag of the field +// For Enum fields, the oneof tag is added to the validate tag +// Json tag is added to the field tag +func (p *protobufValidator) getFieldTag(field protoreflect.FieldDescriptor, validateTag interface{}) string { + tag := fmt.Sprintf(`json:"%s"`, field.JSONName()) + if validateTag != nil { + validate := validateTag.(string) + tag = fmt.Sprintf(`json:"%s" validate:"%s"`, field.JSONName(), validate) + } + if field.Enum() != nil && !strings.Contains(tag, "oneof") { + oneOfEnum := make([]string, 0) + for i := 0; i < field.Enum().Values().Len(); i++ { + oneOfEnum = append(oneOfEnum, string(field.Enum().Values().Get(i).Name())) + } + if !field.IsList() && !field.IsMap() { + tag = fmt.Sprintf(`json:"%s" validate:"oneof=%s"`, field.JSONName(), strings.Join(oneOfEnum, " ")) + } + } + return tag +} + +// handleFieldValue Handle the field value, convert the protobuf message field to a struct field by recursion +// +// see: `reflect.StructField.Type` +func (p *protobufValidator) handleFieldValue(field protoreflect.FieldDescriptor, fieldValue protoreflect.Value, ext protoreflect.ExtensionType) interface{} { + if field.IsMap() { + // convert map field to struct map field + mapTyp := make(map[string]interface{}) + mapValue := fieldValue.Map() + mapValue.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool { + if field.MapValue().Kind() == protoreflect.MessageKind { + mapTyp[key.String()] = p.protobufToStructType(value.Message().Interface(), ext).Interface() + } else { + mapTyp[key.String()] = p.getValue(value, field.MapValue().Kind(), field.MapValue()) + } + return true + }) + return mapTyp + } + + if field.Kind() == protoreflect.MessageKind { + if field.IsList() { + + list := make([]interface{}, 0) + for j := 0; j < fieldValue.List().Len(); j++ { + list = append(list, p.protobufToStructType(fieldValue.List().Get(j).Message().Interface(), ext).Interface()) + } + return list + } + return p.protobufToStructType(fieldValue.Message().Interface(), ext).Interface() + } + + if field.IsList() { + + list := make([]interface{}, 0) + for j := 0; j < fieldValue.List().Len(); j++ { + list = append(list, p.getValue(fieldValue.List().Get(j), field.Kind(), field)) + } + return list + } + + return p.getValue(fieldValue, field.Kind(), field) +} + +// isValid Determine whether the field value is valid, +// +// If the field is a list or map, the function will return true if the field is valid +func (p *protobufValidator) isValid(value protoreflect.Value, f protoreflect.FieldDescriptor) bool { + if f.IsList() { + return value.List().IsValid() + } + if f.IsMap() { + return value.Map().IsValid() + } + switch f.Kind() { + case protoreflect.MessageKind: + return value.Message().IsValid() + case protoreflect.StringKind: + return value.String() != "" + default: + return value.IsValid() + } +} + +// protobufToStructType Convert the protobuf message to a struct type +// +// see: `reflect.StructField` +func (p *protobufValidator) protobufToStructType(message proto.Message, ext protoreflect.ExtensionType) reflect.Value { + md := message.ProtoReflect().Descriptor() + fieldsValues := make(map[string]reflect.Value) + structFields := make([]reflect.StructField, 0) + + for i := 0; i < md.Fields().Len(); i++ { + field := md.Fields().Get(i) + fieldName := strcase.ToCamel(string(field.Name())) + fieldValue := message.ProtoReflect().Get(field) + value := p.handleFieldValue(field, fieldValue, ext) + validateTag := proto.GetExtension(field.Options(), ext) + tag := p.getFieldTag(field, validateTag) + + structFields = append(structFields, reflect.StructField{ + Name: fieldName, + Type: reflect.TypeOf(value), + Tag: reflect.StructTag(tag), + }) + fieldsValues[fieldName] = reflect.ValueOf(value) + if !p.isValid(fieldValue, field) { + fieldsValues[fieldName] = reflect.Zero(reflect.TypeOf(value)) + } + } + + structType := reflect.StructOf(structFields) + newTypVal := reflect.New(structType) + for k, v := range fieldsValues { + newTypVal.Elem().FieldByName(k).Set(v) + } + return newTypVal +} + +func (p *protobufValidator) Protobuf(message proto.Message, ext protoreflect.ExtensionType) error { + v := p.protobufToStructType(message, ext).Interface() + return p.validate.Struct(v) +} + +func (p *protobufValidator) NewStructFromProtobuf(message proto.Message, ext protoreflect.ExtensionType) interface{} { + return p.protobufToStructType(message, ext).Interface() +} + +func (p *protobufValidator) ProtobufPartial(message proto.Message, ext protoreflect.ExtensionType, fields ...string) error { + v := p.protobufToStructType(message, ext).Interface() + return p.validate.StructPartial(v, fields...) +} +func (p *protobufValidator) ProtobufPartialCtx(ctx context.Context, message proto.Message, ext protoreflect.ExtensionType, fields ...string) error { + v := p.protobufToStructType(message, ext).Interface() + return p.validate.StructPartialCtx(ctx, v, fields...) +} diff --git a/internal/middleware/rpc.go b/internal/middleware/rpc.go index 010c377..13e7df2 100644 --- a/internal/middleware/rpc.go +++ b/internal/middleware/rpc.go @@ -22,22 +22,11 @@ import ( type RPCPluginCaller interface{} -// type rpcPluginCallerImpl struct { -// plugins gosdk.Plugins -// } - -// func NewRPCPluginCaller() RPCPluginCaller { -// return &rpcPluginCallerImpl{ -// plugins: make(gosdk.Plugins, 0), -// } -// } - type pluginImpl struct { priority int name string timeout time.Duration lb lb.LoadBalance - // api.PluginServiceClient } func (p *pluginImpl) SetPriority(priority int) { @@ -54,17 +43,16 @@ func (p *pluginImpl) Name() string { func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - md = metadata.New(nil) + md = metadata.New(make(map[string]string)) } - rsp, err := p.Apply(ctx, req, info.FullMethod) + rsp, header, err := p.Apply(ctx, req, info.FullMethod) if err != nil { return nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } + for k, v := range header { + md[k] = append(md[k], v...) - for k, v := range rsp.Metadata { - md.Append(k, v) } - newRequest := rsp.NewRequest if newRequest != nil { err = newRequest.UnmarshalTo(req.(proto.Message)) @@ -97,15 +85,15 @@ func (p *pluginImpl) getEndpoint(ctx context.Context) (lb.Endpoint, error) { return endpoint, nil } -func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName string) (*api.PluginResponse, error) { +func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName string) (*api.PluginResponse, metadata.MD, error) { endpoint, err := p.getEndpoint(ctx) if err != nil { - return nil, err + return nil, nil, err } cn, err := endpoint.Get(ctx) if err != nil { - return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + return nil, nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") } defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) @@ -113,14 +101,38 @@ func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName s plugin := api.NewPluginServiceClient(conn) anyReq, err := anypb.New(in.(proto.Message)) if err != nil { - return nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") + return nil, nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") } - return plugin.Apply(ctx, &api.PluginRequest{ + var header, trailer metadata.MD + rsp, err := plugin.Apply(ctx, &api.PluginRequest{ Request: anyReq, FullMethodName: fullMethodName, - }) - // return plugin.Call(ctx, anyReq, opts...) + }, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return nil, nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") + } + return rsp, header, nil +} +func (p *pluginImpl) Metadata(ctx context.Context, in *emptypb.Empty) (metadata.MD, error) { + endpoint, err := p.getEndpoint(ctx) + if err != nil { + return nil, err + } + cn, err := endpoint.Get(ctx) + if err != nil { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + } + defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) + conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) + plugin := api.NewPluginServiceClient(conn) + var header, trailer metadata.MD + _, err = plugin.Metadata(ctx, in, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("call plugin metadata error:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "metadata") + } + return header, nil + } func (p *pluginImpl) Info(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*api.PluginInfo, error) { endpoint, err := p.getEndpoint(ctx) @@ -137,14 +149,26 @@ func (p *pluginImpl) Info(ctx context.Context, in *emptypb.Empty, opts ...grpc.C return plugin.Info(ctx, in, opts...) } func (p *pluginImpl) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + md, err := p.Metadata(ss.Context(), &emptypb.Empty{}) + if err != nil { + return err + + } + in, _ := metadata.FromIncomingContext(ss.Context()) - grpcStream := NewGrpcPluginStream(ss, info.FullMethod, ss.Context(), p) + ctx := metadata.NewIncomingContext(ss.Context(), metadata.Join(in, md)) + grpcStream := NewGrpcPluginStream(ss, info.FullMethod, ctx, p) if grpcStream != nil { defer grpcStream.Release() } return handler(srv, grpcStream) +} +func (p *pluginImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + + return streamer(ctx, desc, cc, method, opts...) + } func NewPluginImpl(lb lb.LoadBalance, name string, timeout time.Duration) *pluginImpl { return &pluginImpl{ diff --git a/internal/middleware/rpc_test.go b/internal/middleware/rpc_test.go index dd09e60..ed1505d 100644 --- a/internal/middleware/rpc_test.go +++ b/internal/middleware/rpc_test.go @@ -11,6 +11,7 @@ import ( "github.com/begonia-org/begonia/internal/middleware" goloadbalancer "github.com/begonia-org/go-loadbalancer" hello "github.com/begonia-org/go-sdk/api/example/v1" + api "github.com/begonia-org/go-sdk/api/plugin/v1" "github.com/begonia-org/go-sdk/example" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" @@ -20,6 +21,21 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func TestPluginUnaryInterceptor(t *testing.T) { c.Convey("test plugin unary interceptor", t, func() { go example.RunPlugins(":9850") @@ -57,19 +73,23 @@ func TestPluginUnaryInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldBeNil) - patch2 := gomonkey.ApplyFuncSeq(metadata.FromIncomingContext, []gomonkey.OutputCell{{ - Values: gomonkey.Params{metadata.New(map[string]string{"test": "test"}), true}, - Times: 2, - }, - { - Values: gomonkey.Params{nil, false}, - }, - }) - defer patch2.Reset() + // patch2 := gomonkey.ApplyFuncSeq(metadata.FromIncomingContext, []gomonkey.OutputCell{{ + // Values: gomonkey.Params{metadata.New(map[string]string{"test": "test"}), true}, + // Times: 4, + // }, + // { + // Values: gomonkey.Params{nil, false}, + // }, + // }) + // defer patch2.Reset() err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { - return ss.RecvMsg(srv) + err := ss.RecvMsg(srv) + if err != nil { + t.Logf("recv msg error: %v", err) + } + return err }) - patch2.Reset() + // patch2.Reset() c.So(err, c.ShouldBeNil) }) @@ -142,20 +162,61 @@ func TestPluginUnaryInterceptorErr(t *testing.T) { patch4.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "unmarshal to request error") + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: context.Background()}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get metadata from context error") + // select err + patch5 := gomonkey.ApplyMethodReturn(lb, "Select", nil, fmt.Errorf("select endpoint error")) + defer patch5.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "select endpoint error") + patch5.Reset() + enp, _ := lb.Select("127.0.0.1") + patch6 := gomonkey.ApplyMethodReturn(enp, "Get", nil, fmt.Errorf("get connection error")) + defer patch6.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + patch6.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get connection error") + // call plugin metadata err + cli := api.NewPluginServiceClient(nil) + patch7 := gomonkey.ApplyMethodReturn(cli, "Metadata", nil, fmt.Errorf("call test plugin error")) + defer patch7.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + patch7.Reset() + c.So(err.Error(), c.ShouldContainSubstring, "call test plugin error") + // call apply err + patch8 := gomonkey.ApplyMethodReturn(cli, "Apply", nil, fmt.Errorf("call plugin error")) + defer patch8.Reset() + _, err = mid.UnaryInterceptor(metadata.NewIncomingContext(context.Background(), metadata.Pairs("X-Forwarded-For", "127.0.0.1:9090")), &hello.HelloRequest{}, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + patch8.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "call plugin error") }) } func TestPluginStreamInterceptorErr(t *testing.T) { c.Convey("test plugin unary interceptor", t, func() { - go example.RunPlugins(":9850") + go example.RunPlugins(":9851") time.Sleep(2 * time.Second) lb := goloadbalancer.NewGrpcLoadBalance(&goloadbalancer.Server{ Name: "test", Endpoints: []goloadbalancer.EndpointServer{ { - Addr: "127.0.0.1:9850", + Addr: "127.0.0.1:9851", }, }, Pool: &goloadbalancer.PoolConfig{ @@ -176,7 +237,7 @@ func TestPluginStreamInterceptorErr(t *testing.T) { patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "recv msg error") - patch1 := gomonkey.ApplyMethodReturn(mid, "Apply", nil, fmt.Errorf("call test plugin error")) + patch1 := gomonkey.ApplyMethodReturn(mid, "Apply", nil, nil, fmt.Errorf("call test plugin error")) defer patch1.Reset() err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { return ss.RecvMsg(srv) @@ -197,3 +258,34 @@ func TestPluginStreamInterceptorErr(t *testing.T) { }) } + +func TestRPCStreamClientInterceptor(t *testing.T) { + c.Convey("test rpc stream client interceptor", t, func() { + go example.RunPlugins(":9852") + time.Sleep(2 * time.Second) + lb := goloadbalancer.NewGrpcLoadBalance(&goloadbalancer.Server{ + Name: "test", + Endpoints: []goloadbalancer.EndpointServer{ + { + Addr: "127.0.0.1:9852", + }, + }, + Pool: &goloadbalancer.PoolConfig{ + MaxOpenConns: 10, + MaxIdleConns: 5, + MaxActiveConns: 5, + }, + }) + mid := middleware.NewPluginImpl(lb, "test", 3*time.Second) + c.So(mid.Name(), c.ShouldEqual, "test") + mid.SetPriority(3) + st, err := mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("key", "value")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + }, + ) +} diff --git a/internal/middleware/stream.go b/internal/middleware/stream.go index ec3ecda..fab289a 100644 --- a/internal/middleware/stream.go +++ b/internal/middleware/stream.go @@ -3,6 +3,7 @@ package middleware import ( "context" "fmt" + "log" "sync" gosdk "github.com/begonia-org/go-sdk" @@ -20,6 +21,13 @@ type grpcPluginStream struct { ctx context.Context } +// type grpcPluginClientStream struct { +// grpc.ClientStream +// fullName string +// plugin *pluginImpl +// ctx context.Context +// } + var streamPool = &sync.Pool{ New: func() interface{} { return &grpcPluginStream{ @@ -28,11 +36,19 @@ var streamPool = &sync.Pool{ }, } +// func NewGrpcPluginClientStream(s grpc.ClientStream, fullName string, ctx context.Context, plugin *pluginImpl) *grpcPluginClientStream { +// return &grpcPluginClientStream{ +// ClientStream: s, +// fullName: fullName, +// ctx: ctx, +// plugin: plugin, +// } +// } func NewGrpcPluginStream(s grpc.ServerStream, fullName string, ctx context.Context, plugin *pluginImpl) *grpcPluginStream { stream := streamPool.Get().(*grpcPluginStream) stream.ServerStream = s stream.fullName = fullName - stream.ctx = s.Context() + stream.ctx = ctx stream.plugin = plugin return stream } @@ -46,22 +62,20 @@ func (g *grpcPluginStream) Release() { func (g *grpcPluginStream) Context() context.Context { return g.ctx } + func (s *grpcPluginStream) RecvMsg(m interface{}) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - rsp, err := s.plugin.Apply(s.Context(), m, s.fullName) + rsp, header, err := s.plugin.Apply(s.Context(), m, s.fullName) if err != nil { return gosdk.NewError(fmt.Errorf("call %s plugin error: %w", s.plugin.Name(), err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } md, ok := metadata.FromIncomingContext(s.ctx) if !ok { - md = metadata.New(nil) - } - for k, v := range rsp.Metadata { - md.Append(k, v) + md = metadata.New(make(map[string]string)) } newRequest := rsp.NewRequest if newRequest != nil { @@ -70,6 +84,10 @@ func (s *grpcPluginStream) RecvMsg(m interface{}) error { return gosdk.NewError(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") } } + log.Printf("grpcPluginStream server stream pointer:%p", s) + for k, v := range header { + md[k] = append(md[k], v...) + } s.ctx = metadata.NewIncomingContext(s.ctx, md) // s.ctx = metadata.NewIncomingContext(s.ctx, metadata.New(rsp.Metadata)) diff --git a/internal/middleware/vaildator.go b/internal/middleware/vaildator.go index fe0461e..91536fe 100644 --- a/internal/middleware/vaildator.go +++ b/internal/middleware/vaildator.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "reflect" "strings" @@ -25,6 +26,16 @@ type validatePluginStream struct { ctx context.Context validator ParamsValidator } +type ValidateError struct { + error + Field string +} + +// type validatePluginClientStream struct { +// grpc.ClientStream +// ctx context.Context +// validator ParamsValidator +// } var validatePluginStreamPool = &sync.Pool{ New: func() interface{} { @@ -57,6 +68,7 @@ func (p *validatePluginStream) RecvMsg(m interface{}) error { } +// func getFieldNamesFromProto(input interface{}) map[string]string {} func getFieldNamesFromJSONTags(input interface{}) map[string]string { fieldMap := make(map[string]string) @@ -83,131 +95,19 @@ func getFieldNamesFromJSONTags(input interface{}) map[string]string { return fieldMap } -func (p *ParamsValidatorImpl) MergeMaps(maps ...map[string]string) map[string]string { - result := make(map[string]string) - for _, m := range maps { - // if m == nil { - // continue - // } - for k, v := range m { - result[k] = v - } - } - return result -} -func (p *ParamsValidatorImpl) GetValidateTags(v interface{}) map[string]string { - val := reflect.ValueOf(v) - if val.Kind() == reflect.Ptr { - val = val.Elem() - } - if val.Kind() != reflect.Struct { - return nil - } - rules := make(map[string]string) - typ := val.Type() - for i := 0; i < val.NumField(); i++ { - field := typ.Field(i) - fieldVal := val.Field(i) - jsonTag := strings.Split(field.Tag.Get("json"), ",")[0] - validateTag := field.Tag.Get("validate") - - if jsonTag == "" || jsonTag == "-" || validateTag == "" { - continue - } - - // 处理嵌套字段名 - fieldName := jsonTag - fieldName = strcase.ToCamel(fieldName) - rules[fieldName] = validateTag - - // 处理嵌套结构体和指针结构体 - switch fieldVal.Kind() { - case reflect.Struct: - r := p.GetValidateTags(fieldVal.Interface()) - p.validate.RegisterStructValidationMapRules(r, fieldVal.Interface()) - case reflect.Ptr: - if !fieldVal.IsNil() { - r := p.GetValidateTags(fieldVal.Interface()) - p.validate.RegisterStructValidationMapRules(r, fieldVal.Interface()) - } - case reflect.Slice: - elemType := fieldVal.Type().Elem() - if elemType.Kind() == reflect.Struct { - elem := reflect.New(elemType).Elem().Interface() - p.validate.RegisterStructValidationMapRules(p.GetValidateTags(elem), elem) - } else if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct { - elem := reflect.New(elemType.Elem()).Elem().Interface() - p.validate.RegisterStructValidationMapRules(p.GetValidateTags(elem), elem) - } - case reflect.Map: - elemType := fieldVal.Type().Elem() - if elemType.Kind() == reflect.Struct { - elem := reflect.New(elemType).Elem().Interface() - p.validate.RegisterStructValidationMapRules(p.GetValidateTags(elem), elem) - } else if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct { - elem := reflect.New(elemType.Elem()).Elem().Interface() - p.validate.RegisterStructValidationMapRules(p.GetValidateTags(elem), elem) - } - } - } - return rules -} - -// GetValidateMapRules 获取待验证字段的规则 -// 返回值格式为map[字段名]规则 -// 规则格式参考github.com/go-playground/validator/v10 -func (p *ParamsValidatorImpl) GetValidateMapRules(v interface{}) map[string]string { - if message, ok := v.(protoreflect.ProtoMessage); ok { - md := message.ProtoReflect().Descriptor() - rules := make(map[string]string) - - // 遍历所有字段 - for i := 0; i < md.Fields().Len(); i++ { - field := md.Fields().Get(i) - vd, ok := proto.GetExtension(field.Options(), common.E_Validate).(string) - if !ok { - continue - } - key := string(field.Name()) - key = strcase.ToCamel(key) - rules[key] = vd - if field.Kind() == protoreflect.MessageKind && !field.IsMap() && !field.IsList() { - nestedRules := p.GetValidateMapRules(message.ProtoReflect().Get(field).Message().Interface()) - - p.validate.RegisterStructValidationMapRules(nestedRules, message.ProtoReflect().Get(field).Message().Interface()) - - } - - // 处理列表字段 - if field.IsList() { - list := message.ProtoReflect().Get(field).List() - if field.Kind() == protoreflect.MessageKind { - nestedRules := p.GetValidateMapRules(list.NewElement().Message().Interface()) - p.validate.RegisterStructValidationMapRules(nestedRules, list.NewElement().Message().Interface()) - } - } - - // 处理映射字段 - if field.IsMap() { - p.validate.RegisterStructValidationMapRules(p.GetValidateMapRules(message.ProtoReflect().Get(field).Interface()), message.ProtoReflect().Get(field).Map().NewValue().Interface()) - - } - } - return rules - } - return nil -} // isRequiredField 检查字段是否是必填字段 // 通过proto文件中的validate标签或者struct tag中的validate标签判断 func (p *ParamsValidatorImpl) isRequiredField(field interface{}) bool { if fd, ok := field.(protoreflect.FieldDescriptor); ok && fd != nil { + if v, ok := proto.GetExtension(fd.Options(), common.E_Validate).(string); ok && strings.Contains(v, "required") { return true } } + if fd, ok := field.(reflect.StructField); ok && fd.Tag.Get("validate") != "" { if v, ok := fd.Tag.Lookup("validate"); ok && strings.Contains(v, "required") { return true @@ -218,32 +118,19 @@ func (p *ParamsValidatorImpl) isRequiredField(field interface{}) bool { return false } -// removeDuplicates 去除重复的字段 -func (p *ParamsValidatorImpl) removeDuplicates(arr []string) []string { - seen := make(map[string]bool) - result := []string{} - - for _, value := range arr { - if !seen[value] { - seen[value] = true - result = append(result, value) - } - } - - return result -} - // getValidatePath 获取待验证字段的路径 // 路径格式参考validate.StructPartial func (p *ParamsValidatorImpl) getValidatePath(message protoreflect.ProtoMessage, field string, parent string) []string { fieldsName := make([]string, 0) md := message.ProtoReflect().Descriptor() - + // log.Printf("get validate path,field:%s,parent:%s", field, parent) if fd := md.Fields().ByJSONName(field); fd != nil { fieldName := strcase.ToCamel(string(fd.Name())) + // fieldName := fd.JSONName() if parent != "" { fieldName = parent + "." + fieldName } + fieldsName = append(fieldsName, fieldName) if fd.Kind() == protoreflect.MessageKind { if fd.IsList() { @@ -262,6 +149,9 @@ func (p *ParamsValidatorImpl) getValidatePath(message protoreflect.ProtoMessage, item := value.Message().Interface() fieldsName = append(fieldsName, p.FiltersFields(item, fmt.Sprintf("%s[%v]", fieldName, key.Interface()))...) + } else { + // log.Printf("map key path:%v", fmt.Sprintf("%s[%v]", fieldName, key.Interface())) + fieldsName = append(fieldsName, fmt.Sprintf("%s[%v]", fieldName, key.Interface())) } return true }) @@ -269,14 +159,57 @@ func (p *ParamsValidatorImpl) getValidatePath(message protoreflect.ProtoMessage, nestedMessage := message.ProtoReflect().Get(fd).Message().Interface() fieldsName = append(fieldsName, p.FiltersFields(nestedMessage, fieldName)...) } - } else { - fieldsName = append(fieldsName, fieldName) } } return fieldsName } +// FiltersFields 从FieldMask中获取过滤字段,获取待验证字段 +// required 字段优先级高于FieldMask +func (p *ParamsValidatorImpl) FiltersMessageFields(v interface{}) []string { + // fieldsMap := getFieldNamesFromJSONTags(v) + requiredFields := make([]string, 0) + maskFields := make([]string, 0) + + if message, ok := v.(protoreflect.ProtoMessage); ok { + md := message.ProtoReflect().Descriptor() + + // 遍历所有字段 + for i := 0; i < md.Fields().Len(); i++ { + field := md.Fields().Get(i) + // require 字段必须校验 + if p.isRequiredField(field) { + // log.Printf("required field:%s", field.JSONName()) + requiredFields = append(requiredFields, p.getValidatePath(message, field.JSONName(), "")...) + } + + // 检查字段是否是FieldMask类型 + if field.Kind() == protoreflect.MessageKind && !field.IsList() && !field.IsMap() { + + // 获取字段的值(确保它是FieldMask类型) + fieldValue := message.ProtoReflect().Get(field).Message() + mask, ok := fieldValue.Interface().(*fieldmaskpb.FieldMask) + if mask == nil || !ok { + continue + } + paths := make([]string, 0) + paths = append(paths, mask.Paths...) + for _, path := range paths { + maskField := strcase.ToCamel(path) + // if parent != "" { + // maskField = fmt.Sprintf("%s.%s", parent, strcase.ToCamel(path)) + // } + maskFields = append(maskFields, maskField) + maskFields = append(maskFields, p.getValidatePath(message, path, "")...) + } + } + } + return append(requiredFields, maskFields...) + } + return nil +} + // FiltersFields 从FieldMask中获取过滤字段,获取待验证字段 // required 字段优先级高于FieldMask func (p *ParamsValidatorImpl) FiltersFields(v interface{}, parent string) []string { @@ -287,17 +220,16 @@ func (p *ParamsValidatorImpl) FiltersFields(v interface{}, parent string) []stri if message, ok := v.(protoreflect.ProtoMessage); ok { md := message.ProtoReflect().Descriptor() val := reflect.ValueOf(v) + typ := reflect.TypeOf(v) if val.Kind() == reflect.Ptr { val = val.Elem() - } - if val.Kind() != reflect.Struct { - return nil + typ = typ.Elem() } isRequired := false for k := range fieldsMap { field := md.Fields().ByJSONName(k) - if val.Kind() != reflect.Struct { - st, ok := reflect.TypeOf(val).FieldByName(fieldsMap[k]) + st, ok := typ.FieldByName(fieldsMap[k]) + if val.Kind() == reflect.Struct { isRequired = ok && p.isRequiredField(st) } @@ -335,43 +267,41 @@ func (p *ParamsValidatorImpl) FiltersFields(v interface{}, parent string) []stri } return nil } -func RegisterCustomValidators(v *validator.Validate) { - _ = v.RegisterValidation("required_if", requiredIf) -} - -// requiredIf 自定义验证器逻辑 -func requiredIf(fl validator.FieldLevel) bool { - param := fl.Param() - field := fl.Field() - // 获取参数字段值 - paramField := fl.Parent().FieldByName(param) +func (p *ParamsValidatorImpl) ValidateParams(v interface{}) error { + // p.validate.Struct() + var err error + if message, ok := v.(proto.Message); ok { + filters := p.FiltersMessageFields(v) + duplicateFilters := make([]string, 0) + fieldsSet := make(map[string]struct{}) + for _, f := range filters { + if _, ok := fieldsSet[f]; !ok { + fieldsSet[f] = struct{}{} + duplicateFilters = append(duplicateFilters, f) + } + } - // 如果参数字段为空,当前字段必须非空 - if paramField.String() == "" { - return field.String() != "" - } + pv := NewProtobufValidate(p.validate) + // log.Printf("validate fields:%v", duplicateFilters) + err = pv.ProtobufPartial(message, common.E_Validate, duplicateFilters...) + // err = p.ValidateProtoMessage(message, common.E_Validate, fieldsSet, strcase.ToCamel(string(message.ProtoReflect().Descriptor().Name()))+".") - return true -} + } else { -func (p *ParamsValidatorImpl) ValidateParams(v interface{}) error { - rules := p.GetValidateMapRules(v) - rules = p.MergeMaps(rules, p.GetValidateTags(v)) - p.validate.RegisterStructValidationMapRules(rules, v) - - filters := p.FiltersFields(v, "") - filters = p.removeDuplicates(filters) - err := p.validate.Struct(v) - if len(filters) > 0 { - err = p.validate.StructPartial(v, filters...) + err = gosdk.NewError(fmt.Errorf("params validation failed: params is not a proto.Message"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage("params validation failed: unsupported type")) } + fieldName := "" - if errs, ok := err.(validator.ValidationErrors); ok { - clientMsg := fmt.Sprintf("params %s validation failed with %v,except %s", errs[0].Namespace(), errs[0].Value(), errs[0].ActualTag()) - return gosdk.NewError(fmt.Errorf("params %s validation failed: %v", errs[0].Namespace(), errs[0].Value()), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage(clientMsg)) + validateErr := validator.ValidationErrors{} + if errors.As(err, &validateErr) { + if validateErr[0].Namespace() != "" { + fieldName = validateErr[0].Namespace() + } + clientMsg := fmt.Sprintf("params %s validation failed with %v,except %s", fieldName, validateErr[0].Value(), validateErr[0].ActualTag()) + return gosdk.NewError(fmt.Errorf("params %s validation failed: %v due to %v", fieldName, validateErr[0].Value(), validateErr[0].ActualTag()), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage(clientMsg)) } - return nil + return err } func (p *ParamsValidatorImpl) SetPriority(priority int) { @@ -386,6 +316,7 @@ func (p *ParamsValidatorImpl) Name() string { } func (p *ParamsValidatorImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // fmt.Print("params validator unary interceptor\n") err = p.ValidateParams(req) if err != nil { return nil, err @@ -403,13 +334,16 @@ func (p *ParamsValidatorImpl) StreamInterceptor(srv interface{}, ss grpc.ServerS err := handler(srv, validateStream) return err } +func (p *ParamsValidatorImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return streamer(ctx, desc, cc, method, opts...) +} func NewParamsValidator() ParamsValidator { v := &ParamsValidatorImpl{ validate: validator.New(), } - RegisterCustomValidators(v.validate) + // RegisterCustomValidators(v.validate) return v } diff --git a/internal/middleware/vaildator_test.go b/internal/middleware/vaildator_test.go index 20ac235..93beacb 100644 --- a/internal/middleware/vaildator_test.go +++ b/internal/middleware/vaildator_test.go @@ -6,23 +6,61 @@ import ( "github.com/begonia-org/begonia/internal/middleware" hello "github.com/begonia-org/go-sdk/api/example/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/go-playground/validator/v10" c "github.com/smartystreets/goconvey/convey" "github.com/spark-lence/tiga" "google.golang.org/grpc" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/dynamicpb" "google.golang.org/protobuf/types/known/fieldmaskpb" ) -func TestValidatorUnaryInterceptor(t *testing.T) { - c.Convey("test validator unary interceptor", t, func() { - validator := middleware.NewParamsValidator() +type HelloSubRequest struct { + SubMsg string `protobuf:"bytes,1,opt,name=sub_msg,proto3" json:"sub_msg,omitempty"` + // @gotags: validate:"required" + SubName string `protobuf:"bytes,2,opt,name=sub_name,proto3" json:"sub_name,omitempty" validate:"required"` + // @gotags: validate:"required,gte=18,lte=35" + SubAge int32 `protobuf:"varint,4,opt,name=sub_age,proto3" json:"sub_age,omitempty" validate:"required,gte=18,lte=35"` + UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,3,opt,name=update_mask,proto3" json:"update_mask,omitempty"` +} - validator.SetPriority(1) - c.So(validator.Priority(), c.ShouldEqual, 1) - c.So(validator.Name(), c.ShouldEqual, "ParamsValidator") +type HelloRequestWithValidator struct { + + // @gotags: validate:"required" + Msg string `protobuf:"bytes,1,opt,name=msg,proto3" json:"msg,omitempty" validate:"required"` + // @gotags: validate:"required" + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty" validate:"required"` + Age int32 `protobuf:"varint,3,opt,name=age,proto3" json:"age,omitempty" validate:"required,gte=18,lte=35"` + Sub *HelloSubRequest `protobuf:"bytes,4,opt,name=sub,proto3" json:"sub,omitempty"` + // @gotags: validate:"required,dive" + Subs []*HelloSubRequest `protobuf:"bytes,5,rep,name=subs,proto3" json:"subs,omitempty" validate:"required,dive"` + UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,6,opt,name=update_mask,proto3" json:"update_mask,omitempty"` + // @gotags: validate:"required,dive" + SubMap map[string]*HelloSubRequest `protobuf:"bytes,7,rep,name=sub_map,proto3" json:"sub_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3" validate:"required,dive"` + // @gotags: validate:"required" + SubMap2 map[string]string `protobuf:"bytes,8,rep,name=sub_map2,proto3" json:"sub_map2,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3" validate:"required"` +} + +func TestValidateDynamicProtoMessage(t *testing.T) { + c.Convey("test dynamic proto message", t, func() { req := &hello.HelloRequestWithValidator{ - Name: "test", - Msg: "test", - Age: 19, + Name: "test", + Msg: "test", + Age: 16, + FloatNum: 0.0, + BoolData: true, + ExEnum: hello.ExampleEnum_EX_RUNNING, + ExEnums: []hello.ExampleEnum{ + hello.ExampleEnum_EX_RUNNING, + }, + EnumMap: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_RUNNING, + }, + EnumMap2: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_UNKNOWN, + }, + Strs: []string{"test"}, Sub: &hello.HelloSubRequest{ SubMsg: "test", SubAge: 19, @@ -37,6 +75,7 @@ func TestValidatorUnaryInterceptor(t *testing.T) { UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, }, { + SubName: "test", SubAge: 19, SubMsg: "test", UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_msg"}}, @@ -54,6 +93,77 @@ func TestValidatorUnaryInterceptor(t *testing.T) { }, UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub", "subs", "sub_map", "sub_map2"}}, } + + dpb := dynamicpb.NewMessage(req.ProtoReflect().Descriptor()) + b, _ := protojson.Marshal(req) + _ = protojson.Unmarshal(b, dpb) + + pv := middleware.NewProtobufValidate(validator.New()) + err := pv.Protobuf(dpb, common.E_Validate) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Age") + }) + +} +func TestValidateProtoMessage(t *testing.T) { + req := &hello.HelloRequestWithValidator{ + Name: "test", + Msg: "test", + Age: 19, + Age2: 19, + FloatNum: 1.1, + BoolData: true, + BytesData: []byte("test"), + ExEnum: hello.ExampleEnum_EX_RUNNING, + ExEnums: []hello.ExampleEnum{ + hello.ExampleEnum_EX_RUNNING, + }, + EnumMap: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_RUNNING, + }, + EnumMap2: map[string]hello.ExampleEnum{ + "test": hello.ExampleEnum_EX_UNKNOWN, + }, + Strs: []string{"test"}, + Sub: &hello.HelloSubRequest{ + SubMsg: "test", + SubAge: 19, + SubName: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, + }, + Subs: []*hello.HelloSubRequest{ + { + SubAge: 19, + SubName: "test", + SubMsg: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, + }, + { + SubName: "test", + SubAge: 19, + SubMsg: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_msg"}}, + }, + }, + SubMap: map[string]*hello.HelloSubRequest{ + "TEST1": { + SubName: "test", + SubAge: 19, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_age"}}, + }, + }, + SubMap2: map[string]string{ + "TEST1": "test", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"name", "msg", "sub", "subs", "sub_map", "sub_map2"}}, + } + c.Convey("test validator unary interceptor", t, func() { + validator := middleware.NewParamsValidator() + + validator.SetPriority(1) + c.So(validator.Priority(), c.ShouldEqual, 1) + c.So(validator.Name(), c.ShouldEqual, "ParamsValidator") + _, err := validator.UnaryInterceptor(context.Background(), req, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) @@ -64,7 +174,7 @@ func TestValidatorUnaryInterceptor(t *testing.T) { return nil, nil }) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "HelloRequestWithValidator.Age") + c.So(err.Error(), c.ShouldContainSubstring, "Age") req3 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) req3.Subs[1].SubAge = 16 @@ -72,14 +182,14 @@ func TestValidatorUnaryInterceptor(t *testing.T) { return nil, nil }) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "HelloRequestWithValidator.Subs[1].SubAge") + c.So(err.Error(), c.ShouldContainSubstring, "Subs[1].SubAge") req4 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) req4.SubMap["TEST1"].SubAge = 16 _, err = validator.UnaryInterceptor(context.Background(), req4, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "HelloRequestWithValidator.SubMap[TEST1].SubAge") + c.So(err.Error(), c.ShouldContainSubstring, "SubMap[TEST1].SubAge") req5 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) req5.Subs[0] = &hello.HelloSubRequest{ SubName: "test2", @@ -88,35 +198,122 @@ func TestValidatorUnaryInterceptor(t *testing.T) { return nil, nil }) c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "HelloRequestWithValidator.Subs[0].SubAge") + c.So(err.Error(), c.ShouldContainSubstring, "Subs[0].SubAge") - }) -} -func TestRequireIf(t *testing.T) { - c.Convey("test require if", t, func() { - v := []struct { - Field string `validate:"required_if=Field2"` - Field2 string - Field3 string `validate:"required_if=Field2"` - }{{ - Field: "", - Field2: "test", - Field3: "", - }, - { - Field: "", - Field2: "", - Field3: "", - }, + req6 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req6.Sub.SubAge = 16 + _, err = validator.UnaryInterceptor(context.Background(), req6, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub.SubAge") + + req7 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req7.Subs[1] = &hello.HelloSubRequest{ + SubAge: 19, + SubMsg: "test", + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_age"}}, } - validator := middleware.NewParamsValidator() + _, err = validator.UnaryInterceptor(context.Background(), req7, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Subs[1].SubName") - err := validator.ValidateParams(v[0]) + req8 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req8.SubMap2 = nil + _, err = validator.UnaryInterceptor(context.Background(), req8, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "SubMap2") + + req9 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req9.ExEnum = hello.ExampleEnum_EX_UNKNOWN + _, err = validator.UnaryInterceptor(context.Background(), req9, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "ExEnum") + req10 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req10.Sub = nil + _, err = validator.UnaryInterceptor(context.Background(), req10, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub") + + req11 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req11.EnumMap = nil + _, err = validator.UnaryInterceptor(context.Background(), req11, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "EnumMap") + + req12 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req12.EnumMap2 = nil + req12.UpdateMask.Paths = append(req12.UpdateMask.Paths, "enum_map2") + _, err = validator.UnaryInterceptor(context.Background(), req12, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) c.So(err, c.ShouldBeNil) - err = validator.ValidateParams(v[1]) + + req14 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req14.Subs = nil + _, err = validator.UnaryInterceptor(context.Background(), req14, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Subs") + req15 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req15.Name = "hello" + _, err = validator.UnaryInterceptor(context.Background(), req15, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub2") + req16 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) + req16.Sub = nil + _, err = validator.UnaryInterceptor(context.Background(), req16, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "Sub") + + st := struct { + Name string `validate:"required"` + }{} + _, err = validator.UnaryInterceptor(context.Background(), st, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "params is not a proto.Message") + }) + c.Convey("test NewStructFromProtobuf", t, func() { + validate := validator.New() + vd := middleware.NewProtobufValidate(validate) + v := vd.NewStructFromProtobuf(tiga.DeepCopy(req).(*hello.HelloRequestWithValidator), common.E_Validate) + err := validate.Struct(v) + c.So(err, c.ShouldBeNil) + err = vd.ProtobufPartialCtx(context.Background(), tiga.DeepCopy(req).(*hello.HelloRequestWithValidator), common.E_Validate) + c.So(err, c.ShouldBeNil) + }) +} +func TestValidator(t *testing.T) { + st := struct { + IntNum int `validate:"required"` + BoolField bool `validate:"required"` + }{ + IntNum: 1, + BoolField: false, + } + v := validator.New() + err := v.Struct(st) + t.Log(err) } func TestValidatorStreamInterceptor(t *testing.T) { c.Convey("test stream interceptor", t, func() { @@ -138,3 +335,15 @@ func TestValidatorStreamInterceptor(t *testing.T) { c.So(err, c.ShouldNotBeNil) }) } +func TestValidatorStreamClientInterceptor(t *testing.T) { + c.Convey("test stream client interceptor", t, func() { + validator := middleware.NewParamsValidator() + + _, err := validator.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // return middleware.NewGrpcPluginClientStream(ctx, desc, cc, method, opts...),nil + return nil, nil + }) + c.So(err, c.ShouldBeNil) + + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 4dc9558..8b498fe 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,7 +13,6 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/middleware" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/service" loadbalance "github.com/begonia-org/go-loadbalancer" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -52,7 +51,7 @@ func readDesc(conf *config.Config) (gateway.ProtobufDescription, error) { func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []service.Service, pluginApply *middleware.PluginsApply) *gateway.GatewayServer { // 参数选项 opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]runtime.ServeMuxOption, 0), @@ -63,10 +62,12 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("application/x-www-form-urlencoded", gateway.NewFormUrlEncodedMarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption(runtime.MIMEWildcard, gateway.NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("application/octet-stream", gateway.NewRawBinaryUnmarshaler())) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("text/event-stream", gateway.NewEventSourceMarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMetadata(gateway.IncomingHeadersToMetadata)) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithErrorHandler(gateway.HandleErrorWithLogger(gateway.Log))) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithForwardResponseOption(gateway.HttpResponseBodyModify)) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithStreamErrorHandler(gateway.HandleServerStreamError(gateway.Log))) // opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithRoutingErrorHandler(middleware.HandleRoutingError)) // 连接池配置 opts.PoolOptions = append(opts.PoolOptions, loadbalance.WithMaxActiveConns(100)) @@ -74,6 +75,7 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv // 中间件配置 opts.Options = append(opts.Options, grpc.ChainUnaryInterceptor(pluginApply.UnaryInterceptorChains()...)) opts.Options = append(opts.Options, grpc.ChainStreamInterceptor(pluginApply.StreamInterceptorChains()...)) + opts.Middlewares = append(opts.Middlewares, pluginApply.StreamClientInterceptorChains()...) pd, err := readDesc(conf) if err != nil { panic(err) @@ -84,7 +86,7 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv opts.HttpHandlers = append(opts.HttpHandlers, cors.Handle) gw := gateway.New(cfg, opts) - routersList := routers.Get() + routersList := gateway.GetRouter() for _, srv := range services { err := gw.RegisterLocalService(context.Background(), pd, srv.Desc(), srv) if err != nil { diff --git a/internal/service/file_test.go b/internal/service/file_test.go index fabe88e..d0e6c36 100644 --- a/internal/service/file_test.go +++ b/internal/service/file_test.go @@ -67,6 +67,7 @@ func makeBucket(t *testing.T) { rsp, err := apiClient.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + t.Logf("access key:%s", accessKey) minioFile := client.NewFilesAPI(apiAddr, accessKey, secret, api.FileEngine_FILE_ENGINE_MINIO) rsp, err = minioFile.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldBeNil) diff --git a/internal/service/tenant.go b/internal/service/tenant.go index c7abc27..e4dbe93 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -25,6 +25,8 @@ func NewTenantService(tenant *biz.TenantUsecase, cfg *config.Config, log logger. return &TenantService{tenant: tenant, cfg: cfg, log: log} } func (t *TenantService) Register(ctx context.Context, in *api.PostTenantRequest) (*api.Tenants, error) { + // st:=debug.Stack() + // fmt.Printf("TenantService stack:%s\n",st) identity := GetIdentity(ctx) if identity == "" { return nil, gosdk.NewError(pkg.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") diff --git a/testdata/desc.pb b/testdata/desc.pb index 59f32fdd0bcffa3e2ac2e867d5e328a2e8597158..49a730b5692bdf8d4bd58a98db8f1811a424c469 100644 GIT binary patch delta 2087 zcmbVNOK%%h6z=sSwvQiiZ-X;+^ROpX_-Ap{!=^d3@h@&YU~v zulIlM_P<{GKJXMZ^8LFvZlZ-}2y#tbZ?>son7Xak4jW3NMcZ1nqgs!?T9kJ35WC9_ z_t_cGK-f_214}}=U?E1lL-`W~SL%AF=m8adpmJ)MxXZD3(R(HLzvbnWAmI2$?H`2P zLSeoCczHIE|7;r42ejSRk9D)Hc{mtGmaO~sjd?kHj$kjx{W34Rzpl(VuR@|?nz~8o z<3noMI1Ry$)nwuk6VG2UhEQGCs1#tPFcI%E%i7{9&duto9WN_JLi}Cc02pK$?H-mO z(xHxJsZBbdV`p*XlrJT3LXuE}sy5ZwljW~N#AGim>a{vv$`nU^D2|ATA0|HyL8@}1 zc8oT?KVWi%inBya){W;N)@bXhUF{xrq#$buA|nG{0Pwt;YO7L;xdUX(!n12Vn^uK8 z#M}wuj4Z;f>ut{+X6`77)qDfK31O9gcWH`wQl(eMi>)v@o7_pf7 zjNz0QpZxZ_8@MVvGm2`ol%rb;^(Is?bj#jAtGEhrAO9Y;FF>&%C?uSjWcZL@vXAg4 zCvdz9$nKpBFkHeXFzbulqvjDc_gWgQ9MSIo73_@NSb0S#Uc&iF#g@QUEQ>0h_;rMj zVaBJv!~Oz#(J@&y1h}HWL5Vid(jAaI5>%+N?7tYRSki zjaH54QAqkYVgsCR*!WU7b9iMzFo&~~!YN-VZzwj~DCHJhYw5+NY1E5@(YD6=U!E2> zXP>V#e+AxhH;Ri(P4-oe)t+)NTzB6Ff?M1BwJEVZZ^Osoek=bsy0-BEGMe7CtA`e? z9*P_^Ul8f`wzn}77b(CD;Q4cK2n?wd7ZguH(^hf{ml?Dnt3HSe8{SKiW c_nWoXW;3F=1O`!t!h811j}q!XU;6^R0MJ~gNB{r; delta 99 zcmV-p0G$8BZuMNVuL5Tw0<8qI%`f%=lNvHOu>r3F1E>NDlWsB{vo<6P0kbM57Xkt6 zvr{r=0RsL6v%52^0S8J71VV3Qd6UmG8wB5(iz diff --git a/testdata/gateway.json b/testdata/gateway.json index 14807a4..b829aeb 100644 --- a/testdata/gateway.json +++ b/testdata/gateway.json @@ -1,4 +1,302 @@ { + "/helloworld.Greeter/SayHello": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3 + ], + "Pool": [ + "api", + "v1", + "example", + "post" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/post" + }, + "HttpMethod": "POST", + "FullMethodName": "/helloworld.Greeter/SayHello", + "HttpUri": "/api/v1/example/post", + "PathParams": [], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloBody": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3 + ], + "Pool": [ + "api", + "v1", + "example", + "body" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/body" + }, + "HttpMethod": "POST", + "FullMethodName": "/helloworld.Greeter/SayHelloBody", + "HttpUri": "/api/v1/example/body", + "PathParams": [], + "InName": "HttpBody", + "OutName": "HttpBody", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "google.api", + "OutPkg": "google.api" + } + ], + "/helloworld.Greeter/SayHelloClientStream": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4 + ], + "Pool": [ + "api", + "v1", + "example", + "client", + "stream" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/client/stream" + }, + "HttpMethod": "POST", + "FullMethodName": "/helloworld.Greeter/SayHelloClientStream", + "HttpUri": "/api/v1/example/client/stream", + "PathParams": [], + "InName": "HelloRequest", + "OutName": "RepeatedReply", + "IsClientStream": true, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloError": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4 + ], + "Pool": [ + "api", + "v1", + "example", + "error", + "test" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/error/test" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloError", + "HttpUri": "/api/v1/example/error/test", + "PathParams": [], + "InName": "ErrorRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloGet": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 1, + 0, + 4, + 1, + 5, + 3 + ], + "Pool": [ + "api", + "v1", + "example", + "name" + ], + "Verb": "", + "Fields": [ + "name" + ], + "Template": "/api/v1/example/{name}" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloGet", + "HttpUri": "/api/v1/example/{name}", + "PathParams": [ + "name" + ], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": false, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloRPC": [], + "/helloworld.Greeter/SayHelloServerSideEvent": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4, + 1, + 0, + 4, + 1, + 5, + 5 + ], + "Pool": [ + "api", + "v1", + "example", + "server", + "sse", + "name" + ], + "Verb": "", + "Fields": [ + "name" + ], + "Template": "/api/v1/example/server/sse/{name}" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloServerSideEvent", + "HttpUri": "/api/v1/example/server/sse/{name}", + "PathParams": [ + "name" + ], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": false, + "IsServerStream": true, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], + "/helloworld.Greeter/SayHelloWebsocket": [ + { + "Pattern": {}, + "Template": { + "Version": 1, + "OpCodes": [ + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 3, + 2, + 4 + ], + "Pool": [ + "api", + "v1", + "example", + "server", + "websocket" + ], + "Verb": "", + "Fields": null, + "Template": "/api/v1/example/server/websocket" + }, + "HttpMethod": "GET", + "FullMethodName": "/helloworld.Greeter/SayHelloWebsocket", + "HttpUri": "/api/v1/example/server/websocket", + "PathParams": [], + "InName": "HelloRequest", + "OutName": "HelloReply", + "IsClientStream": true, + "IsServerStream": true, + "Pkg": "helloworld", + "InPkg": "helloworld", + "OutPkg": "helloworld" + } + ], "/integration.TestService/Body": [ { "Pattern": {}, diff --git a/testdata/helloworld.pb b/testdata/helloworld.pb index 82797fb1c310292ba49907ecf36de5827c3eb45d..54336ca1e9e632989f3b3062e45038364316d2e0 100644 GIT binary patch delta 154 zcmZpSKNY`W0~6!p%^R8i@i2Bz-oSrT%ZN*`IJKxOGdVTBxTGjGF*h?WU4l)4QG>&Z zH7zlxI5kKBD!+N7K%Y2MJO5@!jl0ZDTiGTn=$ztndNzra#9vCF?;xgC$SY2<(K3e0|2sT BH4*>- delta 127 zcmX>V-x$AP0~6!J%^R8i@i2By-oSrT!+=YmIJKxOwJ5&0q$o8pH#09?f=z)@gTsn7 zEitD!HHaS~y?L`hpE%P2zRk`WcbSA*1f(Zhd)~)QZI1 cf}B+S#DYxyGQ&ko%pN}BNo)m0`6c