diff --git a/cache/cache.go b/cache/cache.go index 74b7bf08..829b4d0d 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -17,6 +17,7 @@ import ( "github.com/ovn-org/libovsdb/mapper" "github.com/ovn-org/libovsdb/model" "github.com/ovn-org/libovsdb/ovsdb" + "github.com/ovn-org/libovsdb/util" ) const ( @@ -718,6 +719,42 @@ func (t *TableCache) Purge(dbModel model.DatabaseModel) { } } +// PurgeTable drops all data in the given table's cache and reinitializes it using the +// provided database model +func (t *TableCache) PurgeTable(dbModel model.DatabaseModel, name string) error { + return t.PurgeTableRows(dbModel, name, nil) +} + +// PurgeTableRows drops all rows in the given table's cache that match the given conditions +func (t *TableCache) PurgeTableRows(dbModel model.DatabaseModel, name string, conditions []ovsdb.Condition) error { + t.mutex.Lock() + defer t.mutex.Unlock() + t.dbModel = dbModel + tableTypes := t.dbModel.Types() + dataType, ok := tableTypes[name] + if !ok { + return fmt.Errorf("table %s not found", name) + } + if len(conditions) == 0 { + t.cache[name] = newRowCache(name, t.dbModel, dataType) + return nil + } + + r := t.cache[name] + rows, err := r.RowsByCondition(conditions) + if err != nil { + return err + } + delErrors := []error{} + for uuid := range rows { + if err := r.Delete(uuid); err != nil { + delErrors = append(delErrors, fmt.Errorf("failed to delete %s: %w", uuid, err)) + } + } + + return util.CombineErrors(delErrors, "failed to delete rows") +} + // AddEventHandler registers the supplied EventHandler to receive cache events func (t *TableCache) AddEventHandler(handler EventHandler) { t.eventProcessor.AddEventHandler(handler) diff --git a/client/client.go b/client/client.go index 1672c311..fc6f6e1b 100644 --- a/client/client.go +++ b/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/ovn-org/libovsdb/model" "github.com/ovn-org/libovsdb/ovsdb" "github.com/ovn-org/libovsdb/ovsdb/serverdb" + "github.com/ovn-org/libovsdb/util" ) // Constants defined for libovsdb @@ -260,15 +261,7 @@ func (o *ovsdbClient) connect(ctx context.Context, reconnect bool) error { } if !connected { - if len(connectErrors) == 1 { - return connectErrors[0] - } - var combined []string - for _, e := range connectErrors { - combined = append(combined, e.Error()) - } - - return fmt.Errorf("unable to connect to any endpoints: %s", strings.Join(combined, ". ")) + return util.CombineErrors(connectErrors, "unable to connect to any endpoints") } // if we're reconnecting, re-start all the monitors @@ -371,8 +364,6 @@ func (o *ovsdbClient) tryEndpoint(ctx context.Context, u *url.URL) (string, erro return "", err } db.api = newAPI(db.cache, o.logger) - } else { - db.cache.Purge(db.model) } db.cacheMutex.Unlock() } @@ -809,6 +800,51 @@ func (o *ovsdbClient) Monitor(ctx context.Context, monitor *Monitor) (MonitorCoo return cookie, o.monitor(ctx, cookie, false, monitor) } +func (db *database) getMonitorTableConditions(tm TableMonitor) (*ovsdb.Condition, error) { + model, err := db.model.NewModel(tm.Table) + if err != nil { + return nil, err + } + info, err := db.model.NewModelInfo(model) + if err != nil { + return nil, err + } + return db.model.Mapper.NewCondition(info, tm.Condition.Field, tm.Condition.Function, tm.Condition.Value) +} + +// purge removes all rows from the row cache that match the monitor +func (o *ovsdbClient) purge(db *database, monitor *Monitor) { + if len(monitor.Tables) == 0 { + db.cache.Purge(db.model) + return + } + + var err error + for _, tm := range monitor.Tables { + if monitor.Method == ovsdb.ConditionalMonitorSinceRPC { + var cond *ovsdb.Condition + cond, err = db.getMonitorTableConditions(tm) + if err != nil { + break + } + err = db.cache.PurgeTableRows(db.model, tm.Table, []ovsdb.Condition{*cond}) + if err != nil { + break + } + } else { + err = db.cache.PurgeTable(db.model, tm.Table) + if err != nil { + break + } + } + } + + if err != nil { + o.logger.V(3).Error(err, "failed to purge database") + db.cache.Purge(db.model) + } +} + //gocyclo:ignore // monitor must only be called with a lock on monitorsMutex func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconnecting bool, monitor *Monitor) error { @@ -859,12 +895,7 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne var args []interface{} if monitor.Method == ovsdb.ConditionalMonitorSinceRPC { - // FIXME: We should pass the monitor.LastTransactionID here - // But that would require delaying clearing the cache until - // after the monitors have been re-established - the logic - // would also need to be different for monitor and monitor_cond - // as we must always clear the cache in that instance - args = ovsdb.NewMonitorCondSinceArgs(dbName, cookie, requests, emptyUUID) + args = ovsdb.NewMonitorCondSinceArgs(dbName, cookie, requests, monitor.LastTransactionID) } else { args = ovsdb.NewMonitorArgs(dbName, cookie, requests) } @@ -873,18 +904,24 @@ func (o *ovsdbClient) monitor(ctx context.Context, cookie MonitorCookie, reconne switch monitor.Method { case ovsdb.MonitorRPC: + o.purge(db, monitor) var reply ovsdb.TableUpdates err = o.rpcClient.CallWithContext(ctx, monitor.Method, args, &reply) tableUpdates = reply case ovsdb.ConditionalMonitorRPC: + o.purge(db, monitor) var reply ovsdb.TableUpdates2 err = o.rpcClient.CallWithContext(ctx, monitor.Method, args, &reply) tableUpdates = reply case ovsdb.ConditionalMonitorSinceRPC: var reply ovsdb.MonitorCondSinceReply err = o.rpcClient.CallWithContext(ctx, monitor.Method, args, &reply) - if err == nil && reply.Found { - monitor.LastTransactionID = reply.LastTransactionID + if err == nil { + if reply.Found { + monitor.LastTransactionID = reply.LastTransactionID + } else { + o.purge(db, monitor) + } } tableUpdates = reply.Updates default: diff --git a/util/errors.go b/util/errors.go new file mode 100644 index 00000000..a7d3a7ba --- /dev/null +++ b/util/errors.go @@ -0,0 +1,20 @@ +package util + +import ( + "fmt" + "strings" +) + +func CombineErrors(errors []error, msg string) error { + if len(errors) == 0 { + return nil + } else if len(errors) == 1 { + return errors[0] + } + + var combined []string + for _, e := range errors { + combined = append(combined, e.Error()) + } + return fmt.Errorf("%s: %s", msg, strings.Join(combined, ". ")) +}