diff --git a/.gitignore b/.gitignore index 92c85e4c..a81cd81e 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ Dockerfile* docker-compose.yml golangci.yml cover*.out +.vscode/settings.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 19f1d9af..c32e9daa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # Change History +## September 13 2024: v7.7.0 + + Minor improvement release. + +- **Improvements** + - [CLIENT-3112] Correctly handle new error messages/error codes returned by AS 7.2. + - [CLIENT-3102] Add "XDR key busy" error code 32. + - [CLIENT-3119] Use Generics For a General Code Clean Up + Uses several new generic containers to simplify concurrent access in the client. + Uses a Guard as a monitor for the tend connection. This encapsulates synchronized tend connection management using said Guard. + - Add documentation about client.WarmUp to the client initialization API. + +- **Fixes** + - [CLIENT-3082] BatchGet with empty Keys raises gRPC EOF error. + ## August 12 2024: v7.6.1 Minor improvement release. diff --git a/client.go b/client.go index 54f4b6f0..60062cbb 100644 --- a/client.go +++ b/client.go @@ -72,12 +72,22 @@ func clientFinalizer(f *Client) { //------------------------------------------------------- // NewClient generates a new Client instance. +// The connection pool after connecting to the database is initially empty, +// and connections are established on a per need basis, which can be slow and +// time out some initial commands. +// It is recommended to call the client.WarmUp() method right after connecting to the database +// to fill up the connection pool to the required service level. func NewClient(hostname string, port int) (*Client, Error) { return NewClientWithPolicyAndHost(NewClientPolicy(), NewHost(hostname, port)) } // NewClientWithPolicy generates a new Client using the specified ClientPolicy. // If the policy is nil, the default relevant policy will be used. +// The connection pool after connecting to the database is initially empty, +// and connections are established on a per need basis, which can be slow and +// time out some initial commands. +// It is recommended to call the client.WarmUp() method right after connecting to the database +// to fill up the connection pool to the required service level. func NewClientWithPolicy(policy *ClientPolicy, hostname string, port int) (*Client, Error) { return NewClientWithPolicyAndHost(policy, NewHost(hostname, port)) } @@ -85,6 +95,11 @@ func NewClientWithPolicy(policy *ClientPolicy, hostname string, port int) (*Clie // NewClientWithPolicyAndHost generates a new Client with the specified ClientPolicy and // sets up the cluster using the provided hosts. // If the policy is nil, the default relevant policy will be used. +// The connection pool after connecting to the database is initially empty, +// and connections are established on a per need basis, which can be slow and +// time out some initial commands. +// It is recommended to call the client.WarmUp() method right after connecting to the database +// to fill up the connection pool to the required service level. func NewClientWithPolicyAndHost(policy *ClientPolicy, hosts ...*Host) (*Client, Error) { if policy == nil { policy = NewClientPolicy() @@ -846,23 +861,31 @@ func (clnt *Client) RegisterUDF(policy *WritePolicy, udfBody []byte, serverPath } response := responseMap[strCmd.String()] + if strings.EqualFold(response, "ok") { + return NewRegisterTask(clnt.cluster, serverPath), nil + } + + err = parseInfoErrorCode(response) + res := make(map[string]string) - vals := strings.Split(response, ";") + vals := strings.Split("error="+err.Error(), ";") for _, pair := range vals { t := strings.SplitN(pair, "=", 2) if len(t) == 2 { - res[t[0]] = t[1] + res[strings.ToLower(t[0])] = t[1] } else if len(t) == 1 { - res[t[0]] = "" + res[strings.ToLower(t[0])] = "" } } if _, exists := res["error"]; exists { msg, _ := base64.StdEncoding.DecodeString(res["message"]) - return nil, newError(types.COMMAND_REJECTED, fmt.Sprintf("Registration failed: %s\nFile: %s\nLine: %s\nMessage: %s", + return nil, newError(err.resultCode(), fmt.Sprintf("Registration failed: %s\nFile: %s\nLine: %s\nMessage: %s", res["error"], res["file"], res["line"], msg)) } - return NewRegisterTask(clnt.cluster, serverPath), nil + + // if message was not parsable + return nil, parseInfoErrorCode(response) } // RemoveUDF removes a package containing user defined functions in the server. @@ -888,10 +911,10 @@ func (clnt *Client) RemoveUDF(policy *WritePolicy, udfName string) (*RemoveTask, } response := responseMap[strCmd.String()] - if response == "ok" { + if strings.EqualFold(response, "ok") { return NewRemoveTask(clnt.cluster, udfName), nil } - return nil, newError(types.SERVER_ERROR, response) + return nil, parseInfoErrorCode(response) } // ListUDF lists all packages containing user defined functions in the server. @@ -1119,36 +1142,34 @@ func (clnt *Client) SetXDRFilter(policy *InfoPolicy, datacenter string, namespac return nil } - return parseIndexErrorCode(response) + return parseInfoErrorCode(response) } -var indexErrRegexp = regexp.MustCompile(`(?i)(fail|error)(:[0-9]+)?(:.+)?`) +var infoErrRegexp = regexp.MustCompile(`(?i)(fail|error)((:|=)(?P[0-9]+))?((:|=)(?P.+))?`) + +func parseInfoErrorCode(response string) Error { + match := infoErrRegexp.FindStringSubmatch(response) -func parseIndexErrorCode(response string) Error { var code = types.SERVER_ERROR var message = response - match := indexErrRegexp.FindStringSubmatch(response) - - // invalid response - if len(match) != 4 { - return newError(types.PARSE_ERROR, response) - } - - // error code - if len(match[2]) > 0 { - i, err := strconv.ParseInt(string(match[2][1:]), 10, 64) - if err == nil { - code = types.ResultCode(i) - message = types.ResultCodeToString(code) + if len(match) > 0 { + for i, name := range infoErrRegexp.SubexpNames() { + if i != 0 && name != "" && len(match[i]) > 0 { + switch name { + case "code": + i, err := strconv.ParseInt(match[i], 10, 64) + if err == nil { + code = types.ResultCode(i) + message = types.ResultCodeToString(code) + } + case "msg": + message = match[i] + } + } } } - // message - if len(match[3]) > 0 { - message = string(match[3][1:]) - } - return newError(code, message) } @@ -1310,7 +1331,7 @@ func (clnt *Client) CreateComplexIndex( return NewIndexTask(clnt.cluster, namespace, indexName), nil } - return nil, parseIndexErrorCode(response) + return nil, parseInfoErrorCode(response) } // DropIndex deletes a secondary index. It will block until index is dropped on all nodes. @@ -1348,7 +1369,7 @@ func (clnt *Client) DropIndex( return <-task.OnComplete() } - err = parseIndexErrorCode(response) + err = parseInfoErrorCode(response) if err.Matches(types.INDEX_NOTFOUND) { // Index did not previously exist. Return without error. return nil @@ -1366,18 +1387,6 @@ func (clnt *Client) DropIndex( func (clnt *Client) Truncate(policy *InfoPolicy, namespace, set string, beforeLastUpdate *time.Time) Error { policy = clnt.getUsableInfoPolicy(policy) - node, err := clnt.cluster.GetRandomNode() - if err != nil { - return err - } - - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err = node.initTendConn(policy.Timeout); err != nil { - return err - } - var strCmd bytes.Buffer if len(set) > 0 { strCmd.WriteString("truncate:namespace=") @@ -1393,9 +1402,8 @@ func (clnt *Client) Truncate(policy *InfoPolicy, namespace, set string, beforeLa strCmd.WriteString(strconv.FormatInt(beforeLastUpdate.UnixNano(), 10)) } - responseMap, err := node.tendConn.RequestInfo(strCmd.String()) + responseMap, err := clnt.sendInfoCommand(policy.Timeout, strCmd.String()) if err != nil { - node.tendConn.Close() return err } @@ -1404,7 +1412,7 @@ func (clnt *Client) Truncate(policy *InfoPolicy, namespace, set string, beforeLa return nil } - return newError(types.SERVER_ERROR, "Truncate failed: "+response) + return parseInfoErrorCode(response) } //------------------------------------------------------- @@ -1426,15 +1434,13 @@ func (clnt *Client) CreateUser(policy *AdminPolicy, user string, password string if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.createUser(conn, policy, user, hash, roles) + }) - command := NewAdminCommand(nil) - return command.createUser(node.tendConn, policy, user, hash, roles) + return err } // DropUser removes a user from the cluster. @@ -1446,15 +1452,12 @@ func (clnt *Client) DropUser(policy *AdminPolicy, user string) Error { if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } - - command := NewAdminCommand(nil) - return command.dropUser(node.tendConn, policy, user) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.dropUser(conn, policy, user) + }) + return err } // ChangePassword changes a user's password. Clear-text password will be hashed using bcrypt before sending to server. @@ -1475,30 +1478,24 @@ func (clnt *Client) ChangePassword(policy *AdminPolicy, user string, password st if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return err - } - command := NewAdminCommand(nil) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) - if user == clnt.cluster.user { - // Change own password. - if err := command.changePassword(node.tendConn, policy, user, clnt.cluster.Password(), hash); err != nil { - return err + if user == clnt.cluster.user { + // Change own password. + err = command.changePassword(conn, policy, user, clnt.cluster.Password(), hash) + } else { + // Change other user's password by user admin. + err = command.setPassword(conn, policy, user, hash) } - } else { - // Change other user's password by user admin. - if err := command.setPassword(node.tendConn, policy, user, hash); err != nil { - return err - } - } + }) - clnt.cluster.changePassword(user, password, hash) + if err == nil { + clnt.cluster.changePassword(user, password, hash) + } - return nil + return err } // GrantRoles adds roles to user's list of roles. @@ -1510,15 +1507,12 @@ func (clnt *Client) GrantRoles(policy *AdminPolicy, user string, roles []string) if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } - - command := NewAdminCommand(nil) - return command.grantRoles(node.tendConn, policy, user, roles) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.grantRoles(conn, policy, user, roles) + }) + return err } // RevokeRoles removes roles from user's list of roles. @@ -1530,19 +1524,17 @@ func (clnt *Client) RevokeRoles(policy *AdminPolicy, user string, roles []string if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.revokeRoles(conn, policy, user, roles) + }) - command := NewAdminCommand(nil) - return command.revokeRoles(node.tendConn, policy, user, roles) + return err } // QueryUser retrieves roles for a given user. -func (clnt *Client) QueryUser(policy *AdminPolicy, user string) (*UserRoles, Error) { +func (clnt *Client) QueryUser(policy *AdminPolicy, user string) (res *UserRoles, err Error) { policy = clnt.getUsableAdminPolicy(policy) // prepare the node.tendConn @@ -1550,19 +1542,16 @@ func (clnt *Client) QueryUser(policy *AdminPolicy, user string) (*UserRoles, Err if err != nil { return nil, err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return nil, err - } - command := NewAdminCommand(nil) - return command.QueryUser(node.tendConn, policy, user) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + res, err = command.QueryUser(conn, policy, user) + }) + return res, err } // QueryUsers retrieves all users and their roles. -func (clnt *Client) QueryUsers(policy *AdminPolicy) ([]*UserRoles, Error) { +func (clnt *Client) QueryUsers(policy *AdminPolicy) (res []*UserRoles, err Error) { policy = clnt.getUsableAdminPolicy(policy) // prepare the node.tendConn @@ -1570,19 +1559,16 @@ func (clnt *Client) QueryUsers(policy *AdminPolicy) ([]*UserRoles, Error) { if err != nil { return nil, err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return nil, err - } - command := NewAdminCommand(nil) - return command.QueryUsers(node.tendConn, policy) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + res, err = command.QueryUsers(conn, policy) + }) + return res, err } // QueryRole retrieves privileges for a given role. -func (clnt *Client) QueryRole(policy *AdminPolicy, role string) (*Role, Error) { +func (clnt *Client) QueryRole(policy *AdminPolicy, role string) (res *Role, err Error) { policy = clnt.getUsableAdminPolicy(policy) // prepare the node.tendConn @@ -1590,19 +1576,16 @@ func (clnt *Client) QueryRole(policy *AdminPolicy, role string) (*Role, Error) { if err != nil { return nil, err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return nil, err - } - command := NewAdminCommand(nil) - return command.QueryRole(node.tendConn, policy, role) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + res, err = command.QueryRole(conn, policy, role) + }) + return res, err } // QueryRoles retrieves all roles and their privileges. -func (clnt *Client) QueryRoles(policy *AdminPolicy) ([]*Role, Error) { +func (clnt *Client) QueryRoles(policy *AdminPolicy) (res []*Role, err Error) { policy = clnt.getUsableAdminPolicy(policy) // prepare the node.tendConn @@ -1610,15 +1593,12 @@ func (clnt *Client) QueryRoles(policy *AdminPolicy) ([]*Role, Error) { if err != nil { return nil, err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return nil, err - } - - command := NewAdminCommand(nil) - return command.QueryRoles(node.tendConn, policy) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + res, err = command.QueryRoles(conn, policy) + }) + return res, err } // CreateRole creates a user-defined role. @@ -1632,15 +1612,12 @@ func (clnt *Client) CreateRole(policy *AdminPolicy, roleName string, privileges if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return err - } - command := NewAdminCommand(nil) - return command.createRole(node.tendConn, policy, roleName, privileges, whitelist, readQuota, writeQuota) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.createRole(conn, policy, roleName, privileges, whitelist, readQuota, writeQuota) + }) + return err } // DropRole removes a user-defined role. @@ -1652,15 +1629,12 @@ func (clnt *Client) DropRole(policy *AdminPolicy, roleName string) Error { if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } - - command := NewAdminCommand(nil) - return command.dropRole(node.tendConn, policy, roleName) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.dropRole(conn, policy, roleName) + }) + return err } // GrantPrivileges grant privileges to a user-defined role. @@ -1672,15 +1646,12 @@ func (clnt *Client) GrantPrivileges(policy *AdminPolicy, roleName string, privil if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } - - command := NewAdminCommand(nil) - return command.grantPrivileges(node.tendConn, policy, roleName, privileges) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.grantPrivileges(conn, policy, roleName, privileges) + }) + return err } // RevokePrivileges revokes privileges from a user-defined role. @@ -1692,15 +1663,12 @@ func (clnt *Client) RevokePrivileges(policy *AdminPolicy, roleName string, privi if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return err - } - command := NewAdminCommand(nil) - return command.revokePrivileges(node.tendConn, policy, roleName, privileges) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.revokePrivileges(conn, policy, roleName, privileges) + }) + return err } // SetWhitelist sets IP address whitelist for a role. If whitelist is nil or empty, it removes existing whitelist from role. @@ -1712,15 +1680,12 @@ func (clnt *Client) SetWhitelist(policy *AdminPolicy, roleName string, whitelist if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - if err := node.initTendConn(time.Second); err != nil { - return err - } - - command := NewAdminCommand(nil) - return command.setWhitelist(node.tendConn, policy, roleName, whitelist) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.setWhitelist(conn, policy, roleName, whitelist) + }) + return err } // SetQuotas sets maximum reads/writes per second limits for a role. If a quota is zero, the limit is removed. @@ -1734,15 +1699,12 @@ func (clnt *Client) SetQuotas(policy *AdminPolicy, roleName string, readQuota, w if err != nil { return err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err := node.initTendConn(time.Second); err != nil { - return err - } - command := NewAdminCommand(nil) - return command.setQuotas(node.tendConn, policy, roleName, readQuota, writeQuota) + node.usingTendConn(policy.Timeout, func(conn *Connection) { + command := NewAdminCommand(nil) + err = command.setQuotas(conn, policy, roleName, readQuota, writeQuota) + }) + return err } //------------------------------------------------------- @@ -1830,20 +1792,8 @@ func (clnt *Client) sendInfoCommand(timeout time.Duration, command string) (map[ return nil, err } - node.tendConnLock.Lock() - defer node.tendConnLock.Unlock() - - if err = node.initTendConn(timeout); err != nil { - return nil, err - } - - results, err := node.tendConn.RequestInfo(command) - if err != nil { - node.tendConn.Close() - return nil, err - } - - return results, nil + policy := InfoPolicy{Timeout: timeout} + return node.RequestInfo(&policy, command) } //------------------------------------------------------- diff --git a/client_test.go b/client_test.go index 93f4b6c5..3b04ba79 100644 --- a/client_test.go +++ b/client_test.go @@ -63,9 +63,9 @@ var _ = gg.Describe("Aerospike", func() { var actualClusterName string - gg.Describe("Client IndexErrorParser", func() { + gg.Describe("Client InfoErrorParser", func() { - gg.It("must parse IndexError response strings", func() { + gg.It("must parse InfoError response strings", func() { type t struct { r string code types.ResultCode @@ -73,7 +73,7 @@ var _ = gg.Describe("Aerospike", func() { } responses := []t{ - {"invalid", types.PARSE_ERROR, "invalid"}, + {"invalid", types.SERVER_ERROR, "invalid"}, {"FAIL", types.SERVER_ERROR, "FAIL"}, {"FAiL", types.SERVER_ERROR, "FAiL"}, {"Error", types.SERVER_ERROR, "Error"}, @@ -83,10 +83,12 @@ var _ = gg.Describe("Aerospike", func() { {"ERROR:200", types.INDEX_FOUND, "Index already exists"}, {"FAIL:201", types.INDEX_NOTFOUND, "Index not found"}, {"FAIL:201:some message from the server", types.INDEX_NOTFOUND, "some message from the server"}, + {"FAIL:some message from the server", types.SERVER_ERROR, "some message from the server"}, + {"error:some message from the server", types.SERVER_ERROR, "some message from the server"}, } for _, r := range responses { - err := as.ParseIndexErrorCode(r.r) + err := as.ParseInfoErrorCode(r.r) gm.Expect(err).To(gm.HaveOccurred()) gm.Expect(err.(*as.AerospikeError).Msg()).To(gm.Equal(r.err)) gm.Expect(err.Matches(r.code)).To(gm.BeTrue()) diff --git a/cluster.go b/cluster.go index 6de03013..c98b8a98 100644 --- a/cluster.go +++ b/cluster.go @@ -25,6 +25,8 @@ import ( "golang.org/x/sync/errgroup" iatomic "github.com/aerospike/aerospike-client-go/v7/internal/atomic" + sm "github.com/aerospike/aerospike-client-go/v7/internal/atomic/map" + "github.com/aerospike/aerospike-client-go/v7/internal/seq" "github.com/aerospike/aerospike-client-go/v7/logger" "github.com/aerospike/aerospike-client-go/v7/types" ) @@ -33,27 +35,27 @@ import ( // them. type Cluster struct { // Initial host nodes specified by user. - seeds iatomic.SyncVal //[]*Host + seeds iatomic.SyncVal[[]*Host] // All aliases for all nodes in cluster. // Only accessed within cluster tend goroutine. - aliases iatomic.SyncVal //map[Host]*Node + aliases sm.Map[Host, *Node] // Map of active nodes in cluster. // Only accessed within cluster tend goroutine. - nodesMap iatomic.SyncVal //map[string]*Node + nodesMap sm.Map[string, *Node] // Active nodes in cluster. - nodes iatomic.SyncVal //[]*Node + nodes iatomic.SyncVal[[]*Node] stats map[string]*nodeStats //host => stats statsLock sync.Mutex // enable performance metrics - metricsEnabled atomic.Bool // bool - metricsPolicy atomic.Value // *MetricsPolicy + metricsEnabled atomic.Bool // bool + metricsPolicy iatomic.TypedVal[*MetricsPolicy] // Hints for best node for a partition - partitionWriteMap atomic.Value //partitionMap + partitionWriteMap iatomic.TypedVal[partitionMap] //partitionMap clientPolicy ClientPolicy infoPolicy InfoPolicy @@ -76,7 +78,7 @@ type Cluster struct { user string // Password in hashed format in bytes. - password iatomic.SyncVal // []byte + password iatomic.SyncVal[[]byte] } // NewCluster generates a Cluster instance. @@ -118,17 +120,17 @@ func NewCluster(policy *ClientPolicy, hosts []*Host) (*Cluster, Error) { tendChannel: make(chan struct{}), seeds: *iatomic.NewSyncVal(hosts), - aliases: *iatomic.NewSyncVal(make(map[Host]*Node)), - nodesMap: *iatomic.NewSyncVal(make(map[string]*Node)), + aliases: *sm.New[Host, *Node](16), + nodesMap: *sm.New[string, *Node](16), nodes: *iatomic.NewSyncVal([]*Node{}), stats: map[string]*nodeStats{}, - password: *iatomic.NewSyncVal(nil), + password: *iatomic.NewSyncVal[[]byte](nil), supportsPartitionQuery: *iatomic.NewBool(false), } - newCluster.partitionWriteMap.Store(make(partitionMap)) + newCluster.partitionWriteMap.Set(make(partitionMap)) // setup auth info for cluster if policy.RequiresAuthentication() { @@ -224,8 +226,7 @@ Loop: // AddSeeds adds new hosts to the cluster. // They will be added to the cluster on next tend call. func (clstr *Cluster) AddSeeds(hosts []*Host) { - clstr.seeds.Update(func(val interface{}) (interface{}, error) { - seeds := val.([]*Host) + clstr.seeds.Update(func(seeds []*Host) ([]*Host, error) { seeds = append(seeds, hosts...) return seeds, nil }) @@ -260,139 +261,93 @@ func (clstr *Cluster) tend() Error { peers := newPeers(len(nodes)+16, 16) - for _, node := range nodes { - // Clear node reference counts. - node.referenceCount.Set(0) - node.partitionChanged.Set(false) - } - - wg := sync.WaitGroup{} - wg.Add(len(nodes)) - for _, node := range nodes { - go func(node *Node) { - defer wg.Done() - if err := node.Refresh(peers); err != nil { - logger.Logger.Debug("Error occurred while refreshing node: %s", node.String()) - } - }(node) - } - wg.Wait() + seq.ParDo(nodes, func(node *Node) { + if err := node.Refresh(peers); err != nil { + logger.Logger.Debug("Error occurred while refreshing node: %s", node.String()) + } + }) // Refresh peers when necessary. if peers.genChanged.Get() || len(peers.peers()) != nodeCountBeforeTend { // Refresh peers for all nodes that responded the first time even if only one node's peers changed. peers.refreshCount.Set(0) - wg.Add(len(nodes)) - for _, node := range nodes { - go func(node *Node) { - defer wg.Done() - node.refreshPeers(peers) - }(node) - } - wg.Wait() + seq.ParDo(nodes, func(node *Node) { + node.refreshPeers(peers) + }) } - var partMap partitionMap - - // Use the following function to allocate memory for the partitionMap on demand. - // This will prevent the allocation when the cluster is stable, and make tend a bit faster. - pmlock := new(sync.Mutex) - setPartitionMap := func(l *sync.Mutex) { - l.Lock() - defer l.Unlock() - if partMap == nil { - partMap = clstr.getPartitions().clone() - } - } + var partMap iatomic.Guard[partitionMap] // find the first host that connects - for _, _peer := range peers.peers() { + seq.ParDo(peers.peers(), func(_peer *peer) { if clstr.peerExists(peers, _peer.nodeName) { // Node already exists. Do not even try to connect to hosts. - continue + return } - wg.Add(1) - go func(__peer *peer) { - defer wg.Done() - for _, host := range __peer.hosts { - // attempt connection to the host - nv := nodeValidator{seedOnlyCluster: clstr.clientPolicy.SeedOnlyCluster} - if err := nv.validateNode(clstr, host); err != nil { - logger.Logger.Warn("Add node `%s` failed: `%s`", host, err) - continue - } + seq.Do(_peer.hosts, func(host *Host) error { + // attempt connection to the host + nv := nodeValidator{seedOnlyCluster: clstr.clientPolicy.SeedOnlyCluster} + if err := nv.validateNode(clstr, host); err != nil { + logger.Logger.Warn("Add node `%s` failed: `%s`", host, err) + return nil + } - // Must look for new node name in the unlikely event that node names do not agree. - if __peer.nodeName != nv.name { - logger.Logger.Warn("Peer node `%s` is different than actual node `%s` for host `%s`", __peer.nodeName, nv.name, host) - } + // Must look for new node name in the unlikely event that node names do not agree. + if _peer.nodeName != nv.name { + logger.Logger.Warn("Peer node `%s` is different than actual node `%s` for host `%s`", _peer.nodeName, nv.name, host) + } - if clstr.peerExists(peers, nv.name) { - // Node already exists. Do not even try to connect to hosts. - break - } + if clstr.peerExists(peers, nv.name) { + // Node already exists. Do not even try to connect to hosts. + return seq.Break + } - // Create new node. - node := clstr.createNode(&nv) - peers.addNode(nv.name, node) - setPartitionMap(pmlock) + // Create new node. + node := clstr.createNode(&nv) + peers.addNode(nv.name, node) + partMap.InitDoVal(clstr.getPartitions().clone, func(partMap partitionMap) { node.refreshPartitions(peers, partMap, true) - break - } - }(_peer) - } + }) + return seq.Break + }) + }) // Refresh partition map when necessary. - wg.Add(len(nodes)) - for _, node := range nodes { - go func(node *Node) { - defer wg.Done() - if node.partitionChanged.Get() { - setPartitionMap(pmlock) + seq.ParDo(nodes, func(node *Node) { + if node.partitionChanged.Get() { + partMap.InitDoVal(clstr.getPartitions().clone, func(partMap partitionMap) { node.refreshPartitions(peers, partMap, false) - } - }(node) - } - - // This waits for the both steps above - wg.Wait() + }) + } + }) if peers.genChanged.Get() { // Handle nodes changes determined from refreshes. removeList := clstr.findNodesToRemove(peers.refreshCount.Get()) // Remove nodes in a batch. - if len(removeList) > 0 { - for _, n := range removeList { - logger.Logger.Debug("The following nodes will be removed: %s", n) - } - clstr.removeNodes(removeList) + for i := range removeList { + logger.Logger.Debug("The following nodes will be removed: %s", removeList[i]) } - + clstr.removeNodes(removeList) clstr.aggregateNodeStats(removeList) } // Add nodes in a batch. - if len(peers.nodes()) > 0 { - clstr.addNodes(peers.nodes()) - } + clstr.addNodes(peers.nodes()) // add to the number of successful tends clstr.tendCount++ // update all partitions in one go - updatePartitionMap := false - for _, node := range clstr.GetNodes() { - if node.partitionChanged.Get() { - updatePartitionMap = true - break - } - } + updatePartitionMap := seq.Any(clstr.GetNodes(), func(node *Node) bool { + return node.partitionChanged.Get() + }) if updatePartitionMap { - clstr.setPartitions(partMap) + clstr.setPartitions(*partMap.Release()) } if err := clstr.getPartitions().validate(); err != nil { @@ -529,12 +484,7 @@ func (clstr *Cluster) waitTillStabilized() Error { } func (clstr *Cluster) findAlias(alias *Host) *Node { - res, _ := clstr.aliases.GetSyncedVia(func(val interface{}) (interface{}, error) { - aliases := val.(map[Host]*Node) - return aliases[*alias], nil - }) - - return res.(*Node) + return clstr.aliases.Get(*alias) } func (clstr *Cluster) setPartitions(partMap partitionMap) { @@ -542,11 +492,11 @@ func (clstr *Cluster) setPartitions(partMap partitionMap) { logger.Logger.Error("Partition map error: %s.", err.Error()) } - clstr.partitionWriteMap.Store(partMap) + clstr.partitionWriteMap.Set(partMap) } func (clstr *Cluster) getPartitions() partitionMap { - return clstr.partitionWriteMap.Load().(partitionMap) + return clstr.partitionWriteMap.Get() } // discoverSeeds will lookup the seed hosts and convert seed hosts @@ -571,8 +521,7 @@ func discoverSeedIPs(seeds []*Host) (res []*Host) { // Adds seeds to the cluster func (clstr *Cluster) seedNodes() (newSeedsFound bool, errChain Error) { // Must copy array reference for copy on write semantics to work. - seedArrayIfc, _ := clstr.seeds.GetSyncedVia(func(val interface{}) (interface{}, error) { - seeds := val.([]*Host) + seedArrayCopy, _ := clstr.seeds.GetSyncedVia(func(seeds []*Host) ([]*Host, error) { seedsCopy := make([]*Host, len(seeds)) copy(seedsCopy, seeds) @@ -580,7 +529,7 @@ func (clstr *Cluster) seedNodes() (newSeedsFound bool, errChain Error) { }) // discover seed IPs from DNS or Load Balancers - seedArray := discoverSeedIPs(seedArrayIfc.([]*Host)) + seedArray := discoverSeedIPs(seedArrayCopy) successChan := make(chan struct{}, len(seedArray)) errChan := make(chan Error, len(seedArray)) @@ -648,11 +597,7 @@ func (clstr *Cluster) findNodeName(list []*Node, name string) bool { func (clstr *Cluster) addAlias(host *Host, node *Node) { if host != nil && node != nil { - clstr.aliases.Update(func(val interface{}) (interface{}, error) { - aliases := val.(map[Host]*Node) - aliases[*host] = node - return aliases, nil - }) + clstr.aliases.Set(*host, node) } } @@ -731,11 +676,14 @@ func (clstr *Cluster) updateClusterFeatures() { } func (clstr *Cluster) addNodes(nodesToAdd map[string]*Node) { + if len(nodesToAdd) == 0 { + return + } + // update features for all nodes defer clstr.updateClusterFeatures() - clstr.nodes.Update(func(val interface{}) (interface{}, error) { - nodes := val.([]*Node) + clstr.nodes.Update(func(nodes []*Node) ([]*Node, error) { if clstr.clientPolicy.SeedOnlyCluster && clstr.GetSeedCount() == len(nodes) { // Don't add new nodes. return nodes, nil @@ -758,14 +706,18 @@ func (clstr *Cluster) addNodes(nodesToAdd map[string]*Node) { } } - clstr.nodesMap.Set(nodesMap) - clstr.aliases.Set(nodesAliases) + clstr.nodesMap.Replace(nodesMap) + clstr.aliases.Replace(nodesAliases) return nodes, nil }) } func (clstr *Cluster) removeNodes(nodesToRemove []*Node) { + if len(nodesToRemove) == 0 { + return + } + // update features for all nodes defer clstr.updateClusterFeatures() @@ -776,26 +728,13 @@ func (clstr *Cluster) removeNodes(nodesToRemove []*Node) { for _, node := range nodesToRemove { // Remove node's aliases from cluster alias set. // Aliases are only used in tend goroutine, so synchronization is not necessary. - clstr.aliases.Update(func(val interface{}) (interface{}, error) { - aliases := val.(map[Host]*Node) - for _, alias := range node.GetAliases() { - delete(aliases, *alias) - } - return aliases, nil - }) - - clstr.nodesMap.Update(func(val interface{}) (interface{}, error) { - nodesMap := val.(map[string]*Node) - delete(nodesMap, node.name) - return nodesMap, nil - }) - + clstr.aliases.DeleteAllDeref(node.GetAliases()...) + clstr.nodesMap.Delete(node.name) node.Close() } // Remove all nodes at once to avoid copying entire array multiple times. - clstr.nodes.Update(func(val interface{}) (interface{}, error) { - nodes := val.([]*Node) + clstr.nodes.Update(func(nodes []*Node) ([]*Node, error) { nlist := make([]*Node, 0, len(nodes)) nlist = append(nlist, nodes...) for i, n := range nlist { @@ -851,23 +790,21 @@ func (clstr *Cluster) GetRandomNode() (*Node, Error) { // GetNodes returns a list of all nodes in the cluster func (clstr *Cluster) GetNodes() []*Node { // Must copy array reference for copy on write semantics to work. - return clstr.nodes.Get().([]*Node) + return clstr.nodes.Get() } // GetSeedCount is the count of seed nodes func (clstr *Cluster) GetSeedCount() int { - res, _ := clstr.seeds.GetSyncedVia(func(val interface{}) (interface{}, error) { - seeds := val.([]*Host) + res, _ := iatomic.MapSyncValue(&clstr.seeds, func(seeds []*Host) (int, error) { return len(seeds), nil }) - return res.(int) + return res } // GetSeeds returns a list of all seed nodes in the cluster func (clstr *Cluster) GetSeeds() []Host { - res, _ := clstr.seeds.GetSyncedVia(func(val interface{}) (interface{}, error) { - seeds := val.([]*Host) + res, _ := iatomic.MapSyncValue(&clstr.seeds, func(seeds []*Host) ([]Host, error) { res := make([]Host, 0, len(seeds)) for _, seed := range seeds { res = append(res, *seed) @@ -876,22 +813,12 @@ func (clstr *Cluster) GetSeeds() []Host { return res, nil }) - return res.([]Host) + return res } // GetAliases returns a list of all node aliases in the cluster func (clstr *Cluster) GetAliases() map[Host]*Node { - res, _ := clstr.aliases.GetSyncedVia(func(val interface{}) (interface{}, error) { - aliases := val.(map[Host]*Node) - res := make(map[Host]*Node, len(aliases)) - for h, n := range aliases { - res[h] = n - } - - return res, nil - }) - - return res.(map[Host]*Node) + return clstr.aliases.Clone() } // GetNodeByName finds a node by name and returns an @@ -999,7 +926,7 @@ func (clstr *Cluster) WaitUntillMigrationIsFinished(timeout time.Duration) Error func (clstr *Cluster) Password() (res []byte) { pass := clstr.password.Get() if pass != nil { - return pass.([]byte) + return pass } return nil } @@ -1045,11 +972,7 @@ func (clstr *Cluster) WarmUp(count int) (int, Error) { // MetricsEnabled returns true if metrics are enabled for the cluster. func (clstr *Cluster) MetricsPolicy() *MetricsPolicy { - res := clstr.metricsPolicy.Load() - if res != nil { - return res.(*MetricsPolicy) - } - return nil + return clstr.metricsPolicy.Get() } // MetricsEnabled returns true if metrics are enabled for the cluster. @@ -1065,7 +988,7 @@ func (clstr *Cluster) EnableMetrics(policy *MetricsPolicy) { policy = DefaultMetricsPolicy() } - clstr.metricsPolicy.Store(policy) + clstr.metricsPolicy.Set(policy) clstr.metricsEnabled.Store(true) clstr.statsLock.Lock() diff --git a/connection.go b/connection.go index fed97fc7..3b977a1d 100644 --- a/connection.go +++ b/connection.go @@ -484,7 +484,7 @@ func (ctn *Connection) login(policy *ClientPolicy, hashedPassword []byte, sessio si := command.sessionInfo() if ctn.node != nil && si.isValid() { - ctn.node.sessionInfo.Store(si) + ctn.node.sessionInfo.Set(si) } } diff --git a/helper_test.go b/helper_test.go index 50e312b6..30af0ba1 100644 --- a/helper_test.go +++ b/helper_test.go @@ -14,8 +14,8 @@ package aerospike -func ParseIndexErrorCode(response string) Error { - return parseIndexErrorCode(response) +func ParseInfoErrorCode(response string) Error { + return parseInfoErrorCode(response) } func (e *AerospikeError) Msg() string { diff --git a/internal/atomic/guard.go b/internal/atomic/guard.go new file mode 100644 index 00000000..3e944f1f --- /dev/null +++ b/internal/atomic/guard.go @@ -0,0 +1,82 @@ +// Copyright 2014-2022 Aerospike, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomic + +import "sync" + +// Guard allows synchronized access to a value +type Guard[T any] struct { + val *T + m sync.Mutex +} + +// NewGuard creates a new instance of Guard +func NewGuard[T any](val *T) *Guard[T] { + return &Guard[T]{val: val} +} + +// Do calls the passed closure. +func (g *Guard[T]) Do(f func(*T)) { + g.m.Lock() + defer g.m.Unlock() + f(g.val) +} + +// DoVal calls the passed closure with a dereferenced internal value. +func (g *Guard[T]) DoVal(f func(T)) { + g.m.Lock() + defer g.m.Unlock() + f(*g.val) +} + +// Call the passed closure allowing to replace the content. +func (g *Guard[T]) Update(f func(**T)) { + g.m.Lock() + defer g.m.Unlock() + f(&g.val) +} + +// Calls the passed closure allowing to replace the content. +// It will call the init func if the internal values is nil. +func (g *Guard[T]) InitDo(init func() *T, f func(*T)) { + g.m.Lock() + defer g.m.Unlock() + if g.val == nil { + g.val = init() + } + f(g.val) +} + +// Calls the passed closure allowing to replace the content. +// It will call the init func if the internal values is nil. +// It is used for reference values like slices and maps. +func (g *Guard[T]) InitDoVal(init func() T, f func(T)) { + g.m.Lock() + defer g.m.Unlock() + if g.val == nil { + t := init() + g.val = &t + } + f(*g.val) +} + +// Release returns the internal value and sets it to nil +func (g *Guard[T]) Release() *T { + g.m.Lock() + defer g.m.Unlock() + res := g.val + g.val = nil + return res +} diff --git a/internal/atomic/guard_test.go b/internal/atomic/guard_test.go new file mode 100644 index 00000000..dc868d79 --- /dev/null +++ b/internal/atomic/guard_test.go @@ -0,0 +1,118 @@ +// Copyright 2014-2022 Aerospike, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomic_test + +import ( + "runtime" + + "github.com/aerospike/aerospike-client-go/v7/internal/atomic" + + gg "github.com/onsi/ginkgo/v2" + gm "github.com/onsi/gomega" +) + +var _ = gg.Describe("Atomic Guard", func() { + // atomic tests require actual parallelism + runtime.GOMAXPROCS(runtime.NumCPU()) + + type S struct { + a int + b bool + } + + var grd *atomic.Guard[S] + + gg.BeforeEach(func() { + grd = atomic.NewGuard[S](&S{a: 1, b: true}) + }) + + gg.It("must pass internal value correctly", func() { + grd.Do(func(s *S) { + gm.Expect(*s).To(gm.Equal(S{a: 1, b: true})) + }) + + }) + + gg.It("must assign/copy internal value correctly", func() { + local := S{a: 99, b: false} + grd.Do(func(s *S) { + *s = local + }) + + grd.Do(func(s *S) { + gm.Expect(*s).To(gm.Equal(S{a: 99, b: false})) + }) + + }) + + gg.It("must initialize and assign internal value correctly", func() { + flocal := func() *S { return &S{a: 99, b: false} } + + var grd atomic.Guard[S] + grd.Do(func(s *S) { + gm.Expect(s).To(gm.BeNil()) + }) + + grd.InitDo(flocal, func(s *S) { + gm.Expect(*s).To(gm.Equal(S{a: 99, b: false})) + s.a++ + s.b = true + }) + + grd.InitDo(flocal, func(s *S) { + gm.Expect(*s).To(gm.Equal(S{a: 100, b: true})) + }) + + grd.Do(func(s *S) { + gm.Expect(*s).To(gm.Equal(S{a: 100, b: true})) + }) + }) + + gg.It("must initialize and assign internal value correctly", func() { + flocal := func() map[int]int { return map[int]int{1: 1, 2: 2, 3: 3} } + + var grd atomic.Guard[map[int]int] + grd.Do(func(s *map[int]int) { + gm.Expect(s).To(gm.BeNil()) + }) + + grd.InitDoVal(flocal, func(s map[int]int) { + gm.Expect(s).To(gm.Equal(map[int]int{1: 1, 2: 2, 3: 3})) + }) + + grd.InitDoVal(flocal, func(s map[int]int) { + gm.Expect(s).To(gm.Equal(map[int]int{1: 1, 2: 2, 3: 3})) + for i := 4; i < 100; i++ { + s[i] = i + } + }) + + grd.DoVal(func(s map[int]int) { + gm.Expect(len(s)).To(gm.Equal(99)) + }) + }) + + gg.It("must replace internal value's reference correctly", func() { + local := S{a: 99, b: false} + grd.Update(func(s **S) { + *s = &local + }) + + grd.Do(func(s *S) { + gm.Expect(s == &local).To(gm.BeTrue()) + }) + + }) +}) diff --git a/internal/atomic/map/map.go b/internal/atomic/map/map.go new file mode 100644 index 00000000..200c2ef3 --- /dev/null +++ b/internal/atomic/map/map.go @@ -0,0 +1,118 @@ +// Copyright 2014-2022 Aerospike, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomic + +import ( + "sync" +) + +// Map implements a Map with atomic semantics. +type Map[K comparable, V any] struct { + m map[K]V + mutex sync.RWMutex +} + +// New generates a new Map instance. +func New[K comparable, V any](length int) *Map[K, V] { + return &Map[K, V]{ + m: make(map[K]V, length), + } +} + +// Get atomically retrieves an element from the Map. +func (m *Map[K, V]) Get(k K) V { + m.mutex.RLock() + res := m.m[k] + m.mutex.RUnlock() + return res +} + +// Set atomically sets an element in the Map. +// If idx is out of range, it will return an error. +func (m *Map[K, V]) Set(k K, v V) { + m.mutex.Lock() + m.m[k] = v + m.mutex.Unlock() +} + +// Replace replaces the internal map with the provided one. +func (m *Map[K, V]) Replace(nm map[K]V) { + m.mutex.Lock() + m.m = nm + m.mutex.Unlock() +} + +// Length returns the Map size. +func (m *Map[K, V]) Length() int { + m.mutex.RLock() + res := len(m.m) + m.mutex.RUnlock() + + return res +} + +// Length returns the Map size. +func (m *Map[K, V]) Clone() map[K]V { + m.mutex.RLock() + res := make(map[K]V, len(m.m)) + for k, v := range m.m { + res[k] = v + } + m.mutex.RUnlock() + + return res +} + +// Delete will remove the key and return its value. +func (m *Map[K, V]) Delete(k K) V { + m.mutex.Lock() + res := m.m[k] + delete(m.m, k) + m.mutex.Unlock() + return res +} + +// DeleteDeref will dereference and remove the key and return its value. +func (m *Map[K, V]) DeleteDeref(k *K) V { + m.mutex.Lock() + res := m.m[*k] + delete(m.m, *k) + m.mutex.Unlock() + return res +} + +// DeleteAllDeref will dereferences and removes the keys. +func (m *Map[K, V]) DeleteAll(ks ...K) { + m.mutex.Lock() + for i := range ks { + delete(m.m, ks[i]) + } + m.mutex.Unlock() +} + +// DeleteAll will remove the keys. +func (m *Map[K, V]) DeleteAllDeref(ks ...*K) { + m.mutex.Lock() + for i := range ks { + delete(m.m, *ks[i]) + } + m.mutex.Unlock() +} + +func MapAllF[K comparable, V any, U any](m *Map[K, V], f func(map[K]V) U) U { + m.mutex.RLock() + defer m.mutex.RUnlock() + return f(m.m) +} diff --git a/internal/atomic/sync_val.go b/internal/atomic/sync_val.go index e5b7ac4a..c7d072bb 100644 --- a/internal/atomic/sync_val.go +++ b/internal/atomic/sync_val.go @@ -3,25 +3,25 @@ package atomic import "sync" // SyncVal allows synchronized access to a value -type SyncVal struct { - val interface{} +type SyncVal[T any] struct { + val T lock sync.RWMutex } // NewSyncVal creates a new instance of SyncVal -func NewSyncVal(val interface{}) *SyncVal { - return &SyncVal{val: val} +func NewSyncVal[T any](val T) *SyncVal[T] { + return &SyncVal[T]{val: val} } // Set updates the value of SyncVal with the passed argument -func (sv *SyncVal) Set(val interface{}) { +func (sv *SyncVal[T]) Set(val T) { sv.lock.Lock() sv.val = val sv.lock.Unlock() } // Get returns the value inside the SyncVal -func (sv *SyncVal) Get() interface{} { +func (sv *SyncVal[T]) Get() T { sv.lock.RLock() val := sv.val sv.lock.RUnlock() @@ -29,7 +29,7 @@ func (sv *SyncVal) Get() interface{} { } // GetSyncedVia returns the value returned by the function f. -func (sv *SyncVal) GetSyncedVia(f func(interface{}) (interface{}, error)) (interface{}, error) { +func (sv *SyncVal[T]) GetSyncedVia(f func(T) (T, error)) (T, error) { sv.lock.RLock() defer sv.lock.RUnlock() @@ -40,7 +40,7 @@ func (sv *SyncVal) GetSyncedVia(f func(interface{}) (interface{}, error)) (inter // Update gets a function and passes the value of SyncVal to it. // If the resulting err is nil, it will update the value of SyncVal. // It will return the resulting error to the caller. -func (sv *SyncVal) Update(f func(interface{}) (interface{}, error)) error { +func (sv *SyncVal[T]) Update(f func(T) (T, error)) error { sv.lock.Lock() defer sv.lock.Unlock() @@ -50,3 +50,12 @@ func (sv *SyncVal) Update(f func(interface{}) (interface{}, error)) error { } return err } + +// MapSyncValue returns the value returned by the function f. +func MapSyncValue[T any, U any](sv *SyncVal[T], f func(T) (U, error)) (U, error) { + sv.lock.RLock() + defer sv.lock.RUnlock() + + val, err := f(sv.val) + return val, err +} diff --git a/internal/atomic/typed_val.go b/internal/atomic/typed_val.go new file mode 100644 index 00000000..7c56a3ab --- /dev/null +++ b/internal/atomic/typed_val.go @@ -0,0 +1,23 @@ +package atomic + +import "sync/atomic" + +// TypedVal allows synchronized access to a value +type TypedVal[T any] atomic.Value + +// Set updates the value of TypedVal with the passed argument +func (sv *TypedVal[T]) Set(val T) { + (*atomic.Value)(sv).Store(&val) +} + +// Get returns the value inside the TypedVal +func (sv *TypedVal[T]) Get() T { + res := (*atomic.Value)(sv).Load() + if res != nil { + return *res.(*T) + } + + // return zero value; for pointers, it will be nil + var t T + return t +} diff --git a/internal/atomic/typed_val_test.go b/internal/atomic/typed_val_test.go new file mode 100644 index 00000000..e45e14a2 --- /dev/null +++ b/internal/atomic/typed_val_test.go @@ -0,0 +1,126 @@ +// Copyright 2014-2022 Aerospike, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package atomic_test + +import ( + "github.com/aerospike/aerospike-client-go/v7/internal/atomic" + + gg "github.com/onsi/ginkgo/v2" + gm "github.com/onsi/gomega" +) + +var _ = gg.Describe("TypedVal", func() { + + gg.Context("Storage must support", func() { + + gg.Context("Primitives", func() { + + gg.It("int", func() { + var t int = 5 + var tv atomic.TypedVal[int] + tv.Set(t) + gm.Expect(tv.Get()).To(gm.Equal(t)) + }) + + gg.It("string", func() { + var t string = "Hello!" + var tv atomic.TypedVal[string] + tv.Set(t) + gm.Expect(tv.Get()).To(gm.Equal(t)) + }) + + gg.It("slice", func() { + var t = []int{1, 2, 3} + var tv atomic.TypedVal[[]int] + tv.Set(t) + gm.Expect(tv.Get()).To(gm.Equal(t)) + + tv.Set(nil) + var tt []int + gm.Expect(tv.Get()).To(gm.Equal(tt)) + }) + + gg.It("map", func() { + var t = map[string]int{"a": 1, "b": 2, "c": 3} + var tv atomic.TypedVal[map[string]int] + tv.Set(t) + gm.Expect(tv.Get()).To(gm.Equal(t)) + + tv.Set(nil) + var tt map[string]int + gm.Expect(tv.Get()).To(gm.Equal(tt)) + }) + + }) + + gg.Context("Pointers", func() { + + gg.It("*int", func() { + var t int = 5 + var tv atomic.TypedVal[*int] + tv.Set(&t) + gm.Expect(tv.Get()).To(gm.Equal(&t)) + + tv.Set(nil) + var tt *int + gm.Expect(tv.Get()).To(gm.Equal(tt)) + }) + + gg.It("*string", func() { + var t string = "Hello!" + var tv atomic.TypedVal[*string] + tv.Set(&t) + gm.Expect(tv.Get()).To(gm.Equal(&t)) + + tv.Set(nil) + var tt *string + gm.Expect(tv.Get()).To(gm.Equal(tt)) + }) + + gg.It("slice", func() { + var t = []int{1, 2, 3} + var tv atomic.TypedVal[*[]int] + tv.Set(&t) + gm.Expect(tv.Get()).To(gm.Equal(&t)) + + t = nil + tv.Set(&t) + gm.Expect(tv.Get()).To(gm.Equal(&t)) + + tv.Set(nil) + var tt *[]int + gm.Expect(tv.Get()).To(gm.Equal(tt)) + }) + + gg.It("map", func() { + var t = map[string]int{"a": 1, "b": 2, "c": 3} + var tv atomic.TypedVal[*map[string]int] + tv.Set(&t) + gm.Expect(tv.Get()).To(gm.Equal(&t)) + + t = nil + tv.Set(&t) + gm.Expect(tv.Get()).To(gm.Equal(&t)) + + tv.Set(nil) + var tt *map[string]int + gm.Expect(tv.Get()).To(gm.Equal(tt)) + }) + + }) + + }) + +}) diff --git a/internal/seq/seq.go b/internal/seq/seq.go new file mode 100644 index 00000000..1b6da158 --- /dev/null +++ b/internal/seq/seq.go @@ -0,0 +1,84 @@ +// Copyright 2014-2022 Aerospike, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seq + +import ( + "errors" + "sync" +) + +var Break = errors.New("Break") + +func Do[T any](seq []T, f func(T) error) { + for i := range seq { + if err := f(seq[i]); err == Break { + break + } + } +} + +func ParDo[T any](seq []T, f func(T)) { + if len(seq) == 0 { + return + } + + wg := new(sync.WaitGroup) + wg.Add(len(seq)) + for i := range seq { + go func() { + defer wg.Done() + f(seq[i]) + }() + } + wg.Wait() +} + +func Any[T any](seq []T, f func(T) bool) bool { + for i := range seq { + if f(seq[i]) { + return true + } + } + return false +} + +func All[T any](seq []T, f func(T) bool) bool { + if len(seq) == 0 { + return false + } + + for i := range seq { + if !f(seq[i]) { + return false + } + } + return true +} + +func Clone[T any](seq []T) []T { + if seq == nil { + return nil + } + + if len(seq) == 0 { + return []T{} + } + + res := make([]T, len(seq)) + for i := range seq { + res[i] = seq[i] + } + return res +} diff --git a/node.go b/node.go index dddfb700..6dc85ddc 100644 --- a/node.go +++ b/node.go @@ -20,8 +20,6 @@ import ( "io" "strconv" "strings" - "sync" - "sync/atomic" "time" "golang.org/x/sync/errgroup" @@ -47,17 +45,16 @@ type Node struct { cluster *Cluster name string host *Host - aliases atomic.Value //[]*Host + aliases iatomic.TypedVal[[]*Host] stats nodeStats - sessionInfo atomic.Value //*sessionInfo + sessionInfo iatomic.TypedVal[*sessionInfo] - racks atomic.Value //map[string]int + racks iatomic.TypedVal[map[string]int] // tendConn reserves a connection for tend so that it won't have to // wait in queue for connections, since that will cause starvation // and the node being dropped under load. - tendConn *Connection - tendConnLock sync.Mutex // All uses of tend connection should be synchronized + tendConn iatomic.Guard[Connection] peersGeneration iatomic.Int peersCount iatomic.Int @@ -102,9 +99,9 @@ func newNode(cluster *Cluster, nv *nodeValidator) *Node { rebalanceGeneration: *iatomic.NewInt(-1), } - newNode.aliases.Store(nv.aliases) - newNode.sessionInfo.Store(nv.sessionInfo) - newNode.racks.Store(map[string]int{}) + newNode.aliases.Set(nv.aliases) + newNode.sessionInfo.Set(nv.sessionInfo) + newNode.racks.Set(make(map[string]int)) // this will reset to zero on first aggregation on the cluster, // therefore will only be counted once. @@ -139,7 +136,9 @@ func (nd *Node) Refresh(peers *peers) Error { // Close idleConnections defer nd.dropIdleConnections() + // Clear node reference counts. nd.referenceCount.Set(0) + nd.partitionChanged.Set(false) var infoMap map[string]string commands := []string{"node", "peers-generation", "partition-generation"} @@ -195,13 +194,13 @@ func (nd *Node) Refresh(peers *peers) Error { } // refreshSessionToken refreshes the session token if it has been expired -func (nd *Node) refreshSessionToken() Error { +func (nd *Node) refreshSessionToken() (err Error) { // no session token to refresh if !nd.cluster.clientPolicy.RequiresAuthentication() { return nil } - st := nd.sessionInfo.Load().(*sessionInfo) + st := nd.sessionInfo.Get() // Consider when the next tend will be in this calculation. If the next tend will be too late, // refresh the sessionInfo now. @@ -209,24 +208,19 @@ func (nd *Node) refreshSessionToken() Error { return nil } - nd.tendConnLock.Lock() - defer nd.tendConnLock.Unlock() - - if err := nd.initTendConn(nd.cluster.clientPolicy.LoginTimeout); err != nil { - return err - } - - command := newLoginCommand(nd.tendConn.dataBuffer) - if err := command.login(&nd.cluster.clientPolicy, nd.tendConn, nd.cluster.Password()); err != nil { - // force new connections to use default creds until a new valid session token is acquired - nd.resetSessionInfo() - // Socket not authenticated. Do not put back into pool. - nd.tendConn.Close() - return err - } + nd.usingTendConn(nd.cluster.clientPolicy.LoginTimeout, func(conn *Connection) { + command := newLoginCommand(conn.dataBuffer) + if err = command.login(&nd.cluster.clientPolicy, conn, nd.cluster.Password()); err != nil { + // force new connections to use default creds until a new valid session token is acquired + nd.resetSessionInfo() + // Socket not authenticated. Do not put back into pool. + conn.Close() + } else { + nd.sessionInfo.Set(command.sessionInfo()) + } + }) - nd.sessionInfo.Store(command.sessionInfo()) - return nil + return err } func (nd *Node) updateRackInfo(infoMap map[string]string) Error { @@ -287,7 +281,7 @@ func (nd *Node) updateRackInfo(infoMap map[string]string) Error { } } - nd.racks.Store(racks) + nd.racks.Set(racks) return nil } @@ -511,7 +505,7 @@ func (nd *Node) newConnection(overrideThreshold bool) (*Connection, Error) { } conn.node = nd - sessionInfo := nd.sessionInfo.Load().(*sessionInfo) + sessionInfo := nd.sessionInfo.Get() // need to authenticate if err = conn.login(&nd.cluster.clientPolicy, nd.cluster.Password(), sessionInfo); err != nil { // increment node errors if authentication hit a network error @@ -625,12 +619,12 @@ func (nd *Node) GetName() string { // GetAliases returns node aliases. func (nd *Node) GetAliases() []*Host { - return nd.aliases.Load().([]*Host) + return nd.aliases.Get() } // Sets node aliases func (nd *Node) setAliases(aliases []*Host) { - nd.aliases.Store(aliases) + nd.aliases.Set(aliases) } // AddAlias adds an alias for the node @@ -670,11 +664,11 @@ func (nd *Node) closeConnections() { } // close the tend connection - nd.tendConnLock.Lock() - defer nd.tendConnLock.Unlock() - if nd.tendConn != nil { - nd.tendConn.Close() - } + nd.tendConn.Do(func(conn *Connection) { + if conn != nil { + conn.Close() + } + }) } // Equals compares equality of two nodes based on their names. @@ -725,33 +719,40 @@ func (nd *Node) WaitUntillMigrationIsFinished(timeout time.Duration) Error { } } -// initTendConn sets up a connection to be used for info requests. -// The same connection will be used for tend. -func (nd *Node) initTendConn(timeout time.Duration) Error { - if timeout <= 0 { - timeout = _DEFAULT_TIMEOUT - } - deadline := time.Now().Add(timeout) +// usingTendConn allows the tend connection to be used in a monitor without race conditions. +// If the connection is not valid, it establishes a valid connection first. +func (nd *Node) usingTendConn(timeout time.Duration, f func(conn *Connection)) (err Error) { + nd.tendConn.Update(func(conn **Connection) { + if timeout <= 0 { + timeout = _DEFAULT_TIMEOUT + } + deadline := time.Now().Add(timeout) + + // if the tend connection is invalid, establish a new connection first + if *conn == nil || !(*conn).IsConnected() { + if nd.connectionCount.Get() == 0 { + // if there are no connections in the pool, create a new connection synchronously. + // this will make sure the initial tend will get a connection without multiple retries. + *conn, err = nd.newConnection(true) + } else { + *conn, err = nd.GetConnection(timeout) + } - if nd.tendConn == nil || !nd.tendConn.IsConnected() { - var tendConn *Connection - var err Error - if nd.connectionCount.Get() == 0 { - // if there are no connections in the pool, create a new connection synchronously. - // this will make sure the initial tend will get a connection without multiple retries. - tendConn, err = nd.newConnection(true) - } else { - tendConn, err = nd.GetConnection(timeout) + // if no connection could be established, exit fast + if err != nil { + return + } } - if err != nil { - return err + // Set timeout for tend conn + if err = (*conn).SetTimeout(deadline, timeout); err != nil { + return } - nd.tendConn = tendConn - } - // Set timeout for tend conn - return nd.tendConn.SetTimeout(deadline, timeout) + // if all went well, call the closure + f(*conn) + }) + return err } // requestInfoWithRetry gets info values by name from the specified database server node. @@ -776,37 +777,26 @@ func (nd *Node) RequestInfo(policy *InfoPolicy, name ...string) (map[string]stri } // RequestInfo gets info values by name from the specified database server node. -func (nd *Node) requestInfo(timeout time.Duration, name ...string) (map[string]string, Error) { - nd.tendConnLock.Lock() - defer nd.tendConnLock.Unlock() - - if err := nd.initTendConn(timeout); err != nil { - return nil, err - } +func (nd *Node) requestInfo(timeout time.Duration, name ...string) (response map[string]string, err Error) { + nd.usingTendConn(timeout, func(conn *Connection) { + response, err = conn.RequestInfo(name...) + if err != nil { + conn.Close() + } + }) - response, err := nd.tendConn.RequestInfo(name...) - if err != nil { - nd.tendConn.Close() - return nil, err - } - return response, nil + return response, err } // requestRawInfo gets info values by name from the specified database server node. // It won't parse the results. -func (nd *Node) requestRawInfo(policy *InfoPolicy, name ...string) (*info, Error) { - nd.tendConnLock.Lock() - defer nd.tendConnLock.Unlock() - - if err := nd.initTendConn(policy.Timeout); err != nil { - return nil, err - } - - response, err := newInfo(nd.tendConn, name...) - if err != nil { - nd.tendConn.Close() - return nil, err - } +func (nd *Node) requestRawInfo(policy *InfoPolicy, name ...string) (response *info, err Error) { + nd.usingTendConn(policy.Timeout, func(conn *Connection) { + response, err = newInfo(conn, name...) + if err != nil { + conn.Close() + } + }) return response, nil } @@ -839,13 +829,13 @@ func (nd *Node) RequestStats(policy *InfoPolicy) (map[string]string, Error) { // unsuccessful authentication with token func (nd *Node) resetSessionInfo() { si := &sessionInfo{} - nd.sessionInfo.Store(si) + nd.sessionInfo.Set(si) } // sessionToken returns the session token for the node. // It will return nil if the session has expired. func (nd *Node) sessionToken() []byte { - si := nd.sessionInfo.Load().(*sessionInfo) + si := nd.sessionInfo.Get() if !si.isValid() { return nil } @@ -855,7 +845,7 @@ func (nd *Node) sessionToken() []byte { // Rack returns the rack number for the namespace. func (nd *Node) Rack(namespace string) (int, Error) { - racks := nd.racks.Load().(map[string]int) + racks := nd.racks.Get() v, exists := racks[namespace] if exists { @@ -867,7 +857,7 @@ func (nd *Node) Rack(namespace string) (int, Error) { // Rack returns the rack number for the namespace. func (nd *Node) hasRack(namespace string, rack int) bool { - racks := nd.racks.Load().(map[string]int) + racks := nd.racks.Get() v, exists := racks[namespace] if !exists { diff --git a/proxy_client.go b/proxy_client.go index 627cbe4e..15257ae5 100644 --- a/proxy_client.go +++ b/proxy_client.go @@ -21,7 +21,6 @@ import ( "math/rand" "runtime" "sync" - "sync/atomic" "time" "google.golang.org/grpc" @@ -44,7 +43,7 @@ type ProxyClient struct { grpcHost *Host dialOptions []grpc.DialOption - authToken atomic.Value + authToken iatomic.TypedVal[string] authInterceptor *authInterceptor active iatomic.Bool @@ -89,6 +88,11 @@ func grpcClientFinalizer(f *ProxyClient) { // If the policy is nil, the default relevant policy will be used. // Pass "dns:///
:" (note the 3 slashes) for dns load balancing, // automatically supported internally by grpc-go. +// The connection pool after connecting to the database is initially empty, +// and connections are established on a per need basis, which can be slow and +// time out some initial commands. +// It is recommended to call the client.WarmUp() method right after connecting to the database +// to fill up the connection pool to the required service level. func NewProxyClientWithPolicyAndHost(policy *ClientPolicy, host *Host, dialOptions ...grpc.DialOption) (*ProxyClient, Error) { if policy == nil { policy = NewClientPolicy() @@ -255,11 +259,11 @@ func (clnt *ProxyClient) SetDefaultInfoPolicy(policy *InfoPolicy) { //------------------------------------------------------- func (clnt *ProxyClient) token() string { - return clnt.authToken.Load().(string) + return clnt.authToken.Get() } func (clnt *ProxyClient) setAuthToken(token string) { - clnt.authToken.Store(token) + clnt.authToken.Set(token) } func (clnt *ProxyClient) grpcConn() (*grpc.ClientConn, Error) { @@ -601,6 +605,11 @@ func (clnt *ProxyClient) GetHeader(policy *BasePolicy, key *Key) (*Record, Error // If the policy is nil, the default relevant policy will be used. func (clnt *ProxyClient) BatchGet(policy *BatchPolicy, keys []*Key, binNames ...string) ([]*Record, Error) { policy = clnt.getUsableBatchPolicy(policy) + + if len(keys) == 0 { + return []*Record{}, nil + } + batchRecordsIfc := make([]BatchRecordIfc, 0, len(keys)) batchRecords := make([]*BatchRecord, 0, len(keys)) for _, key := range keys { @@ -629,6 +638,11 @@ func (clnt *ProxyClient) BatchGet(policy *BatchPolicy, keys []*Key, binNames ... // If a batch request to a node fails, the entire batch is cancelled. func (clnt *ProxyClient) BatchGetOperate(policy *BatchPolicy, keys []*Key, ops ...*Operation) ([]*Record, Error) { policy = clnt.getUsableBatchPolicy(policy) + + if len(keys) == 0 { + return []*Record{}, nil + } + batchRecordsIfc := make([]BatchRecordIfc, 0, len(keys)) batchRecords := make([]*BatchRecord, 0, len(keys)) for _, key := range keys { @@ -678,6 +692,11 @@ func (clnt *ProxyClient) BatchGetComplex(policy *BatchPolicy, records []*BatchRe // If the policy is nil, the default relevant policy will be used. func (clnt *ProxyClient) BatchGetHeader(policy *BatchPolicy, keys []*Key) ([]*Record, Error) { policy = clnt.getUsableBatchPolicy(policy) + + if len(keys) == 0 { + return []*Record{}, nil + } + batchRecordsIfc := make([]BatchRecordIfc, 0, len(keys)) for _, key := range keys { batchRecordsIfc = append(batchRecordsIfc, NewBatchReadHeader(clnt.DefaultBatchReadPolicy, key)) @@ -702,6 +721,11 @@ func (clnt *ProxyClient) BatchGetHeader(policy *BatchPolicy, keys []*Key) ([]*Re // Requires server version 6.0+ func (clnt *ProxyClient) BatchDelete(policy *BatchPolicy, deletePolicy *BatchDeletePolicy, keys []*Key) ([]*BatchRecord, Error) { policy = clnt.getUsableBatchPolicy(policy) + + if len(keys) == 0 { + return []*BatchRecord{}, nil + } + deletePolicy = clnt.getUsableBatchDeletePolicy(deletePolicy) batchRecordsIfc := make([]BatchRecordIfc, 0, len(keys)) @@ -751,6 +775,10 @@ func (clnt *ProxyClient) BatchOperate(policy *BatchPolicy, records []BatchRecord // // Requires server version 6.0+ func (clnt *ProxyClient) BatchExecute(policy *BatchPolicy, udfPolicy *BatchUDFPolicy, keys []*Key, packageName string, functionName string, args ...Value) ([]*BatchRecord, Error) { + if len(keys) == 0 { + return []*BatchRecord{}, nil + } + batchRecordsIfc := make([]BatchRecordIfc, 0, len(keys)) batchRecords := make([]*BatchRecord, 0, len(keys)) for _, key := range keys { diff --git a/types/result_code.go b/types/result_code.go index 2347780e..df79029e 100644 --- a/types/result_code.go +++ b/types/result_code.go @@ -172,6 +172,9 @@ const ( // LOST_CONFLICT defines write command loses conflict to XDR. LOST_CONFLICT = 28 + // Write can't complete until XDR finishes shipping. + XDR_KEY_BUSY = 32 + // QUERY_END defines there are no more records left for query. QUERY_END ResultCode = 50 @@ -454,6 +457,9 @@ func ResultCodeToString(resultCode ResultCode) string { case LOST_CONFLICT: return "Write command loses conflict to XDR." + case XDR_KEY_BUSY: + return "Write can't complete until XDR finishes shipping." + case QUERY_END: return "Query end" @@ -691,6 +697,8 @@ func (rc ResultCode) String() string { return "FILTERED_OUT" case LOST_CONFLICT: return "LOST_CONFLICT" + case XDR_KEY_BUSY: + return "XDR_KEY_BUSY" case QUERY_END: return "QUERY_END" case SECURITY_NOT_SUPPORTED: diff --git a/udf_test.go b/udf_test.go index 856b795a..29067377 100644 --- a/udf_test.go +++ b/udf_test.go @@ -65,6 +65,11 @@ function getRecordKeyValue(rec) end ` +const invalidUdfBody = `function testFunc1(rec, div) + asdf + returned ret -- Return the Return value and/or status +end` + // ALL tests are isolated by SetName and Key, which are 50 random characters var _ = gg.Describe("UDF/Query tests", func() { @@ -96,6 +101,15 @@ var _ = gg.Describe("UDF/Query tests", func() { gm.Expect(<-regTask.OnComplete()).NotTo(gm.HaveOccurred()) }) + gg.It("must parse invalid UDF error", func() { + _, err := client.RegisterUDF(wpolicy, []byte(invalidUdfBody), "invalid_udf1.lua", as.LUA) + gm.Expect(err).To(gm.HaveOccurred()) + gm.Expect(err.Error()).To(gm.HaveSuffix(`compile_error +File: invalid_udf1.lua +Line: 3 +Message: syntax error near 'returned'`)) + }) + gg.It("must run a UDF on a single record", func() { registerUDF(udfBody, "udf1.lua")