diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..378eac2 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a0cfa3b --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +.PHONY: tailws +all: build + +build: tailws + +install: install_tailws + +tailws: + @echo "Building tail-ws" + go build -o build/tail-ws ./cmd/tail-ws + +install_tailws: + @echo "Installing tail-ws" + go install ./cmd/tail-ws + +clean: + rm -rf build diff --git a/README.md b/README.md index 9583a00..03d327d 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,31 @@ -Broadcaster for appended file lines -=================================== +WebSocket broadcaster for appended file lines +============================================= +[![Go Report Card](https://goreportcard.com/badge/github.com/jeronimoalbi/tail-ws)](https://goreportcard.com/report/github.com/jeronimoalbi/tail-ws) [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) + +Installation +------------ + +Install the binary by running: + +``` +go install github.com/jeronimoalbi/tail-ws/cmd/tail-ws@latest +``` + +or alternatively: + +``` +make install +``` + +Run +--- + +To start broadcasting appended lines run: + +``` +tail-ws FILE +``` + +New lines are broadcasted by default from the address `ws://127.0.0.1:8080`. diff --git a/broadcast/connections.go b/broadcast/connections.go new file mode 100644 index 0000000..e3300f7 --- /dev/null +++ b/broadcast/connections.go @@ -0,0 +1,65 @@ +package broadcast + +import ( + "sync" + + "github.com/gorilla/websocket" +) + +// NewConnections create a new Websocket connections registry. +func NewConnections() *Connections { + return &Connections{ + registry: make(map[*websocket.Conn]struct{}), + } +} + +// Connections keeps track of active Websocket connections. +type Connections struct { + mu sync.RWMutex + registry map[*websocket.Conn]struct{} +} + +// IsEmpty checks if there are registered connections. +func (c *Connections) IsEmpty() bool { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.registry) == 0 +} + +// Add adds a new Websocket connection to the registry. +func (c *Connections) Add(ws *websocket.Conn) { + c.mu.Lock() + c.registry[ws] = struct{}{} + c.mu.Unlock() +} + +// Delete removes a Websocket connection from the registry. +// Connections are closed after being removed. +func (c *Connections) Delete(ws *websocket.Conn) error { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.registry, ws) + return ws.Close() +} + +// Close closes all connections. +func (c *Connections) Close() { + c.Iter(func(ws *websocket.Conn) bool { + ws.Close() + return true + }) +} + +// Iter allows iterating the current connections. +// Iteration stops when when false is returned. +func (c *Connections) Iter(fn func(*websocket.Conn) bool) { + c.mu.RLock() + for ws := range c.registry { + if !fn(ws) { + return + } + } + c.mu.RUnlock() +} diff --git a/broadcast/server.go b/broadcast/server.go new file mode 100644 index 0000000..2dd5f44 --- /dev/null +++ b/broadcast/server.go @@ -0,0 +1,214 @@ +package broadcast + +import ( + "bufio" + "context" + "errors" + "log" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" + + "github.com/jeronimoalbi/tail-ws/watch" +) + +var ( + // DefaultAddr defines the default Websocket server address. + DefaultAddr = "127.0.0.1:8080" + + maxMessageSize int64 = 1024 + pingPeriod = (pongWait * 9) / 10 + pongWait = 60 * time.Second + writeWait = 8 * time.Second +) + +// Option configures transaction broadcast servers. +type Option func(*Server) + +// Address sets the server address. +func Address(addr string) Option { + return func(s *Server) { + s.addr = addr + } +} + +// Origin sets the allowed origin for incoming requests. +func Origin(origin string) Option { + return func(s *Server) { + s.origin = origin + } +} + +// Secure enables secure Websockets (WSS). +func Secure(certFile, keyFile string) Option { + return func(s *Server) { + s.certFile = certFile + s.keyFile = keyFile + } +} + +// NewServer creates a new transactions broadcast server. +func NewServer(options ...Option) *Server { + s := Server{ + addr: DefaultAddr, + connections: NewConnections(), + } + + for _, apply := range options { + apply(&s) + } + + s.upgrader.CheckOrigin = func(r *http.Request) bool { + if s.origin != "" { + return s.origin == r.Header.Get("Origin") + } + return true + } + + return &s +} + +// Server handles Websocket connections and broadcasts new transactions. +// It watches the transactions head file and when new transactions are indexed +// it pushes the new entries to the connected clients. +type Server struct { + addr, origin, certFile, keyFile string + reader watch.Reader + connections *Connections + upgrader websocket.Upgrader +} + +// HandleWS is an HTTP handler that upgrades incoming connections to WS or WSS. +func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { + // TODO: Add authentication support + log.Printf("connection stablished with %s", r.RemoteAddr) + + ws, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + // Upgrade already returns the error to the client on failure + log.Printf("connection from %s failed: %v", r.RemoteAddr, err) + return + } + + ws.SetReadLimit(maxMessageSize) + + // Prepare keep alive protocol for the new connection + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + + // Launch a gopher to keep connection alive + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)) + if err != nil { + log.Printf("error sending ping: %v", err) + ws.Close() + } + case <-ctx.Done(): + return + } + } + }() + + // Make sure to cleanup connection when closed + ws.SetCloseHandler(func(int, string) error { + log.Printf("closing connextion %s", ws.RemoteAddr()) + cancel() + return s.connections.Delete(ws) + }) + + s.connections.Add(ws) +} + +// Start starts a new HTTP server to listen for incoming WS or WSS connections. +func (s *Server) Start(ctx context.Context) error { + g, ctx := errgroup.WithContext(ctx) + server := &http.Server{ + Addr: s.addr, + Handler: http.HandlerFunc(s.HandleWS), + BaseContext: func(l net.Listener) context.Context { + return ctx + }, + } + + g.Go(func() error { + <-ctx.Done() + s.connections.Close() + return server.Close() + }) + + g.Go(func() error { + var err error + if s.certFile != "" && s.keyFile != "" { + log.Printf("listening for connections -> wss://%s", s.addr) + err = server.ListenAndServeTLS(s.certFile, s.keyFile) + } else { + log.Printf("listening for connections -> ws://%s", s.addr) + err = server.ListenAndServe() + } + + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err + }) + + return g.Wait() +} + +// Watch starts watching a transaction head file and broadcasts +// the newly indexed transactions to all connected peers. +func (s *Server) Watch(ctx context.Context, name string) error { + r := watch.NewReader(watch.SeekEnd()) + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + scanner := bufio.NewScanner(r) + for scanner.Scan() { + s.broadcast(scanner.Bytes()) + } + + return scanner.Err() + }) + + g.Go(func() error { + defer r.Close() + + for { + // Keep watching when the file is ovewritten + if err := r.Watch(ctx, name); err != watch.ErrFileOverwritten { + return err + } + } + }) + + return g.Wait() +} + +func (s Server) broadcast(tx []byte) { + s.connections.Iter(func(ws *websocket.Conn) bool { + go func() { + ws.SetWriteDeadline(time.Now().Add(writeWait)) + + if err := ws.WriteMessage(websocket.BinaryMessage, tx); err != nil { + log.Printf("tx broadcast failed: %v", err) + ws.Close() + } + }() + + return true + }) +} diff --git a/cmd/tail-ws/main.go b/cmd/tail-ws/main.go new file mode 100644 index 0000000..2cd0fb7 --- /dev/null +++ b/cmd/tail-ws/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "os" + + "golang.org/x/sync/errgroup" + + "github.com/jeronimoalbi/tail-ws/broadcast" +) + +const usage = `Usage: tail-ws [OPTION]... FILE + +WebSocket broadcaster for appended file lines. + +Options:` + +func main() { + var ( + addr, origin, certFile, keyFile string + verbose bool + ) + + flag.StringVar(&addr, "address", "127.0.0.1:8080", "server address") + flag.StringVar(&origin, "allow-origin", "", "address of the origin allowed to connect") + flag.BoolVar(&verbose, "verbose", false, "verbose output") + flag.StringVar(&certFile, "cert-file", "", "certificate file for WSS server") + flag.StringVar(&keyFile, "key-file", "", "private key file for WSS server") + flag.Parse() + + fileName := flag.Arg(0) + if fileName == "" { + printUsage() + os.Exit(1) + } + + if !verbose { + log.SetOutput(io.Discard) + } + + server := broadcast.NewServer( + broadcast.Address(addr), + broadcast.Origin(origin), + broadcast.Secure(certFile, keyFile), + ) + g, ctx := errgroup.WithContext(context.Background()) + + g.Go(func() error { + return server.Watch(ctx, fileName) + }) + + g.Go(func() error { + return server.Start(ctx) + }) + + if err := g.Wait(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func printUsage() { + fmt.Fprintln(os.Stderr, usage) + flag.PrintDefaults() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f744de8 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/jeronimoalbi/tail-ws + +go 1.19 + +require ( + github.com/fsnotify/fsnotify v1.6.0 + github.com/gorilla/websocket v1.5.0 + golang.org/x/sync v0.1.0 +) + +require golang.org/x/sys v0.6.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3694877 --- /dev/null +++ b/go.sum @@ -0,0 +1,9 @@ +github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= +github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/watch/watch.go b/watch/watch.go new file mode 100644 index 0000000..6b344c5 --- /dev/null +++ b/watch/watch.go @@ -0,0 +1,222 @@ +package watch + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/fsnotify/fsnotify" + "golang.org/x/sync/errgroup" +) + +var ( + ErrFileDeleted = errors.New("watched file has been deleted") + ErrFileOverwritten = errors.New("watched file has been overwritten") + ErrFileRenamed = errors.New("watched file has been renamed") +) + +// Watcher defines the interface for trasaction head file watchers. +type Watcher interface { + // Watch starts watching a transaction head file for changer. + Watch(ctx context.Context, name string) error +} + +// OffsetGetter defines the interface for transaction head file offset getters. +type OffsetGetter interface { + // GetOffset returns the current transaction head file offset. + GetOffset() int64 +} + +// Reader defines the interface for transaction head file readers. +type Reader interface { + io.ReadCloser + Watcher + OffsetGetter +} + +// Option configures watch readers. +type Option func(*reader) + +// StartOffset sets the initial offset for the reader. +func StartOffset(offset int64) Option { + return func(r *reader) { + r.offset = offset + } +} + +// SeekEnd points the reader to the end of the contents. +func SeekEnd() Option { + return func(r *reader) { + r.seekEnd = true + } +} + +// NewReader creates a new transaction head file reader. +func NewReader(options ...Option) Reader { + r := &reader{ + read: make(chan struct{}), + } + + for _, apply := range options { + apply(r) + } + + return r +} + +type reader struct { + offset int64 + seekEnd bool + file io.ReadSeekCloser + read chan struct{} +} + +func (r *reader) Read(b []byte) (n int, err error) { + // Block until watch is called + if _, ok := <-r.read; !ok { + // Channel closed before the first read + return 0, io.EOF + } + + for { + n, err := r.file.Read(b) + if err != nil && err != io.EOF { + return n, err + } + + // On successful read increment offset + if n > 0 { + r.offset += int64(n) + return n, nil + } + + // When read fails with EOF wait until more content is available + if _, ok := <-r.read; !ok { + // Finish reading when read channel is closed + return 0, io.EOF + } + } +} + +func (r *reader) Close() error { + close(r.read) + return nil +} + +func (r *reader) GetOffset() int64 { + return r.offset +} + +func (r *reader) Watch(ctx context.Context, name string) error { + name, err := filepath.Abs(name) + if err != nil { + return err + } + + r.file, err = r.openFile(name) + if err != nil { + return err + } + + defer func() { + r.file.Close() + }() + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + + defer watcher.Close() + + dir := filepath.Dir(name) + if err := watcher.Add(dir); err != nil { + return err + } + + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + for { + select { + case e := <-watcher.Events: + if e.Name != name { + continue + } + + if err := r.handleEvent(e); err != nil { + return err + } + case err := <-watcher.Errors: + if err != nil { + return err + } + case <-ctx.Done(): + return nil + } + } + }) + + return g.Wait() +} + +func (r *reader) openFile(name string) (io.ReadSeekCloser, error) { + f, err := os.Open(name) + if err != nil { + return nil, err + } + + if r.offset > 0 { + if _, err := f.Seek(r.offset, io.SeekStart); err != nil { + return nil, fmt.Errorf("error seeking file offset %d: %w", r.offset, err) + } + } else if r.seekEnd { + if _, err := f.Seek(0, io.SeekEnd); err != nil { + return nil, fmt.Errorf("error seeking file end: %w", err) + } + } + + return f, nil +} + +func (r *reader) handleEvent(e fsnotify.Event) error { + if e.Has(fsnotify.Remove) { + if _, err := os.Stat(e.Name); os.IsNotExist(err) { + return ErrFileDeleted + } + return ErrFileOverwritten + } + + if e.Has(fsnotify.Rename) { + return ErrFileRenamed + } + + // On write keep reading the file + if e.Has(fsnotify.Write) { + info, err := os.Stat(e.Name) + if err != nil { + return err + } + + // Reset offset when file is truncated + if size := info.Size(); size <= r.offset { + if _, err := r.file.Seek(0, io.SeekStart); err != nil { + return err + } + + r.offset = 0 + + // Ignore empty truncated file events + if size == 0 { + return nil + } + } + + r.read <- struct{}{} + } + + return nil +}