Skip to content

Commit

Permalink
cosmosdbnosql: Added Cosmos DB NoSQL Semantic Cache Integration with …
Browse files Browse the repository at this point in the history
…tests and jupyter notebook (#24424)

* Added Cosmos DB NoSQL Semantic Cache Integration with tests and
jupyter notebook

---------

Co-authored-by: Aayush Kataria <[email protected]>
Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent 27a9056 commit cdf6202
Show file tree
Hide file tree
Showing 6 changed files with 495 additions and 81 deletions.
155 changes: 149 additions & 6 deletions docs/docs/integrations/llm_caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"id": "88486f6f",
"execution_count": null,
"id": "f938e881",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -30,12 +30,12 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "10ad9224",
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-12T02:05:57.319706Z",
"start_time": "2024-04-12T02:05:57.303868Z"
"end_time": "2024-12-06T00:54:06.474593Z",
"start_time": "2024-12-06T00:53:58.727138Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -1820,7 +1820,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": null,
"id": "bc1570a2a77b58c8",
"metadata": {
"ExecuteTime": {
Expand Down Expand Up @@ -1848,12 +1848,155 @@
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm.invoke(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "235ff73bf7143f13",
"metadata": {},
"source": [
"## Azure CosmosDB NoSql Semantic Cache\n",
"\n",
"You can use this integrated [vector database](https://learn.microsoft.com/en-us/azure/cosmos-db/vector-database) for caching."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41fea5aa7b2153ca",
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-06T00:55:38.648972Z",
"start_time": "2024-12-06T00:55:38.290541Z"
}
},
"outputs": [],
"source": [
"from typing import Any, Dict\n",
"\n",
"from azure.cosmos import CosmosClient, PartitionKey\n",
"from langchain_community.cache import AzureCosmosDBNoSqlSemanticCache\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"HOST = \"COSMOS_DB_URI\"\n",
"KEY = \"COSMOS_DB_KEY\"\n",
"\n",
"cosmos_client = CosmosClient(HOST, KEY)\n",
"\n",
"\n",
"def get_vector_indexing_policy() -> dict:\n",
" return {\n",
" \"indexingMode\": \"consistent\",\n",
" \"includedPaths\": [{\"path\": \"/*\"}],\n",
" \"excludedPaths\": [{\"path\": '/\"_etag\"/?'}],\n",
" \"vectorIndexes\": [{\"path\": \"/embedding\", \"type\": \"diskANN\"}],\n",
" }\n",
"\n",
"\n",
"def get_vector_embedding_policy() -> dict:\n",
" return {\n",
" \"vectorEmbeddings\": [\n",
" {\n",
" \"path\": \"/embedding\",\n",
" \"dataType\": \"float32\",\n",
" \"dimensions\": 1536,\n",
" \"distanceFunction\": \"cosine\",\n",
" }\n",
" ]\n",
" }\n",
"\n",
"\n",
"cosmos_container_properties_test = {\"partition_key\": PartitionKey(path=\"/id\")}\n",
"cosmos_database_properties_test: Dict[str, Any] = {}\n",
"\n",
"set_llm_cache(\n",
" AzureCosmosDBNoSqlSemanticCache(\n",
" cosmos_client=cosmos_client,\n",
" embedding=OpenAIEmbeddings(),\n",
" vector_embedding_policy=get_vector_embedding_policy(),\n",
" indexing_policy=get_vector_indexing_policy(),\n",
" cosmos_container_properties=cosmos_container_properties_test,\n",
" cosmos_database_properties=cosmos_database_properties_test,\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1e1cd93819921bf6",
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-06T00:55:44.513080Z",
"start_time": "2024-12-06T00:55:41.353843Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 374 ms, sys: 34.2 ms, total: 408 ms\n",
"Wall time: 3.15 s\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The first time, it is not yet in cache, so it should take longer\n",
"llm.invoke(\"Tell me a joke\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "576ce24c1244812a",
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-06T00:55:50.925865Z",
"start_time": "2024-12-06T00:55:50.548520Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 17.7 ms, sys: 2.88 ms, total: 20.6 ms\n",
"Wall time: 373 ms\n"
]
},
{
"data": {
"text/plain": [
"\"\\n\\nWhy couldn't the bicycle stand up by itself? Because it was two-tired!\""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%time\n",
"# The second time it is, so it goes faster\n",
"llm.invoke(\"Tell me a joke\")"
]
},
{
"cell_type": "markdown",
"id": "306ff47b",
Expand Down
114 changes: 111 additions & 3 deletions libs/community/langchain_community/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@
from langchain_community.utilities.astradb import (
_AstraDBCollectionEnvironment,
)
from langchain_community.vectorstores import AzureCosmosDBVectorSearch
from langchain_community.vectorstores import (
AzureCosmosDBNoSqlVectorSearch,
AzureCosmosDBVectorSearch,
)
from langchain_community.vectorstores import (
OpenSearchVectorSearch as OpenSearchVectorStore,
)
Expand All @@ -93,6 +96,7 @@
import momento
import pymemcache
from astrapy.db import AstraDB, AsyncAstraDB
from azure.cosmos.cosmos_client import CosmosClient
from cassandra.cluster import Session as CassandraSession


Expand Down Expand Up @@ -2103,7 +2107,7 @@ def __init__(
ef_construction: int = 64,
ef_search: int = 40,
score_threshold: Optional[float] = None,
application_name: str = "LANGCHAIN_CACHING_PYTHON",
application_name: str = "LangChain-CDBMongoVCore-SemanticCache-Python",
):
"""
Args:
Expand Down Expand Up @@ -2268,14 +2272,118 @@ def clear(self, **kwargs: Any) -> None:
index_name = self._index_name(kwargs["llm_string"])
if index_name in self._cache_dict:
self._cache_dict[index_name].get_collection().delete_many({})
# self._cache_dict[index_name].clear_collection()

@staticmethod
def _validate_enum_value(value: Any, enum_type: Type[Enum]) -> None:
if not isinstance(value, enum_type):
raise ValueError(f"Invalid enum value: {value}. Expected {enum_type}.")


class AzureCosmosDBNoSqlSemanticCache(BaseCache):
"""Cache that uses Cosmos DB NoSQL backend"""

def __init__(
self,
embedding: Embeddings,
cosmos_client: CosmosClient,
database_name: str = "CosmosNoSqlCacheDB",
container_name: str = "CosmosNoSqlCacheContainer",
*,
vector_embedding_policy: Dict[str, Any],
indexing_policy: Dict[str, Any],
cosmos_container_properties: Dict[str, Any],
cosmos_database_properties: Dict[str, Any],
create_container: bool = True,
):
self.cosmos_client = cosmos_client
self.database_name = database_name
self.container_name = container_name
self.embedding = embedding
self.vector_embedding_policy = vector_embedding_policy
self.indexing_policy = indexing_policy
self.cosmos_container_properties = cosmos_container_properties
self.cosmos_database_properties = cosmos_database_properties
self.create_container = create_container
self._cache_dict: Dict[str, AzureCosmosDBNoSqlVectorSearch] = {}

def _cache_name(self, llm_string: str) -> str:
hashed_index = _hash(llm_string)
return f"cache:{hashed_index}"

def _get_llm_cache(self, llm_string: str) -> AzureCosmosDBNoSqlVectorSearch:
cache_name = self._cache_name(llm_string)

# return vectorstore client for the specific llm string
if cache_name in self._cache_dict:
return self._cache_dict[cache_name]

# create new vectorstore client to create the cache
if self.cosmos_client:
self._cache_dict[cache_name] = AzureCosmosDBNoSqlVectorSearch(
cosmos_client=self.cosmos_client,
embedding=self.embedding,
vector_embedding_policy=self.vector_embedding_policy,
indexing_policy=self.indexing_policy,
cosmos_container_properties=self.cosmos_container_properties,
cosmos_database_properties=self.cosmos_database_properties,
database_name=self.database_name,
container_name=self.container_name,
create_container=self.create_container,
)

return self._cache_dict[cache_name]

def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt."""
llm_cache = self._get_llm_cache(llm_string)
generations: List = []
# Read from a Hash
results = llm_cache.similarity_search(
query=prompt,
k=1,
)
if results:
for document in results:
try:
generations.extend(loads(document.metadata["return_val"]))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
"properly. This is likely due to the cache being in an "
"older format. Please recreate your cache to avoid this "
"error."
)

generations.extend(
_load_generations_from_json(document.metadata["return_val"])
)
return generations if generations else None

def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:
"""Update cache based on prompt and llm_string."""
for gen in return_val:
if not isinstance(gen, Generation):
raise ValueError(
"CosmosDBNoSqlSemanticCache only supports caching of "
f"normal LLM generations, got {type(gen)}"
)
llm_cache = self._get_llm_cache(llm_string)
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": dumps([g for g in return_val]),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])

def clear(self, **kwargs: Any) -> None:
"""Clear semantic cache for a given llm_string."""
cache_name = self._cache_name(llm_string=kwargs["llm-string"])
if cache_name in self._cache_dict:
container = self._cache_dict["cache_name"].get_container()
for item in container.read_all_items():
container.delete_item(item)


class OpenSearchSemanticCache(BaseCache):
"""Cache that uses OpenSearch vector store backend"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
index_name: str = "vectorSearchIndex",
text_key: str = "textContent",
embedding_key: str = "vectorContent",
application_name: str = "LANGCHAIN_PYTHON",
application_name: str = "LangChain-CDBMongoVCore-VectorStore-Python",
):
"""Constructor for AzureCosmosDBVectorSearch
Expand Down Expand Up @@ -121,7 +121,7 @@ def from_connection_string(
connection_string: str,
namespace: str,
embedding: Embeddings,
application_name: str = "LANGCHAIN_PYTHON",
application_name: str = "LangChain-CDBMongoVCore-VectorStore-Python",
**kwargs: Any,
) -> AzureCosmosDBVectorSearch:
"""Creates an Instance of AzureCosmosDBVectorSearch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:
from azure.cosmos import CosmosClient
from azure.cosmos import ContainerProxy, CosmosClient
from azure.identity import DefaultAzureCredential

USER_AGENT = ("LangChain-CDBNoSql-VectorStore-Python",)
Expand Down Expand Up @@ -859,3 +859,6 @@ def _where_clause_operator_map(self) -> Dict[str, str]:
"$full_text_contains_any": "FullTextContainsAny",
}
return operator_map

def get_container(self) -> ContainerProxy:
return self._container
Loading

0 comments on commit cdf6202

Please sign in to comment.