Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add and fix unit tests #8

Merged
merged 7 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Run Unit Tests

on:
push:
branches:
- main
- master
pull_request:
branches:
- main
- master

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11

- name: Install Poetry
run: |
python -m pip install --upgrade pip
curl -sSL https://install.python-poetry.org | python3 -
export PATH="$HOME/.local/bin:$PATH"

- name: Install dependencies
run: |
poetry install

- name: Run tests
run: |
poetry run pytest
2 changes: 1 addition & 1 deletion memonto/core/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def configure_model(model_provider: str, **config) -> LLMModel:
raise ValueError(f"LLM model {model_provider} not found")


def configure(
def _configure(
config: dict,
) -> Tuple[Optional[TripleStoreModel], Optional[LLMModel], Optional[VectorStoreModel]]:
triple_store, vector_store, llm = None, None, None
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions memonto/memonto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from rdflib import Graph, Namespace, URIRef
from typing import Optional, Union

from memonto.core.configure import configure
from memonto.core.configure import _configure
from memonto.core.init import init
from memonto.core.forget import _forget
from memonto.core.query import _retrieve
from memonto.core.retrieve import _retrieve
from memonto.core.recall import _recall
from memonto.core.remember import _remember
from memonto.core.render import _render
Expand Down Expand Up @@ -75,7 +75,7 @@ def configure(self, config: dict) -> None:

:return: None
"""
self.triple_store, self.vector_store, self.llm = configure(config=config)
self.triple_store, self.vector_store, self.llm = _configure(config=config)

@require_config("llm", "triple_store")
def retain(self, message: str) -> None:
Expand Down
120 changes: 52 additions & 68 deletions tests/core/test_configure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
from pydantic import ValidationError
from unittest.mock import MagicMock

from memonto.core.configure import configure
from memonto.core.configure import _configure
from memonto.llms.openai import OpenAI
from memonto.llms.anthropic import Anthropic
from memonto.stores.jena import ApacheJena
from memonto.stores.triple.jena import ApacheJena
from memonto.stores.vector.chroma import Chroma


@pytest.fixture
Expand Down Expand Up @@ -34,7 +34,7 @@ def api_key():


@pytest.fixture
def store_provider():
def triple_store_provider():
return "apache_jena"


Expand All @@ -44,36 +44,26 @@ def jena_url():


@pytest.fixture
def mock_memonto():
class MockMemonto:
def __init__(self):
self.store = None
self.llm = None
def vector_store_provider():
return "chroma"

return MagicMock(spec=MockMemonto)


def test_configure_with_unsupported_provider(mock_memonto, api_key):
def test_configure_with_unsupported_provider(api_key):
config = {
"model": {
"provider": "random_model_provider",
"config": {
"model": "randome_model_name",
"model": "random_model_name",
"api_key": api_key,
},
}
}

with pytest.raises(ValueError):
configure(self=mock_memonto, config=config)
_configure(config)


def test_configure_with_openai_config(
mock_memonto,
openai_provider,
openai_model,
api_key,
):
def test_configure_with_openai_config(openai_provider, openai_model, api_key):
config = {
"model": {
"provider": openai_provider,
Expand All @@ -84,40 +74,28 @@ def test_configure_with_openai_config(
}
}

configure(self=mock_memonto, config=config)

assert isinstance(mock_memonto.llm, OpenAI)
assert mock_memonto.llm.model == openai_model
assert mock_memonto.llm.api_key == api_key
ts, vs, llm = _configure(config)

assert isinstance(llm, OpenAI)
assert ts is None
assert vs is None

def test_configure_with_bad_openai_config(
mock_memonto,
openai_provider,
openai_model,
api_key,
):

def test_configure_with_bad_openai_config(openai_provider, api_key):
config = {
"model": {
"provider": openai_provider,
"model": openai_model,
"config": {
"api_key": api_key,
},
}
}

with pytest.raises(ValidationError):
configure(self=mock_memonto, config=config)
_configure(config)


def test_configure_with_anthropic_config(
mock_memonto,
anthropic_provider,
anthropic_model,
api_key,
):
def test_configure_with_anthropic_config(anthropic_provider, anthropic_model, api_key):
config = {
"model": {
"provider": anthropic_provider,
Expand All @@ -128,18 +106,14 @@ def test_configure_with_anthropic_config(
}
}

configure(self=mock_memonto, config=config)
ts, vs, llm = _configure(config)

assert isinstance(mock_memonto.llm, Anthropic)
assert mock_memonto.llm.model == anthropic_model
assert mock_memonto.llm.api_key == api_key
assert isinstance(llm, Anthropic)
assert ts is None
assert vs is None


def test_configure_with_bad_anthropic_config(
mock_memonto,
anthropic_provider,
anthropic_model,
):
def test_configure_with_bad_anthropic_config(anthropic_provider, anthropic_model):
config = {
"model": {
"provider": anthropic_provider,
Expand All @@ -150,43 +124,53 @@ def test_configure_with_bad_anthropic_config(
}

with pytest.raises(ValidationError):
configure(self=mock_memonto, config=config)
_configure(config)


def test_configure_with_apache_jena_config(
mock_memonto,
store_provider,
jena_url,
):
def test_configure_with_apache_jena_config(triple_store_provider, jena_url):
config = {
"store": {
"provider": store_provider,
"triple_store": {
"provider": triple_store_provider,
"config": {
"connection_url": jena_url,
},
},
}

configure(self=mock_memonto, config=config)
ts, vs, llm = _configure(config)

assert isinstance(mock_memonto.store, ApacheJena)
assert mock_memonto.store.name == store_provider
assert mock_memonto.store.connection_url == jena_url
assert isinstance(ts, ApacheJena)
assert vs is None
assert llm is None


def test_configure_with_bad_apache_jena_config(
mock_memonto,
store_provider,
jena_url,
):
def test_configure_with_bad_apache_jena_config(triple_store_provider, jena_url):
config = {
"store": {
"provider": store_provider,
"triple_store": {
"provider": triple_store_provider,
"config": {
"url": jena_url,
},
},
}

with pytest.raises(ValidationError):
configure(self=mock_memonto, config=config)
_configure(config)


def test_configure_with_chroma_config(vector_store_provider, jena_url):
config = {
"vector_store": {
"provider": vector_store_provider,
"config": {
"model": "local",
"path": ".local",
},
},
}

ts, vs, llm = _configure(config)

assert isinstance(vs, Chroma)
assert ts is None
assert llm is None
80 changes: 80 additions & 0 deletions tests/core/test_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from rdflib import Graph
from unittest.mock import ANY, MagicMock, patch

from memonto.core.recall import _recall


@pytest.fixture
def graph():
return Graph()


@pytest.fixture
def user_query():
return "some user query about Bismark"


@pytest.fixture
def id():
return "test-id-123"


@pytest.fixture
def mock_llm():
mock_llm = MagicMock()
mock_llm.prompt = MagicMock(return_value="some summary")
return mock_llm


@pytest.fixture
def mock_store():
mock_store = MagicMock()
return mock_store


@patch("memonto.core.recall._find_all")
def test_fetch_all_memory(mock_find_all, mock_llm, mock_store, id):
all_memory = "all memory"
mock_find_all.return_value = all_memory

_recall(
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
message=None,
id=id,
)

mock_llm.prompt.assert_called_once_with(
prompt_name="summarize_memory",
memory=all_memory,
)


@patch("memonto.core.recall._find_adjacent_triples")
@patch("memonto.core.recall._hydrate_triples")
def test_fetch_some_memory(
mock_hydrate_triples,
mock_find_adjacent_triples,
mock_llm,
mock_store,
user_query,
id,
):
some_memory = "some memory"
mock_find_adjacent_triples.return_value = some_memory
mock_hydrate_triples.return_value = []

_recall(
llm=mock_llm,
vector_store=mock_store,
triple_store=mock_store,
message=user_query,
id=id,
)

mock_llm.prompt.assert_called_once_with(
prompt_name="summarize_memory",
memory=some_memory,
)
25 changes: 0 additions & 25 deletions tests/core/test_remember.py

This file was deleted.

Loading
Loading