Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: handle multiple cases on both daemon and cli #12

Merged
merged 1 commit into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cmd/cli/add/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ to quickly create a Cobra application.`,
RunE: func(cmd *cobra.Command, args []string) error {
logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger)

logger.Debug().Any("args", args).Msg("add command called")
logger.Info().
Str("operation", cmd.Name()).
Any("args", args).
Msg(constants.ProcessCommand)

for _, arg := range args {
req := fmt.Sprintf("%s %s", cmd.Name(), arg)
Expand Down
11 changes: 9 additions & 2 deletions cmd/cli/list/list.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package list

import (
"fmt"

"github.com/bilalcaliskan/split-the-tunnel/cmd/cli/utils"
"github.com/bilalcaliskan/split-the-tunnel/internal/constants"
"github.com/rs/zerolog"
Expand All @@ -27,7 +29,9 @@ to quickly create a Cobra application.`,
RunE: func(cmd *cobra.Command, args []string) error {
logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger)

logger.Debug().Msg("list command called")
logger.Info().
Str("operation", cmd.Name()).
Msg(constants.ProcessCommand)

res, err := utils.SendCommandToDaemon(utils.SocketPath, cmd.Name())
if err != nil {
Expand All @@ -36,7 +40,10 @@ to quickly create a Cobra application.`,
return &utils.CommandError{Err: err, Code: 10}
}

logger.Info().Str("command", cmd.Name()).Str("response", res).Msg(constants.SuccessfullyProcessed)
logger.Info().Str("command", cmd.Name()).Msg(constants.SuccessfullyProcessed)

fmt.Println("here is your state:")
fmt.Print(res)

return nil
},
Expand Down
16 changes: 13 additions & 3 deletions cmd/cli/remove/remove.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,27 @@ to quickly create a Cobra application.`,
RunE: func(cmd *cobra.Command, args []string) error {
logger := cmd.Context().Value(constants.LoggerKey{}).(zerolog.Logger)

logger.Info().Any("args", args).Msg("remove command called")
logger.Info().
Str("operation", cmd.Name()).
Any("args", args).
Msg(constants.ProcessCommand)

for _, arg := range args {
req := fmt.Sprintf("%s %s", cmd.Name(), arg)
res, err := utils.SendCommandToDaemon(utils.SocketPath, req)
if err != nil {
logger.Error().Str("command", req).Err(err).Msg(constants.FailedToProcessCommand)
logger.Error().
Str("command", req).
Err(err).
Msg(constants.FailedToProcessCommand)

continue
}

logger.Info().Str("command", req).Str("response", res).Msg(constants.SuccessfullyProcessed)
logger.Info().
Str("command", req).
Str("response", res).
Msg(constants.SuccessfullyProcessed)
}

return nil
Expand Down
1 change: 1 addition & 0 deletions internal/constants/infos.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ const (
TermSignalReceived = "termination signal received"
ShuttingDownDaemon = "shutting down daemon..."
AppStarted = "split-the-tunnel is started!"
ProcessCommand = "processing command"
)
7 changes: 0 additions & 7 deletions internal/constants/keys.go

This file was deleted.

7 changes: 7 additions & 0 deletions internal/constants/others.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package constants

const (
StateFilePath = "/tmp/state.json"
)

type LoggerKey struct{}
13 changes: 9 additions & 4 deletions internal/ipc/ipc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func handleConnection(conn net.Conn, logger zerolog.Logger) {
logger.Info().Str("command", command).Msg("received command")

st := state.NewState()
if err := st.Read("/tmp/state.json"); err != nil {
if err := st.Read(constants.StateFilePath); err != nil {
logger.Error().Err(err).Msg(constants.FailedToReadState)
continue
}
Expand Down Expand Up @@ -111,7 +111,7 @@ func handleAddCommand(logger zerolog.Logger, gw string, domains []string, conn n
logger = logger.With().Str("operation", "add").Logger()

for _, domain := range domains {
ip, err := utils.ResolveDomain(domain)
ips, err := utils.ResolveDomain(domain)
if err != nil {
if err := writeResponse(&DaemonResponse{
Success: false,
Expand All @@ -127,7 +127,7 @@ func handleAddCommand(logger zerolog.Logger, gw string, domains []string, conn n
continue
}

re := state.NewRouteEntry(domain, ip[0], gw)
re := state.NewRouteEntry(domain, gw, ips)

if err := st.AddEntry(re); err != nil {
if err := writeResponse(&DaemonResponse{
Expand Down Expand Up @@ -212,7 +212,12 @@ func handleListCommand(logger zerolog.Logger, conn net.Conn, st *state.State) {

response.Success = false
response.Response = ""
response.Error = "a dummy error list command"
response.Error = ""

// print the state
for _, entry := range st.Entries {
response.Response += fmt.Sprintf("Domain: %s, Gateway: %s, IPs: %v\n", entry.Domain, entry.Gateway, entry.ResolvedIPs)
}

responseJson, err := json.Marshal(response)
if err != nil {
Expand Down
26 changes: 14 additions & 12 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"os"

"github.com/bilalcaliskan/split-the-tunnel/internal/utils"

"github.com/bilalcaliskan/split-the-tunnel/internal/constants"

"github.com/pkg/errors"
Expand All @@ -22,43 +24,43 @@ func NewState() *State {

// RouteEntry is the struct that holds the state of a single route entry
type RouteEntry struct {
Domain string `json:"domain"`
ResolvedIP string `json:"resolvedIP"`
Gateway string `json:"gateway"`
Domain string `json:"domain"`
Gateway string `json:"gateway"`
ResolvedIPs []string `json:"resolvedIPs"`
}

func NewRouteEntry(domain, resolvedIP, gateway string) *RouteEntry {
func NewRouteEntry(domain, gateway string, resolvedIPs []string) *RouteEntry {
return &RouteEntry{
Domain: domain,
ResolvedIP: resolvedIP,
Gateway: gateway,
Domain: domain,
Gateway: gateway,
ResolvedIPs: resolvedIPs,
}
}

// AddEntry adds a new route entry to the state. If the entry already exists, it updates the ResolvedIP.
func (s *State) AddEntry(entry *RouteEntry) error {
for _, e := range s.Entries {
if e.Domain == entry.Domain {
if e.ResolvedIP == entry.ResolvedIP {
if utils.SlicesEqual(e.ResolvedIPs, entry.ResolvedIPs) {
return errors.New(constants.EntryAlreadyExists)
}

e.ResolvedIP = entry.ResolvedIP
return s.Write("/tmp/state.json")
e.ResolvedIPs = entry.ResolvedIPs
return s.Write(constants.StateFilePath)
}
}

s.Entries = append(s.Entries, entry)

return s.Write("/tmp/state.json")
return s.Write(constants.StateFilePath)
}

// RemoveEntry removes a route entry from the state.
func (s *State) RemoveEntry(domain string) error {
for i, entry := range s.Entries {
if entry.Domain == domain {
s.Entries = append(s.Entries[:i], s.Entries[i+1:]...)
return s.Write("/tmp/state.json")
return s.Write(constants.StateFilePath)
}
}

Expand Down
20 changes: 19 additions & 1 deletion internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"os"
"sort"
"strconv"
"strings"

Expand All @@ -22,7 +23,9 @@ func ResolveDomain(domain string) ([]string, error) {

var ipStrings []string
for _, ip := range ips {
ipStrings = append(ipStrings, ip.String())
if ip.To4() != nil {
ipStrings = append(ipStrings, ip.String())
}
}

return ipStrings, nil
Expand Down Expand Up @@ -88,6 +91,21 @@ func parseHexIP(hexStr string) (string, error) {
return fmt.Sprintf("%d.%d.%d.%d", ipBytes[0], ipBytes[1], ipBytes[2], ipBytes[3]), nil
}

// SlicesEqual checks if two string slices are equal
func SlicesEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
sort.Strings(a)
sort.Strings(b)
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}

/*func addRoute(ip, gateway string) error {
cmd := exec.Command("sudo", "ip", "route", "add", ip, "via", gateway)
err := cmd.Run()
Expand Down
Loading