Skip to content

Commit

Permalink
refactor(D3 plugin): render scatter via d3 (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
korvin89 authored Sep 4, 2023
1 parent fff81f1 commit 2170c73
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 71 deletions.
48 changes: 31 additions & 17 deletions src/plugins/d3/renderer/hooks/useShapes/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ import React from 'react';
import {group} from 'd3';

import type {ScatterSeries} from '../../../../../types/widget-data';
import {getRandomCKId} from '../../../../../utils';

import {getOnlyVisibleSeries} from '../../utils';
import type {ChartOptions} from '../useChartOptions/types';
import type {ChartScale} from '../useAxisScales';
import type {PreparedBarXSeries, PreparedPieSeries, PreparedSeries} from '../';
import type {OnSeriesMouseMove, OnSeriesMouseLeave} from '../useTooltip/types';
import {BarXSeriesShapes} from './bar-x';
import {prepareScatterSeries} from './scatter';
import {ScatterSeriesShape} from './scatter';
import {PieSeriesComponent} from './pie';

import './styles.scss';
Expand Down Expand Up @@ -56,39 +57,50 @@ export const useShapes = (args: Args) => {
if (xScale && yScale) {
acc.push(
<BarXSeriesShapes
{...args}
key="bar-x"
series={chartSeries as PreparedBarXSeries[]}
xAxis={xAxis}
xScale={xScale}
yAxis={yAxis}
yScale={yScale}
top={top}
left={left}
svgContainer={svgContainer}
onSeriesMouseMove={onSeriesMouseMove}
onSeriesMouseLeave={onSeriesMouseLeave}
/>,
);
}
break;
}
case 'scatter': {
if (xScale && yScale) {
acc.push(
...prepareScatterSeries({
top,
left,
series: chartSeries as ScatterSeries[],
xAxis,
xScale,
yAxis,
yScale,
onSeriesMouseMove,
onSeriesMouseLeave,
svgContainer,
}),
);
const scatterShapes = chartSeries.map((scatterSeries, i) => {
const id = getRandomCKId();
return (
<ScatterSeriesShape
key={`${i}-${id}`}
top={top}
left={left}
series={scatterSeries as ScatterSeries}
xAxis={xAxis}
xScale={xScale}
yAxis={yAxis}
yScale={yScale}
onSeriesMouseMove={onSeriesMouseMove}
onSeriesMouseLeave={onSeriesMouseLeave}
svgContainer={svgContainer}
/>
);
});
acc.push(...scatterShapes);
}
break;
}
case 'pie': {
const groupedPieSeries = group(
chartSeries as PreparedPieSeries[],
(item) => item.stackId,
(pieSeries) => pieSeries.stackId,
);
acc.push(
...Array.from(groupedPieSeries).map(([key, pieSeries]) => {
Expand Down Expand Up @@ -118,6 +130,8 @@ export const useShapes = (args: Args) => {
yAxis,
yScale,
svgContainer,
left,
top,
onSeriesMouseMove,
onSeriesMouseLeave,
]);
Expand Down
128 changes: 74 additions & 54 deletions src/plugins/d3/renderer/hooks/useShapes/scatter.tsx
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import {pointer, ScaleBand, ScaleLinear, ScaleTime} from 'd3';
import {pointer, select} from 'd3';
import type {ScaleBand, ScaleLinear, ScaleTime} from 'd3';
import React from 'react';
import {ChartOptions} from '../useChartOptions/types';
import {ChartScale} from '../useAxisScales';
import {OnSeriesMouseLeave, OnSeriesMouseMove} from '../useTooltip/types';
import {ScatterSeries, ScatterSeriesData} from '../../../../../types/widget-data';
import {block} from '../../../../../utils/cn';
import {getRandomCKId} from '../../../../../utils';

type PrepareScatterSeriesArgs = {
type ScatterSeriesShapeProps = {
top: number;
left: number;
series: ScatterSeries[];
series: ScatterSeries;
xAxis: ChartOptions['xAxis'];
xScale: ChartScale;
yAxis: ChartOptions['yAxis'];
Expand All @@ -31,17 +31,14 @@ const prepareLinearScatterData = (data: ScatterSeriesData[]) => {
return data.filter((d) => typeof d.x === 'number' && typeof d.y === 'number');
};

const getPointProperties = (args: {
const getCxAttr = (args: {
point: ScatterSeriesData;
xAxis: ChartOptions['xAxis'];
xScale: ChartScale;
yAxis: ChartOptions['yAxis'];
yScale: ChartScale;
}) => {
const {point, xAxis, xScale, yAxis, yScale} = args;
const r = point.radius || DEFAULT_SCATTER_POINT_RADIUS;
let cx: string | number | undefined;
let cy: string | number | undefined;
const {point, xAxis, xScale} = args;

let cx: number;

if (xAxis.type === 'category') {
const xBandScale = xScale as ScaleBand<string>;
Expand All @@ -51,6 +48,18 @@ const getPointProperties = (args: {
cx = xLinearScale(point.x as number);
}

return cx;
};

const getCyAttr = (args: {
point: ScatterSeriesData;
yAxis: ChartOptions['yAxis'];
yScale: ChartScale;
}) => {
const {point, yAxis, yScale} = args;

let cy: number;

if (yAxis[0].type === 'category') {
const yBandScale = yScale as ScaleBand<string>;
cy = (yBandScale(point.category as string) || 0) + yBandScale.step() / 2;
Expand All @@ -59,62 +68,73 @@ const getPointProperties = (args: {
cy = yLinearScale(point.y as number);
}

return {r, cx, cy};
return cy;
};

export function prepareScatterSeries(args: PrepareScatterSeriesArgs) {
export function ScatterSeriesShape(props: ScatterSeriesShapeProps) {
const {
top,
left,
series,
xAxis,
xScale,
yAxis,
yScale,
svgContainer,
left,
top,
onSeriesMouseMove,
onSeriesMouseLeave,
svgContainer,
} = args;
} = props;
const ref = React.useRef<SVGGElement>(null);

return series.reduce<React.ReactElement[]>((result, s) => {
const randomKey = getRandomCKId();
React.useEffect(() => {
if (!ref.current) {
return;
}

const svgElement = select(ref.current);
svgElement.selectAll('*').remove();
const preparedData =
xAxis.type === 'category' || yAxis[0]?.type === 'category'
? prepareCategoricalScatterData(s.data)
: prepareLinearScatterData(s.data);

result.push(
...preparedData.map((point, i) => {
const pointProps = getPointProperties({
point,
xAxis,
xScale,
yAxis,
yScale,
? prepareCategoricalScatterData(series.data)
: prepareLinearScatterData(series.data);

svgElement
.selectAll('allPoints')
.data(preparedData)
.enter()
.append('circle')
.attr('class', b('point'))
.attr('fill', series.color || '')
.attr('r', (d) => d.radius || DEFAULT_SCATTER_POINT_RADIUS)
.attr('cx', (d) => getCxAttr({point: d, xAxis, xScale}))
.attr('cy', (d) => getCyAttr({point: d, yAxis, yScale}))
.on('mousemove', (e, d) => {
const [x, y] = pointer(e, svgContainer);
onSeriesMouseMove?.({
hovered: {
data: d,
series,
},
pointerPosition: [x - left, y - top],
});
})
.on('mouseleave', () => {
if (onSeriesMouseLeave) {
onSeriesMouseLeave();
}
});
}, [
series,
xAxis,
xScale,
yAxis,
yScale,
svgContainer,
left,
top,
onSeriesMouseMove,
onSeriesMouseLeave,
]);

return (
<circle
key={`${i}-${randomKey}`}
className={b('point')}
fill={s.color}
{...pointProps}
onMouseMove={function (e) {
const [x, y] = pointer(e, svgContainer);
onSeriesMouseMove?.({
hovered: {
data: point,
series: s,
},
pointerPosition: [x - left, y - top],
});
}}
onMouseLeave={onSeriesMouseLeave}
/>
);
}),
);

return result;
}, []);
return <g ref={ref} className={b()} />;
}

0 comments on commit 2170c73

Please sign in to comment.