import { useLayoutEffect, useMemo, useRef, useState } from 'react';

import * as d3 from 'd3';

import { Datum, StringValue, useChart } from './context';

export const useElementBounds = <T extends HTMLElement = HTMLDivElement>(): [
  React.MutableRefObject<T | null>,
  DOMRectReadOnly,
] => {
  const ref = useRef<T | null>(null),
    [bounds, setBounds] = useState<DOMRect>(new DOMRect()),
    observer = new ResizeObserver((entries) =>
      setBounds((prev: DOMRect) => {
        const cur = entries.at(0)?.contentRect;
        return cur === undefined ||
          (prev.top === cur.top &&
            prev.left === cur.left &&
            prev.width === cur.width &&
            prev.height === cur.height)
          ? prev
          : cur;
      }),
    );

  useLayoutEffect(() => {
    if (ref.current === null) return () => {};
    observer.observe(ref.current);
    return () => void observer.disconnect();
  });

  return [ref, bounds];
};

export const useLinearScale = () => {
  const { width, padding } = useChart();
  return d3.scaleLinear().range([padding[3], width - padding[1]]);
};

export const useBandScale = <Domain extends StringValue = StringValue>() => {
  const { width, padding } = useChart();
  return d3.scaleBand<Domain>().range([padding[3], width - padding[1]]);
};

export const useRangeScale = () => {
  const { height, padding } = useChart();
  return d3.scaleLinear().range([height - padding[2], padding[0]]);
};

export const useStack = <
  D extends Datum,
  Domain extends StringValue = StringValue,
>(
  domain: (d: D) => Domain,
  value: (d: D | undefined) => number,
  series: (d: D) => string,
  order?: (
    series: d3.Series<[number, d3.InternMap<string, D>], string>[],
  ) => Iterable<number>,
  offset?: (
    series: d3.Series<[number, d3.InternMap<string, D>], string>[],
    order: number[],
  ) => void,
) => {
  const { data } = useChart<D>(),
    index = d3.index(data, domain, series);
  return useMemo(
    () =>
      d3
        .stack<[Domain, d3.InternMap<string, D>], string>()
        .order(order ?? d3.stackOrderDescending)
        .offset(offset ?? d3.stackOffsetNone)
        .keys(d3.union(data.map(series)))
        .value(([, group], key) => value(group.get(key)))(index),
    [data, index, offset, order, series, value],
  );
};
