Skip to content

Commit

Permalink
testcases for community detection
Browse files Browse the repository at this point in the history
  • Loading branch information
JINO-ROHIT committed Jan 11, 2025
1 parent cccab83 commit 38d4db1
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import torch

from sentence_transformers import SentenceTransformer, util
from sentence_transformers.util import community_detection

import pytest

def test_normalize_embeddings() -> None:
"""Tests the correct computation of util.normalize_embeddings"""
Expand Down Expand Up @@ -145,3 +147,109 @@ def test_dot_score_cos_sim() -> None:

assert np.allclose(cosine_calculated, dot_and_cosine_expected)
assert np.allclose(dot_calculated, dot_and_cosine_expected)

def test_two_clear_communities():
"""Test case with two clear communities."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.1, 0.9, 0.0], # Point 3
[0.0, 1.0, 0.0], # Point 4
[0.2, 0.8, 0.0], # Point 5
])
expected = [
[0, 1, 2], # Community 1
[3, 4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_no_communities_high_threshold():
"""Test case where no communities are found due to a high threshold."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
])
expected = []
result = community_detection(embeddings, threshold=0.99, min_community_size=2)
assert result == expected

def test_all_points_in_one_community():
"""Test case where all points form a single community due to a low threshold."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
])
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.5, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_min_community_size_filtering():
"""Test case where communities are filtered based on minimum size."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
[0.1, 0.9, 0.0],
])
expected = [
[0, 1, 2], # Only one community meets the min size requirement
]
result = community_detection(embeddings, threshold=0.8, min_community_size=3)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_overlapping_communities():
"""Test case with overlapping communities (resolved by the function)."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0], # Point 0
[0.9, 0.1, 0.0], # Point 1
[0.8, 0.2, 0.0], # Point 2
[0.7, 0.3, 0.0], # Point 3 (overlaps with both communities)
[0.1, 0.9, 0.0], # Point 4
[0.0, 1.0, 0.0], # Point 5
])
expected = [
[0, 1, 2, 3], # Community 1 (includes overlapping point 3)
[4, 5], # Community 2
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_numpy_input():
"""Test case where input is a numpy array instead of a torch tensor."""
embeddings = np.array([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
])
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

def test_large_batch_size():
"""Test case with a large dataset and batching."""
embeddings = torch.rand(1000, 128) # Random embeddings
result = community_detection(embeddings, threshold=0.8, min_community_size=10, batch_size=256)
# Check that all communities meet the minimum size requirement
assert all(len(community) >= 10 for community in result)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU not available")
def test_gpu_support():
"""Test case for GPU support (if available)."""
embeddings = torch.tensor([
[1.0, 0.0, 0.0],
[0.9, 0.1, 0.0],
[0.8, 0.2, 0.0],
]).cuda()
expected = [
[0, 1, 2], # Single community
]
result = community_detection(embeddings, threshold=0.8, min_community_size=2)
assert sorted([sorted(community) for community in result]) == sorted([sorted(community) for community in expected])

0 comments on commit 38d4db1

Please sign in to comment.