langgraph/how-tos/persistence_redis/ #1123
Replies: 18 comments 21 replies
-
I ran into an issue where my graph was not always grabbing the most recent checkpointed state. Upon investigation I found that the I think there is an issue with the logic to grab the latest key in Wouldn't we want to compare the entire timestamp string? something like: |
Beta Was this translation helpful? Give feedback.
-
The one given in the above guide have a bug in "put/aput" function I have changed the code a little bit with split changed to "$" """Implementation of a langgraph checkpoint saver using Redis."""
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator, List, Union, Tuple, Optional
import redis
from redis.asyncio import Redis as AsyncRedis, ConnectionPool as AsyncConnectionPool
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint import BaseCheckpointSaver
from langgraph.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.base import Checkpoint, CheckpointMetadata, CheckpointTuple
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class JsonAndBinarySerializer(JsonPlusSerializer):
def _default(self, obj: Any) -> Any:
if isinstance(obj, (bytes, bytearray)):
return self._encode_constructor_args(
obj.__class__, method="fromhex", args=[obj.hex()]
)
return super()._default(obj)
def dumps(self, obj: Any) -> str:
try:
if isinstance(obj, (bytes, bytearray)):
return obj.hex()
return super().dumps(obj)
except Exception as e:
logger.error(f"Serialization error: {e}")
raise
def loads(self, s: str, is_binary: bool = False) -> Any:
try:
if is_binary:
return bytes.fromhex(s)
return super().loads(s)
except Exception as e:
logger.error(f"Deserialization error: {e}")
raise
def initialize_sync_pool(
host: str = "localhost", port: int = 6379, db: int = 0, **kwargs
) -> redis.ConnectionPool:
"""Initialize a synchronous Redis connection pool."""
try:
pool = redis.ConnectionPool(host=host, port=port, db=db, **kwargs)
logger.info(
f"Synchronous Redis pool initialized with host={host}, port={port}, db={db}"
)
return pool
except Exception as e:
logger.error(f"Error initializing sync pool: {e}")
raise
def initialize_async_pool(
url: str = "redis://localhost", **kwargs
) -> AsyncConnectionPool:
"""Initialize an asynchronous Redis connection pool."""
try:
pool = AsyncConnectionPool.from_url(url, **kwargs)
logger.info(f"Asynchronous Redis pool initialized with url={url}")
return pool
except Exception as e:
logger.error(f"Error initializing async pool: {e}")
raise
@contextmanager
def _get_sync_connection(
connection: Union[redis.Redis, redis.ConnectionPool, None]
) -> Generator[redis.Redis, None, None]:
conn = None
try:
if isinstance(connection, redis.Redis):
yield connection
elif isinstance(connection, redis.ConnectionPool):
conn = redis.Redis(connection_pool=connection)
yield conn
else:
raise ValueError("Invalid sync connection object.")
except redis.ConnectionError as e:
logger.error(f"Sync connection error: {e}")
raise
finally:
if conn:
conn.close()
@asynccontextmanager
async def _get_async_connection(
connection: Union[AsyncRedis, AsyncConnectionPool, None]
) -> AsyncGenerator[AsyncRedis, None]:
conn = None
try:
if isinstance(connection, AsyncRedis):
yield connection
elif isinstance(connection, AsyncConnectionPool):
conn = AsyncRedis(connection_pool=connection)
yield conn
else:
raise ValueError("Invalid async connection object.")
except redis.ConnectionError as e:
logger.error(f"Async connection error: {e}")
raise
finally:
if conn:
await conn.aclose()
class RedisSaver(BaseCheckpointSaver):
sync_connection: Optional[Union[redis.Redis, redis.ConnectionPool]] = None
async_connection: Optional[Union[AsyncRedis, AsyncConnectionPool]] = None
def __init__(
self,
sync_connection: Optional[Union[redis.Redis, redis.ConnectionPool]] = None,
async_connection: Optional[Union[AsyncRedis, AsyncConnectionPool]] = None,
):
super().__init__(serde=JsonAndBinarySerializer())
self.sync_connection = sync_connection
self.async_connection = async_connection
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
) -> RunnableConfig:
thread_id = config["configurable"]["thread_id"]
parent_ts = config["configurable"].get("thread_ts")
key = f"checkpoint${thread_id}${checkpoint['ts']}"
try:
with _get_sync_connection(self.sync_connection) as conn:
conn.hset(
key,
mapping={
"checkpoint": self.serde.dumps(checkpoint),
"metadata": self.serde.dumps(metadata),
"parent_ts": parent_ts if parent_ts else "",
},
)
logger.info(
f"Checkpoint stored successfully for thread_id: {thread_id}, ts: {checkpoint['ts']}"
)
except Exception as e:
logger.error(f"Failed to put checkpoint: {e}")
raise
return {
"configurable": {
"thread_id": thread_id,
"thread_ts": checkpoint["ts"],
},
}
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
) -> RunnableConfig:
thread_id = config["configurable"]["thread_id"]
parent_ts = config["configurable"].get("thread_ts")
key = f"checkpoint${thread_id}${checkpoint['ts']}"
try:
async with _get_async_connection(self.async_connection) as conn:
await conn.hset(
key,
mapping={
"checkpoint": self.serde.dumps(checkpoint),
"metadata": self.serde.dumps(metadata),
"parent_ts": parent_ts if parent_ts else "",
},
)
logger.info(
f"Checkpoint stored successfully for thread_id: {thread_id}, ts: {checkpoint['ts']}"
)
except Exception as e:
logger.error(f"Failed to aput checkpoint: {e}")
raise
return {
"configurable": {
"thread_id": thread_id,
"thread_ts": checkpoint["ts"],
},
}
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
thread_id = config["configurable"]["thread_id"]
thread_ts = config["configurable"].get("thread_ts", None)
try:
with _get_sync_connection(self.sync_connection) as conn:
if thread_ts:
key = f"checkpoint${thread_id}${thread_ts}"
else:
all_keys = conn.keys(f"checkpoint${thread_id}$*")
if not all_keys:
logger.info(f"No checkpoints found for thread_id: {thread_id}")
return None
# checkpoint:125:2024-07-26T11:49:46.662715+00:00
# get according to the latest timestamp
latest_key = max(all_keys, key=lambda k: k.decode().split("$")[-1])
key = latest_key.decode()
checkpoint_data = conn.hgetall(key)
if not checkpoint_data:
logger.info(f"No valid checkpoint data found for key: {key}")
return None
checkpoint = self.serde.loads(checkpoint_data[b"checkpoint"].decode())
metadata = self.serde.loads(checkpoint_data[b"metadata"].decode())
parent_ts = checkpoint_data.get(b"parent_ts", b"").decode()
parent_config = (
{"configurable": {"thread_id": thread_id, "thread_ts": parent_ts}}
if parent_ts
else None
)
logger.info(
f"Checkpoint retrieved successfully for thread_id: {thread_id}, ts: {thread_ts}"
)
return CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=parent_config,
)
except Exception as e:
logger.error(f"Failed to get checkpoint tuple: {e}")
raise
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
thread_id = config["configurable"]["thread_id"]
thread_ts = config["configurable"].get("thread_ts", None)
try:
async with _get_async_connection(self.async_connection) as conn:
if thread_ts:
key = f"checkpoint${thread_id}${thread_ts}"
else:
all_keys = await conn.keys(f"checkpoint${thread_id}$*")
if not all_keys:
logger.info(f"No checkpoints found for thread_id: {thread_id}")
return None
latest_key = max(all_keys, key=lambda k: k.decode().split("$")[-1])
key = latest_key.decode()
checkpoint_data = await conn.hgetall(key)
if not checkpoint_data:
logger.info(f"No valid checkpoint data found for key: {key}")
return None
checkpoint = self.serde.loads(checkpoint_data[b"checkpoint"].decode())
metadata = self.serde.loads(checkpoint_data[b"metadata"].decode())
parent_ts = checkpoint_data.get(b"parent_ts", b"").decode()
parent_config = (
{"configurable": {"thread_id": thread_id, "thread_ts": parent_ts}}
if parent_ts
else None
)
logger.info(
f"Checkpoint retrieved successfully for thread_id: {thread_id}, ts: {thread_ts}"
)
return CheckpointTuple(
config=config,
checkpoint=checkpoint,
metadata=metadata,
parent_config=parent_config,
)
except Exception as e:
logger.error(f"Failed to get checkpoint tuple: {e}")
raise
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Generator[CheckpointTuple, None, None]:
thread_id = config["configurable"]["thread_id"] if config else "*"
pattern = f"checkpoint:{thread_id}:*"
try:
with _get_sync_connection(self.sync_connection) as conn:
keys = conn.keys(pattern)
if before:
keys = [
k
for k in keys
if k.decode().split("$")[-1]
< before["configurable"]["thread_ts"]
]
keys = sorted(
keys, key=lambda k: k.decode().split("$")[-1], reverse=True
)
if limit:
keys = keys[:limit]
for key in keys:
data = conn.hgetall(key)
if data and "checkpoint" in data and "metadata" in data:
thread_ts = key.decode().split("$")[-1]
yield CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"thread_ts": thread_ts,
}
},
checkpoint=self.serde.loads(data["checkpoint"].decode()),
metadata=self.serde.loads(data["metadata"].decode()),
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"thread_ts": data.get(
"parent_ts", b""
).decode(),
}
}
if data.get("parent_ts")
else None
),
)
logger.info(
f"Checkpoint listed for thread_id: {thread_id}, ts: {thread_ts}"
)
except Exception as e:
logger.error(f"Failed to list checkpoints: {e}")
raise
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncGenerator[CheckpointTuple, None]:
thread_id = config["configurable"]["thread_id"] if config else "*"
pattern = f"checkpoint:{thread_id}:*"
try:
async with _get_async_connection(self.async_connection) as conn:
keys = await conn.keys(pattern)
if before:
keys = [
k
for k in keys
if k.decode().split("$")[-1]
< before["configurable"]["thread_ts"]
]
keys = sorted(
keys, key=lambda k: k.decode().split("$")[-1], reverse=True
)
if limit:
keys = keys[:limit]
for key in keys:
data = await conn.hgetall(key)
if data and "checkpoint" in data and "metadata" in data:
thread_ts = key.decode().split("$")[-1]
yield CheckpointTuple(
config={
"configurable": {
"thread_id": thread_id,
"thread_ts": thread_ts,
}
},
checkpoint=self.serde.loads(data["checkpoint"].decode()),
metadata=self.serde.loads(data["metadata"].decode()),
parent_config=(
{
"configurable": {
"thread_id": thread_id,
"thread_ts": data.get(
"parent_ts", b""
).decode(),
}
}
if data.get("parent_ts")
else None
),
)
logger.info(
f"Checkpoint listed for thread_id: {thread_id}, ts: {thread_ts}"
)
except Exception as e:
logger.error(f"Failed to list checkpoints: {e}")
raise |
Beta Was this translation helpful? Give feedback.
-
missing put_writes and aput_writes functions |
Beta Was this translation helpful? Give feedback.
-
Hi all! The guide has been updated with the correct implementtions of |
Beta Was this translation helpful? Give feedback.
-
Is it possible to use custom serializable objects in the state or do they have to be specific Langchain objects? i.e. I made an Action class like this
Obviously the namespace doesnt exist, but is there a way to make it work in my graph state memory?
|
Beta Was this translation helpful? Give feedback.
-
When using langserve and langgraph, I have the following code: |
Beta Was this translation helpful? Give feedback.
-
I'm having issues with the size of the checkpoint exceeding the allowed size in Firestore. Is there any way of limiting or splitting the updates? |
Beta Was this translation helpful? Give feedback.
-
Are there any issues creating a custom persistent solution outside the checkpoint protocol, where I manually populate the state from some data source at the start node and save it at the end node or some other node? |
Beta Was this translation helpful? Give feedback.
-
This is not clear for me that the memory is keeping the whole graph or if I need to rebuild it from scratch for a stateless application. If I keep only the thread do I have to recompile the whole graph or can I retrieve the whole workflow from memory? |
Beta Was this translation helpful? Give feedback.
-
Can anyone please add memory checkpointer to agentic rag tutorial in langgraph? |
Beta Was this translation helpful? Give feedback.
-
The code works well, but I have a question about the following part:
From the context of the function, it seems like this newly assigned config isn't used afterward. Is this code redundant, and can it be safely removed? Or am I missing something? So it is a meaningless code, Could we remove it? |
Beta Was this translation helpful? Give feedback.
-
I literally copied the tutorial code (https://langchain-ai.github.io/langgraph/how-tos/human_in_the_loop/time-travel/#build-the-agent) recently on jupyter notebook just to see if it was streamlit working in tangent with MemorySaver() which was not allowing the agent to store memory, but somehow when I run it it loses all Memory even with the exact same code, and the code has no previous context. Does memorysaver not work or am I doing something wrong |
Beta Was this translation helpful? Give feedback.
-
is there anyway we can define that one as a global checkpointer ? async with AsyncRedisSaver.from_conn_info(
|
Beta Was this translation helpful? Give feedback.
-
I am getting: |
Beta Was this translation helpful? Give feedback.
-
Can someone pls help on how to add delete checkpoint functionality. In the above code |
Beta Was this translation helpful? Give feedback.
-
How to update the checkpoint with redis? eg. the histrory messsages [mes1, mes2, mes3, mes4,mes5, mes6, mes7, mes8, mes9, mes10, mes11, mes12] , my state_modifier is keep the last 10 messages [mes3, mes4,mes5, mes6, mes7, mes8, mes9, mes10, mes11, mes12], after agent.invoke , agent.get_state show the whole history messages[mes1, mes2, mes3, mes4,mes5, mes6, mes7, mes8, mes9, mes10, mes11, mes12], but I want to [mes3, mes4,mes5, mes6, mes7, mes8, mes9, mes10, mes11, mes12]. I think it is because the checkpoint has not updated . |
Beta Was this translation helpful? Give feedback.
-
When the graph is executed, the MemorySaver keeps growing. What are some solutions to implement a memory eviction mechanism? |
Beta Was this translation helpful? Give feedback.
-
These checkpointers don't work with subgraph state. I took the RedisSaver here and put it at the top of the subgraph state management tutorial and replaced the MemorySaver there with this RedisSaver, and I get the following error:
from the Resuming from breakpoints section, the third cell: state = graph.get_state(config, subgraphs=True)
state.tasks[0] It got 5 values. Looks like handing subgraphs correctly requires handling a sequence of checkpoint ids instead of a single one? |
Beta Was this translation helpful? Give feedback.
-
langgraph/how-tos/persistence_redis/
Build language agents as graphs
https://langchain-ai.github.io/langgraph/how-tos/persistence_redis/
Beta Was this translation helpful? Give feedback.
All reactions