Skip to content

Commit

Permalink
Update Faiss ANN to support IVF strings without number of cells, closes
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Nov 4, 2023
1 parent 826a248 commit 7176802
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
26 changes: 24 additions & 2 deletions src/python/txtai/ann/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def configure(self, count, train):
components = self.setting("components")

if components:
return components
# Format and return components string
return self.components(components, train)

# Derive quantization. Prefer backend-specific setting. Fallback to root-level parameter.
quantize = self.setting("quantize", self.config.get("quantize"))
Expand Down Expand Up @@ -171,7 +172,28 @@ def cells(self, count):

# Calculate number of IVF cells where x = min(4 * sqrt(embeddings count), embeddings count / 39)
# Faiss requires at least 39 * x data points
return min(round(4 * math.sqrt(count)), int(count / 39))
return max(min(round(4 * math.sqrt(count)), int(count / 39)), 1)

def components(self, components, train):
"""
Formats a components string. This method automatically calculates the optimal number of IVF cells, if omitted.
Args:
components: input components string
train: number of rows selected for model training
Returns:
formatted components string
"""

# Optimal number of IVF cells
x = self.cells(train)

# Add number of IVF cells, if missing
components = [f"IVF{x}" if component == "IVF" else component for component in components.split(",")]

# Return components string
return ",".join(components)

def nprobe(self):
"""
Expand Down
1 change: 1 addition & 0 deletions test/python/testann.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def testFaissCustom(self):

# Test with custom settings
self.runTests("faiss", {"faiss": {"nprobe": 2, "components": "PCA16,IDMap,SQ8", "sample": 1.0}}, False)
self.runTests("faiss", {"faiss": {"components": "IVF,SQ8"}}, False)

@unittest.skipIf(os.name == "nt", "mmap not supported on Windows")
def testFaissMmap(self):
Expand Down

0 comments on commit 7176802

Please sign in to comment.