From 9f4c49d49a6021bdefdc0f269f75f46bd0700f8a Mon Sep 17 00:00:00 2001 From: "vforfreedom96@gmail.com" Date: Sun, 14 Jul 2024 19:28:41 +0800 Subject: [PATCH] feat:add client stream mid --- cmd/begonia/main.go | 10 + config/settings.yml | 13 +- gateway.json | 299 ------------------ gateway/endpoint.go | 3 +- gateway/exception.go | 72 ++++- gateway/exception_test.go | 47 +++ gateway/gateway.go | 25 +- gateway/gateway_test.go | 35 ++ gateway/http.go | 1 - gateway/http_test.go | 3 +- gateway/middlewares.go | 86 +++-- gateway/middlewares_test.go | 30 +- gateway/protobuf.go | 4 + gateway/{grpc.go => proxy.go} | 176 +++++++++-- gateway/{grpc_test.go => proxy_test.go} | 68 +++- {internal/pkg/routers => gateway}/routers.go | 24 +- .../pkg/routers => gateway}/routers_test.go | 11 +- gateway/utils_test.go | 2 +- go.mod | 6 +- go.sum | 12 +- internal/biz/aksk.go | 4 +- internal/biz/aksk_test.go | 3 +- internal/biz/data_test.go | 2 +- internal/biz/endpoint/endpoint_test.go | 33 +- internal/biz/endpoint/utils.go | 3 +- internal/biz/endpoint/watcher.go | 15 +- internal/middleware/auth/ak_test.go | 43 ++- internal/middleware/auth/aksk.go | 64 +++- internal/middleware/auth/apikey.go | 65 +++- internal/middleware/auth/apikey_test.go | 20 +- internal/middleware/auth/auth.go | 29 +- internal/middleware/auth/auth_test.go | 246 +++++++++++++- internal/middleware/auth/headers.go | 225 +++++++------ internal/middleware/auth/headers_test.go | 52 +-- internal/middleware/auth/jwt.go | 81 +++-- internal/middleware/auth/jwt_test.go | 52 ++- internal/middleware/auth/stream.go | 53 ++-- internal/middleware/http.go | 25 +- internal/middleware/http_test.go | 37 ++- internal/middleware/middleware.go | 8 + internal/middleware/middleware_test.go | 1 + internal/middleware/rpc.go | 70 ++-- internal/middleware/rpc_test.go | 120 ++++++- internal/middleware/stream.go | 29 +- internal/middleware/vaildator.go | 21 +- internal/middleware/vaildator_test.go | 44 ++- internal/server/server.go | 6 +- internal/service/file_test.go | 1 + internal/service/tenant.go | 2 + 49 files changed, 1546 insertions(+), 735 deletions(-) delete mode 100644 gateway.json create mode 100644 gateway/gateway_test.go rename gateway/{grpc.go => proxy.go} (60%) rename gateway/{grpc_test.go => proxy_test.go} (65%) rename {internal/pkg/routers => gateway}/routers.go (87%) rename {internal/pkg/routers => gateway}/routers_test.go (77%) diff --git a/cmd/begonia/main.go b/cmd/begonia/main.go index 8002b6b..e3e499a 100644 --- a/cmd/begonia/main.go +++ b/cmd/begonia/main.go @@ -205,4 +205,14 @@ func main() { if err := cmd.Execute(); err != nil { log.Fatalf("failed to start begonia: %v", err) } + // env, _ := cmd.Flags().GetString("env") + // cnf, err := cmd.Flags().GetString("config") + // if err != nil { + // log.Fatalf("failed to get config: %v", err) + // } + // config := config.ReadConfigWithDir("dev", "/data/work/begonia-org/begonia/config/settings.yml") + // worker := internal.New(config, gateway.Log, "127.0.0.1:12138") + // hd, _ := os.UserHomeDir() + // _ = os.WriteFile(hd+"/.begonia/gateway.json", []byte(fmt.Sprintf(`{"addr":"http://%s"}`, "127.0.0.1:12138")), 0666) + // worker.Start() } diff --git a/config/settings.yml b/config/settings.yml index 0a5b32e..6d524f1 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -11,8 +11,8 @@ file: engines: - name: "FILE_ENGINE_MINIO" endpoint: "127.0.0.1:9000" - accessKey: "7OdVJK1alV8cpRMeBLBW" - secretKey: "2GNNRqqElReC1KnV3kX9jSjyLU4kwOaTZEqDS2vH" + accessKey: "rLV2Jjj2UbMWSJOhTOtZ" + secretKey: "OVyJOILwx4iVE0EVJB4CKB65j7xlhjT5q1aGCv5t" - name: "FILE_ENGINE_LOCAL" endpoint: "/data/work/begonia-org/begonia/upload" protos: @@ -70,11 +70,12 @@ gateway: - "example.com" plugins: local: - logger: 1 - exception: 0 + # 优先级越大越先执行 + exception: 4 + logger: 3 http: 2 - params_validator: 3 - auth: 4 + auth: 1 + params_validator: 0 # only_api_key_auth: 4 rpc: # - server: diff --git a/gateway.json b/gateway.json deleted file mode 100644 index 825d680..0000000 --- a/gateway.json +++ /dev/null @@ -1,299 +0,0 @@ -{ - "/helloworld.Greeter/SayHello": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "post" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/post" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHello", - "HttpUri": "/api/v1/example/post", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloBody": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "body" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/body" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHelloBody", - "HttpUri": "/api/v1/example/body", - "PathParams": [], - "InName": "HttpBody", - "OutName": "HttpBody", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "google.api", - "OutPkg": "google.api" - } - ], - "/helloworld.Greeter/SayHelloClientStream": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "client", - "stream" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/client/stream" - }, - "HttpMethod": "POST", - "FullMethodName": "/helloworld.Greeter/SayHelloClientStream", - "HttpUri": "/api/v1/example/client/stream", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "RepeatedReply", - "IsClientStream": true, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloError": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "error", - "test" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/error/test" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloError", - "HttpUri": "/api/v1/example/error/test", - "PathParams": [], - "InName": "ErrorRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloGet": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 1, - 0, - 4, - 1, - 5, - 3 - ], - "Pool": [ - "api", - "v1", - "example", - "name" - ], - "Verb": "", - "Fields": [ - "name" - ], - "Template": "/api/v1/example/{name}" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloGet", - "HttpUri": "/api/v1/example/{name}", - "PathParams": [ - "name" - ], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": false, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloServerSideEvent": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4, - 1, - 0, - 4, - 1, - 5, - 5 - ], - "Pool": [ - "api", - "v1", - "example", - "server", - "sse", - "name" - ], - "Verb": "", - "Fields": [ - "name" - ], - "Template": "/api/v1/example/server/sse/{name}" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloServerSideEvent", - "HttpUri": "/api/v1/example/server/sse/{name}", - "PathParams": [ - "name" - ], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": false, - "IsServerStream": true, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ], - "/helloworld.Greeter/SayHelloWebsocket": [ - { - "Pattern": {}, - "Template": { - "Version": 1, - "OpCodes": [ - 2, - 0, - 2, - 1, - 2, - 2, - 2, - 3, - 2, - 4 - ], - "Pool": [ - "api", - "v1", - "example", - "server", - "websocket" - ], - "Verb": "", - "Fields": null, - "Template": "/api/v1/example/server/websocket" - }, - "HttpMethod": "GET", - "FullMethodName": "/helloworld.Greeter/SayHelloWebsocket", - "HttpUri": "/api/v1/example/server/websocket", - "PathParams": [], - "InName": "HelloRequest", - "OutName": "HelloReply", - "IsClientStream": true, - "IsServerStream": true, - "Pkg": "helloworld", - "InPkg": "helloworld", - "OutPkg": "helloworld" - } - ] -} \ No newline at end of file diff --git a/gateway/endpoint.go b/gateway/endpoint.go index a91e85a..e5501b5 100644 --- a/gateway/endpoint.go +++ b/gateway/endpoint.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "fmt" "strings" loadbalance "github.com/begonia-org/go-loadbalancer" @@ -42,7 +43,7 @@ func (e *httpForwardGrpcEndpointImpl) Request(req GrpcRequest) (proto.Message, r return nil, runtime.ServerMetadata{ HeaderMD: make(map[string][]string), TrailerMD: make(map[string][]string), - }, err + }, fmt.Errorf("get conn error:%v", err) } defer e.pool.Release(req.GetContext(), cc) diff --git a/gateway/exception.go b/gateway/exception.go index 7f3da65..3f40dfa 100644 --- a/gateway/exception.go +++ b/gateway/exception.go @@ -9,8 +9,10 @@ import ( gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" + "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" ) type Exception struct { @@ -19,17 +21,30 @@ type Exception struct { name string } +func (e *Exception) setHeader(ctx context.Context) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + reqId := "" + if !ok || len(md.Get(XRequestID)) == 0 { + reqId = uuid.New().String() + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(XRequestID, reqId)) + + } else { + reqId = md.Get(XRequestID)[0] + } + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(XRequestID, reqId)) + + _ = grpc.SetHeader(ctx, metadata.Pairs(XRequestID, reqId)) + return ctx + +} func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { defer func() { if p := recover(); p != nil { - // buf := make([]byte, 1024) - // n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 - // stackTrace := string(buf[:n]) - // err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - // err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") err = e.handlePanic(p) } + }() + ctx = e.setHeader(ctx) resp, err = handler(ctx, req) if err == nil { return resp, err @@ -37,7 +52,7 @@ func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info return nil, err } func (e *Exception) handlePanic(p interface{}) error { - const maxFrames = 10 + const maxFrames = 15 var pcs [maxFrames]uintptr n := runtime.Callers(2, pcs[:]) // skip first 3 frames @@ -60,18 +75,49 @@ func (e *Exception) handlePanic(p interface{}) error { func (e *Exception) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { defer func() { if p := recover(); p != nil { - // buf := make([]byte, 512) - // n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 - // stackTrace := string(buf[:n]) - - // err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - // err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") - // _ = ss.SendMsg(err) err = e.handlePanic(p) } + }() + e.setHeader(ss.Context()) return handler(srv, ss) } +func (e *Exception) wrapHandlerWithPanicRecovery(handler grpc.StreamHandler) grpc.StreamHandler { + return func(srv any, stream grpc.ServerStream) (err error) { + // reqId := "" + defer func() { + if p := recover(); p != nil { + err = e.handlePanic(p) + } + }() + + return handler(srv, stream) + } +} +func (e *Exception) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + reqId := "" + md, ok := metadata.FromOutgoingContext(ctx) + if !ok || len(md.Get(XRequestID)) == 0 { + reqID := uuid.New().String() + reqId = reqID + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(XRequestID, reqId)) + + } else { + reqId = md.Get(XRequestID)[0] + md.Set(XRequestID, reqId) + ctx = metadata.NewOutgoingContext(ctx, md) + + } + + _ = grpc.SetHeader(ctx, metadata.Pairs(XRequestID, reqId)) + desc.Handler = e.wrapHandlerWithPanicRecovery(desc.Handler) + ss, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, err + + } + return ss, nil +} func NewException(log logger.Logger) *Exception { return &Exception{log: log, name: "exception"} } diff --git a/gateway/exception_test.go b/gateway/exception_test.go index be8175a..2f497bd 100644 --- a/gateway/exception_test.go +++ b/gateway/exception_test.go @@ -6,8 +6,10 @@ import ( "testing" "github.com/begonia-org/go-sdk/logger" + "github.com/google/uuid" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" ) type MiddlewaresTest struct { @@ -39,6 +41,21 @@ func (e *MiddlewaresTest) Name() string { return e.name } +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func TestUnaryInterceptor(t *testing.T) { c.Convey("TestUnaryInterceptor", t, func() { mid := NewException(Log) @@ -64,3 +81,33 @@ func TestUnaryInterceptor(t *testing.T) { }) } +func TestExceptionStreamClientInterceptor(t *testing.T) { + c.Convey("TestExceptionStreamClientInterceptor", t, func() { + mid := NewException(Log) + ctx := context.Background() + + desc := &grpc.StreamDesc{ + StreamName: "/INTEGRATION.TESTSERVICE/GET", + ClientStreams: true, + ServerStreams: true, + Handler: func(srv interface{}, ss grpc.ServerStream) error { + panic("test painc") + }, + } + st, err := mid.StreamClientInterceptor(ctx, desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: context.Background()}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + // has request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, uuid.New().String())), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: context.Background()}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, nil) + c.So(err, c.ShouldNotBeNil) + + }) +} diff --git a/gateway/gateway.go b/gateway/gateway.go index b5b6eca..4929efd 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -18,7 +18,7 @@ import ( ) type GrpcServerOptions struct { - Middlewares []GrpcProxyMiddleware + Middlewares []grpc.StreamClientInterceptor Options []grpc.ServerOption PoolOptions []loadbalance.PoolOptionsBuildOption HttpMiddlewares []runtime.ServeMuxOption @@ -28,6 +28,8 @@ type GatewayConfig struct { GatewayAddr string GrpcProxyAddr string } +type DynamicGrpcItem struct { +} type GatewayServer struct { grpcServer *grpc.Server httpGateway HttpEndpoint @@ -37,14 +39,15 @@ type GatewayServer struct { proxyAddr string opts *GrpcServerOptions mux *sync.Mutex + proxy *GrpcProxy } -func NewGrpcServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *grpc.Server { +func NewGrpcProxyServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *GrpcProxy { - proxy := NewGrpcProxy(lb, opts.Middlewares...) + proxy := NewGrpcProxy(lb,Log, opts.Middlewares...) - opts.Options = append(opts.Options, grpc.UnknownServiceHandler(proxy.Handler)) - return grpc.NewServer(opts.Options...) + opts.Options = append(opts.Options, grpc.UnknownServiceHandler(proxy.Do)) + return proxy } func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) (HttpEndpoint, error) { @@ -57,7 +60,9 @@ func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) ( } func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { lb := NewGrpcLoadBalancer() - grpcServer := NewGrpcServer(opts, lb) + gProxy := NewGrpcProxyServer(opts, lb) + opts.Options = append(opts.Options, grpc.UnknownServiceHandler(gProxy.Do)) + grpcServer := grpc.NewServer(opts.Options...) _, port, _ := net.SplitHostPort(cfg.GrpcProxyAddr) proxy := fmt.Sprintf("127.0.0.1:%s", port) @@ -77,6 +82,7 @@ func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { proxyAddr: cfg.GrpcProxyAddr, opts: opts, mux: &sync.Mutex{}, + proxy: gProxy, } // }) return gatewayS @@ -97,6 +103,13 @@ func (g *GatewayServer) RegisterLocalService(ctx context.Context, pd ProtobufDes g.grpcServer.RegisterService(sd, ss) return g.httpGateway.RegisterHandlerClient(ctx, pd, g.gatewayMux) } +func (g *GatewayServer) RegisterServiceWithProxy(pd ProtobufDescription) { + g.mux.Lock() + defer g.mux.Unlock() + g.proxy.buildServiceDesc(pd) + +} + func (g *GatewayServer) DeleteLocalService(pd ProtobufDescription) { g.mux.Lock() defer g.mux.Unlock() diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go new file mode 100644 index 0000000..b332fac --- /dev/null +++ b/gateway/gateway_test.go @@ -0,0 +1,35 @@ +package gateway + +import ( + "fmt" + "log" + "os" + "path/filepath" + "testing" + + "github.com/begonia-org/begonia/internal/pkg/config" + common "github.com/begonia-org/go-sdk/common/api/v1" +) + +func readDesc(conf *config.Config) (ProtobufDescription, error) { + desc := conf.GetLocalAPIDesc() + log.Printf("read desc file:%s", desc) + bin, err := os.ReadFile("/data/work/begonia-org/openRAG/openrag/pb/openrag.bin") + if err != nil { + return nil, fmt.Errorf("read desc file error:%w", err) + } + pd, err := NewDescriptionFromBinary(bin, filepath.Dir(desc)) + if err != nil { + return nil, err + } + err = pd.SetHttpResponse(common.E_HttpResponse) + if err != nil { + return nil, err + } + return pd, nil +} +func TestRegisterDynamicServices(t *testing.T) { + // pd, _ := readDesc(config.NewConfig(cfg.ReadConfig("test"))) + // _ = &GatewayServer{} + // gw.buildServiceDesc(pd) +} diff --git a/gateway/http.go b/gateway/http.go index db10a89..de1402d 100644 --- a/gateway/http.go +++ b/gateway/http.go @@ -421,7 +421,6 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu return } resp, md, err := h.client.Request(reqInstance) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) diff --git a/gateway/http_test.go b/gateway/http_test.go index b1cecfb..4c92959 100644 --- a/gateway/http_test.go +++ b/gateway/http_test.go @@ -51,7 +51,7 @@ var eps []loadbalance.Endpoint func newTestServer(gwPort, randomNumber int) (*GrpcServerOptions, *GatewayConfig) { opts := &GrpcServerOptions{ - Middlewares: make([]GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -131,6 +131,7 @@ func testRegisterClient(t *testing.T) { load, err := loadbalance.New(loadbalance.RRBalanceType, endps) c.So(err, c.ShouldBeNil) + gw.RegisterServiceWithProxy(pd) err = gw.RegisterService(context.Background(), pd, load) c.So(err, c.ShouldBeNil) c.So(gw.GetLoadbalanceName(), c.ShouldEqual, loadbalance.RRBalanceType) diff --git a/gateway/middlewares.go b/gateway/middlewares.go index 29e8629..c5c9669 100644 --- a/gateway/middlewares.go +++ b/gateway/middlewares.go @@ -2,13 +2,14 @@ package gateway import ( "context" + "fmt" "net/http" "strconv" "strings" "time" gosdk "github.com/begonia-org/go-sdk" - _ "github.com/begonia-org/go-sdk/api" + // _ "github.com/begonia-org/go-sdk/api" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" "github.com/google/uuid" @@ -189,12 +190,45 @@ func (log *LoggerMiddleware) UnaryInterceptor(ctx context.Context, req interface } }() - + // fmt.Printf("call logger mid\n") rsp, err = handler(ctx, req) elapsed := time.Since(now) + // fmt.Printf("logger error:%v", err) log.logger(ctx, info.FullMethod, err, elapsed) return } +func (log *LoggerMiddleware) wrapHandlerWithLogger(handler grpc.StreamHandler) grpc.StreamHandler { + return func(srv interface{}, ss grpc.ServerStream) (err error) { + now := time.Now() + defer func() { + if r := recover(); r != nil { + elapsed := time.Since(now) + method, _ := grpc.Method(ss.Context()) + err = fmt.Errorf("handle err:%v", r) + log.logger(ss.Context(), method, fmt.Errorf("handle err:%v",r), elapsed) + } + if md, ok := metadata.FromIncomingContext(ss.Context()); ok { + reqId := md.Get(XRequestID) + if len(reqId) > 0 { + _ = grpc.SendHeader(ss.Context(), metadata.Pairs(XRequestID, reqId[0])) + } + } + }() + err = handler(srv, ss) + elapsed := time.Since(now) + method, _ := grpc.Method(ss.Context()) + + log.logger(ss.Context(), method, err, elapsed) + return err + } + +} +func (log *LoggerMiddleware) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + + desc.Handler = log.wrapHandlerWithLogger(desc.Handler) + return streamer(ctx, desc, cc, method, opts...) + +} func (log *LoggerMiddleware) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { now := time.Now() defer func() { @@ -313,7 +347,7 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { log.WithField("status", code).Errorf(ctx, msg) w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) - // log.Infof(ctx, "error message:%s,err code:%d", data.Message, data.Code) + log.Infof(ctx, "error message:%s,err code:%d", data.Message, data.Code) bData, _ := protojson.Marshal(data) _, _ = w.Write(bData) return @@ -332,7 +366,7 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { w.Header().Del(key) if v != "" { if strings.EqualFold(httpKey, "Content-Type") { - if v == "application/grpc" { + if strings.EqualFold(v, "application/grpc") { continue } w.Header().Set(httpKey, v) @@ -347,6 +381,9 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { return } } + if strings.HasPrefix(strings.ToLower(httpKey), strings.ToLower("Grpc-")) { + continue + } headers = append(headers, http.CanonicalHeaderKey(httpKey)) w.Header().Set("Access-Control-Expose-Headers", strings.Join(headers, ",")) @@ -358,32 +395,37 @@ func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { } func HttpResponseBodyModify(ctx context.Context, w http.ResponseWriter, msg proto.Message) error { httpCode := http.StatusOK - for key, value := range w.Header() { - if strings.HasPrefix(key, "Grpc-Metadata-") { + // log.Printf("response message header:%v", w.Header()) + for key := range w.Header() { + if strings.HasPrefix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("Grpc-")) { w.Header().Del(key) + continue } - // log.Printf("send to client rsp header,key:%s,value:%s", key, value[0]) - writeHttpHeaders(w, key, value) - if strings.HasSuffix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("X-Http-Code")) { - codeStr := value[0] - code, err := strconv.ParseInt(codeStr, 10, 32) - if err != nil { - Log.Error(ctx, err) - return status.Error(codes.Internal, err.Error()) - } - httpCode = int(code) + } + if out, ok := runtime.ServerMetadataFromContext(ctx); ok { + // log.Printf("response message header from server metadata:%v", out.HeaderMD) + for key, value := range out.HeaderMD { - } + if strings.HasPrefix(strings.ToLower(key), strings.ToLower("Grpc-")) || strings.EqualFold(key, "content-type") { + continue - } + } + writeHttpHeaders(w, key, value) + if strings.HasSuffix(http.CanonicalHeaderKey(key), http.CanonicalHeaderKey("X-Http-Code")) { + codeStr := value[0] + code, err := strconv.ParseInt(codeStr, 10, 32) + if err != nil { + Log.Error(ctx, err) + return status.Error(codes.Internal, err.Error()) + } + httpCode = int(code) + + } - out, ok := metadata.FromIncomingContext(ctx) - if ok { - for k, v := range out { - writeHttpHeaders(w, k, v) } } + if httpCode != http.StatusOK { w.WriteHeader(httpCode) } diff --git a/gateway/middlewares_test.go b/gateway/middlewares_test.go index 688aa41..9e26339 100644 --- a/gateway/middlewares_test.go +++ b/gateway/middlewares_test.go @@ -5,7 +5,9 @@ import ( "net/http" "testing" + "github.com/agiledragon/gomonkey/v2" hello "github.com/begonia-org/go-sdk/api/example/v1" + "github.com/google/uuid" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -49,11 +51,37 @@ func TestLoggerMiddlewares(t *testing.T) { } c.So(f, c.ShouldNotPanic) f2 := func() { - _ = mid.StreamInterceptor(nil, &streamMock{}, &grpc.StreamServerInfo{FullMethod: "/test"}, func(srv interface{}, ss grpc.ServerStream) error { + _ = mid.StreamInterceptor(nil, &streamMock{ctx: context.Background()}, &grpc.StreamServerInfo{FullMethod: "/test"}, func(srv interface{}, ss grpc.ServerStream) error { panic("test") }) } + // f2() c.So(f2, c.ShouldNotPanic) + + desc := &grpc.StreamDesc{ + StreamName: "/INTEGRATION.TESTSERVICE/GET", + ClientStreams: true, + ServerStreams: true, + Handler: func(srv interface{}, ss grpc.ServerStream) error { + panic("test painc") + }, + } + patch:=gomonkey.ApplyFuncReturn(grpc.Method,"/INTEGRATION.TESTSERVICE/GET",true) + defer patch.Reset() + st, err := mid.StreamClientInterceptor(context.Background(), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(),metadata.New(make(map[string]string)))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + // has request id + st, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs(XRequestID, uuid.New().String())), desc, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: metadata.NewOutgoingContext(context.Background(),metadata.New(make(map[string]string)))}, nil + }) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + err = desc.Handler(nil, &streamMock{ctx: context.Background()}) + c.So(err, c.ShouldNotBeNil) }) } diff --git a/gateway/protobuf.go b/gateway/protobuf.go index d1b80b9..3ee2bd6 100644 --- a/gateway/protobuf.go +++ b/gateway/protobuf.go @@ -159,6 +159,7 @@ func NewDescriptionFromBinary(data []byte, outDir string) (ProtobufDescription, return nil, err } // desc.gatewayJsonSchema = filepath.Join(outDir, "gateway.json") + // log.Printf("GetFileDescriptorSet result is :%v",desc.GetFileDescriptorSet()) contents, err := register.Register(desc.GetFileDescriptorSet(), false, "") if err != nil { return nil, fmt.Errorf("Failed to register: %w", err) @@ -196,8 +197,11 @@ func (p *protobufDescription) GetMessageTypeByFullName(fullName string) protoref v := desc.(protoreflect.MessageDescriptor) return v } + // log.Printf("GetMessageTypeByFullName failed:%s", fullName) return nil } func (p *protobufDescription) GetGatewayJsonSchema() string { return p.gatewayJsonSchema } + + diff --git a/gateway/grpc.go b/gateway/proxy.go similarity index 60% rename from gateway/grpc.go rename to gateway/proxy.go index fe59bdf..573cfdd 100644 --- a/gateway/grpc.go +++ b/gateway/proxy.go @@ -2,13 +2,16 @@ package gateway import ( "context" + "errors" "fmt" "io" + "runtime/debug" "strings" "sync" "time" loadbalance "github.com/begonia-org/go-loadbalancer" + "github.com/begonia-org/go-sdk/logger" "github.com/spark-lence/tiga" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -16,7 +19,8 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" ) type grpcEndpointImpl struct { @@ -128,17 +132,34 @@ func (g *GrpcLoadBalancer) Select(method string, args ...interface{}) (loadbalan return nil, loadbalance.ErrNoEndpoint } +type IOType struct { + In protoreflect.MessageDescriptor + Out protoreflect.MessageDescriptor +} type GrpcProxyMiddleware func(srv interface{}, serverStream grpc.ServerStream) error type GrpcProxy struct { - lb *GrpcLoadBalancer - middlewares []GrpcProxyMiddleware + lb *GrpcLoadBalancer + // middlewares []grpc.StreamServerInterceptor + // chainStreamInts []grpc.StreamServerInterceptor + streamInt grpc.StreamClientInterceptor + chainClientStream []grpc.StreamClientInterceptor + ioType map[string]*IOType + log logger.Logger } -func NewGrpcProxy(lb *GrpcLoadBalancer, middlewares ...GrpcProxyMiddleware) *GrpcProxy { - return &GrpcProxy{ - lb: lb, - middlewares: middlewares, +func NewGrpcProxy(lb *GrpcLoadBalancer, log logger.Logger, middlewares ...grpc.StreamClientInterceptor) *GrpcProxy { + g := &GrpcProxy{ + lb: lb, + // middlewares: middlewares, + chainClientStream: middlewares, + ioType: make(map[string]*IOType), + log: log, } + g.chainStreamClientInterceptors() + return g +} +func (g *GrpcProxy) Register(lb loadbalance.LoadBalance, pd ProtobufDescription) { + g.lb.Register(lb, pd) } func (g *GrpcProxy) getClientIP(ctx context.Context) (string, error) { peer, ok := peer.FromContext(ctx) @@ -156,14 +177,53 @@ func (g *GrpcProxy) getXForward(ctx context.Context) []string { return md.Get("X-Forwarded-For") } -func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) error { +// chainStreamServerInterceptors chains all stream server interceptors into one. +func (g *GrpcProxy) chainStreamClientInterceptors() { + // Prepend opts.streamInt to the chaining interceptors if it exists, since streamInt will + // be executed before any other chained interceptors. + interceptors := g.chainClientStream + if g.streamInt != nil { + interceptors = append([]grpc.StreamClientInterceptor{g.streamInt}, g.chainClientStream...) + } - // 执行中间件 - for _, middleware := range g.middlewares { - if err := middleware(srv, serverStream); err != nil { - return err - } + var chainedInt grpc.StreamClientInterceptor + if len(interceptors) == 0 { + chainedInt = nil + } else if len(interceptors) == 1 { + chainedInt = interceptors[0] + } else { + chainedInt = g.chainStreamInterceptors(interceptors) } + + g.streamInt = chainedInt +} + +func (g *GrpcProxy) chainStreamInterceptors(interceptors []grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptors[0](ctx, desc, cc, method, g.getChainStreamHandler(interceptors, 0, streamer)) + } +} + +func (g *GrpcProxy) getChainStreamHandler(interceptors []grpc.StreamClientInterceptor, curr int, finalHandler grpc.Streamer) grpc.Streamer { + if curr == len(interceptors)-1 { + return finalHandler + } + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return interceptors[curr+1](ctx, desc, cc, method, g.getChainStreamHandler(interceptors, curr+1, finalHandler)) + } +} + +func (g *GrpcProxy) Do(srv interface{}, serverStream grpc.ServerStream) error { + + return g.Handler(srv, serverStream) +} +func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) error { + defer func() { + if p := recover(); p != nil { + s := debug.Stack() + g.log.Errorf(serverStream.Context(), "panic recover! p: %v stack:%s", p, s) + } + }() // 获取方法名 fullMethodName, ok := grpc.MethodFromServerStream(serverStream) if !ok { @@ -193,6 +253,7 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err defer endpoint.AfterTransform(serverStream.Context(), cn.((loadbalance.Connection))) conn := cn.(loadbalance.Connection).ConnInstance().(*grpc.ClientConn) + clientCtx, clientCancel := context.WithCancel(serverStream.Context()) defer clientCancel() proxyDesc := &grpc.StreamDesc{ @@ -202,28 +263,43 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err // 添加本地ip local, _ := tiga.GetLocalIP() xForwards = append(xForwards, local) - md, ok := metadata.FromIncomingContext(clientCtx) + md, ok := metadata.FromIncomingContext(serverStream.Context()) if !ok { md = metadata.MD{} } + md.Set("X-Forwarded-For", strings.Join(xForwards, ",")) clientCtx = metadata.NewOutgoingContext(clientCtx, md) + var clientStream grpc.ClientStream = nil - clientStream, err := grpc.NewClientStream(clientCtx, proxyDesc, conn, fullMethodName) - if err != nil { - return err + if len(g.chainClientStream) > 0 && g.streamInt != nil { + clientStream, err = g.streamInt(clientCtx, proxyDesc, conn, fullMethodName, g.getChainStreamHandler(g.chainClientStream, 0, func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return grpc.NewClientStream(ctx, desc, cc, method, opts...) + })) + if err != nil { + return fmt.Errorf("failed creating client stream from stream client int: %v", err) + } + } else { + clientStream, err = grpc.NewClientStream(clientCtx, proxyDesc, conn, fullMethodName) + if err != nil { + return err + } } + + // proxyStream := &proxyClientStream{ClientStream: clientStream, ctx: clientCtx} // 转发流量 // 从客户端到服务端 s2cErrChan := g.forwardServerToClient(serverStream, clientStream) + // 从服务端到客户端 c2sErrChan := g.forwardClientToServer(clientStream, serverStream) for i := 0; i < 2; i++ { select { case s2cErr := <-s2cErrChan: - if s2cErr == io.EOF { + if errors.Is(s2cErr, io.EOF) { // this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./ // the clientStream>serverStream may continue pumping though. + // log.Printf("s2cErr:%v", s2cErr) err = clientStream.CloseSend() if err != nil { return status.Errorf(codes.Internal, "failed closing client stream: %v", err) @@ -241,7 +317,7 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err // will be nil. serverStream.SetTrailer(clientStream.Trailer()) // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. - if c2sErr != io.EOF { + if !errors.Is(c2sErr, io.EOF) { return c2sErr } return nil @@ -253,10 +329,20 @@ func (g *GrpcProxy) Handler(srv interface{}, serverStream grpc.ServerStream) err func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error { ret := make(chan error, 1) go func() { - f := &emptypb.Empty{} + defer func () { + if p := recover(); p != nil { + s := debug.Stack() + g.log.Errorf(dst.Context(), "panic recover! p: %v stack:%s", p, s) + } + }() + method, _ := grpc.Method(dst.Context()) + io := g.ioType[method] + f := dynamicpb.NewMessage(io.Out) + // f := &emptypb.Empty{} + for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - ret <- err // this can be io.EOF which is happy case + ret <- fmt.Errorf("fail forward client to server:%w", err) // this can be io.EOF which is happy case break } if i == 0 { @@ -264,18 +350,25 @@ func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.Server // received but must be written to server stream before the first msg is flushed. // This is the only place to do it nicely // 先转发header + // inMD, _ := metadata.FromIncomingContext(src.Context()) + // outMD, _ := metadata.FromOutgoingContext(src.Context()) + + // md := metadata.Join(inMD, outMD) + // m := make(map[string][]string) + md, err := src.Header() if err != nil { - ret <- err + ret <- fmt.Errorf("failed reading header from client stream: %w", err) break } + if err := dst.SendHeader(md); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending header to server stream: %w", err) break } } if err := dst.SendMsg(f); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending msg to server stream: %w", err) break } } @@ -286,17 +379,46 @@ func (g *GrpcProxy) forwardClientToServer(src grpc.ClientStream, dst grpc.Server func (g *GrpcProxy) forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error { ret := make(chan error, 1) go func() { - f := &emptypb.Empty{} + method, _ := grpc.Method(dst.Context()) + io := g.ioType[method] + f := dynamicpb.NewMessage(io.In) + for i := 0; ; i++ { if err := src.RecvMsg(f); err != nil { - ret <- err // this can be io.EOF which is happy case + ret <- fmt.Errorf("recv msg error:%w from client forward", err) // this can be io.EOF which is happy case break } + if err := dst.SendMsg(f); err != nil { - ret <- err + ret <- fmt.Errorf("failed sending msg to client stream: %w", err) break } } }() return ret } + +func (g *GrpcProxy) buildServiceDesc(pd ProtobufDescription) { + fd := pd.GetFileDescriptorSet() + // sds := make([]*grpc.ServiceDesc, 0) + for _, file := range fd.GetFile() { + sd := file.GetService() + + for _, service := range sd { + + methods := service.GetMethod() + for _, method := range methods { + in := method.GetInputType() + inDesc := pd.GetMessageTypeByFullName(strings.TrimPrefix(in, ".")) + outDesc := pd.GetMessageTypeByFullName(strings.TrimPrefix(method.GetOutputType(), ".")) + srv := fmt.Sprintf("/%s.%s/%s", file.GetPackage(), service.GetName(), method.GetName()) + g.ioType[srv] = &IOType{ + In: inDesc, + Out: outDesc, + } + + } + + } + } +} diff --git a/gateway/grpc_test.go b/gateway/proxy_test.go similarity index 65% rename from gateway/grpc_test.go rename to gateway/proxy_test.go index 2bbacbb..ce46070 100644 --- a/gateway/grpc_test.go +++ b/gateway/proxy_test.go @@ -14,6 +14,8 @@ import ( "time" "github.com/agiledragon/gomonkey/v2" + cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/internal/pkg/config" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" hello "github.com/begonia-org/go-sdk/api/example/v1" @@ -26,8 +28,11 @@ import ( ) type streamMock struct { + ctx context.Context +} +type clientStreamMock struct { + ctx context.Context } -type clientStreamMock struct{} func (*streamMock) SendHeader(md metadata.MD) error { return nil @@ -37,10 +42,11 @@ func (*streamMock) SetHeader(md metadata.MD) error { } func (*streamMock) SetTrailer(md metadata.MD) { } -func (*streamMock) Context() context.Context { - return context.Background() +func (s *streamMock) Context() context.Context { + return s.ctx } func (*streamMock) SendMsg(m interface{}) error { + time.Sleep(1 * time.Second) return nil } func (*streamMock) RecvMsg(m interface{}) error { @@ -94,13 +100,26 @@ func TestGrpcHandleErr(t *testing.T) { load, _ := loadbalance.New(loadbalance.RRBalanceType, endps) lb := NewGrpcLoadBalancer() lb.Register(load, pd) - mid := func(srv interface{}, serverStream grpc.ServerStream) error { - return nil + fullMethod1 := "" + mid := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + fullMethod1 = method + return streamer(ctx, desc, cc, method, opts...) + } + fullMethod2 := "" + mid2 := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + fullMethod2 = method + return streamer(ctx, desc, cc, method, opts...) + } + proxy := NewGrpcProxy(lb, Log, mid, mid2) + proxy.buildServiceDesc(pd) + + stream := &streamMock{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("uri", "/api/v1/example/server/websocket")), } - proxy := NewGrpcProxy(lb, mid) - stream := &streamMock{} patch := gomonkey.ApplyFuncReturn(grpc.MethodFromServerStream, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket"), true) - patch.ApplyFuncReturn(grpc.NewClientStream, &clientStreamMock{}, nil) + patch.ApplyFuncReturn(grpc.NewClientStream, &clientStreamMock{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("uri", "/api/v1/example/server/websocket")), + }, nil) addrs, _ := net.InterfaceAddrs() var localAddr net.Addr for _, addr := range addrs { @@ -121,10 +140,15 @@ func TestGrpcHandleErr(t *testing.T) { err error output []interface{} }{ + // { + // patch: metadata.FromIncomingContext, + // err: fmt.Errorf("metadata not exists in context"), + // output: []interface{}{nil, false}, + // }, { patch: mid, err: fmt.Errorf("mid handle err"), - output: []interface{}{fmt.Errorf("mid handle err")}, + output: []interface{}{nil, fmt.Errorf("mid handle err")}, }, { patch: (*clientStreamMock).CloseSend, @@ -153,14 +177,21 @@ func TestGrpcHandleErr(t *testing.T) { return io.EOF }) defer patch3.Reset() + patch6 := gomonkey.ApplyFuncReturn(grpc.Method, "/helloworld.Greeter/SayHelloWebsocket", true) + defer patch6.Reset() for _, caseV := range cases { patch2 := gomonkey.ApplyFuncReturn(caseV.patch, caseV.output...) defer patch2.Reset() - err = proxy.Handler(&hello.HelloRequest{}, stream) + + err = proxy.Do(&hello.HelloRequest{}, stream) + t.Log(caseV.err.Error()) + // t.Logf("err:%v", err.Error()) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, caseV.err.Error()) patch2.Reset() } + c.So(fullMethod1, c.ShouldEqual, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket")) + c.So(fullMethod2, c.ShouldEqual, strings.ToUpper("/helloworld.Greeter/SayHelloWebsocket")) patch3.Reset() errChan2 := make(chan error, 3) @@ -168,12 +199,27 @@ func TestGrpcHandleErr(t *testing.T) { errChan2 <- io.EOF errChan2 <- io.EOF patch4 := gomonkey.ApplyFuncReturn((*GrpcProxy).forwardServerToClient, errChan2) - err = proxy.Handler(&hello.HelloRequest{}, stream) + err = proxy.Do(&hello.HelloRequest{}, stream) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "proxying should never reach") defer patch4.Reset() patch.Reset() + patch4.Reset() + // patch6.Reset() + patch7:=gomonkey.ApplyFuncReturn(grpc.MethodFromServerStream,"/helloworld.Greeter/SayHelloWebsocket",false) + defer patch7.Reset() + + proxy2 := NewGrpcProxy(lb, Log) + err = proxy2.Do(&hello.HelloRequest{}, stream) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "stream not exists in context") }) } + +func TestProxyDo(t *testing.T) { + pd, _ := readDesc(config.NewConfig(cfg.ReadConfig("test"))) + p := &GrpcProxy{ioType: make(map[string]*IOType), lb: NewGrpcLoadBalancer(), log: Log} + p.buildServiceDesc(pd) +} diff --git a/internal/pkg/routers/routers.go b/gateway/routers.go similarity index 87% rename from internal/pkg/routers/routers.go rename to gateway/routers.go index 52ddb09..816a644 100644 --- a/internal/pkg/routers/routers.go +++ b/gateway/routers.go @@ -1,20 +1,10 @@ -package routers +package gateway import ( "fmt" - "log" "strings" "sync" - "github.com/begonia-org/begonia/gateway" - _ "github.com/begonia-org/go-sdk/api/app/v1" - _ "github.com/begonia-org/go-sdk/api/endpoint/v1" - _ "github.com/begonia-org/go-sdk/api/example/v1" - _ "github.com/begonia-org/go-sdk/api/iam/v1" - _ "github.com/begonia-org/go-sdk/api/plugin/v1" - _ "github.com/begonia-org/go-sdk/api/sys/v1" - _ "github.com/begonia-org/go-sdk/api/user/v1" - _ "github.com/begonia-org/go-sdk/common/api/v1" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/protobuf/proto" @@ -53,7 +43,7 @@ func NewHttpURIRouteToSrvMethod() *HttpURIRouteToSrvMethod { }) return httpURIRouteToSrvMethod } -func Get() *HttpURIRouteToSrvMethod { +func GetRouter() *HttpURIRouteToSrvMethod { return NewHttpURIRouteToSrvMethod() } @@ -61,7 +51,8 @@ func (r *HttpURIRouteToSrvMethod) AddRoute(uri string, srvMethod *APIMethodDetai r.mux.Lock() defer r.mux.Unlock() r.routers[uri] = srvMethod - r.grpcRouter[srvMethod.GrpcFullRouter] = srvMethod + // log.Printf("add srv method grpc router:%s,pointer:%p", srvMethod.GrpcFullRouter, r) + r.grpcRouter[strings.ToUpper(srvMethod.GrpcFullRouter)] = srvMethod } func (r *HttpURIRouteToSrvMethod) deleteRoute(uri string, grpcFullMethod string) { delete(r.routers, uri) @@ -72,6 +63,7 @@ func (r *HttpURIRouteToSrvMethod) GetRoute(uri string) *APIMethodDetails { return r.routers[uri] } func (r *HttpURIRouteToSrvMethod) GetRouteByGrpcMethod(method string) *APIMethodDetails { + // log.Printf("get grpc method,%v:%s,pointer:%p",r.grpcRouter,strings.ToUpper(method),r) return r.grpcRouter[strings.ToUpper(method)] } func (r *HttpURIRouteToSrvMethod) GetAllRoutes() map[string]*APIMethodDetails { @@ -98,7 +90,7 @@ func (r *HttpURIRouteToSrvMethod) getHttpRule(method *descriptorpb.MethodDescrip return nil } func (r *HttpURIRouteToSrvMethod) AddLocalSrv(fullMethod string) { - log.Printf("add local srv:%s", fullMethod) + // log.Printf("add local srv:%s", fullMethod) r.localSrv[strings.ToUpper(fullMethod)] = true } func (r *HttpURIRouteToSrvMethod) IsLocalSrv(fullMethod string) bool { @@ -156,7 +148,7 @@ func (r *HttpURIRouteToSrvMethod) addRouterDetails(serviceName string, useJsonRe } } -func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) { +func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd ProtobufDescription) { fds := pd.GetFileDescriptorSet() for _, fd := range fds.File { for _, service := range fd.Service { @@ -181,7 +173,7 @@ func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) } -func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd gateway.ProtobufDescription) { +func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd ProtobufDescription) { fds := pd.GetFileDescriptorSet() for _, fd := range fds.File { for _, service := range fd.Service { diff --git a/internal/pkg/routers/routers_test.go b/gateway/routers_test.go similarity index 77% rename from internal/pkg/routers/routers_test.go rename to gateway/routers_test.go index c07cd8d..be6046e 100644 --- a/internal/pkg/routers/routers_test.go +++ b/gateway/routers_test.go @@ -1,4 +1,4 @@ -package routers_test +package gateway_test import ( "path/filepath" @@ -6,15 +6,14 @@ import ( "testing" "github.com/begonia-org/begonia/gateway" - "github.com/begonia-org/begonia/internal/pkg/routers" c "github.com/smartystreets/goconvey/convey" ) func TestLoadAllRouters(t *testing.T) { c.Convey("TestLoadAllRouters", t, func() { - R := routers.NewHttpURIRouteToSrvMethod() + R := gateway.NewHttpURIRouteToSrvMethod() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata") pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) R.LoadAllRouters(pd) @@ -37,9 +36,9 @@ func TestLoadAllRouters(t *testing.T) { } func TestDeleteRouters(t *testing.T) { c.Convey("TestDeleteRouters", t, func() { - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + pbFile := filepath.Join((filepath.Dir(filepath.Dir(filename))), "testdata") pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) diff --git a/gateway/utils_test.go b/gateway/utils_test.go index b368ed0..ee16dc0 100644 --- a/gateway/utils_test.go +++ b/gateway/utils_test.go @@ -16,7 +16,7 @@ import ( func TestNewEndpoint(t *testing.T) { opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), diff --git a/go.mod b/go.mod index ff477f2..691f2a8 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/smartystreets/goconvey v1.8.1 github.com/spark-lence/tiga v0.0.0-20240707025120-c2e1f47f88dc github.com/spf13/cobra v1.8.0 - google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 + google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 ) @@ -80,7 +80,7 @@ require ( require ( github.com/agiledragon/gomonkey/v2 v2.11.0 github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a - github.com/begonia-org/go-sdk v0.0.0-20240707025218-b8159f0b2462 + github.com/begonia-org/go-sdk v0.0.0-20240714083941-00e95e667477 github.com/go-git/go-git/v5 v5.11.0 github.com/go-playground/validator/v10 v10.19.0 github.com/gorilla/websocket v1.5.0 @@ -140,7 +140,7 @@ require ( go.uber.org/zap v1.27.0 // indirect golang.org/x/sys v0.22.0 // indirect golang.org/x/text v0.16.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index 2d1245a..e16b822 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/begonia-org/go-layered-cache v0.0.0-20240510102605-41bdb7aa07fa h1:DH github.com/begonia-org/go-layered-cache v0.0.0-20240510102605-41bdb7aa07fa/go.mod h1:xEqoca1vNGqH8CV7X9EzhDV5Ihtq9J95p7ZipzUB6pc= github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a h1:Mpw7T+90KC5QW7yCa8Nn/5psnlvsexipAOrQAcc7YE0= github.com/begonia-org/go-loadbalancer v0.0.0-20240519060752-71ca464f0f1a/go.mod h1:crPS67sfgmgv47psftwfmTMbmTfdepVm8MPeqApINlI= -github.com/begonia-org/go-sdk v0.0.0-20240707025218-b8159f0b2462 h1:qd5Fim08aHI+RlBqyZhJEWtuIm9Q03DMm+AToh4Uprc= -github.com/begonia-org/go-sdk v0.0.0-20240707025218-b8159f0b2462/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= +github.com/begonia-org/go-sdk v0.0.0-20240714083941-00e95e667477 h1:8LhQbM+Y51bsyo0Z2vfyx1ycjhv1UKL6YHJgx9eZyPg= +github.com/begonia-org/go-sdk v0.0.0-20240714083941-00e95e667477/go.mod h1:2mHpFudwolu6RHF18EX+lnFYyTNnwDxBD6JcfRcahz8= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -395,10 +395,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:BwIjyKYGsK9dMCBOorzRri8MQwmi7mT9rGHsCEinZkA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d h1:kHjw/5UfflP/L5EbledDrcG4C2597RtymmGRZvHiCuY= +google.golang.org/genproto/googleapis/api v0.0.0-20240711142825-46eb208f015d/go.mod h1:mw8MG/Qz5wfgYr6VqVCiZcHe/GJEfI+oGGDCohaVgB0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d h1:JU0iKnSg02Gmb5ZdV8nYsKEKsP6o/FGVWTrw4i1DA9A= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240711142825-46eb208f015d/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= diff --git a/internal/biz/aksk.go b/internal/biz/aksk.go index 6fa1329..7d81263 100644 --- a/internal/biz/aksk.go +++ b/internal/biz/aksk.go @@ -5,9 +5,9 @@ import ( "strings" "time" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -30,7 +30,7 @@ func NewAccessKeyAuth(app AppRepo, config *config.Config, log logger.Logger) *Ac } func IfNeedValidate(ctx context.Context, fullMethod string) bool { - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(strings.ToUpper(fullMethod)) if router == nil { return false diff --git a/internal/biz/aksk_test.go b/internal/biz/aksk_test.go index 373cc6e..06aef13 100644 --- a/internal/biz/aksk_test.go +++ b/internal/biz/aksk_test.go @@ -16,7 +16,6 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/pkg/utils" gosdk "github.com/begonia-org/go-sdk" @@ -138,7 +137,7 @@ func testIfNeedValidate(t *testing.T) { ok := biz.IfNeedValidate(context.TODO(), akskAccess) c.So(ok, c.ShouldBeFalse) - patch := gomonkey.ApplyFuncReturn((*routers.HttpURIRouteToSrvMethod).GetRouteByGrpcMethod, &routers.APIMethodDetails{AuthRequired: true}) + patch := gomonkey.ApplyFuncReturn((*gateway.HttpURIRouteToSrvMethod).GetRouteByGrpcMethod, &gateway.APIMethodDetails{AuthRequired: true}) defer patch.Reset() ok = biz.IfNeedValidate(context.TODO(), akskAccess) c.So(ok, c.ShouldBeTrue) diff --git a/internal/biz/data_test.go b/internal/biz/data_test.go index c8c693a..91c5958 100644 --- a/internal/biz/data_test.go +++ b/internal/biz/data_test.go @@ -116,7 +116,7 @@ func TestDo(t *testing.T) { _ = cache.Del(context.Background(), "begonia:user:black:lock") _ = cache.Del(context.Background(), "begonia:user:black:last_updated") opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), diff --git a/internal/biz/endpoint/endpoint_test.go b/internal/biz/endpoint/endpoint_test.go index b669b22..01fcd21 100644 --- a/internal/biz/endpoint/endpoint_test.go +++ b/internal/biz/endpoint/endpoint_test.go @@ -21,7 +21,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gwRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" goloadbalancer "github.com/begonia-org/go-loadbalancer" @@ -252,6 +251,8 @@ func testPatchEndpoint(t *testing.T) { c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "test watcher error") + + }) } @@ -374,7 +375,7 @@ func testWatcherUpdate(t *testing.T) { } val, _ := json.Marshal(value) opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), @@ -385,12 +386,12 @@ func testWatcherUpdate(t *testing.T) { GrpcProxyAddr: "127.0.0.1:12148", } gateway.New(gwCnf, opts) - routers.NewHttpURIRouteToSrvMethod() + gateway.NewHttpURIRouteToSrvMethod() c.Convey("Test Watcher Update", t, func() { err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldBeNil) - r := routers.Get() + r := gateway.GetRouter() detail := r.GetRoute("/api/v1/example/{name}") c.So(detail, c.ShouldNotBeNil) @@ -433,6 +434,28 @@ func testWatcherUpdate(t *testing.T) { c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrUnknownLoadBalancer.Error()) + + // SetHttpResponse err + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + conf := config.ReadConfig(env) + cnf := cfg.NewConfig(conf) + outDir := cnf.GetGatewayDescriptionOut() + _, filename, _, _ := runtime.Caller(0) + + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata", "helloworld.pb") + pb, err := os.ReadFile(pbFile) + c.So(err, c.ShouldBeNil) + pd, err := gateway.NewDescriptionFromBinary(pb, filepath.Join(outDir, "tmp-test")) + c.So(err, c.ShouldBeNil) + patch6 := gomonkey.ApplyMethodReturn(pd, "SetHttpResponse", fmt.Errorf("test SetHttpResponse error")) + defer patch6.Reset() + err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "test SetHttpResponse error") }) } func testWatcherDel(t *testing.T) { @@ -453,7 +476,7 @@ func testWatcherDel(t *testing.T) { c.Convey("Test Watcher Del", t, func() { err := watcher.Handle(context.TODO(), mvccpb.DELETE, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldBeNil) - r := routers.Get() + r := gateway.GetRouter() detail := r.GetRoute("/api/v1/example/{name}") c.So(detail, c.ShouldBeNil) }) diff --git a/internal/biz/endpoint/utils.go b/internal/biz/endpoint/utils.go index bcd5d4e..2de7aca 100644 --- a/internal/biz/endpoint/utils.go +++ b/internal/biz/endpoint/utils.go @@ -7,14 +7,13 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/grpc/codes" ) func deleteAll(ctx context.Context, pd gateway.ProtobufDescription) error { - routersList := routers.Get() + routersList := gateway.GetRouter() routersList.DeleteRouters(pd) gw := gateway.Get() gw.DeleteLoadBalance(pd) diff --git a/internal/biz/endpoint/watcher.go b/internal/biz/endpoint/watcher.go index 8149c29..edc414e 100644 --- a/internal/biz/endpoint/watcher.go +++ b/internal/biz/endpoint/watcher.go @@ -3,17 +3,17 @@ package endpoint import ( "context" "fmt" + "log" "sync" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" "go.etcd.io/etcd/api/v3/mvccpb" "encoding/json" - "github.com/begonia-org/begonia/gateway" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -38,7 +38,7 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) return nil } endpoint := &api.Endpoints{} - routersList := routers.NewHttpURIRouteToSrvMethod() + routersList := gateway.GetRouter() err := json.Unmarshal([]byte(value), endpoint) if err != nil { return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") @@ -62,15 +62,24 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) } // register routers // log.Print("register router") + err=pd.SetHttpResponse(common.E_HttpResponse) + if err != nil { + return gosdk.NewError(fmt.Errorf("set http response error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "set_http_response") + + } routersList.LoadAllRouters(pd) + // register service to gateway gw := gateway.Get() err = gw.RegisterService(ctx, pd, lb) + if err != nil { return gosdk.NewError(fmt.Errorf("register service error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "register_service") } + gw.RegisterServiceWithProxy(pd) // err = g.repo.PutTags(ctx, endpoint.Key, endpoint.Tags) + log.Printf("register service success") return nil } func (g *EndpointWatcher) Del(ctx context.Context, key string, value string) error { diff --git a/internal/middleware/auth/ak_test.go b/internal/middleware/auth/ak_test.go index 9517232..a074b8e 100644 --- a/internal/middleware/auth/ak_test.go +++ b/internal/middleware/auth/ak_test.go @@ -15,7 +15,6 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/middleware/auth" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" c "github.com/smartystreets/goconvey/convey" @@ -38,7 +37,7 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { ak.SetPriority(1) c.So(ak.Name(), c.ShouldEqual, "ak_auth") c.So(ak.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -58,11 +57,18 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { return fmt.Errorf("metadata not exists in context") }) c.So(err.Error(), c.ShouldContainSubstring, fmt.Errorf("metadata not exists in context").Error()) - - patch := gomonkey.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamRequestBefore, nil, nil) + patch := gomonkey.ApplyMethodReturn(akBiz, "AppValidator", "test", nil) + patch = patch.ApplyMethodReturn(akBiz, "GetAppOwner", "test", nil) + // patch := gomonkey.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamRequestBefore, nil, nil) patch = patch.ApplyFuncReturn((*auth.AccessKeyAuthMiddleware).StreamResponseAfter, fmt.Errorf("StreamResponseAfter err")) defer patch.Reset() - err = ak.StreamInterceptor(context.Background(), &testStream{ctx: context.Background()}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, func(srv any, stream grpc.ServerStream) error { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + err = ak.StreamInterceptor(ctx, &testStream{ctx: ctx}, &grpc.StreamServerInfo{FullMethod: "/integration.TestService/Get"}, func(srv any, ss grpc.ServerStream) error { + md, _ := metadata.FromIncomingContext(ss.Context()) + if len(md.Get(gosdk.HeaderXIdentity)) == 0 || md.Get(gosdk.HeaderXIdentity)[0] == "" { + t.Error("identity not exists in context") + return fmt.Errorf("identity not exists in context") + } return nil }) @@ -76,6 +82,31 @@ func TestAccessKeyAuthMiddleware(t *testing.T) { }) patch2.Reset() c.So(err.Error(), c.ShouldContainSubstring, fmt.Errorf("StreamRequestBefore err").Error()) + // do not need validate + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAccessKey, "test")) + _, err = ak.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + + // NO CONTEXT + _, err = ak.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + // get owner error + patch3:=gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, "", fmt.Errorf("get owner error")) + defer patch3.Reset() + // get owner error + in:=metadata.NewIncomingContext(context.Background(),metadata.Pairs(gosdk.HeaderXAccessKey,"test")) + _,err=ak.StreamRequestBefore(in,&testStream{ctx:in},&grpc.StreamServerInfo{FullMethod:"/integration.TestService/Get"},nil) + patch3.Reset() + c.So(err,c.ShouldNotBeNil) + c.So(err.Error(),c.ShouldContainSubstring,"get app owner error") + + + }) } func TestRequestBeforeErr(t *testing.T) { @@ -133,7 +164,7 @@ func TestValidateStream(t *testing.T) { ak := auth.NewAccessKeyAuth(akBiz, cnf, gateway.Log) patch := gomonkey.ApplyFuncReturn(gosdk.NewGatewayRequestFromGrpc, nil, fmt.Errorf("NewGatewayRequestFromGrpc err")) defer patch.Reset() - _, err := ak.ValidateStream(context.TODO(), nil, "", nil) + _, err := ak.ValidateStream(context.TODO(), nil, "") patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "NewGatewayRequestFromGrpc err") diff --git a/internal/middleware/auth/aksk.go b/internal/middleware/auth/aksk.go index 01c572c..a120260 100644 --- a/internal/middleware/auth/aksk.go +++ b/internal/middleware/auth/aksk.go @@ -4,9 +4,9 @@ import ( "context" "strings" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" "github.com/begonia-org/go-sdk/logger" "google.golang.org/grpc" @@ -34,7 +34,7 @@ func NewAccessKeyAuth(app *biz.AccessKeyAuth, config *config.Config, log logger. } func IfNeedValidate(ctx context.Context, fullMethod string) bool { - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(strings.ToUpper(fullMethod)) if router == nil { return false @@ -63,8 +63,8 @@ func (a *AccessKeyAuthMiddleware) RequestBefore(ctx context.Context, info *grpc. if !ok { md = metadata.MD{} } - // md.Set(gosdk.HeaderXIdentity, owner) - md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, owner)) + md.Set(gosdk.HeaderXIdentity, owner) + // md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, owner)) ctx = metadata.NewIncomingContext(ctx, md) // md2, _ := metadata.FromIncomingContext(ctx) @@ -73,21 +73,29 @@ func (a *AccessKeyAuthMiddleware) RequestBefore(ctx context.Context, info *grpc. } -func (a *AccessKeyAuthMiddleware) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *AccessKeyAuthMiddleware) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { ctx, err := a.RequestBefore(ctx, &grpc.UnaryServerInfo{FullMethod: fullName}, req) if err != nil { return ctx, err } - md, _ := metadata.FromIncomingContext(ctx) - if identity := md.Get(gosdk.HeaderXIdentity); len(identity) > 0 { - headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity[0]) - } + return ctx, nil } func (a *AccessKeyAuthMiddleware) StreamRequestBefore(ctx context.Context, ss grpc.ServerStream, info *grpc.StreamServerInfo, req interface{}) (grpc.ServerStream, error) { - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) - // defer grpcStream.Release() + if in, ok := metadata.FromIncomingContext(ctx); ok { + if ak := in.Get(gosdk.HeaderXAccessKey); len(ak) > 0 { + identity, err := a.app.GetAppOwner(ctx, ak[0]) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "get app owner error,%v", err) + } + md, _ := metadata.FromIncomingContext(ctx) + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.NewOutgoingContext(ctx, md) + } + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) return grpcStream, nil } @@ -103,6 +111,7 @@ func (a *AccessKeyAuthMiddleware) UnaryInterceptor(ctx context.Context, req any, defer func() { _ = a.ResponseAfter(ctx, info, req, resp) }() + resp, err = handler(ctx, req) return resp, err @@ -111,10 +120,12 @@ func (a *AccessKeyAuthMiddleware) StreamInterceptor(srv interface{}, ss grpc.Ser if !IfNeedValidate(ss.Context(), info.FullMethod) { return handler(srv, ss) } + grpcStream, err := a.StreamRequestBefore(ss.Context(), ss, info, srv) if err != nil { return err } + // log.Printf("AccessKeyAuthMiddleware StreamInterceptor") defer func() { err := a.StreamResponseAfter(ss.Context(), ss, info) if err != nil { @@ -145,3 +156,34 @@ func (a *AccessKeyAuthMiddleware) Priority() int { func (a *AccessKeyAuthMiddleware) Name() string { return a.name } +func (a *AccessKeyAuthMiddleware) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") + } + if AK := md.Get(gosdk.HeaderXAccessKey); len(AK) > 0 { + accessKey := AK[0] + identity, err := a.app.GetAppOwner(ctx, accessKey) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "get app owner error,%v", err) + } + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewOutgoingContext(ctx, md) + in,ok:=metadata.FromIncomingContext(ctx) + if !ok { + in=metadata.MD{} + } + in.Set(gosdk.HeaderXIdentity,identity) + ctx=metadata.NewIncomingContext(ctx,in) + // ctx = metadata.NewIncomingContext() + } + st, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, status.Errorf(codes.Internal, "streamer error,%v", err) + } + return st, nil +} diff --git a/internal/middleware/auth/apikey.go b/internal/middleware/auth/apikey.go index 99c6db6..2bfde06 100644 --- a/internal/middleware/auth/apikey.go +++ b/internal/middleware/auth/apikey.go @@ -3,7 +3,6 @@ package auth import ( "context" "fmt" - "strings" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/pkg" @@ -66,22 +65,31 @@ func NewApiKeyAuth(config *config.Config, authz *biz.AuthzUsecase) ApiKeyAuth { } func (a *ApiKeyAuthImpl) check(ctx context.Context) (string, error) { md, ok := metadata.FromIncomingContext(ctx) - if !ok { + out, outOK := metadata.FromOutgoingContext(ctx) + if !ok && !outOK { return "", gosdk.NewError(status.Errorf(codes.Unauthenticated, "metadata not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } // authorization := a.GetAuthorizationFromMetadata(md) apikeys := md.Get(gosdk.HeaderXApiKey) - if len(apikeys) == 0 { + outAPIKeys := out.Get(gosdk.HeaderXApiKey) + if len(apikeys) == 0 && len(outAPIKeys) == 0 { return "", gosdk.NewError(status.Errorf(codes.Unauthenticated, "apikey not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } - apikey := apikeys[0] - if apikey != a.config.GetAdminAPIKey() { + + apikey:="" + if len(apikeys) != 0 { + apikey = apikeys[0] + }else if len(outAPIKeys) != 0 { + apikey = outAPIKeys[0] + } + + if apikey != a.config.GetAdminAPIKey(){ return "", gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } return apikey, nil } -func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { apikey := "" var err error if apikey, err = a.check(ctx); err == nil && apikey != "" { @@ -92,7 +100,7 @@ func (a *ApiKeyAuthImpl) ValidateStream(ctx context.Context, req interface{}, fu md, _ := metadata.FromIncomingContext(ctx) md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) - headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity) + // headers.Set(strings.ToLower(gosdk.HeaderXIdentity), identity) return metadata.NewIncomingContext(ctx, md), err } return ctx, err @@ -102,7 +110,46 @@ func (a *ApiKeyAuthImpl) StreamInterceptor(srv interface{}, ss grpc.ServerStream if !IfNeedValidate(ss.Context(), info.FullMethod) { return handler(srv, ss) } - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) + ctx := ss.Context() + if apikey, err := a.check(ctx); err == nil && apikey != "" { + identity, err := a.authz.GetIdentity(ctx, gosdk.ApiKeyType, apikey) + if err != nil { + return gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + } + md, _ := metadata.FromIncomingContext(ctx) + md.Set(gosdk.HeaderXIdentity, identity) + ctx = metadata.NewIncomingContext(ctx, md) + ctx = metadata.AppendToOutgoingContext(ctx, gosdk.HeaderXIdentity, identity) + } else { + return gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) defer grpcStream.Release() - return handler(srv, grpcStream) + err := handler(srv, grpcStream) + + return err +} +func (a *ApiKeyAuthImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + if apikey, err := a.check(ctx); err == nil && apikey != "" { + identity, err := a.authz.GetIdentity(ctx, gosdk.ApiKeyType, apikey) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("query user id base on apikey err:%w", err), int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + } + md, _ := metadata.FromOutgoingContext(ctx) + md = metadata.Join(md, metadata.Pairs(gosdk.HeaderXIdentity, identity)) + ctx = metadata.NewOutgoingContext(ctx, md) + in, ok := metadata.FromIncomingContext(ctx) + if !ok { + in = metadata.New(make(map[string]string)) + } + in.Set(gosdk.HeaderXIdentity, identity) + // log.Printf("incoming identity:%s", identity) + ctx = metadata.NewIncomingContext(ctx, in) + return streamer(ctx, desc, cc, method, opts...) + } + return nil, gosdk.NewError(pkg.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } diff --git a/internal/middleware/auth/apikey_test.go b/internal/middleware/auth/apikey_test.go index 489d397..8c23a44 100644 --- a/internal/middleware/auth/apikey_test.go +++ b/internal/middleware/auth/apikey_test.go @@ -17,7 +17,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" c "github.com/smartystreets/goconvey/convey" @@ -39,7 +38,7 @@ func TestAPIKeyUnaryInterceptor(t *testing.T) { c.So(apikey.Name(), c.ShouldEqual, "api_key_auth") c.So(apikey.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -101,6 +100,18 @@ func TestAPIKeyUnaryInterceptor(t *testing.T) { patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "get user id base on apikey error") + // ValidateStream error + patch2 := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get user id base on apikey error")) + defer patch2.Reset() + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = apikey.(*auth.ApiKeyAuthImpl).ValidateStream(ctx, nil, "/integration.TestService/Get") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get user id base on apikey error") + patch2.Reset() + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) + _, err = apikey.(*auth.ApiKeyAuthImpl).ValidateStream(ctx, nil, "/integration.TestService/Get") + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) }) } func newAuthzBiz() *biz.AuthzUsecase { @@ -131,7 +142,7 @@ func TestApiKeyStreamInterceptor(t *testing.T) { c.So(apikey.Name(), c.ShouldEqual, "api_key_auth") c.So(apikey.Priority(), c.ShouldEqual, 1) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -155,12 +166,13 @@ func TestApiKeyStreamInterceptor(t *testing.T) { ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())), }}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { - err := ss.RecvMsg(srv) md, _ := metadata.FromIncomingContext(ss.Context()) if len(md.Get(gosdk.HeaderXIdentity)) == 0 || md.Get(gosdk.HeaderXIdentity)[0] == "" { t.Error("identity not exists in context") return fmt.Errorf("identity not exists in context") } + err := ss.RecvMsg(srv) + return err }) diff --git a/internal/middleware/auth/auth.go b/internal/middleware/auth/auth.go index acc21cd..9a0f8e3 100644 --- a/internal/middleware/auth/auth.go +++ b/internal/middleware/auth/auth.go @@ -31,6 +31,7 @@ func NewAuth(ak *AccessKeyAuthMiddleware, jwt *JWTAuth, apikey ApiKeyAuth) gosdk } func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + // fmt.Print("auth unary interceptor \n") if !IfNeedValidate(ctx, info.FullMethod) { return handler(ctx, req) } @@ -57,13 +58,14 @@ func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnarySe func (a *Auth) StreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if !IfNeedValidate(ss.Context(), info.FullMethod) { + // log.Printf("no need stream validate %s", info.FullMethod) return handler(srv, ss) } md, ok := metadata.FromIncomingContext(ss.Context()) if !ok { return status.Errorf(codes.Unauthenticated, "metadata not exists in context") } - xApiKey := md.Get("x-api-key") + xApiKey := md.Get(gosdk.HeaderXApiKey) if len(xApiKey) != 0 { return a.apikey.StreamInterceptor(srv, ss, info, handler) } @@ -89,3 +91,28 @@ func (a *Auth) Priority() int { func (a *Auth) Name() string { return a.name } + +func (a *Auth) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + // log.Printf("no need stream validate %s", info.FullMethod) + return streamer(ctx, desc, cc, method, opts...) + } + md, ok := metadata.FromOutgoingContext(ctx) + if !ok { + return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") + } + xApiKey := md.Get("x-api-key") + if len(xApiKey) != 0 { + return a.apikey.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) + } + authorization := a.jwt.GetAuthorizationFromMetadata(md) + + if authorization == "" { + return nil, gosdk.NewError(pkg.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + } + if strings.Contains(authorization, "Bearer") { + return a.jwt.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) + + } + return a.ak.StreamClientInterceptor(ctx, desc, cc, method, streamer, opts...) +} diff --git a/internal/middleware/auth/auth_test.go b/internal/middleware/auth/auth_test.go index 96ed475..4ae05d9 100644 --- a/internal/middleware/auth/auth_test.go +++ b/internal/middleware/auth/auth_test.go @@ -7,10 +7,12 @@ import ( "crypto/cipher" "crypto/rand" "crypto/rsa" + "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" + "io" "log" "net/http" "os" @@ -29,7 +31,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/pkg/utils" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" @@ -44,6 +45,21 @@ import ( type testStream struct { ctx context.Context } +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func (t *testStream) SetHeader(metadata.MD) error { return nil @@ -253,7 +269,7 @@ func TestUnaryInterceptor(t *testing.T) { } return nil, nil } - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -288,8 +304,11 @@ func TestUnaryInterceptor(t *testing.T) { c.Convey("TestUnaryInterceptor aksk", t, func() { u := &v1.Users{} bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) access, secret, appid := readInitAPP() sgin := gosdk.NewAppAuthSigner(access, secret) gwReq, err := gosdk.NewGatewayRequestFromHttp(req) @@ -342,7 +361,7 @@ func TestStreamInterceptor(t *testing.T) { mid := getMid() ctx := context.Background() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -353,6 +372,9 @@ func TestStreamInterceptor(t *testing.T) { ctx = metadata.NewIncomingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { err := ss.RecvMsg(srv) + if err != nil { + return err + } md, _ := metadata.FromIncomingContext(ss.Context()) if identify := md.Get(gosdk.HeaderXIdentity); len(identify) == 0 || identify[0] == "" { return fmt.Errorf("no app identity") @@ -360,7 +382,7 @@ func TestStreamInterceptor(t *testing.T) { if xAppKey := md.Get(gosdk.HeaderXApiKey); len(xAppKey) == 0 || xAppKey[0] == "" { return fmt.Errorf("no app key") } - return err + return ss.RecvMsg(srv) }) c.So(err, c.ShouldBeNil) ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) @@ -368,6 +390,27 @@ func TestStreamInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldNotBeNil) + // get identity err + patch := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get identity err")) + defer patch.Reset() + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "query user id base on apikey") + patch.Reset() + + // check apikey not match + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs(gosdk.HeaderXApiKey, "cnf.GetAdminAPIKey()")) + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) }) c.Convey("TestStreamInterceptor jwt", t, func() { @@ -390,13 +433,26 @@ func TestStreamInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldNotBeNil) + + ctx = metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "")) + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(),c.ShouldContainSubstring,pkg.ErrTokenMissing.Error()) }) - c.Convey("TestUnaryInterceptor aksk", t, func() { + c.Convey("TestStreamInterceptor aksk", t, func() { u := &v1.Users{} + bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) + + t.Logf("data:%s", string(bData)) req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) access, secret, appid := readInitAPP() sgin := gosdk.NewAppAuthSigner(access, secret) gwReq, err := gosdk.NewGatewayRequestFromHttp(req) @@ -426,13 +482,11 @@ func TestStreamInterceptor(t *testing.T) { return fmt.Errorf("no app identity") } if xAccessKey := md.Get(gosdk.HeaderXAccessKey); len(xAccessKey) == 0 || xAccessKey[0] == "" { - t.Logf("error metadata:%v", md) return fmt.Errorf("no app access key") } return err }) c.So(err, c.ShouldBeNil) - sign1 := gosdk.NewAppAuthSigner("ASDASDCASDFQ", "ASDASDCASDFQ") gwReq1, err := gosdk.NewGatewayRequestFromHttp(req) c.So(err, c.ShouldBeNil) @@ -452,13 +506,187 @@ func TestStreamInterceptor(t *testing.T) { }) c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAppSignatureInvalid.Error()) + + // recv msg err + ctx = metadata.NewIncomingContext(context.Background(), md) + patch4 := gomonkey.ApplyFuncReturn((*testStream).RecvMsg, fmt.Errorf("recv msg err")) + + defer patch4.Reset() + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + patch4.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "recv msg err") + // recv io.EOF + patch5 := gomonkey.ApplyFuncReturn((*testStream).RecvMsg, io.EOF) + defer patch5.Reset() + err = mid.StreamInterceptor(&v1.Users{}, &greeterSayHelloWebsocketServer{ServerStream: &testStream{ctx: ctx}}, &grpc.StreamServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/POST"}, func(srv interface{}, ss grpc.ServerStream) error { + err := ss.RecvMsg(srv) + return err + }) + patch5.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, io.EOF.Error()) + + }) +} +func TestStreamClientInterceptor(t *testing.T) { + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + config := config.ReadConfig(env) + cnf := cfg.NewConfig(config) + mid := getMid() + ctx := context.Background() + + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + + pd, _ := gateway.NewDescription(pbFile) + R.LoadAllRouters(pd) + c.Convey("TestStreamClientInterceptor apikey", t, func() { + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + st, err := mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + err = st.SendMsg(&v1.Users{}) + c.So(err, c.ShouldBeNil) + + // no need validate + st, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldBeEmpty) + // no outgoing context + _, err = mid.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + // do not need validate + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + // not match apikey + patch := gomonkey.ApplyFuncReturn((*cfg.Config).GetAdminAPIKey, "") + defer patch.Reset() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, pkg.ErrAPIKeyNotMatch.Error()) + patch.Reset() + // get owner error + patch2 := gomonkey.ApplyFuncReturn((*biz.AuthzUsecase).GetIdentity, "", fmt.Errorf("get owner err")) + defer patch2.Reset() + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs(gosdk.HeaderXApiKey, cnf.GetAdminAPIKey())) + _, err = mid.StreamClientInterceptor(ctx, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + patch2.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "query user id base on apikey") + + }) + c.Convey("TestStreamClientInterceptor jwt", t, func() { + jwt := getJWT() + outCTX := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorization", "Bearer "+jwt)) + st, err := mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldNotBeEmpty) + }) + c.Convey("TestStreamClientInterceptor aksk", t, func() { + u := &v1.Users{} + + bData, _ := json.Marshal(u) + hash := sha256.Sum256(bData) + hexStr := fmt.Sprintf("%x", hash) + + t.Logf("data:%s", string(bData)) + req, _ := http.NewRequest(http.MethodPost, "/test/post", bytes.NewReader(bData)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(gosdk.HeaderXContentSha256, hexStr) + access, secret, appid := readInitAPP() + sgin := gosdk.NewAppAuthSigner(access, secret) + gwReq, err := gosdk.NewGatewayRequestFromHttp(req) + c.So(err, c.ShouldBeNil) + err = sgin.SignRequest(gwReq) + c.So(err, c.ShouldBeNil) + md := metadata.New(make(map[string]string)) + + headers := gwReq.Headers + for _, k := range headers.Keys() { + // t.Logf("header:%s,value:%s", k, headers.Get(k)) + md.Append(k, headers.Get(k)) + } + md.Append("uri", "/test/post") + md.Append("x-http-method", http.MethodPost) + patch := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetSecret, secret, nil) + patch = patch.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, appid, nil) + defer patch.Reset() + outCTX := metadata.NewOutgoingContext(context.Background(), md) + st, err := mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok := metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldNotBeEmpty) + // no need validate + st, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + out, ok = metadata.FromOutgoingContext(st.Context()) + c.So(ok, c.ShouldBeTrue) + c.So(out.Get(gosdk.HeaderXIdentity), c.ShouldBeEmpty) + // no context + _, err = mid.StreamClientInterceptor(req.Context(), nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + + // streamer err + _, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, fmt.Errorf("streamer err") + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "streamer err") + + // get owner err + + patch2 := gomonkey.ApplyFuncReturn((*biz.AccessKeyAuth).GetAppOwner, "", fmt.Errorf("get owner err")) + defer patch2.Reset() + _, err = mid.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/POST", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get owner err") + }) } func TestTestUnaryInterceptorErr(t *testing.T) { mid := getMid() ctx := context.Background() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -485,7 +713,7 @@ func TestTestUnaryInterceptorErr(t *testing.T) { func TestStreamInterceptorErr(t *testing.T) { mid := getMid() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") diff --git a/internal/middleware/auth/headers.go b/internal/middleware/auth/headers.go index b47c0f1..7a44932 100644 --- a/internal/middleware/auth/headers.go +++ b/internal/middleware/auth/headers.go @@ -1,119 +1,134 @@ package auth -import ( - "context" - "net/http" - "sync" - - "google.golang.org/grpc" - "google.golang.org/grpc/metadata" -) - type Header interface { Set(key, value string) SendHeader(key, value string) } -type GrpcHeader struct { - in metadata.MD - ctx context.Context - out metadata.MD -} -type httpHeader struct { - w http.ResponseWriter - r *http.Request -} -type GrpcStreamHeader struct { - *GrpcHeader - ss grpc.ServerStream -} -var headerPool = &sync.Pool{ - New: func() interface{} { - return &GrpcHeader{} - }, -} -var httpHeaderPool = &sync.Pool{ - New: func() interface{} { - return &httpHeader{} - }, -} +// type GrpcHeader struct { +// in metadata.MD +// ctx context.Context +// out metadata.MD +// } +// type httpHeader struct { +// w http.ResponseWriter +// r *http.Request +// } +// type GrpcStreamHeader struct { +// *GrpcHeader +// ss grpc.ServerStream +// } -var grpcStreamHeaderPool = &sync.Pool{ - New: func() interface{} { - return &GrpcStreamHeader{} - }, -} +// var headerPool = &sync.Pool{ +// New: func() interface{} { +// return &GrpcHeader{} +// }, +// } +// var httpHeaderPool = &sync.Pool{ +// New: func() interface{} { +// return &httpHeader{} +// }, +// } -func (g *GrpcHeader) Release() { - g.ctx = nil - g.in = nil - g.out = nil - headerPool.Put(g) -} -func (g *GrpcHeader) Set(key, value string) { - g.in.Set(key, value) - md, _ := metadata.FromIncomingContext(g.ctx) - newMd := metadata.Join(md, g.in) - g.ctx = metadata.NewIncomingContext(g.ctx, newMd) +// var grpcStreamHeaderPool = &sync.Pool{ +// New: func() interface{} { +// return &GrpcStreamHeader{} +// }, +// } -} -func (g *GrpcHeader) SendHeader(key, value string) { - g.out.Append(key, value) - _ = grpc.SendHeader(g.ctx, g.out) - g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) -} -func (g *httpHeader) Release() { - g.w = nil - g.r = nil - httpHeaderPool.Put(g) -} -func (g *httpHeader) Set(key, value string) { - g.r.Header.Add(key, value) +// func (g *GrpcHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// headerPool.Put(g) +// } +// func (g *GrpcHeader) Set(key, value string) { +// g.in.Set(key, value) +// md, _ := metadata.FromIncomingContext(g.ctx) +// newMd := metadata.Join(md, g.in) +// g.ctx = metadata.NewIncomingContext(g.ctx, newMd) +// log.Printf("grpc header set key:%v value:%v", key, value) -} -func (g *httpHeader) SendHeader(key, value string) { - g.w.Header().Add(key, value) -} -func (g *GrpcStreamHeader) Release() { - g.ctx = nil - g.in = nil - g.out = nil - g.ss = nil - grpcStreamHeaderPool.Put(g) +// } +// func (g *GrpcHeader) SendHeader(key, value string) { +// g.out.Append(key, value) +// _ = grpc.SendHeader(g.ctx, g.out) +// g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) +// } +// func (g *httpHeader) Release() { +// g.w = nil +// g.r = nil +// httpHeaderPool.Put(g) +// } +// func (g *httpHeader) Set(key, value string) { +// g.r.Header.Add(key, value) -} -func (g *GrpcStreamHeader) Set(key, value string) { - g.in.Append(key, value) - newCtx := metadata.NewIncomingContext(g.ctx, g.in) - g.ctx = newCtx - _ = g.ss.SetHeader(g.in) -} -func (g *GrpcStreamHeader) SendHeader(key, value string) { - g.out.Append(key, value) - _ = g.ss.SendHeader(g.out) - g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) -} +// } +// func (g *httpHeader) SendHeader(key, value string) { +// g.w.Header().Add(key, value) +// } -func NewGrpcHeader(in metadata.MD, ctx context.Context, out metadata.MD) *GrpcHeader { - // return &GrpcHeader{in: in, ctx: ctx, out: out} - header := headerPool.Get().(*GrpcHeader) - header.in = in - header.ctx = ctx - header.out = out - return header -} -func NewHttpHeader(w http.ResponseWriter, r *http.Request) *httpHeader { - // return &httpHeader{w: w, r: r} - header := httpHeaderPool.Get().(*httpHeader) - header.w = w - header.r = r - return header -} +// func (g *GrpcStreamHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// g.ss = nil +// grpcStreamHeaderPool.Put(g) -func NewGrpcStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, ss grpc.ServerStream) *GrpcStreamHeader { - // return &GrpcStreamHeader{&GrpcHeader{in: in, ctx: ctx, out: out}, ss} - header := grpcStreamHeaderPool.Get().(*GrpcStreamHeader) - header.GrpcHeader = NewGrpcHeader(in, ctx, out) - header.ss = ss - return header -} +// } +// func (g *GrpcStreamHeader) Set(key, value string) { +// g.in.Set(key, value) +// newCtx := metadata.NewIncomingContext(g.ctx, g.in) +// g.ctx = newCtx +// _ = g.ss.SetHeader(g.in) +// } +// func (g *GrpcStreamHeader) SendHeader(key, value string) { +// g.out.Append(key, value) +// _ = g.ss.SendHeader(g.out) +// g.ctx = metadata.NewOutgoingContext(g.ctx, g.out) +// } + +// func (g *GrpcClientStreamHeader) Release() { +// g.ctx = nil +// g.in = nil +// g.out = nil +// g.cs = nil +// grpcClientHeaderPool.Put(g) + +// } +// func (g *GrpcClientStreamHeader) Set(key, value string) { +// g.out.Set(key, value) +// newCtx := metadata.NewOutgoingContext(g.ctx, g.in) +// g.ctx = newCtx +// // _ = g.cs.(g.in) + +// } +// func NewGrpcHeader(in metadata.MD, ctx context.Context, out metadata.MD) *GrpcHeader { +// // return &GrpcHeader{in: in, ctx: ctx, out: out} +// header := headerPool.Get().(*GrpcHeader) +// header.in = in +// header.ctx = ctx +// header.out = out +// return header +// } +// func NewHttpHeader(w http.ResponseWriter, r *http.Request) *httpHeader { +// // return &httpHeader{w: w, r: r} +// header := httpHeaderPool.Get().(*httpHeader) +// header.w = w +// header.r = r +// return header +// } + +// func NewGrpcStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, ss grpc.ServerStream) *GrpcStreamHeader { +// // return &GrpcStreamHeader{&GrpcHeader{in: in, ctx: ctx, out: out}, ss} +// header := grpcStreamHeaderPool.Get().(*GrpcStreamHeader) +// header.GrpcHeader = NewGrpcHeader(in, ctx, out) +// header.ss = ss +// return header +// } +// func NewGrpcClientStreamHeader(in metadata.MD, ctx context.Context, out metadata.MD, cs grpc.ClientStream) *GrpcClientStreamHeader { +// header := grpcClientHeaderPool.Get().(*GrpcClientStreamHeader) +// header.GrpcHeader = NewGrpcHeader(in, ctx, out) +// header.cs = cs +// return header +// } diff --git a/internal/middleware/auth/headers_test.go b/internal/middleware/auth/headers_test.go index 8344ba2..27cdf96 100644 --- a/internal/middleware/auth/headers_test.go +++ b/internal/middleware/auth/headers_test.go @@ -1,32 +1,32 @@ package auth_test -import ( - "net/http" - "testing" +// import ( +// "net/http" +// "testing" - "github.com/begonia-org/begonia/internal/middleware/auth" - c "github.com/smartystreets/goconvey/convey" -) +// "github.com/begonia-org/begonia/internal/middleware/auth" +// c "github.com/smartystreets/goconvey/convey" +// ) -type responseWriter struct { -} +// type responseWriter struct { +// } -func (r *responseWriter) Header() http.Header { - return make(http.Header) -} -func (r *responseWriter) Write([]byte) (int, error) { - return 0, nil -} -func (r *responseWriter) WriteHeader(int) { +// func (r *responseWriter) Header() http.Header { +// return make(http.Header) +// } +// func (r *responseWriter) Write([]byte) (int, error) { +// return 0, nil +// } +// func (r *responseWriter) WriteHeader(int) { -} -func TestHeaders(t *testing.T) { - c.Convey("TestHeaders", t, func() { - req, _ := http.NewRequest("GET", "http://localhost", nil) - h := auth.NewHttpHeader(&responseWriter{}, req) - c.So(h, c.ShouldNotBeNil) - h.Set("key", "value") - h.SendHeader("key", "value") - h.Release() - }) -} +// } +// func TestHeaders(t *testing.T) { +// c.Convey("TestHeaders", t, func() { +// req, _ := http.NewRequest("GET", "http://localhost", nil) +// h := auth.NewHttpHeader(&responseWriter{}, req) +// c.So(h, c.ShouldNotBeNil) +// h.Set("key", "value") +// h.SendHeader("key", "value") +// h.Release() +// }) +// } diff --git a/internal/middleware/auth/jwt.go b/internal/middleware/auth/jwt.go index 381a978..c5c35b2 100644 --- a/internal/middleware/auth/jwt.go +++ b/internal/middleware/auth/jwt.go @@ -104,18 +104,21 @@ func (a *JWTAuth) checkJWTItem(ctx context.Context, payload *api.BasicAuth, toke } return true, nil } -func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, rspHeader Header, reqHeader Header) (ok bool, err error) { +func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, io *metadata.MD) (ok bool, err error) { payload, errAuth := a.jwt2BasicAuth(authorization) err = errAuth if err != nil { return false, err } + io.Set("x-uid", payload.Uid) + strArr := strings.Split(authorization, " ") token := strArr[1] ok, err = a.checkJWTItem(ctx, payload, token) if err != nil || !ok { return false, err } + io.Set("x-token", token) left := payload.Expiration - time.Now().Unix() // expiration := a.config.GetJWTExpiration() @@ -156,63 +159,73 @@ func (a *JWTAuth) checkJWT(ctx context.Context, authorization string, rspHeader } // 旧token加入黑名单 go a.biz.PutBlackList(ctx, a.config.GetUserBlackListKey(tiga.GetMd5(token))) - rspHeader.SendHeader("Authorization", fmt.Sprintf("Bearer %s", newToken)) + // rspHeader.Set("Authorization", fmt.Sprintf("Bearer %s", newToken)) + _ = grpc.SetHeader(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", newToken))) token = newToken } - // 设置uid - reqHeader.Set("x-token", token) - reqHeader.Set("x-uid", payload.Uid) - reqHeader.Set(gosdk.HeaderXIdentity, payload.Uid) + io.Set("x-token", token) return true, nil } -func (a *JWTAuth) jwtValidator(ctx context.Context, headers Header) (context.Context, error) { - - md, _ := metadata.FromIncomingContext(ctx) +func (a *JWTAuth) jwtValidator(ctx context.Context) (context.Context, error) { + in, inOK := metadata.FromIncomingContext(ctx) + if !inOK { + in = metadata.MD{} + } + out, ok := metadata.FromOutgoingContext(ctx) + if !ok { + out = metadata.MD{} + } + md := metadata.Join(in, out) token := a.GetAuthorizationFromMetadata(md) if token == "" { - return nil, status.Errorf(codes.Unauthenticated, "token not exists in context") - } + if token = a.GetAuthorizationFromMetadata(out); token == "" { + return nil, status.Errorf(codes.Unauthenticated, "token not exists in context") + } - ok, err := a.checkJWT(ctx, token, headers, headers) + } + ioMD := metadata.New(make(map[string]string)) + ok, err := a.checkJWT(ctx, token, &ioMD) if err != nil || !ok { return nil, status.Errorf(codes.Unauthenticated, "check token error,%v", err) } + in.Set("x-token", ioMD.Get("x-token")...) + in.Set("x-uid", ioMD.Get("x-uid")...) + in.Set(gosdk.HeaderXIdentity, ioMD.Get("x-uid")...) + newCtx := metadata.NewIncomingContext(ctx, in) + + out.Set("x-token", ioMD.Get("x-token")...) + out.Set("x-uid", ioMD.Get("x-uid")...) + out.Set(gosdk.HeaderXIdentity, ioMD.Get("x-uid")...) - newCtx := metadata.NewIncomingContext(ctx, md) + newCtx = metadata.NewOutgoingContext(newCtx, out) return newCtx, nil // return handler(newCtx, req) } func (a *JWTAuth) RequestBefore(ctx context.Context, info *grpc.UnaryServerInfo, req interface{}) (context.Context, error) { - in, ok := metadata.FromIncomingContext(ctx) - if !ok { - return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") - } - out, ok := metadata.FromOutgoingContext(ctx) - if !ok { - out = metadata.MD{} - } - headers := NewGrpcHeader(in, ctx, out) - defer headers.Release() - _, err := a.jwtValidator(ctx, headers) + ctx, err := a.jwtValidator(ctx) if err != nil { return nil, err } - return headers.ctx, nil + return ctx, nil } -func (a *JWTAuth) ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) { +func (a *JWTAuth) ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) { // headers := NewGrpcStreamHeader(in, ctx, out,ss) - ctx, err := a.jwtValidator(ctx, headers) + ctx, err := a.jwtValidator(ctx) return ctx, err } func (a *JWTAuth) StreamRequestBefore(ctx context.Context, ss grpc.ServerStream, info *grpc.StreamServerInfo, req interface{}) (grpc.ServerStream, error) { - grpcStream := NewGrpcStream(ss, info.FullMethod, ss.Context(), a) + ctx, err := a.jwtValidator(ctx) + if err != nil { + return nil, err + } + grpcStream := NewGrpcStream(ss, info.FullMethod, ctx, a) return grpcStream, nil } @@ -272,3 +285,15 @@ func (jwt *JWTAuth) Priority() int { func (jwt *JWTAuth) Name() string { return jwt.name } + +func (jwt *JWTAuth) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + if !IfNeedValidate(ctx, method) { + return streamer(ctx, desc, cc, method, opts...) + } + + ctx, err := jwt.jwtValidator(ctx) + if err != nil { + return nil, err + } + return streamer(ctx, desc, cc, method, opts...) +} diff --git a/internal/middleware/auth/jwt_test.go b/internal/middleware/auth/jwt_test.go index d1e388e..af22742 100644 --- a/internal/middleware/auth/jwt_test.go +++ b/internal/middleware/auth/jwt_test.go @@ -19,7 +19,7 @@ import ( "github.com/begonia-org/begonia/internal/pkg" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" - "github.com/begonia-org/begonia/internal/pkg/routers" + gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" api "github.com/begonia-org/go-sdk/api/user/v1" "github.com/bsm/redislock" @@ -62,7 +62,7 @@ func TestJWTUnaryInterceptor(t *testing.T) { config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -85,14 +85,7 @@ func TestJWTUnaryInterceptor(t *testing.T) { }) c.So(err, c.ShouldBeNil) - _, err = jwt.UnaryInterceptor(context.Background(), &hello.HelloRequest{}, &grpc.UnaryServerInfo{ - FullMethod: "/integration.TestService/Get", - }, func(ctx context.Context, req interface{}) (interface{}, error) { - return nil, nil - - }) - c.So(err, c.ShouldNotBeNil) - c.So(err.Error(), c.ShouldContainSubstring, "metadata not exists in context") + _, err = jwt.UnaryInterceptor(metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test")), &hello.HelloRequest{}, &grpc.UnaryServerInfo{ FullMethod: "/integration.TestService/Get", @@ -289,7 +282,7 @@ func TestJWTStreamInterceptor(t *testing.T) { config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") @@ -320,3 +313,40 @@ func TestJWTStreamInterceptor(t *testing.T) { c.So(err, c.ShouldBeNil) }) } + +func TestJWTClientStream(t *testing.T){ + c.Convey("TestJWTClientStream", t, func(){ + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + config := config.ReadConfig(env) + cnf := cfg.NewConfig(config) + + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata") + + pd, _ := gateway.NewDescription(pbFile) + R.LoadAllRouters(pd) + user := data.NewUserRepo(config, gateway.Log) + userAuth := crypto.NewUsersAuth(cnf) + authzRepo := data.NewAuthzRepo(config, gateway.Log) + appRepo := data.NewAppRepo(config, gateway.Log) + authz := biz.NewAuthzUsecase(authzRepo, user, appRepo, gateway.Log, userAuth, cnf) + jwt := auth.NewJWTAuth(cnf, tiga.NewRedisDao(config), authz, gateway.Log) + jwt.SetPriority(1) + outCTX:=metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAuthorization, cnf.GetAdminAPIKey())) + _, err := jwt.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldBeNil) + + outCTX=metadata.NewOutgoingContext(context.Background(), metadata.Pairs(gosdk.HeaderXAuthorization, cnf.GetAdminAPIKey())) + _, err = jwt.StreamClientInterceptor(outCTX, nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }) + c.So(err, c.ShouldNotBeNil) + + }) +} \ No newline at end of file diff --git a/internal/middleware/auth/stream.go b/internal/middleware/auth/stream.go index 2fde49c..e4a89d8 100644 --- a/internal/middleware/auth/stream.go +++ b/internal/middleware/auth/stream.go @@ -2,24 +2,34 @@ package auth import ( "context" + "errors" + "io" "sync" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) type StreamValidator interface { - ValidateStream(ctx context.Context, req interface{}, fullName string, headers Header) (context.Context, error) + ValidateStream(ctx context.Context, req interface{}, fullName string) (context.Context, error) } type grpcServerStream struct { grpc.ServerStream - fullName string - validate StreamValidator - ctx context.Context + fullName string + validate StreamValidator + ctx context.Context + firstFrame bool } +// type grpcClientStream struct { +// grpc.ClientStream +// fullName string +// ctx context.Context +// validate StreamValidator +// firstFrame bool +// } + var streamPool = &sync.Pool{ New: func() interface{} { return &grpcServerStream{ @@ -28,11 +38,17 @@ var streamPool = &sync.Pool{ }, } +// var clientStreamPool = &sync.Pool{ +// New: func() interface{} { +// return &grpcClientStream{} +// }, +// } + func NewGrpcStream(s grpc.ServerStream, fullName string, ctx context.Context, validator StreamValidator) *grpcServerStream { stream := streamPool.Get().(*grpcServerStream) stream.ServerStream = s stream.fullName = fullName - stream.ctx = s.Context() + stream.ctx = ctx stream.validate = validator return stream } @@ -41,29 +57,28 @@ func (g *grpcServerStream) Release() { g.fullName = "" g.ServerStream = nil g.validate = nil + g.firstFrame = false streamPool.Put(g) } func (g *grpcServerStream) Context() context.Context { return g.ctx } func (s *grpcServerStream) RecvMsg(m interface{}) error { - if err := s.ServerStream.RecvMsg(m); err != nil { - return err + var err error + if err = s.ServerStream.RecvMsg(m); err != nil && !errors.Is(err, io.EOF) { + return status.Errorf(codes.Internal, "recv msg err:%s", err.Error()) } - in, ok := metadata.FromIncomingContext(s.Context()) - if !ok { - return status.Errorf(codes.Unauthenticated, "metadata not exists in context") + if err != nil { + return err } - out, ok := metadata.FromOutgoingContext(s.Context()) - if !ok { - out = metadata.MD{} + if !s.firstFrame { + ctx, err := s.validate.ValidateStream(s.Context(), m, s.fullName) + s.ctx = ctx + s.firstFrame = true + return err } - header := NewGrpcStreamHeader(in, s.Context(), out, s.ServerStream) - _, err := s.validate.ValidateStream(s.Context(), m, s.fullName, header) - s.ctx = header.ctx - header.Release() - return err + return nil } diff --git a/internal/middleware/http.go b/internal/middleware/http.go index 8bf0804..9e6d1ce 100644 --- a/internal/middleware/http.go +++ b/internal/middleware/http.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - "github.com/begonia-org/begonia/internal/pkg/routers" + "github.com/begonia-org/begonia/gateway" gosdk "github.com/begonia-org/go-sdk" _ "github.com/begonia-org/go-sdk/api/app/v1" _ "github.com/begonia-org/go-sdk/api/endpoint/v1" @@ -38,6 +38,7 @@ type HttpStream struct { grpc.ServerStream FullMethod string } + type Http struct { priority int name string @@ -53,12 +54,14 @@ func (s *HttpStream) SendMsg(m interface{}) error { if !strings.EqualFold(protocol[0], "application/json") { return s.ServerStream.SendMsg(m) } - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(s.FullMethod) // 对内置服务的http响应进行格式化 - if routersList.IsLocalSrv(s.FullMethod) || router.UseJsonResponse { + if routersList.IsLocalSrv(s.FullMethod) || (router != nil && router.UseJsonResponse) { rsp, _ := grpcToHttpResponse(m, nil) - return s.ServerStream.SendMsg(rsp) + // _=grpc.SetHeader(s.Context(),metadata.Pairs("content-type", "application/json")) + err := s.ServerStream.SendMsg(rsp) + return err } } return s.ServerStream.SendMsg(m) @@ -125,6 +128,9 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error if st.Code() == codes.Unimplemented { code = int32(common.Code_NOT_FOUND) } + if st.Code() == codes.InvalidArgument { + code = int32(common.Code_PARAMS_ERROR) + } return &common.HttpResponse{ Code: code, @@ -163,10 +169,11 @@ func (h *Http) UnaryInterceptor(ctx context.Context, req interface{}, info *grpc if !strings.EqualFold(protocol[0], "application/json") { return handler(ctx, req) } - routersList := routers.Get() + routersList := gateway.GetRouter() router := routersList.GetRouteByGrpcMethod(info.FullMethod) // 对内置服务的http响应进行格式化 if routersList.IsLocalSrv(info.FullMethod) || router.UseJsonResponse { + ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("content-type", "application/json")) rsp, err := handler(ctx, req) if _, ok := rsp.(*httpbody.HttpBody); ok { return rsp, err @@ -183,7 +190,13 @@ func (h *Http) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *gr return handler(srv, stream) } - +func (h *Http) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ss, err := streamer(ctx, desc, cc, method, opts...) + if err != nil { + return nil, status.Errorf(codes.Internal, "create client stream error:%s", err.Error()) + } + return ss, nil +} func NewHttp() *Http { return &Http{name: "http"} } diff --git a/internal/middleware/http_test.go b/internal/middleware/http_test.go index 805d4c0..fd17090 100644 --- a/internal/middleware/http_test.go +++ b/internal/middleware/http_test.go @@ -11,7 +11,6 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/middleware" "github.com/begonia-org/begonia/internal/pkg" - "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" hello "github.com/begonia-org/go-sdk/api/example/v1" user "github.com/begonia-org/go-sdk/api/user/v1" @@ -70,7 +69,7 @@ func (x *greeterSayHelloWebsocketServer) Context() context.Context { func TestStreamInterceptor(t *testing.T) { c.Convey("test stream interceptor", t, func() { mid := middleware.NewHttp() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") @@ -84,10 +83,37 @@ func TestStreamInterceptor(t *testing.T) { c.So(err, c.ShouldBeNil) }) } +func TestHttpStreamClientInterceptor(t *testing.T){ + c.Convey("test http stream client interceptor", t, func() { + mid := middleware.NewHttp() + R := gateway.GetRouter() + _, filename, _, _ := runtime.Caller(0) + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") + + pd, err := gateway.NewDescription(pbFile) + c.So(err, c.ShouldBeNil) + R.LoadAllRouters(pd) + stream, err := mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("grpcgateway-accept", "application/json")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + c.So(stream, c.ShouldNotBeNil) + + stream, err = mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("grpcgateway-accept", "application/json")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil,fmt.Errorf("new stream err") + }, + ) + c.So(err, c.ShouldNotBeNil) + c.So(stream, c.ShouldBeNil) + + }) + +} func TestUnaryInterceptor(t *testing.T) { c.Convey("test unary interceptor", t, func() { mid := middleware.NewHttp() - R := routers.Get() + R := gateway.GetRouter() _, filename, _, _ := runtime.Caller(0) pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata") @@ -120,6 +146,11 @@ func TestUnaryInterceptor(t *testing.T) { }) c.So(err, c.ShouldNotBeNil) c.So(req, c.ShouldNotBeNil) + req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, status.Error(codes.InvalidArgument, "test") + }) + c.So(err, c.ShouldNotBeNil) + c.So(req, c.ShouldNotBeNil) req, err = mid.UnaryInterceptor(ctx, &hello.HelloRequest{}, &grpc.UnaryServerInfo{FullMethod: "/INTEGRATION.TESTSERVICE/GET"}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, fmt.Errorf("test") diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index 234030f..e0d69f1 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -102,3 +102,11 @@ func (p *PluginsApply) StreamInterceptorChains() []grpc.StreamServerInterceptor } return chains } + +func (p *PluginsApply) StreamClientInterceptorChains() []grpc.StreamClientInterceptor { + chains := make([]grpc.StreamClientInterceptor, 0) + for _, plugin := range p.Plugins { + chains = append(chains, plugin.(gosdk.LocalPlugin).StreamClientInterceptor) + } + return chains +} diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 0f83484..7c3bf6d 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -37,6 +37,7 @@ func TestMiddlewareUnaryInterceptorChains(t *testing.T) { // mid.SetPriority(1) c.So(len(mid.StreamInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) c.So(len(mid.UnaryInterceptorChains()), c.ShouldBeGreaterThanOrEqualTo, 0) + c.So(len(mid.StreamClientInterceptorChains()),c.ShouldBeGreaterThanOrEqualTo, 0) plugins := cnf.GetPlugins() plugins["test"] = 1 diff --git a/internal/middleware/rpc.go b/internal/middleware/rpc.go index 010c377..bf3f045 100644 --- a/internal/middleware/rpc.go +++ b/internal/middleware/rpc.go @@ -22,22 +22,13 @@ import ( type RPCPluginCaller interface{} -// type rpcPluginCallerImpl struct { -// plugins gosdk.Plugins -// } -// func NewRPCPluginCaller() RPCPluginCaller { -// return &rpcPluginCallerImpl{ -// plugins: make(gosdk.Plugins, 0), -// } -// } type pluginImpl struct { priority int name string timeout time.Duration lb lb.LoadBalance - // api.PluginServiceClient } func (p *pluginImpl) SetPriority(priority int) { @@ -54,17 +45,16 @@ func (p *pluginImpl) Name() string { func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - md = metadata.New(nil) + md = metadata.New(make(map[string]string)) } - rsp, err := p.Apply(ctx, req, info.FullMethod) + rsp, header, err := p.Apply(ctx, req, info.FullMethod) if err != nil { return nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } + for k, v := range header { + md[k] = append(md[k], v...) - for k, v := range rsp.Metadata { - md.Append(k, v) } - newRequest := rsp.NewRequest if newRequest != nil { err = newRequest.UnmarshalTo(req.(proto.Message)) @@ -97,15 +87,15 @@ func (p *pluginImpl) getEndpoint(ctx context.Context) (lb.Endpoint, error) { return endpoint, nil } -func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName string) (*api.PluginResponse, error) { +func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName string) (*api.PluginResponse, metadata.MD, error) { endpoint, err := p.getEndpoint(ctx) if err != nil { - return nil, err + return nil, nil, err } cn, err := endpoint.Get(ctx) if err != nil { - return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + return nil, nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") } defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) @@ -113,14 +103,38 @@ func (p *pluginImpl) Apply(ctx context.Context, in interface{}, fullMethodName s plugin := api.NewPluginServiceClient(conn) anyReq, err := anypb.New(in.(proto.Message)) if err != nil { - return nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") + return nil, nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") } - return plugin.Apply(ctx, &api.PluginRequest{ + var header, trailer metadata.MD + rsp, err := plugin.Apply(ctx, &api.PluginRequest{ Request: anyReq, FullMethodName: fullMethodName, - }) - // return plugin.Call(ctx, anyReq, opts...) + }, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return nil, nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") + } + return rsp, header, nil +} +func (p *pluginImpl) Metadata(ctx context.Context, in *emptypb.Empty) (metadata.MD, error) { + endpoint, err := p.getEndpoint(ctx) + if err != nil { + return nil, err + } + cn, err := endpoint.Get(ctx) + if err != nil { + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + } + defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) + conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) + plugin := api.NewPluginServiceClient(conn) + var header, trailer metadata.MD + _, err = plugin.Metadata(ctx, in, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + return nil, gosdk.NewError(fmt.Errorf("call plugin metadata error:%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "metadata") + } + return header, nil + } func (p *pluginImpl) Info(ctx context.Context, in *emptypb.Empty, opts ...grpc.CallOption) (*api.PluginInfo, error) { endpoint, err := p.getEndpoint(ctx) @@ -137,14 +151,26 @@ func (p *pluginImpl) Info(ctx context.Context, in *emptypb.Empty, opts ...grpc.C return plugin.Info(ctx, in, opts...) } func (p *pluginImpl) StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + md, err := p.Metadata(ss.Context(), &emptypb.Empty{}) + if err != nil { + return err + + } + in,_ := metadata.FromIncomingContext(ss.Context()) - grpcStream := NewGrpcPluginStream(ss, info.FullMethod, ss.Context(), p) + ctx := metadata.NewIncomingContext(ss.Context(), metadata.Join(in, md)) + grpcStream := NewGrpcPluginStream(ss, info.FullMethod, ctx, p) if grpcStream != nil { defer grpcStream.Release() } return handler(srv, grpcStream) +} +func (p *pluginImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + + return streamer(ctx, desc, cc, method, opts...) + } func NewPluginImpl(lb lb.LoadBalance, name string, timeout time.Duration) *pluginImpl { return &pluginImpl{ diff --git a/internal/middleware/rpc_test.go b/internal/middleware/rpc_test.go index dd09e60..ed1505d 100644 --- a/internal/middleware/rpc_test.go +++ b/internal/middleware/rpc_test.go @@ -11,6 +11,7 @@ import ( "github.com/begonia-org/begonia/internal/middleware" goloadbalancer "github.com/begonia-org/go-loadbalancer" hello "github.com/begonia-org/go-sdk/api/example/v1" + api "github.com/begonia-org/go-sdk/api/plugin/v1" "github.com/begonia-org/go-sdk/example" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" @@ -20,6 +21,21 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) +type testClientStream struct { + ctx context.Context + grpc.ClientStream +} + +func (t *testClientStream) Context() context.Context { + return t.ctx +} +func (t *testClientStream) SendMsg(m interface{}) error { + return nil +} +func (t *testClientStream) RecvMsg(m interface{}) error { + return nil + +} func TestPluginUnaryInterceptor(t *testing.T) { c.Convey("test plugin unary interceptor", t, func() { go example.RunPlugins(":9850") @@ -57,19 +73,23 @@ func TestPluginUnaryInterceptor(t *testing.T) { return ss.RecvMsg(srv) }) c.So(err, c.ShouldBeNil) - patch2 := gomonkey.ApplyFuncSeq(metadata.FromIncomingContext, []gomonkey.OutputCell{{ - Values: gomonkey.Params{metadata.New(map[string]string{"test": "test"}), true}, - Times: 2, - }, - { - Values: gomonkey.Params{nil, false}, - }, - }) - defer patch2.Reset() + // patch2 := gomonkey.ApplyFuncSeq(metadata.FromIncomingContext, []gomonkey.OutputCell{{ + // Values: gomonkey.Params{metadata.New(map[string]string{"test": "test"}), true}, + // Times: 4, + // }, + // { + // Values: gomonkey.Params{nil, false}, + // }, + // }) + // defer patch2.Reset() err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { - return ss.RecvMsg(srv) + err := ss.RecvMsg(srv) + if err != nil { + t.Logf("recv msg error: %v", err) + } + return err }) - patch2.Reset() + // patch2.Reset() c.So(err, c.ShouldBeNil) }) @@ -142,20 +162,61 @@ func TestPluginUnaryInterceptorErr(t *testing.T) { patch4.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "unmarshal to request error") + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: context.Background()}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get metadata from context error") + // select err + patch5 := gomonkey.ApplyMethodReturn(lb, "Select", nil, fmt.Errorf("select endpoint error")) + defer patch5.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "select endpoint error") + patch5.Reset() + enp, _ := lb.Select("127.0.0.1") + patch6 := gomonkey.ApplyMethodReturn(enp, "Get", nil, fmt.Errorf("get connection error")) + defer patch6.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + patch6.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "get connection error") + // call plugin metadata err + cli := api.NewPluginServiceClient(nil) + patch7 := gomonkey.ApplyMethodReturn(cli, "Metadata", nil, fmt.Errorf("call test plugin error")) + defer patch7.Reset() + err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { + return ss.RecvMsg(srv) + }) + patch7.Reset() + c.So(err.Error(), c.ShouldContainSubstring, "call test plugin error") + // call apply err + patch8 := gomonkey.ApplyMethodReturn(cli, "Apply", nil, fmt.Errorf("call plugin error")) + defer patch8.Reset() + _, err = mid.UnaryInterceptor(metadata.NewIncomingContext(context.Background(), metadata.Pairs("X-Forwarded-For", "127.0.0.1:9090")), &hello.HelloRequest{}, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }) + patch8.Reset() + c.So(err, c.ShouldNotBeNil) + c.So(err.Error(), c.ShouldContainSubstring, "call plugin error") }) } func TestPluginStreamInterceptorErr(t *testing.T) { c.Convey("test plugin unary interceptor", t, func() { - go example.RunPlugins(":9850") + go example.RunPlugins(":9851") time.Sleep(2 * time.Second) lb := goloadbalancer.NewGrpcLoadBalance(&goloadbalancer.Server{ Name: "test", Endpoints: []goloadbalancer.EndpointServer{ { - Addr: "127.0.0.1:9850", + Addr: "127.0.0.1:9851", }, }, Pool: &goloadbalancer.PoolConfig{ @@ -176,7 +237,7 @@ func TestPluginStreamInterceptorErr(t *testing.T) { patch.Reset() c.So(err, c.ShouldNotBeNil) c.So(err.Error(), c.ShouldContainSubstring, "recv msg error") - patch1 := gomonkey.ApplyMethodReturn(mid, "Apply", nil, fmt.Errorf("call test plugin error")) + patch1 := gomonkey.ApplyMethodReturn(mid, "Apply", nil, nil, fmt.Errorf("call test plugin error")) defer patch1.Reset() err = mid.StreamInterceptor(&hello.HelloRequest{}, &testStream{ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("test", "test"))}, &grpc.StreamServerInfo{}, func(srv interface{}, ss grpc.ServerStream) error { return ss.RecvMsg(srv) @@ -197,3 +258,34 @@ func TestPluginStreamInterceptorErr(t *testing.T) { }) } + +func TestRPCStreamClientInterceptor(t *testing.T) { + c.Convey("test rpc stream client interceptor", t, func() { + go example.RunPlugins(":9852") + time.Sleep(2 * time.Second) + lb := goloadbalancer.NewGrpcLoadBalance(&goloadbalancer.Server{ + Name: "test", + Endpoints: []goloadbalancer.EndpointServer{ + { + Addr: "127.0.0.1:9852", + }, + }, + Pool: &goloadbalancer.PoolConfig{ + MaxOpenConns: 10, + MaxIdleConns: 5, + MaxActiveConns: 5, + }, + }) + mid := middleware.NewPluginImpl(lb, "test", 3*time.Second) + c.So(mid.Name(), c.ShouldEqual, "test") + mid.SetPriority(3) + st, err := mid.StreamClientInterceptor(metadata.NewOutgoingContext(context.Background(), metadata.Pairs("key", "value")), nil, nil, "/INTEGRATION.TESTSERVICE/GET", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return &testClientStream{ctx: ctx}, nil + }, + ) + c.So(err, c.ShouldBeNil) + c.So(st, c.ShouldNotBeNil) + + }, + ) +} diff --git a/internal/middleware/stream.go b/internal/middleware/stream.go index ec3ecda..4c5e93d 100644 --- a/internal/middleware/stream.go +++ b/internal/middleware/stream.go @@ -3,6 +3,7 @@ package middleware import ( "context" "fmt" + "log" "sync" gosdk "github.com/begonia-org/go-sdk" @@ -19,6 +20,12 @@ type grpcPluginStream struct { plugin *pluginImpl ctx context.Context } +// type grpcPluginClientStream struct { +// grpc.ClientStream +// fullName string +// plugin *pluginImpl +// ctx context.Context +// } var streamPool = &sync.Pool{ New: func() interface{} { @@ -28,11 +35,19 @@ var streamPool = &sync.Pool{ }, } +// func NewGrpcPluginClientStream(s grpc.ClientStream, fullName string, ctx context.Context, plugin *pluginImpl) *grpcPluginClientStream { +// return &grpcPluginClientStream{ +// ClientStream: s, +// fullName: fullName, +// ctx: ctx, +// plugin: plugin, +// } +// } func NewGrpcPluginStream(s grpc.ServerStream, fullName string, ctx context.Context, plugin *pluginImpl) *grpcPluginStream { stream := streamPool.Get().(*grpcPluginStream) stream.ServerStream = s stream.fullName = fullName - stream.ctx = s.Context() + stream.ctx = ctx stream.plugin = plugin return stream } @@ -46,22 +61,20 @@ func (g *grpcPluginStream) Release() { func (g *grpcPluginStream) Context() context.Context { return g.ctx } + func (s *grpcPluginStream) RecvMsg(m interface{}) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - rsp, err := s.plugin.Apply(s.Context(), m, s.fullName) + rsp, header, err := s.plugin.Apply(s.Context(), m, s.fullName) if err != nil { return gosdk.NewError(fmt.Errorf("call %s plugin error: %w", s.plugin.Name(), err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } md, ok := metadata.FromIncomingContext(s.ctx) if !ok { - md = metadata.New(nil) - } - for k, v := range rsp.Metadata { - md.Append(k, v) + md = metadata.New(make(map[string]string)) } newRequest := rsp.NewRequest if newRequest != nil { @@ -70,6 +83,10 @@ func (s *grpcPluginStream) RecvMsg(m interface{}) error { return gosdk.NewError(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") } } + log.Printf("grpcPluginStream server stream pointer:%p", s) + for k, v := range header { + md[k] = append(md[k], v...) + } s.ctx = metadata.NewIncomingContext(s.ctx, md) // s.ctx = metadata.NewIncomingContext(s.ctx, metadata.New(rsp.Metadata)) diff --git a/internal/middleware/vaildator.go b/internal/middleware/vaildator.go index 7b36bb5..990d03d 100644 --- a/internal/middleware/vaildator.go +++ b/internal/middleware/vaildator.go @@ -25,6 +25,11 @@ type validatePluginStream struct { ctx context.Context validator ParamsValidator } +// type validatePluginClientStream struct { +// grpc.ClientStream +// ctx context.Context +// validator ParamsValidator +// } var validatePluginStreamPool = &sync.Pool{ New: func() interface{} { @@ -56,7 +61,17 @@ func (p *validatePluginStream) RecvMsg(m interface{}) error { return err } - +// func (p *validatePluginClientStream) SendMsg(m interface{}) error { +// err := p.validator.ValidateParams(m) +// if err != nil { +// return err +// } +// err = p.ClientStream.SendMsg(m) +// if err != nil { +// return err +// } +// return nil +// } func getFieldNamesFromJSONTags(input interface{}) map[string]string { fieldMap := make(map[string]string) @@ -393,6 +408,7 @@ func (p *ParamsValidatorImpl) Name() string { } func (p *ParamsValidatorImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + fmt.Print("params validator unary interceptor\n") err = p.ValidateParams(req) if err != nil { return nil, err @@ -410,7 +426,10 @@ func (p *ParamsValidatorImpl) StreamInterceptor(srv interface{}, ss grpc.ServerS err := handler(srv, validateStream) return err } +func (p *ParamsValidatorImpl) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return streamer(ctx, desc, cc, method, opts...) +} func NewParamsValidator() ParamsValidator { v := &ParamsValidatorImpl{ diff --git a/internal/middleware/vaildator_test.go b/internal/middleware/vaildator_test.go index 095cf0d..92edbcc 100644 --- a/internal/middleware/vaildator_test.go +++ b/internal/middleware/vaildator_test.go @@ -11,36 +11,37 @@ import ( "google.golang.org/grpc" "google.golang.org/protobuf/types/known/fieldmaskpb" ) + func TestGetValidateTags(t *testing.T) { c.Convey("test get validate tags", t, func() { validator := middleware.NewParamsValidator() validator.SetPriority(1) - data1:=&struct{ - Name string `validate:"required" json:"name"` - Age int `validate:"required" json:"age"` - Data2 struct{ + data1 := &struct { + Name string `validate:"required" json:"name"` + Age int `validate:"required" json:"age"` + Data2 struct { Name string `validate:"required" json:"name"` - Age int `validate:"required" json:"age"` + Age int `validate:"required" json:"age"` } `validate:"required" json:"data2"` - Data2List []struct{ + Data2List []struct { Name string `validate:"required" json:"name"` - Age int `validate:"required" json:"age"` + Age int `validate:"required" json:"age"` } `validate:"required" json:"data2_list"` - Data2Map map[string]struct{ + Data2Map map[string]struct { Name string `validate:"required" json:"name"` - Age int `validate:"required" json:"age"` + Age int `validate:"required" json:"age"` } `validate:"required" json:"data2_map"` }{ Name: "test", Age: 19, - Data2:struct { + Data2: struct { Name string `validate:"required" json:"name"` Age int `validate:"required" json:"age"` }{ Name: "Jane", Age: 25, - } , + }, Data2List: []struct { Name string `validate:"required" json:"name"` Age int `validate:"required" json:"age"` @@ -56,7 +57,7 @@ func TestGetValidateTags(t *testing.T) { "key2": {Name: "David", Age: 40}, }, } - tags:=validator.(*middleware.ParamsValidatorImpl).GetValidateTags(data1) + tags := validator.(*middleware.ParamsValidatorImpl).GetValidateTags(data1) c.So(tags, c.ShouldNotBeNil) }) @@ -88,7 +89,7 @@ func TestValidatorUnaryInterceptor(t *testing.T) { UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_name", "sub_msg"}}, }, { - SubName: "test", + SubName: "test", SubAge: 19, SubMsg: "test", UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_msg"}}, @@ -143,7 +144,7 @@ func TestValidatorUnaryInterceptor(t *testing.T) { c.So(err.Error(), c.ShouldContainSubstring, "HelloRequestWithValidator.Subs[0].SubAge") req6 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) - req6.Sub.SubAge=16 + req6.Sub.SubAge = 16 _, err = validator.UnaryInterceptor(context.Background(), req6, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { return nil, nil }) @@ -152,8 +153,8 @@ func TestValidatorUnaryInterceptor(t *testing.T) { req7 := tiga.DeepCopy(req).(*hello.HelloRequestWithValidator) req7.Subs[1] = &hello.HelloSubRequest{ - SubAge: 19, - SubMsg: "test", + SubAge: 19, + SubMsg: "test", UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"sub_age"}}, } _, err = validator.UnaryInterceptor(context.Background(), req7, &grpc.UnaryServerInfo{}, func(ctx context.Context, req interface{}) (interface{}, error) { @@ -209,3 +210,14 @@ func TestValidatorStreamInterceptor(t *testing.T) { c.So(err, c.ShouldNotBeNil) }) } +func TestValidatorStreamClientInterceptor(t *testing.T) { + c.Convey("test stream client interceptor", t, func() { + validator := middleware.NewParamsValidator() + + _, err := validator.StreamClientInterceptor(context.Background(), nil, nil, "/INTEGRATION.TESTSERVICE/NOT_FOUND", func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + // return middleware.NewGrpcPluginClientStream(ctx, desc, cc, method, opts...),nil + return nil, nil + }) + c.So(err, c.ShouldBeNil) + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 4dc9558..ff73021 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,7 +13,6 @@ import ( "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/middleware" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/service" loadbalance "github.com/begonia-org/go-loadbalancer" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -52,7 +51,7 @@ func readDesc(conf *config.Config) (gateway.ProtobufDescription, error) { func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []service.Service, pluginApply *middleware.PluginsApply) *gateway.GatewayServer { // 参数选项 opts := &gateway.GrpcServerOptions{ - Middlewares: make([]gateway.GrpcProxyMiddleware, 0), + Middlewares: make([]grpc.StreamClientInterceptor, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]runtime.ServeMuxOption, 0), @@ -74,6 +73,7 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv // 中间件配置 opts.Options = append(opts.Options, grpc.ChainUnaryInterceptor(pluginApply.UnaryInterceptorChains()...)) opts.Options = append(opts.Options, grpc.ChainStreamInterceptor(pluginApply.StreamInterceptorChains()...)) + opts.Middlewares = append(opts.Middlewares, pluginApply.StreamClientInterceptorChains()...) pd, err := readDesc(conf) if err != nil { panic(err) @@ -84,7 +84,7 @@ func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []serv opts.HttpHandlers = append(opts.HttpHandlers, cors.Handle) gw := gateway.New(cfg, opts) - routersList := routers.Get() + routersList := gateway.GetRouter() for _, srv := range services { err := gw.RegisterLocalService(context.Background(), pd, srv.Desc(), srv) if err != nil { diff --git a/internal/service/file_test.go b/internal/service/file_test.go index fabe88e..12d4983 100644 --- a/internal/service/file_test.go +++ b/internal/service/file_test.go @@ -67,6 +67,7 @@ func makeBucket(t *testing.T) { rsp, err := apiClient.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + t.Logf("access key:%s",accessKey) minioFile := client.NewFilesAPI(apiAddr, accessKey, secret, api.FileEngine_FILE_ENGINE_MINIO) rsp, err = minioFile.CreateBucket(context.Background(), fileBucket, "test", false, true) c.So(err, c.ShouldBeNil) diff --git a/internal/service/tenant.go b/internal/service/tenant.go index c7abc27..e4dbe93 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -25,6 +25,8 @@ func NewTenantService(tenant *biz.TenantUsecase, cfg *config.Config, log logger. return &TenantService{tenant: tenant, cfg: cfg, log: log} } func (t *TenantService) Register(ctx context.Context, in *api.PostTenantRequest) (*api.Tenants, error) { + // st:=debug.Stack() + // fmt.Printf("TenantService stack:%s\n",st) identity := GetIdentity(ctx) if identity == "" { return nil, gosdk.NewError(pkg.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity")