From 49f9c14f737ad8e05a27b4a487ccf3b18bec8b12 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 19 Dec 2023 23:56:33 +0200 Subject: [PATCH] [ENH]: SHA256 sum check of Chroma's onnx model. (#1493) Refs: #883 ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Verify ONNX all-MiniLM-L6 model model download from s3 with static SHA256 (within the python code) ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python ## Documentation Changes N/A --- chromadb/test/ef/test_default_ef.py | 32 +++++++++++++++++++++++++-- chromadb/utils/embedding_functions.py | 29 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/chromadb/test/ef/test_default_ef.py b/chromadb/test/ef/test_default_ef.py index e1ed5520660..6d8fb623698 100644 --- a/chromadb/test/ef/test_default_ef.py +++ b/chromadb/test/ef/test_default_ef.py @@ -1,17 +1,20 @@ +import shutil +import os from typing import List, Hashable import hypothesis.strategies as st import onnxruntime import pytest -from hypothesis import given +from hypothesis import given, settings -from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2 +from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256 def unique_by(x: Hashable) -> Hashable: return x +@settings(deadline=None) @given( providers=st.lists( st.sampled_from(onnxruntime.get_all_providers()).filter( @@ -60,3 +63,28 @@ def test_provider_repeating(providers: List[str]) -> None: ef = ONNXMiniLM_L6_V2(preferred_providers=providers) ef(["test"]) assert "Preferred providers must be unique" in str(e.value) + + +def test_invalid_sha256() -> None: + ef = ONNXMiniLM_L6_V2() + shutil.rmtree(ef.DOWNLOAD_PATH) # clean up any existing models + with pytest.raises(ValueError) as e: + ef._MODEL_SHA256 = "invalid" + ef(["test"]) + assert "does not match expected SHA256 hash" in str(e.value) + + +def test_partial_download() -> None: + ef = ONNXMiniLM_L6_V2() + shutil.rmtree(ef.DOWNLOAD_PATH, ignore_errors=True) # clean up any existing models + os.makedirs(ef.DOWNLOAD_PATH, exist_ok=True) + path = os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME) + with open(path, "wb") as f: # create invalid file to simulate partial download + f.write(b"invalid") + ef._download_model_if_not_exists() # re-download model + assert os.path.exists(path) + assert _verify_sha256( + str(os.path.join(ef.DOWNLOAD_PATH, ef.ARCHIVE_FILENAME)), + ef._MODEL_SHA256, + ) + assert len(ef(["test"])) == 1 diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 05a4d03887f..dcb2e6f410a 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -1,5 +1,8 @@ +import hashlib import logging +from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception + from chromadb.api.types import ( Document, Documents, @@ -31,6 +34,16 @@ logger = logging.getLogger(__name__) +def _verify_sha256(fname: str, expected_sha256: str) -> bool: + sha256_hash = hashlib.sha256() + with open(fname, "rb") as f: + # Read and update hash in chunks to avoid using too much memory + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + + return sha256_hash.hexdigest() == expected_sha256 + + class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {} @@ -346,6 +359,7 @@ class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): MODEL_DOWNLOAD_URL = ( "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" ) + _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" tokenizer = None model = None @@ -389,6 +403,12 @@ def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 # Download with tqdm to preserve the sentence-transformers experience + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random(min=1, max=3), + retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), + ) def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) @@ -402,6 +422,12 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: for data in resp.iter_content(chunk_size=chunk_size): size = file.write(data) bar.update(size) + if not _verify_sha256(fname, self._MODEL_SHA256): + # if the integrity of the file is not verified, remove it + os.remove(fname) + raise ValueError( + f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." + ) # Use pytorches default epsilon for division by zero # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html @@ -503,6 +529,9 @@ def _download_model_if_not_exists(self) -> None: os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) if not os.path.exists( os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) + ) or not _verify_sha256( + os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + self._MODEL_SHA256, ): self._download( url=self.MODEL_DOWNLOAD_URL,