Skip to content

Commit

Permalink
Add tests for Static Embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Oct 17, 2024
1 parent 66e3e07 commit 4fec4ca
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions tests/models/test_static_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
from tokenizers import Tokenizer

from sentence_transformers.models.StaticEmbedding import StaticEmbedding


@pytest.fixture
def tokenizer() -> Tokenizer:
return Tokenizer.from_pretrained("bert-base-uncased")


@pytest.fixture
def embedding_weights():
return np.random.rand(30522, 768)


@pytest.fixture
def static_embedding(tokenizer: Tokenizer, embedding_weights) -> StaticEmbedding:
return StaticEmbedding(tokenizer, embedding_weights=embedding_weights)


def test_initialization_with_embedding_weights(tokenizer: Tokenizer, embedding_weights) -> None:
model = StaticEmbedding(tokenizer, embedding_weights=embedding_weights)
assert model.embedding.weight.shape == (30522, 768)


def test_initialization_with_embedding_dim(tokenizer: Tokenizer) -> None:
model = StaticEmbedding(tokenizer, embedding_dim=768)
assert model.embedding.weight.shape == (30522, 768)


def test_tokenize(static_embedding: StaticEmbedding) -> None:
texts = ["Hello world!", "How are you?"]
tokens = static_embedding.tokenize(texts)
assert "input_ids" in tokens
assert "offsets" in tokens


def test_forward(static_embedding: StaticEmbedding) -> None:
texts = ["Hello world!", "How are you?"]
tokens = static_embedding.tokenize(texts)
output = static_embedding(tokens)
assert "sentence_embedding" in output


def test_save_and_load(tmp_path: Path, static_embedding: StaticEmbedding) -> None:
save_dir = tmp_path / "model"
save_dir.mkdir()
static_embedding.save(str(save_dir))

loaded_model = StaticEmbedding.load(str(save_dir))
assert loaded_model.embedding.weight.shape == static_embedding.embedding.weight.shape


def test_from_distillation() -> None:
model = StaticEmbedding.from_distillation("sentence-transformers-testing/stsb-bert-tiny-safetensors", pca_dims=32)
assert model.embedding.weight.shape == (29528, 32)


def test_from_model2vec() -> None:
model = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
assert model.embedding.weight.shape == (29528, 256)

0 comments on commit 4fec4ca

Please sign in to comment.