diff --git a/cmd/cli/add/.gitkeep b/cmd/cli/add/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/cmd/cli/add/add.go b/cmd/cli/add/add.go new file mode 100644 index 0000000..d9b7e0b --- /dev/null +++ b/cmd/cli/add/add.go @@ -0,0 +1,46 @@ +package add + +import ( + "fmt" + + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" + "github.com/bilalcaliskan/split-the-tunnel/internal/logging" + "github.com/spf13/cobra" +) + +// AddCmd represents the add command +var AddCmd = &cobra.Command{ + Use: "add", + Short: "A brief description of your command", + Long: `A longer description that spans multiple lines and likely contains examples +and usage of using your command. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.`, + PreRunE: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return utils.ErrNoArgs + } + + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + logger := logging.GetLogger() + + logger.Info().Any("args", args).Msg("add called") + + for _, arg := range args { + req := fmt.Sprintf("%s %s", cmd.Name(), arg) + res, err := utils.SendCommandToDaemon(utils.SocketPath, req) + if err != nil { + logger.Error().Str("command", req).Err(err).Msg("error sending command to daemon") + continue + } + + logger.Info().Str("command", req).Str("response", res).Msg("successfully processed command") + } + + return nil + }, +} diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 42927e6..1438b99 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -3,42 +3,32 @@ package main import ( "os" + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/add" + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/list" + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/remove" "github.com/bilalcaliskan/split-the-tunnel/internal/version" + "github.com/spf13/cobra" ) -var ver = version.Get() - -// rootCmd represents the base command when called without any subcommands -var rootCmd = &cobra.Command{ - Use: "stt-cli", - Short: "", - Long: ``, - Version: ver.GitVersion, -} +var ( + ver = version.Get() + cliCmd = &cobra.Command{ + Use: "stt-cli", + Short: "", + Long: ``, + Version: ver.GitVersion, + } +) func main() { - if err := rootCmd.Execute(); err != nil { + if err := cliCmd.Execute(); err != nil { os.Exit(1) } } -//const socketPath = "/tmp/mydaemon.sock" - -//func sendCommandToDaemon(command string) (string, error) { -// conn, err := net.Dial("unix", socketPath) -// if err != nil { -// return "", err -// } -// defer conn.Close() -// -// _, err = conn.Write([]byte(command + "\n")) -// if err != nil { -// return "", err -// } -// -// // If you expect a response from the daemon, read it here -// // For example, using bufio.NewReader(conn).ReadString('\n') -// -// return "Command sent successfully", nil -//} +func init() { + cliCmd.AddCommand(add.AddCmd) + cliCmd.AddCommand(list.ListCmd) + cliCmd.AddCommand(remove.RemoveCmd) +} diff --git a/cmd/cli/list/.gitkeep b/cmd/cli/list/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/cmd/cli/list/list.go b/cmd/cli/list/list.go new file mode 100644 index 0000000..b0bf3ce --- /dev/null +++ b/cmd/cli/list/list.go @@ -0,0 +1,42 @@ +package list + +import ( + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" + "github.com/bilalcaliskan/split-the-tunnel/internal/logging" + + "github.com/spf13/cobra" +) + +// ListCmd represents the list command +var ListCmd = &cobra.Command{ + Use: "list", + Short: "A brief description of your command", + Long: `A longer description that spans multiple lines and likely contains examples +and usage of using your command. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.`, + PreRunE: func(cmd *cobra.Command, args []string) error { + if len(args) > 0 { + return utils.ErrTooManyArgs + } + + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + logger := logging.GetLogger() + + logger.Info().Msg("list called") + + req := cmd.Name() + res, err := utils.SendCommandToDaemon(utils.SocketPath, req) + if err != nil { + logger.Error().Err(err).Msg("error sending command to daemon") + return err + } + + logger.Info().Str("command", req).Str("response", res).Msg("successfully processed command") + return nil + }, +} diff --git a/cmd/cli/remove/.gitkeep b/cmd/cli/remove/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/cmd/cli/remove/remove.go b/cmd/cli/remove/remove.go new file mode 100644 index 0000000..fed0b4d --- /dev/null +++ b/cmd/cli/remove/remove.go @@ -0,0 +1,47 @@ +package remove + +import ( + "fmt" + "strings" + + "github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils" + "github.com/bilalcaliskan/split-the-tunnel/internal/logging" + + "github.com/spf13/cobra" +) + +// RemoveCmd represents the remove command +var RemoveCmd = &cobra.Command{ + Use: "remove", + Short: "A brief description of your command", + Long: `A longer description that spans multiple lines and likely contains examples +and usage of using your command. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.`, + PreRunE: func(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return utils.ErrNoArgs + } + + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + logger := logging.GetLogger() + + argsStr := strings.Join(args, " ") + + logger.Info().Str("args", argsStr).Msg("remove called") + + req := fmt.Sprintf("%s %s", cmd.Name(), argsStr) + res, err := utils.SendCommandToDaemon(utils.SocketPath, req) + if err != nil { + logger.Error().Err(err).Msg("error sending command to daemon") + return err + } + + logger.Info().Str("command", req).Str("response", res).Msg("successfully processed command") + return nil + }, +} diff --git a/cmd/cli/utils/utils.go b/cmd/cli/utils/utils.go new file mode 100644 index 0000000..3345eaa --- /dev/null +++ b/cmd/cli/utils/utils.go @@ -0,0 +1,52 @@ +package utils + +import ( + "encoding/json" + "net" + + "github.com/pkg/errors" +) + +const SocketPath = "/tmp/mydaemon.sock" + +var ( + ErrNoArgs = errors.New("no arguments provided") + ErrTooManyArgs = errors.New("too many arguments provided") +) + +type DaemonResponse struct { + Success bool `json:"success"` + Response string `json:"response"` + Error string `json:"error"` +} + +func SendCommandToDaemon(socketPath, command string) (string, error) { + conn, err := net.Dial("unix", socketPath) + if err != nil { + return "", errors.Wrap(err, "failed to connect to unix domain socket") + } + defer conn.Close() + + _, err = conn.Write([]byte(command + "\n")) + if err != nil { + return "", errors.Wrap(err, "failed to write to unix domain socket") + } + + buf := make([]byte, 1024) + n, err := conn.Read(buf[:]) + if err != nil { + return "", err + } + + var response DaemonResponse + if err := json.Unmarshal(buf[:n], &response); err != nil { + return "", err + } + + var respErr error + if response.Error != "" { + respErr = errors.New(response.Error) + } + + return response.Response, respErr +} diff --git a/go.mod b/go.mod index b0ca91f..59d664d 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/bilalcaliskan/split-the-tunnel go 1.21 require ( + github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.31.0 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index 0da5c92..02e9035 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= diff --git a/internal/ipc/ipc.go b/internal/ipc/ipc.go index 7b94049..39c44ea 100644 --- a/internal/ipc/ipc.go +++ b/internal/ipc/ipc.go @@ -2,16 +2,26 @@ package ipc import ( "bufio" - "errors" + "encoding/json" + "fmt" "io" "net" "os" "strings" + "github.com/bilalcaliskan/split-the-tunnel/internal/utils" + "github.com/bilalcaliskan/split-the-tunnel/internal/state" + "github.com/pkg/errors" "github.com/rs/zerolog" ) +type DaemonResponse struct { + Success bool `json:"success"` + Response string `json:"response"` + Error string `json:"error"` +} + func InitIPC(path string, logger zerolog.Logger) error { // Check and remove the socket file if it already exists if _, err := os.Stat(path); err == nil { @@ -63,105 +73,176 @@ func handleConnection(conn net.Conn, logger zerolog.Logger) { st := new(state.State) if err := st.Read("/tmp/state.json"); err != nil { - logger.Error().Str("path", "/tmp/state.json").Err(err).Msg("failed to read state") + logger.Error().Err(err).Msg("failed to read state") continue } - logger.Info().Any("state", st).Msg("read state") - - if err := processCommand(command, conn); err != nil { - logger.Error().Str("command", command).Err(err).Msg("error processing command") + // get default gateway + gw, err := utils.GetDefaultNonVPNGateway() + if err != nil { + logger.Error().Err(err).Msg("failed to get default gateway") continue } - logger.Info().Str("command", command).Msg("command processed successfully") + processCommand(logger, command, gw, conn, st) } } -func processCommand(command string, conn net.Conn) error { +func processCommand(logger zerolog.Logger, command, gateway string, conn net.Conn, st *state.State) { parts := strings.Fields(command) if len(parts) == 0 { - return errors.New("empty command received") + logger.Error().Msg("empty command received") + return } switch parts[0] { case "add": - if len(parts) < 2 { - _, err := conn.Write([]byte("'add' command requires at least a domain name\n")) - if err != nil { - return err - } + logger = logger.With().Str("operation", "add").Logger() + + //if len(parts) < 2 { + // errMsg := fmt.Sprintf("'%s' command requires at least a domain name", parts[0]) + // + // res := &DaemonResponse{ + // Success: false, + // Response: "", + // Error: errMsg, + // } + // + // responseJson, err := json.Marshal(res) + // if err != nil { + // logger.Error().Err(err).Msg("failed to marshal response") + // return + // } + // + // if _, err := conn.Write(responseJson); err != nil { + // logger.Error().Err(err).Msg("failed to write response to unix domain socket") + // return + // } + // + // logger.Error().Msg(errMsg) + //} + + handleAddCommand(logger, gateway, parts[1:], conn, st) + //case "remove": + // logger = logger.With().Str("operation", "remove").Logger() + // + // handleRemoveCommand(parts[1:], conn) + //case "list": + // logger = logger.With().Str("operation", "remove").Logger() + // + // handleListCommand(conn) + } +} - return errors.New("'add' command requires at least a domain name") - } +func handleAddCommand(logger zerolog.Logger, gw string, domains []string, conn net.Conn, st *state.State) { + logger = logger.With().Str("operation", "add").Logger() + + for _, domain := range domains { + response := new(DaemonResponse) - return handleAddCommand(parts[1:], conn) - case "remove": - if len(parts) < 2 { - _, err := conn.Write([]byte("'remove' command requires at least a domain name\n")) + ip, err := utils.ResolveDomain(domain) + if err != nil { + response.Success = false + response.Response = "" + response.Error = errors.Wrap(err, "failed to resolve domain").Error() + + responseJson, err := json.Marshal(response) if err != nil { - return err + logger.Error(). + Err(err). + Str("domain", domain). + Msg("failed to marshal response object") + continue + } + + if _, err := conn.Write(responseJson); err != nil { + logger.Error(). + Err(err). + Str("domain", domain). + Msg("failed to write response to unix domain socket") + continue } - return errors.New("'remove' command requires at least a domain name") + continue + } + + re := &state.RouteEntry{ + Domain: domain, + ResolvedIP: ip[0], + Gateway: gw, } - return handleRemoveCommand(parts[1:], conn) - case "list": - if len(parts) != 1 { - _, err := conn.Write([]byte("'list' command does not accept any arguments\n")) + if err := st.AddEntry(re); err != nil { + response.Success = false + response.Response = "" + response.Error = errors.Wrap(err, "failed to write RouteEntry to state").Error() + + responseJson, err := json.Marshal(response) if err != nil { - return err + logger.Error(). + Err(err). + Str("domain", domain). + Msg("failed to marshal response object") + continue } - return errors.New("'list' command does not accept any arguments") + if _, err := conn.Write(responseJson); err != nil { + logger.Error(). + Err(err). + Str("domain", domain). + Msg("failed to write response to unix domain socket") + continue + } } - return handleListCommand(conn) - default: - _, err := conn.Write([]byte("unknown command received\n")) - return err - } -} - -func handleAddCommand(domains []string, conn net.Conn) error { - // Add the domain to the routing table - // ... + response.Success = false + response.Response = fmt.Sprintf("added route for " + domain) + response.Error = "" - for _, domain := range domains { - // Send a response to the client - _, err := conn.Write([]byte("added route for " + domain + "\n")) + responseJson, err := json.Marshal(response) if err != nil { - return err + logger.Error(). + Err(err). + Str("domain", domain). + Msg("failed to marshal response object") + continue } - } - return nil -} - -func handleRemoveCommand(domains []string, conn net.Conn) error { - // Remove the domain from the routing table - // ... - - for _, domain := range domains { // Send a response to the client - _, err := conn.Write([]byte("removed route for " + domain + "\n")) + _, err = conn.Write(responseJson) if err != nil { - return err + logger.Error(). + Err(err). + Str("domain", domain). + Msg("failed to write response to unix domain socket") + continue } } - - return nil } -func handleListCommand(conn net.Conn) error { - // List the domains that we manage from the routing table - // ... - - // Send a response to the client - _, err := conn.Write([]byte("listing routes\n")) - return err -} +//func handleRemoveCommand(domains []string, conn net.Conn) error { +// // Remove the domain from the routing table +// // ... +// +// for _, domain := range domains { +// // Send a response to the client +// _, err := conn.Write([]byte("removed route for " + domain + "\n")) +// if err != nil { +// return err +// } +// } +// +// return nil +//} + +//func handleListCommand(conn net.Conn) error { +// // List the domains that we manage from the routing table +// // ... +// +// // Send a response to the client +// _, err := conn.Write([]byte("listing routes\n")) +// return err +//} func Cleanup(path string) error { // Perform any cleanup and shutdown tasks here diff --git a/internal/state/state.go b/internal/state/state.go index 8ad8419..4efa929 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -3,40 +3,58 @@ package state import ( "encoding/json" "os" + + "github.com/pkg/errors" ) -type RouteEntry struct { - Domain string `json:"domain"` - ResolvedIP string `json:"resolvedIP"` - OriginalGateway string `json:"originalGateway"` - // Add other fields as necessary +// State is the struct that holds the state of the application +type State struct { + Entries []*RouteEntry `json:"entries"` } -type State struct { - Entries []RouteEntry `json:"entries"` +// RouteEntry is the struct that holds the state of a single route entry +type RouteEntry struct { + Domain string `json:"domain"` + ResolvedIP string `json:"resolvedIP"` + Gateway string `json:"gateway"` } -func (s *State) AddEntry(entry RouteEntry) { +// AddEntry adds a new route entry to the state. If the entry already exists, it updates the ResolvedIP. +func (s *State) AddEntry(entry *RouteEntry) error { + for _, e := range s.Entries { + if e.Domain == entry.Domain { + if e.ResolvedIP != entry.ResolvedIP { + e.ResolvedIP = entry.ResolvedIP + } + return nil + } + } + s.Entries = append(s.Entries, entry) + return nil } -func (s *State) RemoveEntry(domain string) { +// RemoveEntry removes a route entry from the state. +func (s *State) RemoveEntry(domain string) error { for i, entry := range s.Entries { if entry.Domain == domain { s.Entries = append(s.Entries[:i], s.Entries[i+1:]...) - break + return nil } } + + // target entry not found + return errors.New("entry not found") } -func (s *State) GetEntry(domain string) *RouteEntry { - for _, entry := range s.Entries { - if entry.Domain == domain { - return &entry +func (s *State) GetEntry(domain string) (*RouteEntry, error) { + for i := range s.Entries { + if s.Entries[i].Domain == domain { + return s.Entries[i], nil } } - return nil + return nil, errors.New("entry not found") } func (s *State) Read(path string) error { diff --git a/internal/utils/utils.go b/internal/utils/utils.go index c81776e..0476165 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -4,15 +4,32 @@ import ( "bufio" "encoding/hex" "fmt" + "net" "os" "strconv" "strings" + + "github.com/pkg/errors" ) +func ResolveDomain(domain string) ([]string, error) { + ips, err := net.LookupIP(domain) + if err != nil { + return nil, err + } + + var ipStrings []string + for _, ip := range ips { + ipStrings = append(ipStrings, ip.String()) + } + + return ipStrings, nil +} + func GetDefaultNonVPNGateway() (string, error) { file, err := os.Open("/proc/net/route") if err != nil { - return "", fmt.Errorf("failed to open routing info: %w", err) + return "", errors.Wrap(err, "failed to open routing info file") } defer file.Close() @@ -41,7 +58,7 @@ func GetDefaultNonVPNGateway() (string, error) { } if err := scanner.Err(); err != nil { - return "", fmt.Errorf("error reading file: %w", err) + return "", errors.Wrap(err, "error reading file") } if bestGateway == "" { @@ -54,7 +71,7 @@ func GetDefaultNonVPNGateway() (string, error) { func parseHexIP(hexStr string) (string, error) { ipBytes, err := hex.DecodeString(hexStr) if err != nil { - return "", fmt.Errorf("failed to decode hex string: %w", err) + return "", errors.Wrap(err, "failed to decode hex string") } if len(ipBytes) != 4 { @@ -69,21 +86,7 @@ func parseHexIP(hexStr string) (string, error) { return fmt.Sprintf("%d.%d.%d.%d", ipBytes[0], ipBytes[1], ipBytes[2], ipBytes[3]), nil } -/*func resolveDomain(domain string) ([]string, error) { - ips, err := net.LookupIP(domain) - if err != nil { - return nil, err - } - - var ipStrings []string - for _, ip := range ips { - ipStrings = append(ipStrings, ip.String()) - } - - return ipStrings, nil -} - -func addRoute(ip, gateway string) error { +/*func addRoute(ip, gateway string) error { cmd := exec.Command("sudo", "ip", "route", "add", ip, "via", gateway) err := cmd.Run() if err != nil {