import {
  createContext,
  useCallback,
  useContext,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react';
import { createPortal } from 'react-dom';

import * as d3 from 'd3';

import { Datum, useChart } from './context';
import { isScaleBand } from './types';

interface TooltipValue<D extends Datum> {
  i?: number;
  datum?: D;
  pos?: [number, number];
  data?: D[];
}

const TooltipContext = createContext<TooltipValue<Datum>>({});

export const useTooltip = <D extends Datum>() =>
  useContext<TooltipValue<D>>(
    TooltipContext as unknown as React.Context<TooltipValue<D>>,
  );

interface Props<D extends Datum, Domain extends string | number = number>
  extends React.PropsWithChildren {
  accessors:
    | [(d: D) => Domain, ((d: D) => number) | undefined]
    | ((d: D) => any)[];

  content?: (d: TooltipValue<D>) => JSX.Element;
}

export const Tooltip = <
  D extends Datum,
  Domain extends string | number = number,
>({
  accessors,
  content,
  children,
}: Props<D, Domain>) => {
  const { ref, svg, data, x, y } = useChart<D>(),
    tooltipRef = useRef<HTMLDivElement | null>(null),
    [open, setOpen] = useState(false),
    [pos, setPos] = useState<[number, number]>([0, 0]),
    [offset, setOffset] = useState<[number, number]>([0, 0]),
    [datum, setDatum] = useState<D>(),
    [index, setIndex] = useState(-1);

  // TODO: make calculateOffset configurable
  const calculateOffset = (
    container: DOMRect,
    tooltip: DOMRect,
    pos: [number, number],
    bandwidthPadding: number,
  ) => {
    const margin = 8,
      xMargin = margin + bandwidthPadding,
      defaultXOffset = -tooltip.width / 2,
      defaultYOffset = -tooltip.height - margin,
      tooltipLeft = pos[0] + defaultXOffset - xMargin,
      tooltipRight = pos[0] + tooltip.width + defaultXOffset + xMargin,
      tooltipTop = pos[1] + defaultYOffset,
      tooltipBottom = pos[1] + tooltip.height + defaultYOffset - margin;

    let offset: [number, number] = [defaultXOffset, defaultYOffset];

    if (tooltip.width > container.width) {
      // center tooltip within container
      offset[0] =
        // subtract pos, the default tooltip left
        -pos[0] +
        container.left +
        container.width / 2 -
        tooltip.width / 2 +
        xMargin;
    } else if (tooltipRight > container.right)
      // tooltip overflows right of container
      // offset left of pos
      offset[0] = -tooltip.width - xMargin;
    else if (tooltipLeft < container.left)
      // tooltip overflows left of container
      // offset right of pos
      offset[0] = xMargin;

    const topOverflow = container.top - tooltipTop,
      bottomOverflow = tooltipBottom - container.bottom;
    if (tooltip.height > container.height) {
      // center tooltip within container
      offset[1] =
        -pos[1] +
        container.top +
        container.height / 2 -
        tooltip.height / 2 +
        margin;
      if (offset[0] === defaultXOffset) {
        // tooltip is centered horizontally on cursor
        // offset tooltip right or left if it fits within the container
        if (tooltipRight + tooltip.width / 2 < container.right)
          // tooltip fits to the right, offset right of pos
          offset[0] = xMargin;
        else if (tooltipLeft - tooltip.width / 2 > container.left)
          // tooltip fits to the left, offset left of pos
          offset[0] = -tooltip.width - xMargin;
      }
    } else if (topOverflow > 0) {
      // tooltip overflows container top
      if (topOverflow < tooltip.height / 2 - margin)
        // tooltip fits if offset to vertical center of pos
        offset[1] = -tooltip.height / 2;
      // offset below pos
      else offset[1] = margin;
      if (offset[0] === defaultXOffset) {
        // tooltip is centered horizontally on cursor
        // offset tooltip right or left if it fits within the container
        if (tooltipRight + tooltip.width / 2 < container.right)
          // tooltip fits to the right, offset right of pos
          offset[0] = xMargin;
        else if (tooltipLeft - tooltip.width / 2 > container.left)
          // tooltip fits to the left, offset left of pos
          offset[0] = -tooltip.width - xMargin;
      }
    } else if (bottomOverflow > 0)
      // tooltip overflows container bottom
      // offset above container bottom
      offset[1] -= bottomOverflow + margin;

    return offset;
  };

  const bandwidthPadding = (isScaleBand(x) ? x.bandwidth() : 0) / 2,
    points = useMemo(
      () =>
        data.map((d) => [
          (x(accessors[0](d)) ?? 0) + bandwidthPadding,
          y(accessors[1]?.(d)) ?? y(d3.sum(y.domain()) / 2),
          d,
        ]) as [number, number, D][],
      [accessors, bandwidthPadding, data, x, y],
    ),
    onPointerMove = useCallback(
      (e: PointerEvent) => {
        const container = svg?.current?.getBoundingClientRect();
        if (container === undefined) return;

        const [cx, cy] = d3.pointer(e),
          i =
            d3.leastIndex(points, ([dx, dy]) =>
              isScaleBand(x) ? Math.abs(dx - cx) : Math.hypot(dx - cx, dy - cy),
            ) ?? -1,
          d = points[i][2],
          pos: [number, number] = [
            (container.left ?? 0) + (points.at(i)?.[0] ?? 0),
            (container.top ?? 0) + (points.at(i)?.[1] ?? 0),
          ];

        setIndex(i);
        setDatum(d);
        setPos(pos);

        const tooltip = tooltipRef.current?.getBoundingClientRect();
        if (tooltip === undefined) return;

        setOffset(calculateOffset(container, tooltip, pos, bandwidthPadding));
      },
      [svg, points, bandwidthPadding, x],
    );

  useEffect(() => {
    if (!svg?.current) return;
    const s = svg.current,
      onPointerEnter = () => void setOpen(true),
      onPointerLeave = () => void setOpen(false);

    s.addEventListener('pointerenter', onPointerEnter);
    s.addEventListener('pointerleave', onPointerLeave);
    s.addEventListener('pointermove', onPointerMove);
    return () => {
      s.removeEventListener('pointerenter', onPointerEnter);
      s.removeEventListener('pointerleave', onPointerLeave);
      s.removeEventListener('pointermove', onPointerMove);
    };
  }, [svg, onPointerMove]);

  const value = { i: index, datum, pos, data };
  return (
    <TooltipContext.Provider value={value}>
      {open &&
        ref?.current &&
        datum !== undefined &&
        createPortal(
          <div
            ref={tooltipRef}
            style={{
              pointerEvents: 'none',
              transition: 'top 0.3s ease-out, left 0.3s ease-out',
              position: 'fixed',
              top: pos[1] + offset[1],
              left: pos[0] + offset[0],
            }}
          >
            {content !== undefined ? (
              content(value)
            ) : (
              <div
                style={{
                  width: '100%',
                  height: '100%',
                  color: 'black',
                  padding: '0 8px',
                  backgroundColor: 'white',
                  border: '1px solid black',
                  borderRadius: '4px',
                }}
              >
                <p>
                  {accessors[0](datum)}: {accessors[1]?.(datum)}
                </p>
              </div>
            )}
            {children}
          </div>,
          ref.current,
        )}
    </TooltipContext.Provider>
  );
};
