diff --git a/cmd/cli/add/add.go b/cmd/cli/add/add.go index c8b85bd..4646887 100644 --- a/cmd/cli/add/add.go +++ b/cmd/cli/add/add.go @@ -3,10 +3,11 @@ package add import ( "fmt" + "github.com/rs/zerolog" + "github.com/bilalcaliskan/split-the-tunnel/internal/constants" "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" - "github.com/bilalcaliskan/split-the-tunnel/internal/logging" "github.com/spf13/cobra" ) @@ -28,15 +29,15 @@ to quickly create a Cobra application.`, return nil }, RunE: func(cmd *cobra.Command, args []string) error { - logger := logging.GetLogger() + logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger) - logger.Info().Any("args", args).Msg("add called") + logger.Debug().Any("args", args).Msg("add command called") for _, arg := range args { req := fmt.Sprintf("%s %s", cmd.Name(), arg) res, err := utils.SendCommandToDaemon(utils.SocketPath, req) if err != nil { - logger.Error().Str("command", req).Err(err).Msg(constants.FailedToSendCommand) + logger.Error().Str("command", req).Err(err).Msg(constants.FailedToProcessCommand) continue } diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 3e1017e..1fd7ed8 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -1,8 +1,13 @@ package main import ( + "context" + "errors" "os" + "github.com/bilalcaliskan/split-the-tunnel/internal/constants" + "github.com/bilalcaliskan/split-the-tunnel/internal/logging" + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/add" "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/list" "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/remove" @@ -13,12 +18,26 @@ import ( ) var ( - ver = version.Get() - cliCmd = &cobra.Command{ + verbose bool + ver = version.Get() + cliCmd = &cobra.Command{ Use: "stt-cli", Short: "", Long: ``, Version: ver.GitVersion, + PersistentPreRun: func(cmd *cobra.Command, args []string) { + logger := logging.GetLogger() + logger.Info().Str("appVersion", ver.GitVersion).Str("goVersion", ver.GoVersion).Str("goOS", ver.GoOs). + Str("goArch", ver.GoArch).Str("gitCommit", ver.GitCommit).Str("buildDate", ver.BuildDate). + Msg("split-the-tunnel cli is started!") + + if verbose { + logger = logging.WithVerbose() + logger.Debug().Str("foo", "bar").Msg("this is a dummy log") + } + + cmd.SetContext(context.WithValue(cmd.Context(), constants.LoggerKey{}, logger)) + }, } ) @@ -26,7 +45,8 @@ func main() { if err := cliCmd.Execute(); err != nil { // extract the response code from the error var resCode int - if cmdErr, ok := err.(*utils.CommandError); ok { + var cmdErr *utils.CommandError + if errors.As(err, &cmdErr) { resCode = cmdErr.Code } @@ -39,6 +59,8 @@ func main() { } func init() { + cliCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "enable verbose mode") + cliCmd.AddCommand(add.AddCmd) cliCmd.AddCommand(list.ListCmd) cliCmd.AddCommand(remove.RemoveCmd) diff --git a/cmd/cli/list/list.go b/cmd/cli/list/list.go index 31eea9e..57fd0a4 100644 --- a/cmd/cli/list/list.go +++ b/cmd/cli/list/list.go @@ -3,7 +3,7 @@ package list import ( "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" "github.com/bilalcaliskan/split-the-tunnel/internal/constants" - "github.com/bilalcaliskan/split-the-tunnel/internal/logging" + "github.com/rs/zerolog" "github.com/spf13/cobra" ) @@ -25,12 +25,13 @@ to quickly create a Cobra application.`, return nil }, RunE: func(cmd *cobra.Command, args []string) error { - logger := logging.GetLogger() + logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger) + + logger.Debug().Msg("list command called") - logger.Info().Msg("list called") res, err := utils.SendCommandToDaemon(utils.SocketPath, cmd.Name()) if err != nil { - logger.Error().Str("command", cmd.Name()).Err(err).Msg(constants.FailedToSendCommand) + logger.Error().Str("command", cmd.Name()).Err(err).Msg(constants.FailedToProcessCommand) return &utils.CommandError{Err: err, Code: 10} } diff --git a/cmd/cli/remove/remove.go b/cmd/cli/remove/remove.go index 305f913..e4098fc 100644 --- a/cmd/cli/remove/remove.go +++ b/cmd/cli/remove/remove.go @@ -3,11 +3,11 @@ package remove import ( "fmt" + "github.com/rs/zerolog" + "github.com/bilalcaliskan/split-the-tunnel/internal/constants" "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" - "github.com/bilalcaliskan/split-the-tunnel/internal/logging" - "github.com/spf13/cobra" ) @@ -29,15 +29,15 @@ to quickly create a Cobra application.`, return nil }, RunE: func(cmd *cobra.Command, args []string) error { - logger := logging.GetLogger() + logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger) - logger.Info().Any("args", args).Msg("add called") + logger.Info().Any("args", args).Msg("remove command called") for _, arg := range args { req := fmt.Sprintf("%s %s", cmd.Name(), arg) res, err := utils.SendCommandToDaemon(utils.SocketPath, req) if err != nil { - logger.Error().Str("command", req).Err(err).Msg(constants.FailedToSendCommand) + logger.Error().Str("command", req).Err(err).Msg(constants.FailedToProcessCommand) continue } diff --git a/cmd/cli/utils/utils.go b/cmd/cli/utils/utils.go index 3345eaa..22d562e 100644 --- a/cmd/cli/utils/utils.go +++ b/cmd/cli/utils/utils.go @@ -4,6 +4,8 @@ import ( "encoding/json" "net" + "github.com/bilalcaliskan/split-the-tunnel/internal/constants" + "github.com/pkg/errors" ) @@ -23,13 +25,13 @@ type DaemonResponse struct { func SendCommandToDaemon(socketPath, command string) (string, error) { conn, err := net.Dial("unix", socketPath) if err != nil { - return "", errors.Wrap(err, "failed to connect to unix domain socket") + return "", errors.Wrap(err, constants.FailedToConnectToUnixDomainSocket) } defer conn.Close() _, err = conn.Write([]byte(command + "\n")) if err != nil { - return "", errors.Wrap(err, "failed to write to unix domain socket") + return "", errors.Wrap(err, constants.FailedToWriteToUnixDomainSocket) } buf := make([]byte, 1024) diff --git a/cmd/daemon/daemon.go b/cmd/daemon/daemon.go index dca4423..5579313 100644 --- a/cmd/daemon/daemon.go +++ b/cmd/daemon/daemon.go @@ -56,7 +56,7 @@ var daemonCmd = &cobra.Command{ } }() - logger.Info().Msg(constants.DaemonRunning) + logger.Info().Str("socket", socketPath).Msg(constants.DaemonRunning) // Wait for termination signal <-sigs diff --git a/cmd/daemon/options/options.go b/cmd/daemon/options/options.go index e04bcf2..38fcaf6 100644 --- a/cmd/daemon/options/options.go +++ b/cmd/daemon/options/options.go @@ -17,7 +17,7 @@ func GetRootOptions() *RootOptions { } func (opts *RootOptions) InitFlags(cmd *cobra.Command) { - cmd.Flags().BoolVarP(&opts.Verbose, "verbose", "", false, "verbose log") + cmd.Flags().BoolVarP(&opts.Verbose, "verbose", "", false, "verbose logging output") cmd.Flags().StringVarP(&opts.DnsServers, "dns-servers", "", "", "comma separated dns servers to be used for DNS resolving") - cmd.Flags().IntVarP(&opts.CheckIntervalMin, "check-interval-min", "", 5, "") + cmd.Flags().IntVarP(&opts.CheckIntervalMin, "check-interval-min", "", 5, "routing table check interval with collected state, in minutes") } diff --git a/go.mod b/go.mod index 59d664d..d052406 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/pkg/errors v0.9.1 - github.com/rs/zerolog v1.31.0 + github.com/rs/zerolog v1.32.0 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 02e9035..6156b0d 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= -github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= diff --git a/internal/constants/errors.go b/internal/constants/errors.go index d2d2d92..92b5d11 100644 --- a/internal/constants/errors.go +++ b/internal/constants/errors.go @@ -14,11 +14,12 @@ const ( FailedToResolveDomain = "failed to resolve domain" FailedToMarshalResponse = "failed to marshal response" FailedToWriteRouteEntry = "failed to write RouteEntry to state" - EntryNotFound = "entry not found" + EntryNotFound = "route entry not found in state" NonVPNGatewayNotFound = "non-VPN gateway not found" FailedToDecodeHex = "failed to decode hex string" InvalidIpLength = "invalid IP length: %d" - FailedToSendCommand = "failed to send command to daemon" + FailedToProcessCommand = "failed to process command" FailedToCleanupIPC = "failed to cleanup IPC" FailedToInitializeIPC = "failed to initialize IPC" + FailedToRemoveRouteEntry = "failed to remove RouteEntry from state" ) diff --git a/internal/constants/infos.go b/internal/constants/infos.go index fe7afff..fe8f01b 100644 --- a/internal/constants/infos.go +++ b/internal/constants/infos.go @@ -3,7 +3,7 @@ package constants const ( SuccessfullyProcessed = "successfully processed command" IPCInitialized = "IPC is initialized" - DaemonRunning = "daemon is running..." + DaemonRunning = "daemon is running, waiting for requests over unix domain socket..." TermSignalReceived = "termination signal received" ShuttingDownDaemon = "shutting down daemon..." AppStarted = "split-the-tunnel is started!" diff --git a/internal/constants/keys.go b/internal/constants/keys.go new file mode 100644 index 0000000..810d22a --- /dev/null +++ b/internal/constants/keys.go @@ -0,0 +1,7 @@ +package constants + +//const ( +// LoggerKey = "logger" +//) + +type LoggerKey struct{} diff --git a/internal/ipc/ipc.go b/internal/ipc/ipc.go index 45bcdc6..ec14d9f 100644 --- a/internal/ipc/ipc.go +++ b/internal/ipc/ipc.go @@ -67,7 +67,7 @@ func handleConnection(conn net.Conn, logger zerolog.Logger) { command := strings.TrimSpace(message) logger.Info().Str("command", command).Msg("received command") - st := new(state.State) + st := state.NewState() if err := st.Read("/tmp/state.json"); err != nil { logger.Error().Err(err).Msg(constants.FailedToReadState) continue @@ -111,114 +111,99 @@ func handleAddCommand(logger zerolog.Logger, gw string, domains []string, conn n logger = logger.With().Str("operation", "add").Logger() for _, domain := range domains { - response := new(DaemonResponse) - ip, err := utils.ResolveDomain(domain) if err != nil { - response.Success = false - response.Response = "" - response.Error = errors.Wrap(err, constants.FailedToResolveDomain).Error() - - responseJson, err := json.Marshal(response) - if err != nil { - logger.Error(). - Err(err). - Str("domain", domain). - Msg(constants.FailedToMarshalResponse) - continue - } - - if _, err := conn.Write(responseJson); err != nil { + if err := writeResponse(&DaemonResponse{ + Success: false, + Response: "", + Error: errors.Wrap(err, constants.FailedToResolveDomain).Error(), + }, conn); err != nil { logger.Error(). Err(err). Str("domain", domain). Msg(constants.FailedToWriteToUnixDomainSocket) - continue } continue } - re := &state.RouteEntry{ - Domain: domain, - ResolvedIP: ip[0], - Gateway: gw, - } + re := state.NewRouteEntry(domain, ip[0], gw) if err := st.AddEntry(re); err != nil { - response.Success = false - response.Response = "" - response.Error = errors.Wrap(err, constants.FailedToWriteRouteEntry).Error() - - responseJson, err := json.Marshal(response) - if err != nil { - logger.Error(). - Err(err). - Str("domain", domain). - Msg(constants.FailedToMarshalResponse) - continue - } - - if _, err := conn.Write(responseJson); err != nil { + if err := writeResponse(&DaemonResponse{ + Success: false, + Response: "", + Error: errors.Wrap(err, constants.FailedToWriteRouteEntry).Error(), + }, conn); err != nil { logger.Error(). Err(err). Str("domain", domain). Msg(constants.FailedToWriteToUnixDomainSocket) - continue } - } - response.Success = false - response.Response = fmt.Sprintf("added route for " + domain) - response.Error = "" - - responseJson, err := json.Marshal(response) - if err != nil { - logger.Error(). - Err(err). - Str("domain", domain). - Msg(constants.FailedToMarshalResponse) continue } - // Send a response to the client - _, err = conn.Write(responseJson) - if err != nil { + logger.Info().Str("domain", domain).Msg("successfully added route to routing table") + + if err := writeResponse(&DaemonResponse{ + Success: true, + Response: fmt.Sprintf("added route for " + domain), + Error: "", + }, conn); err != nil { logger.Error(). Err(err). Str("domain", domain). Msg(constants.FailedToWriteToUnixDomainSocket) - continue } } } func handleRemoveCommand(logger zerolog.Logger, gw string, domains []string, conn net.Conn, st *state.State) { - for _, domain := range domains { - response := new(DaemonResponse) + logger = logger.With().Str("operation", "remove").Logger() - response.Success = false - response.Response = "" - response.Error = fmt.Sprintf("a dummy error for domain %s", domain) + for _, domain := range domains { + entry := st.GetEntry(domain) + if entry == nil { + if err := writeResponse(&DaemonResponse{ + Success: false, + Response: "", + Error: errors.New(constants.EntryNotFound).Error(), + }, conn); err != nil { + logger.Error(). + Err(err). + Str("domain", domain). + Msg(constants.FailedToWriteToUnixDomainSocket) + } + continue + } - responseJson, err := json.Marshal(response) - if err != nil { - logger.Error(). - Err(err). - Str("domain", domain). - Msg(constants.FailedToMarshalResponse) + if err := st.RemoveEntry(domain); err != nil { + if err := writeResponse(&DaemonResponse{ + Success: false, + Response: "", + Error: errors.Wrap(err, constants.FailedToRemoveRouteEntry).Error(), + }, conn); err != nil { + logger.Error(). + Err(err). + Str("domain", domain). + Msg(constants.FailedToWriteToUnixDomainSocket) + } continue } - if _, err := conn.Write(responseJson); err != nil { + logger.Info().Str("domain", domain).Msg("successfully removed route from routing table") + + if err := writeResponse(&DaemonResponse{ + Success: true, + Response: fmt.Sprintf("removed route for " + domain), + Error: "", + }, conn); err != nil { logger.Error(). Err(err). Str("domain", domain). Msg(constants.FailedToWriteToUnixDomainSocket) - continue } - - continue } } @@ -244,9 +229,3 @@ func handleListCommand(logger zerolog.Logger, conn net.Conn, st *state.State) { return } } - -func Cleanup(path string) error { - // Perform any cleanup and shutdown tasks here - - return os.Remove(path) -} diff --git a/internal/ipc/utils.go b/internal/ipc/utils.go new file mode 100644 index 0000000..95dcbb2 --- /dev/null +++ b/internal/ipc/utils.go @@ -0,0 +1,27 @@ +package ipc + +import ( + "encoding/json" + "net" + "os" + + "github.com/bilalcaliskan/split-the-tunnel/internal/constants" + "github.com/pkg/errors" +) + +func Cleanup(path string) error { + // Perform any cleanup and shutdown tasks here + + return os.Remove(path) +} + +func writeResponse(response *DaemonResponse, conn net.Conn) error { + responseJson, err := json.Marshal(response) + if err != nil { + return errors.Wrap(err, constants.FailedToMarshalResponse) + } + + _, err = conn.Write(responseJson) + + return err +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go index 80fda86..030c01d 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -21,6 +21,10 @@ func GetLogger() zerolog.Logger { return logger } +func WithVerbose() zerolog.Logger { + return logger.Level(zerolog.DebugLevel) +} + func EnableDebugLogging() { logger = logger.Level(zerolog.DebugLevel) } diff --git a/internal/state/state.go b/internal/state/state.go index 643a237..df018ad 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -14,6 +14,12 @@ type State struct { Entries []*RouteEntry `json:"entries"` } +func NewState() *State { + return &State{ + Entries: []*RouteEntry{}, + } +} + // RouteEntry is the struct that holds the state of a single route entry type RouteEntry struct { Domain string `json:"domain"` @@ -21,6 +27,14 @@ type RouteEntry struct { Gateway string `json:"gateway"` } +func NewRouteEntry(domain, resolvedIP, gateway string) *RouteEntry { + return &RouteEntry{ + Domain: domain, + ResolvedIP: resolvedIP, + Gateway: gateway, + } +} + // AddEntry adds a new route entry to the state. If the entry already exists, it updates the ResolvedIP. func (s *State) AddEntry(entry *RouteEntry) error { for _, e := range s.Entries { @@ -33,7 +47,8 @@ func (s *State) AddEntry(entry *RouteEntry) error { } s.Entries = append(s.Entries, entry) - return nil + + return s.Write("/tmp/state.json") } // RemoveEntry removes a route entry from the state. @@ -49,14 +64,14 @@ func (s *State) RemoveEntry(domain string) error { return errors.New(constants.EntryNotFound) } -func (s *State) GetEntry(domain string) (*RouteEntry, error) { +func (s *State) GetEntry(domain string) *RouteEntry { for i := range s.Entries { if s.Entries[i].Domain == domain { - return s.Entries[i], nil + return s.Entries[i] } } - return nil, errors.New(constants.EntryNotFound) + return nil } func (s *State) Read(path string) error {