diff --git a/cmd/root.go b/cmd/root.go index 7680cdc..c9fe897 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -22,7 +22,7 @@ var ( Use: "wg-sync", Short: "syncs peers from a central url ", Long: `replaces all peers with those from a central url`, - RunE: sync, + RunE: syncPeers, } cfgFile string ) @@ -38,7 +38,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&cfgFile, "url", "", "config file (default is $HOME/.wg-sync.yaml)") } -func sync(cmd *cobra.Command, args []string) error { +func syncPeers(cmd *cobra.Command, args []string) error { resp, err := http.Get(cfgFile) if err != nil { diff --git a/cmd/serve.go b/cmd/serve.go index f34a159..5edfb73 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -4,17 +4,22 @@ import ( "context" "encoding/json" "errors" + "fmt" "log" "net" "net/http" "os" "os/signal" + "sync" "syscall" "time" + "github.com/paulgmiller/wg-sync/nethelpers" "github.com/paulgmiller/wg-sync/pretty" "github.com/paulgmiller/wg-sync/wghelpers" + "github.com/samber/lo" "github.com/spf13/cobra" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) const defaultJoinPort = ":5000" @@ -33,8 +38,9 @@ func init() { } func serve(cmd *cobra.Command, args []string) error { - http.HandleFunc("/peers", Peers) - srv := http.Server{Addr: ":8888"} + mux := http.NewServeMux() + mux.HandleFunc("/peers", Peers) + srv := http.Server{Addr: ":8888", Handler: mux} ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -46,7 +52,7 @@ func serve(cmd *cobra.Command, args []string) error { } }() - err := HaddleJoins(ctx) + err := HaddleJoins(ctx, cidrAllocatorImpl{}) if err != nil { log.Printf("udp handler exited with %s", err) } @@ -71,7 +77,19 @@ type joinResponse struct { Peers []pretty.Peer } -func HaddleJoins(ctx context.Context) error { +type cidrAllocator interface { + Allocate() (net.IP, error) +} + +type cidrAllocatorImpl struct{} + +func (c cidrAllocatorImpl) Allocate() (net.IP, error) { + return net.ParseIP("10.0.0.100"), nil +} + +var lock sync.Mutex + +func HaddleJoins(ctx context.Context, alloc cidrAllocator) error { udpaddr, err := net.ResolveUDPAddr("udp", "127.0.0.1"+defaultJoinPort) if err != nil { return err @@ -83,8 +101,8 @@ func HaddleJoins(ctx context.Context) error { log.Printf("Waiting for joins on %s", udpaddr.String()) go func() { for { - buf := make([]byte, 4096) //how big should we be? will we go over multiple packets? - n, remoteAddr, err := conn.ReadFromUDP(buf) + buf := make([]byte, 4096) //how big should we be? will we go over multiple packets? + n, remoteAddr, err := conn.ReadFromUDP(buf) //has to be this ratehr than desrialize because we need the remote addr or we get write: destination address required if err != nil { if !errors.Is(err, net.ErrClosed) { log.Printf("Failed to read from udp: %s", err) @@ -100,16 +118,21 @@ func HaddleJoins(ctx context.Context) error { continue } + //obviously bad. + if jreq.AuthToken != "HOKEYPOKEYSMOKEY" { + log.Printf("bad auth token from %v, %s", remoteAddr, jreq.PublicKey) + //ban them for a extended period? + continue + } + log.Printf("got join request from %v, %s", remoteAddr, jreq.PublicKey) - jResp := joinResponse{ - Assignedip: "10.0.0.100", - Peers: []pretty.Peer{ - { - PublicKey: "amMRWDvsLUmNHn52xer2yl/UaAkXnDrd/HxUTRkEGXc=", - AllowedIPs: "10.0.0.0/24", - }, - }, + jResp, err := GenerateResponse(jreq, alloc) + if err != nil { + log.Printf("Failed to generate response %s", err) + //ban them for a extended period? + continue } + respbuf, err := json.Marshal(jResp) if err != nil { log.Printf("Failed to enode: %s", err) @@ -130,6 +153,50 @@ func HaddleJoins(ctx context.Context) error { } +func GenerateResponse(jreq joinRequest, alloc cidrAllocator) (joinResponse, error) { + lock.Lock() + defer lock.Unlock() + + d0, err := wghelpers.GetDevice() + if err != nil { + return joinResponse{}, err + } + + var asssignedip string + existing, found := lo.Find(d0.Peers, func(p wgtypes.Peer) bool { return p.PublicKey.String() == jreq.PublicKey }) + if found { //should we also check that the ip is the same? + log.Printf("peer %s already exists", jreq.PublicKey) + asssignedip = existing.AllowedIPs[0].String() + } else { + ip, err := alloc.Allocate() + if err != nil { + //not nice to not tell them sorry? But then we need an error protocol + return joinResponse{}, err + } + asssignedip = ip.String() + } + + //ad the peer to us before we return anything + + cidr, err := nethelpers.GetWireGaurdCIDR(d0.Name) + if err != nil { + return joinResponse{}, err + } + + //ip, cinet.ParseCIDR(cidr.String()) + + return joinResponse{ + Assignedip: asssignedip, + Peers: []pretty.Peer{ + { + PublicKey: d0.PublicKey.String(), + AllowedIPs: cidr.String(), //too much throttle down to /32? + Endpoint: fmt.Sprintf("%s:%d", nethelpers.GetOutboundIP(), d0.ListenPort), //just pass this in instead of trying to detect it? + }, + }, + }, nil +} + func Peers(resp http.ResponseWriter, req *http.Request) { d0, err := wghelpers.GetDevice() if err != nil { diff --git a/nethelpers/helpers.go b/nethelpers/helpers.go index 88e491e..b6ff3e9 100644 --- a/nethelpers/helpers.go +++ b/nethelpers/helpers.go @@ -10,22 +10,22 @@ import ( "github.com/samber/lo" ) -func GetWireGaurdIP(interfacename string) string { +func GetWireGaurdCIDR(interfacename string) (net.Addr, error) { ifaces, err := net.Interfaces() if err != nil { - log.Fatalf("can't get interfaces: %v", err) + return nil, err } wginterface, found := lo.Find(ifaces, func(iface net.Interface) bool { return iface.Name == interfacename }) if !found { - log.Fatalf("can't get interfaces: %v", err) + return nil, err } addrs, err := wginterface.Addrs() if err != nil { - log.Fatalf("can't get interface addrs: %v", err) + return nil, err } - return addrs[0].String() + return addrs[0], nil } func GetOutboundIP() string {