Skip to content

Commit

Permalink
[web] feat: layout based on sublayout
Browse files Browse the repository at this point in the history
  • Loading branch information
sroussey committed Apr 11, 2024
1 parent a046ef3 commit e438e9c
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 194 deletions.
3 changes: 2 additions & 1 deletion packages/web/src/RunGraphFlow.css
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@
.react-flow__node-compound .outside {
display: flex;
flex-direction: column;
width: 100%;
}

.react-flow__node-compound .wrapper {
width: 300px;
width: 100%;
}

.gradient:before {
Expand Down
139 changes: 72 additions & 67 deletions packages/web/src/RunGraphFlow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ import {
useNodesInitialized,
useReactFlow,
Edge,
Position,
} from "@xyflow/react";

import "@xyflow/react/dist/base.css";
import "./RunGraphFlow.css";
import { TurboNodeData, SingleNode, CompoundNode } from "./TurboNode";
import TurboEdge from "./TurboEdge";
import { FiFileText, FiClipboard, FiDownload, FiUpload } from "react-icons/fi";
Expand All @@ -22,7 +20,9 @@ import {
registerMediaPipeTfJsLocalInMemory,
} from "ellmers-core/browser";
import { GraphPipelineCenteredLayout, GraphPipelineLayout, computeLayout } from "./layout";
import { NodeChange } from "./changes";

import "@xyflow/react/dist/base.css";
import "./RunGraphFlow.css";

registerHuggingfaceLocalTasksInMemory();
registerMediaPipeTfJsLocalInMemory();
Expand Down Expand Up @@ -66,27 +66,30 @@ function sortNodes(nodes: Node<TurboNodeData>[]): Node<TurboNodeData>[] {

function convertGraphToNodes(graph: TaskGraph): Node<TurboNodeData>[] {
const tasks = graph.getNodes();
const nodes = tasks.flatMap((node, index) => {
const nodes = tasks.flatMap((task, index) => {
let n: Node<TurboNodeData>[] = [
{
id: node.config.id as string,
id: task.config.id as string,
position: { x: 0, y: 0 },
data: {
icon: categoryIcons[(node.constructor as any).category],
title: (node.constructor as any).type,
subline: node.config.name,
icon: categoryIcons[(task.constructor as any).category],
title: (task.constructor as any).type,
subline: task.config.name,
},
type: node.isCompound ? "compound" : "single",
type: task.isCompound ? "compound" : "single",
selectable: true,
connectable: false,
draggable: false,
sourcePosition: Position.Right,
targetPosition: Position.Left,
},
];
if (node.isCompound) {
const subNodes = convertGraphToNodes(node.subGraph).map((n) => {
if (task.isCompound) {
const subNodes = convertGraphToNodes(task.subGraph).map((n) => {
return {
...n,
parentId: node.config.id as string,
parentId: task.config.id as string,
extent: "parent",
selectable: false,
connectable: false,
} as Node<TurboNodeData>;
});
n = [...n, ...subNodes];
Expand All @@ -95,10 +98,46 @@ function convertGraphToNodes(graph: TaskGraph): Node<TurboNodeData>[] {
});
return nodes;
}

function doNodeLayout(
setNodes: Dispatch<SetStateAction<Node[]>>,
setEdges: Dispatch<SetStateAction<Edge[]>>
) {
let edges = [];
setEdges((es) => {
edges = es.map((n) => {
return {
...n,
style: { opacity: 1 },
};
});
setNodes((nodes) => {
const computedNodes = computeLayout(
nodes,
edges,
new GraphPipelineCenteredLayout<Node<TurboNodeData>>(),
new GraphPipelineLayout<Node<TurboNodeData>>({ startTop: 100, startLeft: 20 })
) as Node<TurboNodeData>[];
const sortedNodes = sortNodes(computedNodes);
sortedNodes.map((n) => {
n.style = { opacity: 1 };
return n;
});
return sortedNodes;
});
return edges;
});
}

// TODO: unlisten to tasks
function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>) {
function listenToTask(
task: Task,
setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>,
edges: Edge[],
setEdges: Dispatch<SetStateAction<Edge[]>>
) {
task.on("progress", (progress, progressText) => {
setNodes((nds) =>
setNodes((nds) => {
nds.map((nd) => {
if (nd.id === task.config.id) {
return {
Expand All @@ -112,8 +151,10 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
};
}
return nd;
})
);
});
doNodeLayout(setNodes, setEdges);
return nds;
});
});
task.on("start", () => {
setNodes((nds) =>
Expand Down Expand Up @@ -168,7 +209,7 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
);
});
if (task.isCompound) {
listenToGraphTasks(task.subGraph, setNodes);
listenToGraphTasks(task.subGraph, setNodes, edges, setEdges);
task.on("regenerate", () => {
// console.log("Node regenerated", task.config.id);
setNodes((nodes: Node<TurboNodeData>[]) => {
Expand All @@ -182,7 +223,7 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
connectable: false,
}) as Node<TurboNodeData>
);
listenToGraphTasks(task.subGraph, setNodes);
listenToGraphTasks(task.subGraph, setNodes, edges, setEdges);
let returnNodes = nodes.filter((n) => n.parentId !== task.config.id); // remove old children
returnNodes = [...returnNodes, ...children]; // add new children
returnNodes = sortNodes(returnNodes); // sort all nodes (parent, children, parent, children, ...)
Expand All @@ -194,11 +235,13 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo

function listenToGraphTasks(
graph: TaskGraph,
setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>
setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>,
edges: Edge[],
setEdges: Dispatch<SetStateAction<Edge[]>>
) {
const nodes = graph.getNodes();
for (const node of nodes) {
listenToTask(node, setNodes);
listenToTask(node, setNodes, edges, setEdges);
}
}

Expand All @@ -219,63 +262,25 @@ const defaultEdgeOptions = {
export const RunGraphFlow: React.FC<{
graph: TaskGraph;
}> = ({ graph }) => {
const [nodes, setNodes, onNodesChangeTheirs] = useNodesState<Node<TurboNodeData>>([]);
const [nodes, setNodes, onNodesChange] = useNodesState<Node<TurboNodeData>>([]);
const [edges, setEdges, onEdgesChange] = useEdgesState<Edge>([]);
const graphRef = useRef<TaskGraph | null>(null);

const onNodesChange = useCallback(
(changes: NodeChange<Node<TurboNodeData>>[]) => {
console.log("Nodes changed", changes);
onNodesChangeTheirs(changes);
// const computedNodes = computeLayout(
// nodes,
// edges,
// new GraphPipelineCenteredLayout(),
// new GraphPipelineLayout({ startTop: 100, startLeft: 20 })
// ) as Node<TurboNodeData>[];
// const sortedNodes = sortNodes(computedNodes);
// console.log(nodes, sortedNodes);
},
[onNodesChangeTheirs, nodes, edges]
);

const initialized = useNodesInitialized() && !nodes.some((n) => !n.measured);
const shouldLayout = useNodesInitialized() && !nodes.some((n) => !n.measured);
const { fitView } = useReactFlow();

useEffect(() => {
const id: Timer | null = null;
if (initialized) {
console.log("Nodes initialized", nodes);
const computedNodes = computeLayout(
nodes,
edges,
new GraphPipelineCenteredLayout(),
new GraphPipelineLayout({ startTop: 100, startLeft: 20 })
) as Node<TurboNodeData>[];
const sortedNodes = sortNodes(computedNodes);
setNodes(
sortedNodes.map((n) => {
n.style = { opacity: 1 };
return n;
})
);
setEdges(
edges.map((n) => {
return {
...n,
style: { opacity: 1 },
};
})
);
if (shouldLayout) {
doNodeLayout(setNodes, setEdges);
setTimeout(() => {
fitView();
}, 5);
}
}, [initialized, setNodes, setEdges, fitView]);
}, [shouldLayout, setNodes, setEdges, fitView]);

useEffect(() => {
if (graph !== graphRef.current) {
console.log("Graph changed", graph);
graphRef.current = graph;
console.log("Graph changed", graph);
const nodes = sortNodes(convertGraphToNodes(graph));
Expand All @@ -298,7 +303,7 @@ export const RunGraphFlow: React.FC<{
};
})
);
listenToGraphTasks(graph, setNodes);
listenToGraphTasks(graph, setNodes, edges, setEdges);
}
}, [graph, setNodes, setEdges, graphRef.current]);

Expand Down
60 changes: 0 additions & 60 deletions packages/web/src/changes.ts

This file was deleted.

Loading

0 comments on commit e438e9c

Please sign in to comment.