Skip to content

Commit

Permalink
Add StructureVisualizer params (#1434)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Dec 13, 2024
1 parent d69bd8c commit 7521927
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Exception when calling `Structure.to_json()` after it has run.
- `Agent` unintentionally modifying `stream` for all Prompt Drivers.
- `StructureVisualizer.base_url` for setting the base URL on the url generated by `StructureVisualizer.to_url()`.
- `StructureVisualizer.query_params` for setting query parameters on the url generated by `StructureVisualizer.to_url()`.

## [1.0.0] - 2024-12-09

Expand Down
13 changes: 12 additions & 1 deletion griptape/utils/structure_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import base64
import urllib.parse
from typing import TYPE_CHECKING, Callable

from attrs import define, field
from requests.models import PreparedRequest

if TYPE_CHECKING:
from griptape.structures import Structure
Expand All @@ -17,6 +19,8 @@ class StructureVisualizer:
structure: Structure = field()
header: str = field(default="graph TD;", kw_only=True)
build_node_id: Callable[[BaseTask], str] = field(default=lambda task: task.id.title(), kw_only=True)
query_params: dict[str, str] = field(factory=dict, kw_only=True)
base_url: str = field(default="https://mermaid.ink", kw_only=True)

def to_url(self) -> str:
"""Generates a url that renders the Workflow structure as a Mermaid flowchart.
Expand All @@ -34,7 +38,14 @@ def to_url(self) -> str:
graph_bytes = graph.encode("utf-8")
base64_string = base64.b64encode(graph_bytes).decode("utf-8")

return f"https://mermaid.ink/svg/{base64_string}"
url = urllib.parse.urljoin(self.base_url, f"svg/{base64_string}")
req = PreparedRequest()
req.prepare_url(url, self.query_params)

if req.url is None:
raise ValueError("Failed to generate the URL")

return req.url

def __render_tasks(self, tasks: list[BaseTask]) -> str:
return "\n\t" + "\n\t".join([self.__render_task(task) for task in tasks])
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/utils/test_structure_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,22 @@ def on_run(_: BranchTask) -> ListArtifact[InfoArtifact]:
result
== "https://mermaid.ink/svg/Z3JhcGggVEQ7CgkxLS0+IEJyYW5jaDsKCUJyYW5jaHsgQnJhbmNoIH0tLi0+IDIgJiAzOwoJMi0tPiA0OwoJMy0tPiA0OwoJNDs="
)

def test_query_params(self):
visualizer = StructureVisualizer(
Pipeline(
tasks=[
PromptTask("test1", id="task1"),
PromptTask("test2", id="task2", parent_ids=["task1"]),
PromptTask("test3", id="task3", parent_ids=["task1"]),
PromptTask("test4", id="task4", parent_ids=["task2", "task3"]),
],
),
query_params={"theme": "dark", "bgColor": "2b2b2b"},
)
result = visualizer.to_url()

assert (
result
== "https://mermaid.ink/svg/Z3JhcGggVEQ7CglUYXNrMS0tPiBUYXNrMiAmIFRhc2szOwoJVGFzazItLT4gVGFzazMgJiBUYXNrNDsKCVRhc2szLS0+IFRhc2s0OwoJVGFzazQ7?theme=dark&bgColor=2b2b2b"
)

0 comments on commit 7521927

Please sign in to comment.