Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use matrix API to bootstrap graph view #216

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 29 additions & 4 deletions src/components/GraphVisualisation/GraphVisualisation.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import { deduplicatePoints, getSimilarPoints, initGraph } from '../../lib/graph-
import ForceGraph from 'force-graph';
import { useClient } from '../../context/client-context';
import { useSnackbar } from 'notistack';
import { debounce } from 'lodash';
import { resizeObserverWithCallback } from '../../lib/common-helpers';

const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) => {
const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef, sampleLinks }) => {
const graphRef = useRef(null);
const { client: qdrantClient } = useClient();
const { enqueueSnackbar } = useSnackbar();
Expand Down Expand Up @@ -53,7 +55,9 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
onDataDisplay(node);
})
.autoPauseRedraw(false)
.nodeCanvasObjectMode((node) => (node?.id === highlightedNode?.id ? 'before' : undefined))
.nodeCanvasObjectMode((node) => {
return node?.id === highlightedNode?.id ? 'before' : undefined;
})
.nodeCanvasObject((node, ctx) => {
if (!node) return;
// add ring for last hovered nodes
Expand All @@ -62,18 +66,33 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
ctx.fillStyle = node.id === highlightedNode?.id ? '#817' : 'transparent';
ctx.fill();
})
.linkLabel('score')
.linkColor(() => '#a6a6a6');

graphRef.current.d3Force('charge').strength(-10);
}, [initNode, options]);

useEffect(() => {
if (!wrapperRef) return;

const debouncedResizeCallback = debounce((width, height) => {
graphRef.current.width(width).height(height);
}, 500);

graphRef.current.width(wrapperRef?.clientWidth).height(wrapperRef?.clientHeight);
resizeObserverWithCallback(debouncedResizeCallback).observe(wrapperRef);

return () => {
resizeObserverWithCallback(debouncedResizeCallback).unobserve(wrapperRef);
};
}, [wrapperRef, initNode, options]);

useEffect(() => {
const initNewGraph = async () => {
const graphData = await initGraph(qdrantClient, {
...options,
initNode,
sampleLinks,
});
if (graphRef.current && options) {
const initialActiveNode = graphData.nodes[0];
Expand All @@ -83,9 +102,14 @@ const GraphVisualisation = ({ initNode, options, onDataDisplay, wrapperRef }) =>
}
};
initNewGraph().catch((e) => {
enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' });
console.error(e);
if (e.getActualType) {
enqueueSnackbar(JSON.stringify(e.getActualType()), { variant: 'error' });
} else {
enqueueSnackbar(e.message, { variant: 'error' });
}
});
}, [initNode, options]);
}, [initNode, options, sampleLinks]);

return <div id="graph"></div>;
};
Expand All @@ -95,6 +119,7 @@ GraphVisualisation.propTypes = {
options: PropTypes.object.isRequired,
onDataDisplay: PropTypes.func.isRequired,
wrapperRef: PropTypes.object,
sampleLinks: PropTypes.array,
};

export default GraphVisualisation;
9 changes: 9 additions & 0 deletions src/lib/common-helpers.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export const resizeObserverWithCallback = (callback) => {
return new ResizeObserver((entries) => {
for (const entry of entries) {
const { target } = entry;
const { width, height } = target.getBoundingClientRect();
if (typeof callback === 'function') callback(width, height);
}
});
};
125 changes: 118 additions & 7 deletions src/lib/graph-visualization-helpers.js
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
export const initGraph = async (qdrantClient, { collectionName, initNode, limit, filter, using }) => {
if (!initNode) {
import { axiosInstance } from '../common/axios';

export const initGraph = async (
qdrantClient,
{ collectionName, initNode, limit, filter, using, sampleLinks, tree = false }
) => {
let nodes = [];
let links = [];

if (sampleLinks) {
const uniquePoints = new Set();

for (const link of sampleLinks) {
links.push({ source: link.a, target: link.b, score: link.score });
uniquePoints.add(link.a);
uniquePoints.add(link.b);
}

if (tree) {
// ToDo acs should depend on metric type
links = getMinimalSpanningTree(links, true);
}

nodes = await getPointsWithPayload(qdrantClient, { collectionName, pointIds: Array.from(uniquePoints) });
} else if (initNode) {
initNode.clicked = true;
nodes = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using });
links = nodes.map((point) => ({ source: initNode.id, target: point.id, score: point.score }));
nodes = [initNode, ...nodes];
} else {
return {
nodes: [],
links: [],
};
}
initNode.clicked = true;

const points = await getSimilarPoints(qdrantClient, { collectionName, pointId: initNode.id, limit, filter, using });

const graphData = {
nodes: [initNode, ...points],
links: points.map((point) => ({ source: initNode.id, target: point.id })),
nodes,
links,
};
// console.log(graphData);
return graphData;
};

Expand Down Expand Up @@ -44,9 +70,94 @@ export const getFirstPoint = async (qdrantClient, { collectionName, filter }) =>
return points[0];
};

const getPointsWithPayload = async (qdrantClient, { collectionName, pointIds }) => {
const points = await qdrantClient.retrieve(collectionName, {
ids: pointIds,
with_payload: true,
with_vector: false,
});

return points;
};

export const getSamplePoints = async ({ collectionName, filter, sample, using, limit }) => {
// ToDo: replace it with qdrantClient when it will be implemented

const response = await axiosInstance({
method: 'POST',
url: `collections/${collectionName}/points/search/matrix/pairs`,
data: {
filter,
sample,
using,
limit,
},
});

return response.data.result.pairs;
};

export const deduplicatePoints = (existingPoints, foundPoints) => {
// Returns array of found points that are not in existing points
// deduplication is done by id
const existingIds = new Set(existingPoints.map((point) => point.id));
return foundPoints.filter((point) => !existingIds.has(point.id));
};

export const getMinimalSpanningTree = (links, acs = true) => {
// Sort links by score (assuming each link has a score property)

let sortedLinks = [];
if (acs) {
sortedLinks = links.sort((a, b) => b.score - a.score);
} else {
sortedLinks = links.sort((a, b) => a.score - b.score);
}
// Helper function to find the root of a node
const findRoot = (parent, i) => {
if (parent[i] === i) {
return i;
}
return findRoot(parent, parent[i]);
};

// Helper function to perform union of two sets
const union = (parent, rank, x, y) => {
const rootX = findRoot(parent, x);
const rootY = findRoot(parent, y);

if (rank[rootX] < rank[rootY]) {
parent[rootX] = rootY;
} else if (rank[rootX] > rank[rootY]) {
parent[rootY] = rootX;
} else {
parent[rootY] = rootX;
rank[rootX]++;
}
};

const parent = {};
const rank = {};
const mstLinks = [];

// Initialize parent and rank arrays
links.forEach((link) => {
parent[link.source] = link.source;
parent[link.target] = link.target;
rank[link.source] = 0;
rank[link.target] = 0;
});

// Kruskal's algorithm
sortedLinks.forEach((link) => {
const sourceRoot = findRoot(parent, link.source);
const targetRoot = findRoot(parent, link.target);

if (sourceRoot !== targetRoot) {
mstLinks.push(link);
union(parent, rank, sourceRoot, targetRoot);
}
});

return mstLinks;
};
56 changes: 56 additions & 0 deletions src/lib/tests/graph-visualization-helpers.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import { describe, it, expect } from 'vitest';
import { getMinimalSpanningTree } from '../graph-visualization-helpers';

describe('getMinimalSpanningTree', () => {
it('should return the minimal spanning tree for a given set of links (ascending order)', () => {
const links = [
{ source: 'A', target: 'B', score: 1 },
{ source: 'B', target: 'C', score: 2 },
{ source: 'A', target: 'C', score: 3 },
{ source: 'C', target: 'D', score: 4 },
{ source: 'B', target: 'D', score: 5 },
];

const expectedMST = [
{ source: 'B', target: 'D', score: 5 },
{ source: 'C', target: 'D', score: 4 },
{ source: 'A', target: 'C', score: 3 },
];

const result = getMinimalSpanningTree(links, true);
expect(result).toEqual(expectedMST);
});

it('should return the minimal spanning tree for a given set of links (descending order)', () => {
const links = [
{ source: 'A', target: 'B', score: 1 },
{ source: 'B', target: 'C', score: 2 },
{ source: 'A', target: 'C', score: 3 },
{ source: 'C', target: 'D', score: 4 },
{ source: 'B', target: 'D', score: 5 },
];

const expectedMST = [
{ source: 'A', target: 'B', score: 1 },
{ source: 'B', target: 'C', score: 2 },
{ source: 'C', target: 'D', score: 4 },
];

const result = getMinimalSpanningTree(links, false);
expect(result).toEqual(expectedMST);
});

it('should return an empty array if no links are provided', () => {
const links = [];
const expectedMST = [];
const result = getMinimalSpanningTree(links, true);
expect(result).toEqual(expectedMST);
});

it('should handle a single link correctly', () => {
const links = [{ source: 'A', target: 'B', score: 1 }];
const expectedMST = [{ source: 'A', target: 'B', score: 1 }];
const result = getMinimalSpanningTree(links, true);
expect(result).toEqual(expectedMST);
});
});
Loading
Loading