diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index ee1089e..27170c3 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -5,8 +5,9 @@ import ( "github.com/begonia-org/begonia/config" "github.com/begonia-org/begonia/internal" - "github.com/begonia-org/begonia/transport" "github.com/spf13/cobra" + "github.com/begonia-org/begonia/gateway" + ) // var ProviderSet = wire.NewSet(NewMasterCmd) @@ -42,7 +43,7 @@ func NewGatewayCmd() *cobra.Command { // name, _ := cmd.Flags().GetString("name") env, _ := cmd.Flags().GetString("env") config := config.ReadConfig(env) - worker := internal.New(config, transport.Log, endpoint) + worker := internal.New(config, gateway.Log, endpoint) worker.Start() }, diff --git a/config/settings.yml b/config/settings.yml index 7f88e6a..95f2fde 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -56,6 +56,7 @@ gateway: cors: - "localhost" - "127.0.0.1:8081" + - "example.com" plugins: local: logger: 1 diff --git a/transport/endpoint.go b/gateway/endpoint.go similarity index 96% rename from transport/endpoint.go rename to gateway/endpoint.go index ce56aba..02238f9 100644 --- a/transport/endpoint.go +++ b/gateway/endpoint.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "context" @@ -50,7 +50,7 @@ func (e *httpForwardGrpcEndpointImpl) Request(req GrpcRequest) (proto.Message, r in := req.GetIn() ctx := req.GetContext() - err = conn.Invoke(ctx, req.GetFullMethodName(), in, out,grpc.Header(&metadata.HeaderMD),grpc.Trailer(&metadata.TrailerMD)) + err = conn.Invoke(ctx, req.GetFullMethodName(), in, out, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return out, metadata, err } diff --git a/transport/endpoint_test.go b/gateway/endpoint_test.go similarity index 83% rename from transport/endpoint_test.go rename to gateway/endpoint_test.go index f38bd52..65da7f9 100644 --- a/transport/endpoint_test.go +++ b/gateway/endpoint_test.go @@ -1,5 +1,4 @@ -package transport - +package gateway import ( "context" @@ -21,7 +20,7 @@ func TestRequest(t *testing.T) { // Out: &v1.HelloReply{}, // Ctx: context.Background(), // } - request:=NewGrpcRequest(context.Background(),nil,nil,"helloworld.Greeter/SayHello",WithIn(&v1.HelloRequest{Msg: "begonia"}),WithOut(&v1.HelloReply{})) + request := NewGrpcRequest(context.Background(), nil, nil, "helloworld.Greeter/SayHello", WithIn(&v1.HelloRequest{Msg: "begonia"}), WithOut(&v1.HelloReply{})) pool := NewGrpcConnPool("127.0.0.1:12138") endpoint := NewEndpoint(pool) reply, metadata, err := endpoint.Request(request) @@ -32,4 +31,4 @@ func TestRequest(t *testing.T) { c.So(metadata, c.ShouldNotBeNil) }) -} \ No newline at end of file +} diff --git a/transport/gateway.go b/gateway/gateway.go similarity index 90% rename from transport/gateway.go rename to gateway/gateway.go index 8fb88cc..681bdac 100644 --- a/transport/gateway.go +++ b/gateway/gateway.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "context" @@ -17,8 +17,6 @@ import ( "google.golang.org/grpc" ) - - type GrpcServerOptions struct { Middlewares []GrpcProxyMiddleware Options []grpc.ServerOption @@ -41,7 +39,6 @@ type GatewayServer struct { mux *sync.Mutex } - func NewGrpcServer(opts *GrpcServerOptions, lb *GrpcLoadBalancer) *grpc.Server { proxy := NewGrpcProxy(lb, opts.Middlewares...) @@ -57,7 +54,6 @@ func NewHttpServer(addr string, poolOpt ...loadbalance.PoolOptionsBuildOption) ( return NewHttpEndpoint(endpoint) - } func NewGateway(cfg *GatewayConfig, opts *GrpcServerOptions) *GatewayServer { lb := NewGrpcLoadBalancer() @@ -101,7 +97,7 @@ func (g *GatewayServer) DeleteLocalService(pd ProtobufDescription) { g.mux.Lock() defer g.mux.Unlock() g.proxyLB.Delete(pd) - _= g.DeleteHandlerClient(context.Background(), pd) + _ = g.DeleteHandlerClient(context.Background(), pd) } func (g *GatewayServer) GetLoadbalanceName() loadbalance.BalanceType { return g.proxyLB.Name() @@ -115,6 +111,19 @@ func (g *GatewayServer) Start() { g.grpcServer.ServeHTTP(w, r) } else { g.gatewayMux.ServeHTTP(w, r) + + // var handler = func(h http.Handler) http.Handler { + // return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // g.gatewayMux.ServeHTTP(w, r) + // }) + // } + // var httpHandle http.Handler + // for _, h := range g.opts.HttpHandlers { + // // handler = h(handler) + // // handler=h(handler) + // // httpHandle = + // } + // handler.ServeHTTP(w, r) } }), &http2.Server{}) diff --git a/transport/grpc.go b/gateway/grpc.go similarity index 99% rename from transport/grpc.go rename to gateway/grpc.go index ed55d3d..78c0d58 100644 --- a/transport/grpc.go +++ b/gateway/grpc.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "context" diff --git a/transport/http.go b/gateway/http.go similarity index 96% rename from transport/http.go rename to gateway/http.go index 94af421..30965ad 100644 --- a/transport/http.go +++ b/gateway/http.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "bytes" @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" + "github.com/spark-lence/tiga" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -306,7 +307,7 @@ func (h *HttpEndpointImpl) inParamsHandle(params map[string]string, req *http.Re field := fields.ByName(protoreflect.Name(param)) if field == nil { // Log.Errorf("inParamsHandle no parameter %s", param) - return status.Errorf(codes.InvalidArgument, "no parameter %s", param) + return status.Errorf(codes.InvalidArgument, "no such parameter %s", param) } in.Set(field, protoreflect.ValueOfString(msg)) @@ -379,7 +380,7 @@ func (h *HttpEndpointImpl) newRequest(ctx context.Context, item *HttpEndpointIte for _, param := range item.PathParams { val, ok = pathParams[param] if !ok { - return nil, status.Errorf(codes.InvalidArgument, "missing parameter %s", param) + continue } msg, err := runtime.String(val) if err != nil { @@ -388,9 +389,12 @@ func (h *HttpEndpointImpl) newRequest(ctx context.Context, item *HttpEndpointIte fields := in.Descriptor().Fields() field := fields.ByName(protoreflect.Name(param)) if field == nil { - return nil, status.Errorf(codes.InvalidArgument, "no parameter %s", param) + continue + } + if err := tiga.SetFieldValueFromString(in, field, msg); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "set field value error: %v", err) + } - in.Set(field, protoreflect.ValueOfString(msg)) } query := req.URL.Query() @@ -401,10 +405,14 @@ func (h *HttpEndpointImpl) newRequest(ctx context.Context, item *HttpEndpointIte } fields := in.Descriptor().Fields() field := fields.ByName(protoreflect.Name(k)) + if field == nil { - return nil, status.Errorf(codes.InvalidArgument, "no parameter %s", k) + continue + } + + if err := tiga.SetFieldValueFromString(in, field, msg); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "set field value error: %v", err) } - in.Set(field, protoreflect.ValueOfString(msg)) } grpcReq := NewGrpcRequest(ctx, item.In, item.Out, item.FullMethodName, WithIn(in), WithOut(dynamicpb.NewMessage(item.Out))) return grpcReq, nil @@ -470,11 +478,7 @@ func (h *HttpEndpointImpl) RegisterHandlerClient(ctx context.Context, pd Protobu req.Header.Set(GatewayXParams, strings.Join(params, ",")) } - // annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, item.FullMethodName, runtime.WithHTTPPathPattern(item.HttpUri)) - // if err != nil { - // runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - // return - // } + annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, item.FullMethodName, runtime.WithHTTPPathPattern(item.HttpUri)) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) diff --git a/transport/http_test.go b/gateway/http_test.go similarity index 71% rename from transport/http_test.go rename to gateway/http_test.go index ddbdf9d..4066c79 100644 --- a/transport/http_test.go +++ b/gateway/http_test.go @@ -1,4 +1,4 @@ -package transport_test +package gateway_test import ( "bytes" @@ -19,8 +19,11 @@ import ( "testing" "time" - "github.com/begonia-org/begonia/transport" - "github.com/begonia-org/begonia/transport/serialization" + "github.com/begonia-org/begonia" + "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" + "github.com/begonia-org/begonia/gateway/serialization" + cfg "github.com/begonia-org/begonia/internal/pkg/config" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" hello "github.com/begonia-org/go-sdk/api/example/v1" @@ -31,26 +34,27 @@ import ( "github.com/r3labs/sse/v2" c "github.com/smartystreets/goconvey/convey" // 别名导入 "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/reflect/protoregistry" "gopkg.in/cenkalti/backoff.v1" ) -// var endpoint transport.HttpEndpoint +// var endpoint gateway.HttpEndpoint var gwPort = 9527 -var gw *transport.GatewayServer +var gw *gateway.GatewayServer var onceInit sync.Once var randomNumber int -func newTestServer(gwPort, randomNumber int) (*transport.GrpcServerOptions, *transport.GatewayConfig) { - opts := &transport.GrpcServerOptions{ - Middlewares: make([]transport.GrpcProxyMiddleware, 0), +func newTestServer(gwPort, randomNumber int) (*gateway.GrpcServerOptions, *gateway.GatewayConfig) { + opts := &gateway.GrpcServerOptions{ + Middlewares: make([]gateway.GrpcProxyMiddleware, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), HttpHandlers: make([]func(http.Handler) http.Handler, 0), } - gwCnf := &transport.GatewayConfig{ + gwCnf := &gateway.GatewayConfig{ GatewayAddr: fmt.Sprintf("127.0.0.1:%d", gwPort), GrpcProxyAddr: fmt.Sprintf("127.0.0.1:%d", randomNumber+1), } @@ -62,15 +66,26 @@ func newTestServer(gwPort, randomNumber int) (*transport.GrpcServerOptions, *tra opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption("application/octet-stream", serialization.NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption("text/event-stream", serialization.NewEventSourceMarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMarshalerOption(serialization.ClientStreamContentType, serialization.NewProtobufWithLengthPrefix())) - opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMetadata(transport.IncomingHeadersToMetadata)) - opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithErrorHandler(transport.HandleErrorWithLogger(transport.Log))) - opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithForwardResponseOption(transport.HttpResponseBodyModify)) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithMetadata(gateway.IncomingHeadersToMetadata)) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithErrorHandler(gateway.HandleErrorWithLogger(gateway.Log))) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, gwRuntime.WithForwardResponseOption(gateway.HttpResponseBodyModify)) opts.PoolOptions = append(opts.PoolOptions, loadbalance.WithMaxActiveConns(100)) opts.PoolOptions = append(opts.PoolOptions, loadbalance.WithPoolSize(128)) - loggerMid:=transport.NewLoggerMiddleware(transport.Log) + loggerMid := gateway.NewLoggerMiddleware(gateway.Log) opts.Options = append(opts.Options, grpc.ChainStreamInterceptor(loggerMid.StreamInterceptor)) opts.Options = append(opts.Options, grpc.ChainUnaryInterceptor(loggerMid.UnaryInterceptor)) + env := "dev" + if begonia.Env != "" { + env = begonia.Env + } + // log.Printf("env: %s", env) + cnf := config.ReadConfig(env) + conf := cfg.NewConfig(cnf) + cors := &gateway.CorsHandler{ + Cors: conf.GetCorsConfig(), + } + opts.HttpHandlers = append(opts.HttpHandlers, cors.Handle) return opts, gwCnf } func init() { @@ -84,7 +99,7 @@ func init() { gwPort = randomNumber // gw = newTestServer(gwPort, randomNumber) opts, cnf := newTestServer(gwPort, randomNumber) - gw = transport.New(cnf, opts) + gw = gateway.New(cnf, opts) }) @@ -92,13 +107,16 @@ func init() { func testRegisterClient(t *testing.T) { c.Convey("test register client", t, func() { _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "internal", "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata", "helloworld.pb") pb, err := os.ReadFile(pbFile) c.So(err, c.ShouldBeNil) - pd, err := transport.NewDescriptionFromBinary(pb, filepath.Join("tmp", "test-pd")) + pd, err := gateway.NewDescriptionFromBinary(pb, filepath.Join("tmp", "test-pd")) + // t.Logf("pd:%+v", pd.GetGatewayJsonSchema()) c.So(err, c.ShouldBeNil) + c.So(pd.GetMessageTypeByFullName("helloworld.HelloRequest"), c.ShouldNotBeNil) + c.So(pd.GetDescription(),c.ShouldNotBeEmpty) helloAddr := fmt.Sprintf("127.0.0.1:%d", randomNumber+2) - endps, err := transport.NewLoadBalanceEndpoint(loadbalance.RRBalanceType, []*api.EndpointMeta{{ + endps, err := gateway.NewLoadBalanceEndpoint(loadbalance.RRBalanceType, []*api.EndpointMeta{{ Addr: helloAddr, Weight: 0, }}) @@ -119,13 +137,16 @@ func testRequestGet(t *testing.T) { c.Convey("test request GET", t, func() { url := fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/world?msg=hello", gwPort) r, err := http.NewRequest(http.MethodGet, url, nil) + r.Header.Set("x-uid","12345678") + // r.Header.Set("Origin", "http://www.example.com") c.So(err, c.ShouldBeNil) resp, err := http.DefaultClient.Do(r) c.So(err, c.ShouldBeNil) c.So(resp.StatusCode, c.ShouldEqual, http.StatusOK) - + c.So(resp.Header.Get("test"), c.ShouldBeEmpty) + c.So(resp.Header.Get("trace_id"), c.ShouldEqual, "123456") defer resp.Body.Close() body, err := io.ReadAll(resp.Body) c.So(err, c.ShouldBeNil) @@ -137,6 +158,35 @@ func testRequestGet(t *testing.T) { c.So(rsp.Message, c.ShouldEqual, "hello") c.So(rsp.Name, c.ShouldEqual, "world") + url = fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/http-code?msg=hello", gwPort) + r, err = http.NewRequest(http.MethodGet, url, nil) + // r.Header.Set("Origin", "http://www.example.com") + c.So(err, c.ShouldBeNil) + + resp, err = http.DefaultClient.Do(r) + c.So(err, c.ShouldBeNil) + c.So(resp.StatusCode, c.ShouldEqual, http.StatusIMUsed) + + }) +} +func testCors(t *testing.T) { + c.Convey("test cors", t, func() { + url := fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/world?msg=hello", gwPort) + r, err := http.NewRequest(http.MethodOptions, url, nil) + r.Header.Set("Origin", "http://www.example.com") + c.So(err, c.ShouldBeNil) + + resp, err := http.DefaultClient.Do(r) + c.So(err, c.ShouldBeNil) + c.So(resp.StatusCode, c.ShouldEqual, http.StatusNoContent) + c.So(resp.Header.Get("Access-Control-Allow-Origin"), c.ShouldEqual, "http://www.example.com") + + r, err = http.NewRequest(http.MethodOptions, url, nil) + r.Header.Set("Origin", "http://www.begonia-org.com") + c.So(err, c.ShouldBeNil) + resp, err = http.DefaultClient.Do(r) + c.So(err, c.ShouldBeNil) + c.So(resp.StatusCode, c.ShouldEqual, http.StatusForbidden) }) } func testRequestPost(t *testing.T) { @@ -337,10 +387,10 @@ func testClientStreamRequest(t *testing.T) { func testDeleteEndpoint(t *testing.T) { c.Convey("test delete endpoint", t, func() { _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "internal", "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata", "helloworld.pb") pb, err := os.ReadFile(pbFile) c.So(err, c.ShouldBeNil) - pd, err := transport.NewDescriptionFromBinary(pb, filepath.Join("tmp", "test-pd")) + pd, err := gateway.NewDescriptionFromBinary(pb, filepath.Join("tmp", "test-pd")) c.So(err, c.ShouldBeNil) err = gw.DeleteHandlerClient(context.TODO(), pd) c.So(err, c.ShouldBeNil) @@ -358,15 +408,15 @@ func testDeleteEndpoint(t *testing.T) { }) } func testRegisterLocalService(t *testing.T) { - var pd transport.ProtobufDescription + var pd gateway.ProtobufDescription var gwPort int - var localGW *transport.GatewayServer + var localGW *gateway.GatewayServer c.Convey("test register local service", t, func() { _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "internal", "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata", "helloworld.pb") pb, err := os.ReadFile(pbFile) c.So(err, c.ShouldBeNil) - pd, err = transport.NewDescriptionFromBinary(pb, filepath.Join("tmp", "test-pd-2")) + pd, err = gateway.NewDescriptionFromBinary(pb, filepath.Join("tmp", "test-pd-2")) c.So(err, c.ShouldBeNil) exampleServer := example.NewExampleServer() @@ -378,7 +428,7 @@ func testRegisterLocalService(t *testing.T) { gwPort = randomNumber // gw = newTestServer(gwPort, randomNumber) opts, cnf := newTestServer(gwPort, randomNumber) - localGW = transport.NewGateway(cnf, opts) + localGW = gateway.NewGateway(cnf, opts) err = localGW.RegisterLocalService(context.Background(), pd, exampleServer.Desc(), exampleServer) c.So(err, c.ShouldBeNil) @@ -419,15 +469,14 @@ func testRegisterLocalService(t *testing.T) { c.So(resp.StatusCode, c.ShouldEqual, http.StatusNotFound) }) } - func testLoadGlobalTypes(t *testing.T) { c.Convey("test load global types", t, func() { _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "internal", "integration", "testdata") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "testdata") os.Remove(filepath.Join(pbFile, "desc.pb")) os.Remove(filepath.Join(pbFile, "gateway.json")) - pd, err := transport.NewDescription(pbFile) + pd, err := gateway.NewDescription(pbFile) c.So(err, c.ShouldBeNil) err = gw.RegisterHandlerClient(context.Background(), pd) c.So(err, c.ShouldBeNil) @@ -443,15 +492,87 @@ func testLoadGlobalTypes(t *testing.T) { c.So(err, c.ShouldBeNil) }) } +func testHttpError(t *testing.T) { + c.Convey("test http error", t, func() { + // url := fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/error?code=%d&msg=ok", codes.OK, gwPort) + cases := []struct { + code int32 + msg string + expectHttpCode int + expectInternalCode int32 + }{ + { + code: int32(codes.OK), + msg: "ok", + expectHttpCode: http.StatusOK, + expectInternalCode: int32(common.Code_OK), + }, + { + code: int32(codes.Internal), + msg: codes.Internal.String(), + expectHttpCode: http.StatusInternalServerError, + expectInternalCode: int32(common.Code_INTERNAL_ERROR), + }, + { + code: int32(codes.InvalidArgument), + msg: codes.InvalidArgument.String(), + expectHttpCode: http.StatusBadRequest, + expectInternalCode: int32(common.Code_PARAMS_ERROR), + }, + { + code: int32(codes.NotFound), + msg: codes.NotFound.String(), + expectHttpCode: http.StatusNotFound, + expectInternalCode: int32(common.Code_NOT_FOUND), + }, + { + code: int32(codes.PermissionDenied), + msg: codes.PermissionDenied.String(), + expectHttpCode: http.StatusForbidden, + expectInternalCode: int32(common.Code_AUTH_ERROR), + }, + { + code: int32(codes.Unauthenticated), + msg: codes.Unauthenticated.String(), + expectHttpCode: http.StatusUnauthorized, + expectInternalCode: int32(common.Code_AUTH_ERROR), + }, + { + code: int32(codes.ResourceExhausted), + msg: codes.ResourceExhausted.String(), + expectHttpCode: http.StatusTooManyRequests, + expectInternalCode: int32(common.Code_RESOURCE_EXHAUSTED), + }, + { + code: int32(codes.DeadlineExceeded), + msg: codes.DeadlineExceeded.String(), + expectHttpCode: http.StatusGatewayTimeout, + expectInternalCode: int32(common.Code_TIMEOUT_ERROR), + }, + } + for _, v := range cases { + url := fmt.Sprintf("http://127.0.0.1:%d/api/v1/example/error/test?msg=ok&code=%d", gwPort, v.code) + r, err := http.NewRequest(http.MethodGet, url, nil) + c.So(err, c.ShouldBeNil) + c.So(r, c.ShouldNotBeNil) + resp, err := http.DefaultClient.Do(r) + c.So(err, c.ShouldBeNil) + c.So(resp.StatusCode, c.ShouldEqual, v.expectHttpCode) + + } + }) +} func TestHttp(t *testing.T) { t.Run("testRegisterClient", testRegisterClient) t.Run("testRequestGet", testRequestGet) + t.Run("testCors", testCors) t.Run("testRequestPost", testRequestPost) t.Run("testServerSideEvent", testServerSideEvent) t.Run("testWebsocket", testWebsocket) t.Run("testClientStreamRequest", testClientStreamRequest) t.Run("testLoadGlobalTypes", testLoadGlobalTypes) + t.Run("testHttpError", testHttpError) t.Run("testDeleteEndpoint", testDeleteEndpoint) t.Run("testRegisterLocalService", testRegisterLocalService) // time.Sleep(30 * time.Second) diff --git a/transport/logger.go b/gateway/logger.go similarity index 99% rename from transport/logger.go rename to gateway/logger.go index bf135f2..0123d08 100644 --- a/transport/logger.go +++ b/gateway/logger.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "context" @@ -138,7 +138,7 @@ func (l *LoggerImpl) Error(ctx context.Context, err error) { func (f *loggerFormatter) getFormatterFields(data logrus.Fields) string { bData, _ := json.Marshal(data) - return string(bData) + return string(bData) + "\n" } diff --git a/transport/marshaller.go b/gateway/marshaller.go similarity index 76% rename from transport/marshaller.go rename to gateway/marshaller.go index 504a4bc..77373a9 100644 --- a/transport/marshaller.go +++ b/gateway/marshaller.go @@ -1,4 +1,4 @@ -package transport +package gateway type FormatDataDecoder interface { SetBoundary(string) diff --git a/transport/middlewares.go b/gateway/middlewares.go similarity index 91% rename from transport/middlewares.go rename to gateway/middlewares.go index e12ae4c..e870f93 100644 --- a/transport/middlewares.go +++ b/gateway/middlewares.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "context" @@ -43,13 +43,14 @@ func preflightHandler(w http.ResponseWriter, _ *http.Request) { // methods := []string{"GET", "HEAD", "POST", "PUT", "DELETE"} w.Header().Set("Access-Control-Allow-Methods", "*") w.Header().Set("Access-Control-Expose-Headers", "*") + w.WriteHeader(http.StatusNoContent) } -type CorsMiddleware struct { +type CorsHandler struct { Cors []string } -func (cors *CorsMiddleware) Handle(h http.Handler) http.Handler { +func (cors *CorsHandler) Handle(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if clientOrigin := r.Header.Get("Origin"); clientOrigin != "" { var isAllowed bool @@ -66,22 +67,15 @@ func (cors *CorsMiddleware) Handle(h http.Handler) http.Handler { preflightHandler(w, r) return } + }else{ + Log.Errorf(r.Context(), "origin:%s not allowed", clientOrigin) + w.WriteHeader(http.StatusForbidden) + return } } h.ServeHTTP(w, r) }) } -func RequestIDMiddleware(ctx context.Context, r *http.Request) metadata.MD { - md, ok := runtime.ServerMetadataFromContext(ctx) - if !ok { - return nil - } - if val := md.HeaderMD.Get(XRequestID); len(val) > 0 { - r.Header.Set(XRequestID, val[0]) - } - return md.HeaderMD - -} func IncomingHeadersToMetadata(ctx context.Context, req *http.Request) metadata.MD { // 创建一个新的 metadata.MD 实例 @@ -156,8 +150,9 @@ func (log *LoggerMiddleware) Name() string { } func (log *LoggerMiddleware) logger(ctx context.Context, fullMethod string, err error, elapsed time.Duration) { code := status.Code(err) + httpCode := runtime.HTTPStatusFromCode(code) logger := log.log.WithFields(logrus.Fields{ - "status": code, + "status": httpCode, "elapsed": elapsed.String(), "name": fullMethod, "module": "request", @@ -191,9 +186,7 @@ func (log *LoggerMiddleware) StreamInterceptor(srv interface{}, ss grpc.ServerSt log.logger(ss.Context(), info.FullMethod, err, elapsed) } }() - // md, ok := metadata.FromIncomingContext(ss.Context()) - // reqId := uuid.New().String() ctx := ss.Context() err = handler(srv, ss) @@ -235,7 +228,6 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { statusCode := http.StatusOK - // file, line, fn, _ := errors.GetOneLineSource(err) log := logger.WithFields(logrus.Fields{ "status": statusCode, @@ -246,12 +238,10 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { data.Code = int32(common.Code_INTERNAL_ERROR) data.Message = "internal error" if st, ok := status.FromError(err); ok { - // rspCode := float64(common.Code_INTERNAL_ERROR) msg := st.Message() details := st.Details() data.Message = clientMessageFromCode(st.Code()) - // code = runtime.HTTPStatusFromCode(st.Code()) for _, detail := range details { if anyType, ok := detail.(*anypb.Any); ok { var errDetail common.Errors @@ -263,8 +253,7 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { "line": errDetail.Line, "fn": errDetail.Fn, }) - // log.Error(msg) - // fmt.Printf("error code:%d", errDetail.Code) + msg := codes[int32(errDetail.Code)] if errDetail.ToClientMessage != "" { msg = errDetail.ToClientMessage @@ -282,7 +271,8 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { } code = runtime.HTTPStatusFromCode(st.Code()) - log.Errorf(ctx, msg) + + log.WithField("status", code).Errorf(ctx, msg) w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) bData, _ := protojson.Marshal(data) @@ -298,6 +288,7 @@ func HandleErrorWithLogger(logger logger.Logger) runtime.ErrorHandlerFunc { } } func writeHttpHeaders(w http.ResponseWriter, key string, value []string) { + // del Grpc-Metadata if httpKey := gosdk.GetHttpHeaderKey(key); httpKey != "" { for _, v := range value { w.Header().Del(key) diff --git a/gateway/plugin.go b/gateway/plugin.go new file mode 100644 index 0000000..fa0ed3c --- /dev/null +++ b/gateway/plugin.go @@ -0,0 +1 @@ +package gateway diff --git a/transport/protobuf.go b/gateway/protobuf.go similarity index 99% rename from transport/protobuf.go rename to gateway/protobuf.go index 48eaa65..4435f18 100644 --- a/transport/protobuf.go +++ b/gateway/protobuf.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "encoding/json" diff --git a/transport/protos/Makefile b/gateway/protos/Makefile similarity index 100% rename from transport/protos/Makefile rename to gateway/protos/Makefile diff --git a/transport/protos/desc.pb b/gateway/protos/desc.pb similarity index 100% rename from transport/protos/desc.pb rename to gateway/protos/desc.pb diff --git a/transport/protos/gateway.json b/gateway/protos/gateway.json similarity index 100% rename from transport/protos/gateway.json rename to gateway/protos/gateway.json diff --git a/internal/integration/testdata/google/api/annotations.proto b/gateway/protos/google/api/annotations.proto similarity index 100% rename from internal/integration/testdata/google/api/annotations.proto rename to gateway/protos/google/api/annotations.proto diff --git a/internal/integration/testdata/google/api/http.proto b/gateway/protos/google/api/http.proto similarity index 100% rename from internal/integration/testdata/google/api/http.proto rename to gateway/protos/google/api/http.proto diff --git a/internal/integration/testdata/google/protobuf/any.proto b/gateway/protos/google/protobuf/any.proto similarity index 100% rename from internal/integration/testdata/google/protobuf/any.proto rename to gateway/protos/google/protobuf/any.proto diff --git a/internal/integration/testdata/google/protobuf/descriptor.proto b/gateway/protos/google/protobuf/descriptor.proto similarity index 99% rename from internal/integration/testdata/google/protobuf/descriptor.proto rename to gateway/protos/google/protobuf/descriptor.proto index 3b38675..5154e5a 100755 --- a/internal/integration/testdata/google/protobuf/descriptor.proto +++ b/gateway/protos/google/protobuf/descriptor.proto @@ -305,7 +305,7 @@ message MethodDescriptorProto { // Identifies if client streams multiple client messages optional bool client_streaming = 5 [default = false]; // Identifies if server streams multiple server messages - optional bool server_streaming = 6 [default = false]; + optional bool service_streaming = 6 [default = false]; } // =================================================================== diff --git a/internal/integration/testdata/google/protobuf/field_mask.proto b/gateway/protos/google/protobuf/field_mask.proto similarity index 100% rename from internal/integration/testdata/google/protobuf/field_mask.proto rename to gateway/protos/google/protobuf/field_mask.proto diff --git a/internal/integration/testdata/google/protobuf/timestamp.proto b/gateway/protos/google/protobuf/timestamp.proto similarity index 100% rename from internal/integration/testdata/google/protobuf/timestamp.proto rename to gateway/protos/google/protobuf/timestamp.proto diff --git a/transport/protos/helloworld.proto b/gateway/protos/helloworld.proto similarity index 100% rename from transport/protos/helloworld.proto rename to gateway/protos/helloworld.proto diff --git a/transport/protos/http.proto b/gateway/protos/http.proto similarity index 100% rename from transport/protos/http.proto rename to gateway/protos/http.proto diff --git a/transport/request.go b/gateway/request.go similarity index 99% rename from transport/request.go rename to gateway/request.go index 54200a5..80335b7 100644 --- a/transport/request.go +++ b/gateway/request.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "context" diff --git a/gateway/request_test.go b/gateway/request_test.go new file mode 100644 index 0000000..e63ea73 --- /dev/null +++ b/gateway/request_test.go @@ -0,0 +1,43 @@ +package gateway_test + +import ( + "context" + "net/http" + "testing" + + "github.com/begonia-org/begonia/gateway/serialization" + hello "github.com/begonia-org/go-sdk/api/example/v1" + c "github.com/smartystreets/goconvey/convey" + "google.golang.org/grpc" + "github.com/begonia-org/begonia/gateway" + +) + +func TestBuildGrpcRequest(t *testing.T) { + c.Convey("TestBuildGrpcRequest", t, func() { + in := &hello.HelloRequest{} + out := &hello.HelloReply{} + httpReq, _ := http.NewRequest("GET", "http://127.0.0.1:8080", nil) + + req := gateway.NewGrpcRequest(context.Background(), + in.ProtoReflect().Descriptor(), + out.ProtoReflect().Descriptor(), + "helloworld.Greeter/SayHello", + gateway.WithGatewayCallOptions(grpc.CompressorCallOption{}), + gateway.WithGatewayMarshaler(serialization.NewJSONMarshaler()), + gateway.WithGatewayPathParams(map[string]string{"key": "value"}), + gateway.WithGatewayReq(httpReq), + gateway.WithIn(in), + gateway.WithOut(out), + ) + c.So(req.GetFullMethodName(), c.ShouldEqual, "helloworld.Greeter/SayHello") + c.So(len(req.GetCallOptions()), c.ShouldEqual, 1) + c.So(req.GetMarshaler(), c.ShouldHaveSameTypeAs, serialization.NewJSONMarshaler()) + c.So(req.GetPathParams(), c.ShouldResemble, map[string]string{"key": "value"}) + c.So(req.GetReq().URL.String(), c.ShouldEqual, httpReq.URL.String()) + c.So(req.GetIn(), c.ShouldHaveSameTypeAs, in) + c.So(req.GetOut(), c.ShouldHaveSameTypeAs, out) + c.So(req.GetInType(), c.ShouldEqual, in.ProtoReflect().Descriptor()) + c.So(req.GetOutType(), c.ShouldEqual, out.ProtoReflect().Descriptor()) + }) +} diff --git a/transport/serialization/formdata.go b/gateway/serialization/formdata.go similarity index 88% rename from transport/serialization/formdata.go rename to gateway/serialization/formdata.go index 6a9c0bb..f1c5a4e 100644 --- a/transport/serialization/formdata.go +++ b/gateway/serialization/formdata.go @@ -8,7 +8,7 @@ import ( "strconv" "strings" - "github.com/begonia-org/begonia/internal/pkg/errors" + gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "google.golang.org/grpc/codes" @@ -57,12 +57,12 @@ func (f *FormDataDecoder) Decode(v interface{}) error { file := files[0] fd, err := file.Open() if err != nil { - return errors.New(fmt.Errorf("read file from form data error,%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") + return gosdk.NewError(fmt.Errorf("read file from form data error,%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") } fileBytes, err := io.ReadAll(fd) fd.Close() if err != nil { - return errors.New(fmt.Errorf("read file from form data error,%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") + return gosdk.NewError(fmt.Errorf("read file from form data error,%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") } if _, ok := formData.Value[key]; !ok { @@ -74,7 +74,7 @@ func (f *FormDataDecoder) Decode(v interface{}) error { if pb, ok := v.(protoreflect.ProtoMessage); ok { err := parseFormToProto(formData.Value, pb) if err != nil { - return errors.New(fmt.Errorf("parse form data to proto error,%w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "parse_form_data") + return gosdk.NewError(fmt.Errorf("parse form data to proto error,%w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "parse_form_data") } return nil } @@ -211,16 +211,16 @@ func NewFormDataMarshaler() *FormDataMarshaler { func (f *FormUrlEncodedDecoder) Decode(v interface{}) error { buf, err := io.ReadAll(f.r) if err != nil && err != io.EOF { - return errors.New(fmt.Errorf("read form data error,%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_format") + return gosdk.NewError(fmt.Errorf("read form data error,%w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_format") } data, err := url.ParseQuery(string(buf)) if err != nil { - return errors.New(fmt.Errorf("parse form data error,%w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "parse_form_data") + return gosdk.NewError(fmt.Errorf("parse form data error,%w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "parse_form_data") } if pb, ok := v.(protoreflect.ProtoMessage); ok { err := parseFormToProto(data, pb) if err != nil { - return errors.New(fmt.Errorf("parse form data to proto error,%w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "parse_form_data") + return gosdk.NewError(fmt.Errorf("parse form data to proto error,%w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "parse_form_data") } return nil } diff --git a/transport/serialization/formdata_test.go b/gateway/serialization/formdata_test.go similarity index 100% rename from transport/serialization/formdata_test.go rename to gateway/serialization/formdata_test.go diff --git a/transport/serialization/mask.go b/gateway/serialization/mask.go similarity index 100% rename from transport/serialization/mask.go rename to gateway/serialization/mask.go diff --git a/transport/serialization/mask_test.go b/gateway/serialization/mask_test.go similarity index 100% rename from transport/serialization/mask_test.go rename to gateway/serialization/mask_test.go diff --git a/transport/serialization/serialization.go b/gateway/serialization/serialization.go similarity index 100% rename from transport/serialization/serialization.go rename to gateway/serialization/serialization.go diff --git a/transport/serialization/stream.go b/gateway/serialization/stream.go similarity index 100% rename from transport/serialization/stream.go rename to gateway/serialization/stream.go diff --git a/transport/transport.go b/gateway/transport.go similarity index 93% rename from transport/transport.go rename to gateway/transport.go index 0084449..4f1773f 100644 --- a/transport/transport.go +++ b/gateway/transport.go @@ -1,4 +1,4 @@ -package transport +package gateway import "sync" diff --git a/transport/types.go b/gateway/types.go similarity index 99% rename from transport/types.go rename to gateway/types.go index 1a3fcae..96f1946 100644 --- a/transport/types.go +++ b/gateway/types.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "sync/atomic" diff --git a/transport/utils.go b/gateway/utils.go similarity index 94% rename from transport/utils.go rename to gateway/utils.go index ddaedd6..2ff3eb6 100644 --- a/transport/utils.go +++ b/gateway/utils.go @@ -1,10 +1,10 @@ -package transport +package gateway import ( "fmt" loadbalance "github.com/begonia-org/go-loadbalancer" - api "github.com/begonia-org/go-sdk/api/endpoint/v1" + api "github.com/begonia-org/go-sdk/api/endpoint/v1" ) func isASCIILower(c byte) bool { @@ -33,7 +33,7 @@ func NewLoadBalanceEndpoint(lb loadbalance.BalanceType, endpoints []*api.Endpoin opts := gw.GetOptions() for _, ep := range endpoints { pool := NewGrpcConnPool(ep.GetAddr(), opts.PoolOptions...) - eps = append(eps,NewGrpcEndpoint(ep.GetAddr(), pool)) + eps = append(eps, NewGrpcEndpoint(ep.GetAddr(), pool)) } switch lb { case loadbalance.RRBalanceType: diff --git a/transport/utils_test.go b/gateway/utils_test.go similarity index 76% rename from transport/utils_test.go rename to gateway/utils_test.go index c75a4eb..9ba1f99 100644 --- a/transport/utils_test.go +++ b/gateway/utils_test.go @@ -1,4 +1,4 @@ -package transport_test +package gateway_test import ( "fmt" @@ -6,27 +6,27 @@ import ( "reflect" "testing" - "github.com/begonia-org/begonia/transport" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" gwRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" c "github.com/smartystreets/goconvey/convey" "google.golang.org/grpc" + "github.com/begonia-org/begonia/gateway" ) func TestNewEndpoint(t *testing.T) { - opts := &transport.GrpcServerOptions{ - Middlewares: make([]transport.GrpcProxyMiddleware, 0), + opts := &gateway.GrpcServerOptions{ + Middlewares: make([]gateway.GrpcProxyMiddleware, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), HttpHandlers: make([]func(http.Handler) http.Handler, 0), } - gwCnf := &transport.GatewayConfig{ + gwCnf := &gateway.GatewayConfig{ GatewayAddr: "127.0.0.1:9527", GrpcProxyAddr: "127.0.0.1:12148", } - transport.New(gwCnf, opts) + gateway.New(gwCnf, opts) meta := []*api.EndpointMeta{{ Addr: "127.0.0.1:12138", Weight: 0, @@ -50,7 +50,7 @@ func TestNewEndpoint(t *testing.T) { name: string(loadbalance.RRBalanceType), endpoints: meta, exceptErr: nil, - exceptEndpointType: reflect.TypeOf(transport.NewGrpcEndpoint("", nil)).Elem(), + exceptEndpointType: reflect.TypeOf(gateway.NewGrpcEndpoint("", nil)).Elem(), }, { name: string(loadbalance.WRRBalanceType), @@ -68,7 +68,7 @@ func TestNewEndpoint(t *testing.T) { name: string(loadbalance.ConsistentHashBalanceType), endpoints: meta, exceptErr: nil, - exceptEndpointType: reflect.TypeOf(transport.NewGrpcEndpoint("", nil)).Elem(), + exceptEndpointType: reflect.TypeOf(gateway.NewGrpcEndpoint("", nil)).Elem(), }, { name: string(loadbalance.LCBalanceType), @@ -97,7 +97,7 @@ func TestNewEndpoint(t *testing.T) { } c.Convey("TestNewEndpoint", t, func() { for _, testCase := range cases { - enps, err := transport.NewLoadBalanceEndpoint(loadbalance.BalanceType(testCase.name), testCase.endpoints) + enps, err := gateway.NewLoadBalanceEndpoint(loadbalance.BalanceType(testCase.name), testCase.endpoints) if testCase.exceptErr != nil { c.So(err, c.ShouldNotBeNil) } else { @@ -111,9 +111,9 @@ func TestNewEndpoint(t *testing.T) { func TestJSONCamelCase(t *testing.T) { c.Convey("TestJSONCamelCase", t, func() { - c.So(transport.JSONCamelCase("testCase"), c.ShouldEqual, "testCase") - c.So(transport.JSONCamelCase("testcasetest"), c.ShouldEqual, "testcasetest") - c.So(transport.JSONCamelCase("test_case_test_test"), c.ShouldEqual, "testCaseTestTest") - c.So(transport.JSONCamelCase("test_case_test_test_test"), c.ShouldEqual, "testCaseTestTestTest") + c.So(gateway.JSONCamelCase("testCase"), c.ShouldEqual, "testCase") + c.So(gateway.JSONCamelCase("testcasetest"), c.ShouldEqual, "testcasetest") + c.So(gateway.JSONCamelCase("test_case_test_test"), c.ShouldEqual, "testCaseTestTest") + c.So(gateway.JSONCamelCase("test_case_test_test_test"), c.ShouldEqual, "testCaseTestTestTest") }) -} \ No newline at end of file +} diff --git a/transport/websocket.go b/gateway/websocket.go similarity index 92% rename from transport/websocket.go rename to gateway/websocket.go index be6cf5b..19da3e0 100644 --- a/transport/websocket.go +++ b/gateway/websocket.go @@ -1,4 +1,4 @@ -package transport +package gateway import ( "bytes" @@ -22,7 +22,7 @@ type websocketForwarder struct { responseType int } -func NewWebsocketForwarder(w http.ResponseWriter, req *http.Request,responseType int) (WebsocketForwarder, error) { +func NewWebsocketForwarder(w http.ResponseWriter, req *http.Request, responseType int) (WebsocketForwarder, error) { var upgrader = websocket.Upgrader{ // 允许所有CORS请求 CheckOrigin: func(r *http.Request) bool { return true }, @@ -50,7 +50,7 @@ func (w *websocketForwarder) Read() ([]byte, error) { } func (w *websocketForwarder) NextReader() (io.Reader, error) { - _,reader,err:= w.websocket.NextReader() + _, reader, err := w.websocket.NextReader() if err != nil { return nil, err } diff --git a/go.mod b/go.mod index 3898f4f..adf01a4 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/cockroachdb/errors v1.11.1 github.com/google/wire v0.6.0 github.com/smartystreets/goconvey v1.8.1 - github.com/spark-lence/tiga v0.0.0-20240510102710-93bf07b60b07 + github.com/spark-lence/tiga v0.0.0-20240517061929-e81eba889226 github.com/spf13/cobra v1.8.0 google.golang.org/genproto/googleapis/api v0.0.0-20240515191416-fc5f0ca64291 google.golang.org/grpc v1.64.0 @@ -80,7 +80,7 @@ require ( require ( github.com/agiledragon/gomonkey/v2 v2.11.0 github.com/begonia-org/go-loadbalancer v0.0.0-20240515153502-b1d83dda8ae3 - github.com/begonia-org/go-sdk v0.0.0-20240516160356-c90fe583fd2e + github.com/begonia-org/go-sdk v0.0.0-20240517084829-c3cdf5e5e1eb github.com/go-git/go-git/v5 v5.11.0 github.com/go-playground/validator/v10 v10.19.0 github.com/gorilla/websocket v1.5.0 diff --git a/go.sum b/go.sum index 2daded8..4d473b1 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,10 @@ github.com/begonia-org/go-sdk v0.0.0-20240515083527-ef2ff6b73539 h1:M7pPon2kyX2M github.com/begonia-org/go-sdk v0.0.0-20240515083527-ef2ff6b73539/go.mod h1:I70a3fiAADGrOoOC3lv408rFcTRhTwLt3pwr6cQwB4Y= github.com/begonia-org/go-sdk v0.0.0-20240516160356-c90fe583fd2e h1:VwPf1HI//SopJpJtWHtQd6JreryTtu8s1m0zQo+Jeqc= github.com/begonia-org/go-sdk v0.0.0-20240516160356-c90fe583fd2e/go.mod h1:I70a3fiAADGrOoOC3lv408rFcTRhTwLt3pwr6cQwB4Y= +github.com/begonia-org/go-sdk v0.0.0-20240517035447-b6ee0a94bc66 h1:ejdny9b1oeioMFd8IWQcY9uDpDOAlTnMAJvhsalHdjs= +github.com/begonia-org/go-sdk v0.0.0-20240517035447-b6ee0a94bc66/go.mod h1:I70a3fiAADGrOoOC3lv408rFcTRhTwLt3pwr6cQwB4Y= +github.com/begonia-org/go-sdk v0.0.0-20240517084829-c3cdf5e5e1eb h1:H58fLjtWA5CZSl20GLyi+3xnQiRiP47aEub7YWFjuQc= +github.com/begonia-org/go-sdk v0.0.0-20240517084829-c3cdf5e5e1eb/go.mod h1:I70a3fiAADGrOoOC3lv408rFcTRhTwLt3pwr6cQwB4Y= github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -226,6 +230,10 @@ github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9yS github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spark-lence/tiga v0.0.0-20240510102710-93bf07b60b07 h1:BLvfdUv/NKsRirhQbUum2eduFHfu0tcWyF4fHhP62zY= github.com/spark-lence/tiga v0.0.0-20240510102710-93bf07b60b07/go.mod h1:jo3Qr3EkFkOX0GUD6c4YyLoJESTYtd2hfYP9HZk8s8I= +github.com/spark-lence/tiga v0.0.0-20240517030839-e2e8385d3629 h1:J5qEeswukKY+wAnQLgHFz8T9/vItTF4lga7JDjwz8fE= +github.com/spark-lence/tiga v0.0.0-20240517030839-e2e8385d3629/go.mod h1:MSL8X9t+qvpQ4Tq3vVPKncq9RJcCzF2XGEWkCuNhm6Q= +github.com/spark-lence/tiga v0.0.0-20240517061929-e81eba889226 h1:WKMb1r+0r5lDKrynMlC9v62EajmtLXMIIk6NzrVbcNs= +github.com/spark-lence/tiga v0.0.0-20240517061929-e81eba889226/go.mod h1:MSL8X9t+qvpQ4Tq3vVPKncq9RJcCzF2XGEWkCuNhm6Q= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= diff --git a/internal/biz/aksk.go b/internal/biz/aksk.go index 23f93f4..15494cd 100644 --- a/internal/biz/aksk.go +++ b/internal/biz/aksk.go @@ -57,34 +57,34 @@ func (a *AccessKeyAuth) AppValidator(ctx context.Context, req *gosdk.GatewayRequ } } if xDate == "" { - return "", errors.New(errors.ErrAppXDateMissing, int32(api.APPSvrCode_APP_XDATE_MISSING_ERR), codes.Unauthenticated, "app_timestamp") + return "", gosdk.NewError(errors.ErrAppXDateMissing, int32(api.APPSvrCode_APP_XDATE_MISSING_ERR), codes.Unauthenticated, "app_timestamp") } if auth == "" { - return "", errors.New(errors.ErrAppSignatureMissing, int32(api.APPSvrCode_APP_AUTH_MISSING_ERR), codes.Unauthenticated, "app_signature") + return "", gosdk.NewError(errors.ErrAppSignatureMissing, int32(api.APPSvrCode_APP_AUTH_MISSING_ERR), codes.Unauthenticated, "app_signature") } if accessKey == "" { - return "", errors.New(errors.ErrAppAccessKeyMissing, int32(api.APPSvrCode_APP_ACCESS_KEY_MISSING_ERR), codes.Unauthenticated, "app_access_key") + return "", gosdk.NewError(errors.ErrAppAccessKeyMissing, int32(api.APPSvrCode_APP_ACCESS_KEY_MISSING_ERR), codes.Unauthenticated, "app_access_key") } t, err := time.Parse(gosdk.DateFormat, xDate) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Unauthenticated, "sign_request") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Unauthenticated, "sign_request") } // check timestamp if time.Since(t).Abs() > time.Minute*1 { - return "", errors.New(errors.ErrRequestExpired, int32(api.APPSvrCode_APP_REQUEST_EXPIRED_ERR), codes.DeadlineExceeded, "app_timestamp") + return "", gosdk.NewError(errors.ErrRequestExpired, int32(api.APPSvrCode_APP_REQUEST_EXPIRED_ERR), codes.DeadlineExceeded, "app_timestamp") } secret, err := a.app.GetSecret(ctx, accessKey) if err != nil { - return "", errors.New(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_secret") + return "", gosdk.NewError(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_secret") } signer := gosdk.NewAppAuthSigner(accessKey, secret) sign, err := signer.Sign(req) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Unauthenticated, "sign_request") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Unauthenticated, "sign_request") } if sign != a.getSignature(auth) { - return "", errors.New(errors.ErrAppSignatureInvalid, int32(api.APPSvrCode_APP_SIGNATURE_ERR), codes.Unauthenticated, "app签名校验") + return "", gosdk.NewError(errors.ErrAppSignatureInvalid, int32(api.APPSvrCode_APP_SIGNATURE_ERR), codes.Unauthenticated, "app签名校验") } return accessKey, nil } @@ -104,7 +104,7 @@ func (a *AccessKeyAuth) getSignature(auth string) string { func (a *AccessKeyAuth) GetSecret(ctx context.Context, accessKey string) (string, error) { secret, err := a.app.GetSecret(ctx, accessKey) if err != nil { - return "", errors.New(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_secret") + return "", gosdk.NewError(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_secret") } return secret, nil } @@ -112,7 +112,7 @@ func (a *AccessKeyAuth) GetSecret(ctx context.Context, accessKey string) (string func (a *AccessKeyAuth) GetAppid(ctx context.Context, accessKey string) (string, error) { appid, err := a.app.GetAppid(ctx, accessKey) if err != nil { - return "", errors.New(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_secret") + return "", gosdk.NewError(err, int32(api.APPSvrCode_APP_UNKNOWN), codes.Unauthenticated, "app_secret") } return appid, nil } diff --git a/internal/biz/aksk_test.go b/internal/biz/aksk_test.go index 986992a..b35bde0 100644 --- a/internal/biz/aksk_test.go +++ b/internal/biz/aksk_test.go @@ -16,9 +16,9 @@ import ( cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/errors" "github.com/begonia-org/begonia/internal/pkg/utils" - "github.com/begonia-org/begonia/transport" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/app/v1" + "github.com/begonia-org/begonia/gateway" c "github.com/smartystreets/goconvey/convey" "github.com/spark-lence/tiga" @@ -55,9 +55,9 @@ func newAKSK() *biz.AccessKeyAuth { env = begonia.Env } config := config.ReadConfig(env) - repo := data.NewAppRepo(config, transport.Log) + repo := data.NewAppRepo(config, gateway.Log) cnf := cfg.NewConfig(config) - return biz.NewAccessKeyAuth(repo, cnf, transport.Log) + return biz.NewAccessKeyAuth(repo, cnf, gateway.Log) } func testGetSecret(t *testing.T) { @@ -66,7 +66,7 @@ func testGetSecret(t *testing.T) { env = begonia.Env } config := config.ReadConfig(env) - repo := data.NewAppRepo(config, transport.Log) + repo := data.NewAppRepo(config, gateway.Log) snk, _ := tiga.NewSnowflake(1) access, _ := utils.GenerateRandomString(32) akskAccess = access diff --git a/internal/biz/app.go b/internal/biz/app.go index 41e5e02..42a9f51 100644 --- a/internal/biz/app.go +++ b/internal/biz/app.go @@ -1,179 +1,179 @@ package biz import ( - "context" - "crypto/rand" - "fmt" - "strings" - "time" - - "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/errors" - api "github.com/begonia-org/go-sdk/api/app/v1" - common "github.com/begonia-org/go-sdk/common/api/v1" - "github.com/spark-lence/tiga" - "google.golang.org/grpc/codes" - "google.golang.org/protobuf/types/known/timestamppb" - "gorm.io/gorm" + "context" + "crypto/rand" + "fmt" + "strings" + "time" + + "github.com/begonia-org/begonia/internal/pkg/config" + gosdk "github.com/begonia-org/go-sdk" + api "github.com/begonia-org/go-sdk/api/app/v1" + common "github.com/begonia-org/go-sdk/common/api/v1" + "github.com/spark-lence/tiga" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" ) type AppRepo interface { - Add(ctx context.Context, apps *api.Apps) error - Get(ctx context.Context, key string) (*api.Apps, error) - Cache(ctx context.Context, prefix string, models *api.Apps, exp time.Duration) error - Del(ctx context.Context, key string) error - List(ctx context.Context, tags []string, status []api.APPStatus, page, pageSize int32) ([]*api.Apps, error) - Patch(ctx context.Context, model *api.Apps) error - GetSecret(ctx context.Context, accessKey string) (string, error) - GetAppid(ctx context.Context, accessKey string) (string, error) + Add(ctx context.Context, apps *api.Apps) error + Get(ctx context.Context, key string) (*api.Apps, error) + Cache(ctx context.Context, prefix string, models *api.Apps, exp time.Duration) error + Del(ctx context.Context, key string) error + List(ctx context.Context, tags []string, status []api.APPStatus, page, pageSize int32) ([]*api.Apps, error) + Patch(ctx context.Context, model *api.Apps) error + GetSecret(ctx context.Context, accessKey string) (string, error) + GetAppid(ctx context.Context, accessKey string) (string, error) } type AppUsecase struct { - repo AppRepo - config *config.Config - snowflake *tiga.Snowflake + repo AppRepo + config *config.Config + snowflake *tiga.Snowflake } func NewAppUsecase(repo AppRepo, config *config.Config) *AppUsecase { - sn, _ := tiga.NewSnowflake(1) - return &AppUsecase{repo: repo, config: config, snowflake: sn} + sn, _ := tiga.NewSnowflake(1) + return &AppUsecase{repo: repo, config: config, snowflake: sn} } func GenerateRandomString(n int) (string, error) { - const lettersAndDigits = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - b := make([]byte, n) - if _, err := rand.Read(b); err != nil { - return "", fmt.Errorf("Failed to generate random string: %w", err) - } - - for i := 0; i < n; i++ { - // 将随机字节转换为lettersAndDigits中的一个有效字符 - b[i] = lettersAndDigits[b[i]%byte(len(lettersAndDigits))] - } - - return string(b), nil + const lettersAndDigits = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("Failed to generate random string: %w", err) + } + + for i := 0; i < n; i++ { + // 将随机字节转换为lettersAndDigits中的一个有效字符 + b[i] = lettersAndDigits[b[i]%byte(len(lettersAndDigits))] + } + + return string(b), nil } func (a *AppUsecase) newApp() *api.Apps { - return &api.Apps{ - Status: api.APPStatus_APP_ENABLED, - IsDeleted: false, - } + return &api.Apps{ + Status: api.APPStatus_APP_ENABLED, + IsDeleted: false, + } } func (a *AppUsecase) CreateApp(ctx context.Context, in *api.AppsRequest, owner string) (*api.Apps, error) { - appid := GenerateAppid(a.snowflake) - accessKey, err := GenerateAppAccessKey() - if err != nil { - return nil, errors.New(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "generate_app_access_key") - - } - secret, err := GenerateAppSecret() - if err != nil { - return nil, errors.New(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "generate_app_secret_key") - } - app := a.newApp() - app.AccessKey = accessKey - app.Secret = secret - app.Appid = appid - app.Name = in.Name - app.Description = in.Description - app.Tags = in.Tags - app.Owner = owner - err = a.Put(ctx, app, owner) - if err != nil { - return nil, err - - } - return app, nil + appid := GenerateAppid(a.snowflake) + accessKey, err := GenerateAppAccessKey() + if err != nil { + return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "generate_app_access_key") + + } + secret, err := GenerateAppSecret() + if err != nil { + return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "generate_app_secret_key") + } + app := a.newApp() + app.AccessKey = accessKey + app.Secret = secret + app.Appid = appid + app.Name = in.Name + app.Description = in.Description + app.Tags = in.Tags + app.Owner = owner + err = a.Put(ctx, app, owner) + if err != nil { + return nil, err + + } + return app, nil } func GenerateAppid(snowflake *tiga.Snowflake) string { - appid := snowflake.GenerateIDString() - return appid + appid := snowflake.GenerateIDString() + return appid } func GenerateAppAccessKey() (string, error) { - return GenerateRandomString(32) + return GenerateRandomString(32) } func GenerateAppSecret() (string, error) { - return GenerateRandomString(64) + return GenerateRandomString(64) } // AddApps 新增并缓存app func (a *AppUsecase) Put(ctx context.Context, apps *api.Apps, owner string) (err error) { - defer func() { - if err != nil { - // log.Println(err) - if strings.Contains(err.Error(), "Duplicate entry") { - err = errors.New(err, int32(api.APPSvrCode_APP_DUPLICATE_ERR), codes.AlreadyExists, "commit_app") - } else { - err = errors.New(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "cache_apps") - - } - } - }() + defer func() { + if err != nil { + // log.Println(err) + if strings.Contains(err.Error(), "Duplicate entry") { + err = gosdk.NewError(err, int32(api.APPSvrCode_APP_DUPLICATE_ERR), codes.AlreadyExists, "commit_app") + } else { + err = gosdk.NewError(err, int32(api.APPSvrCode_APP_CREATE_ERR), codes.Internal, "cache_apps") + + } + } + }() if apps.Appid == "" { apps.Appid = GenerateAppid(a.snowflake) apps.AccessKey, _ = GenerateAppAccessKey() apps.Secret, _ = GenerateAppSecret() } - apps.Owner = owner - - err = a.repo.Add(ctx, apps) - if err != nil { - return err - } - prefix := a.config.GetAPPAccessKeyPrefix() - err = a.repo.Cache(ctx, prefix, apps, time.Duration(0)*time.Second) - return err - // return a.repo.AddApps(ctx, apps) + apps.Owner = owner + + err = a.repo.Add(ctx, apps) + if err != nil { + return err + } + prefix := a.config.GetAPPAccessKeyPrefix() + err = a.repo.Cache(ctx, prefix, apps, time.Duration(0)*time.Second) + return err + // return a.repo.AddApps(ctx, apps) } func (a *AppUsecase) Get(ctx context.Context, key string) (*api.Apps, error) { - app, err := a.repo.Get(ctx, key) - if err != nil { - return nil, errors.New(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "get_app") + app, err := a.repo.Get(ctx, key) + if err != nil { + return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "get_app") - } - return app, nil + } + return app, nil } func (a *AppUsecase) Cache(ctx context.Context, prefix string, models *api.Apps, exp time.Duration) error { - return a.repo.Cache(ctx, prefix, models, exp) + return a.repo.Cache(ctx, prefix, models, exp) } func (a *AppUsecase) Del(ctx context.Context, key string) error { - err := a.repo.Del(ctx, key) - if err != nil { - if strings.Contains(err.Error(), gorm.ErrRecordNotFound.Error()) { - return errors.New(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "delete_app") - } - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete_app") - - } - return nil + err := a.repo.Del(ctx, key) + if err != nil { + if strings.Contains(err.Error(), gorm.ErrRecordNotFound.Error()) { + return gosdk.NewError(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "delete_app") + } + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete_app") + + } + return nil } func (a *AppUsecase) Patch(ctx context.Context, in *api.AppsRequest, owner string) (*api.Apps, error) { - app, err := a.Get(ctx, in.Appid) - if err != nil { - return nil, errors.New(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "get_app") - } - app.Name = in.Name - app.Description = in.Description - app.Tags = in.Tags - app.UpdatedAt = timestamppb.Now() - app.UpdateMask = in.UpdateMask - - err = a.repo.Patch(ctx, app) - if err != nil { - if strings.Contains(err.Error(), "Duplicate entry") { - return nil, errors.New(err, int32(api.APPSvrCode_APP_DUPLICATE_ERR), codes.Internal, "update_app") - } - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "update_app") - } - return app, nil + app, err := a.Get(ctx, in.Appid) + if err != nil { + return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "get_app") + } + app.Name = in.Name + app.Description = in.Description + app.Tags = in.Tags + app.UpdatedAt = timestamppb.Now() + app.UpdateMask = in.UpdateMask + + err = a.repo.Patch(ctx, app) + if err != nil { + if strings.Contains(err.Error(), "Duplicate entry") { + return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_DUPLICATE_ERR), codes.Internal, "update_app") + } + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "update_app") + } + return app, nil } func (a *AppUsecase) List(ctx context.Context, in *api.AppsListRequest) ([]*api.Apps, error) { - apps, err := a.repo.List(ctx, in.Tags, in.Status, in.Page, in.PageSize) - if err != nil { - return nil, errors.New(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "list_app") - } - return apps, nil + apps, err := a.repo.List(ctx, in.Tags, in.Status, in.Page, in.PageSize) + if err != nil { + return nil, gosdk.NewError(err, int32(api.APPSvrCode_APP_NOT_FOUND_ERR), codes.NotFound, "list_app") + } + return apps, nil } diff --git a/internal/biz/app_test.go b/internal/biz/app_test.go index 13f91e3..899aa6a 100644 --- a/internal/biz/app_test.go +++ b/internal/biz/app_test.go @@ -8,12 +8,13 @@ import ( "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/data" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/utils" - "github.com/begonia-org/begonia/transport" api "github.com/begonia-org/go-sdk/api/app/v1" + c "github.com/smartystreets/goconvey/convey" "github.com/spark-lence/tiga" "google.golang.org/protobuf/types/known/fieldmaskpb" @@ -32,7 +33,7 @@ func newAppBiz() *biz.AppUsecase { env = begonia.Env } config := config.ReadConfig(env) - repo := data.NewAppRepo(config, transport.Log) + repo := data.NewAppRepo(config, gateway.Log) cnf := cfg.NewConfig(config) return biz.NewAppUsecase(repo, cnf) } @@ -55,7 +56,7 @@ func testPutApp(t *testing.T) { secret = app.Secret appid = app.Appid - layered := data.NewLayered(config.ReadConfig("dev"), transport.Log) + layered := data.NewLayered(config.ReadConfig("dev"), gateway.Log) env := "dev" if begonia.Env != "" { env = begonia.Env diff --git a/internal/biz/authz.go b/internal/biz/authz.go index ba9e374..011e2e1 100644 --- a/internal/biz/authz.go +++ b/internal/biz/authz.go @@ -10,6 +10,7 @@ import ( "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" "github.com/begonia-org/begonia/internal/pkg/errors" + gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/user/v1" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" @@ -45,7 +46,7 @@ func (u *AuthzUsecase) DelToken(ctx context.Context, key string) error { func (u *AuthzUsecase) AuthSeed(ctx context.Context, in *api.AuthLogAPIRequest) (string, error) { token, err := u.authCrypto.GenerateAuthSeed(in.Token) if err != nil { - return "", errors.New(fmt.Errorf("auth seed generate %w", err), int32(api.UserSvrCode_USER_LOGIN_ERR), codes.InvalidArgument, "generate_seed") + return "", gosdk.NewError(fmt.Errorf("auth seed generate %w", err), int32(api.UserSvrCode_USER_LOGIN_ERR), codes.InvalidArgument, "generate_seed") } return token, nil @@ -62,18 +63,18 @@ func (u *AuthzUsecase) getUserAuth(_ context.Context, in *api.LoginAPIRequest) ( now := time.Now().Unix() if now-timestamp > 60 { - return nil, errors.New(errors.ErrTokenExpired, int32(api.UserSvrCode_USER_TOKEN_EXPIRE_ERR.Number()), codes.InvalidArgument, "种子有效期校验") + return nil, gosdk.NewError(errors.ErrTokenExpired, int32(api.UserSvrCode_USER_TOKEN_EXPIRE_ERR.Number()), codes.InvalidArgument, "种子有效期校验") } auth := in.Auth authBytes, err := u.authCrypto.RSADecrypt(auth) if err != nil { - return nil, errors.New(errors.ErrAuthDecrypt, int32(api.UserSvrCode_USER_AUTH_DECRYPT_ERR.Number()), codes.InvalidArgument, "login_info_rsa") + return nil, gosdk.NewError(errors.ErrAuthDecrypt, int32(api.UserSvrCode_USER_AUTH_DECRYPT_ERR.Number()), codes.InvalidArgument, "login_info_rsa") } userAuth := &api.UserAuth{} err = json.Unmarshal([]byte(authBytes), userAuth) if err != nil { - return nil, errors.New(errors.ErrDecode, int32(common.Code_AUTH_ERROR), codes.InvalidArgument, "login_info_decode") + return nil, gosdk.NewError(errors.ErrDecode, int32(common.Code_AUTH_ERROR), codes.InvalidArgument, "login_info_decode") } return userAuth, nil } @@ -100,7 +101,7 @@ func (u *AuthzUsecase) GenerateJWT(ctx context.Context, user *api.Users, isKeepL token, err := tiga.GenerateJWT(payload, secret) if err != nil { - return "", errors.New(err, int32(api.UserSvrCode_USER_UNKNOWN), codes.Internal, "jwt_generate") + return "", gosdk.NewError(err, int32(api.UserSvrCode_USER_UNKNOWN), codes.Internal, "jwt_generate") } return token, nil @@ -116,7 +117,7 @@ func (u *AuthzUsecase) Login(ctx context.Context, in *api.LoginAPIRequest) (*api key, iv := u.config.GetAesConfig() account, err := tiga.EncryptAES([]byte(key), userAuth.Account, iv) if err != nil { - err := errors.New(errors.ErrEncrypt, int32(api.UserSvrCode_USER_ACCOUNT_ERR), codes.InvalidArgument, "accout_encrypt") + err := gosdk.NewError(errors.ErrEncrypt, int32(api.UserSvrCode_USER_ACCOUNT_ERR), codes.InvalidArgument, "accout_encrypt") return nil, err } @@ -125,15 +126,15 @@ func (u *AuthzUsecase) Login(ctx context.Context, in *api.LoginAPIRequest) (*api if err == nil || strings.Contains(err.Error(), "not found") { err = errors.ErrUserNotFound } - err := errors.New(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "user_query") + err := gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "user_query") return nil, err } if user.Password != userAuth.Password { - err := errors.New(errors.ErrUserPasswordInvalid, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "password_match") + err := gosdk.NewError(errors.ErrUserPasswordInvalid, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "password_match") return nil, err } if user.Status != api.USER_STATUS_ACTIVE { - err := errors.New(errors.ErrUserDisabled, int32(api.UserSvrCode_USER_DISABLED_ERR), codes.Unauthenticated, "user_query") + err := gosdk.NewError(errors.ErrUserDisabled, int32(api.UserSvrCode_USER_DISABLED_ERR), codes.Unauthenticated, "user_query") return nil, err } @@ -154,15 +155,15 @@ func (u *AuthzUsecase) Logout(ctx context.Context, req *api.LogoutAPIRequest) er md, ok := metadata.FromIncomingContext(ctx) if !ok { - return errors.New(errors.ErrNoMetadata, int32(common.Code_METADATA_MISSING), codes.InvalidArgument, "metadata_missing") + return gosdk.NewError(errors.ErrNoMetadata, int32(common.Code_METADATA_MISSING), codes.InvalidArgument, "metadata_missing") } token := md.Get("x-token") if len(token) == 0 { - return errors.New(errors.ErrTokenMissing, int32(common.Code_TOKEN_NOT_FOUND), codes.InvalidArgument, "token_missing") + return gosdk.NewError(errors.ErrTokenMissing, int32(common.Code_TOKEN_NOT_FOUND), codes.InvalidArgument, "token_missing") } err := u.repo.PutBlackList(ctx, tiga.GetMd5(token[0])) if err != nil { - return errors.New(err, int32(common.Code_AUTH_ERROR), codes.Internal, "add_black_list") + return gosdk.NewError(err, int32(common.Code_AUTH_ERROR), codes.Internal, "add_black_list") } return nil diff --git a/internal/biz/authz_test.go b/internal/biz/authz_test.go index a06f7ef..d043c64 100644 --- a/internal/biz/authz_test.go +++ b/internal/biz/authz_test.go @@ -19,13 +19,14 @@ import ( "github.com/agiledragon/gomonkey/v2" "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/data" cfg "github.com/begonia-org/begonia/internal/pkg/config" + "github.com/begonia-org/begonia/internal/pkg/crypto" "github.com/begonia-org/begonia/internal/pkg/errors" "github.com/begonia-org/begonia/internal/pkg/utils" - "github.com/begonia-org/begonia/transport" v1 "github.com/begonia-org/go-sdk/api/user/v1" "github.com/spark-lence/tiga" "google.golang.org/grpc/metadata" @@ -43,11 +44,11 @@ func newAuthzBiz() *biz.AuthzUsecase { env = begonia.Env } config := config.ReadConfig(env) - repo := data.NewAuthzRepo(config, transport.Log) - user := data.NewUserRepo(config, transport.Log) + repo := data.NewAuthzRepo(config, gateway.Log) + user := data.NewUserRepo(config, gateway.Log) crypto := crypto.NewUsersAuth() cnf := cfg.NewConfig(config) - return biz.NewAuthzUsecase(repo, user, transport.Log, crypto, cnf) + return biz.NewAuthzUsecase(repo, user, gateway.Log, crypto, cnf) } func testAuthSeed(t *testing.T) { diff --git a/internal/biz/data_test.go b/internal/biz/data_test.go index 6b32a02..73d4ea6 100644 --- a/internal/biz/data_test.go +++ b/internal/biz/data_test.go @@ -9,11 +9,11 @@ import ( "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/biz/endpoint" "github.com/begonia-org/begonia/internal/data" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/transport" loadbalance "github.com/begonia-org/go-loadbalancer" appApi "github.com/begonia-org/go-sdk/api/app/v1" api "github.com/begonia-org/go-sdk/api/user/v1" @@ -30,11 +30,11 @@ func newDataOperatorUsecase() *biz.DataOperatorUsecase { env = begonia.Env } config := config.ReadConfig(env) - repo := data.NewEndpointRepo(config, transport.Log) - repoData := data.NewOperator(config, transport.Log) + repo := data.NewEndpointRepo(config, gateway.Log) + repoData := data.NewOperator(config, gateway.Log) cnf := cfg.NewConfig(config) watcher := endpoint.NewWatcher(cnf, repo) - return biz.NewDataOperatorUsecase(repoData, cnf, transport.Log, watcher, repo) + return biz.NewDataOperatorUsecase(repoData, cnf, gateway.Log, watcher, repo) } func TestDo(t *testing.T) { @@ -47,21 +47,21 @@ func TestDo(t *testing.T) { } config := config.ReadConfig(env) cnf := cfg.NewConfig(config) - cache := data.NewLayered(config, transport.Log) + cache := data.NewLayered(config, gateway.Log) _ = cache.Del(context.Background(), "begonia:user:black:lock") _ = cache.Del(context.Background(), "begonia:user:black:last_updated") - opts := &transport.GrpcServerOptions{ - Middlewares: make([]transport.GrpcProxyMiddleware, 0), + opts := &gateway.GrpcServerOptions{ + Middlewares: make([]gateway.GrpcProxyMiddleware, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), HttpHandlers: make([]func(http.Handler) http.Handler, 0), } - gwCnf := &transport.GatewayConfig{ + gwCnf := &gateway.GatewayConfig{ GatewayAddr: "127.0.0.1:9527", GrpcProxyAddr: "127.0.0.1:12148", } - transport.New(gwCnf, opts) + gateway.New(gwCnf, opts) c.Convey("test data operator do success", t, func() { u1 := &api.Users{ Name: fmt.Sprintf("user-data-operator-%s", time.Now().Format("20060102150405")), diff --git a/internal/biz/endpoint/endpoint.go b/internal/biz/endpoint/endpoint.go index 919d0df..af349c9 100644 --- a/internal/biz/endpoint/endpoint.go +++ b/internal/biz/endpoint/endpoint.go @@ -11,6 +11,7 @@ import ( "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/errors" loadbalance "github.com/begonia-org/go-loadbalancer" + gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/endpoint/v1" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/grpc/codes" @@ -44,7 +45,7 @@ func NewEndpointUsecase(repo EndpointRepo, file *file.FileUsecase, config *confi func (e *EndpointUsecase) AddConfig(ctx context.Context, srvConfig *api.EndpointSrvConfig) (string, error) { if !loadbalance.CheckBalanceType(srvConfig.Balance) { - return "", errors.New(errors.ErrUnknownLoadBalancer, int32(api.EndpointSvrStatus_NOT_SUPPORT_BALANCE), codes.InvalidArgument, "balance_type") + return "", gosdk.NewError(errors.ErrUnknownLoadBalancer, int32(api.EndpointSvrStatus_NOT_SUPPORT_BALANCE), codes.InvalidArgument, "balance_type") } id := e.snk.GenerateIDString() @@ -64,7 +65,7 @@ func (e *EndpointUsecase) AddConfig(ctx context.Context, srvConfig *api.Endpoint log.Printf("endpoint add tags :%v", srvConfig.Tags) err := e.repo.Put(ctx, endpoint) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "put_endpoint") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "put_endpoint") } return id, nil @@ -74,12 +75,12 @@ func (e *EndpointUsecase) Patch(ctx context.Context, srvConfig *api.EndpointSrvU patch := make(map[string]interface{}) bSrvConfig, err := json.Marshal(srvConfig) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "marshal_config") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "marshal_config") } svrConfigPatch := make(map[string]interface{}) err = json.Unmarshal(bSrvConfig, &svrConfigPatch) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_config") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_config") } // 过滤掉不允许修改的字段 @@ -92,7 +93,7 @@ func (e *EndpointUsecase) Patch(ctx context.Context, srvConfig *api.EndpointSrvU patch["updated_at"] = updated_at err = e.repo.Patch(ctx, srvConfig.UniqueKey, patch) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "patch_config") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "patch_config") } return updated_at, nil } @@ -105,17 +106,17 @@ func (u *EndpointUsecase) Get(ctx context.Context, uniqueKey string) (*api.Endpo detailsKey := u.config.GetServiceKey(uniqueKey) value, err := u.repo.Get(ctx, detailsKey) if err != nil { - return nil, errors.New(fmt.Errorf("%s:%w", errors.ErrEndpointNotExists.Error(), err), int32(common.Code_NOT_FOUND), codes.NotFound, "get_endpoint") + return nil, gosdk.NewError(fmt.Errorf("%s:%w", errors.ErrEndpointNotExists.Error(), err), int32(common.Code_NOT_FOUND), codes.NotFound, "get_endpoint") } if value == "" { - return nil, errors.New(errors.ErrEndpointNotExists, int32(common.Code_NOT_FOUND), codes.NotFound, "get_endpoint") + return nil, gosdk.NewError(errors.ErrEndpointNotExists, int32(common.Code_NOT_FOUND), codes.NotFound, "get_endpoint") } // log.Printf("get endpoint value:%s", value) endpoint := &api.Endpoints{} err = json.Unmarshal([]byte(value), endpoint) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") } return endpoint, nil @@ -126,7 +127,7 @@ func (u *EndpointUsecase) List(ctx context.Context, in *api.ListEndpointRequest) if len(in.Tags) > 0 { ks, err := u.repo.GetKeysByTags(ctx, in.Tags) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_keys_by_tags") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_keys_by_tags") } keys = append(keys, ks...) diff --git a/internal/biz/endpoint/endpoint_test.go b/internal/biz/endpoint/endpoint_test.go index 5ddb2fd..d1105c4 100644 --- a/internal/biz/endpoint/endpoint_test.go +++ b/internal/biz/endpoint/endpoint_test.go @@ -15,13 +15,13 @@ import ( "github.com/agiledragon/gomonkey/v2" "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz/endpoint" "github.com/begonia-org/begonia/internal/data" cfg "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/errors" - + "github.com/begonia-org/begonia/internal/pkg/routers" - "github.com/begonia-org/begonia/transport" gwRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" goloadbalancer "github.com/begonia-org/go-loadbalancer" @@ -43,14 +43,14 @@ func newEndpointBiz() *endpoint.EndpointUsecase { } conf := config.ReadConfig(env) cnf := cfg.NewConfig(conf) - repo := data.NewEndpointRepo(conf, transport.Log) + repo := data.NewEndpointRepo(conf, gateway.Log) return endpoint.NewEndpointUsecase(repo, nil, cnf) } func testAddEndpoint(t *testing.T) { endpointBiz := newEndpointBiz() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata", "helloworld.pb") pb, err := os.ReadFile(pbFile) c.Convey("Test Add Endpoint", t, func() { @@ -217,7 +217,7 @@ func testPatchEndpoint(t *testing.T) { func testListEndpoints(t *testing.T) { endpointBiz := newEndpointBiz() _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filepath.Dir(filename)))), "testdata", "helloworld.pb") pb, err := os.ReadFile(pbFile) if err != nil { t.Error(err) @@ -298,18 +298,18 @@ func testWatcherUpdate(t *testing.T) { return } val, _ := json.Marshal(value) - opts := &transport.GrpcServerOptions{ - Middlewares: make([]transport.GrpcProxyMiddleware, 0), + opts := &gateway.GrpcServerOptions{ + Middlewares: make([]gateway.GrpcProxyMiddleware, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), HttpHandlers: make([]func(http.Handler) http.Handler, 0), } - gwCnf := &transport.GatewayConfig{ + gwCnf := &gateway.GatewayConfig{ GatewayAddr: "127.0.0.1:9527", GrpcProxyAddr: "127.0.0.1:12148", } - transport.New(gwCnf, opts) + gateway.New(gwCnf, opts) routers.NewHttpURIRouteToSrvMethod() c.Convey("Test Watcher Update", t, func() { @@ -335,7 +335,7 @@ func testWatcherUpdate(t *testing.T) { c.So(err.Error(), c.ShouldContainSubstring, "Unknown load balance type") patch.Reset() - patch2 := gomonkey.ApplyFuncReturn((*transport.GatewayServer).RegisterService, fmt.Errorf("register error")) + patch2 := gomonkey.ApplyFuncReturn((*gateway.GatewayServer).RegisterService, fmt.Errorf("register error")) defer patch2.Reset() err = watcher.Handle(context.TODO(), mvccpb.PUT, cnf.GetServiceKey(epId), string(val)) @@ -374,7 +374,7 @@ func testWatcherDel(t *testing.T) { err = watcher.Handle(context.TODO(), mvccpb.DELETE, cnf.GetServiceKey(epId), "{}") c.So(err, c.ShouldNotBeNil) - patch := gomonkey.ApplyFuncReturn((*transport.HttpEndpointImpl).DeleteEndpoint, fmt.Errorf("unregister error")) + patch := gomonkey.ApplyFuncReturn((*gateway.HttpEndpointImpl).DeleteEndpoint, fmt.Errorf("unregister error")) defer patch.Reset() err = watcher.Handle(context.TODO(), mvccpb.DELETE, cnf.GetServiceKey(epId), string(val)) c.So(err, c.ShouldNotBeNil) diff --git a/internal/biz/endpoint/utils.go b/internal/biz/endpoint/utils.go index 8c63e98..bcd5d4e 100644 --- a/internal/biz/endpoint/utils.go +++ b/internal/biz/endpoint/utils.go @@ -5,29 +5,29 @@ import ( "path/filepath" "strings" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/errors" "github.com/begonia-org/begonia/internal/pkg/routers" - "github.com/begonia-org/begonia/transport" + gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "google.golang.org/grpc/codes" ) -func deleteAll(ctx context.Context, pd transport.ProtobufDescription) error { +func deleteAll(ctx context.Context, pd gateway.ProtobufDescription) error { routersList := routers.Get() routersList.DeleteRouters(pd) - gw := transport.Get() + gw := gateway.Get() gw.DeleteLoadBalance(pd) - err := transport.Get().DeleteHandlerClient(ctx, pd) + err := gateway.Get().DeleteHandlerClient(ctx, pd) return err } -func getDescriptorSet(config *config.Config, key string, value []byte) (transport.ProtobufDescription, error) { +func getDescriptorSet(config *config.Config, key string, value []byte) (gateway.ProtobufDescription, error) { key = getEndpointId(config, key) outDir := config.GetGatewayDescriptionOut() - pd, err := transport.NewDescriptionFromBinary(value, filepath.Join(outDir, key)) + pd, err := gateway.NewDescriptionFromBinary(value, filepath.Join(outDir, key)) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "new_description_from_binary") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "new_description_from_binary") } return pd, nil } diff --git a/internal/biz/endpoint/utils_test.go b/internal/biz/endpoint/utils_test.go index 973986a..5b381cd 100644 --- a/internal/biz/endpoint/utils_test.go +++ b/internal/biz/endpoint/utils_test.go @@ -6,7 +6,7 @@ package endpoint // "reflect" // "testing" -// "github.com/begonia-org/begonia/transport" +// "github.com/begonia-org/begonia/gateway" // loadbalance "github.com/begonia-org/go-loadbalancer" // api "github.com/begonia-org/go-sdk/api/endpoint/v1" // gwRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -15,18 +15,18 @@ package endpoint // ) // func TestNewEndpoint(t *testing.T) { -// opts := &transport.GrpcServerOptions{ -// Middlewares: make([]transport.GrpcProxyMiddleware, 0), +// opts := &gateway.GrpcServerOptions{ +// Middlewares: make([]gateway.GrpcProxyMiddleware, 0), // Options: make([]grpc.ServerOption, 0), // PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), // HttpMiddlewares: make([]gwRuntime.ServeMuxOption, 0), // HttpHandlers: make([]func(http.Handler) http.Handler, 0), // } -// gwCnf := &transport.GatewayConfig{ +// gwCnf := &gateway.GatewayConfig{ // GatewayAddr: "127.0.0.1:9527", // GrpcProxyAddr: "127.0.0.1:12148", // } -// transport.New(gwCnf, opts) +// gateway.New(gwCnf, opts) // meta := []*api.EndpointMeta{{ // Addr: "127.0.0.1:12138", // Weight: 0, @@ -50,7 +50,7 @@ package endpoint // name: string(loadbalance.RRBalanceType), // endpoints: meta, // exceptErr: nil, -// exceptEndpointType: reflect.TypeOf(transport.NewGrpcEndpoint("", nil)).Elem(), +// exceptEndpointType: reflect.TypeOf(gateway.NewGrpcEndpoint("", nil)).Elem(), // }, // { // name: string(loadbalance.WRRBalanceType), @@ -68,7 +68,7 @@ package endpoint // name: string(loadbalance.ConsistentHashBalanceType), // endpoints: meta, // exceptErr: nil, -// exceptEndpointType: reflect.TypeOf(transport.NewGrpcEndpoint("", nil)).Elem(), +// exceptEndpointType: reflect.TypeOf(gateway.NewGrpcEndpoint("", nil)).Elem(), // }, // { // name: string(loadbalance.LCBalanceType), diff --git a/internal/biz/endpoint/watcher.go b/internal/biz/endpoint/watcher.go index 5bc4dfe..e407c02 100644 --- a/internal/biz/endpoint/watcher.go +++ b/internal/biz/endpoint/watcher.go @@ -6,12 +6,13 @@ import ( "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/routers" + gosdk "github.com/begonia-org/go-sdk" "go.etcd.io/etcd/api/v3/mvccpb" "encoding/json" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/errors" - "github.com/begonia-org/begonia/transport" loadbalance "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -33,33 +34,33 @@ func (g *EndpointWatcher) Update(ctx context.Context, key string, value string) routersList := routers.NewHttpURIRouteToSrvMethod() err := json.Unmarshal([]byte(value), endpoint) if err != nil { - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_endpoint") } pd, err := getDescriptorSet(g.config, key, endpoint.DescriptorSet) if err != nil { - transport.Log.Errorf(ctx,"get descriptor set error: %s", err.Error()) - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_descriptor_set") + gateway.Log.Errorf(ctx, "get descriptor set error: %s", err.Error()) + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_descriptor_set") } err = deleteAll(ctx, pd) if err != nil { - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete_descriptor") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete_descriptor") } - eps, err := transport.NewLoadBalanceEndpoint(loadbalance.BalanceType(endpoint.Balance), endpoint.GetEndpoints()) + eps, err := gateway.NewLoadBalanceEndpoint(loadbalance.BalanceType(endpoint.Balance), endpoint.GetEndpoints()) if err != nil { - return errors.New(errors.ErrUnknownLoadBalancer, int32(api.EndpointSvrStatus_NOT_SUPPORT_BALANCE), codes.InvalidArgument, "new_endpoint") + return gosdk.NewError(errors.ErrUnknownLoadBalancer, int32(api.EndpointSvrStatus_NOT_SUPPORT_BALANCE), codes.InvalidArgument, "new_endpoint") } lb, err := loadbalance.New(loadbalance.BalanceType(endpoint.Balance), eps) if err != nil { - return errors.New(fmt.Errorf("new loadbalance error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "new_loadbalance") + return gosdk.NewError(fmt.Errorf("new loadbalance error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "new_loadbalance") } // register routers // log.Print("register router") routersList.LoadAllRouters(pd) // register service to gateway - gw := transport.Get() + gw := gateway.Get() err = gw.RegisterService(ctx, pd, lb) if err != nil { - return errors.New(fmt.Errorf("register service error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "register_service") + return gosdk.NewError(fmt.Errorf("register service error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "register_service") } // err = g.repo.PutTags(ctx, endpoint.Key, endpoint.Tags) @@ -69,15 +70,15 @@ func (g *EndpointWatcher) del(ctx context.Context, key string, value string) err endpoint := &api.Endpoints{} err := json.Unmarshal([]byte(value), endpoint) if err != nil { - return errors.New(err, int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "unmarshal_endpoint") + return gosdk.NewError(err, int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "unmarshal_endpoint") } pd, err := getDescriptorSet(g.config, key, endpoint.DescriptorSet) if err != nil { - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_descriptor_set") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_descriptor_set") } err = deleteAll(ctx, pd) if err != nil { - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete_descriptor") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "delete_descriptor") } return nil } @@ -89,7 +90,7 @@ func (g *EndpointWatcher) Handle(ctx context.Context, op mvccpb.Event_EventType, case mvccpb.DELETE: return g.del(ctx, key, value) default: - return errors.New(fmt.Errorf("unknown operation"), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unknown_operation") + return gosdk.NewError(fmt.Errorf("unknown operation"), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unknown_operation") } } diff --git a/internal/biz/endpoint/watcher_test.go b/internal/biz/endpoint/watcher_test.go index 187da37..e6334cc 100644 --- a/internal/biz/endpoint/watcher_test.go +++ b/internal/biz/endpoint/watcher_test.go @@ -3,10 +3,10 @@ package endpoint_test import ( "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz/endpoint" "github.com/begonia-org/begonia/internal/data" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/transport" ) func newWatcher() *endpoint.EndpointWatcher { @@ -16,8 +16,6 @@ func newWatcher() *endpoint.EndpointWatcher { } conf := config.ReadConfig(env) cnf := cfg.NewConfig(conf) - repo := data.NewEndpointRepo(conf, transport.Log) - return endpoint.NewWatcher(cnf,repo) + repo := data.NewEndpointRepo(conf, gateway.Log) + return endpoint.NewWatcher(cnf, repo) } - - diff --git a/internal/biz/file/file.go b/internal/biz/file/file.go index d361a66..d00281a 100644 --- a/internal/biz/file/file.go +++ b/internal/biz/file/file.go @@ -15,6 +15,7 @@ import ( "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/errors" + gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/file/v1" user "github.com/begonia-org/go-sdk/api/user/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -49,7 +50,7 @@ func (f *FileUsecase) getPartsDir(key string) string { } func (f *FileUsecase) spiltKey(key string) (string, string, error) { if strings.HasPrefix(key, "/") { - return "", "", errors.New(errors.ErrInvalidFileKey, int32(api.FileSvrStatus_FILE_INVALIDATE_KEY_ERR), codes.InvalidArgument, "invalid_key") + return "", "", gosdk.NewError(errors.ErrInvalidFileKey, int32(api.FileSvrStatus_FILE_INVALIDATE_KEY_ERR), codes.InvalidArgument, "invalid_key") } if strings.Contains(key, "/") { name := filepath.Base(key) @@ -60,7 +61,7 @@ func (f *FileUsecase) spiltKey(key string) (string, string, error) { } func (f *FileUsecase) InitiateUploadFile(ctx context.Context, in *api.InitiateMultipartUploadRequest) (*api.InitiateMultipartUploadResponse, error) { if in.Key == "" || strings.HasPrefix(in.Key, "/") { - return nil, errors.New(errors.ErrInvalidFileKey, int32(api.FileSvrStatus_FILE_INVALIDATE_KEY_ERR), codes.InvalidArgument, "invalid_key") + return nil, gosdk.NewError(errors.ErrInvalidFileKey, int32(api.FileSvrStatus_FILE_INVALIDATE_KEY_ERR), codes.InvalidArgument, "invalid_key") } uploadId := f.snowflake.GenerateIDString() _, _, err := f.spiltKey(in.Key) @@ -69,7 +70,7 @@ func (f *FileUsecase) InitiateUploadFile(ctx context.Context, in *api.InitiateMu } saveDir := f.getPartsDir(uploadId) if err := os.MkdirAll(saveDir, 0755); err != nil { - err = errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_upload_dir") + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_upload_dir") return nil, err } return &api.InitiateMultipartUploadResponse{ @@ -175,10 +176,10 @@ func (f *FileUsecase) getSaveDir(key string) string { // The key is not allow start with '/'. func (f *FileUsecase) checkIn(key string) (string, error) { if key == "" || strings.HasPrefix(key, "/") { - return "", errors.New(errors.ErrInvalidFileKey, int32(api.FileSvrStatus_FILE_INVALIDATE_KEY_ERR), codes.InvalidArgument, "invalid_key") + return "", gosdk.NewError(errors.ErrInvalidFileKey, int32(api.FileSvrStatus_FILE_INVALIDATE_KEY_ERR), codes.InvalidArgument, "invalid_key") } // if authorId == "" { - // return "", errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + // return "", gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") // } // if !strings.HasPrefix(key, authorId) { // key = authorId + "/" + key @@ -193,7 +194,7 @@ func (f *FileUsecase) checkIn(key string) (string, error) { // The authorId is used to determine the directory which is as user's home dir where the file is saved. func (f *FileUsecase) Upload(ctx context.Context, in *api.UploadFileRequest, authorId string) (*api.UploadFileResponse, error) { if authorId == "" { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } key, err := f.checkIn(in.Key) @@ -207,12 +208,12 @@ func (f *FileUsecase) Upload(ctx context.Context, in *api.UploadFileRequest, aut saveDir := f.getSaveDir(in.Key) if err := os.MkdirAll(saveDir, 0755); err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_upload_dir") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_upload_dir") } filePath := filepath.Join(saveDir, filename) file, err := os.Create(filePath) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_file") } defer file.Close() defer func() { @@ -222,7 +223,7 @@ func (f *FileUsecase) Upload(ctx context.Context, in *api.UploadFileRequest, aut }() _, err = file.Write(in.Content) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "write_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "write_file") } uri, err := f.getUri(filePath) if err != nil { @@ -231,7 +232,7 @@ func (f *FileUsecase) Upload(ctx context.Context, in *api.UploadFileRequest, aut sha256Hash := getSHA256(in.Content) if sha256Hash != in.Sha256 { os.Remove(filePath) - err = errors.New(errors.ErrSHA256NotMatch, int32(api.FileSvrStatus_FILE_SHA256_NOT_MATCH_ERR), codes.InvalidArgument, "sha256_not_match") + err = gosdk.NewError(errors.ErrSHA256NotMatch, int32(api.FileSvrStatus_FILE_SHA256_NOT_MATCH_ERR), codes.InvalidArgument, "sha256_not_match") return nil, err } @@ -239,7 +240,7 @@ func (f *FileUsecase) Upload(ctx context.Context, in *api.UploadFileRequest, aut if in.UseVersion { commitId, err = f.commitFile(saveDir, filename, authorId, "fs@begonia.com") if err != nil { - err = errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "commit_file") + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "commit_file") return nil, err } @@ -254,10 +255,10 @@ func (f *FileUsecase) Upload(ctx context.Context, in *api.UploadFileRequest, aut func (f *FileUsecase) UploadMultipartFileFile(ctx context.Context, in *api.UploadMultipartFileRequest) (*api.UploadMultipartFileResponse, error) { if in.UploadId == "" { - return nil, errors.New(errors.ErrUploadIdMissing, int32(api.FileSvrStatus_FILE_UPLOADID_MISSING_ERR), codes.InvalidArgument, "upload_id_not_found") + return nil, gosdk.NewError(errors.ErrUploadIdMissing, int32(api.FileSvrStatus_FILE_UPLOADID_MISSING_ERR), codes.InvalidArgument, "upload_id_not_found") } if in.PartNumber <= 0 { - return nil, errors.New(errors.ErrPartNumberMissing, int32(api.FileSvrStatus_FILE_PARTNUMBER_MISSING_ERR), codes.InvalidArgument, "part_number_not_found") + return nil, gosdk.NewError(errors.ErrPartNumberMissing, int32(api.FileSvrStatus_FILE_PARTNUMBER_MISSING_ERR), codes.InvalidArgument, "part_number_not_found") } @@ -265,7 +266,7 @@ func (f *FileUsecase) UploadMultipartFileFile(ctx context.Context, in *api.Uploa // get upload dir by uploadId saveDir := f.getPartsDir(uploadId) if !pathExists(saveDir) { - err := errors.New(errors.ErrUploadNotInitiate, int32(api.FileSvrStatus_FILE_UPLOAD_NOT_INITIATE_ERR), codes.NotFound, "upload_dir_not_found") + err := gosdk.NewError(errors.ErrUploadNotInitiate, int32(api.FileSvrStatus_FILE_UPLOAD_NOT_INITIATE_ERR), codes.NotFound, "upload_dir_not_found") return nil, err } @@ -273,18 +274,18 @@ func (f *FileUsecase) UploadMultipartFileFile(ctx context.Context, in *api.Uploa file, err := os.Create(filePath) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_file") } defer file.Close() _, err = file.Write(in.Content) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "write_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "write_file") } sha256Hash := getSHA256(in.Content) if sha256Hash != in.Sha256 { os.Remove(filePath) - err := errors.New(errors.ErrSHA256NotMatch, int32(api.FileSvrStatus_FILE_SHA256_NOT_MATCH_ERR), codes.InvalidArgument, "sha256_not_match") + err := gosdk.NewError(errors.ErrSHA256NotMatch, int32(api.FileSvrStatus_FILE_SHA256_NOT_MATCH_ERR), codes.InvalidArgument, "sha256_not_match") return nil, err } @@ -365,7 +366,7 @@ func (f *FileUsecase) getUri(filePath string) (string, error) { log.Printf("uploadRootDir:%s,filePath:%s", uploadRootDir, filePath) uri, err := filepath.Rel(uploadRootDir, filePath) if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_file_uri") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_file_uri") } return uri, nil @@ -373,19 +374,19 @@ func (f *FileUsecase) getUri(filePath string) (string, error) { func (f *FileUsecase) AbortMultipartUpload(ctx context.Context, in *api.AbortMultipartUploadRequest) (*api.AbortMultipartUploadResponse, error) { partsDir := f.getPartsDir(in.UploadId) if !pathExists(partsDir) { - err := errors.New(errors.ErrUploadIdNotFound, int32(api.FileSvrStatus_FILE_NOT_FOUND_UPLOADID_ERR), codes.NotFound, "upload_id_not_found") + err := gosdk.NewError(errors.ErrUploadIdNotFound, int32(api.FileSvrStatus_FILE_NOT_FOUND_UPLOADID_ERR), codes.NotFound, "upload_id_not_found") return nil, err } err := os.RemoveAll(partsDir) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "remove_parts_dir") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "remove_parts_dir") } return &api.AbortMultipartUploadResponse{}, nil } func (f *FileUsecase) CompleteMultipartUploadFile(ctx context.Context, in *api.CompleteMultipartUploadRequest, authorId string) (*api.CompleteMultipartUploadResponse, error) { if authorId == "" { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } key, err := f.checkIn(in.Key) if err != nil { @@ -394,13 +395,13 @@ func (f *FileUsecase) CompleteMultipartUploadFile(ctx context.Context, in *api.C in.Key = filepath.Join(authorId, key) partsDir := f.getPartsDir(in.UploadId) if !pathExists(partsDir) { - err := errors.New(errors.ErrUploadIdNotFound, int32(api.FileSvrStatus_FILE_NOT_FOUND_UPLOADID_ERR), codes.NotFound, "upload_id_not_found") + err := gosdk.NewError(errors.ErrUploadIdNotFound, int32(api.FileSvrStatus_FILE_NOT_FOUND_UPLOADID_ERR), codes.NotFound, "upload_id_not_found") return nil, err } files, err := f.getSortedFiles(partsDir) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_sorted_files") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_sorted_files") } @@ -410,16 +411,16 @@ func (f *FileUsecase) CompleteMultipartUploadFile(ctx context.Context, in *api.C // merge files to uploadDir/key err = f.mergeFiles(files, filePath) if err != nil { - return nil, errors.New(fmt.Errorf("merge file error"), int32(common.Code_INTERNAL_ERROR), codes.Internal, "merge_files") + return nil, gosdk.NewError(fmt.Errorf("merge file error"), int32(common.Code_INTERNAL_ERROR), codes.Internal, "merge_files") } // the parts file has been merged, remove the parts dir to uploadDir/parts/key keyParts := f.getPersistenceKeyParts(in.Key) if err = os.MkdirAll(keyParts, 0755); err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_parts_dir") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "create_parts_dir") } err = f.mvDir(partsDir, keyParts) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "mv_dir") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "mv_dir") } uri, err := f.getUri(filePath) @@ -431,7 +432,7 @@ func (f *FileUsecase) CompleteMultipartUploadFile(ctx context.Context, in *api.C if in.UseVersion { commit, err = f.commitFile(saveDir, filename, authorId, "begonia@begonia.com") if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "commit_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "commit_file") } } @@ -450,17 +451,17 @@ func (f *FileUsecase) DownloadForRange(ctx context.Context, in *api.DownloadRequ } in.Key = key if start > end { - err := errors.New(errors.ErrInvalidRange, int32(api.FileSvrStatus_FILE_INVALIDATE_RANGE_ERR), codes.InvalidArgument, "invalid_range") + err := gosdk.NewError(errors.ErrInvalidRange, int32(api.FileSvrStatus_FILE_INVALIDATE_RANGE_ERR), codes.InvalidArgument, "invalid_range") return nil, 0, err } file, err := f.getReader(in.Key, in.Version) if err == git.ErrRepositoryNotExists || os.IsNotExist(err) { - return nil, 0, errors.New(err, int32(common.Code_NOT_FOUND), codes.NotFound, "file_not_found") + return nil, 0, gosdk.NewError(err, int32(common.Code_NOT_FOUND), codes.NotFound, "file_not_found") } if err != nil { - return nil, 0, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "open_file") + return nil, 0, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "open_file") } defer file.Close() @@ -473,7 +474,7 @@ func (f *FileUsecase) DownloadForRange(ctx context.Context, in *api.DownloadRequ // log.Printf("start:%d,end:%d,bufsize:%d", start, end, len(buf)) _, err = file.ReadAt(buf, start) if err != nil && err != io.EOF { - err = errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") return nil, 0, err } return buf, file.Size(), nil @@ -488,19 +489,19 @@ func (f *FileUsecase) Metadata(ctx context.Context, in *api.FileMetadataRequest, in.Key = key file, err := f.getReader(in.Key, in.Version) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "open_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "open_file") } hasher := sha256.New() reader, err := file.Reader() if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") } defer reader.Close() defer file.Close() // 以流式传输的方式将文件内容写入哈希对象 if _, err := io.Copy(hasher, reader); err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "hash_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "hash_file") } // 计算最终的哈希值 @@ -564,10 +565,10 @@ func (f *FileUsecase) Version(ctx context.Context, key, authorId string) (string // fileDir := filepath.Join(f.config.GetUploadDir(), in.Key) file, err := f.getReader(key, "latest") if err == git.ErrRepositoryNotExists { - return "", errors.New(err, int32(common.Code_NOT_FOUND), codes.NotFound, "file_not_found") + return "", gosdk.NewError(err, int32(common.Code_NOT_FOUND), codes.NotFound, "file_not_found") } if err != nil { - return "", errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "open_file") + return "", gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "open_file") } defer file.Close() return file.(FileVersionReader).Version(), nil @@ -591,18 +592,18 @@ func (f *FileUsecase) Download(ctx context.Context, in *api.DownloadRequest, aut file, err := f.getReader(in.Key, in.Version) if err != nil { code, httpCode := f.checkStatusCode(err) - return nil, errors.New(err, code, httpCode, "open_file") + return nil, gosdk.NewError(err, code, httpCode, "open_file") } buf := make([]byte, file.Size()) reader, err := file.Reader() if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") } defer reader.Close() defer file.Close() _, err = reader.Read(buf) if err != nil && err != io.EOF { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "read_file") } return buf, nil @@ -617,7 +618,7 @@ func (f *FileUsecase) Delete(ctx context.Context, in *api.DeleteRequest, authorI file, err := f.getReader(in.Key, "") if err != nil && !os.IsNotExist(err) { // log.Printf("err:%v", err) - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "remove_file") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "remove_file") } if file != nil { @@ -628,7 +629,7 @@ func (f *FileUsecase) Delete(ctx context.Context, in *api.DeleteRequest, authorI if err != nil { // log.Printf("version err:%v", err) code, rpcCode := f.checkStatusCode(err) - return nil, errors.New(err, code, rpcCode, "remove_file") + return nil, gosdk.NewError(err, code, rpcCode, "remove_file") } defer versionFile.Close() os.Remove(versionFile.Name()) @@ -636,7 +637,7 @@ func (f *FileUsecase) Delete(ctx context.Context, in *api.DeleteRequest, authorI keyParts := f.getPersistenceKeyParts(in.Key) err = os.RemoveAll(keyParts) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "remove_parts_dir") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "remove_parts_dir") } return &api.DeleteResponse{}, nil } diff --git a/internal/biz/user.go b/internal/biz/user.go index ff516aa..a362422 100644 --- a/internal/biz/user.go +++ b/internal/biz/user.go @@ -6,7 +6,7 @@ import ( "time" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/internal/pkg/errors" + gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/user/v1" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/redis/go-redis/v9" @@ -40,9 +40,9 @@ func (u *UserUsecase) Add(ctx context.Context, users *api.Users) (err error) { if err != nil { // log.Println(err) if strings.Contains(err.Error(), "Duplicate entry") { - err = errors.New(err, int32(api.UserSvrCode_USER_USERNAME_DUPLICATE_ERR), codes.AlreadyExists, "commit_app") + err = gosdk.NewError(err, int32(api.UserSvrCode_USER_USERNAME_DUPLICATE_ERR), codes.AlreadyExists, "commit_app") } else { - err = errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "cache_apps") + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "cache_apps") } } @@ -55,7 +55,7 @@ func (u *UserUsecase) Add(ctx context.Context, users *api.Users) (err error) { func (u *UserUsecase) Get(ctx context.Context, key string) (*api.Users, error) { user, err := u.repo.Get(ctx, key) if err != nil { - return nil, errors.New(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") + return nil, gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") } return user, nil } @@ -63,12 +63,12 @@ func (u *UserUsecase) Update(ctx context.Context, model *api.Users) error { err := u.repo.Patch(ctx, model) if err != nil { if strings.Contains(err.Error(), "not found") { - return errors.New(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") + return gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") } if strings.Contains(err.Error(), "Duplicate entry") { - return errors.New(err, int32(api.UserSvrCode_USER_USERNAME_DUPLICATE_ERR), codes.AlreadyExists, "patch_app") + return gosdk.NewError(err, int32(api.UserSvrCode_USER_USERNAME_DUPLICATE_ERR), codes.AlreadyExists, "patch_app") } - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") } return nil } @@ -76,10 +76,10 @@ func (u *UserUsecase) Delete(ctx context.Context, uid string) error { err := u.repo.Del(ctx, uid) if err != nil { if strings.Contains(err.Error(), "not found") { - return errors.New(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") + return gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "get_user") } - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_user") } return nil } @@ -88,10 +88,9 @@ func (u *UserUsecase) List(ctx context.Context, dept []string, status []api.USER users, err := u.repo.List(ctx, dept, status, page, pageSize) if err != nil { if strings.Contains(err.Error(), "not found") { - return nil, errors.New(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "list_users") + return nil, gosdk.NewError(err, int32(api.UserSvrCode_USER_NOT_FOUND_ERR), codes.NotFound, "list_users") } - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "list_user") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "list_user") } return users, nil } - diff --git a/internal/biz/user_test.go b/internal/biz/user_test.go index 20753af..0456be8 100644 --- a/internal/biz/user_test.go +++ b/internal/biz/user_test.go @@ -9,12 +9,13 @@ import ( "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/biz" "github.com/begonia-org/begonia/internal/data" cfg "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/transport" api "github.com/begonia-org/go-sdk/api/user/v1" c "github.com/smartystreets/goconvey/convey" + "github.com/spark-lence/tiga" "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -31,7 +32,7 @@ func newUserBiz() *biz.UserUsecase { env = begonia.Env } config := config.ReadConfig(env) - repo := data.NewUserRepo(config, transport.Log) + repo := data.NewUserRepo(config, gateway.Log) cnf := cfg.NewConfig(config) return biz.NewUserUsecase(repo, cnf) } diff --git a/internal/data/app_test.go b/internal/data/app_test.go index c4ff689..e2da3d2 100644 --- a/internal/data/app_test.go +++ b/internal/data/app_test.go @@ -12,10 +12,9 @@ import ( cfg "github.com/begonia-org/begonia/config" "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/utils" - "github.com/begonia-org/begonia/transport" api "github.com/begonia-org/go-sdk/api/app/v1" - "github.com/cockroachdb/errors" c "github.com/smartystreets/goconvey/convey" + "github.com/begonia-org/begonia/gateway" "github.com/spark-lence/tiga" "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -47,7 +46,7 @@ func addTest(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) access, _ := generateRandomString(32) accessKey = access @@ -76,7 +75,7 @@ func addTest(t *testing.T) { value, err = layered.GetFromLocal(context.Background(), cacheKey) c.So(err, c.ShouldBeNil) c.So(string(value), c.ShouldEqual, secret) - patch := gomonkey.ApplyFuncReturn((*LayeredCache).Get, nil, errors.New("error")) + patch := gomonkey.ApplyFuncReturn((*LayeredCache).Get, nil, fmt.Errorf("error")) defer patch.Reset() val, err := repo.GetSecret(context.Background(), access) @@ -92,7 +91,7 @@ func getTest(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) app, err := repo.Get(context.TODO(), appid) c.So(err, c.ShouldBeNil) c.So(app.Appid, c.ShouldEqual, appid) @@ -112,7 +111,7 @@ func duplicateNameTest(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) // snk, _ := tiga.NewSnowflake(1) access, _ := generateRandomString(32) accessKey = access @@ -141,7 +140,7 @@ func patchTest(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) err := repo.Patch(context.TODO(), &api.Apps{ Appid: appid, AccessKey: accessKey, @@ -178,7 +177,7 @@ func testListApp(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) rand := rand.New(rand.NewSource(time.Now().UnixNano())) status := []api.APPStatus{api.APPStatus_APP_ENABLED, api.APPStatus_APP_DISABLED} @@ -235,7 +234,7 @@ func delTest(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) err := repo.Del(context.TODO(), appid) c.So(err, c.ShouldBeNil) _, err = repo.Get(context.TODO(), appid) diff --git a/internal/data/authz_test.go b/internal/data/authz_test.go index 23a5d06..a684119 100644 --- a/internal/data/authz_test.go +++ b/internal/data/authz_test.go @@ -7,9 +7,9 @@ import ( "github.com/begonia-org/begonia" cfg "github.com/begonia-org/begonia/config" - "github.com/begonia-org/begonia/transport" c "github.com/smartystreets/goconvey/convey" "github.com/spark-lence/tiga" + "github.com/begonia-org/begonia/gateway" ) var token = "" @@ -21,7 +21,7 @@ func testCacheToken(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAuthzRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAuthzRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) token = snk.GenerateIDString() err := repo.CacheToken(context.TODO(), "test:token", token, 5*time.Second) @@ -33,7 +33,7 @@ func testGetToken(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAuthzRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAuthzRepo(cfg.ReadConfig(env), gateway.Log) c.Convey("test get token", t, func() { tk := repo.GetToken(context.TODO(), "test:token") @@ -52,7 +52,7 @@ func deleteToken(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAuthzRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAuthzRepo(cfg.ReadConfig(env), gateway.Log) c.Convey("test delete exp token", t, func() { err := repo.DelToken(context.TODO(), "test:token") c.So(err, c.ShouldBeNil) @@ -73,7 +73,7 @@ func testPutBlacklist(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAuthzRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAuthzRepo(cfg.ReadConfig(env), gateway.Log) c.Convey("test put blacklist", t, func() { err := repo.PutBlackList(context.TODO(), token) c.So(err, c.ShouldBeNil) @@ -84,7 +84,7 @@ func testCheckInBlackList(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAuthzRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAuthzRepo(cfg.ReadConfig(env), gateway.Log) c.Convey("test check in blacklist", t, func() { b, err := repo.CheckInBlackList(context.TODO(), token) diff --git a/internal/data/data.go b/internal/data/data.go index 4beaabe..5c4aaa3 100644 --- a/internal/data/data.go +++ b/internal/data/data.go @@ -160,7 +160,7 @@ func (d *Data) BatchUpdates(ctx context.Context, models []SourceType) error { if size == 1 { if err != nil { - return errors.New("获取第一个元素失败") + return fmt.Errorf("获取第一个元素失败") } return d.db.Update(ctx, model, model) } @@ -265,7 +265,6 @@ func (d *Data) DelCacheByTx(ctx context.Context, keys ...string) redis.Pipeliner return pipe } - func (d *Data) BatchEtcdDelete(models []SourceType) error { if len(models) == 0 { return nil diff --git a/internal/data/endpoint_test.go b/internal/data/endpoint_test.go index e7f9b6c..708586a 100644 --- a/internal/data/endpoint_test.go +++ b/internal/data/endpoint_test.go @@ -13,9 +13,9 @@ import ( "github.com/begonia-org/begonia" cfg "github.com/begonia-org/begonia/config" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/transport" goloadbalancer "github.com/begonia-org/go-loadbalancer" api "github.com/begonia-org/go-sdk/api/endpoint/v1" + "github.com/begonia-org/begonia/gateway" c "github.com/smartystreets/goconvey/convey" "github.com/spark-lence/tiga" "google.golang.org/protobuf/types/known/timestamppb" @@ -33,10 +33,10 @@ func putTest(t *testing.T) { env = begonia.Env } _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata", "helloworld.pb") pb, _ := os.ReadFile(pbFile) conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) snk, _ := tiga.NewSnowflake(1) endpointId = snk.GenerateIDString() err := repo.Put(context.Background(), &api.Endpoints{ @@ -75,7 +75,7 @@ func getEndpointTest(t *testing.T) { env = begonia.Env } conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) cnf := config.NewConfig(conf) endpointKey := cnf.GetServiceKey(endpointId) data, err := repo.Get(context.Background(), endpointKey) @@ -96,7 +96,7 @@ func getKeysByTagsTest(t *testing.T) { env = begonia.Env } conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) keys, err := repo.GetKeysByTags(context.Background(), []string{tag}) c.So(err, c.ShouldBeNil) c.So(keys, c.ShouldNotBeEmpty) @@ -111,10 +111,10 @@ func testList(t *testing.T) { env = begonia.Env } _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filepath.Dir(filename)), "integration", "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata", "helloworld.pb") pb, _ := os.ReadFile(pbFile) conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) snk, _ := tiga.NewSnowflake(1) enps := make([]string, 0) c.Convey("test list", t, func() { @@ -168,7 +168,7 @@ func patchEndpointTest(t *testing.T) { env = begonia.Env } conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) cnf := config.NewConfig(conf) endpointKey := cnf.GetServiceKey(endpointId) tag1 := fmt.Sprintf("test-data-patch-%s", time.Now().Format("20060102150405")) @@ -208,7 +208,7 @@ func delEndpointTest(t *testing.T) { env = begonia.Env } conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) cnf := config.NewConfig(conf) endpointKey := cnf.GetServiceKey(endpointId) err := repo.Del(context.Background(), endpointId) @@ -233,7 +233,7 @@ func putTagsTest(t *testing.T) { env = begonia.Env } conf := cfg.ReadConfig(env) - repo := NewEndpointRepo(conf, transport.Log) + repo := NewEndpointRepo(conf, gateway.Log) cnf := config.NewConfig(conf) endpointKey := cnf.GetServiceKey(endpointId) tag1 := fmt.Sprintf("test1-data-%s", time.Now().Format("20060102150405")) diff --git a/internal/data/operator_test.go b/internal/data/operator_test.go index 328fdaf..c280af2 100644 --- a/internal/data/operator_test.go +++ b/internal/data/operator_test.go @@ -9,7 +9,7 @@ import ( "github.com/begonia-org/begonia" cfg "github.com/begonia-org/begonia/config" - "github.com/begonia-org/begonia/transport" + "github.com/begonia-org/begonia/gateway" appApi "github.com/begonia-org/go-sdk/api/app/v1" api "github.com/begonia-org/go-sdk/api/user/v1" c "github.com/smartystreets/goconvey/convey" @@ -28,7 +28,7 @@ func testGetAllForbiddenUsers(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewUserRepo(cfg.ReadConfig(env), transport.Log) + repo := NewUserRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) // rand.Seed(time.Now().UnixNano()) rand := rand.New(rand.NewSource(time.Now().UnixNano())) @@ -52,7 +52,7 @@ func testGetAllForbiddenUsers(t *testing.T) { t.Errorf("add user error:%v", err) } } - operator := NewOperator(cfg.ReadConfig(env), transport.Log) + operator := NewOperator(cfg.ReadConfig(env), gateway.Log) var err error users, err = operator.GetAllForbiddenUsers(context.Background()) c.So(err, c.ShouldBeNil) @@ -73,7 +73,7 @@ func testFlashUsersCache(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - operator := NewOperator(cfg.ReadConfig(env), transport.Log) + operator := NewOperator(cfg.ReadConfig(env), gateway.Log) // operator.(*dataOperatorRepo).local.OnStart() operator.OnStart() lock, err := operator.Locker(context.Background(), "test:user:blacklist:lock", 3*time.Second) @@ -109,7 +109,7 @@ func testGetAllApp(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewAppRepo(cfg.ReadConfig(env), transport.Log) + repo := NewAppRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) for i := 0; i < 10; i++ { appAccess, _ := generateRandomString(32) @@ -128,7 +128,7 @@ func testGetAllApp(t *testing.T) { c.So(err, c.ShouldBeNil) } - operator := NewOperator(cfg.ReadConfig(env), transport.Log) + operator := NewOperator(cfg.ReadConfig(env), gateway.Log) var err error apps, err = operator.GetAllApps(context.Background()) c.So(err, c.ShouldBeNil) @@ -142,7 +142,7 @@ func testFlashAppsCache(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - operator := NewOperator(cfg.ReadConfig(env), transport.Log) + operator := NewOperator(cfg.ReadConfig(env), gateway.Log) err := operator.FlashAppsCache(context.Background(), "test:app:blacklist", apps, 10*time.Second) c.So(err, c.ShouldBeNil) isOK := true @@ -164,7 +164,7 @@ func testLastUpdated(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - operator := NewOperator(cfg.ReadConfig(env), transport.Log) + operator := NewOperator(cfg.ReadConfig(env), gateway.Log) t, err := operator.LastUpdated(context.Background(), "test:user:blacklist") c.So(err, c.ShouldBeNil) @@ -177,7 +177,7 @@ func testWatcher(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - operator := NewOperator(cfg.ReadConfig(env), transport.Log) + operator := NewOperator(cfg.ReadConfig(env), gateway.Log) updated := "" deleted := "" go func() { diff --git a/internal/data/user_test.go b/internal/data/user_test.go index 6560758..7ff7a1b 100644 --- a/internal/data/user_test.go +++ b/internal/data/user_test.go @@ -9,8 +9,8 @@ import ( "github.com/begonia-org/begonia" cfg "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/config" - "github.com/begonia-org/begonia/transport" api "github.com/begonia-org/go-sdk/api/user/v1" c "github.com/smartystreets/goconvey/convey" "github.com/spark-lence/tiga" @@ -32,7 +32,7 @@ func testAddUser(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewUserRepo(cfg.ReadConfig(env), transport.Log) + repo := NewUserRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) uid = snk.GenerateIDString() err := repo.Add(context.TODO(), &api.Users{ @@ -92,7 +92,7 @@ func testGetUser(t *testing.T) { env = begonia.Env } conf := cfg.ReadConfig(env) - repo := NewUserRepo(conf, transport.Log) + repo := NewUserRepo(conf, gateway.Log) user, err := repo.Get(context.TODO(), uid) c.So(err, c.ShouldBeNil) c.So(user, c.ShouldNotBeNil) @@ -121,7 +121,7 @@ func testUpdateUser(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewUserRepo(cfg.ReadConfig(env), transport.Log) + repo := NewUserRepo(cfg.ReadConfig(env), gateway.Log) user, err := repo.Get(context.TODO(), uid) c.So(err, c.ShouldBeNil) c.So(user, c.ShouldNotBeNil) @@ -163,7 +163,7 @@ func testDelUser(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewUserRepo(cfg.ReadConfig(env), transport.Log) + repo := NewUserRepo(cfg.ReadConfig(env), gateway.Log) err := repo.Del(context.TODO(), uid) c.So(err, c.ShouldBeNil) _, err = repo.Get(context.TODO(), uid) @@ -178,7 +178,7 @@ func testListUser(t *testing.T) { if begonia.Env != "" { env = begonia.Env } - repo := NewUserRepo(cfg.ReadConfig(env), transport.Log) + repo := NewUserRepo(cfg.ReadConfig(env), gateway.Log) snk, _ := tiga.NewSnowflake(1) // rand.Seed(time.Now().UnixNano()) rand := rand.New(rand.NewSource(time.Now().UnixNano())) diff --git a/internal/integration/user_test.go b/internal/integration/user_test.go deleted file mode 100644 index 01a28fe..0000000 --- a/internal/integration/user_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package integration_test - -import ( - "context" - "fmt" - "testing" - "time" - - api "github.com/begonia-org/go-sdk/api/user/v1" - "github.com/begonia-org/go-sdk/client" - common "github.com/begonia-org/go-sdk/common/api/v1" - c "github.com/smartystreets/goconvey/convey" -) - -var uid = "" - -func addUser(t *testing.T) { - c.Convey( - "test add user", - t, - func() { - apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) - name:=fmt.Sprintf("user-%s", time.Now().Format("20060102150405")) - rsp, err := apiClient.PostUser(context.Background(), &api.PostUserRequest{ - Name: name, - Password: "123456", - Email: fmt.Sprintf("%s@example.com",name), - Role: api.Role_ADMIN, - Dept: "development", - Avatar: "https://www.example.com/avatar.jpg", - Owner: "test-user-01", - Phone: time.Now().Format("20060102150405"), - }) - c.So(err, c.ShouldBeNil) - c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) - c.So(rsp.Uid, c.ShouldNotBeEmpty) - uid = rsp.Uid - - }) -} - -func getUser(t *testing.T) { - c.Convey( - "test get user", - t, - func() { - apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) - rsp, err := apiClient.GetUser(context.Background(), uid) - c.So(err, c.ShouldBeNil) - c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) - }) -} -func deleteUser(t *testing.T) { - c.Convey( - "test delete user", - t, - func() { - apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) - rsp, err := apiClient.DeleteUser(context.Background(), uid) - c.So(err, c.ShouldBeNil) - c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) - - _, err = apiClient.GetUser(context.Background(), uid) - c.So(err, c.ShouldNotBeNil) - - }) -} -func patchUser(t *testing.T) { - c.Convey( - "test patch user", - t, - func() { - apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) - rsp, err := apiClient.PatchUser(context.Background(), uid, map[string]interface{}{ - "password": "123456ecfasddccddd", - "email": fmt.Sprintf("%s@example.com",time.Now().Format("20060102150405"))}) - c.So(err, c.ShouldBeNil) - c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) - }) -} -func TestUser(t *testing.T) { - t.Run("add user", addUser) - t.Run("get user", getUser) - // uid = "442210231930327040" - t.Run("patch user", patchUser) - t.Run("delete user", deleteUser) -} diff --git a/internal/pkg/errors/constant.go b/internal/pkg/errors/constant.go index 6e3be13..465e00b 100644 --- a/internal/pkg/errors/constant.go +++ b/internal/pkg/errors/constant.go @@ -3,89 +3,83 @@ package errors import ( "errors" "fmt" - "runtime" - - common "github.com/begonia-org/go-sdk/common/api/v1" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/anypb" ) -type SrvError struct { - Err error - ErrCode int32 - GrpcCode codes.Code - Action string -} - -func Is(err error, target error) bool { - return errors.Is(err, target) -} -func As(err error, target interface{}) bool { - return errors.As(err, target) - -} - -type Options func(*common.Errors) - -func WithClientMessage(msg string) Options { - return func(e *common.Errors) { - e.ToClientMessage = msg - } - -} -func New(err error, code int32, grpcCode codes.Code, action string, opts ...Options) error { - pc, file, line, ok := runtime.Caller(1) - if !ok { - file = "unknown" - line = 0 - } - fn := runtime.FuncForPC(pc) - funcName := "unknown" - if fn != nil { - funcName = fn.Name() - } - - srvErr := &common.Errors{ - Code: code, - Message: err.Error(), - Action: action, - File: file, - Line: int32(line), - Fn: funcName, - } - // details,_:=structpb.NewStruct(map[string]interface{}{ - // "code":code, - // "message":err.Error(), - // "action":action, - // "file":file, - // "line":line, - // "fn":funcName, - // }) - for _, opt := range opts { - opt(srvErr) - } - st := status.New(grpcCode, err.Error()) - detailProto, err := anypb.New(srvErr) - if err != nil { - return status.Errorf(codes.Internal, "failed to marshal error details: %v", err) - } - st, err = st.WithDetails(detailProto) - if err != nil { - return status.Errorf(codes.Internal, "failed to marshal error details: %v", err) - - } - return st.Err() - // srvErr:=status.New(0,err.Error()) - // srvErr.WithDetails() - // return nil -} -func (s *SrvError) Error() string { - return fmt.Sprintf("%s|%d", s.Err.Error(), s.ErrCode) -} -func (s *SrvError) Code() int32 { - return s.ErrCode -} +// type SrvError struct { +// Err error +// ErrCode int32 +// GrpcCode codes.Code +// Action string +// } + +// func Is(err error, target error) bool { +// return errors.Is(err, target) +// } +// func As(err error, target interface{}) bool { +// return errors.As(err, target) + +// } + +// type Options func(*common.Errors) + +// func WithClientMessage(msg string) Options { +// return func(e *common.Errors) { +// e.ToClientMessage = msg +// } + +// } +// func New(err error, code int32, grpcCode codes.Code, action string, opts ...Options) error { +// pc, file, line, ok := runtime.Caller(1) +// if !ok { +// file = "unknown" +// line = 0 +// } +// fn := runtime.FuncForPC(pc) +// funcName := "unknown" +// if fn != nil { +// funcName = fn.Name() +// } + +// srvErr := &common.Errors{ +// Code: code, +// Message: err.Error(), +// Action: action, +// File: file, +// Line: int32(line), +// Fn: funcName, +// } +// // details,_:=structpb.NewStruct(map[string]interface{}{ +// // "code":code, +// // "message":err.Error(), +// // "action":action, +// // "file":file, +// // "line":line, +// // "fn":funcName, +// // }) +// for _, opt := range opts { +// opt(srvErr) +// } +// st := status.New(grpcCode, err.Error()) +// detailProto, err := anypb.New(srvErr) +// if err != nil { +// return status.Errorf(codes.Internal, "failed to marshal error details: %v", err) +// } +// st, err = st.WithDetails(detailProto) +// if err != nil { +// return status.Errorf(codes.Internal, "failed to marshal error details: %v", err) + +// } +// return st.Err() +// // srvErr:=status.New(0,err.Error()) +// // srvErr.WithDetails() +// // return nil +// } +// func (s *SrvError) Error() string { +// return fmt.Sprintf("%s|%d", s.Err.Error(), s.ErrCode) +// } +// func (s *SrvError) Code() int32 { +// return s.ErrCode +// } var ( ErrUserNotFound = errors.New("用户不存在") diff --git a/internal/pkg/middleware/auth.go b/internal/pkg/middleware/auth.go index ce7cfaa..e653f75 100644 --- a/internal/pkg/middleware/auth.go +++ b/internal/pkg/middleware/auth.go @@ -39,14 +39,14 @@ func (a *Auth) UnaryInterceptor(ctx context.Context, req any, info *grpc.UnarySe if !ok { return nil, status.Errorf(codes.Unauthenticated, "metadata not exists in context") } - xApiKey:=md.Get("x-api-key") + xApiKey := md.Get("x-api-key") if len(xApiKey) != 0 { return a.apikey.UnaryInterceptor(ctx, req, info, handler) } authorization := a.jwt.GetAuthorizationFromMetadata(md) if authorization == "" { - return nil, errors.New(errors.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + return nil, gosdk.NewError(errors.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } if strings.Contains(authorization, "Bearer") { @@ -71,14 +71,14 @@ func (a *Auth) StreamInterceptor(srv any, ss grpc.ServerStream, info *grpc.Strea if !ok { return status.Errorf(codes.Unauthenticated, "metadata not exists in context") } - xApiKey:=md.Get("x-api-key") + xApiKey := md.Get("x-api-key") if len(xApiKey) != 0 { return a.apikey.StreamInterceptor(srv, ss, info, handler) } authorization := a.jwt.GetAuthorizationFromMetadata(md) if authorization == "" { - return errors.New(errors.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + return gosdk.NewError(errors.ErrTokenMissing, int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } var err error if strings.Contains(authorization, "Bearer") { diff --git a/internal/pkg/middleware/auth/apikey.go b/internal/pkg/middleware/auth/apikey.go index a5f4083..fe50b65 100644 --- a/internal/pkg/middleware/auth/apikey.go +++ b/internal/pkg/middleware/auth/apikey.go @@ -13,7 +13,7 @@ import ( "google.golang.org/grpc/status" ) -type ApiKeyAuth interface{ +type ApiKeyAuth interface { gosdk.LocalPlugin } @@ -53,16 +53,16 @@ func NewApiKeyAuth(config *config.Config) ApiKeyAuth { func (a *ApiKeyAuthImpl) check(ctx context.Context) error { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return errors.New(status.Errorf(codes.Unauthenticated, "metadata not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + return gosdk.NewError(status.Errorf(codes.Unauthenticated, "metadata not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } // authorization := a.GetAuthorizationFromMetadata(md) apikeys := md.Get("x-api-key") if len(apikeys) == 0 { - return errors.New(status.Errorf(codes.Unauthenticated, "apikey not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") + return gosdk.NewError(status.Errorf(codes.Unauthenticated, "apikey not exists in context"), int32(api.UserSvrCode_USER_AUTH_MISSING_ERR), codes.Unauthenticated, "authorization_check") } apikey := apikeys[0] if apikey != a.config.GetAdminAPIKey() { - return errors.New(errors.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") + return gosdk.NewError(errors.ErrAPIKeyNotMatch, int32(api.UserSvrCode_USER_APIKEY_NOT_MATCH_ERR), codes.Unauthenticated, "authorization_check") } return nil diff --git a/internal/pkg/middleware/exception.go b/internal/pkg/middleware/exception.go index 32cee22..832cca2 100644 --- a/internal/pkg/middleware/exception.go +++ b/internal/pkg/middleware/exception.go @@ -5,7 +5,7 @@ import ( "fmt" sysRuntime "runtime" - "github.com/begonia-org/begonia/internal/pkg/errors" + gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/begonia-org/go-sdk/logger" "google.golang.org/grpc" @@ -25,7 +25,7 @@ func (e *Exception) UnaryInterceptor(ctx context.Context, req interface{}, info n := sysRuntime.Stack(buf, false) // false 表示不需要所有goroutine的调用栈 stackTrace := string(buf[:n]) err = fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - err = errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") } }() resp, err = handler(ctx, req) @@ -42,7 +42,7 @@ func (e *Exception) StreamInterceptor(srv interface{}, ss grpc.ServerStream, inf stackTrace := string(buf[:n]) err := fmt.Errorf("panic: %v\nStack trace: %s", p, stackTrace) - err = errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") + err = gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "panic") _ = ss.SendMsg(err) } }() diff --git a/internal/pkg/middleware/http.go b/internal/pkg/middleware/http.go index bee5a9f..6f6a083 100644 --- a/internal/pkg/middleware/http.go +++ b/internal/pkg/middleware/http.go @@ -7,7 +7,6 @@ import ( "strconv" "strings" - "github.com/begonia-org/begonia/internal/pkg/errors" "github.com/begonia-org/begonia/internal/pkg/routers" gosdk "github.com/begonia-org/go-sdk" _ "github.com/begonia-org/go-sdk/api/app/v1" @@ -84,16 +83,6 @@ func getClientMessageMap() map[int32]string { } -func clientMessageFromCode(code codes.Code) string { - switch code { - case codes.ResourceExhausted: - return "The requested resource size exceeds the server limit." - default: - return "Unknown error" - - } -} - // func isValidContentType(ct string) bool { // mimeType, _, err := mime.ParseMediaType(ct) // return err == nil && mimeType != "" @@ -137,7 +126,7 @@ func HttpResponseBodyModify(ctx context.Context, w http.ResponseWriter, msg prot codeStr := value[0] code, err := strconv.ParseInt(codeStr, 10, 32) if err != nil { - return errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "internal_error") + return gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "internal_error") } httpCode = int(code) @@ -238,7 +227,7 @@ func grpcToHttpResponse(rsp interface{}, err error) (*common.HttpResponse, error data := rsp.(protoreflect.ProtoMessage) anyData, err = toStructMessage(data) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "internal_error") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "internal_error") } } diff --git a/internal/pkg/middleware/metadata.go b/internal/pkg/middleware/metadata.go deleted file mode 100644 index 2c0508c..0000000 --- a/internal/pkg/middleware/metadata.go +++ /dev/null @@ -1,13 +0,0 @@ -package middleware - -import ( - "context" - "net/http" - - "google.golang.org/grpc/metadata" -) - -func HTTPAnnotator(ctx context.Context, r *http.Request) metadata.MD{ - md := metadata.Pairs("x-request-id", r.Header.Get("x-request-id")) - return md -} diff --git a/internal/pkg/middleware/middleware.go b/internal/pkg/middleware/middleware.go index c57c39a..e423ddb 100644 --- a/internal/pkg/middleware/middleware.go +++ b/internal/pkg/middleware/middleware.go @@ -9,12 +9,13 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/middleware/auth" - "github.com/begonia-org/begonia/transport" goloadbalancer "github.com/begonia-org/go-loadbalancer" gosdk "github.com/begonia-org/go-sdk" "github.com/begonia-org/go-sdk/logger" "github.com/spark-lence/tiga" "google.golang.org/grpc" + "github.com/begonia-org/begonia/gateway" + ) // var Plugins = map[string]gosdk.GrpcPlugin{ @@ -33,7 +34,7 @@ func New(config *config.Config, plugins := map[string]gosdk.LocalPlugin{ "onlyJWT": jwt, "onlyAK": ak, - "logger": transport.NewLoggerMiddleware(log), + "logger": gateway.NewLoggerMiddleware(log), "exception": NewException(log), "http": NewHttp(), "auth": NewAuth(ak, jwt, apiKey), @@ -55,7 +56,7 @@ func New(config *config.Config, rpcPlugins, err := config.GetRPCPlugins() if err != nil { - log.Errorf(context.TODO(),"get rpc plugins error:%v", err) + log.Errorf(context.TODO(), "get rpc plugins error:%v", err) return pluginsApply } for _, rpc := range rpcPlugins { diff --git a/internal/pkg/middleware/request.go b/internal/pkg/middleware/request.go index 421fcb8..c870d7c 100644 --- a/internal/pkg/middleware/request.go +++ b/internal/pkg/middleware/request.go @@ -1,106 +1 @@ package middleware - -import ( - "context" - "net/http" - "strings" - - gosdk "github.com/begonia-org/go-sdk" - "github.com/google/uuid" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "google.golang.org/grpc/metadata" -) - -func preflightHandler(w http.ResponseWriter, _ *http.Request) { - // headers := []string{"Content-Type", "Accept", "Authorization", "X-Token", "x-date", "x-access-key"} - w.Header().Set("Access-Control-Allow-Headers", "*") - // methods := []string{"GET", "HEAD", "POST", "PUT", "DELETE"} - w.Header().Set("Access-Control-Allow-Methods", "*") - w.Header().Set("Access-Control-Expose-Headers", "*") -} - -type CorsMiddleware struct { - Cors []string -} - -func (cors *CorsMiddleware) Handle(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if clientOrigin := r.Header.Get("Origin"); clientOrigin != "" { - var isAllowed bool - - for _, origin := range cors.Cors { - if origin == "*" || strings.HasSuffix(clientOrigin, origin) { - isAllowed = true - break - } - } - if isAllowed { - w.Header().Set("Access-Control-Allow-Origin", clientOrigin) - if r.Method == "OPTIONS" { - preflightHandler(w, r) - return - } - } - } - h.ServeHTTP(w, r) - }) -} -func RequestIDMiddleware(ctx context.Context, r *http.Request) metadata.MD { - md, ok := runtime.ServerMetadataFromContext(ctx) - if !ok { - return nil - } - if val := md.HeaderMD.Get("x-request-id"); len(val) > 0 { - r.Header.Set("x-request-id", val[0]) - } - return md.HeaderMD - -} - -func IncomingHeadersToMetadata(ctx context.Context, req *http.Request) metadata.MD { - // 创建一个新的 metadata.MD 实例 - md := metadata.MD{} - invalidHeaders := []string{ - "Connection", "Keep-Alive", "Proxy-Connection", - "Transfer-Encoding", "Upgrade", "TE", - } - for _, h := range invalidHeaders { - req.Header.Del(h) - - } - for k, v := range req.Header { - if strings.HasPrefix(strings.ToLower(k), "sec-") { - continue - } - if strings.ToLower(k) == "pragma" { - continue - } - - md.Set(strings.ToLower(k), v...) - } - // 设置一些默认的元数据 - reqID := uuid.New().String() - md.Set("x-request-id", reqID) - md.Set("uri", req.RequestURI) - md.Set("x-http-method", req.Method) - md.Set("remote_addr", req.RemoteAddr) - md.Set("protocol", req.Proto) - md.Set(gosdk.GetMetadataKey("x-request-id"), reqID) - - xuid := md.Get("x-uid") - accessKey := md.Get("x-access-key") - author := "" - - if len(xuid) > 0 { - author = xuid[0] - } - if author == "" && len(accessKey) > 0 { - author = accessKey[0] - } - if author == "" { - return md - } - md.Set("x-identity", author) - - return md -} diff --git a/internal/pkg/middleware/rpc.go b/internal/pkg/middleware/rpc.go index 6b32794..ea7b900 100644 --- a/internal/pkg/middleware/rpc.go +++ b/internal/pkg/middleware/rpc.go @@ -6,7 +6,6 @@ import ( "strings" "time" - "github.com/begonia-org/begonia/internal/pkg/errors" goloadbalancer "github.com/begonia-org/go-loadbalancer" lb "github.com/begonia-org/go-loadbalancer" gosdk "github.com/begonia-org/go-sdk" @@ -59,14 +58,14 @@ func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.U } cn, err := endpoint.Get(ctx) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") } defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) plugin := api.NewPluginServiceClient(conn) anyReq, err := anypb.New(req.(proto.Message)) if err != nil { - return nil, errors.New(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") + return nil, gosdk.NewError(fmt.Errorf("new any to plugin error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") } rsp, err := plugin.Call(ctx, &api.PluginRequest{ @@ -74,7 +73,7 @@ func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.U FullMethodName: info.FullMethod, }) if err != nil { - return nil, errors.New(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") + return nil, gosdk.NewError(fmt.Errorf("call plugin error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } md, ok := metadata.FromIncomingContext(ctx) if !ok { @@ -88,7 +87,7 @@ func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.U if newRequest != nil { err = newRequest.UnmarshalTo(req.(proto.Message)) if err != nil { - return nil, errors.New(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") + return nil, gosdk.NewError(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") } } @@ -98,7 +97,7 @@ func (p *pluginImpl) UnaryInterceptor(ctx context.Context, req any, info *grpc.U func (p *pluginImpl) getEndpoint(ctx context.Context) (lb.Endpoint, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return nil, errors.New(fmt.Errorf("get metadata from context error"), int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_metadata") + return nil, gosdk.NewError(fmt.Errorf("get metadata from context error"), int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_metadata") } xforwardeds := md.Get("X-Forwarded-For") clientIP := "" @@ -111,7 +110,7 @@ func (p *pluginImpl) getEndpoint(ctx context.Context) (lb.Endpoint, error) { } endpoint, err := p.lb.Select(clientIP) if err != nil { - return nil, errors.New(fmt.Errorf("select endpoint error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "select_endpoint") + return nil, gosdk.NewError(fmt.Errorf("select endpoint error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "select_endpoint") } return endpoint, nil @@ -124,7 +123,7 @@ func (p *pluginImpl) Call(ctx context.Context, in *api.PluginRequest, opts ...gr } cn, err := endpoint.Get(ctx) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") } defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) @@ -139,7 +138,7 @@ func (p *pluginImpl) GetPluginInfo(ctx context.Context, in *emptypb.Empty, opts } cn, err := endpoint.Get(ctx) if err != nil { - return nil, errors.New(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") + return nil, gosdk.NewError(err, int32(common.Code_INTERNAL_ERROR), codes.Internal, "get_connection") } defer endpoint.AfterTransform(ctx, cn.((goloadbalancer.Connection))) conn := cn.(goloadbalancer.Connection).ConnInstance().(*grpc.ClientConn) diff --git a/internal/pkg/middleware/stream.go b/internal/pkg/middleware/stream.go index 8ff1e0e..3fee971 100644 --- a/internal/pkg/middleware/stream.go +++ b/internal/pkg/middleware/stream.go @@ -5,7 +5,6 @@ import ( "fmt" "sync" - "github.com/begonia-org/begonia/internal/pkg/errors" gosdk "github.com/begonia-org/go-sdk" api "github.com/begonia-org/go-sdk/api/plugin/v1" common "github.com/begonia-org/go-sdk/common/api/v1" @@ -57,7 +56,7 @@ func (s *grpcPluginStream) RecvMsg(m interface{}) error { anyReq, err := anypb.New(m.(protoreflect.ProtoMessage)) if err != nil { - return errors.New(fmt.Errorf("new any error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") + return gosdk.NewError(fmt.Errorf("new any error: %w", err), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "new_any") } rsp, err := s.plugin.Call(s.Context(), &api.PluginRequest{ @@ -65,7 +64,7 @@ func (s *grpcPluginStream) RecvMsg(m interface{}) error { Request: anyReq, }) if err != nil { - return errors.New(fmt.Errorf("call %s plugin error: %w", s.plugin.Name(), err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") + return gosdk.NewError(fmt.Errorf("call %s plugin error: %w", s.plugin.Name(), err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "call_plugin") } md, ok := metadata.FromIncomingContext(s.ctx) @@ -79,7 +78,7 @@ func (s *grpcPluginStream) RecvMsg(m interface{}) error { if newRequest != nil { err = newRequest.UnmarshalTo(m.(proto.Message)) if err != nil { - return errors.New(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") + return gosdk.NewError(fmt.Errorf("unmarshal to request error: %w", err), int32(common.Code_INTERNAL_ERROR), codes.Internal, "unmarshal_to_request") } } s.ctx = metadata.NewIncomingContext(s.ctx, md) diff --git a/internal/pkg/middleware/vaildator.go b/internal/pkg/middleware/vaildator.go index 3fbc5d4..14a2add 100644 --- a/internal/pkg/middleware/vaildator.go +++ b/internal/pkg/middleware/vaildator.go @@ -5,7 +5,6 @@ import ( "fmt" "sync" - "github.com/begonia-org/begonia/internal/pkg/errors" gosdk "github.com/begonia-org/go-sdk" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/go-playground/validator/v10" @@ -92,7 +91,7 @@ func (p *ParamsValidatorImpl) ValidateParams(v interface{}) error { } if errs, ok := err.(validator.ValidationErrors); ok { clientMsg := fmt.Sprintf("params %s validation failed with %v,except %s", errs[0].Field(), errs[0].Value(), errs[0].ActualTag()) - return errors.New(fmt.Errorf("params %s validation failed: %v", errs[0].Field(), errs[0].Value()), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", errors.WithClientMessage(clientMsg)) + return gosdk.NewError(fmt.Errorf("params %s validation failed: %v", errs[0].Field(), errs[0].Value()), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "params_validation", gosdk.WithClientMessage(clientMsg)) } return nil } diff --git a/internal/pkg/pkg.go b/internal/pkg/pkg.go index 2613b7b..7d7e06c 100644 --- a/internal/pkg/pkg.go +++ b/internal/pkg/pkg.go @@ -3,12 +3,12 @@ package pkg import ( "context" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/crypto" "github.com/begonia-org/begonia/internal/pkg/middleware" "github.com/begonia-org/begonia/internal/pkg/middleware/auth" "github.com/begonia-org/begonia/internal/pkg/migrate" - "github.com/begonia-org/begonia/transport" "github.com/google/wire" ) @@ -28,6 +28,6 @@ var ProviderSet = wire.NewSet( migrate.NewAPPOperator, auth.NewAccessKeyAuth, - transport.NewLoggerMiddleware, + gateway.NewLoggerMiddleware, middleware.New, ) diff --git a/internal/pkg/routers/routers.go b/internal/pkg/routers/routers.go index dcf1453..2bfeadc 100644 --- a/internal/pkg/routers/routers.go +++ b/internal/pkg/routers/routers.go @@ -5,7 +5,7 @@ import ( "strings" "sync" - "github.com/begonia-org/begonia/transport" + "github.com/begonia-org/begonia/gateway" _ "github.com/begonia-org/go-sdk/api/app/v1" _ "github.com/begonia-org/go-sdk/api/endpoint/v1" _ "github.com/begonia-org/go-sdk/api/example/v1" @@ -169,7 +169,7 @@ func (r *HttpURIRouteToSrvMethod) addRouterDetails(serviceName string, authRequi } } -func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd transport.ProtobufDescription) { +func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd gateway.ProtobufDescription) { fds := pd.GetFileDescriptorSet() r.mux.Lock() defer r.mux.Unlock() @@ -195,7 +195,7 @@ func (r *HttpURIRouteToSrvMethod) LoadAllRouters(pd transport.ProtobufDescriptio } -func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd transport.ProtobufDescription) { +func (h *HttpURIRouteToSrvMethod) DeleteRouters(pd gateway.ProtobufDescription) { // h.mux.Lock() // defer h.mux.Unlock() fds := pd.GetFileDescriptorSet() diff --git a/internal/pkg/web/response.go b/internal/pkg/web/response.go deleted file mode 100644 index 66a6b42..0000000 --- a/internal/pkg/web/response.go +++ /dev/null @@ -1,63 +0,0 @@ -package web - -import ( - "fmt" - - "github.com/begonia-org/begonia/internal/pkg/config" - _ "github.com/begonia-org/go-sdk/api/app/v1" - _ "github.com/begonia-org/go-sdk/api/example/v1" - _ "github.com/begonia-org/go-sdk/api/plugin/v1" - _ "github.com/begonia-org/go-sdk/api/endpoint/v1" - _ "github.com/begonia-org/go-sdk/api/user/v1" - _ "github.com/begonia-org/go-sdk/api/iam/v1" - _ "github.com/begonia-org/go-sdk/api/sys/v1" - _ "github.com/begonia-org/go-sdk/common/api/v1" - - common "github.com/begonia-org/go-sdk/common/api/v1" - "github.com/cockroachdb/errors" - "github.com/spark-lence/tiga" - srvErr "github.com/spark-lence/tiga/errors" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" -) - -func unwrap(err error) *srvErr.Errors { - var se *srvErr.Errors - if errors.As(err, &se) { - return se - } - return nil - -} -func MakeResponse(data protoreflect.ProtoMessage, srcErr error) (*common.APIResponse, error) { - message := "Internal Error" - code := int32(common.Code_INTERNAL_ERROR.Number()) - if srcErr != nil { - se := unwrap(srcErr) - if se != nil { - message = se.ClientMessage() - code = se.Code() - } - } else { - code = int32(common.Code_OK.Number()) - message = "ok" - } - rsp, err := tiga.MakeResponse(code, data, srcErr, message, fmt.Sprintf("%s.APIResponse", config.APIPkg)) - if rsp != nil { - // 序列化 *dynamicapi.Message - serializedMsg, mErr := proto.Marshal(rsp) - if mErr != nil { - return nil, fmt.Errorf("序列化响应失败,%w", mErr) // 处理错误 - // 处理错误 - } - // 反序列化为 common.APIResponse - var apiResponse *common.APIResponse = &common.APIResponse{} - mErr = proto.Unmarshal(serializedMsg, apiResponse) - if mErr != nil { - return nil, fmt.Errorf("反序列化响应失败,%w", mErr) // 处理错误 - // 处理错误 - } - return apiResponse, err - } - return nil, err -} diff --git a/internal/server/coverprofile.cov b/internal/server/coverprofile.cov index 39fe2c4..6168ac0 100644 --- a/internal/server/coverprofile.cov +++ b/internal/server/coverprofile.cov @@ -1,13 +1,17 @@ mode: set -github.com/begonia-org/begonia/internal/server/server.go:26.52,33.2 3 1 -github.com/begonia-org/begonia/internal/server/server.go:34.145,62.16 16 1 -github.com/begonia-org/begonia/internal/server/server.go:62.16,63.13 1 0 -github.com/begonia-org/begonia/internal/server/server.go:65.2,66.16 2 1 -github.com/begonia-org/begonia/internal/server/server.go:66.16,67.13 1 0 -github.com/begonia-org/begonia/internal/server/server.go:69.2,70.31 2 1 -github.com/begonia-org/begonia/internal/server/server.go:70.31,72.17 2 1 -github.com/begonia-org/begonia/internal/server/server.go:72.17,73.14 1 0 -github.com/begonia-org/begonia/internal/server/server.go:75.3,75.45 1 1 -github.com/begonia-org/begonia/internal/server/server.go:75.45,77.4 1 1 -github.com/begonia-org/begonia/internal/server/server.go:80.2,82.11 2 1 -github.com/begonia-org/begonia/internal/server/wire_gen.go:26.104,53.2 26 1 +github.com/begonia-org/begonia/internal/server/server.go:27.57,34.2 3 1 +github.com/begonia-org/begonia/internal/server/server.go:35.73,38.16 3 1 +github.com/begonia-org/begonia/internal/server/server.go:38.16,40.3 1 0 +github.com/begonia-org/begonia/internal/server/server.go:41.2,42.16 2 1 +github.com/begonia-org/begonia/internal/server/server.go:42.16,44.3 1 0 +github.com/begonia-org/begonia/internal/server/server.go:45.2,46.16 2 1 +github.com/begonia-org/begonia/internal/server/server.go:46.16,48.3 1 0 +github.com/begonia-org/begonia/internal/server/server.go:49.2,49.16 1 1 +github.com/begonia-org/begonia/internal/server/server.go:51.155,84.16 18 1 +github.com/begonia-org/begonia/internal/server/server.go:84.16,85.13 1 0 +github.com/begonia-org/begonia/internal/server/server.go:87.2,88.31 2 1 +github.com/begonia-org/begonia/internal/server/server.go:88.31,90.17 2 1 +github.com/begonia-org/begonia/internal/server/server.go:90.17,91.14 1 0 +github.com/begonia-org/begonia/internal/server/server.go:93.3,93.45 1 1 +github.com/begonia-org/begonia/internal/server/server.go:93.45,95.4 1 1 +github.com/begonia-org/begonia/internal/server/server.go:98.2,100.11 2 1 diff --git a/internal/server/server.go b/internal/server/server.go index 507e171..5306848 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,12 +9,12 @@ import ( "path/filepath" "strconv" + "github.com/begonia-org/begonia/gateway" + "github.com/begonia-org/begonia/gateway/serialization" "github.com/begonia-org/begonia/internal/pkg/config" "github.com/begonia-org/begonia/internal/pkg/middleware" "github.com/begonia-org/begonia/internal/pkg/routers" "github.com/begonia-org/begonia/internal/service" - "github.com/begonia-org/begonia/transport" - "github.com/begonia-org/begonia/transport/serialization" loadbalance "github.com/begonia-org/go-loadbalancer" common "github.com/begonia-org/go-sdk/common/api/v1" "github.com/google/wire" @@ -24,21 +24,21 @@ import ( var ProviderSet = wire.NewSet(NewGatewayConfig, NewGateway) -func NewGatewayConfig(gw string) *transport.GatewayConfig { +func NewGatewayConfig(gw string) *gateway.GatewayConfig { _, port, _ := net.SplitHostPort(gw) p, _ := strconv.Atoi(port) - return &transport.GatewayConfig{ + return &gateway.GatewayConfig{ GrpcProxyAddr: fmt.Sprintf(":%d", p+1), GatewayAddr: gw, } } -func readDesc(conf *config.Config) (transport.ProtobufDescription, error) { +func readDesc(conf *config.Config) (gateway.ProtobufDescription, error) { desc := conf.GetLocalAPIDesc() bin, err := os.ReadFile(desc) if err != nil { return nil, fmt.Errorf("read desc file error:%w", err) } - pd, err := transport.NewDescriptionFromBinary(bin, filepath.Dir(desc)) + pd, err := gateway.NewDescriptionFromBinary(bin, filepath.Dir(desc)) if err != nil { return nil, err } @@ -48,10 +48,10 @@ func readDesc(conf *config.Config) (transport.ProtobufDescription, error) { } return pd, nil } -func NewGateway(cfg *transport.GatewayConfig, conf *config.Config, services []service.Service, pluginApply *middleware.PluginsApply) *transport.GatewayServer { +func NewGateway(cfg *gateway.GatewayConfig, conf *config.Config, services []service.Service, pluginApply *middleware.PluginsApply) *gateway.GatewayServer { // 参数选项 - opts := &transport.GrpcServerOptions{ - Middlewares: make([]transport.GrpcProxyMiddleware, 0), + opts := &gateway.GrpcServerOptions{ + Middlewares: make([]gateway.GrpcProxyMiddleware, 0), Options: make([]grpc.ServerOption, 0), PoolOptions: make([]loadbalance.PoolOptionsBuildOption, 0), HttpMiddlewares: make([]runtime.ServeMuxOption, 0), @@ -63,9 +63,9 @@ func NewGateway(cfg *transport.GatewayConfig, conf *config.Config, services []se opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption(runtime.MIMEWildcard, serialization.NewRawBinaryUnmarshaler())) opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMarshalerOption("application/octet-stream", serialization.NewRawBinaryUnmarshaler())) - opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMetadata(transport.IncomingHeadersToMetadata)) - opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithErrorHandler(transport.HandleErrorWithLogger(transport.Log))) - opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithForwardResponseOption(transport.HttpResponseBodyModify)) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithMetadata(gateway.IncomingHeadersToMetadata)) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithErrorHandler(gateway.HandleErrorWithLogger(gateway.Log))) + opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithForwardResponseOption(gateway.HttpResponseBodyModify)) // opts.HttpMiddlewares = append(opts.HttpMiddlewares, runtime.WithRoutingErrorHandler(middleware.HandleRoutingError)) // 连接池配置 opts.PoolOptions = append(opts.PoolOptions, loadbalance.WithMaxActiveConns(100)) @@ -74,12 +74,11 @@ func NewGateway(cfg *transport.GatewayConfig, conf *config.Config, services []se opts.Options = append(opts.Options, grpc.ChainUnaryInterceptor(pluginApply.UnaryInterceptorChains()...)) opts.Options = append(opts.Options, grpc.ChainStreamInterceptor(pluginApply.StreamInterceptorChains()...)) - cors := &middleware.CorsMiddleware{ + cors := &gateway.CorsHandler{ Cors: conf.GetCorsConfig(), } opts.HttpHandlers = append(opts.HttpHandlers, cors.Handle) - runtime.WithMetadata(middleware.IncomingHeadersToMetadata) - gw := transport.New(cfg, opts) + gw := gateway.New(cfg, opts) pd, err := readDesc(conf) if err != nil { diff --git a/internal/server/wire.go b/internal/server/wire.go index bf6000b..20d4cfd 100644 --- a/internal/server/wire.go +++ b/internal/server/wire.go @@ -8,14 +8,13 @@ import ( "github.com/begonia-org/begonia/internal/data" "github.com/begonia-org/begonia/internal/pkg" "github.com/begonia-org/begonia/internal/service" - "github.com/begonia-org/begonia/transport" "github.com/begonia-org/go-sdk/logger" "github.com/google/wire" "github.com/spark-lence/tiga" ) -func New(config *tiga.Configuration, log logger.Logger, endpoint string) *transport.GatewayServer { +func New(config *tiga.Configuration, log logger.Logger, endpoint string) *gateway.GatewayServer { panic(wire.Build(biz.ProviderSet, pkg.ProviderSet, data.ProviderSet, service.ProviderSet, ProviderSet)) diff --git a/internal/integration/app_test.go b/internal/service/app_test.go similarity index 91% rename from internal/integration/app_test.go rename to internal/service/app_test.go index a193af5..1e6e21b 100644 --- a/internal/integration/app_test.go +++ b/internal/service/app_test.go @@ -1,4 +1,4 @@ -package integration_test +package service_test import ( "context" @@ -20,7 +20,7 @@ func addApp(t *testing.T) { t, func() { apiClient := client.NewAppAPI(apiAddr, accessKey, secret) - rsp, err := apiClient.PostAppConfig(context.Background(), &api.AppsRequest{Name: fmt.Sprintf("app-%s",time.Now().Format("20060102150405")), Description: "test"}) + rsp, err := apiClient.PostAppConfig(context.Background(), &api.AppsRequest{Name: fmt.Sprintf("app-%s", time.Now().Format("20060102150405")), Description: "test"}) c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) c.So(rsp.Appid, c.ShouldNotBeEmpty) diff --git a/internal/integration/authz_test.go b/internal/service/authz_test.go similarity index 92% rename from internal/integration/authz_test.go rename to internal/service/authz_test.go index 5bc90cb..e4cf5c7 100644 --- a/internal/integration/authz_test.go +++ b/internal/service/authz_test.go @@ -1,4 +1,4 @@ -package integration_test +package service_test import ( "context" @@ -13,9 +13,9 @@ import ( "runtime" "testing" + sys "github.com/begonia-org/go-sdk/api/sys/v1" "github.com/begonia-org/go-sdk/client" common "github.com/begonia-org/go-sdk/common/api/v1" - sys "github.com/begonia-org/go-sdk/api/sys/v1" c "github.com/smartystreets/goconvey/convey" "google.golang.org/protobuf/encoding/protojson" ) @@ -77,12 +77,12 @@ func loginTest(t *testing.T) { c.So(err, c.ShouldBeNil) c.So(apiRsp.Code, c.ShouldEqual, common.Code_OK) - bData,err:=apiRsp.Data.MarshalJSON() - c.So(err,c.ShouldBeNil) - info:= &sys.InfoResponse{} - err=protojson.Unmarshal(bData,info) - c.So(err,c.ShouldBeNil) - t.Log(info.Version,info.BuildTime,info.Commit) + bData, err := apiRsp.Data.MarshalJSON() + c.So(err, c.ShouldBeNil) + info := &sys.InfoResponse{} + err = protojson.Unmarshal(bData, info) + c.So(err, c.ShouldBeNil) + t.Log(info.Version, info.BuildTime, info.Commit) // c.So(info.Name,c.ShouldEqual,"gateway") }) diff --git a/internal/integration/base_test.go b/internal/service/base_test.go similarity index 93% rename from internal/integration/base_test.go rename to internal/service/base_test.go index f8c98af..d33599f 100644 --- a/internal/integration/base_test.go +++ b/internal/service/base_test.go @@ -1,4 +1,4 @@ -package integration_test +package service_test import ( "encoding/json" @@ -11,8 +11,8 @@ import ( "github.com/begonia-org/begonia" "github.com/begonia-org/begonia/config" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal" - "github.com/begonia-org/begonia/transport" api "github.com/begonia-org/go-sdk/api/app/v1" example "github.com/begonia-org/go-sdk/example" ) @@ -78,7 +78,7 @@ func RunTestServer() { config := config.ReadConfig(env) go func() { - worker := internal.New(config, transport.Log, "0.0.0.0:12140") + worker := internal.New(config, gateway.Log, "0.0.0.0:12140") worker.Start() }() diff --git a/internal/service/file.go b/internal/service/file.go index 5123274..ee2fbd5 100644 --- a/internal/service/file.go +++ b/internal/service/file.go @@ -37,11 +37,11 @@ func NewFileService(biz *file.FileUsecase, config *config.Config) *FileService { func (f *FileService) Upload(ctx context.Context, in *api.UploadFileRequest) (*api.UploadFileResponse, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return nil, errors.New(fmt.Errorf("not found metadata"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "not_found_metadata") + return nil, gosdk.NewError(fmt.Errorf("not found metadata"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "not_found_metadata") } identity := md.Get("x-identity") if len(identity) == 0 { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } // in.Key = identity[0] + "/" + in.Key return f.biz.Upload(ctx, in, identity[0]) @@ -56,11 +56,11 @@ func (f *FileService) UploadMultipartFile(ctx context.Context, in *api.UploadMul func (f *FileService) CompleteMultipartUpload(ctx context.Context, in *api.CompleteMultipartUploadRequest) (*api.CompleteMultipartUploadResponse, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_metadata") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_metadata") } identity := md.Get("x-identity") if len(identity) == 0 { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } return f.biz.CompleteMultipartUploadFile(ctx, in, identity[0]) } @@ -70,15 +70,15 @@ func (f *FileService) AbortMultipartUpload(ctx context.Context, in *api.AbortMul func (f *FileService) Download(ctx context.Context, in *api.DownloadRequest) (*httpbody.HttpBody, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { - return nil, errors.New(fmt.Errorf("not found metadata"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "not_found_metadata") + return nil, gosdk.NewError(fmt.Errorf("not found metadata"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "not_found_metadata") } identity := md.Get("x-identity") if len(identity) == 0 { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } newKey, err := url.PathUnescape(in.Key) if err != nil { - return nil, errors.New(err, int32(common.Code_UNKNOWN), codes.InvalidArgument, "url_unescape") + return nil, gosdk.NewError(err, int32(common.Code_UNKNOWN), codes.InvalidArgument, "url_unescape") } in.Key = newKey buf, err := f.biz.Download(ctx, in, identity[0]) @@ -94,7 +94,7 @@ func (f *FileService) Download(ctx context.Context, in *api.DownloadRequest) (*h ) err = grpc.SendHeader(ctx, rspMd) if err != nil { - return nil, errors.New(err, int32(common.Code_UNKNOWN), codes.Internal, "send_header") + return nil, gosdk.NewError(err, int32(common.Code_UNKNOWN), codes.Internal, "send_header") } rsp := &httpbody.HttpBody{ @@ -155,17 +155,17 @@ func (f *FileService) DownloadForRange(ctx context.Context, in *api.DownloadRequ var err error if ok { if _, ok := md["range"]; !ok { - return nil, errors.New(fmt.Errorf("range header not found"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "range_header_not_found") + return nil, gosdk.NewError(fmt.Errorf("range header not found"), int32(common.Code_PARAMS_ERROR), codes.InvalidArgument, "range_header_not_found") } rangeStr = md.Get("range")[0] start, end, err = parseRangeHeader(rangeStr) if err != nil { - return nil, errors.New(err, int32(common.Code_UNKNOWN), codes.InvalidArgument, "parse_range_header") + return nil, gosdk.NewError(err, int32(common.Code_UNKNOWN), codes.InvalidArgument, "parse_range_header") } } identity := GetIdentity(ctx) if identity == "" { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } data, fileSize, err := f.biz.DownloadForRange(ctx, in, start, end, identity) @@ -184,7 +184,7 @@ func (f *FileService) DownloadForRange(ctx context.Context, in *api.DownloadRequ ) err = grpc.SendHeader(ctx, rspMd) if err != nil { - return nil, errors.New(err, int32(common.Code_UNKNOWN), codes.Internal, "send_header") + return nil, gosdk.NewError(err, int32(common.Code_UNKNOWN), codes.Internal, "send_header") } return &httpbody.HttpBody{ ContentType: "application/octet-stream", @@ -194,14 +194,14 @@ func (f *FileService) DownloadForRange(ctx context.Context, in *api.DownloadRequ func (f *FileService) Delete(ctx context.Context, in *api.DeleteRequest) (*api.DeleteResponse, error) { identity := GetIdentity(ctx) if identity == "" { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } return f.biz.Delete(ctx, in, identity) } func (f *FileService) Metadata(ctx context.Context, in *api.FileMetadataRequest) (*api.FileMetadataResponse, error) { identity := GetIdentity(ctx) if identity == "" { - return nil, errors.New(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") + return nil, gosdk.NewError(errors.ErrIdentityMissing, int32(user.UserSvrCode_USER_IDENTITY_MISSING_ERR), codes.InvalidArgument, "not_found_identity") } rsp, err := f.biz.Metadata(ctx, in, identity) if err != nil { @@ -226,7 +226,7 @@ func (f *FileService) Metadata(ctx context.Context, in *api.FileMetadataRequest) err = grpc.SendHeader(ctx, rspMd) if err != nil { - return nil, errors.New(fmt.Errorf("非法的响应头,%w", err), int32(common.Code_UNKNOWN), codes.Internal, "send_header") + return nil, gosdk.NewError(fmt.Errorf("非法的响应头,%w", err), int32(common.Code_UNKNOWN), codes.Internal, "send_header") } if md, ok := metadata.FromIncomingContext(ctx); ok { diff --git a/internal/integration/file_test.go b/internal/service/file_test.go similarity index 94% rename from internal/integration/file_test.go rename to internal/service/file_test.go index 702849c..3371b70 100644 --- a/internal/integration/file_test.go +++ b/internal/service/file_test.go @@ -1,4 +1,4 @@ -package integration_test +package service_test import ( "context" @@ -37,16 +37,16 @@ func sumFileSha256(src string) (string, error) { } func upload(t *testing.T) { - env:=begonia.Env - if env==""{ - env="dev" + env := begonia.Env + if env == "" { + env = "dev" } c.Convey("test upload file", t, func() { // test upload file apiClient := client.NewFilesAPI(apiAddr, accessKey, secret) _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filename), "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata", "helloworld.pb") srcSha256, _ := sumFileSha256(pbFile) rsp, err := apiClient.UploadFile(context.Background(), pbFile, "test/helloworld.pb", true) c.So(err, c.ShouldBeNil) @@ -130,9 +130,9 @@ func uploadParts(t *testing.T) { c.So(rsp.Uri, c.ShouldNotBeEmpty) c.So(rsp.Version, c.ShouldNotBeEmpty) // c.So(rsp.Sha256,c.ShouldEqual,tmp.sha256) - env:="dev" - if begonia.Env!=""{ - env="test" + env := "dev" + if begonia.Env != "" { + env = "test" } conf := cfg.NewConfig(config.ReadConfig(env)) @@ -173,9 +173,9 @@ func download(t *testing.T) { c.So(err, c.ShouldBeNil) defer tmp.Close() defer os.Remove(tmp.Name()) - sha256Str, err := apiClient.DownloadFile(context.Background(), sdkAPPID + "/test/helloworld.pb", tmp.Name(), "") + sha256Str, err := apiClient.DownloadFile(context.Background(), sdkAPPID+"/test/helloworld.pb", tmp.Name(), "") c.So(err, c.ShouldBeNil) - _,err=os.Stat(tmp.Name()) + _, err = os.Stat(tmp.Name()) c.So(err, c.ShouldBeNil) downloadedSha256, err := sumFileSha256(tmp.Name()) c.So(err, c.ShouldBeNil) @@ -194,7 +194,7 @@ func downloadParts(t *testing.T) { c.So(err, c.ShouldBeNil) // c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) downloadedSha256, _ := sumFileSha256(tmp.Name()) - + c.So(rsp.Sha256, c.ShouldEqual, downloadedSha256) }) @@ -206,7 +206,7 @@ func deleteFile(t *testing.T) { } c.Convey("test delete file", t, func() { apiClient := client.NewFilesAPI(apiAddr, accessKey, secret) - rsp, err := apiClient.DeleteFile(context.Background(), sdkAPPID +"/test/helloworld.pb") + rsp, err := apiClient.DeleteFile(context.Background(), sdkAPPID+"/test/helloworld.pb") c.So(err, c.ShouldBeNil) c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) conf := cfg.NewConfig(config.ReadConfig(env)) diff --git a/internal/integration/gateway_test.go b/internal/service/gateway_test.go similarity index 96% rename from internal/integration/gateway_test.go rename to internal/service/gateway_test.go index 2328bea..b5ee326 100644 --- a/internal/integration/gateway_test.go +++ b/internal/service/gateway_test.go @@ -1,4 +1,4 @@ -package integration_test +package service_test import ( "context" @@ -22,7 +22,7 @@ func postEndpoint(t *testing.T) { c.Convey("test create endpoint api", t, func() { _, filename, _, _ := runtime.Caller(0) - pbFile := filepath.Join(filepath.Dir(filename), "testdata", "helloworld.pb") + pbFile := filepath.Join(filepath.Dir(filepath.Dir(filepath.Dir(filename))), "testdata", "helloworld.pb") pb, err := os.ReadFile(pbFile) c.So(err, c.ShouldBeNil) endpoint := &api.EndpointSrvConfig{ diff --git a/internal/service/user_test.go b/internal/service/user_test.go index 7c3c083..42a474a 100644 --- a/internal/service/user_test.go +++ b/internal/service/user_test.go @@ -1,13 +1,87 @@ -package service +package service_test -// func newServer() { +import ( + "context" + "fmt" + "testing" + "time" -// config := config.ReadConfig("dev") -// server := New(config, logger.Logger, "12138") -// go func() { -// _ = server.Start() -// }() -// } -// func TestUserLog(t *testing.T) { + api "github.com/begonia-org/go-sdk/api/user/v1" + "github.com/begonia-org/go-sdk/client" + common "github.com/begonia-org/go-sdk/common/api/v1" + c "github.com/smartystreets/goconvey/convey" +) -// } +var uid = "" + +func addUser(t *testing.T) { + c.Convey( + "test add user", + t, + func() { + apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) + name := fmt.Sprintf("user-%s", time.Now().Format("20060102150405")) + rsp, err := apiClient.PostUser(context.Background(), &api.PostUserRequest{ + Name: name, + Password: "123456", + Email: fmt.Sprintf("%s@example.com", name), + Role: api.Role_ADMIN, + Dept: "development", + Avatar: "https://www.example.com/avatar.jpg", + Owner: "test-user-01", + Phone: time.Now().Format("20060102150405"), + }) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + c.So(rsp.Uid, c.ShouldNotBeEmpty) + uid = rsp.Uid + + }) +} + +func getUser(t *testing.T) { + c.Convey( + "test get user", + t, + func() { + apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) + rsp, err := apiClient.GetUser(context.Background(), uid) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + }) +} +func deleteUser(t *testing.T) { + c.Convey( + "test delete user", + t, + func() { + apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) + rsp, err := apiClient.DeleteUser(context.Background(), uid) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + + _, err = apiClient.GetUser(context.Background(), uid) + c.So(err, c.ShouldNotBeNil) + + }) +} +func patchUser(t *testing.T) { + c.Convey( + "test patch user", + t, + func() { + apiClient := client.NewUsersAPI(apiAddr, accessKey, secret) + rsp, err := apiClient.PatchUser(context.Background(), uid, map[string]interface{}{ + "password": "123456ecfasddccddd", + "email": fmt.Sprintf("%s@example.com", time.Now().Format("20060102150405"))}) + c.So(err, c.ShouldBeNil) + c.So(rsp.StatusCode, c.ShouldEqual, common.Code_OK) + }) +} +func TestUser(t *testing.T) { + t.Run("add user", addUser) + t.Run("get user", getUser) + // uid = "442210231930327040" + t.Run("patch user", patchUser) + t.Run("delete user", deleteUser) +} diff --git a/internal/wire.go b/internal/wire.go index a45f4a8..cab37b7 100644 --- a/internal/wire.go +++ b/internal/wire.go @@ -11,7 +11,6 @@ import ( "github.com/begonia-org/begonia/internal/pkg/migrate" "github.com/begonia-org/begonia/internal/server" "github.com/begonia-org/begonia/internal/service" - "github.com/begonia-org/begonia/transport" "github.com/google/wire" @@ -24,7 +23,7 @@ func InitOperatorApp(config *tiga.Configuration) *migrate.InitOperator { } -func New(config *tiga.Configuration, log transport.Logger, endpoint string) GatewayWorker { +func New(config *tiga.Configuration, log gateway.Logger, endpoint string) GatewayWorker { panic(wire.Build(biz.ProviderSet, pkg.ProviderSet, data.ProviderSet, service.ProviderSet, daemon.ProviderSet, server.ProviderSet, NewGatewayWorkerImpl)) diff --git a/internal/worker.go b/internal/worker.go index ec5d6a3..e42ec52 100644 --- a/internal/worker.go +++ b/internal/worker.go @@ -4,8 +4,8 @@ import ( "context" "time" + "github.com/begonia-org/begonia/gateway" "github.com/begonia-org/begonia/internal/daemon" - "github.com/begonia-org/begonia/transport" ) type GatewayWorker interface { @@ -17,10 +17,10 @@ type GatewayWorkerImpl struct { // data daemon daemon.Daemon - server *transport.GatewayServer + server *gateway.GatewayServer } -func NewGatewayWorkerImpl(daemon daemon.Daemon, server *transport.GatewayServer) GatewayWorker { +func NewGatewayWorkerImpl(daemon daemon.Daemon, server *gateway.GatewayServer) GatewayWorker { return &GatewayWorkerImpl{ daemon: daemon, server: server, diff --git a/internal/integration/testdata/desc.pb b/testdata/desc.pb similarity index 98% rename from internal/integration/testdata/desc.pb rename to testdata/desc.pb index f9de8ea..c9c20ed 100644 Binary files a/internal/integration/testdata/desc.pb and b/testdata/desc.pb differ diff --git a/internal/integration/testdata/gateway.json b/testdata/gateway.json similarity index 100% rename from internal/integration/testdata/gateway.json rename to testdata/gateway.json diff --git a/transport/protos/google/api/annotations.proto b/testdata/google/api/annotations.proto similarity index 100% rename from transport/protos/google/api/annotations.proto rename to testdata/google/api/annotations.proto diff --git a/transport/protos/google/api/http.proto b/testdata/google/api/http.proto similarity index 100% rename from transport/protos/google/api/http.proto rename to testdata/google/api/http.proto diff --git a/internal/integration/testdata/google/api/httpbody.proto b/testdata/google/api/httpbody.proto similarity index 100% rename from internal/integration/testdata/google/api/httpbody.proto rename to testdata/google/api/httpbody.proto diff --git a/transport/protos/google/protobuf/any.proto b/testdata/google/protobuf/any.proto similarity index 100% rename from transport/protos/google/protobuf/any.proto rename to testdata/google/protobuf/any.proto diff --git a/transport/protos/google/protobuf/descriptor.proto b/testdata/google/protobuf/descriptor.proto similarity index 99% rename from transport/protos/google/protobuf/descriptor.proto rename to testdata/google/protobuf/descriptor.proto index 3b38675..5154e5a 100755 --- a/transport/protos/google/protobuf/descriptor.proto +++ b/testdata/google/protobuf/descriptor.proto @@ -305,7 +305,7 @@ message MethodDescriptorProto { // Identifies if client streams multiple client messages optional bool client_streaming = 5 [default = false]; // Identifies if server streams multiple server messages - optional bool server_streaming = 6 [default = false]; + optional bool service_streaming = 6 [default = false]; } // =================================================================== diff --git a/internal/integration/testdata/google/protobuf/empty.proto b/testdata/google/protobuf/empty.proto similarity index 100% rename from internal/integration/testdata/google/protobuf/empty.proto rename to testdata/google/protobuf/empty.proto diff --git a/transport/protos/google/protobuf/field_mask.proto b/testdata/google/protobuf/field_mask.proto similarity index 100% rename from transport/protos/google/protobuf/field_mask.proto rename to testdata/google/protobuf/field_mask.proto diff --git a/internal/integration/testdata/google/protobuf/struct.proto b/testdata/google/protobuf/struct.proto similarity index 100% rename from internal/integration/testdata/google/protobuf/struct.proto rename to testdata/google/protobuf/struct.proto diff --git a/transport/protos/google/protobuf/timestamp.proto b/testdata/google/protobuf/timestamp.proto similarity index 100% rename from transport/protos/google/protobuf/timestamp.proto rename to testdata/google/protobuf/timestamp.proto diff --git a/internal/integration/testdata/helloworld.pb b/testdata/helloworld.pb similarity index 98% rename from internal/integration/testdata/helloworld.pb rename to testdata/helloworld.pb index 25de36f..c37596f 100644 Binary files a/internal/integration/testdata/helloworld.pb and b/testdata/helloworld.pb differ diff --git a/internal/integration/testdata/options.proto b/testdata/options.proto similarity index 100% rename from internal/integration/testdata/options.proto rename to testdata/options.proto diff --git a/internal/integration/testdata/test.proto b/testdata/test.proto similarity index 100% rename from internal/integration/testdata/test.proto rename to testdata/test.proto diff --git a/transport/plugin.go b/transport/plugin.go deleted file mode 100644 index d11d0be..0000000 --- a/transport/plugin.go +++ /dev/null @@ -1 +0,0 @@ -package transport diff --git a/transport/request_test.go b/transport/request_test.go deleted file mode 100644 index bfc87b5..0000000 --- a/transport/request_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package transport_test - -import ( - "context" - "net/http" - "testing" - - "github.com/begonia-org/begonia/transport" - "github.com/begonia-org/begonia/transport/serialization" - hello "github.com/begonia-org/go-sdk/api/example/v1" - c "github.com/smartystreets/goconvey/convey" - "google.golang.org/grpc" -) - - -func TestBuildGrpcRequest(t *testing.T) { - c.Convey("TestBuildGrpcRequest", t, func() { - in:=&hello.HelloRequest{} - out:=&hello.HelloReply{} - httpReq,_:=http.NewRequest("GET","http://127.0.0.1:8080",nil) - - req:=transport.NewGrpcRequest(context.Background(), - in.ProtoReflect().Descriptor(), - out.ProtoReflect().Descriptor(), - "helloworld.Greeter/SayHello", - transport.WithGatewayCallOptions(grpc.CompressorCallOption{}), - transport.WithGatewayMarshaler(serialization.NewJSONMarshaler()), - transport.WithGatewayPathParams(map[string]string{"key":"value"}), - transport.WithGatewayReq(httpReq), - transport.WithIn(in), - transport.WithOut(out), - ) - c.So(req.GetFullMethodName(),c.ShouldEqual,"helloworld.Greeter/SayHello") - c.So(len(req.GetCallOptions()),c.ShouldEqual,1) - c.So(req.GetMarshaler(),c.ShouldHaveSameTypeAs,serialization.NewJSONMarshaler()) - c.So(req.GetPathParams(),c.ShouldResemble,map[string]string{"key":"value"}) - c.So(req.GetReq().URL.String(),c.ShouldEqual,httpReq.URL.String()) - c.So(req.GetIn(),c.ShouldHaveSameTypeAs,in) - c.So(req.GetOut(),c.ShouldHaveSameTypeAs,out) - c.So(req.GetInType(),c.ShouldEqual,in.ProtoReflect().Descriptor()) - c.So(req.GetOutType(),c.ShouldEqual,out.ProtoReflect().Descriptor()) - }) -}