import * as d3 from 'd3';

import { Chart } from './Chart';
import { ChartContextProvider, Datum, useChart } from './context';

interface Props<D extends Datum> extends React.PropsWithChildren {
  rootId?: string;
  identifier: (d: D) => string;
  descendants: (d: D | undefined) => string[] | undefined;
  vertical?: boolean;
  reverse?: boolean;
  node?: (
    n: d3.HierarchyPointNode<D | undefined>,
  ) => React.ReactElement<React.SVGProps<SVGElement>>;
  visible?: (d: D) => boolean;
}

export const TreeChart = <D extends Datum>({
  rootId,
  identifier,
  descendants,
  vertical,
  reverse,
  node,
  visible,
  children,
}: Props<D>) => {
  const { padding, width, height, data, ...ctx } = useChart<D>(),
    index = d3.index(data, identifier),
    root = rootId ? index.get(rootId) : undefined,
    tree = d3.tree<D | undefined>()(
      d3.hierarchy(root, (d) =>
        descendants(d)?.reduce<(D | undefined)[]>(
          (D, d) => D.concat(index.get(d)),
          [],
        ),
      ),
    ),
    nodes = Array.from(tree),
    x = d3
      .scaleLinear()
      .range([padding[3], width - padding[1]])
      .domain(reverse && !vertical ? [1, 0] : [0, 1]),
    y = d3
      .scaleLinear()
      .range([height - padding[2], padding[0]])
      .domain(
        (reverse && !vertical) || (!reverse && vertical) ? [1, 0] : [0, 1],
      ),
    link = d3
      .link<
        d3.HierarchyPointLink<D | undefined>,
        d3.HierarchyPointNode<D | undefined>
      >(vertical ? d3.curveBumpY : d3.curveBumpX)
      .x((d) => x(vertical ? d.x : d.y))
      .y((d) => y(vertical ? d.y : d.x));

  return (
    <ChartContextProvider
      value={{
        ...ctx,
        data,
        padding,
        width,
        height,
      }}
    >
      <Chart>
        {children}
        <g data-chart-component='chart' data-chart-type='tree'>
          <g data-chart-component='links'>
            {tree
              .links()
              .filter(({ source: { data: s }, target: { data: t } }) =>
                visible && s && t ? visible(s) && visible(t) : true,
              )
              .map((l, i) => (
                <path
                  key={i}
                  d={link(l) ?? undefined}
                  stroke='currentColor'
                  fill='none'
                />
              ))}
          </g>
          <g data-chart-component='nodes'>
            {nodes
              .filter((n) => (visible && n.data ? visible(n.data) : true))
              .map((n, i) => (
                <g
                  key={i}
                  transform={`translate(${x(vertical ? n.x : n.y)} ${y(
                    vertical ? n.y : n.x,
                  )})`}
                >
                  {(node ?? (() => <circle r={5} stroke='currentColor' />))(n)}
                </g>
              ))}
          </g>
        </g>
      </Chart>
    </ChartContextProvider>
  );
};
