diff --git a/bun.lockb b/bun.lockb index 1703648..693d0b6 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/packages/web/package.json b/packages/web/package.json index b25a061..f95dcaf 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -9,7 +9,7 @@ "preview": "vite preview" }, "dependencies": { - "@xyflow/react": "^12.0.0-next.13", + "@xyflow/react": "^12.0.0-next.14", "react": "^18.2.0", "react-dom": "^18.2.0", "@uiw/react-codemirror": "^4.21.25", diff --git a/packages/web/src/RunGraphFlow.css b/packages/web/src/RunGraphFlow.css index 769b946..b3f124d 100644 --- a/packages/web/src/RunGraphFlow.css +++ b/packages/web/src/RunGraphFlow.css @@ -66,12 +66,18 @@ width: 250px; } -.react-flow__node-compound .wrapper { - width: 300px; +.react-flow__node-compound { + min-height: 150px; } -.react-flow__node-compound { - min-height: 450px; +.react-flow__node-compound .outside { + display: flex; + flex-direction: column; + width: 100%; +} + +.react-flow__node-compound .wrapper { + width: 100%; } .gradient:before { diff --git a/packages/web/src/RunGraphFlow.tsx b/packages/web/src/RunGraphFlow.tsx index 5355455..80e00df 100644 --- a/packages/web/src/RunGraphFlow.tsx +++ b/packages/web/src/RunGraphFlow.tsx @@ -8,10 +8,8 @@ import { useNodesInitialized, useReactFlow, Edge, + Position, } from "@xyflow/react"; - -import "@xyflow/react/dist/base.css"; -import "./RunGraphFlow.css"; import { TurboNodeData, SingleNode, CompoundNode } from "./TurboNode"; import TurboEdge from "./TurboEdge"; import { FiFileText, FiClipboard, FiDownload, FiUpload } from "react-icons/fi"; @@ -23,6 +21,9 @@ import { } from "ellmers-core/browser"; import { GraphPipelineCenteredLayout, GraphPipelineLayout, computeLayout } from "./layout"; +import "@xyflow/react/dist/base.css"; +import "./RunGraphFlow.css"; + registerHuggingfaceLocalTasksInMemory(); registerMediaPipeTfJsLocalInMemory(); @@ -65,27 +66,30 @@ function sortNodes(nodes: Node[]): Node[] { function convertGraphToNodes(graph: TaskGraph): Node[] { const tasks = graph.getNodes(); - const nodes = tasks.flatMap((node, index) => { + const nodes = tasks.flatMap((task, index) => { let n: Node[] = [ { - id: node.config.id as string, + id: task.config.id as string, position: { x: 0, y: 0 }, data: { - icon: categoryIcons[(node.constructor as any).category], - title: (node.constructor as any).type, - subline: node.config.name, + icon: categoryIcons[(task.constructor as any).category], + title: (task.constructor as any).type, + subline: task.config.name, }, - type: node.isCompound ? "compound" : "single", + type: task.isCompound ? "compound" : "single", + selectable: true, + connectable: false, + draggable: false, + sourcePosition: Position.Right, + targetPosition: Position.Left, }, ]; - if (node.isCompound) { - const subNodes = convertGraphToNodes(node.subGraph).map((n) => { + if (task.isCompound) { + const subNodes = convertGraphToNodes(task.subGraph).map((n) => { return { ...n, - parentId: node.config.id as string, + parentId: task.config.id as string, extent: "parent", - selectable: false, - connectable: false, } as Node; }); n = [...n, ...subNodes]; @@ -95,9 +99,45 @@ function convertGraphToNodes(graph: TaskGraph): Node[] { return nodes; } -function listenToTask(task: Task, setNodes: Dispatch[]>>) { +function doNodeLayout( + setNodes: Dispatch>, + setEdges: Dispatch> +) { + let edges = []; + setEdges((es) => { + edges = es.map((n) => { + return { + ...n, + style: { opacity: 1 }, + }; + }); + setNodes((nodes) => { + const computedNodes = computeLayout( + nodes, + edges, + new GraphPipelineCenteredLayout>(), + new GraphPipelineLayout>({ startTop: 100, startLeft: 20 }) + ) as Node[]; + const sortedNodes = sortNodes(computedNodes); + sortedNodes.map((n) => { + n.style = { opacity: 1 }; + return n; + }); + return sortedNodes; + }); + return edges; + }); +} + +// TODO: unlisten to tasks +function listenToTask( + task: Task, + setNodes: Dispatch[]>>, + edges: Edge[], + setEdges: Dispatch> +) { task.on("progress", (progress, progressText) => { - setNodes((nds) => + setNodes((nds) => { nds.map((nd) => { if (nd.id === task.config.id) { return { @@ -111,8 +151,10 @@ function listenToTask(task: Task, setNodes: Dispatch { setNodes((nds) => @@ -167,7 +209,7 @@ function listenToTask(task: Task, setNodes: Dispatch { // console.log("Node regenerated", task.config.id); setNodes((nodes: Node[]) => { @@ -181,7 +223,7 @@ function listenToTask(task: Task, setNodes: Dispatch ); - listenToGraphNodes(task.subGraph, setNodes); + listenToGraphTasks(task.subGraph, setNodes, edges, setEdges); let returnNodes = nodes.filter((n) => n.parentId !== task.config.id); // remove old children returnNodes = [...returnNodes, ...children]; // add new children returnNodes = sortNodes(returnNodes); // sort all nodes (parent, children, parent, children, ...) @@ -191,13 +233,15 @@ function listenToTask(task: Task, setNodes: Dispatch[]>> + setNodes: Dispatch[]>>, + edges: Edge[], + setEdges: Dispatch> ) { const nodes = graph.getNodes(); for (const node of nodes) { - listenToTask(node, setNodes); + listenToTask(node, setNodes, edges, setEdges); } } @@ -218,53 +262,27 @@ const defaultEdgeOptions = { export const RunGraphFlow: React.FC<{ graph: TaskGraph; }> = ({ graph }) => { - const [nodes, setNodes, onNodesChangeTheirs] = useNodesState>([]); + const [nodes, setNodes, onNodesChange] = useNodesState>([]); const [edges, setEdges, onEdgesChange] = useEdgesState([]); const graphRef = useRef(null); - const onNodesChange = useCallback( - (changes: any) => { - console.log("Nodes changed", changes); - onNodesChangeTheirs(changes); - }, - [onNodesChangeTheirs, nodes, edges] - ); - - const initialized = useNodesInitialized() && !nodes.some((n) => !n.measured); + const shouldLayout = useNodesInitialized() && !nodes.some((n) => !n.measured); const { fitView } = useReactFlow(); useEffect(() => { - if (initialized) { - const computedNodes = computeLayout( - nodes, - edges, - new GraphPipelineCenteredLayout(), - new GraphPipelineLayout({ startTop: 100, startLeft: 20 }) - ) as Node[]; - const sortedNodes = sortNodes(computedNodes); - setNodes( - sortedNodes.map((n) => { - n.style = { opacity: 1 }; - return n; - }) - ); - setEdges( - edges.map((n) => { - return { - ...n, - style: { opacity: 1 }, - }; - }) - ); + const id: Timer | null = null; + if (shouldLayout) { + doNodeLayout(setNodes, setEdges); setTimeout(() => { fitView(); }, 5); } - }, [initialized]); + }, [shouldLayout, setNodes, setEdges, fitView]); useEffect(() => { if (graph !== graphRef.current) { graphRef.current = graph; + console.log("Graph changed", graph); const nodes = sortNodes(convertGraphToNodes(graph)); setNodes( nodes.map((n) => { @@ -285,9 +303,9 @@ export const RunGraphFlow: React.FC<{ }; }) ); - listenToGraphNodes(graph, setNodes); + listenToGraphTasks(graph, setNodes, edges, setEdges); } - }, [graph]); + }, [graph, setNodes, setEdges, graphRef.current]); // const onConnect = useCallback( // (params: any) => setEdges((els) => addEdge(params, els)), diff --git a/packages/web/src/layout.ts b/packages/web/src/layout.ts index 60adcf2..35e1fbe 100644 --- a/packages/web/src/layout.ts +++ b/packages/web/src/layout.ts @@ -16,12 +16,12 @@ interface LayoutOptions { } export class GraphPipelineLayout implements LayoutOptions { - protected dag: DirectedAcyclicGraph; + protected dataflowDAG: DirectedAcyclicGraph; protected positions: Map = new Map(); protected layerHeight: number[] = []; - protected layers: Map = new Map(); + public layers: Map = new Map(); public nodeWidthMin: number = 190; - public nodeHeightMin: number = 150; + public nodeHeightMin: number = 50; public horizontalSpacing = 80; // Horizontal spacing between layers public verticalSpacing = 20; // Vertical spacing between nodes within a layer public startTop = 50; // Starting position of the top layer @@ -32,20 +32,18 @@ export class GraphPipelineLayout implements LayoutOptions { } public setGraph(dag: DirectedAcyclicGraph) { - this.dag = dag; - this.positions = new Map(); + this.dataflowDAG = dag; this.layers = new Map(); this.layerHeight = []; } public layoutGraph() { - const sortedNodes = this.dag.topologicallySortedNodes(); + const sortedNodes = this.dataflowDAG.topologicallySortedNodes(); this.assignLayers(sortedNodes); this.positionNodes(); - // Optionally, you can include edge drawing logic here or handle it separately } - private assignLayers(sortedNodes: T[]) { + public assignLayers(sortedNodes: T[]) { this.layers = new Map(); const nodeToLayer = new Map(); @@ -56,7 +54,7 @@ export class GraphPipelineLayout implements LayoutOptions { let maxLayer = -1; // Get all incoming edges (dependencies) of the node - const incomingEdges = this.dag.inEdges(node.id).map(([from]) => from); + const incomingEdges = this.dataflowDAG.inEdges(node.id).map(([from]) => from); incomingEdges.forEach((from) => { // Find the layer of the dependency @@ -85,10 +83,7 @@ export class GraphPipelineLayout implements LayoutOptions { let nodeWidth = this.nodeWidthMin; let currentY = this.startTop; nodes.forEach((node) => { - this.positions.set(node.id, { - x: currentX, - y: currentY, - }); + node.position = { x: currentX, y: currentY }; const nodeHeight = this.getNodeHeight(node); @@ -104,16 +99,14 @@ export class GraphPipelineLayout implements LayoutOptions { }); } - protected getNodeHeight(node: T): number { - return Math.max(node.measured?.height, this.nodeHeightMin); - } - - protected getNodeWidth(node: T): number { - return Math.max(node.measured?.width, this.nodeWidthMin); + public getNodeHeight(node: T): number { + const baseHeight = node.height || node.measured?.height || this.nodeHeightMin; + return Math.max(baseHeight, this.nodeHeightMin); } - public getNodePosition(nodeIdentity: string): PositionXY | undefined { - return this.positions.get(nodeIdentity); + public getNodeWidth(node: T): number { + const baseWidth = node.width || node.measured?.width || this.nodeWidthMin; + return Math.max(baseWidth, this.nodeWidthMin); } } @@ -126,10 +119,10 @@ export class GraphPipelineCenteredLayout extends GraphPipelineLa nodes.forEach((node) => { const nodeHeight = this.getNodeHeight(node); - this.positions.set(node.id, { - x: this.positions.get(node.id)!.x, + node.position = { + x: node.position.x, y: currentY, - }); + }; currentY += nodeHeight + this.verticalSpacing; }); @@ -141,61 +134,87 @@ export class GraphPipelineCenteredLayout extends GraphPipelineLa return Math.max(...this.layerHeight); } } +const groupBy = (items: T[], key: keyof T) => + items.reduce( + (result: Record, item: T) => ({ + ...result, + [String(item[key])]: [...(result[String(item[key])] || []), item], + }), + {} + ); export function computeLayout( nodes: Node[], edges: Edge[], layout: GraphPipelineLayout, - subFlowLayout?: GraphPipelineLayout, - parentId?: string + subFlowLayout?: GraphPipelineLayout ): Node[] { - const g = new DirectedAcyclicGraph((node) => node.id); - + // before we bother with anything, ignore hidden nodes nodes = nodes.filter((node) => !node.hidden); - const topLevelNodes = nodes.filter( - (node) => node.parentId === undefined || node.parentId === parentId - ); + const subgraphSize = new Map(); + const subgraphDAG = new DirectedAcyclicGraph((node) => node.id); - topLevelNodes.forEach((node) => { - g.insert(node); + nodes.forEach((node) => { + subgraphDAG.insert(node); }); - edges.forEach((edge) => { - try { - g.addEdge(edge.source, edge.target); - } catch (e) { - // might be an edge to a hidden node + nodes.forEach((node) => { + if (node.parentId) { + subgraphDAG.addEdge(node.parentId, node.id); } }); - layout.setGraph(g); - layout.layoutGraph(); - - const returnNodes: Node[] = topLevelNodes.map((node) => { - const nodePosition = layout.getNodePosition(node.id)!; + const subgraphDepthLayout = new GraphPipelineLayout(); + subgraphDepthLayout.setGraph(subgraphDAG); + const sortedNodes = subgraphDAG.topologicallySortedNodes(); + subgraphDepthLayout.assignLayers(sortedNodes); + const allgraphs = Array.from(subgraphDepthLayout.layers.values()); + + const returnNodes: Node[] = []; + for (let i = allgraphs.length - 1; i >= 0; i--) { + const graphs = groupBy(allgraphs[i], "parentId"); + for (const parentId in graphs) { + // This loop goes from innermost graph to outermost graph + // and lays out the nodes in each graph. We do innermost + // first so that the outer nodes can be positioned based on + // the size and layout of the inner nodes (the parent needs + // to expand to fit the children). + + const subgraphNodes = graphs[parentId]; + + const dataflowDAG = new DirectedAcyclicGraph( + (node) => node.id + ); - return { - ...node, - targetPosition: Position.Left, - sourcePosition: Position.Right, - position: { x: nodePosition.x, y: nodePosition.y }, - }; - }); + subgraphNodes.forEach((node) => { + dataflowDAG.insert(node); + }); - for (const node of topLevelNodes) { - const children = nodes.filter((n) => n.parentId === node.id); + edges.forEach((edge) => { + if (dataflowDAG.hasNode(edge.source) && dataflowDAG.hasNode(edge.target)) + dataflowDAG.addEdge(edge.source, edge.target); + }); + subgraphNodes.forEach((node) => { + if (subgraphSize.has(node.id)) { + const sizes = subgraphSize.get(node.id); + node.height = sizes.height; + node.width = sizes.width; + } + }); + const last = subgraphNodes[subgraphNodes.length - 1]; + const l = parentId === "undefined" ? layout : subFlowLayout; + l.setGraph(dataflowDAG); + l.layoutGraph(); + + subgraphSize.set(parentId, { + width: l.startLeft + (last.position?.x || 0) + l.getNodeWidth(last), + height: l.startTop / 2 + (last.position?.y || 0) + l.getNodeHeight(last), + }); - if (children.length > 0) { - const childNodes = computeLayout( - children, - edges, - subFlowLayout ?? layout, - subFlowLayout ?? layout, - node.id - ); - returnNodes.push(...childNodes); + returnNodes.push(...subgraphNodes); } } - return returnNodes; + + return returnNodes.toReversed().map((node) => ({ ...node })); }