Skip to content

Commit

Permalink
Merge branch 'change-size-based-on-children'
Browse files Browse the repository at this point in the history
  • Loading branch information
sroussey committed Apr 11, 2024
2 parents cf04a27 + e438e9c commit 841aad1
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 125 deletions.
Binary file modified bun.lockb
Binary file not shown.
2 changes: 1 addition & 1 deletion packages/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"preview": "vite preview"
},
"dependencies": {
"@xyflow/react": "^12.0.0-next.13",
"@xyflow/react": "^12.0.0-next.14",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"@uiw/react-codemirror": "^4.21.25",
Expand Down
14 changes: 10 additions & 4 deletions packages/web/src/RunGraphFlow.css
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,18 @@
width: 250px;
}

.react-flow__node-compound .wrapper {
width: 300px;
.react-flow__node-compound {
min-height: 150px;
}

.react-flow__node-compound {
min-height: 450px;
.react-flow__node-compound .outside {
display: flex;
flex-direction: column;
width: 100%;
}

.react-flow__node-compound .wrapper {
width: 100%;
}

.gradient:before {
Expand Down
134 changes: 76 additions & 58 deletions packages/web/src/RunGraphFlow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ import {
useNodesInitialized,
useReactFlow,
Edge,
Position,
} from "@xyflow/react";

import "@xyflow/react/dist/base.css";
import "./RunGraphFlow.css";
import { TurboNodeData, SingleNode, CompoundNode } from "./TurboNode";
import TurboEdge from "./TurboEdge";
import { FiFileText, FiClipboard, FiDownload, FiUpload } from "react-icons/fi";
Expand All @@ -23,6 +21,9 @@ import {
} from "ellmers-core/browser";
import { GraphPipelineCenteredLayout, GraphPipelineLayout, computeLayout } from "./layout";

import "@xyflow/react/dist/base.css";
import "./RunGraphFlow.css";

registerHuggingfaceLocalTasksInMemory();
registerMediaPipeTfJsLocalInMemory();

Expand Down Expand Up @@ -65,27 +66,30 @@ function sortNodes(nodes: Node<TurboNodeData>[]): Node<TurboNodeData>[] {

function convertGraphToNodes(graph: TaskGraph): Node<TurboNodeData>[] {
const tasks = graph.getNodes();
const nodes = tasks.flatMap((node, index) => {
const nodes = tasks.flatMap((task, index) => {
let n: Node<TurboNodeData>[] = [
{
id: node.config.id as string,
id: task.config.id as string,
position: { x: 0, y: 0 },
data: {
icon: categoryIcons[(node.constructor as any).category],
title: (node.constructor as any).type,
subline: node.config.name,
icon: categoryIcons[(task.constructor as any).category],
title: (task.constructor as any).type,
subline: task.config.name,
},
type: node.isCompound ? "compound" : "single",
type: task.isCompound ? "compound" : "single",
selectable: true,
connectable: false,
draggable: false,
sourcePosition: Position.Right,
targetPosition: Position.Left,
},
];
if (node.isCompound) {
const subNodes = convertGraphToNodes(node.subGraph).map((n) => {
if (task.isCompound) {
const subNodes = convertGraphToNodes(task.subGraph).map((n) => {
return {
...n,
parentId: node.config.id as string,
parentId: task.config.id as string,
extent: "parent",
selectable: false,
connectable: false,
} as Node<TurboNodeData>;
});
n = [...n, ...subNodes];
Expand All @@ -95,9 +99,45 @@ function convertGraphToNodes(graph: TaskGraph): Node<TurboNodeData>[] {
return nodes;
}

function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>) {
function doNodeLayout(
setNodes: Dispatch<SetStateAction<Node[]>>,
setEdges: Dispatch<SetStateAction<Edge[]>>
) {
let edges = [];
setEdges((es) => {
edges = es.map((n) => {
return {
...n,
style: { opacity: 1 },
};
});
setNodes((nodes) => {
const computedNodes = computeLayout(
nodes,
edges,
new GraphPipelineCenteredLayout<Node<TurboNodeData>>(),
new GraphPipelineLayout<Node<TurboNodeData>>({ startTop: 100, startLeft: 20 })
) as Node<TurboNodeData>[];
const sortedNodes = sortNodes(computedNodes);
sortedNodes.map((n) => {
n.style = { opacity: 1 };
return n;
});
return sortedNodes;
});
return edges;
});
}

// TODO: unlisten to tasks
function listenToTask(
task: Task,
setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>,
edges: Edge[],
setEdges: Dispatch<SetStateAction<Edge[]>>
) {
task.on("progress", (progress, progressText) => {
setNodes((nds) =>
setNodes((nds) => {
nds.map((nd) => {
if (nd.id === task.config.id) {
return {
Expand All @@ -111,8 +151,10 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
};
}
return nd;
})
);
});
doNodeLayout(setNodes, setEdges);
return nds;
});
});
task.on("start", () => {
setNodes((nds) =>
Expand Down Expand Up @@ -167,7 +209,7 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
);
});
if (task.isCompound) {
listenToGraphNodes(task.subGraph, setNodes);
listenToGraphTasks(task.subGraph, setNodes, edges, setEdges);
task.on("regenerate", () => {
// console.log("Node regenerated", task.config.id);
setNodes((nodes: Node<TurboNodeData>[]) => {
Expand All @@ -181,7 +223,7 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
connectable: false,
}) as Node<TurboNodeData>
);
listenToGraphNodes(task.subGraph, setNodes);
listenToGraphTasks(task.subGraph, setNodes, edges, setEdges);
let returnNodes = nodes.filter((n) => n.parentId !== task.config.id); // remove old children
returnNodes = [...returnNodes, ...children]; // add new children
returnNodes = sortNodes(returnNodes); // sort all nodes (parent, children, parent, children, ...)
Expand All @@ -191,13 +233,15 @@ function listenToTask(task: Task, setNodes: Dispatch<SetStateAction<Node<TurboNo
}
}

function listenToGraphNodes(
function listenToGraphTasks(
graph: TaskGraph,
setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>
setNodes: Dispatch<SetStateAction<Node<TurboNodeData>[]>>,
edges: Edge[],
setEdges: Dispatch<SetStateAction<Edge[]>>
) {
const nodes = graph.getNodes();
for (const node of nodes) {
listenToTask(node, setNodes);
listenToTask(node, setNodes, edges, setEdges);
}
}

Expand All @@ -218,53 +262,27 @@ const defaultEdgeOptions = {
export const RunGraphFlow: React.FC<{
graph: TaskGraph;
}> = ({ graph }) => {
const [nodes, setNodes, onNodesChangeTheirs] = useNodesState<Node<TurboNodeData>>([]);
const [nodes, setNodes, onNodesChange] = useNodesState<Node<TurboNodeData>>([]);
const [edges, setEdges, onEdgesChange] = useEdgesState<Edge>([]);
const graphRef = useRef<TaskGraph | null>(null);

const onNodesChange = useCallback(
(changes: any) => {
console.log("Nodes changed", changes);
onNodesChangeTheirs(changes);
},
[onNodesChangeTheirs, nodes, edges]
);

const initialized = useNodesInitialized() && !nodes.some((n) => !n.measured);
const shouldLayout = useNodesInitialized() && !nodes.some((n) => !n.measured);
const { fitView } = useReactFlow();

useEffect(() => {
if (initialized) {
const computedNodes = computeLayout(
nodes,
edges,
new GraphPipelineCenteredLayout(),
new GraphPipelineLayout({ startTop: 100, startLeft: 20 })
) as Node<TurboNodeData>[];
const sortedNodes = sortNodes(computedNodes);
setNodes(
sortedNodes.map((n) => {
n.style = { opacity: 1 };
return n;
})
);
setEdges(
edges.map((n) => {
return {
...n,
style: { opacity: 1 },
};
})
);
const id: Timer | null = null;
if (shouldLayout) {
doNodeLayout(setNodes, setEdges);
setTimeout(() => {
fitView();
}, 5);
}
}, [initialized]);
}, [shouldLayout, setNodes, setEdges, fitView]);

useEffect(() => {
if (graph !== graphRef.current) {
graphRef.current = graph;
console.log("Graph changed", graph);
const nodes = sortNodes(convertGraphToNodes(graph));
setNodes(
nodes.map((n) => {
Expand All @@ -285,9 +303,9 @@ export const RunGraphFlow: React.FC<{
};
})
);
listenToGraphNodes(graph, setNodes);
listenToGraphTasks(graph, setNodes, edges, setEdges);
}
}, [graph]);
}, [graph, setNodes, setEdges, graphRef.current]);

// const onConnect = useCallback(
// (params: any) => setEdges((els) => addEdge(params, els)),
Expand Down
Loading

0 comments on commit 841aad1

Please sign in to comment.