Skip to content

Commit

Permalink
Add quantization option for pgvector with support for halfvec
Browse files Browse the repository at this point in the history
  • Loading branch information
lucagiac81 authored and alwayslove2013 committed Sep 2, 2024
1 parent 753c46d commit b364fe3
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 36 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ Options:
--m INTEGER hnsw m
--ef-construction INTEGER hnsw ef-construction
--ef-search INTEGER hnsw ef-search
--quantization-type [none|halfvec]
quantization type for vectors
--help Show this message and exit.
```
#### Using a configuration file.
Expand Down
16 changes: 14 additions & 2 deletions vectordb_bench/backend/clients/pgvector/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,15 @@ class PgVectorTypedDict(CommonTypedDict):
required=False,
),
]

quantization_type: Annotated[
Optional[str],
click.option(
"--quantization-type",
type=click.Choice(["none", "halfvec"]),
help="quantization type for vectors",
required=False,
),
]

class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
...
Expand All @@ -79,7 +87,10 @@ def PgVectorIVFFlat(
db_name=parameters["db_name"],
),
db_case_config=PgVectorIVFFlatConfig(
metric_type=None, lists=parameters["lists"], probes=parameters["probes"]
metric_type=None,
lists=parameters["lists"],
probes=parameters["probes"],
quantization_type=parameters["quantization_type"],
),
**parameters,
)
Expand Down Expand Up @@ -111,6 +122,7 @@ def PgVectorHNSW(
ef_search=parameters["ef_search"],
maintenance_work_mem=parameters["maintenance_work_mem"],
max_parallel_workers=parameters["max_parallel_workers"],
quantization_type=parameters["quantization_type"],
),
**parameters,
)
25 changes: 20 additions & 5 deletions vectordb_bench/backend/clients/pgvector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,18 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
create_index_after_load: bool = True

def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"
if self.quantization_type == "halfvec":
if self.metric_type == MetricType.L2:
return "halfvec_l2_ops"
elif self.metric_type == MetricType.IP:
return "halfvec_ip_ops"
return "halfvec_cosine_ops"
else:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
elif self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"

def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
Expand Down Expand Up @@ -143,9 +150,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
index: IndexType = IndexType.ES_IVFFlat
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"lists": self.lists}
if self.quantization_type == "none":
self.quantization_type = None
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -154,6 +164,7 @@ def index_param(self) -> PgVectorIndexParam:
),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down Expand Up @@ -183,9 +194,12 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
index: IndexType = IndexType.ES_HNSW
maintenance_work_mem: Optional[str] = None
max_parallel_workers: Optional[int] = None
quantization_type: Optional[str] = None

def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
if self.quantization_type == "none":
self.quantization_type = None
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
Expand All @@ -194,6 +208,7 @@ def index_param(self) -> PgVectorIndexParam:
),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
}

def search_param(self) -> PgVectorSearchParam:
Expand Down
113 changes: 84 additions & 29 deletions vectordb_bench/backend/clients/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,63 @@ def init(self) -> Generator[None, None, None]:
self.cursor.execute(command)
self.conn.commit()

self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)
index_param = self.case_config.index_param()
# The following sections assume that the quantization_type value matches the quantization function name
if index_param["quantization_type"] != None:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._filtered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
sql.Identifier(self.table_name)
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)
if index_param["quantization_type"] != None:
self._unfiltered_search = sql.Composed(
[
sql.SQL(
"SELECT id FROM public.{table_name} ORDER BY embedding::{quantization_type}({dim}) "
).format(
table_name=sql.Identifier(self.table_name),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::{quantization_type}({dim}) LIMIT %s::int").format(
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
),
]
)
else:
self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
sql.Identifier(self.table_name)
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
]
)

try:
yield
Expand Down Expand Up @@ -265,17 +303,34 @@ def _create_index(self):
else:
with_clause = sql.Composed(())

index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"]),
embedding_metric=sql.Identifier(index_param["metric"]),
)
if index_param["quantization_type"] != None:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} ((embedding::{quantization_type}({dim})) {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"]),
# This assumes that the quantization_type value matches the quantization function name
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=self.dim,
embedding_metric=sql.Identifier(index_param["metric"]),
)
else:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
"""
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"]),
embedding_metric=sql.Identifier(index_param["metric"]),
)

index_create_sql_with_with_clause = (
index_create_sql + with_clause
).join(" ")
Expand Down
15 changes: 15 additions & 0 deletions vectordb_bench/frontend/config/dbCaseConfigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,19 @@ class CaseConfigInput(BaseModel):
],
)

CaseConfigParamInput_QuantizationType_PgVector = CaseConfigInput(
label=CaseConfigParamType.quantizationType,
inputType=InputType.Option,
inputConfig={
"options": ["none", "halfvec"],
},
isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
IndexType.IVFFlat.value,
],
)

CaseConfigParamInput_QuantizationRatio_PgVectoRS = CaseConfigInput(
label=CaseConfigParamType.quantizationRatio,
inputType=InputType.Option,
Expand Down Expand Up @@ -831,6 +844,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_Lists_PgVector,
CaseConfigParamInput_m,
CaseConfigParamInput_EFConstruction_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
]
Expand All @@ -841,6 +855,7 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_EFSearch_PgVector,
CaseConfigParamInput_Lists_PgVector,
CaseConfigParamInput_Probes_PgVector,
CaseConfigParamInput_QuantizationType_PgVector,
CaseConfigParamInput_maintenance_work_mem_PgVector,
CaseConfigParamInput_max_parallel_workers_PgVector,
]
Expand Down

0 comments on commit b364fe3

Please sign in to comment.