From 67814fd416a745da5584c1c5006738c16255db9f Mon Sep 17 00:00:00 2001 From: bilalcaliskan Date: Sat, 10 Feb 2024 13:01:25 +0300 Subject: [PATCH] refactor: handle multiple cases on both daemon and cli --- cmd/cli/add/add.go | 5 ++++- cmd/cli/list/list.go | 11 +++++++++-- cmd/cli/remove/remove.go | 16 +++++++++++++--- internal/constants/infos.go | 1 + internal/constants/keys.go | 7 ------- internal/constants/others.go | 7 +++++++ internal/ipc/ipc.go | 13 +++++++++---- internal/state/state.go | 26 ++++++++++++++------------ internal/utils/utils.go | 20 +++++++++++++++++++- 9 files changed, 76 insertions(+), 30 deletions(-) delete mode 100644 internal/constants/keys.go create mode 100644 internal/constants/others.go diff --git a/cmd/cli/add/add.go b/cmd/cli/add/add.go index 4646887..426c950 100644 --- a/cmd/cli/add/add.go +++ b/cmd/cli/add/add.go @@ -31,7 +31,10 @@ to quickly create a Cobra application.`, RunE: func(cmd *cobra.Command, args []string) error { logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger) - logger.Debug().Any("args", args).Msg("add command called") + logger.Info(). + Str("operation", cmd.Name()). + Any("args", args). + Msg(constants.ProcessCommand) for _, arg := range args { req := fmt.Sprintf("%s %s", cmd.Name(), arg) diff --git a/cmd/cli/list/list.go b/cmd/cli/list/list.go index 57fd0a4..cb0cdcd 100644 --- a/cmd/cli/list/list.go +++ b/cmd/cli/list/list.go @@ -1,6 +1,8 @@ package list import ( + "fmt" + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" "github.com/bilalcaliskan/split-the-tunnel/internal/constants" "github.com/rs/zerolog" @@ -27,7 +29,9 @@ to quickly create a Cobra application.`, RunE: func(cmd *cobra.Command, args []string) error { logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger) - logger.Debug().Msg("list command called") + logger.Info(). + Str("operation", cmd.Name()). + Msg(constants.ProcessCommand) res, err := utils.SendCommandToDaemon(utils.SocketPath, cmd.Name()) if err != nil { @@ -36,7 +40,10 @@ to quickly create a Cobra application.`, return &utils.CommandError{Err: err, Code: 10} } - logger.Info().Str("command", cmd.Name()).Str("response", res).Msg(constants.SuccessfullyProcessed) + logger.Info().Str("command", cmd.Name()).Msg(constants.SuccessfullyProcessed) + + fmt.Println("here is your state:") + fmt.Print(res) return nil }, diff --git a/cmd/cli/remove/remove.go b/cmd/cli/remove/remove.go index e4098fc..3b14e71 100644 --- a/cmd/cli/remove/remove.go +++ b/cmd/cli/remove/remove.go @@ -31,17 +31,27 @@ to quickly create a Cobra application.`, RunE: func(cmd *cobra.Command, args []string) error { logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger) - logger.Info().Any("args", args).Msg("remove command called") + logger.Info(). + Str("operation", cmd.Name()). + Any("args", args). + Msg(constants.ProcessCommand) for _, arg := range args { req := fmt.Sprintf("%s %s", cmd.Name(), arg) res, err := utils.SendCommandToDaemon(utils.SocketPath, req) if err != nil { - logger.Error().Str("command", req).Err(err).Msg(constants.FailedToProcessCommand) + logger.Error(). + Str("command", req). + Err(err). + Msg(constants.FailedToProcessCommand) + continue } - logger.Info().Str("command", req).Str("response", res).Msg(constants.SuccessfullyProcessed) + logger.Info(). + Str("command", req). + Str("response", res). + Msg(constants.SuccessfullyProcessed) } return nil diff --git a/internal/constants/infos.go b/internal/constants/infos.go index fe8f01b..299bf6c 100644 --- a/internal/constants/infos.go +++ b/internal/constants/infos.go @@ -7,4 +7,5 @@ const ( TermSignalReceived = "termination signal received" ShuttingDownDaemon = "shutting down daemon..." AppStarted = "split-the-tunnel is started!" + ProcessCommand = "processing command" ) diff --git a/internal/constants/keys.go b/internal/constants/keys.go deleted file mode 100644 index 810d22a..0000000 --- a/internal/constants/keys.go +++ /dev/null @@ -1,7 +0,0 @@ -package constants - -//const ( -// LoggerKey = "logger" -//) - -type LoggerKey struct{} diff --git a/internal/constants/others.go b/internal/constants/others.go new file mode 100644 index 0000000..fcc7d8b --- /dev/null +++ b/internal/constants/others.go @@ -0,0 +1,7 @@ +package constants + +const ( + StateFilePath = "/tmp/state.json" +) + +type LoggerKey struct{} diff --git a/internal/ipc/ipc.go b/internal/ipc/ipc.go index b244165..f100cb0 100644 --- a/internal/ipc/ipc.go +++ b/internal/ipc/ipc.go @@ -68,7 +68,7 @@ func handleConnection(conn net.Conn, logger zerolog.Logger) { logger.Info().Str("command", command).Msg("received command") st := state.NewState() - if err := st.Read("/tmp/state.json"); err != nil { + if err := st.Read(constants.StateFilePath); err != nil { logger.Error().Err(err).Msg(constants.FailedToReadState) continue } @@ -111,7 +111,7 @@ func handleAddCommand(logger zerolog.Logger, gw string, domains []string, conn n logger = logger.With().Str("operation", "add").Logger() for _, domain := range domains { - ip, err := utils.ResolveDomain(domain) + ips, err := utils.ResolveDomain(domain) if err != nil { if err := writeResponse(&DaemonResponse{ Success: false, @@ -127,7 +127,7 @@ func handleAddCommand(logger zerolog.Logger, gw string, domains []string, conn n continue } - re := state.NewRouteEntry(domain, ip[0], gw) + re := state.NewRouteEntry(domain, gw, ips) if err := st.AddEntry(re); err != nil { if err := writeResponse(&DaemonResponse{ @@ -212,7 +212,12 @@ func handleListCommand(logger zerolog.Logger, conn net.Conn, st *state.State) { response.Success = false response.Response = "" - response.Error = "a dummy error list command" + response.Error = "" + + // print the state + for _, entry := range st.Entries { + response.Response += fmt.Sprintf("Domain: %s, Gateway: %s, IPs: %v\n", entry.Domain, entry.Gateway, entry.ResolvedIPs) + } responseJson, err := json.Marshal(response) if err != nil { diff --git a/internal/state/state.go b/internal/state/state.go index c0c9d0c..4770910 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -4,6 +4,8 @@ import ( "encoding/json" "os" + "github.com/bilalcaliskan/split-the-tunnel/internal/utils" + "github.com/bilalcaliskan/split-the-tunnel/internal/constants" "github.com/pkg/errors" @@ -22,16 +24,16 @@ func NewState() *State { // RouteEntry is the struct that holds the state of a single route entry type RouteEntry struct { - Domain string `json:"domain"` - ResolvedIP string `json:"resolvedIP"` - Gateway string `json:"gateway"` + Domain string `json:"domain"` + Gateway string `json:"gateway"` + ResolvedIPs []string `json:"resolvedIPs"` } -func NewRouteEntry(domain, resolvedIP, gateway string) *RouteEntry { +func NewRouteEntry(domain, gateway string, resolvedIPs []string) *RouteEntry { return &RouteEntry{ - Domain: domain, - ResolvedIP: resolvedIP, - Gateway: gateway, + Domain: domain, + Gateway: gateway, + ResolvedIPs: resolvedIPs, } } @@ -39,18 +41,18 @@ func NewRouteEntry(domain, resolvedIP, gateway string) *RouteEntry { func (s *State) AddEntry(entry *RouteEntry) error { for _, e := range s.Entries { if e.Domain == entry.Domain { - if e.ResolvedIP == entry.ResolvedIP { + if utils.SlicesEqual(e.ResolvedIPs, entry.ResolvedIPs) { return errors.New(constants.EntryAlreadyExists) } - e.ResolvedIP = entry.ResolvedIP - return s.Write("/tmp/state.json") + e.ResolvedIPs = entry.ResolvedIPs + return s.Write(constants.StateFilePath) } } s.Entries = append(s.Entries, entry) - return s.Write("/tmp/state.json") + return s.Write(constants.StateFilePath) } // RemoveEntry removes a route entry from the state. @@ -58,7 +60,7 @@ func (s *State) RemoveEntry(domain string) error { for i, entry := range s.Entries { if entry.Domain == domain { s.Entries = append(s.Entries[:i], s.Entries[i+1:]...) - return s.Write("/tmp/state.json") + return s.Write(constants.StateFilePath) } } diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 9bdc6b8..a1e394d 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "os" + "sort" "strconv" "strings" @@ -22,7 +23,9 @@ func ResolveDomain(domain string) ([]string, error) { var ipStrings []string for _, ip := range ips { - ipStrings = append(ipStrings, ip.String()) + if ip.To4() != nil { + ipStrings = append(ipStrings, ip.String()) + } } return ipStrings, nil @@ -88,6 +91,21 @@ func parseHexIP(hexStr string) (string, error) { return fmt.Sprintf("%d.%d.%d.%d", ipBytes[0], ipBytes[1], ipBytes[2], ipBytes[3]), nil } +// SlicesEqual checks if two string slices are equal +func SlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + sort.Strings(a) + sort.Strings(b) + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + /*func addRoute(ip, gateway string) error { cmd := exec.Command("sudo", "ip", "route", "add", ip, "via", gateway) err := cmd.Run()