From 02f5962cf17459350d0b546c3fba5a458e587848 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 26 Sep 2024 15:21:10 -0400 Subject: [PATCH] docs: add api referencs to langgraph (#26877) Add api references to langgraph --- docs/scripts/generate_api_reference_links.py | 324 +++++++++++++++---- docs/vercel_requirements.txt | 1 + 2 files changed, 262 insertions(+), 63 deletions(-) diff --git a/docs/scripts/generate_api_reference_links.py b/docs/scripts/generate_api_reference_links.py index 39fe3329fac01..2f94390838bc1 100644 --- a/docs/scripts/generate_api_reference_links.py +++ b/docs/scripts/generate_api_reference_links.py @@ -6,21 +6,143 @@ import os import re from pathlib import Path +from typing import List, Literal, Optional + +from typing_extensions import TypedDict logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Base URL for all class documentation -_BASE_URL = "https://python.langchain.com/api_reference/" +_LANGCHAIN_API_REFERENCE = "https://python.langchain.com/api_reference/" +_LANGGRAPH_API_REFERENCE = "https://langchain-ai.github.io/langgraph/reference/" # Regular expression to match Python code blocks code_block_re = re.compile(r"^(```\s?python\n)(.*?)(```)", re.DOTALL | re.MULTILINE) + + +MANUAL_API_REFERENCES_LANGGRAPH = [ + ("langgraph.prebuilt", "create_react_agent"), + ( + "langgraph.prebuilt", + "ToolNode", + ), + ( + "langgraph.prebuilt", + "ToolExecutor", + ), + ( + "langgraph.prebuilt", + "ToolInvocation", + ), + ("langgraph.prebuilt", "tools_condition"), + ( + "langgraph.prebuilt", + "ValidationNode", + ), + ( + "langgraph.prebuilt", + "InjectedState", + ), + # Graph + ( + "langgraph.graph", + "StateGraph", + ), + ( + "langgraph.graph.message", + "MessageGraph", + ), + ("langgraph.graph.message", "add_messages"), + ( + "langgraph.graph.graph", + "CompiledGraph", + ), + ( + "langgraph.types", + "StreamMode", + ), + ( + "langgraph.graph", + "START", + ), + ( + "langgraph.graph", + "END", + ), + ( + "langgraph.types", + "Send", + ), + ( + "langgraph.types", + "Interrupt", + ), + ( + "langgraph.types", + "RetryPolicy", + ), + ( + "langgraph.checkpoint.base", + "Checkpoint", + ), + ( + "langgraph.checkpoint.base", + "CheckpointMetadata", + ), + ( + "langgraph.checkpoint.base", + "BaseCheckpointSaver", + ), + ( + "langgraph.checkpoint.base", + "SerializerProtocol", + ), + ( + "langgraph.checkpoint.serde.jsonplus", + "JsonPlusSerializer", + ), + ( + "langgraph.checkpoint.memory", + "MemorySaver", + ), + ( + "langgraph.checkpoint.sqlite.aio", + "AsyncSqliteSaver", + ), + ( + "langgraph.checkpoint.sqlite", + "SqliteSaver", + ), + ( + "langgraph.checkpoint.postgres.aio", + "AsyncPostgresSaver", + ), + ( + "langgraph.checkpoint.postgres", + "PostgresSaver", + ), +] + +WELL_KNOWN_LANGGRAPH_OBJECTS = { + (module_, class_) for module_, class_ in MANUAL_API_REFERENCES_LANGGRAPH +} + + +def _make_regular_expression(pkg_prefix: str) -> re.Pattern: + if not pkg_prefix.isidentifier(): + raise ValueError(f"Invalid package prefix: {pkg_prefix}") + return re.compile( + r"from\s+(" + pkg_prefix + "(?:_\w+)?(?:\.\w+)*?)\s+import\s+" + r"((?:\w+(?:,\s*)?)*" # Match zero or more words separated by a comma+optional ws + r"(?:\s*\(.*?\))?)", # Match optional parentheses block + re.DOTALL, # Match newlines as well + ) + + # Regular expression to match langchain import lines -_IMPORT_RE = re.compile( - r"from\s+(langchain(?:_\w+)?(?:\.\w+)*?)\s+import\s+" - r"((?:\w+(?:,\s*)?)*" # Match zero or more words separated by a comma+optional ws - r"(?:\s*\(.*?\))?)", # Match optional parentheses block - re.DOTALL, # Match newlines as well -) +_IMPORT_LANGCHAIN_RE = _make_regular_expression("langchain") +_IMPORT_LANGGRAPH_RE = _make_regular_expression("langgraph") + _CURRENT_PATH = Path(__file__).parent.absolute() # Directory where generated markdown files are stored @@ -44,14 +166,22 @@ def find_files(path): yield full -def get_full_module_name(module_path, class_name): +def get_full_module_name(module_path, class_name) -> Optional[str]: """Get full module name using inspect""" - module = importlib.import_module(module_path) - class_ = getattr(module, class_name) - return inspect.getmodule(class_).__name__ + try: + module = importlib.import_module(module_path) + class_ = getattr(module, class_name) + return inspect.getmodule(class_).__name__ + except AttributeError as e: + logger.warning(f"Could not find module for {class_name}, {e}") + return None + except ImportError as e: + logger.warning(f"Failed to load for class {class_name}, {e}") + return None -def get_args(): +def get_args() -> argparse.Namespace: + """Get command line arguments""" parser = argparse.ArgumentParser() parser.add_argument( "--docs_dir", @@ -68,7 +198,7 @@ def get_args(): return parser.parse_args() -def main(): +def main() -> None: """Main function""" args = get_args() global_imports = {} @@ -112,9 +242,123 @@ def _get_doc_title(data: str, file_name: str) -> str: return file_name -def replace_imports(file): +class ImportInformation(TypedDict): + imported: str # imported class name + source: str # module path + docs: str # URL to the documentation + title: str # Title of the document + + +def _get_imports( + code: str, doc_title: str, package_ecosystem: Literal["langchain", "langgraph"] +) -> List[ImportInformation]: + """Get imports from the given code block. + + Args: + code: Python code block from which to extract imports + doc_title: Title of the document + package_ecosystem: "langchain" or "langgraph". The two live in different + repositories and have separate documentation sites. + + Returns: + List of import information for the given code block + """ + imports = [] + + if package_ecosystem == "langchain": + pattern = _IMPORT_LANGCHAIN_RE + elif package_ecosystem == "langgraph": + pattern = _IMPORT_LANGGRAPH_RE + else: + raise ValueError(f"Invalid package ecosystem: {package_ecosystem}") + + for import_match in pattern.finditer(code): + module = import_match.group(1) + if "pydantic_v1" in module: + continue + imports_str = ( + import_match.group(2).replace("(\n", "").replace("\n)", "") + ) # Handle newlines within parentheses + # remove any newline and spaces, then split by comma + imported_classes = [ + imp.strip() + for imp in re.split(r",\s*", imports_str.replace("\n", "")) + if imp.strip() + ] + for class_name in imported_classes: + module_path = get_full_module_name(module, class_name) + if not module_path: + continue + if len(module_path.split(".")) < 2: + continue + + if package_ecosystem == "langchain": + pkg = module_path.split(".")[0].replace("langchain_", "") + top_level_mod = module_path.split(".")[1] + + url = ( + _LANGCHAIN_API_REFERENCE + + pkg + + "/" + + top_level_mod + + "/" + + module_path + + "." + + class_name + + ".html" + ) + elif package_ecosystem == "langgraph": + if module.startswith("langgraph.checkpoint"): + namespace = "checkpoints" + elif module.startswith("langgraph.graph"): + namespace = "graphs" + elif module.startswith("langgraph.prebuilt"): + namespace = "prebuilt" + elif module.startswith("langgraph.errors"): + namespace = "errors" + else: + # Likely not documented yet + # Unable to determine the namespace + continue + + if module.startswith("langgraph.errors"): + # Has different URL structure than other modules + url = ( + _LANGGRAPH_API_REFERENCE + + namespace + + "/#langgraph.errors." + + class_name # Uses the actual class name here. + ) + else: + if (module, class_name) not in WELL_KNOWN_LANGGRAPH_OBJECTS: + # Likely not documented yet + continue + url = ( + _LANGGRAPH_API_REFERENCE + namespace + "/#" + class_name.lower() + ) + else: + raise ValueError(f"Invalid package ecosystem: {package_ecosystem}") + + # Add the import information to our list + imports.append( + { + "imported": class_name, + "source": module, + "docs": url, + "title": doc_title, + } + ) + + return imports + + +def replace_imports(file) -> List[ImportInformation]: """Replace imports in each Python code block with links to their - documentation and append the import info in a comment""" + documentation and append the import info in a comment + + Returns: + list of import information for the given file + """ all_imports = [] with open(file, "r") as f: data = f.read() @@ -133,53 +377,9 @@ def replacer(match): # Process imports in the code block imports = [] - for import_match in _IMPORT_RE.finditer(code): - module = import_match.group(1) - if "pydantic_v1" in module: - continue - imports_str = ( - import_match.group(2).replace("(\n", "").replace("\n)", "") - ) # Handle newlines within parentheses - # remove any newline and spaces, then split by comma - imported_classes = [ - imp.strip() - for imp in re.split(r",\s*", imports_str.replace("\n", "")) - if imp.strip() - ] - for class_name in imported_classes: - try: - module_path = get_full_module_name(module, class_name) - except AttributeError as e: - logger.warning(f"Could not find module for {class_name}, {e}") - continue - except ImportError as e: - logger.warning(f"Failed to load for class {class_name}, {e}") - continue - if len(module_path.split(".")) < 2: - continue - pkg = module_path.split(".")[0].replace("langchain_", "") - top_level_mod = module_path.split(".")[1] - url = ( - _BASE_URL - + pkg - + "/" - + top_level_mod - + "/" - + module_path - + "." - + class_name - + ".html" - ) - # Add the import information to our list - imports.append( - { - "imported": class_name, - "source": module, - "docs": url, - "title": _DOC_TITLE, - } - ) + imports.extend(_get_imports(code, _DOC_TITLE, "langchain")) + imports.extend(_get_imports(code, _DOC_TITLE, "langgraph")) if imports: all_imports.extend(imports) @@ -194,8 +394,6 @@ def replacer(match): # Use re.sub to replace each Python code block data = code_block_re.sub(replacer, data) - # if all_imports: - # print(f"Adding {len(all_imports)} links for imports in {file}") with open(file, "w") as f: f.write(data) return all_imports diff --git a/docs/vercel_requirements.txt b/docs/vercel_requirements.txt index 3c3f6fe124f8c..365d23a43527e 100644 --- a/docs/vercel_requirements.txt +++ b/docs/vercel_requirements.txt @@ -2,6 +2,7 @@ -e ../libs/langchain -e ../libs/community -e ../libs/text-splitters +langgraph langchain-cohere urllib3==1.26.19 nbconvert==7.16.4