Skip to content

Commit

Permalink
[ENH]: SHA256 sum check of Chroma's onnx model. (#1493)
Browse files Browse the repository at this point in the history
Refs: #883

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Verify ONNX all-MiniLM-L6 model model download from s3 with static
SHA256 (within the python code)

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python

## Documentation Changes
N/A
  • Loading branch information
tazarov authored Dec 19, 2023
1 parent 1ba6eac commit 49f9c14
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
32 changes: 30 additions & 2 deletions chromadb/test/ef/test_default_ef.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import shutil
import os
from typing import List, Hashable

import hypothesis.strategies as st
import onnxruntime
import pytest
from hypothesis import given
from hypothesis import given, settings

from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256


def unique_by(x: Hashable) -> Hashable:
return x


@settings(deadline=None)
@given(
providers=st.lists(
st.sampled_from(onnxruntime.get_all_providers()).filter(
Expand Down Expand Up @@ -60,3 +63,28 @@ def test_provider_repeating(providers: List[str]) -> None:
ef = ONNXMiniLM_L6_V2(preferred_providers=providers)
ef(["test"])
assert "Preferred providers must be unique" in str(e.value)


def test_invalid_sha256() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models
with pytest.raises(ValueError) as e:
ef._MODEL_SHA256 = "invalid"
ef(["test"])
assert "does not match expected SHA256 hash" in str(e.value)


def test_partial_download() -> None:
ef = ONNXMiniLM_L6_V2()
shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models
os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True)
path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)
with open(path, "wb") as f: # create invalid file to simulate partial download
f.write(b"invalid")
ef._download_model_if_not_exists() # re-download model
assert os.path.exists(path)
assert _verify_sha256(
str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)),
ef._MODEL_SHA256,
)
assert len(ef(["test"])) == 1
29 changes: 29 additions & 0 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import hashlib
import logging

from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception

from chromadb.api.types import (
Document,
Documents,
Expand Down Expand Up @@ -31,6 +34,16 @@
logger = logging.getLogger(__name__)


def _verify_sha256(fname: str, expected_sha256: str) -> bool:
sha256_hash = hashlib.sha256()
with open(fname, "rb") as f:
# Read and update hash in chunks to avoid using too much memory
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)

return sha256_hash.hexdigest() == expected_sha256


class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]):
# Since we do dynamic imports we have to type this as Any
models: Dict[str, Any] = {}
Expand Down Expand Up @@ -346,6 +359,7 @@ class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]):
MODEL_DOWNLOAD_URL = (
"https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz"
)
_MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3"
tokenizer = None
model = None

Expand Down Expand Up @@ -389,6 +403,12 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None:

# Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
# Download with tqdm to preserve the sentence-transformers experience
@retry(
reraise=True,
stop=stop_after_attempt(3),
wait=wait_random(min=1, max=3),
retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)),
)
def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
Expand All @@ -402,6 +422,12 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None:
for data in resp.iter_content(chunk_size=chunk_size):
size = file.write(data)
bar.update(size)
if not _verify_sha256(fname, self._MODEL_SHA256):
# if the integrity of the file is not verified, remove it
os.remove(fname)
raise ValueError(
f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file."
)

# Use pytorches default epsilon for division by zero
# https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
Expand Down Expand Up @@ -503,6 +529,9 @@ def _download_model_if_not_exists(self) -> None:
os.makedirs(self.DOWNLOAD_PATH, exist_ok=True)
if not os.path.exists(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME)
) or not _verify_sha256(
os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME),
self._MODEL_SHA256,
):
self._download(
url=self.MODEL_DOWNLOAD_URL,
Expand Down

0 comments on commit 49f9c14

Please sign in to comment.