diff --git a/go.mod b/go.mod index a7849eb01..7dd25ac1c 100644 --- a/go.mod +++ b/go.mod @@ -5,22 +5,18 @@ go 1.18 require ( github.com/creasty/defaults v1.7.0 github.com/go-redis/redis/v8 v8.11.5 - github.com/go-sql-driver/mysql v1.7.1 github.com/goccy/go-yaml v1.11.2 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.4.0 github.com/icinga/icinga-go-library v0.0.0-20231121130315-330ba126e732 github.com/jessevdk/go-flags v1.5.0 github.com/jmoiron/sqlx v1.3.5 - github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.18 github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd github.com/pkg/errors v0.9.1 - github.com/ssgreg/journald v1.0.0 github.com/stretchr/testify v1.8.4 github.com/vbauerster/mpb/v6 v6.0.4 go.uber.org/zap v1.26.0 - golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa golang.org/x/sync v0.5.0 ) @@ -31,12 +27,16 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fatih/color v1.16.0 // indirect + github.com/go-sql-driver/mysql v1.7.1 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.4 // indirect + github.com/ssgreg/journald v1.0.0 // indirect go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect golang.org/x/sys v0.14.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/pkg/backoff/backoff.go b/pkg/backoff/backoff.go deleted file mode 100644 index 6ce7beef8..000000000 --- a/pkg/backoff/backoff.go +++ /dev/null @@ -1,43 +0,0 @@ -package backoff - -import ( - "math/rand" - "time" -) - -// Backoff returns the backoff duration for a specific retry attempt. -type Backoff func(uint64) time.Duration - -// NewExponentialWithJitter returns a backoff implementation that -// exponentially increases the backoff duration for each retry from min, -// never exceeding max. Some randomization is added to the backoff duration. -// It panics if min >= max. -func NewExponentialWithJitter(min, max time.Duration) Backoff { - if min <= 0 { - min = time.Millisecond * 100 - } - if max <= 0 { - max = time.Second * 10 - } - if min >= max { - panic("max must be larger than min") - } - - return func(attempt uint64) time.Duration { - e := min << attempt - if e <= 0 || e > max { - e = max - } - - return time.Duration(jitter(int64(e))) - } -} - -// jitter returns a random integer distributed in the range [n/2..n). -func jitter(n int64) int64 { - if n == 0 { - return 0 - } - - return n/2 + rand.Int63n(n/2) -} diff --git a/pkg/com/atomic.go b/pkg/com/atomic.go deleted file mode 100644 index 316413dfd..000000000 --- a/pkg/com/atomic.go +++ /dev/null @@ -1,38 +0,0 @@ -package com - -import "sync/atomic" - -// Atomic is a type-safe wrapper around atomic.Value. -type Atomic[T any] struct { - v atomic.Value -} - -func (a *Atomic[T]) Load() (_ T, ok bool) { - if v, ok := a.v.Load().(box[T]); ok { - return v.v, true - } - - return -} - -func (a *Atomic[T]) Store(v T) { - a.v.Store(box[T]{v}) -} - -func (a *Atomic[T]) Swap(new T) (old T, ok bool) { - if old, ok := a.v.Swap(box[T]{new}).(box[T]); ok { - return old.v, true - } - - return -} - -func (a *Atomic[T]) CompareAndSwap(old, new T) (swapped bool) { - return a.v.CompareAndSwap(box[T]{old}, box[T]{new}) -} - -// box allows, for the case T is an interface, nil values and values of different specific types implementing T -// to be stored in Atomic[T]#v (bypassing atomic.Value#Store()'s policy) by wrapping it (into a non-interface). -type box[T any] struct { - v T -} diff --git a/pkg/com/bulker.go b/pkg/com/bulker.go deleted file mode 100644 index e4fe7aa9e..000000000 --- a/pkg/com/bulker.go +++ /dev/null @@ -1,166 +0,0 @@ -package com - -import ( - "context" - "golang.org/x/sync/errgroup" - "sync" - "time" -) - -// BulkChunkSplitPolicy is a state machine which tracks the items of a chunk a bulker assembles. -// A call takes an item for the current chunk into account. -// Output true indicates that the state machine was reset first and the bulker -// shall finish the current chunk now (not e.g. once $size is reached) without the given item. -type BulkChunkSplitPolicy[T any] func(T) bool - -type BulkChunkSplitPolicyFactory[T any] func() BulkChunkSplitPolicy[T] - -// NeverSplit returns a pseudo state machine which never demands splitting. -func NeverSplit[T any]() BulkChunkSplitPolicy[T] { - return neverSplit[T] -} - -func neverSplit[T any](T) bool { - return false -} - -// Bulker reads all values from a channel and streams them in chunks into a Bulk channel. -type Bulker[T any] struct { - ch chan []T - ctx context.Context - mu sync.Mutex -} - -// NewBulker returns a new Bulker and starts streaming. -func NewBulker[T any]( - ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T], -) *Bulker[T] { - b := &Bulker[T]{ - ch: make(chan []T), - ctx: ctx, - mu: sync.Mutex{}, - } - - go b.run(ch, count, splitPolicyFactory) - - return b -} - -// Bulk returns the channel on which the bulks are delivered. -func (b *Bulker[T]) Bulk() <-chan []T { - return b.ch -} - -func (b *Bulker[T]) run(ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T]) { - defer close(b.ch) - - bufCh := make(chan T, count) - splitPolicy := splitPolicyFactory() - g, ctx := errgroup.WithContext(b.ctx) - - g.Go(func() error { - defer close(bufCh) - - for { - select { - case v, ok := <-ch: - if !ok { - return nil - } - - bufCh <- v - case <-ctx.Done(): - return ctx.Err() - } - } - }) - - g.Go(func() error { - for done := false; !done; { - buf := make([]T, 0, count) - timeout := time.After(256 * time.Millisecond) - - for drain := true; drain && len(buf) < count; { - select { - case v, ok := <-bufCh: - if !ok { - drain = false - done = true - - break - } - - if splitPolicy(v) { - if len(buf) > 0 { - b.ch <- buf - buf = make([]T, 0, count) - } - - timeout = time.After(256 * time.Millisecond) - } - - buf = append(buf, v) - case <-timeout: - drain = false - case <-ctx.Done(): - return ctx.Err() - } - } - - if len(buf) > 0 { - b.ch <- buf - } - - splitPolicy = splitPolicyFactory() - } - - return nil - }) - - // We don't expect an error here. - // We only use errgroup for the encapsulated use of sync.WaitGroup. - _ = g.Wait() -} - -// Bulk reads all values from a channel and streams them in chunks into a returned channel. -func Bulk[T any]( - ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T], -) <-chan []T { - if count <= 1 { - return oneBulk(ctx, ch) - } - - return NewBulker(ctx, ch, count, splitPolicyFactory).Bulk() -} - -// oneBulk operates just as NewBulker(ctx, ch, 1, splitPolicy).Bulk(), -// but without the overhead of the actual bulk creation with a buffer channel, timeout and BulkChunkSplitPolicy. -func oneBulk[T any](ctx context.Context, ch <-chan T) <-chan []T { - out := make(chan []T) - go func() { - defer close(out) - - for { - select { - case item, ok := <-ch: - if !ok { - return - } - - select { - case out <- []T{item}: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } - } - }() - - return out -} - -var ( - _ BulkChunkSplitPolicyFactory[struct{}] = NeverSplit[struct{}] -) diff --git a/pkg/com/com.go b/pkg/com/com.go deleted file mode 100644 index a28bf8c93..000000000 --- a/pkg/com/com.go +++ /dev/null @@ -1,85 +0,0 @@ -package com - -import ( - "context" - "github.com/icinga/icingadb/pkg/types" - "github.com/pkg/errors" - "golang.org/x/sync/errgroup" -) - -// WaitAsync calls Wait() on the passed Waiter in a new goroutine and -// sends the first non-nil error (if any) to the returned channel. -// The returned channel is always closed when the Waiter is done. -func WaitAsync(ctx context.Context, w Waiter) <-chan error { - errs := make(chan error, 1) - - go func() { - defer close(errs) - - if e := w.Wait(); e != nil { - select { - case errs <- e: - case <-ctx.Done(): - } - } - }() - - return errs -} - -// ErrgroupReceive adds a goroutine to the specified group that -// returns the first non-nil error (if any) from the specified channel. -// If the channel is closed, it will return nil. -func ErrgroupReceive(ctx context.Context, g *errgroup.Group, err <-chan error) { - g.Go(func() error { - select { - case e, closed := <-err: - if closed { - return nil - } - - return e - case <-ctx.Done(): - return ctx.Err() - } - }) -} - -// CopyFirst asynchronously forwards all items from input to forward and synchronously returns the first item. -func CopyFirst[T any](ctx context.Context, input <-chan T) (T, <-chan T, error) { - select { - case first, ok := <-input: - if !ok { - return types.Zero[T](), nil, errors.New("can't read from closed channel") - } - - // Buffer of one because we receive an entity and send it back immediately. - forward := make(chan T, 1) - forward <- first - - go func() { - defer close(forward) - - for { - select { - case e, ok := <-input: - if !ok { - return - } - - select { - case forward <- e: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } - } - }() - - return first, forward, nil - case <-ctx.Done(): - return types.Zero[T](), nil, ctx.Err() - } -} diff --git a/pkg/com/cond.go b/pkg/com/cond.go deleted file mode 100644 index 72ba347c5..000000000 --- a/pkg/com/cond.go +++ /dev/null @@ -1,90 +0,0 @@ -package com - -import ( - "context" - "github.com/pkg/errors" -) - -// Cond implements a channel-based synchronization for goroutines that wait for signals or send them. -// Internally based on a controller loop that handles the synchronization of new listeners and signal propagation, -// which is only started when NewCond is called. Thus the zero value cannot be used. -type Cond struct { - broadcast chan struct{} - done chan struct{} - cancel context.CancelFunc - listeners chan chan struct{} -} - -// NewCond returns a new Cond and starts the controller loop. -func NewCond(ctx context.Context) *Cond { - ctx, cancel := context.WithCancel(ctx) - - c := &Cond{ - broadcast: make(chan struct{}), - cancel: cancel, - done: make(chan struct{}), - listeners: make(chan chan struct{}), - } - - go c.controller(ctx) - - return c -} - -// Broadcast sends a signal to all current listeners by closing the previously returned channel from Wait. -// Panics if the controller loop has already ended. -func (c *Cond) Broadcast() { - select { - case c.broadcast <- struct{}{}: - case <-c.done: - panic(errors.New("condition closed")) - } -} - -// Close stops the controller loop, waits for it to finish, and returns an error if any. -// Implements the io.Closer interface. -func (c *Cond) Close() error { - c.cancel() - <-c.done - - return nil -} - -// Done returns a channel that will be closed when the controller loop has ended. -func (c *Cond) Done() <-chan struct{} { - return c.done -} - -// Wait returns a channel that is closed with the next signal. -// Panics if the controller loop has already ended. -func (c *Cond) Wait() <-chan struct{} { - select { - case l := <-c.listeners: - return l - case <-c.done: - panic(errors.New("condition closed")) - } -} - -// controller loop. -func (c *Cond) controller(ctx context.Context) { - defer close(c.done) - - // Note that the notify channel does not close when the controller loop ends - // in order not to notify pending listeners. - notify := make(chan struct{}) - - for { - select { - case <-c.broadcast: - // Close channel to notify all current listeners. - close(notify) - // Create a new channel for the next listeners. - notify = make(chan struct{}) - case c.listeners <- notify: - // A new listener received the channel. - case <-ctx.Done(): - return - } - } -} diff --git a/pkg/com/contracts.go b/pkg/com/contracts.go deleted file mode 100644 index ce40061cb..000000000 --- a/pkg/com/contracts.go +++ /dev/null @@ -1,16 +0,0 @@ -package com - -// Waiter implements the Wait method, -// which blocks until execution is complete. -type Waiter interface { - Wait() error // Wait waits for execution to complete. -} - -// The WaiterFunc type is an adapter to allow the use of ordinary functions as Waiter. -// If f is a function with the appropriate signature, WaiterFunc(f) is a Waiter that calls f. -type WaiterFunc func() error - -// Wait implements the Waiter interface. -func (f WaiterFunc) Wait() error { - return f() -} diff --git a/pkg/com/counter.go b/pkg/com/counter.go deleted file mode 100644 index 52f9f7ff2..000000000 --- a/pkg/com/counter.go +++ /dev/null @@ -1,48 +0,0 @@ -package com - -import ( - "sync" - "sync/atomic" -) - -// Counter implements an atomic counter. -type Counter struct { - value uint64 - mu sync.Mutex // Protects total. - total uint64 -} - -// Add adds the given delta to the counter. -func (c *Counter) Add(delta uint64) { - atomic.AddUint64(&c.value, delta) -} - -// Inc increments the counter by one. -func (c *Counter) Inc() { - c.Add(1) -} - -// Reset resets the counter to 0 and returns its previous value. -// Does not reset the total value returned from Total. -func (c *Counter) Reset() uint64 { - c.mu.Lock() - defer c.mu.Unlock() - - v := atomic.SwapUint64(&c.value, 0) - c.total += v - - return v -} - -// Total returns the total counter value. -func (c *Counter) Total() uint64 { - c.mu.Lock() - defer c.mu.Unlock() - - return c.total + c.Val() -} - -// Val returns the current counter value. -func (c *Counter) Val() uint64 { - return atomic.LoadUint64(&c.value) -} diff --git a/pkg/config/config.go b/pkg/config/config.go deleted file mode 100644 index 700cf9ed7..000000000 --- a/pkg/config/config.go +++ /dev/null @@ -1,116 +0,0 @@ -package config - -import ( - "crypto/tls" - "crypto/x509" - "github.com/creasty/defaults" - "github.com/goccy/go-yaml" - "github.com/jessevdk/go-flags" - "github.com/pkg/errors" - "io/ioutil" - "os" -) - -type Initer interface { - Init() -} - -type Validator interface { - Validate() error -} - -// FromYAMLFile returns a new value of type T created from the given YAML config file. -func FromYAMLFile[T any, P interface { - *T - Validator -}](name string) (*T, error) { - f, err := os.Open(name) - if err != nil { - return nil, errors.Wrap(err, "can't open YAML file "+name) - } - defer f.Close() - - c := P(new(T)) - - if initer, ok := any(c).(Initer); ok { - initer.Init() - } - - if err := defaults.Set(c); err != nil { - return nil, errors.Wrap(err, "can't set config defaults") - } - - d := yaml.NewDecoder(f, yaml.DisallowUnknownField()) - if err := d.Decode(c); err != nil { - return nil, errors.Wrap(err, "can't parse YAML file "+name) - } - - if err := c.Validate(); err != nil { - return nil, errors.Wrap(err, "invalid configuration") - } - - return c, nil -} - -// ParseFlags parses CLI flags and -// returns a Flags value created from them. -func ParseFlags[T any, P interface{ *T }]() (*T, error) { - f := P(new(T)) - parser := flags.NewParser(f, flags.Default) - - if _, err := parser.Parse(); err != nil { - return nil, errors.Wrap(err, "can't parse CLI flags") - } - - return f, nil -} - -// TLS provides TLS configuration options. -type TLS struct { - Enable bool `yaml:"tls"` - Cert string `yaml:"cert"` - Key string `yaml:"key"` - Ca string `yaml:"ca"` - Insecure bool `yaml:"insecure"` -} - -// MakeConfig assembles a tls.Config from t and serverName. -func (t *TLS) MakeConfig(serverName string) (*tls.Config, error) { - if !t.Enable { - return nil, nil - } - - tlsConfig := &tls.Config{} - if t.Cert == "" { - if t.Key != "" { - return nil, errors.New("private key given, but client certificate missing") - } - } else if t.Key == "" { - return nil, errors.New("client certificate given, but private key missing") - } else { - crt, err := tls.LoadX509KeyPair(t.Cert, t.Key) - if err != nil { - return nil, errors.Wrap(err, "can't load X.509 key pair") - } - - tlsConfig.Certificates = []tls.Certificate{crt} - } - - if t.Insecure { - tlsConfig.InsecureSkipVerify = true - } else if t.Ca != "" { - raw, err := ioutil.ReadFile(t.Ca) - if err != nil { - return nil, errors.Wrap(err, "can't read CA file") - } - - tlsConfig.RootCAs = x509.NewCertPool() - if !tlsConfig.RootCAs.AppendCertsFromPEM(raw) { - return nil, errors.New("can't parse CA file") - } - } - - tlsConfig.ServerName = serverName - - return tlsConfig, nil -} diff --git a/pkg/database/cleanup.go b/pkg/database/cleanup.go deleted file mode 100644 index 3f485b0ef..000000000 --- a/pkg/database/cleanup.go +++ /dev/null @@ -1,79 +0,0 @@ -package database - -import ( - "context" - "fmt" - "github.com/icinga/icingadb/pkg/com" - "github.com/icinga/icingadb/pkg/driver" - "github.com/icinga/icingadb/pkg/types" - "time" -) - -// CleanupStmt defines information needed to compose cleanup statements. -type CleanupStmt struct { - Table string - PK string - Column string -} - -// Build assembles the cleanup statement for the specified database driver with the given limit. -func (stmt *CleanupStmt) Build(driverName string, limit uint64) string { - switch driverName { - case driver.MySQL, "mysql": - return fmt.Sprintf(`DELETE FROM %[1]s WHERE environment_id = :environment_id AND %[2]s < :time -ORDER BY %[2]s LIMIT %[3]d`, stmt.Table, stmt.Column, limit) - case driver.PostgreSQL, "postgres": - return fmt.Sprintf(`WITH rows AS ( -SELECT %[1]s FROM %[2]s WHERE environment_id = :environment_id AND %[3]s < :time ORDER BY %[3]s LIMIT %[4]d -) -DELETE FROM %[2]s WHERE %[1]s IN (SELECT %[1]s FROM rows)`, stmt.PK, stmt.Table, stmt.Column, limit) - default: - panic(fmt.Sprintf("invalid database type %s", driverName)) - } -} - -// CleanupOlderThan deletes all rows with the specified statement that are older than the given time. -// Deletes a maximum of as many rows per round as defined in count. Actually deleted rows will be passed to onSuccess. -// Returns the total number of rows deleted. -func (db *DB) CleanupOlderThan( - ctx context.Context, stmt CleanupStmt, envId types.Binary, - count uint64, olderThan time.Time, onSuccess ...OnSuccess[struct{}], -) (uint64, error) { - var counter com.Counter - defer db.log(ctx, stmt.Build(db.DriverName(), 0), &counter).Stop() - - for { - q := db.Rebind(stmt.Build(db.DriverName(), count)) - rs, err := db.NamedExecContext(ctx, q, cleanupWhere{ - EnvironmentId: envId, - Time: types.UnixMilli(olderThan), - }) - if err != nil { - return 0, CantPerformQuery(err, q) - } - - n, err := rs.RowsAffected() - if err != nil { - return 0, err - } - - counter.Add(uint64(n)) - - for _, onSuccess := range onSuccess { - if err := onSuccess(ctx, make([]struct{}, n)); err != nil { - return 0, err - } - } - - if n < int64(count) { - break - } - } - - return counter.Total(), nil -} - -type cleanupWhere struct { - EnvironmentId types.Binary - Time types.UnixMilli -} diff --git a/pkg/database/config.go b/pkg/database/config.go deleted file mode 100644 index c27b803fe..000000000 --- a/pkg/database/config.go +++ /dev/null @@ -1,45 +0,0 @@ -package database - -import ( - "github.com/icinga/icingadb/pkg/config" - "github.com/pkg/errors" -) - -// Config defines database client configuration. -type Config struct { - Type string `yaml:"type" default:"mysql"` - Host string `yaml:"host"` - Port int `yaml:"port"` - Database string `yaml:"database"` - User string `yaml:"user"` - Password string `yaml:"password"` - TlsOptions config.TLS `yaml:",inline"` - Options Options `yaml:"options"` -} - -// Validate checks constraints in the supplied database configuration and returns an error if they are violated. -func (c *Config) Validate() error { - switch c.Type { - case "mysql", "pgsql": - default: - return unknownDbType(c.Type) - } - - if c.Host == "" { - return errors.New("database host missing") - } - - if c.User == "" { - return errors.New("database user missing") - } - - if c.Database == "" { - return errors.New("database name missing") - } - - return c.Options.Validate() -} - -func unknownDbType(t string) error { - return errors.Errorf(`unknown database type %q, must be one of: "mysql", "pgsql"`, t) -} diff --git a/pkg/database/contracts.go b/pkg/database/contracts.go deleted file mode 100644 index 7ab8cd192..000000000 --- a/pkg/database/contracts.go +++ /dev/null @@ -1,49 +0,0 @@ -package database - -// ID is a unique identifier of an entity. -type ID interface { - // String returns the string representation form of the ID. - // The String method is used to use the ID in functions - // where it needs to be compared or hashed. - String() string -} - -// IDer is implemented by every entity that uniquely identifies itself. -type IDer interface { - ID() ID // ID returns the ID. - SetID(ID) // SetID sets the ID. -} - -// Fingerprinter is implemented by every entity that uniquely identifies itself. -type Fingerprinter interface { - // Fingerprint returns the value that uniquely identifies the entity. - Fingerprint() Fingerprinter -} - -// Entity is implemented by each type that works with the database package. -type Entity interface { - Fingerprinter - IDer -} - -// EntityFactoryFunc knows how to create an Entity. -type EntityFactoryFunc func() Entity - -// Upserter implements the Upsert method, -// which returns a part of the object for ON DUPLICATE KEY UPDATE. -type Upserter interface { - Upsert() any // Upsert partitions the object. -} - -// TableNamer implements the TableName method, -// which returns the table of the object. -type TableNamer interface { - TableName() string // TableName tells the table. -} - -// Scoper implements the Scope method, -// which returns a struct specifying the WHERE conditions that -// entities must satisfy in order to be SELECTed. -type Scoper interface { - Scope() any -} diff --git a/pkg/database/db.go b/pkg/database/db.go deleted file mode 100644 index a41ab782d..000000000 --- a/pkg/database/db.go +++ /dev/null @@ -1,745 +0,0 @@ -package database - -import ( - "context" - "fmt" - "github.com/go-sql-driver/mysql" - "github.com/icinga/icingadb/pkg/backoff" - "github.com/icinga/icingadb/pkg/com" - "github.com/icinga/icingadb/pkg/driver" - "github.com/icinga/icingadb/pkg/logging" - "github.com/icinga/icingadb/pkg/periodic" - "github.com/icinga/icingadb/pkg/retry" - "github.com/icinga/icingadb/pkg/strcase" - "github.com/icinga/icingadb/pkg/utils" - "github.com/jmoiron/sqlx" - "github.com/jmoiron/sqlx/reflectx" - "github.com/pkg/errors" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/semaphore" - "net/url" - "reflect" - "strconv" - "strings" - "sync" - "time" -) - -// DB is a wrapper around sqlx.DB with bulk execution, -// statement building, streaming and logging capabilities. -type DB struct { - *sqlx.DB - - Options *Options - - logger *logging.Logger - tableSemaphores map[string]*semaphore.Weighted - tableSemaphoresMu sync.Mutex -} - -// Options define user configurable database options. -type Options struct { - // Maximum number of open connections to the database. - MaxConnections int `yaml:"max_connections" default:"16"` - - // Maximum number of connections per table, - // regardless of what the connection is actually doing, - // e.g. INSERT, UPDATE, DELETE. - MaxConnectionsPerTable int `yaml:"max_connections_per_table" default:"8"` - - // MaxPlaceholdersPerStatement defines the maximum number of placeholders in an - // INSERT, UPDATE or DELETE statement. Theoretically, MySQL can handle up to 2^16-1 placeholders, - // but this increases the execution time of queries and thus reduces the number of queries - // that can be executed in parallel in a given time. - // The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism. - MaxPlaceholdersPerStatement int `yaml:"max_placeholders_per_statement" default:"8192"` - - // MaxRowsPerTransaction defines the maximum number of rows per transaction. - // The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism. - MaxRowsPerTransaction int `yaml:"max_rows_per_transaction" default:"8192"` -} - -// Validate checks constraints in the supplied database options and returns an error if they are violated. -func (o *Options) Validate() error { - if o.MaxConnections == 0 { - return errors.New("max_connections cannot be 0. Configure a value greater than zero, or use -1 for no connection limit") - } - if o.MaxConnectionsPerTable < 1 { - return errors.New("max_connections_per_table must be at least 1") - } - if o.MaxPlaceholdersPerStatement < 1 { - return errors.New("max_placeholders_per_statement must be at least 1") - } - if o.MaxRowsPerTransaction < 1 { - return errors.New("max_rows_per_transaction must be at least 1") - } - - return nil -} - -// NewDb returns a new DB wrapper for a pre-existing sqlx.DB. -func NewDb(db *sqlx.DB, logger *logging.Logger, options *Options) *DB { - return &DB{ - DB: db, - logger: logger, - Options: options, - tableSemaphores: make(map[string]*semaphore.Weighted), - } -} - -// NewDbFromConfig returns a new DB from Config. -func NewDbFromConfig(c *Config, logger *logging.Logger) (*DB, error) { - var dsn string - switch c.Type { - case "mysql": - config := mysql.NewConfig() - - config.User = c.User - config.Passwd = c.Password - - if utils.IsUnixAddr(c.Host) { - config.Net = "unix" - config.Addr = c.Host - } else { - config.Net = "tcp" - port := c.Port - if port == 0 { - port = 3306 - } - config.Addr = utils.JoinHostPort(c.Host, port) - } - - config.DBName = c.Database - config.Timeout = time.Minute - config.Params = map[string]string{"sql_mode": "ANSI_QUOTES"} - - tlsConfig, err := c.TlsOptions.MakeConfig(c.Host) - if err != nil { - return nil, err - } - - if tlsConfig != nil { - config.TLSConfig = "icingadb" - if err := mysql.RegisterTLSConfig(config.TLSConfig, tlsConfig); err != nil { - return nil, errors.Wrap(err, "can't register TLS config") - } - } - - dsn = config.FormatDSN() - case "pgsql": - uri := &url.URL{ - Scheme: "postgres", - User: url.UserPassword(c.User, c.Password), - Path: "/" + url.PathEscape(c.Database), - } - - query := url.Values{ - "connect_timeout": {"60"}, - "binary_parameters": {"yes"}, - - // Host and port can alternatively be specified in the query string. lib/pq can't parse the connection URI - // if a Unix domain socket path is specified in the host part of the URI, therefore always use the query - // string. See also https://github.com/lib/pq/issues/796 - "host": {c.Host}, - } - if c.Port != 0 { - query["port"] = []string{strconv.FormatInt(int64(c.Port), 10)} - } - - if _, err := c.TlsOptions.MakeConfig(c.Host); err != nil { - return nil, err - } - - if c.TlsOptions.Enable { - if c.TlsOptions.Insecure { - query["sslmode"] = []string{"require"} - } else { - query["sslmode"] = []string{"verify-full"} - } - - if c.TlsOptions.Cert != "" { - query["sslcert"] = []string{c.TlsOptions.Cert} - } - - if c.TlsOptions.Key != "" { - query["sslkey"] = []string{c.TlsOptions.Key} - } - - if c.TlsOptions.Ca != "" { - query["sslrootcert"] = []string{c.TlsOptions.Ca} - } - } else { - query["sslmode"] = []string{"disable"} - } - - uri.RawQuery = query.Encode() - dsn = uri.String() - default: - return nil, unknownDbType(c.Type) - } - - db, err := sqlx.Open("icingadb-"+c.Type, dsn) - if err != nil { - return nil, errors.Wrap(err, "can't open database") - } - - db.SetMaxIdleConns(c.Options.MaxConnections / 3) - db.SetMaxOpenConns(c.Options.MaxConnections) - - db.Mapper = reflectx.NewMapperFunc("db", strcase.Snake) - - return NewDb(db, logger, &c.Options), nil -} - -// BuildColumns returns all columns of the given struct. -func (db *DB) BuildColumns(subject interface{}) []string { - fields := db.Mapper.TypeMap(reflect.TypeOf(subject)).Names - columns := make([]string, 0, len(fields)) - for _, f := range fields { - if f.Field.Tag == "" { - continue - } - columns = append(columns, f.Name) - } - - return columns -} - -// BuildDeleteStmt returns a DELETE statement for the given struct. -func (db *DB) BuildDeleteStmt(from interface{}) string { - return fmt.Sprintf( - `DELETE FROM "%s" WHERE id IN (?)`, - TableName(from), - ) -} - -// BuildInsertStmt returns an INSERT INTO statement for the given struct. -func (db *DB) BuildInsertStmt(into interface{}) (string, int) { - columns := db.BuildColumns(into) - - return fmt.Sprintf( - `INSERT INTO "%s" ("%s") VALUES (%s)`, - TableName(into), - strings.Join(columns, `", "`), - fmt.Sprintf(":%s", strings.Join(columns, ", :")), - ), len(columns) -} - -// BuildInsertIgnoreStmt returns an INSERT statement for the specified struct for -// which the database ignores rows that have already been inserted. -func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { - table := TableName(into) - columns := db.BuildColumns(into) - var clause string - - switch db.DriverName() { - case driver.MySQL: - // MySQL treats UPDATE id = id as a no-op. - clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0]) - case driver.PostgreSQL: - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO NOTHING", table) - } - - return fmt.Sprintf( - `INSERT INTO "%s" ("%s") VALUES (%s) %s`, - table, - strings.Join(columns, `", "`), - fmt.Sprintf(":%s", strings.Join(columns, ", :")), - clause, - ), len(columns) -} - -// BuildSelectStmt returns a SELECT query that creates the FROM part from the given table struct -// and the column list from the specified columns struct. -func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string { - q := fmt.Sprintf( - `SELECT "%s" FROM "%s"`, - strings.Join(db.BuildColumns(columns), `", "`), - TableName(table), - ) - - if scoper, ok := table.(Scoper); ok { - where, _ := db.BuildWhere(scoper.Scope()) - q += ` WHERE ` + where - } - - return q -} - -// BuildUpdateStmt returns an UPDATE statement for the given struct. -func (db *DB) BuildUpdateStmt(update interface{}) (string, int) { - columns := db.BuildColumns(update) - set := make([]string, 0, len(columns)) - - for _, col := range columns { - set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col)) - } - - return fmt.Sprintf( - `UPDATE "%s" SET %s WHERE id = :id`, - TableName(update), - strings.Join(set, ", "), - ), len(columns) + 1 // +1 because of WHERE id = :id -} - -// BuildUpsertStmt returns an upsert statement for the given struct. -func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders int) { - insertColumns := db.BuildColumns(subject) - table := TableName(subject) - var updateColumns []string - - if upserter, ok := subject.(Upserter); ok { - updateColumns = db.BuildColumns(upserter.Upsert()) - } else { - updateColumns = insertColumns - } - - var clause, setFormat string - switch db.DriverName() { - case driver.MySQL: - clause = "ON DUPLICATE KEY UPDATE" - setFormat = `"%[1]s" = VALUES("%[1]s")` - case driver.PostgreSQL: - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO UPDATE SET", table) - setFormat = `"%[1]s" = EXCLUDED."%[1]s"` - } - - set := make([]string, 0, len(updateColumns)) - - for _, col := range updateColumns { - set = append(set, fmt.Sprintf(setFormat, col)) - } - - return fmt.Sprintf( - `INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, - table, - strings.Join(insertColumns, `", "`), - fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")), - clause, - strings.Join(set, ","), - ), len(insertColumns) -} - -// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct -// combined with the AND operator. -func (db *DB) BuildWhere(subject interface{}) (string, int) { - columns := db.BuildColumns(subject) - where := make([]string, 0, len(columns)) - for _, col := range columns { - where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col)) - } - - return strings.Join(where, ` AND `), len(columns) -} - -// OnSuccess is a callback for successful (bulk) DML operations. -type OnSuccess[T any] func(ctx context.Context, affectedRows []T) (err error) - -func OnSuccessIncrement[T any](counter *com.Counter) OnSuccess[T] { - return func(_ context.Context, rows []T) error { - counter.Add(uint64(len(rows))) - return nil - } -} - -func OnSuccessSendTo[T any](ch chan<- T) OnSuccess[T] { - return func(ctx context.Context, rows []T) error { - for _, row := range rows { - select { - case ch <- row: - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - } -} - -// BulkExec bulk executes queries with a single slice placeholder in the form of `IN (?)`. -// Takes in up to the number of arguments specified in count from the arg stream, -// derives and expands a query and executes it with this set of arguments until the arg stream has been processed. -// The derived queries are executed in a separate goroutine with a weighting of 1 -// and can be executed concurrently to the extent allowed by the semaphore passed in sem. -// Arguments for which the query ran successfully will be passed to onSuccess. -func (db *DB) BulkExec( - ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any], -) error { - var counter com.Counter - defer db.log(ctx, query, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - // Use context from group. - bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any]) - - g.Go(func() error { - g, ctx := errgroup.WithContext(ctx) - - for b := range bulk { - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - g.Go(func(b []interface{}) func() error { - return func() error { - defer sem.Release(1) - - return retry.WithBackoff( - ctx, - func(context.Context) error { - stmt, args, err := sqlx.In(query, b) - if err != nil { - return errors.Wrapf(err, "can't build placeholders for %q", query) - } - - stmt = db.Rebind(stmt) - _, err = db.ExecContext(ctx, stmt, args...) - if err != nil { - return CantPerformQuery(err, query) - } - - counter.Add(uint64(len(b))) - - for _, onSuccess := range onSuccess { - if err := onSuccess(ctx, b); err != nil { - return err - } - } - - return nil - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - retry.Settings{}, - ) - } - }(b)) - } - - return g.Wait() - }) - - return g.Wait() -} - -// NamedBulkExec bulk executes queries with named placeholders in a VALUES clause most likely -// in the format INSERT ... VALUES. Takes in up to the number of entities specified in count -// from the arg stream, derives and executes a new query with the VALUES clause expanded to -// this set of arguments, until the arg stream has been processed. -// The queries are executed in a separate goroutine with a weighting of 1 -// and can be executed concurrently to the extent allowed by the semaphore passed in sem. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) NamedBulkExec( - ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity, - splitPolicyFactory com.BulkChunkSplitPolicyFactory[Entity], onSuccess ...OnSuccess[Entity], -) error { - var counter com.Counter - defer db.log(ctx, query, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - bulk := com.Bulk(ctx, arg, count, splitPolicyFactory) - - g.Go(func() error { - for { - select { - case b, ok := <-bulk: - if !ok { - return nil - } - - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - g.Go(func(b []Entity) func() error { - return func() error { - defer sem.Release(1) - - return retry.WithBackoff( - ctx, - func(ctx context.Context) error { - _, err := db.NamedExecContext(ctx, query, b) - if err != nil { - return CantPerformQuery(err, query) - } - - counter.Add(uint64(len(b))) - - for _, onSuccess := range onSuccess { - if err := onSuccess(ctx, b); err != nil { - return err - } - } - - return nil - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - retry.Settings{}, - ) - } - }(b)) - case <-ctx.Done(): - return ctx.Err() - } - } - }) - - return g.Wait() -} - -// NamedBulkExecTx bulk executes queries with named placeholders in separate transactions. -// Takes in up to the number of entities specified in count from the arg stream and -// executes a new transaction that runs a new query for each entity in this set of arguments, -// until the arg stream has been processed. -// The transactions are executed in a separate goroutine with a weighting of 1 -// and can be executed concurrently to the extent allowed by the semaphore passed in sem. -func (db *DB) NamedBulkExecTx( - ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity, -) error { - var counter com.Counter - defer db.log(ctx, query, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - bulk := com.Bulk(ctx, arg, count, com.NeverSplit[Entity]) - - g.Go(func() error { - for { - select { - case b, ok := <-bulk: - if !ok { - return nil - } - - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - g.Go(func(b []Entity) func() error { - return func() error { - defer sem.Release(1) - - return retry.WithBackoff( - ctx, - func(ctx context.Context) error { - tx, err := db.BeginTxx(ctx, nil) - if err != nil { - return errors.Wrap(err, "can't start transaction") - } - - stmt, err := tx.PrepareNamedContext(ctx, query) - if err != nil { - return errors.Wrap(err, "can't prepare named statement with context in transaction") - } - - for _, arg := range b { - if _, err := stmt.ExecContext(ctx, arg); err != nil { - return errors.Wrap(err, "can't execute statement in transaction") - } - } - - if err := tx.Commit(); err != nil { - return errors.Wrap(err, "can't commit transaction") - } - - counter.Add(uint64(len(b))) - - return nil - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - retry.Settings{}, - ) - } - }(b)) - case <-ctx.Done(): - return ctx.Err() - } - } - }) - - return g.Wait() -} - -// BatchSizeByPlaceholders returns how often the specified number of placeholders fits -// into Options.MaxPlaceholdersPerStatement, but at least 1. -func (db *DB) BatchSizeByPlaceholders(n int) int { - s := db.Options.MaxPlaceholdersPerStatement / n - if s > 0 { - return s - } - - return 1 -} - -// YieldAll executes the query with the supplied scope, -// scans each resulting row into an entity returned by the factory function, -// and streams them into a returned channel. -func (db *DB) YieldAll(ctx context.Context, factoryFunc EntityFactoryFunc, query string, scope interface{}) (<-chan Entity, <-chan error) { - entities := make(chan Entity, 1) - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - var counter com.Counter - defer db.log(ctx, query, &counter).Stop() - defer close(entities) - - rows, err := db.NamedQueryContext(ctx, query, scope) - if err != nil { - return CantPerformQuery(err, query) - } - defer rows.Close() - - for rows.Next() { - e := factoryFunc() - - if err := rows.StructScan(e); err != nil { - return errors.Wrapf(err, "can't store query result into a %T: %s", e, query) - } - - select { - case entities <- e: - counter.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - - return entities, com.WaitAsync(ctx, g) -} - -// CreateStreamed bulk creates the specified entities via NamedBulkExec. -// The insert statement is created using BuildInsertStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) CreateStreamed( - ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity], -) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, placeholders := db.BuildInsertStmt(first) - - return db.NamedBulkExec( - ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, - forward, com.NeverSplit[Entity], onSuccess..., - ) -} - -// CreateIgnoreStreamed bulk creates the specified entities via NamedBulkExec. -// The insert statement is created using BuildInsertIgnoreStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) CreateIgnoreStreamed( - ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity], -) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, placeholders := db.BuildInsertIgnoreStmt(first) - - return db.NamedBulkExec( - ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, - forward, SplitOnDupId[Entity], onSuccess..., - ) -} - -// UpsertStreamed bulk upserts the specified entities via NamedBulkExec. -// The upsert statement is created using BuildUpsertStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// Entities for which the query ran successfully will be passed to onSuccess. -func (db *DB) UpsertStreamed( - ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity], -) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, placeholders := db.BuildUpsertStmt(first) - - return db.NamedBulkExec( - ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, - forward, SplitOnDupId[Entity], onSuccess..., - ) -} - -// UpdateStreamed bulk updates the specified entities via NamedBulkExecTx. -// The update statement is created using BuildUpdateStmt with the first entity from the entities stream. -// Bulk size is controlled via Options.MaxRowsPerTransaction and -// concurrency is controlled via Options.MaxConnectionsPerTable. -func (db *DB) UpdateStreamed(ctx context.Context, entities <-chan Entity) error { - first, forward, err := com.CopyFirst(ctx, entities) - if err != nil { - return errors.Wrap(err, "can't copy first entity") - } - sem := db.GetSemaphoreForTable(TableName(first)) - stmt, _ := db.BuildUpdateStmt(first) - - return db.NamedBulkExecTx(ctx, stmt, db.Options.MaxRowsPerTransaction, sem, forward) -} - -// DeleteStreamed bulk deletes the specified ids via BulkExec. -// The delete statement is created using BuildDeleteStmt with the passed entityType. -// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and -// concurrency is controlled via Options.MaxConnectionsPerTable. -// IDs for which the query ran successfully will be passed to onSuccess. -func (db *DB) DeleteStreamed( - ctx context.Context, entityType Entity, ids <-chan interface{}, onSuccess ...OnSuccess[any], -) error { - sem := db.GetSemaphoreForTable(TableName(entityType)) - return db.BulkExec( - ctx, db.BuildDeleteStmt(entityType), db.Options.MaxPlaceholdersPerStatement, sem, ids, onSuccess..., - ) -} - -// Delete creates a channel from the specified ids and -// bulk deletes them by passing the channel along with the entityType to DeleteStreamed. -// IDs for which the query ran successfully will be passed to onSuccess. -func (db *DB) Delete( - ctx context.Context, entityType Entity, ids []interface{}, onSuccess ...OnSuccess[any], -) error { - idsCh := make(chan interface{}, len(ids)) - for _, id := range ids { - idsCh <- id - } - close(idsCh) - - return db.DeleteStreamed(ctx, entityType, idsCh, onSuccess...) -} - -func (db *DB) GetSemaphoreForTable(table string) *semaphore.Weighted { - db.tableSemaphoresMu.Lock() - defer db.tableSemaphoresMu.Unlock() - - if sem, ok := db.tableSemaphores[table]; ok { - return sem - } else { - sem = semaphore.NewWeighted(int64(db.Options.MaxConnectionsPerTable)) - db.tableSemaphores[table] = sem - return sem - } -} - -func (db *DB) log(ctx context.Context, query string, counter *com.Counter) periodic.Stopper { - return periodic.Start(ctx, db.logger.Interval(), func(tick periodic.Tick) { - if count := counter.Reset(); count > 0 { - db.logger.Debugf("Executed %q with %d rows", query, count) - } - }, periodic.OnStop(func(tick periodic.Tick) { - db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed) - })) -} diff --git a/pkg/database/utils.go b/pkg/database/utils.go deleted file mode 100644 index e9eb681db..000000000 --- a/pkg/database/utils.go +++ /dev/null @@ -1,45 +0,0 @@ -package database - -import ( - "github.com/icinga/icingadb/pkg/com" - "github.com/icinga/icingadb/pkg/strcase" - "github.com/icinga/icingadb/pkg/types" - "github.com/pkg/errors" -) - -// CantPerformQuery wraps the given error with the specified query that cannot be executed. -func CantPerformQuery(err error, q string) error { - return errors.Wrapf(err, "can't perform %q", q) -} - -// TableName returns the table of t. -func TableName(t interface{}) string { - if tn, ok := t.(TableNamer); ok { - return tn.TableName() - } else { - return strcase.Snake(types.Name(t)) - } -} - -// SplitOnDupId returns a state machine which tracks the inputs' IDs. -// Once an already seen input arrives, it demands splitting. -func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] { - seenIds := map[string]struct{}{} - - return func(ider T) bool { - id := ider.ID().String() - - _, ok := seenIds[id] - if ok { - seenIds = map[string]struct{}{id: {}} - } else { - seenIds[id] = struct{}{} - } - - return ok - } -} - -var ( - _ com.BulkChunkSplitPolicyFactory[Entity] = SplitOnDupId[Entity] -) diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go deleted file mode 100644 index 3ddf04694..000000000 --- a/pkg/driver/driver.go +++ /dev/null @@ -1,141 +0,0 @@ -package driver - -import ( - "context" - "database/sql" - "database/sql/driver" - "github.com/go-sql-driver/mysql" - "github.com/icinga/icingadb/pkg/backoff" - "github.com/icinga/icingadb/pkg/logging" - "github.com/icinga/icingadb/pkg/retry" - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - "go.uber.org/zap" - "time" -) - -const MySQL = "icingadb-mysql" -const PostgreSQL = "icingadb-pgsql" - -var timeout = time.Minute * 5 - -// RetryConnector wraps driver.Connector with retry logic. -type RetryConnector struct { - driver.Connector - driver Driver - onError retry.OnErrorFunc - onSuccess retry.OnSuccessFunc -} - -// Connect implements part of the driver.Connector interface. -func (c RetryConnector) Connect(ctx context.Context) (driver.Conn, error) { - var conn driver.Conn - err := errors.Wrap(retry.WithBackoff( - ctx, - func(ctx context.Context) (err error) { - conn, err = c.Connector.Connect(ctx) - return - }, - shouldRetry, - backoff.NewExponentialWithJitter(time.Millisecond*128, time.Minute*1), - retry.Settings{ - Timeout: timeout, - OnError: func(elapsed time.Duration, attempt uint64, err, lastErr error) { - if c.onError != nil { - c.onError(elapsed, attempt, err, lastErr) - } - - if lastErr == nil || err.Error() != lastErr.Error() { - c.driver.Logger.Warnw("Can't connect to database. Retrying", zap.Error(err)) - } - }, - OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) { - if c.onSuccess != nil { - c.onSuccess(elapsed, attempt, lastErr) - } - - if attempt > 0 { - c.driver.Logger.Infow("Reconnected to database", - zap.Duration("after", elapsed), zap.Uint64("attempts", attempt+1)) - } - }, - }, - ), "can't connect to database") - return conn, err -} - -// Driver implements part of the driver.Connector interface. -func (c RetryConnector) Driver() driver.Driver { - return c.driver -} - -type Option func(*RetryConnector) - -func WithOnError(f retry.OnErrorFunc) Option { - return func(connector *RetryConnector) { - connector.onError = f - } -} - -func WithOnSuccess(f retry.OnSuccessFunc) Option { - return func(connector *RetryConnector) { - connector.onSuccess = f - } -} - -// Driver wraps a driver.Driver that also must implement driver.DriverContext with logging capabilities and -// provides our RetryConnector. -type Driver struct { - ctxDriver - options []Option - Logger *logging.Logger -} - -// OpenConnector implements the DriverContext interface. -func (d Driver) OpenConnector(name string) (driver.Connector, error) { - c, err := d.ctxDriver.OpenConnector(name) - if err != nil { - return nil, err - } - - rc := &RetryConnector{ - driver: d, - Connector: c, - } - - for _, option := range d.options { - option(rc) - } - - return rc, nil -} - -// Register makes our database Driver available under the name "icingadb-*sql". -func Register(logger *logging.Logger, options ...Option) { - sql.Register(MySQL, &Driver{ctxDriver: &mysql.MySQLDriver{}, Logger: logger}) - sql.Register(PostgreSQL, &Driver{ctxDriver: PgSQLDriver{}, Logger: logger}) - _ = mysql.SetLogger(mysqlLogger(func(v ...interface{}) { logger.Debug(v...) })) - sqlx.BindDriver(PostgreSQL, sqlx.DOLLAR) -} - -// ctxDriver helps ensure that we only support drivers that implement driver.Driver and driver.DriverContext. -type ctxDriver interface { - driver.Driver - driver.DriverContext -} - -// mysqlLogger is an adapter that allows ordinary functions to be used as a logger for mysql.SetLogger. -type mysqlLogger func(v ...interface{}) - -// Print implements the mysql.Logger interface. -func (log mysqlLogger) Print(v ...interface{}) { - log(v) -} - -func shouldRetry(err error) bool { - if errors.Is(err, driver.ErrBadConn) { - return true - } - - return retry.Retryable(err) -} diff --git a/pkg/driver/pgsql.go b/pkg/driver/pgsql.go deleted file mode 100644 index 3c88fe05a..000000000 --- a/pkg/driver/pgsql.go +++ /dev/null @@ -1,22 +0,0 @@ -package driver - -import ( - "database/sql/driver" - "github.com/lib/pq" -) - -// PgSQLDriver extends pq.Driver with driver.DriverContext compliance. -type PgSQLDriver struct { - pq.Driver -} - -// Assert interface compliance. -var ( - _ driver.Driver = PgSQLDriver{} - _ driver.DriverContext = PgSQLDriver{} -) - -// OpenConnector implements the driver.DriverContext interface. -func (PgSQLDriver) OpenConnector(name string) (driver.Connector, error) { - return pq.NewConnector(name) -} diff --git a/pkg/flatten/flatten.go b/pkg/flatten/flatten.go deleted file mode 100644 index 94a6e7ebc..000000000 --- a/pkg/flatten/flatten.go +++ /dev/null @@ -1,47 +0,0 @@ -package flatten - -import ( - "database/sql" - "fmt" - "github.com/icinga/icingadb/pkg/types" - "strconv" -) - -// Flatten creates flat, one-dimensional maps from arbitrarily nested values, e.g. JSON. -func Flatten(value interface{}, prefix string) map[string]types.String { - var flatten func(string, interface{}) - flattened := make(map[string]types.String) - - flatten = func(key string, value interface{}) { - switch value := value.(type) { - case map[string]interface{}: - if len(value) == 0 { - flattened[key] = types.String{} - break - } - - for k, v := range value { - flatten(key+"."+k, v) - } - case []interface{}: - if len(value) == 0 { - flattened[key] = types.String{} - break - } - - for i, v := range value { - flatten(key+"["+strconv.Itoa(i)+"]", v) - } - default: - val := "null" - if value != nil { - val = fmt.Sprintf("%v", value) - } - flattened[key] = types.String{NullString: sql.NullString{String: val, Valid: true}} - } - } - - flatten(prefix, value) - - return flattened -} diff --git a/pkg/logging/config.go b/pkg/logging/config.go deleted file mode 100644 index 63b865ce6..000000000 --- a/pkg/logging/config.go +++ /dev/null @@ -1,43 +0,0 @@ -package logging - -import ( - "github.com/pkg/errors" - "go.uber.org/zap/zapcore" - "os" - "time" -) - -// Config defines Logger configuration. -type Config struct { - // zapcore.Level at 0 is for info level. - Level zapcore.Level `yaml:"level" default:"0"` - Output string `yaml:"output"` - // Interval for periodic logging. - Interval time.Duration `yaml:"interval" default:"20s"` - - Options `yaml:"options"` -} - -// Validate checks constraints in the supplied Config configuration and returns an error if they are violated. -// Also configures the log output if it is not configured: -// systemd-journald is used when Icinga DB is running under systemd, otherwise stderr. -func (l *Config) Validate() error { - if l.Interval <= 0 { - return errors.New("periodic logging interval must be positive") - } - - if l.Output == "" { - if _, ok := os.LookupEnv("NOTIFY_SOCKET"); ok { - // When started by systemd, NOTIFY_SOCKET is set by systemd for Type=notify supervised services, - // which is the default setting for the Icinga DB service. - // This assumes that Icinga DB is running under systemd, so set output to systemd-journald. - l.Output = JOURNAL - } else { - // Otherwise set it to console, i.e. write log messages to stderr. - l.Output = CONSOLE - } - } - - // To be on the safe side, always call AssertOutput. - return AssertOutput(l.Output) -} diff --git a/pkg/logging/journald_core.go b/pkg/logging/journald_core.go deleted file mode 100644 index ac67cc60c..000000000 --- a/pkg/logging/journald_core.go +++ /dev/null @@ -1,84 +0,0 @@ -package logging - -import ( - "github.com/icinga/icingadb/pkg/strcase" - "github.com/pkg/errors" - "github.com/ssgreg/journald" - "go.uber.org/zap/zapcore" - "strings" -) - -// priorities maps zapcore.Level to journal.Priority. -var priorities = map[zapcore.Level]journald.Priority{ - zapcore.DebugLevel: journald.PriorityDebug, - zapcore.InfoLevel: journald.PriorityInfo, - zapcore.WarnLevel: journald.PriorityWarning, - zapcore.ErrorLevel: journald.PriorityErr, - zapcore.FatalLevel: journald.PriorityCrit, - zapcore.PanicLevel: journald.PriorityCrit, - zapcore.DPanicLevel: journald.PriorityCrit, -} - -// NewJournaldCore returns a zapcore.Core that sends log entries to systemd-journald and -// uses the given identifier as a prefix for structured logging context that is sent as journal fields. -func NewJournaldCore(identifier string, enab zapcore.LevelEnabler) zapcore.Core { - return &journaldCore{ - LevelEnabler: enab, - identifier: identifier, - identifierU: strings.ToUpper(identifier), - } -} - -type journaldCore struct { - zapcore.LevelEnabler - context []zapcore.Field - identifier string - identifierU string -} - -func (c *journaldCore) Check(ent zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { - if c.Enabled(ent.Level) { - return ce.AddCore(ent, c) - } - - return ce -} - -func (c *journaldCore) Sync() error { - return nil -} - -func (c *journaldCore) With(fields []zapcore.Field) zapcore.Core { - cc := *c - cc.context = append(cc.context[:len(cc.context):len(cc.context)], fields...) - - return &cc -} - -func (c *journaldCore) Write(ent zapcore.Entry, fields []zapcore.Field) error { - pri, ok := priorities[ent.Level] - if !ok { - return errors.Errorf("unknown log level %q", ent.Level) - } - - enc := zapcore.NewMapObjectEncoder() - c.addFields(enc, fields) - c.addFields(enc, c.context) - enc.Fields["SYSLOG_IDENTIFIER"] = c.identifier - - message := ent.Message - if ent.LoggerName != c.identifier { - message = ent.LoggerName + ": " + message - } - - return journald.Send(message, pri, enc.Fields) -} - -func (c *journaldCore) addFields(enc zapcore.ObjectEncoder, fields []zapcore.Field) { - for _, field := range fields { - field.Key = c.identifierU + - "_" + - strcase.ScreamingSnake(field.Key) - field.AddTo(enc) - } -} diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go deleted file mode 100644 index 490445e17..000000000 --- a/pkg/logging/logger.go +++ /dev/null @@ -1,26 +0,0 @@ -package logging - -import ( - "go.uber.org/zap" - "time" -) - -// Logger wraps zap.SugaredLogger and -// allows to get the interval for periodic logging. -type Logger struct { - *zap.SugaredLogger - interval time.Duration -} - -// NewLogger returns a new Logger. -func NewLogger(base *zap.SugaredLogger, interval time.Duration) *Logger { - return &Logger{ - SugaredLogger: base, - interval: interval, - } -} - -// Interval returns the interval for periodic logging. -func (l *Logger) Interval() time.Duration { - return l.interval -} diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go deleted file mode 100644 index e3106956a..000000000 --- a/pkg/logging/logging.go +++ /dev/null @@ -1,131 +0,0 @@ -package logging - -import ( - "fmt" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "os" - "sync" - "time" -) - -const ( - CONSOLE = "console" - JOURNAL = "systemd-journald" -) - -// defaultEncConfig defines the default zapcore.EncoderConfig for the logging package. -var defaultEncConfig = zapcore.EncoderConfig{ - TimeKey: "ts", - LevelKey: "level", - NameKey: "logger", - CallerKey: "caller", - MessageKey: "msg", - StacktraceKey: "stacktrace", - LineEnding: zapcore.DefaultLineEnding, - EncodeLevel: zapcore.CapitalLevelEncoder, - EncodeTime: zapcore.ISO8601TimeEncoder, - EncodeDuration: zapcore.StringDurationEncoder, - EncodeCaller: zapcore.ShortCallerEncoder, -} - -// Options define child loggers with their desired log level. -type Options map[string]zapcore.Level - -// Logging implements access to a default logger and named child loggers. -// Log levels can be configured per named child via Options which, if not configured, -// fall back on a default log level. -// Logs either to the console or to systemd-journald. -type Logging struct { - logger *Logger - output string - verbosity zap.AtomicLevel - interval time.Duration - - // coreFactory creates zapcore.Core based on the log level and the log output. - coreFactory func(zap.AtomicLevel) zapcore.Core - - mu sync.Mutex - loggers map[string]*Logger - - options Options -} - -// NewLogging takes the name and log level for the default logger, -// output where log messages are written to, -// options having log levels for named child loggers -// and returns a new Logging. -func NewLogging(name string, level zapcore.Level, output string, options Options, interval time.Duration) (*Logging, error) { - verbosity := zap.NewAtomicLevelAt(level) - - var coreFactory func(zap.AtomicLevel) zapcore.Core - switch output { - case CONSOLE: - enc := zapcore.NewConsoleEncoder(defaultEncConfig) - ws := zapcore.Lock(os.Stderr) - coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core { - return zapcore.NewCore(enc, ws, verbosity) - } - case JOURNAL: - coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core { - return NewJournaldCore(name, verbosity) - } - default: - return nil, invalidOutput(output) - } - - logger := NewLogger(zap.New(coreFactory(verbosity)).Named(name).Sugar(), interval) - - return &Logging{ - logger: logger, - output: output, - verbosity: verbosity, - interval: interval, - coreFactory: coreFactory, - loggers: make(map[string]*Logger), - options: options, - }, - nil -} - -// GetChildLogger returns a named child logger. -// Log levels for named child loggers are obtained from the logging options and, if not found, -// set to the default log level. -func (l *Logging) GetChildLogger(name string) *Logger { - l.mu.Lock() - defer l.mu.Unlock() - - if logger, ok := l.loggers[name]; ok { - return logger - } - - var verbosity zap.AtomicLevel - if level, found := l.options[name]; found { - verbosity = zap.NewAtomicLevelAt(level) - } else { - verbosity = l.verbosity - } - - logger := NewLogger(zap.New(l.coreFactory(verbosity)).Named(name).Sugar(), l.interval) - l.loggers[name] = logger - - return logger -} - -// GetLogger returns the default logger. -func (l *Logging) GetLogger() *Logger { - return l.logger -} - -// AssertOutput returns an error if output is not a valid logger output. -func AssertOutput(o string) error { - if o == CONSOLE || o == JOURNAL { - return nil - } - - return invalidOutput(o) -} - -func invalidOutput(o string) error { - return fmt.Errorf("%s is not a valid logger output. Must be either %q or %q", o, CONSOLE, JOURNAL) -} diff --git a/pkg/objectpacker/objectpacker.go b/pkg/objectpacker/objectpacker.go deleted file mode 100644 index 9ddfdc8b3..000000000 --- a/pkg/objectpacker/objectpacker.go +++ /dev/null @@ -1,214 +0,0 @@ -package objectpacker - -import ( - "bytes" - "encoding/binary" - "fmt" - "github.com/pkg/errors" - "io" - "io/ioutil" - "reflect" - "sort" -) - -// MustPackSlice calls PackAny using items and panics if there was an error. -func MustPackSlice(items ...interface{}) []byte { - var buf bytes.Buffer - - if err := PackAny(items, &buf); err != nil { - panic(err) - } - - return buf.Bytes() -} - -// PackAny packs any JSON-encodable value (ex. structs, also ignores interfaces like encoding.TextMarshaler) -// to a BSON-similar format suitable for consistent hashing. Spec: -// -// PackAny(nil) => 0x0 -// PackAny(false) => 0x1 -// PackAny(true) => 0x2 -// PackAny(float64(42)) => 0x3 ieee754_binary64_bigendian(42) -// PackAny("exämple") => 0x4 uint64_bigendian(len([]byte("exämple"))) []byte("exämple") -// PackAny([]uint8{0x42}) => 0x4 uint64_bigendian(len([]uint8{0x42})) []uint8{0x42} -// PackAny([1]uint8{0x42}) => 0x4 uint64_bigendian(len([1]uint8{0x42})) [1]uint8{0x42} -// PackAny([]T{x,y}) => 0x5 uint64_bigendian(len([]T{x,y})) PackAny(x) PackAny(y) -// PackAny(map[K]V{x:y}) => 0x6 uint64_bigendian(len(map[K]V{x:y})) len(map_key(x)) map_key(x) PackAny(y) -// PackAny((*T)(nil)) => 0x0 -// PackAny((*T)(0x42)) => PackAny(*(*T)(0x42)) -// PackAny(x) => panic() -// -// map_key([1]uint8{0x42}) => [1]uint8{0x42} -// map_key(x) => []byte(fmt.Sprint(x)) -func PackAny(in interface{}, out io.Writer) error { - return errors.Wrapf(packValue(reflect.ValueOf(in), out), "can't pack %#v", in) -} - -var tByte = reflect.TypeOf(byte(0)) -var tBytes = reflect.TypeOf([]uint8(nil)) - -// packValue does the actual job of packAny and just exists for recursion w/o unnecessary reflect.ValueOf calls. -func packValue(in reflect.Value, out io.Writer) error { - switch kind := in.Kind(); kind { - case reflect.Invalid: // nil - _, err := out.Write([]byte{0}) - return err - case reflect.Bool: - if in.Bool() { - _, err := out.Write([]byte{2}) - return err - } else { - _, err := out.Write([]byte{1}) - return err - } - case reflect.Float64: - if _, err := out.Write([]byte{3}); err != nil { - return err - } - - return binary.Write(out, binary.BigEndian, in.Float()) - case reflect.Array, reflect.Slice: - if typ := in.Type(); typ.Elem() == tByte { - if kind == reflect.Array { - if !in.CanAddr() { - vNewElem := reflect.New(typ).Elem() - vNewElem.Set(in) - in = vNewElem - } - - in = in.Slice(0, in.Len()) - } - - // Pack []byte as string, not array of numbers. - return packString(in.Convert(tBytes). // Support types.Binary - Interface().([]uint8), out) - } - - if _, err := out.Write([]byte{5}); err != nil { - return err - } - - l := in.Len() - if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil { - return err - } - - for i := 0; i < l; i++ { - if err := packValue(in.Index(i), out); err != nil { - return err - } - } - - // If there aren't any values to pack, ... - if l < 1 { - // ... create one and pack it - panics on disallowed type. - _ = packValue(reflect.Zero(in.Type().Elem()), ioutil.Discard) - } - - return nil - case reflect.Interface: - return packValue(in.Elem(), out) - case reflect.Map: - type kv struct { - key []byte - value reflect.Value - } - - if _, err := out.Write([]byte{6}); err != nil { - return err - } - - l := in.Len() - if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil { - return err - } - - sorted := make([]kv, 0, l) - - { - iter := in.MapRange() - for iter.Next() { - var packedKey []byte - if key := iter.Key(); key.Kind() == reflect.Array { - if typ := key.Type(); typ.Elem() == tByte { - if !key.CanAddr() { - vNewElem := reflect.New(typ).Elem() - vNewElem.Set(key) - key = vNewElem - } - - packedKey = key.Slice(0, key.Len()).Interface().([]byte) - } else { - // Not just stringify the key (below), but also pack it (here) - panics on disallowed type. - _ = packValue(iter.Key(), ioutil.Discard) - - packedKey = []byte(fmt.Sprint(key.Interface())) - } - } else { - // Not just stringify the key (below), but also pack it (here) - panics on disallowed type. - _ = packValue(iter.Key(), ioutil.Discard) - - packedKey = []byte(fmt.Sprint(key.Interface())) - } - - sorted = append(sorted, kv{packedKey, iter.Value()}) - } - } - - sort.Slice(sorted, func(i, j int) bool { return bytes.Compare(sorted[i].key, sorted[j].key) < 0 }) - - for _, kv := range sorted { - if err := binary.Write(out, binary.BigEndian, uint64(len(kv.key))); err != nil { - return err - } - - if _, err := out.Write(kv.key); err != nil { - return err - } - - if err := packValue(kv.value, out); err != nil { - return err - } - } - - // If there aren't any key-value pairs to pack, ... - if l < 1 { - typ := in.Type() - - // ... create one and pack it - panics on disallowed type. - _ = packValue(reflect.Zero(typ.Key()), ioutil.Discard) - _ = packValue(reflect.Zero(typ.Elem()), ioutil.Discard) - } - - return nil - case reflect.Ptr: - if in.IsNil() { - err := packValue(reflect.Value{}, out) - - // Create a fictive referenced value and pack it - panics on disallowed type. - _ = packValue(reflect.Zero(in.Type().Elem()), ioutil.Discard) - - return err - } else { - return packValue(in.Elem(), out) - } - case reflect.String: - return packString([]byte(in.String()), out) - default: - panic("bad type: " + in.Kind().String()) - } -} - -// packString deduplicates string packing of multiple locations in packValue. -func packString(in []byte, out io.Writer) error { - if _, err := out.Write([]byte{4}); err != nil { - return err - } - - if err := binary.Write(out, binary.BigEndian, uint64(len(in))); err != nil { - return err - } - - _, err := out.Write(in) - return err -} diff --git a/pkg/objectpacker/objectpacker_test.go b/pkg/objectpacker/objectpacker_test.go deleted file mode 100644 index e377d7736..000000000 --- a/pkg/objectpacker/objectpacker_test.go +++ /dev/null @@ -1,195 +0,0 @@ -package objectpacker - -import ( - "bytes" - "github.com/icinga/icingadb/pkg/types" - "github.com/pkg/errors" - "io" - "testing" -) - -// limitedWriter allows writing a specific amount of data. -type limitedWriter struct { - // limit specifies how many bytes to allow to write. - limit int -} - -var _ io.Writer = (*limitedWriter)(nil) - -// Write returns io.EOF once lw.limit is exceeded, nil otherwise. -func (lw *limitedWriter) Write(p []byte) (n int, err error) { - if len(p) <= lw.limit { - lw.limit -= len(p) - return len(p), nil - } - - n = lw.limit - err = io.EOF - - lw.limit = 0 - return -} - -func TestLimitedWriter_Write(t *testing.T) { - assertLimitedWriter_Write(t, 3, []byte{1, 2}, 2, nil, 1) - assertLimitedWriter_Write(t, 3, []byte{1, 2, 3}, 3, nil, 0) - assertLimitedWriter_Write(t, 3, []byte{1, 2, 3, 4}, 3, io.EOF, 0) - assertLimitedWriter_Write(t, 0, []byte{1}, 0, io.EOF, 0) - assertLimitedWriter_Write(t, 0, nil, 0, nil, 0) -} - -func assertLimitedWriter_Write(t *testing.T, limitBefore int, p []byte, n int, err error, limitAfter int) { - t.Helper() - - lw := limitedWriter{limitBefore} - actualN, actualErr := lw.Write(p) - - if !errors.Is(actualErr, err) { - t.Errorf("_, err := (&limitedWriter{%d}).Write(%#v); err != %#v", limitBefore, p, err) - } - - if actualN != n { - t.Errorf("n, _ := (&limitedWriter{%d}).Write(%#v); n != %d", limitBefore, p, n) - } - - if lw.limit != limitAfter { - t.Errorf("lw := limitedWriter{%d}; lw.Write(%#v); lw.limit != %d", limitBefore, p, limitAfter) - } -} - -func TestPackAny(t *testing.T) { - assertPackAny(t, nil, []byte{0}) - assertPackAny(t, false, []byte{1}) - assertPackAny(t, true, []byte{2}) - - assertPackAnyPanic(t, -42, 0) - assertPackAnyPanic(t, int8(-42), 0) - assertPackAnyPanic(t, int16(-42), 0) - assertPackAnyPanic(t, int32(-42), 0) - assertPackAnyPanic(t, int64(-42), 0) - - assertPackAnyPanic(t, uint(42), 0) - assertPackAnyPanic(t, uint8(42), 0) - assertPackAnyPanic(t, uint16(42), 0) - assertPackAnyPanic(t, uint32(42), 0) - assertPackAnyPanic(t, uint64(42), 0) - assertPackAnyPanic(t, uintptr(42), 0) - - assertPackAnyPanic(t, float32(-42.5), 0) - assertPackAny(t, -42.5, []byte{3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0}) - - assertPackAnyPanic(t, []struct{}(nil), 9) - assertPackAnyPanic(t, []struct{}{}, 9) - - assertPackAny(t, []interface{}{nil, true, -42.5}, []byte{ - 5, 0, 0, 0, 0, 0, 0, 0, 3, - 0, - 2, - 3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0, - }) - - assertPackAny(t, []string{"", "a"}, []byte{ - 5, 0, 0, 0, 0, 0, 0, 0, 2, - 4, 0, 0, 0, 0, 0, 0, 0, 0, - 4, 0, 0, 0, 0, 0, 0, 0, 1, 'a', - }) - - assertPackAnyPanic(t, []interface{}{0 + 0i}, 9) - - assertPackAnyPanic(t, map[struct{}]struct{}(nil), 9) - assertPackAnyPanic(t, map[struct{}]struct{}{}, 9) - - assertPackAny(t, map[interface{}]interface{}{true: "", "nil": -42.5}, []byte{ - 6, 0, 0, 0, 0, 0, 0, 0, 2, - 0, 0, 0, 0, 0, 0, 0, 3, 'n', 'i', 'l', - 3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 4, 't', 'r', 'u', 'e', - 4, 0, 0, 0, 0, 0, 0, 0, 0, - }) - - assertPackAny(t, map[string]float64{"": 42}, []byte{ - 6, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 0, - 3, 0x40, 0x45, 0, 0, 0, 0, 0, 0, - }) - - assertPackAny(t, map[[1]byte]bool{{42}: true}, []byte{ - 6, 0, 0, 0, 0, 0, 0, 0, 1, - 0, 0, 0, 0, 0, 0, 0, 1, 42, - 2, - }) - - assertPackAnyPanic(t, map[struct{}]struct{}{{}: {}}, 9) - - assertPackAny(t, (*string)(nil), []byte{0}) - assertPackAnyPanic(t, (*int)(nil), 0) - assertPackAny(t, new(float64), []byte{3, 0, 0, 0, 0, 0, 0, 0, 0}) - - assertPackAny(t, "", []byte{4, 0, 0, 0, 0, 0, 0, 0, 0}) - assertPackAny(t, "a", []byte{4, 0, 0, 0, 0, 0, 0, 0, 1, 'a'}) - assertPackAny(t, "ä", []byte{4, 0, 0, 0, 0, 0, 0, 0, 2, 0xc3, 0xa4}) - - { - var binary [256]byte - for i := range binary { - binary[i] = byte(i) - } - - assertPackAny(t, binary, append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...)) - assertPackAny(t, binary[:], append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...)) - assertPackAny(t, types.Binary(binary[:]), append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...)) - } - - { - type myByte byte - assertPackAnyPanic(t, []myByte(nil), 9) - } - - assertPackAnyPanic(t, complex64(0+0i), 0) - assertPackAnyPanic(t, 0+0i, 0) - assertPackAnyPanic(t, make(chan struct{}), 0) - assertPackAnyPanic(t, func() {}, 0) - assertPackAnyPanic(t, struct{}{}, 0) - assertPackAnyPanic(t, uintptr(0), 0) -} - -func assertPackAny(t *testing.T, in interface{}, out []byte) { - t.Helper() - - { - buf := &bytes.Buffer{} - if err := PackAny(in, buf); err == nil { - if !bytes.Equal(buf.Bytes(), out) { - t.Errorf("buf := &bytes.Buffer{}; packAny(%#v, buf); !bytes.Equal(buf.Bytes(), %#v)", in, out) - } - } else { - t.Errorf("packAny(%#v, &bytes.Buffer{}) != nil", in) - } - } - - for i := 0; i < len(out); i++ { - if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) { - t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i) - } - } -} - -func assertPackAnyPanic(t *testing.T, in interface{}, allowToWrite int) { - t.Helper() - - for i := 0; i < allowToWrite; i++ { - if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) { - t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i) - } - } - - defer func() { - t.Helper() - - if r := recover(); r == nil { - t.Errorf("packAny(%#v, &limitedWriter{%d}) didn't panic", in, allowToWrite) - } - }() - - _ = PackAny(in, &limitedWriter{allowToWrite}) -} diff --git a/pkg/periodic/periodic.go b/pkg/periodic/periodic.go deleted file mode 100644 index 6ef5ceb87..000000000 --- a/pkg/periodic/periodic.go +++ /dev/null @@ -1,123 +0,0 @@ -package periodic - -import ( - "context" - "sync" - "time" -) - -// Option configures Start. -type Option interface { - apply(*periodic) -} - -// Stopper implements the Stop method, -// which stops a periodic task from Start(). -type Stopper interface { - Stop() // Stops a periodic task. -} - -// Tick is the value for periodic task callbacks that -// contains the time of the tick and -// the time elapsed since the start of the periodic task. -type Tick struct { - Elapsed time.Duration - Time time.Time -} - -// Immediate starts the periodic task immediately instead of after the first tick. -func Immediate() Option { - return optionFunc(func(p *periodic) { - p.immediate = true - }) -} - -// OnStop configures a callback that is executed when a periodic task is stopped or canceled. -func OnStop(f func(Tick)) Option { - return optionFunc(func(p *periodic) { - p.onStop = f - }) -} - -// Start starts a periodic task with a ticker at the specified interval, -// which executes the given callback after each tick. -// Pending tasks do not overlap, but could start immediately if -// the previous task(s) takes longer than the interval. -// Call Stop() on the return value in order to stop the ticker and to release associated resources. -// The interval must be greater than zero. -func Start(ctx context.Context, interval time.Duration, callback func(Tick), options ...Option) Stopper { - t := &periodic{ - interval: interval, - callback: callback, - } - - for _, option := range options { - option.apply(t) - } - - ctx, cancelCtx := context.WithCancel(ctx) - - start := time.Now() - - go func() { - done := false - - if !t.immediate { - select { - case <-time.After(interval): - case <-ctx.Done(): - done = true - } - } - - if !done { - ticker := time.NewTicker(t.interval) - defer ticker.Stop() - - for tickTime := time.Now(); !done; { - t.callback(Tick{ - Elapsed: tickTime.Sub(start), - Time: tickTime, - }) - - select { - case tickTime = <-ticker.C: - case <-ctx.Done(): - done = true - } - } - } - - if t.onStop != nil { - now := time.Now() - t.onStop(Tick{ - Elapsed: now.Sub(start), - Time: now, - }) - } - }() - - return stoperFunc(func() { - t.stop.Do(cancelCtx) - }) -} - -type optionFunc func(*periodic) - -func (f optionFunc) apply(p *periodic) { - f(p) -} - -type stoperFunc func() - -func (f stoperFunc) Stop() { - f() -} - -type periodic struct { - interval time.Duration - callback func(Tick) - immediate bool - stop sync.Once - onStop func(Tick) -} diff --git a/pkg/redis/client.go b/pkg/redis/client.go deleted file mode 100644 index 83f4c21a7..000000000 --- a/pkg/redis/client.go +++ /dev/null @@ -1,310 +0,0 @@ -package redis - -import ( - "context" - "crypto/tls" - "github.com/go-redis/redis/v8" - "github.com/icinga/icingadb/pkg/backoff" - "github.com/icinga/icingadb/pkg/com" - "github.com/icinga/icingadb/pkg/logging" - "github.com/icinga/icingadb/pkg/periodic" - "github.com/icinga/icingadb/pkg/retry" - "github.com/icinga/icingadb/pkg/utils" - "github.com/pkg/errors" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/semaphore" - "net" - "time" -) - -// Client is a wrapper around redis.Client with -// streaming and logging capabilities. -type Client struct { - *redis.Client - - Options *Options - - logger *logging.Logger -} - -// Options define user configurable Redis options. -type Options struct { - BlockTimeout time.Duration `yaml:"block_timeout" default:"1s"` - HMGetCount int `yaml:"hmget_count" default:"4096"` - HScanCount int `yaml:"hscan_count" default:"4096"` - MaxHMGetConnections int `yaml:"max_hmget_connections" default:"8"` - MinPoolSize int `yaml:"min_pool_size" default:"32"` - Timeout time.Duration `yaml:"timeout" default:"30s"` - XReadCount int `yaml:"xread_count" default:"4096"` -} - -// Validate checks constraints in the supplied Redis options and returns an error if they are violated. -func (o *Options) Validate() error { - if o.BlockTimeout <= 0 { - return errors.New("block_timeout must be positive") - } - if o.HMGetCount < 1 { - return errors.New("hmget_count must be at least 1") - } - if o.HScanCount < 1 { - return errors.New("hscan_count must be at least 1") - } - if o.MinPoolSize < 1 { - return errors.New("min_pool_size must be at least 1") - } - if o.MaxHMGetConnections < 1 { - return errors.New("max_hmget_connections must be at least 1") - } - if o.Timeout == 0 { - return errors.New("timeout cannot be 0. Configure a value greater than zero, or use -1 for no timeout") - } - if o.XReadCount < 1 { - return errors.New("xread_count must be at least 1") - } - - return nil -} - -// NewClient returns a new Client wrapper for a pre-existing redis.Client. -func NewClient(client *redis.Client, logger *logging.Logger, options *Options) *Client { - return &Client{Client: client, logger: logger, Options: options} -} - -// NewClientFromConfig returns a new Client from Config. -func NewClientFromConfig(c *Config, logger *logging.Logger) (*Client, error) { - tlsConfig, err := c.TlsOptions.MakeConfig(c.Host) - if err != nil { - return nil, err - } - - var dialer ctxDialerFunc - dl := &net.Dialer{Timeout: 15 * time.Second} - - if tlsConfig == nil { - dialer = dl.DialContext - } else { - dialer = (&tls.Dialer{NetDialer: dl, Config: tlsConfig}).DialContext - } - - options := &redis.Options{ - Dialer: dialWithLogging(dialer, logger), - Password: c.Password, - DB: c.Database, - ReadTimeout: c.Options.Timeout, - TLSConfig: tlsConfig, - } - - if utils.IsUnixAddr(c.Host) { - options.Network = "unix" - options.Addr = c.Host - } else { - port := c.Port - if port == 0 { - port = 6379 - } - options.Network = "tcp" - options.Addr = utils.JoinHostPort(c.Host, c.Port) - } - - client := redis.NewClient(options) - options = client.Options() - options.PoolSize = utils.MaxInt(c.Options.MinPoolSize, options.PoolSize) - options.MaxRetries = options.PoolSize + 1 // https://github.com/go-redis/redis/issues/1737 - client = redis.NewClient(options) - - return NewClient(client, logger, &c.Options), nil -} - -// HPair defines Redis hashes field-value pairs. -type HPair struct { - Field string - Value string -} - -// HYield yields HPair field-value pairs for all fields in the hash stored at key. -func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) { - pairs := make(chan HPair, c.Options.HScanCount) - - return pairs, com.WaitAsync(ctx, com.WaiterFunc(func() error { - var counter com.Counter - defer c.log(ctx, key, &counter).Stop() - defer close(pairs) - - seen := make(map[string]struct{}) - - var cursor uint64 - var err error - var page []string - - for { - cmd := c.HScan(ctx, key, cursor, "", int64(c.Options.HScanCount)) - page, cursor, err = cmd.Result() - - if err != nil { - return WrapCmdErr(cmd) - } - - for i := 0; i < len(page); i += 2 { - if _, ok := seen[page[i]]; ok { - // Ignore duplicate returned by HSCAN. - continue - } - - seen[page[i]] = struct{}{} - - select { - case pairs <- HPair{ - Field: page[i], - Value: page[i+1], - }: - counter.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - if cursor == 0 { - break - } - } - - return nil - })) -} - -// HMYield yields HPair field-value pairs for the specified fields in the hash stored at key. -func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) { - pairs := make(chan HPair) - - return pairs, com.WaitAsync(ctx, com.WaiterFunc(func() error { - var counter com.Counter - defer c.log(ctx, key, &counter).Stop() - - g, ctx := errgroup.WithContext(ctx) - - defer func() { - // Wait until the group is done so that we can safely close the pairs channel, - // because on error, sem.Acquire will return before calling g.Wait(), - // which can result in goroutines working on a closed channel. - _ = g.Wait() - close(pairs) - }() - - // Use context from group. - batches := utils.BatchSliceOfStrings(ctx, fields, c.Options.HMGetCount) - - sem := semaphore.NewWeighted(int64(c.Options.MaxHMGetConnections)) - - for batch := range batches { - if err := sem.Acquire(ctx, 1); err != nil { - return errors.Wrap(err, "can't acquire semaphore") - } - - batch := batch - g.Go(func() error { - defer sem.Release(1) - - cmd := c.HMGet(ctx, key, batch...) - vals, err := cmd.Result() - - if err != nil { - return WrapCmdErr(cmd) - } - - for i, v := range vals { - if v == nil { - c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) - continue - } - - select { - case pairs <- HPair{ - Field: batch[i], - Value: v.(string), - }: - counter.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - } - - return g.Wait() - })) -} - -// XReadUntilResult (repeatedly) calls XREAD with the specified arguments until a result is returned. -// Each call blocks at most for the duration specified in Options.BlockTimeout until data -// is available before it times out and the next call is made. -// This also means that an already set block timeout is overridden. -func (c *Client) XReadUntilResult(ctx context.Context, a *redis.XReadArgs) ([]redis.XStream, error) { - a.Block = c.Options.BlockTimeout - - for { - cmd := c.XRead(ctx, a) - streams, err := cmd.Result() - if err != nil { - if errors.Is(err, redis.Nil) { - continue - } - - return streams, WrapCmdErr(cmd) - } - - return streams, nil - } -} - -func (c *Client) log(ctx context.Context, key string, counter *com.Counter) periodic.Stopper { - return periodic.Start(ctx, c.logger.Interval(), func(tick periodic.Tick) { - // We may never get to progress logging here, - // as fetching should be completed before the interval expires, - // but if it does, it is good to have this log message. - if count := counter.Reset(); count > 0 { - c.logger.Debugf("Fetched %d items from %s", count, key) - } - }, periodic.OnStop(func(tick periodic.Tick) { - c.logger.Debugf("Finished fetching from %s with %d items in %s", key, counter.Total(), tick.Elapsed) - })) -} - -type ctxDialerFunc = func(ctx context.Context, network, addr string) (net.Conn, error) - -// dialWithLogging returns a Config Dialer with logging capabilities. -func dialWithLogging(dialer ctxDialerFunc, logger *logging.Logger) ctxDialerFunc { - // dial behaves like net.Dialer#DialContext, - // but re-tries on common errors that are considered retryable. - return func(ctx context.Context, network, addr string) (conn net.Conn, err error) { - err = retry.WithBackoff( - ctx, - func(ctx context.Context) (err error) { - conn, err = dialer(ctx, network, addr) - return - }, - retry.Retryable, - backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), - retry.Settings{ - Timeout: 5 * time.Minute, - OnError: func(_ time.Duration, _ uint64, err, lastErr error) { - if lastErr == nil || err.Error() != lastErr.Error() { - logger.Warnw("Can't connect to Redis. Retrying", zap.Error(err)) - } - }, - OnSuccess: func(elapsed time.Duration, attempt uint64, _ error) { - if attempt > 0 { - logger.Infow("Reconnected to Redis", - zap.Duration("after", elapsed), zap.Uint64("attempts", attempt+1)) - } - }, - }, - ) - - err = errors.Wrap(err, "can't connect to Redis") - - return - } -} diff --git a/pkg/redis/config.go b/pkg/redis/config.go deleted file mode 100644 index 4ca730a14..000000000 --- a/pkg/redis/config.go +++ /dev/null @@ -1,25 +0,0 @@ -package redis - -import ( - "github.com/icinga/icingadb/pkg/config" - "github.com/pkg/errors" -) - -// Config defines Config client configuration. -type Config struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - Database int `yaml:"database" default:"0"` - Password string `yaml:"password"` - TlsOptions config.TLS `yaml:",inline"` - Options Options `yaml:"options"` -} - -// Validate checks constraints in the supplied Config configuration and returns an error if they are violated. -func (r *Config) Validate() error { - if r.Host == "" { - return errors.New("Config host missing") - } - - return r.Options.Validate() -} diff --git a/pkg/redis/streams.go b/pkg/redis/streams.go deleted file mode 100644 index 737954369..000000000 --- a/pkg/redis/streams.go +++ /dev/null @@ -1,20 +0,0 @@ -package redis - -// Streams represents a Redis stream key to ID mapping. -type Streams map[string]string - -// Option returns the Redis stream key to ID mapping -// as a slice of stream keys followed by their IDs -// that is compatible for the Redis STREAMS option. -func (s Streams) Option() []string { - // len*2 because we're appending the IDs later. - streams := make([]string, 0, len(s)*2) - ids := make([]string, 0, len(s)) - - for key, id := range s { - streams = append(streams, key) - ids = append(ids, id) - } - - return append(streams, ids...) -} diff --git a/pkg/redis/utils.go b/pkg/redis/utils.go deleted file mode 100644 index 0a84a0af9..000000000 --- a/pkg/redis/utils.go +++ /dev/null @@ -1,22 +0,0 @@ -package redis - -import ( - "context" - "github.com/go-redis/redis/v8" - "github.com/icinga/icingadb/pkg/utils" - "github.com/pkg/errors" -) - -// WrapCmdErr adds the command itself and -// the stack of the current goroutine to the command's error if any. -func WrapCmdErr(cmd redis.Cmder) error { - err := cmd.Err() - if err != nil { - err = errors.Wrapf(err, "can't perform %q", utils.Ellipsize( - redis.NewCmd(context.Background(), cmd.Args()).String(), // Omits error in opposite to cmd.String() - 100, - )) - } - - return err -} diff --git a/pkg/retry/retry.go b/pkg/retry/retry.go deleted file mode 100644 index f319feed8..000000000 --- a/pkg/retry/retry.go +++ /dev/null @@ -1,190 +0,0 @@ -package retry - -import ( - "context" - "database/sql/driver" - "github.com/go-sql-driver/mysql" - "github.com/icinga/icingadb/pkg/backoff" - "github.com/lib/pq" - "github.com/pkg/errors" - "net" - "strings" - "syscall" - "time" -) - -// RetryableFunc is a retryable function. -type RetryableFunc func(context.Context) error - -// IsRetryable checks whether a new attempt can be started based on the error passed. -type IsRetryable func(error) bool - -// OnErrorFunc is called if an error occurs. -type OnErrorFunc func(elapsed time.Duration, attempt uint64, err, lastErr error) - -// OnSuccessFunc is called once the operation succeeds. -type OnSuccessFunc func(elapsed time.Duration, attempt uint64, lastErr error) - -// Settings aggregates optional settings for WithBackoff. -type Settings struct { - // Timeout lets WithBackoff give up once elapsed (if >0). - Timeout time.Duration - // OnError is called if an error occurs. - OnError OnErrorFunc - // OnSuccess is called once the operation succeeds. - OnSuccess OnSuccessFunc -} - -// WithBackoff retries the passed function if it fails and the error allows it to retry. -// The specified backoff policy is used to determine how long to sleep between attempts. -func WithBackoff( - ctx context.Context, retryableFunc RetryableFunc, retryable IsRetryable, b backoff.Backoff, settings Settings, -) (err error) { - parentCtx := ctx - - if settings.Timeout > 0 { - var cancelCtx context.CancelFunc - ctx, cancelCtx = context.WithTimeout(ctx, settings.Timeout) - defer cancelCtx() - } - - start := time.Now() - for attempt := uint64(0); ; /* true */ attempt++ { - prevErr := err - - if err = retryableFunc(ctx); err == nil { - if settings.OnSuccess != nil { - settings.OnSuccess(time.Since(start), attempt, prevErr) - } - - return - } - - if settings.OnError != nil { - settings.OnError(time.Since(start), attempt, err, prevErr) - } - - isRetryable := retryable(err) - - if prevErr != nil && (errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) { - err = prevErr - } - - if !isRetryable { - err = errors.Wrap(err, "can't retry") - - return - } - - sleep := b(attempt) - select { - case <-ctx.Done(): - if outerErr := parentCtx.Err(); outerErr != nil { - err = errors.Wrap(outerErr, "outer context canceled") - } else { - if err == nil { - err = ctx.Err() - } - err = errors.Wrap(err, "can't retry") - } - - return - case <-time.After(sleep): - } - } -} - -// Retryable returns true for common errors that are considered retryable, -// i.e. temporary, timeout, DNS, connection refused and reset, host down and unreachable and -// network down and unreachable errors. -func Retryable(err error) bool { - var temporary interface { - Temporary() bool - } - if errors.As(err, &temporary) && temporary.Temporary() { - return true - } - - var timeout interface { - Timeout() bool - } - if errors.As(err, &timeout) && timeout.Timeout() { - return true - } - - var dnsError *net.DNSError - if errors.As(err, &dnsError) { - return true - } - - var opError *net.OpError - if errors.As(err, &opError) { - // OpError provides Temporary() and Timeout(), but not Unwrap(), - // so we have to extract the underlying error ourselves to also check for ECONNREFUSED, - // which is not considered temporary or timed out by Go. - err = opError.Err - } - if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) { - // syscall errors provide Temporary() and Timeout(), - // which do not include ECONNREFUSED or ENOENT, so we check these ourselves. - return true - } - if errors.Is(err, syscall.ECONNRESET) { - // ECONNRESET is treated as a temporary error by Go only if it comes from calling accept. - return true - } - if errors.Is(err, syscall.EHOSTDOWN) || errors.Is(err, syscall.EHOSTUNREACH) { - return true - } - if errors.Is(err, syscall.ENETDOWN) || errors.Is(err, syscall.ENETUNREACH) { - return true - } - - if errors.Is(err, driver.ErrBadConn) { - return true - } - if errors.Is(err, mysql.ErrInvalidConn) { - return true - } - - var e *mysql.MySQLError - if errors.As(err, &e) { - switch e.Number { - case 1053, 1205, 1213, 2006: - // 1053: Server shutdown in progress - // 1205: Lock wait timeout - // 1213: Deadlock found when trying to get lock - // 2006: MySQL server has gone away - return true - default: - return false - } - } - - var pe *pq.Error - if errors.As(err, &pe) { - switch pe.Code { - case "08000", // connection_exception - "08006", // connection_failure - "08001", // sqlclient_unable_to_establish_sqlconnection - "08004", // sqlserver_rejected_establishment_of_sqlconnection - "40001", // serialization_failure - "40P01", // deadlock_detected - "54000", // program_limit_exceeded - "55006", // object_in_use - "55P03", // lock_not_available - "57P01", // admin_shutdown - "57P02", // crash_shutdown - "57P03", // cannot_connect_now - "58000", // system_error - "58030", // io_error - "XX000": // internal_error - return true - default: - // Class 53 - Insufficient Resources - return strings.HasPrefix(string(pe.Code), "53") - } - } - - return false -} diff --git a/pkg/strcase/strcase.go b/pkg/strcase/strcase.go deleted file mode 100644 index 211442de8..000000000 --- a/pkg/strcase/strcase.go +++ /dev/null @@ -1,70 +0,0 @@ -package strcase - -import ( - "strings" - "unicode" -) - -var delimiters = map[rune]any{' ': struct{}{}, '_': struct{}{}, '-': struct{}{}, '.': struct{}{}} - -// Delimited converts a string to delimited.lower.case, here using `.` as delimiter. -func Delimited(s string, d rune) string { - return convert(s, unicode.LowerCase, d) -} - -// ScreamingDelimited converts a string to DELIMITED.UPPER.CASE, here using `.` as delimiter. -func ScreamingDelimited(s string, d rune) string { - return convert(s, unicode.UpperCase, d) -} - -// Snake converts a string to snake_case. -func Snake(s string) string { - return Delimited(s, '_') -} - -// ScreamingSnake converts a string to SCREAMING_SNAKE_CASE. -func ScreamingSnake(s string) string { - return ScreamingDelimited(s, '_') -} - -// convert converts a camelCase or space/underscore/hyphen/dot delimited string into various cases. -// _case must be unicode.Lower or unicode.Upper. -func convert(s string, _case int, d rune) string { - if s == "" { - return s - } - - var wasLower bool - var wasNumber bool - - s = strings.TrimSpace(s) - - n := strings.Builder{} - n.Grow(len(s) + 2) // Allow adding at least 2 delimiters without another allocation. - - for _, r := range s { - var isNumber bool - delimiter, isDelimiter := delimiters[r] - if !isDelimiter { - isNumber = unicode.IsNumber(r) - if wasNumber { - if !isNumber { - n.WriteRune(d) - } - } else if isNumber || wasLower && unicode.IsUpper(r) { - n.WriteRune(d) - } - - n.WriteRune(unicode.To(_case, r)) - } else if delimiter != d { - n.WriteRune(d) - } - - wasLower = unicode.IsLower(r) - if !wasLower { - wasNumber = isNumber - } - } - - return n.String() -} diff --git a/pkg/strcase/strcase_test.go b/pkg/strcase/strcase_test.go deleted file mode 100644 index a72133d5b..000000000 --- a/pkg/strcase/strcase_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package strcase - -import "testing" - -func TestSnake(t *testing.T) { snake(t) } - -func BenchmarkSnake(b *testing.B) { - for n := 0; n < b.N; n++ { - snake(b) - } -} - -func TestScreamingSnake(t *testing.T) { screamingSnake(t) } - -func BenchmarkScreamingSnake(b *testing.B) { - for n := 0; n < b.N; n++ { - screamingSnake(b) - } -} - -func snake(tb testing.TB) { - cases := [][]string{ - {"", ""}, - {"AnyKind of_string", "any_kind_of_string"}, - {" Test Case ", "test_case"}, - {" Test Case", "test_case"}, - {"testCase", "test_case"}, - {"test_case", "test_case"}, - {"Test Case ", "test_case"}, - {"Test Case", "test_case"}, - {"TestCase", "test_case"}, - {"Test", "test"}, - {"test", "test"}, - {"ID", "id"}, - {"ManyManyWords", "many_many_words"}, - {"manyManyWords", "many_many_words"}, - {" some string", "some_string"}, - {"some string", "some_string"}, - {"userID", "user_id"}, - {"icinga2", "icinga_2"}, - } - for _, c := range cases { - s, expected := c[0], c[1] - actual := Snake(s) - if actual != expected { - tb.Errorf("%q: %q != %q", s, actual, expected) - } - } -} - -func screamingSnake(tb testing.TB) { - cases := [][]string{ - {"", ""}, - {"AnyKind of_string", "ANY_KIND_OF_STRING"}, - {" Test Case ", "TEST_CASE"}, - {" Test Case", "TEST_CASE"}, - {"testCase", "TEST_CASE"}, - {"test_case", "TEST_CASE"}, - {"Test Case ", "TEST_CASE"}, - {"Test Case", "TEST_CASE"}, - {"TestCase", "TEST_CASE"}, - {"Test", "TEST"}, - {"test", "TEST"}, - {"ID", "ID"}, - {"ManyManyWords", "MANY_MANY_WORDS"}, - {"manyManyWords", "MANY_MANY_WORDS"}, - {" some string", "SOME_STRING"}, - {"some string", "SOME_STRING"}, - {"userID", "USER_ID"}, - {"icinga2", "ICINGA_2"}, - } - for _, c := range cases { - s, expected := c[0], c[1] - actual := ScreamingSnake(s) - if actual != expected { - tb.Errorf("%q: %q != %q", s, actual, expected) - } - } -} diff --git a/pkg/structify/structify.go b/pkg/structify/structify.go deleted file mode 100644 index aab22ae52..000000000 --- a/pkg/structify/structify.go +++ /dev/null @@ -1,174 +0,0 @@ -package structify - -import ( - "encoding" - "fmt" - "github.com/pkg/errors" - "golang.org/x/exp/constraints" - "reflect" - "strconv" - "strings" - "unsafe" -) - -// structBranch represents either a leaf or a subTree. -type structBranch struct { - // field specifies the struct field index. - field int - // leaf specifies the map key to parse the struct field from. - leaf string - // subTree specifies the struct field's inner tree. - subTree []structBranch -} - -type MapStructifier = func(map[string]interface{}) (interface{}, error) - -// MakeMapStructifier builds a function which parses a map's string values into a new struct of type t -// and returns a pointer to it. tag specifies which tag connects struct fields to map keys. -// MakeMapStructifier panics if it detects an unsupported type (suitable for usage in init() or global vars). -func MakeMapStructifier(t reflect.Type, tag string, initer func(any)) MapStructifier { - tree := buildStructTree(t, tag) - - return func(kv map[string]interface{}) (interface{}, error) { - vPtr := reflect.New(t) - ptr := vPtr.Interface() - initer(ptr) - vPtrElem := vPtr.Elem() - err := errors.Wrapf(structifyMapByTree(kv, tree, vPtrElem, vPtrElem, new([]int)), "can't structify map %#v by tree %#v", kv, tree) - - return ptr, err - } -} - -// buildStructTree assembles a tree which represents the struct t based on tag. -func buildStructTree(t reflect.Type, tag string) []structBranch { - var tree []structBranch - numFields := t.NumField() - - for i := 0; i < numFields; i++ { - if field := t.Field(i); field.PkgPath == "" { - switch tagValue := field.Tag.Get(tag); tagValue { - case "", "-": - case ",inline": - if subTree := buildStructTree(field.Type, tag); subTree != nil { - tree = append(tree, structBranch{i, "", subTree}) - } - default: - // If parseString doesn't support *T, it'll panic. - _ = parseString("", reflect.New(field.Type).Interface()) - - tree = append(tree, structBranch{i, tagValue, nil}) - } - } - } - - return tree -} - -// structifyMapByTree parses src's string values into the struct dest according to tree's specification. -func structifyMapByTree(src map[string]interface{}, tree []structBranch, dest, root reflect.Value, stack *[]int) error { - *stack = append(*stack, 0) - defer func() { - *stack = (*stack)[:len(*stack)-1] - }() - - for _, branch := range tree { - (*stack)[len(*stack)-1] = branch.field - - if branch.subTree == nil { - if v, ok := src[branch.leaf]; ok { - if vs, ok := v.(string); ok { - if err := parseString(vs, dest.Field(branch.field).Addr().Interface()); err != nil { - rt := root.Type() - typ := rt - var path []string - - for _, i := range *stack { - f := typ.Field(i) - path = append(path, f.Name) - typ = f.Type - } - - return errors.Wrapf(err, "can't parse %s into the %s %s#%s: %s", - branch.leaf, typ.Name(), rt.Name(), strings.Join(path, "."), vs) - } - } - } - } else if err := structifyMapByTree(src, branch.subTree, dest.Field(branch.field), root, stack); err != nil { - return err - } - } - - return nil -} - -// parseString parses src into *dest. -func parseString(src string, dest interface{}) error { - switch ptr := dest.(type) { - case encoding.TextUnmarshaler: - return ptr.UnmarshalText([]byte(src)) - case *string: - *ptr = src - return nil - case **string: - *ptr = &src - return nil - case *uint8: - return parseUint(src, ptr) - case *uint16: - return parseUint(src, ptr) - case *uint32: - return parseUint(src, ptr) - case *uint64: - return parseUint(src, ptr) - case *int8: - return parseInt(src, ptr) - case *int16: - return parseInt(src, ptr) - case *int32: - return parseInt(src, ptr) - case *int64: - return parseInt(src, ptr) - case *float32: - return parseFloat(src, ptr) - case *float64: - return parseFloat(src, ptr) - default: - panic(fmt.Sprintf("unsupported type: %T", dest)) - } -} - -// parseUint parses src into *dest. -func parseUint[T constraints.Unsigned](src string, dest *T) error { - i, err := strconv.ParseUint(src, 10, bitSizeOf[T]()) - if err == nil { - *dest = T(i) - } - - return err -} - -// parseInt parses src into *dest. -func parseInt[T constraints.Signed](src string, dest *T) error { - i, err := strconv.ParseInt(src, 10, bitSizeOf[T]()) - if err == nil { - *dest = T(i) - } - - return err -} - -// parseFloat parses src into *dest. -func parseFloat[T constraints.Float](src string, dest *T) error { - f, err := strconv.ParseFloat(src, bitSizeOf[T]()) - if err == nil { - *dest = T(f) - } - - return err -} - -func bitSizeOf[T any]() int { - var x T - return int(unsafe.Sizeof(x) * 8) -} diff --git a/pkg/types/binary.go b/pkg/types/binary.go deleted file mode 100644 index 2a786afef..000000000 --- a/pkg/types/binary.go +++ /dev/null @@ -1,125 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/hex" - "encoding/json" - "fmt" - "github.com/pkg/errors" -) - -// Binary nullable byte string. Hex as JSON. -type Binary []byte - -// nullBinary for validating whether a Binary is valid. -var nullBinary Binary - -// Valid returns whether the Binary is valid. -func (binary Binary) Valid() bool { - return !bytes.Equal(binary, nullBinary) -} - -// String returns the hex string representation form of the Binary. -func (binary Binary) String() string { - return hex.EncodeToString(binary) -} - -// MarshalText implements a custom marhsal function to encode -// the Binary as hex. MarshalText implements the -// encoding.TextMarshaler interface. -func (binary Binary) MarshalText() ([]byte, error) { - return []byte(binary.String()), nil -} - -// UnmarshalText implements a custom unmarshal function to decode -// hex into a Binary. UnmarshalText implements the -// encoding.TextUnmarshaler interface. -func (binary *Binary) UnmarshalText(text []byte) error { - b := make([]byte, hex.DecodedLen(len(text))) - _, err := hex.Decode(b, text) - if err != nil { - return CantDecodeHex(err, string(text)) - } - *binary = b - - return nil -} - -// MarshalJSON implements a custom marshal function to encode the Binary -// as a hex string. MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (binary Binary) MarshalJSON() ([]byte, error) { - if !binary.Valid() { - return []byte("null"), nil - } - - return MarshalJSON(binary.String()) -} - -// UnmarshalJSON implements a custom unmarshal function to decode -// a JSON hex string into a Binary. UnmarshalJSON implements the -// json.Unmarshaler interface. Supports JSON null. -func (binary *Binary) UnmarshalJSON(data []byte) error { - if string(data) == "null" || len(data) == 0 { - return nil - } - - var s string - if err := UnmarshalJSON(data, &s); err != nil { - return err - } - b, err := hex.DecodeString(s) - if err != nil { - return CantDecodeHex(err, s) - } - *binary = b - - return nil -} - -// Scan implements the sql.Scanner interface. -// Supports SQL NULL. -func (binary *Binary) Scan(src interface{}) error { - switch src := src.(type) { - case nil: - return nil - - case []byte: - if len(src) == 0 { - return nil - } - - b := make([]byte, len(src)) - copy(b, src) - *binary = b - - default: - return errors.Errorf("unable to scan type %T into Binary", src) - } - - return nil -} - -// Value implements the driver.Valuer interface. -// Supports SQL NULL. -func (binary Binary) Value() (driver.Value, error) { - if !binary.Valid() { - return nil, nil - } - - return []byte(binary), nil -} - -// Assert interface compliance. -var ( - _ fmt.Stringer = (*Binary)(nil) - _ encoding.TextMarshaler = (*Binary)(nil) - _ encoding.TextUnmarshaler = (*Binary)(nil) - _ json.Marshaler = (*Binary)(nil) - _ json.Unmarshaler = (*Binary)(nil) - _ sql.Scanner = (*Binary)(nil) - _ driver.Valuer = (*Binary)(nil) -) diff --git a/pkg/types/binary_test.go b/pkg/types/binary_test.go deleted file mode 100644 index 2a4f82920..000000000 --- a/pkg/types/binary_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package types - -import ( - "github.com/stretchr/testify/require" - "testing" - "unicode/utf8" -) - -func TestBinary_MarshalJSON(t *testing.T) { - subtests := []struct { - name string - input Binary - output string - }{ - {"nil", nil, `null`}, - {"empty", make(Binary, 0, 1), `null`}, - {"space", Binary(" "), `"20"`}, - } - - for _, st := range subtests { - t.Run(st.name, func(t *testing.T) { - actual, err := st.input.MarshalJSON() - - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, st.output, string(actual)) - }) - } -} diff --git a/pkg/types/bool.go b/pkg/types/bool.go deleted file mode 100644 index cd0af2cea..000000000 --- a/pkg/types/bool.go +++ /dev/null @@ -1,104 +0,0 @@ -package types - -import ( - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "github.com/pkg/errors" - "strconv" -) - -var ( - enum = map[bool]string{ - true: "y", - false: "n", - } -) - -// Bool represents a bool for ENUM ('y', 'n'), which can be NULL. -type Bool struct { - Bool bool - Valid bool // Valid is true if Bool is not NULL -} - -// MarshalJSON implements the json.Marshaler interface. -func (b Bool) MarshalJSON() ([]byte, error) { - if !b.Valid { - return []byte("null"), nil - } - - return MarshalJSON(b.Bool) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (b *Bool) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseUint(string(text), 10, 64) - if err != nil { - return CantParseUint64(err, string(text)) - } - - *b = Bool{parsed != 0, true} - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -func (b *Bool) UnmarshalJSON(data []byte) error { - if string(data) == "null" || len(data) == 0 { - return nil - } - - if err := UnmarshalJSON(data, &b.Bool); err != nil { - return err - } - - b.Valid = true - - return nil -} - -// Scan implements the sql.Scanner interface. -// Supports SQL NULL. -func (b *Bool) Scan(src interface{}) error { - if src == nil { - b.Bool, b.Valid = false, false - return nil - } - - v, ok := src.([]byte) - if !ok { - return errors.Errorf("bad []byte type assertion from %#v", src) - } - - switch string(v) { - case "y": - b.Bool = true - case "n": - b.Bool = false - default: - return errors.Errorf("bad bool %#v", v) - } - - b.Valid = true - - return nil -} - -// Value implements the driver.Valuer interface. -// Supports SQL NULL. -func (b Bool) Value() (driver.Value, error) { - if !b.Valid { - return nil, nil - } - - return enum[b.Bool], nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = (*Bool)(nil) - _ encoding.TextUnmarshaler = (*Bool)(nil) - _ json.Unmarshaler = (*Bool)(nil) - _ sql.Scanner = (*Bool)(nil) - _ driver.Valuer = (*Bool)(nil) -) diff --git a/pkg/types/bool_test.go b/pkg/types/bool_test.go deleted file mode 100644 index fe49588c8..000000000 --- a/pkg/types/bool_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package types - -import ( - "fmt" - "github.com/stretchr/testify/require" - "testing" - "unicode/utf8" -) - -func TestBool_MarshalJSON(t *testing.T) { - subtests := []struct { - input Bool - output string - }{ - {Bool{Bool: false, Valid: false}, `null`}, - {Bool{Bool: false, Valid: true}, `false`}, - {Bool{Bool: true, Valid: false}, `null`}, - {Bool{Bool: true, Valid: true}, `true`}, - } - - for _, st := range subtests { - t.Run(fmt.Sprintf("Bool-%#v_Valid-%#v", st.input.Bool, st.input.Valid), func(t *testing.T) { - actual, err := st.input.MarshalJSON() - - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, st.output, string(actual)) - }) - } -} diff --git a/pkg/types/float.go b/pkg/types/float.go deleted file mode 100644 index dcd51c6f0..000000000 --- a/pkg/types/float.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "strconv" -) - -// Float adds JSON support to sql.NullFloat64. -type Float struct { - sql.NullFloat64 -} - -// MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (f Float) MarshalJSON() ([]byte, error) { - var v interface{} - if f.Valid { - v = f.Float64 - } - - return MarshalJSON(v) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (f *Float) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseFloat(string(text), 64) - if err != nil { - return CantParseFloat64(err, string(text)) - } - - *f = Float{sql.NullFloat64{ - Float64: parsed, - Valid: true, - }} - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Supports JSON null. -func (f *Float) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if bytes.HasPrefix(data, []byte{'n'}) { - return nil - } - - if err := UnmarshalJSON(data, &f.Float64); err != nil { - return err - } - - f.Valid = true - - return nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = Float{} - _ encoding.TextUnmarshaler = (*Float)(nil) - _ json.Unmarshaler = (*Float)(nil) - _ driver.Valuer = Float{} - _ sql.Scanner = (*Float)(nil) -) diff --git a/pkg/types/int.go b/pkg/types/int.go deleted file mode 100644 index 448180f16..000000000 --- a/pkg/types/int.go +++ /dev/null @@ -1,67 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "strconv" -) - -// Int adds JSON support to sql.NullInt64. -type Int struct { - sql.NullInt64 -} - -// MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (i Int) MarshalJSON() ([]byte, error) { - var v interface{} - if i.Valid { - v = i.Int64 - } - - return MarshalJSON(v) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (i *Int) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseInt(string(text), 10, 64) - if err != nil { - return CantParseInt64(err, string(text)) - } - - *i = Int{sql.NullInt64{ - Int64: parsed, - Valid: true, - }} - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Supports JSON null. -func (i *Int) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if bytes.HasPrefix(data, []byte{'n'}) { - return nil - } - - if err := UnmarshalJSON(data, &i.Int64); err != nil { - return err - } - - i.Valid = true - - return nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = Int{} - _ json.Unmarshaler = (*Int)(nil) - _ encoding.TextUnmarshaler = (*Int)(nil) - _ driver.Valuer = Int{} - _ sql.Scanner = (*Int)(nil) -) diff --git a/pkg/types/string.go b/pkg/types/string.go deleted file mode 100644 index 45acb4860..000000000 --- a/pkg/types/string.go +++ /dev/null @@ -1,73 +0,0 @@ -package types - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "strings" -) - -// String adds JSON support to sql.NullString. -type String struct { - sql.NullString -} - -// MarshalJSON implements the json.Marshaler interface. -// Supports JSON null. -func (s String) MarshalJSON() ([]byte, error) { - var v interface{} - if s.Valid { - v = s.String - } - - return MarshalJSON(v) -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (s *String) UnmarshalText(text []byte) error { - *s = String{sql.NullString{ - String: string(text), - Valid: true, - }} - - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Supports JSON null. -func (s *String) UnmarshalJSON(data []byte) error { - // Ignore null, like in the main JSON package. - if bytes.HasPrefix(data, []byte{'n'}) { - return nil - } - - if err := UnmarshalJSON(data, &s.String); err != nil { - return err - } - - s.Valid = true - - return nil -} - -// Value implements the driver.Valuer interface. -// Supports SQL NULL. -func (s String) Value() (driver.Value, error) { - if !s.Valid { - return nil, nil - } - - // PostgreSQL does not allow null bytes in varchar, char and text fields. - return strings.ReplaceAll(s.String, "\x00", ""), nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = String{} - _ encoding.TextUnmarshaler = (*String)(nil) - _ json.Unmarshaler = (*String)(nil) - _ driver.Valuer = String{} - _ sql.Scanner = (*String)(nil) -) diff --git a/pkg/types/unix_milli.go b/pkg/types/unix_milli.go deleted file mode 100644 index d910a3ae5..000000000 --- a/pkg/types/unix_milli.go +++ /dev/null @@ -1,94 +0,0 @@ -package types - -import ( - "database/sql" - "database/sql/driver" - "encoding" - "encoding/json" - "github.com/icinga/icingadb/pkg/utils" - "github.com/pkg/errors" - "strconv" - "time" -) - -// UnixMilli is a nullable millisecond UNIX timestamp in databases and JSON. -type UnixMilli time.Time - -// Time returns the time.Time conversion of UnixMilli. -func (t UnixMilli) Time() time.Time { - return time.Time(t) -} - -// MarshalJSON implements the json.Marshaler interface. -// Marshals to milliseconds. Supports JSON null. -func (t UnixMilli) MarshalJSON() ([]byte, error) { - if time.Time(t).IsZero() { - return []byte("null"), nil - } - - return []byte(strconv.FormatInt(time.Time(t).UnixMilli(), 10)), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -func (t *UnixMilli) UnmarshalText(text []byte) error { - parsed, err := strconv.ParseFloat(string(text), 64) - if err != nil { - return CantParseFloat64(err, string(text)) - } - - *t = UnixMilli(utils.FromUnixMilli(int64(parsed))) - return nil -} - -// UnmarshalJSON implements the json.Unmarshaler interface. -// Unmarshals from milliseconds. Supports JSON null. -func (t *UnixMilli) UnmarshalJSON(data []byte) error { - if string(data) == "null" || len(data) == 0 { - return nil - } - - ms, err := strconv.ParseFloat(string(data), 64) - if err != nil { - return CantParseFloat64(err, string(data)) - } - tt := utils.FromUnixMilli(int64(ms)) - *t = UnixMilli(tt) - - return nil -} - -// Scan implements the sql.Scanner interface. -// Scans from milliseconds. Supports SQL NULL. -func (t *UnixMilli) Scan(src interface{}) error { - if src == nil { - return nil - } - - v, ok := src.(int64) - if !ok { - return errors.Errorf("bad int64 type assertion from %#v", src) - } - tt := utils.FromUnixMilli(v) - *t = UnixMilli(tt) - - return nil -} - -// Value implements the driver.Valuer interface. -// Returns milliseconds. Supports SQL NULL. -func (t UnixMilli) Value() (driver.Value, error) { - if t.Time().IsZero() { - return nil, nil - } - - return t.Time().UnixMilli(), nil -} - -// Assert interface compliance. -var ( - _ json.Marshaler = (*UnixMilli)(nil) - _ encoding.TextUnmarshaler = (*UnixMilli)(nil) - _ json.Unmarshaler = (*UnixMilli)(nil) - _ sql.Scanner = (*UnixMilli)(nil) - _ driver.Valuer = (*UnixMilli)(nil) -) diff --git a/pkg/types/unix_milli_test.go b/pkg/types/unix_milli_test.go deleted file mode 100644 index 985fa2ac1..000000000 --- a/pkg/types/unix_milli_test.go +++ /dev/null @@ -1,30 +0,0 @@ -package types - -import ( - "github.com/stretchr/testify/require" - "testing" - "time" - "unicode/utf8" -) - -func TestUnixMilli_MarshalJSON(t *testing.T) { - subtests := []struct { - name string - input UnixMilli - output string - }{ - {"zero", UnixMilli{}, `null`}, - {"epoch", UnixMilli(time.Unix(0, 0)), `0`}, - {"nonzero", UnixMilli(time.Unix(1234567890, 62500000)), `1234567890062`}, - } - - for _, st := range subtests { - t.Run(st.name, func(t *testing.T) { - actual, err := st.input.MarshalJSON() - - require.NoError(t, err) - require.True(t, utf8.Valid(actual)) - require.Equal(t, st.output, string(actual)) - }) - } -} diff --git a/pkg/types/utils.go b/pkg/types/utils.go deleted file mode 100644 index 467eae312..000000000 --- a/pkg/types/utils.go +++ /dev/null @@ -1,57 +0,0 @@ -package types - -import ( - "encoding/json" - "fmt" - "github.com/pkg/errors" - "strings" -) - -// Name returns the declared name of type t. -func Name(t any) string { - s := strings.TrimLeft(fmt.Sprintf("%T", t), "*") - - return s[strings.LastIndex(s, ".")+1:] -} - -// CantDecodeHex wraps the given error with the given string that cannot be hex-decoded. -func CantDecodeHex(err error, s string) error { - return errors.Wrapf(err, "can't decode hex %q", s) -} - -// CantParseFloat64 wraps the given error with the specified string that cannot be parsed into float64. -func CantParseFloat64(err error, s string) error { - return errors.Wrapf(err, "can't parse %q into float64", s) -} - -// CantParseInt64 wraps the given error with the specified string that cannot be parsed into int64. -func CantParseInt64(err error, s string) error { - return errors.Wrapf(err, "can't parse %q into int64", s) -} - -// CantParseUint64 wraps the given error with the specified string that cannot be parsed into uint64. -func CantParseUint64(err error, s string) error { - return errors.Wrapf(err, "can't parse %q into uint64", s) -} - -// CantUnmarshalYAML wraps the given error with the designated value, which cannot be unmarshalled into. -func CantUnmarshalYAML(err error, v interface{}) error { - return errors.Wrapf(err, "can't unmarshal YAML into %T", v) -} - -// MarshalJSON calls json.Marshal and wraps any resulting errors. -func MarshalJSON(v interface{}) ([]byte, error) { - b, err := json.Marshal(v) - - return b, errors.Wrapf(err, "can't marshal JSON from %T", v) -} - -// UnmarshalJSON calls json.Unmarshal and wraps any resulting errors. -func UnmarshalJSON(data []byte, v interface{}) error { - return errors.Wrapf(json.Unmarshal(data, v), "can't unmarshal JSON into %T", v) -} - -// Zero returns the zero value for type T. -func Zero[T any]() T { - return *new(T) -} diff --git a/pkg/types/uuid.go b/pkg/types/uuid.go deleted file mode 100644 index 02acbcdb1..000000000 --- a/pkg/types/uuid.go +++ /dev/null @@ -1,24 +0,0 @@ -package types - -import ( - "database/sql/driver" - "encoding" - "github.com/google/uuid" -) - -// UUID is like uuid.UUID, but marshals itself binarily (not like xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx) in SQL context. -type UUID struct { - uuid.UUID -} - -// Value implements driver.Valuer. -func (uuid UUID) Value() (driver.Value, error) { - return uuid.UUID[:], nil -} - -// Assert interface compliance. -var ( - _ encoding.TextUnmarshaler = (*UUID)(nil) - _ driver.Valuer = UUID{} - _ driver.Valuer = (*UUID)(nil) -) diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go deleted file mode 100644 index 22c80a0a8..000000000 --- a/pkg/utils/utils.go +++ /dev/null @@ -1,172 +0,0 @@ -package utils - -import ( - "context" - "crypto/sha1" - "fmt" - "github.com/go-sql-driver/mysql" - "github.com/lib/pq" - "github.com/pkg/errors" - "golang.org/x/exp/utf8string" - "math" - "net" - "os" - "path/filepath" - "strings" - "time" -) - -// FromUnixMilli creates and returns a time.Time value -// from the given milliseconds since the Unix epoch ms. -func FromUnixMilli(ms int64) time.Time { - sec, dec := math.Modf(float64(ms) / 1e3) - - return time.Unix(int64(sec), int64(dec*(1e9))) -} - -// Timed calls the given callback with the time that has elapsed since the start. -// -// Timed should be installed by defer: -// -// func TimedExample(logger *zap.SugaredLogger) { -// defer utils.Timed(time.Now(), func(elapsed time.Duration) { -// logger.Debugf("Executed job in %s", elapsed) -// }) -// job() -// } -func Timed(start time.Time, callback func(elapsed time.Duration)) { - callback(time.Since(start)) -} - -// BatchSliceOfStrings groups the given keys into chunks of size count and streams them into a returned channel. -func BatchSliceOfStrings(ctx context.Context, keys []string, count int) <-chan []string { - batches := make(chan []string) - - go func() { - defer close(batches) - - for i := 0; i < len(keys); i += count { - end := i + count - if end > len(keys) { - end = len(keys) - } - - select { - case batches <- keys[i:end]: - case <-ctx.Done(): - return - } - } - }() - - return batches -} - -// IsContextCanceled returns whether the given error is context.Canceled. -func IsContextCanceled(err error) bool { - return errors.Is(err, context.Canceled) -} - -// Checksum returns the SHA-1 checksum of the data. -func Checksum(data interface{}) []byte { - var chksm [sha1.Size]byte - - switch data := data.(type) { - case string: - chksm = sha1.Sum([]byte(data)) - case []byte: - chksm = sha1.Sum(data) - default: - panic(fmt.Sprintf("Unable to create checksum for type %T", data)) - } - - return chksm[:] -} - -// Fatal panics with the given error. -func Fatal(err error) { - panic(err) -} - -// IsDeadlock returns whether the given error signals serialization failure. -func IsDeadlock(err error) bool { - var e *mysql.MySQLError - if errors.As(err, &e) { - switch e.Number { - case 1205, 1213: - return true - default: - return false - } - } - - var pe *pq.Error - if errors.As(err, &pe) { - switch pe.Code { - case "40001", "40P01": - return true - } - } - - return false -} - -var ellipsis = utf8string.NewString("...") - -// Ellipsize shortens s to <=limit runes and indicates shortening by "...". -func Ellipsize(s string, limit int) string { - utf8 := utf8string.NewString(s) - switch { - case utf8.RuneCount() <= limit: - return s - case utf8.RuneCount() <= ellipsis.RuneCount(): - return ellipsis.String() - default: - return utf8.Slice(0, limit-ellipsis.RuneCount()) + ellipsis.String() - } -} - -// AppName returns the name of the executable that started this program (process). -func AppName() string { - exe, err := os.Executable() - if err != nil { - exe = os.Args[0] - } - - return filepath.Base(exe) -} - -// MaxInt returns the larger of the given integers. -func MaxInt(x, y int) int { - if x > y { - return x - } - - return y -} - -func IsUnixAddr(host string) bool { - return strings.HasPrefix(host, "/") -} - -// JoinHostPort is like its equivalent in net., but handles UNIX sockets as well. -func JoinHostPort(host string, port int) string { - if IsUnixAddr(host) { - return host - } - - return net.JoinHostPort(host, fmt.Sprint(port)) -} - -// ChanFromSlice takes a slice of values and returns a channel from which these values can be received. -// This channel is closed after the last value was sent. -func ChanFromSlice[T any](values []T) <-chan T { - ch := make(chan T, len(values)) - for _, value := range values { - ch <- value - } - - close(ch) - - return ch -} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go deleted file mode 100644 index b0ea54b8f..000000000 --- a/pkg/utils/utils_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package utils - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestChanFromSlice(t *testing.T) { - t.Run("Nil", func(t *testing.T) { - ch := ChanFromSlice[int](nil) - require.NotNil(t, ch) - requireClosedEmpty(t, ch) - }) - - t.Run("Empty", func(t *testing.T) { - ch := ChanFromSlice([]int{}) - require.NotNil(t, ch) - requireClosedEmpty(t, ch) - }) - - t.Run("NonEmpty", func(t *testing.T) { - ch := ChanFromSlice([]int{42, 23, 1337}) - require.NotNil(t, ch) - requireReceive(t, ch, 42) - requireReceive(t, ch, 23) - requireReceive(t, ch, 1337) - requireClosedEmpty(t, ch) - }) -} - -// requireReceive is a helper function to check if a value can immediately be received from a channel. -func requireReceive(t *testing.T, ch <-chan int, expected int) { - t.Helper() - - select { - case v, ok := <-ch: - require.True(t, ok, "receiving should return a value") - require.Equal(t, expected, v) - default: - require.Fail(t, "receiving should not block") - } -} - -// requireReceive is a helper function to check if the channel is closed and empty. -func requireClosedEmpty(t *testing.T, ch <-chan int) { - t.Helper() - - select { - case _, ok := <-ch: - require.False(t, ok, "receiving from channel should not return anything") - default: - require.Fail(t, "receiving should not block") - } -} diff --git a/pkg/version/version.go b/pkg/version/version.go deleted file mode 100644 index 250318c7e..000000000 --- a/pkg/version/version.go +++ /dev/null @@ -1,180 +0,0 @@ -package version - -import ( - "bufio" - "errors" - "fmt" - "os" - "runtime" - "runtime/debug" - "strconv" - "strings" -) - -type VersionInfo struct { - Version string - Commit string -} - -// Version determines version and commit information based on multiple data sources: -// - Version information dynamically added by `git archive` in the remaining to parameters. -// - A hardcoded version number passed as first parameter. -// - Commit information added to the binary by `go build`. -// -// It's supposed to be called like this in combination with setting the `export-subst` attribute for the corresponding -// file in .gitattributes: -// -// var Version = version.Version("1.0.0-rc2", "$Format:%(describe)$", "$Format:%H$") -// -// When exported using `git archive`, the placeholders are replaced in the file and this version information is -// preferred. Otherwise the hardcoded version is used and augmented with commit information from the build metadata. -func Version(version, gitDescribe, gitHash string) *VersionInfo { - const hashLen = 7 // Same truncation length for the commit hash as used by git describe. - - if !strings.HasPrefix(gitDescribe, "$") && !strings.HasPrefix(gitHash, "$") { - if strings.HasPrefix(gitDescribe, "%") { - // Only Git 2.32+ supports %(describe), older versions don't expand it but keep it as-is. - // Fall back to the hardcoded version augmented with the commit hash. - gitDescribe = version - - if len(gitHash) >= hashLen { - gitDescribe += "-g" + gitHash[:hashLen] - } - } - - return &VersionInfo{ - Version: gitDescribe, - Commit: gitHash, - } - } else { - commit := "" - - if info, ok := debug.ReadBuildInfo(); ok { - modified := false - - for _, setting := range info.Settings { - switch setting.Key { - case "vcs.revision": - commit = setting.Value - case "vcs.modified": - modified, _ = strconv.ParseBool(setting.Value) - } - } - - if len(commit) >= hashLen { - version += "-g" + commit[:hashLen] - - if modified { - version += "-dirty" - commit += " (modified)" - } - } - } - - return &VersionInfo{ - Version: version, - Commit: commit, - } - } -} - -// Print writes verbose version output to stdout. -func (v *VersionInfo) Print() { - fmt.Println("Icinga DB version:", v.Version) - fmt.Println() - - fmt.Println("Build information:") - fmt.Printf(" Go version: %s (%s, %s)\n", runtime.Version(), runtime.GOOS, runtime.GOARCH) - if v.Commit != "" { - fmt.Println(" Git commit:", v.Commit) - } - - if r, err := readOsRelease(); err == nil { - fmt.Println() - fmt.Println("System information:") - fmt.Println(" Platform:", r.Name) - fmt.Println(" Platform version:", r.DisplayVersion()) - } -} - -// osRelease contains the information obtained from the os-release file. -type osRelease struct { - Name string - Version string - VersionId string - BuildId string -} - -// DisplayVersion returns the most suitable version information for display purposes. -func (o *osRelease) DisplayVersion() string { - if o.Version != "" { - // Most distributions set VERSION - return o.Version - } else if o.VersionId != "" { - // Some only set VERSION_ID (Alpine Linux for example) - return o.VersionId - } else if o.BuildId != "" { - // Others only set BUILD_ID (Arch Linux for example) - return o.BuildId - } else { - return "(unknown)" - } -} - -// readOsRelease reads and parses the os-release file. -func readOsRelease() (*osRelease, error) { - for _, path := range []string{"/etc/os-release", "/usr/lib/os-release"} { - f, err := os.Open(path) - if err != nil { - if os.IsNotExist(err) { - continue // Try next path. - } else { - return nil, err - } - } - - o := &osRelease{ - Name: "Linux", // Suggested default as per os-release(5) man page. - } - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "#") { - continue // Ignore comment. - } - - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue // Ignore empty or possibly malformed line. - } - - key := parts[0] - val := parts[1] - - // Unquote strings. This isn't fully compliant with the specification which allows using some shell escape - // sequences. However, typically quotes are only used to allow whitespace within the value. - if len(val) >= 2 && (val[0] == '"' || val[0] == '\'') && val[0] == val[len(val)-1] { - val = val[1 : len(val)-1] - } - - switch key { - case "NAME": - o.Name = val - case "VERSION": - o.Version = val - case "VERSION_ID": - o.VersionId = val - case "BUILD_ID": - o.BuildId = val - } - } - if err := scanner.Err(); err != nil { - return nil, err - } - - return o, nil - } - - return nil, errors.New("os-release file not found") -}