diff --git a/CHANGELOG.md b/CHANGELOG.md index fe7b7707b..b1be35689 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Exception when calling `Structure.to_json()` after it has run. - `Agent` unintentionally modifying `stream` for all Prompt Drivers. +- `StructureVisualizer.base_url` for setting the base URL on the url generated by `StructureVisualizer.to_url()`. +- `StructureVisualizer.query_params` for setting query parameters on the url generated by `StructureVisualizer.to_url()`. ## [1.0.0] - 2024-12-09 diff --git a/griptape/utils/structure_visualizer.py b/griptape/utils/structure_visualizer.py index d6616710a..e4d578c01 100644 --- a/griptape/utils/structure_visualizer.py +++ b/griptape/utils/structure_visualizer.py @@ -1,9 +1,11 @@ from __future__ import annotations import base64 +import urllib.parse from typing import TYPE_CHECKING, Callable from attrs import define, field +from requests.models import PreparedRequest if TYPE_CHECKING: from griptape.structures import Structure @@ -17,6 +19,8 @@ class StructureVisualizer: structure: Structure = field() header: str = field(default="graph TD;", kw_only=True) build_node_id: Callable[[BaseTask], str] = field(default=lambda task: task.id.title(), kw_only=True) + query_params: dict[str, str] = field(factory=dict, kw_only=True) + base_url: str = field(default="https://mermaid.ink", kw_only=True) def to_url(self) -> str: """Generates a url that renders the Workflow structure as a Mermaid flowchart. @@ -34,7 +38,14 @@ def to_url(self) -> str: graph_bytes = graph.encode("utf-8") base64_string = base64.b64encode(graph_bytes).decode("utf-8") - return f"https://mermaid.ink/svg/{base64_string}" + url = urllib.parse.urljoin(self.base_url, f"svg/{base64_string}") + req = PreparedRequest() + req.prepare_url(url, self.query_params) + + if req.url is None: + raise ValueError("Failed to generate the URL") + + return req.url def __render_tasks(self, tasks: list[BaseTask]) -> str: return "\n\t" + "\n\t".join([self.__render_task(task) for task in tasks]) diff --git a/tests/unit/utils/test_structure_visualizer.py b/tests/unit/utils/test_structure_visualizer.py index 565e33e28..e9fe0d02e 100644 --- a/tests/unit/utils/test_structure_visualizer.py +++ b/tests/unit/utils/test_structure_visualizer.py @@ -103,3 +103,22 @@ def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]: result == "https://mermaid.ink/svg/Z3JhcGggVEQ7CgkxLS0+IEJyYW5jaDsKCUJyYW5jaHsgQnJhbmNoIH0tLi0+IDIgJiAzOwoJMi0tPiA0OwoJMy0tPiA0OwoJNDs=" ) + + def test_query_params(self): + visualizer = StructureVisualizer( + Pipeline( + tasks=[ + PromptTask("test1", id="task1"), + PromptTask("test2", id="task2", parent_ids=["task1"]), + PromptTask("test3", id="task3", parent_ids=["task1"]), + PromptTask("test4", id="task4", parent_ids=["task2", "task3"]), + ], + ), + query_params={"theme": "dark", "bgColor": "2b2b2b"}, + ) + result = visualizer.to_url() + + assert ( + result + == "https://mermaid.ink/svg/Z3JhcGggVEQ7CglUYXNrMS0tPiBUYXNrMiAmIFRhc2szOwoJVGFzazItLT4gVGFzazMgJiBUYXNrNDsKCVRhc2szLS0+IFRhc2s0OwoJVGFzazQ7?theme=dark&bgColor=2b2b2b" + )