From 4fec4ca9f5a802dc2d4cf8dc67fd35f881c786e8 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Thu, 17 Oct 2024 13:16:59 +0200 Subject: [PATCH] Add tests for Static Embeddings --- tests/models/test_static_embedding.py | 67 +++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/models/test_static_embedding.py diff --git a/tests/models/test_static_embedding.py b/tests/models/test_static_embedding.py new file mode 100644 index 000000000..983cb4a8d --- /dev/null +++ b/tests/models/test_static_embedding.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +from tokenizers import Tokenizer + +from sentence_transformers.models.StaticEmbedding import StaticEmbedding + + +@pytest.fixture +def tokenizer() -> Tokenizer: + return Tokenizer.from_pretrained("bert-base-uncased") + + +@pytest.fixture +def embedding_weights(): + return np.random.rand(30522, 768) + + +@pytest.fixture +def static_embedding(tokenizer: Tokenizer, embedding_weights) -> StaticEmbedding: + return StaticEmbedding(tokenizer, embedding_weights=embedding_weights) + + +def test_initialization_with_embedding_weights(tokenizer: Tokenizer, embedding_weights) -> None: + model = StaticEmbedding(tokenizer, embedding_weights=embedding_weights) + assert model.embedding.weight.shape == (30522, 768) + + +def test_initialization_with_embedding_dim(tokenizer: Tokenizer) -> None: + model = StaticEmbedding(tokenizer, embedding_dim=768) + assert model.embedding.weight.shape == (30522, 768) + + +def test_tokenize(static_embedding: StaticEmbedding) -> None: + texts = ["Hello world!", "How are you?"] + tokens = static_embedding.tokenize(texts) + assert "input_ids" in tokens + assert "offsets" in tokens + + +def test_forward(static_embedding: StaticEmbedding) -> None: + texts = ["Hello world!", "How are you?"] + tokens = static_embedding.tokenize(texts) + output = static_embedding(tokens) + assert "sentence_embedding" in output + + +def test_save_and_load(tmp_path: Path, static_embedding: StaticEmbedding) -> None: + save_dir = tmp_path / "model" + save_dir.mkdir() + static_embedding.save(str(save_dir)) + + loaded_model = StaticEmbedding.load(str(save_dir)) + assert loaded_model.embedding.weight.shape == static_embedding.embedding.weight.shape + + +def test_from_distillation() -> None: + model = StaticEmbedding.from_distillation("sentence-transformers-testing/stsb-bert-tiny-safetensors", pca_dims=32) + assert model.embedding.weight.shape == (29528, 32) + + +def test_from_model2vec() -> None: + model = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") + assert model.embedding.weight.shape == (29528, 256)