Skip to content

Commit

Permalink
support vector index ttl
Browse files Browse the repository at this point in the history
  • Loading branch information
douchao.douchao authored and yangbodong22011 committed Aug 1, 2023
1 parent 0e6de33 commit 7f15861
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
.DS_Store
211 changes: 118 additions & 93 deletions tair/tairvector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from typing import Dict, List, Sequence, Tuple, Union, Optional, Iterable
from tair.typing import AbsExpiryT, CommandsProtocol, ExpiryT, ResponseT

from redis.client import pairs_to_dict
from redis.utils import str_if_bytes


VectorType = Sequence[Union[int, float]]


Expand Down Expand Up @@ -123,11 +123,11 @@ def __init__(self, client, name, **index_params):

# bind methods
for method in (
"tvs_del",
"tvs_hdel",
"tvs_hgetall",
"tvs_hmget",
"tvs_scan",
"tvs_del",
"tvs_hdel",
"tvs_hgetall",
"tvs_hmget",
"tvs_scan",
):
attr = getattr(TairVectorCommands, method)
if callable(attr):
Expand All @@ -150,23 +150,23 @@ def tvs_hset(self, key: str, vector: Union[VectorType, str, None] = None, **kwar
return self.client.tvs_hset(self.name, key, vector, self.is_binary, **kwargs)

def tvs_knnsearch(
self,
k: int,
vector: Union[VectorType, str],
filter_str: Optional[str] = None,
**kwargs
self,
k: int,
vector: Union[VectorType, str],
filter_str: Optional[str] = None,
**kwargs
):
"""search for the top @k approximate nearest neighbors of @vector"""
return self.client.tvs_knnsearch(
self.name, k, vector, self.is_binary, filter_str, **kwargs
)

def tvs_mknnsearch(
self,
k: int,
vectors: Sequence[VectorType],
filter_str: Optional[str] = None,
**kwargs
self,
k: int,
vectors: Sequence[VectorType],
filter_str: Optional[str] = None,
**kwargs
):
"""batch approximate nearest neighbors search for a list of vectors"""
return self.client.tvs_mknnsearch(
Expand All @@ -180,7 +180,7 @@ def __repr__(self):
return str(self)


class TairVectorCommands:
class TairVectorCommands(CommandsProtocol):
encode_vector = TextVectorEncoder.encode
decode_vector = TextVectorEncoder.decode

Expand All @@ -190,13 +190,13 @@ class TairVectorCommands:
SCAN_INDEX_CMD = "TVS.SCANINDEX"

def tvs_create_index(
self,
name: str,
dim: int,
distance_type: str = DistanceMetric.L2,
index_type: str = IndexType.HNSW,
data_type: str = DataType.Float32,
**kwargs
self,
name: str,
dim: int,
distance_type: str = DistanceMetric.L2,
index_type: str = IndexType.HNSW,
data_type: str = DataType.Float32,
**kwargs
):
"""
create a vector
Expand Down Expand Up @@ -231,7 +231,7 @@ def tvs_del_index(self, name: str):
return self.execute_command(self.DEL_INDEX_CMD, name)

def tvs_scan_index(
self, pattern: Optional[str] = None, batch: int = 10
self, pattern: Optional[str] = None, batch: int = 10
) -> TairVectorScanResult:
"""
scan all the indices
Expand All @@ -257,12 +257,12 @@ def tvs_index(self, name: str, **index_params) -> TairVectorIndex:
SCAN_CMD = "TVS.SCAN"

def tvs_hset(
self,
index: str,
key: str,
vector: Union[VectorType, str, None] = None,
is_binary=False,
**kwargs
self,
index: str,
key: str,
vector: Union[VectorType, str, None] = None,
is_binary=False,
**kwargs
):
"""
add/update a data entry to index
Expand Down Expand Up @@ -309,13 +309,13 @@ def tvs_hmget(self, index: str, key: str, *args):
return self.execute_command(self.HMGET_CMD, index, key, *args)

def tvs_scan(
self,
index: str,
pattern: Optional[str] = None,
batch: int = 10,
filter_str: Optional[str] = None,
vector: Optional[VectorType] = None,
max_dist: Optional[float] = None,
self,
index: str,
pattern: Optional[str] = None,
batch: int = 10,
filter_str: Optional[str] = None,
vector: Optional[VectorType] = None,
max_dist: Optional[float] = None,
):
"""
scan all data entries in an index
Expand All @@ -340,14 +340,14 @@ def get_batch(c):
return TairVectorScanResult(self, get_batch)

def _tvs_scan(
self,
index: str,
cursor: int = 0,
count: Optional[int] = None,
pattern: Optional[str] = None,
filter_str: Optional[str] = None,
vector: Union[VectorType, bytes, None] = None,
max_dist: Optional[float] = None,
self,
index: str,
cursor: int = 0,
count: Optional[int] = None,
pattern: Optional[str] = None,
filter_str: Optional[str] = None,
vector: Union[VectorType, bytes, None] = None,
max_dist: Optional[float] = None,
):
args = [] if pattern is None else ["MATCH", pattern]
if count is not None:
Expand All @@ -374,13 +374,13 @@ def _tvs_scan(
MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH"

def tvs_knnsearch(
self,
index: str,
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: str,
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in an index
Expand All @@ -395,13 +395,13 @@ def tvs_knnsearch(
)

def tvs_mknnsearch(
self,
index: str,
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: str,
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
Expand Down Expand Up @@ -430,13 +430,13 @@ def tvs_mknnsearch(
)

def tvs_mindexknnsearch(
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in indexs
Expand All @@ -453,13 +453,13 @@ def tvs_mindexknnsearch(
)

def tvs_mindexmknnsearch(
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
Expand Down Expand Up @@ -492,13 +492,13 @@ def tvs_mindexmknnsearch(
GETDISTANCE_CMD = "TVS.GETDISTANCE"

def _tvs_getdistance(
self,
index_name: str,
vector: VectorType,
keys: Iterable[str],
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
self,
index_name: str,
vector: VectorType,
keys: Iterable[str],
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
):
"""
low level interface for TVS.GETDISTANCE
Expand All @@ -520,15 +520,15 @@ def _tvs_getdistance(
)

def tvs_getdistance(
self,
index_name: str,
vector: Union[VectorType, str, bytes],
keys: Iterable[str],
batch_size: int = 100000,
parallelism: int = 1,
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
self,
index_name: str,
vector: Union[VectorType, str, bytes],
keys: Iterable[str],
batch_size: int = 100000,
parallelism: int = 1,
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
):
"""
wrapped interface for TVS.GETDISTANCE
Expand Down Expand Up @@ -562,7 +562,7 @@ def process_batch(batch):

with ThreadPoolExecutor(max_workers=parallelism) as executor:
batches = [
keys[i : i + batch_size] for i in range(0, len(keys), batch_size)
keys[i: i + batch_size] for i in range(0, len(keys), batch_size)
]

futures = [executor.submit(process_batch, batch) for batch in batches]
Expand All @@ -579,10 +579,10 @@ def process_batch(batch):
)
queue = itertools.islice(queue, k)
return [(key, score) for score, key in queue]

HINCRBY_CMD = "TVS.HINCRBY"
HINCRBYFLOAT_CMD = "TVS.HINCRBYFLOAT"

def tvs_hincrby(self, index: str, key: str, field: str, num: int):
"""
increment the long value of a tairvector field by the given amount, not support field VECTOR
Expand All @@ -595,6 +595,31 @@ def tvs_hincrbyfloat(self, index: str, key: str, field: str, num: float):
"""
return self.execute_command(self.HINCRBYFLOAT_CMD, index, key, field, num)

def tvs_hexpire(self, index: str, key: str, ex: ExpiryT) -> ResponseT:
return self.execute_command("TVS.HEXPIRE", index, key, ex)

def tvs_hexpireat(self, index: str, key: str, exat: AbsExpiryT) -> ResponseT:
return self.execute_command("TVS.HEXPIREAT", index, key, exat)

def tvs_hpexpire(self, index: str, key: str, px: ExpiryT) -> ResponseT:
return self.execute_command("TVS.HPEXPIRE", index, key, px)

def tvs_hpexpireat(self, index: str, key: str, pxat: AbsExpiryT) -> ResponseT:
return self.execute_command("TVS.HPEXPIREAT", index, key, pxat)

def tvs_httl(self, index: str, key: str) -> ResponseT:
return self.execute_command("TVS.HTTL", index, key)

def tvs_hpttl(self, index: str, key: str) -> ResponseT:
return self.execute_command("TVS.HPTTL", index, key)

def tvs_hexpiretime(self, index: str, key: str) -> ResponseT:
return self.execute_command("TVS.HEXPIRETIME", index, key)

def tvs_hpexpiretime(self, index: str, key: str) -> ResponseT:
return self.execute_command("TVS.HPEXPIRETIME", index, key)


def parse_tvs_get_index_result(resp) -> Union[Dict, None]:
if len(resp) == 0:
return None
Expand Down
Loading

0 comments on commit 7f15861

Please sign in to comment.