From 712f74018f3f5168421374bf39f168840fedd03e Mon Sep 17 00:00:00 2001 From: lms Date: Wed, 20 Mar 2024 15:39:34 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=84=8F=E5=A4=96=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E7=9A=84=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/server/grpc/grpc.go | 203 +++++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 internal/server/grpc/grpc.go diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go new file mode 100644 index 0000000..a1a0d42 --- /dev/null +++ b/internal/server/grpc/grpc.go @@ -0,0 +1,203 @@ +package grpc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/cossim/hipush/api/grpc/v1" + push2 "github.com/cossim/hipush/api/push" + "github.com/cossim/hipush/config" + "github.com/cossim/hipush/internal/factory" + "github.com/cossim/hipush/pkg/consts" + "github.com/cossim/hipush/pkg/notify" + "github.com/cossim/hipush/pkg/push" + "github.com/cossim/hipush/pkg/status" + "github.com/go-logr/logr" + "github.com/golang/protobuf/jsonpb" + "google.golang.org/grpc" + "net" +) + +type Handler struct { + cfg *config.Config + logger logr.Logger + factory *factory.PushServiceFactory + v1.UnimplementedPushServiceServer +} + +func NewHandler(cfg *config.Config, logger logr.Logger, factory *factory.PushServiceFactory) *Handler { + return &Handler{ + cfg: cfg, + logger: logger.WithValues("server", "grpc"), + factory: factory, + } +} + +func (h *Handler) Start(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + lisAddr := fmt.Sprintf("%s", h.cfg.GRPC.Addr()) + lis, err := net.Listen("tcp", lisAddr) + if err != nil { + return err + } + + server := grpc.NewServer() + v1.RegisterPushServiceServer(server, h) + + serverShutdown := make(chan struct{}) + go func() { + <-ctx.Done() + h.logger.Info("Shutting down grpcServer", "addr", lisAddr) + server.GracefulStop() + close(serverShutdown) + }() + + h.logger.Info("Starting grpcServer", "addr", lisAddr) + if err := server.Serve(lis); err != nil { + if !errors.Is(err, grpc.ErrServerStopped) { + h.logger.Error(err, "failed to start grpcServer") + return err + } + } + + <-serverShutdown + return nil +} + +func (h *Handler) Push(ctx context.Context, req *v1.PushRequest) (*v1.PushResponse, error) { + resp := &v1.PushResponse{} + h.logger.Info("Received push request", "platform", req.Platform, "tokens", req.Tokens, "req", req) + + status.StatStorage.AddGrpcTotal(1) + + service, err := h.factory.GetPushService(req.Platform) + if err != nil { + status.StatStorage.AddGrpcFailed(1) + h.logger.Error(err, "failed to create push service") + return resp, err + } + + r, err := h.getPushRequest(req) + if err != nil { + status.StatStorage.AddGrpcFailed(1) + h.logger.Error(err, "failed to get push request") + return nil, err + } + + _, err = service.Send(ctx, r, &push2.SendOptions{ + DryRun: req.Option.DryRun, + Retry: int(req.Option.Retry), + }) + if err != nil { + status.StatStorage.AddGrpcFailed(1) + h.logger.Error(err, "failed to send push") + return resp, err + } + + status.StatStorage.AddGrpcSuccess(1) + + h.logger.Info("Push request processed success") + return resp, nil +} + +func (h *Handler) getPushRequest(req *v1.PushRequest) (push.PushRequest, error) { + badge := int(req.Badge) + + data := make(map[string]interface{}) + if req.Data != nil { + jsonStr, err := (&jsonpb.Marshaler{}).MarshalToString(req.Data) + if err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(jsonStr), &data); err != nil { + return nil, err + } + } + + alert := notify.Alert{} + if req.Alert != nil { + alert = notify.Alert{ + Action: req.Alert.Action, + ActionLocKey: req.Alert.ActionLocKey, + Body: req.Alert.Body, + LaunchImage: req.Alert.LaunchImage, + LocArgs: req.Alert.LocArgs, + LocKey: req.Alert.LocKey, + Title: req.Alert.Title, + Subtitle: req.Alert.Subtitle, + TitleLocArgs: req.Alert.TitleLocArgs, + TitleLocKey: req.Alert.TitleLocKey, + } + } + + return ¬ify.ApnsPushNotification{ + AppID: req.AppID, + ApnsID: req.AppID, + Tokens: req.Tokens, + Title: req.Title, + Content: req.Message, + Topic: req.Topic, + Category: req.Category, + Sound: req.Sound, + Alert: alert, + Badge: &badge, + ThreadID: req.ThreadID, + Data: data, + PushType: req.PushType, + Priority: string(req.Priority), + ContentAvailable: req.ContentAvailable, + MutableContent: req.MutableContent, + Development: req.Development, + }, nil +} + +func (h *Handler) validatePushRequest(req *v1.PushRequest) error { + if req == nil { + return errors.New("request is nil") + } + + if !consts.Platform(req.Platform).IsValid() { + return errors.New("invalid platform") + + } + + if len(req.Tokens) == 0 { + return errors.New("tokens are required") + } + + // 检查其他必填字段 + if req.Title == "" { + return errors.New("title is required") + } + + if req.Message == "" { + return errors.New("message is required") + } + + // 检查 Alert 字段 + if req.Alert != nil { + if err := h.validateAlert(req.Alert); err != nil { + return err + } + } + + // 检查 Data 字段 + if req.Data == nil { + //return errors.New("data is required") + } + + return nil +} + +func (h *Handler) validateAlert(alert *v1.Alert) error { + // TODO 检查 Alert 字段的必填参数 + return nil +} + +func (h *Handler) mustEmbedUnimplementedPushServiceServer() { + //TODO implement me + panic("implement me") +}