import {
  createContext,
  useCallback,
  useContext,
  useEffect,
  useMemo,
  useRef,
  useState,
} from 'react';
import { createPortal } from 'react-dom';

import * as d3 from 'd3';

import { Datum, useChart } from './context';
import { isScaleBand } from './types';

interface TooltipValue<D extends Datum> {
  i?: number;
  datum?: D;
  pos?: [number, number];
  data?: D[];
}

const TooltipContext = createContext<TooltipValue<Datum>>({});

export const useTidyTooltip = <D extends Datum>() =>
  useContext<TooltipValue<D>>(
    TooltipContext as unknown as React.Context<TooltipValue<D>>,
  );

interface Props<D extends Datum> extends React.PropsWithChildren {
  content?: (d: TooltipValue<D>) => JSX.Element;
}

export const TidyTooltip = <D extends Datum>({
  content,
  children,
}: Props<D>) => {
  const tooltipRef = useRef<HTMLDivElement | null>(null),
    [open, setOpen] = useState(false),
    [pos, setPos] = useState<[number, number]>([0, 0]),
    [cursor, setCursor] = useState<[number, number]>([0, 0]),
    [offset, setOffset] = useState<[number, number]>([0, 0]),
    [datum, setDatum] = useState<D>(),
    [index, setIndex] = useState(-1),
    { ref, svg, data, x, y, stack } = useChart<D>();

  // TODO: make calculateOffset configurable
  const calculateOffset = (
    container: DOMRect,
    tooltip: DOMRect,
    pos: [number, number],
    bandwidth: number,
  ) => {
    const margin = 8,
      xMargin = margin + bandwidth / 2,
      defaultXOffset = -tooltip.width / 2,
      defaultYOffset = -tooltip.height - margin,
      tooltipLeft = pos[0] + defaultXOffset - xMargin,
      tooltipRight = pos[0] + tooltip.width + defaultXOffset + xMargin,
      tooltipTop = pos[1] + defaultYOffset,
      tooltipBottom = pos[1] + tooltip.height + defaultYOffset - margin;

    let offset: [number, number] = [defaultXOffset, defaultYOffset];

    if (tooltip.width > container.width) {
      // center tooltip within container
      offset[0] =
        // subtract pos, the default tooltip left
        -pos[0] +
        container.left +
        container.width / 2 -
        tooltip.width / 2 +
        xMargin;
    } else if (tooltipRight > container.right)
      // tooltip overflows right of container
      // offset left of pos
      offset[0] = -tooltip.width - xMargin;
    else if (tooltipLeft < container.left)
      // tooltip overflows left of container
      // offset right of pos
      offset[0] = xMargin;

    const topOverflow = container.top - tooltipTop,
      bottomOverflow = tooltipBottom - container.bottom;
    if (tooltip.height > container.height) {
      // center tooltip within container
      offset[1] =
        -pos[1] +
        container.top +
        container.height / 2 -
        tooltip.height / 2 +
        margin;
      if (offset[0] === defaultXOffset) {
        // tooltip is centered horizontally on cursor
        // offset tooltip right or left if it fits within the container
        if (tooltipRight + tooltip.width / 2 < container.right)
          // tooltip fits to the right, offset right of pos
          offset[0] = xMargin;
        else if (tooltipLeft - tooltip.width / 2 > container.left)
          // tooltip fits to the left, offset left of pos
          offset[0] = -tooltip.width - xMargin;
      }
    } else if (topOverflow > 0) {
      // tooltip overflows container top
      if (topOverflow < tooltip.height / 2 - margin)
        // tooltip fits if offset to vertical center of pos
        offset[1] = -tooltip.height / 2;
      else if (pos[1] + tooltip.height + 2 * margin < container.bottom)
        // tooltip fits if offset below pos
        offset[1] = margin;
      else {
        // center tooltip within container
        offset[1] =
          -pos[1] +
          container.top +
          container.height / 2 -
          tooltip.height / 2 +
          margin;
      }
      if (offset[0] === defaultXOffset) {
        // tooltip is centered horizontally on cursor
        // offset tooltip right or left if it fits within the container
        if (tooltipRight + tooltip.width / 2 < container.right)
          // tooltip fits to the right, offset right of pos
          offset[0] = xMargin;
        else if (tooltipLeft - tooltip.width / 2 > container.left)
          // tooltip fits to the left, offset left of pos
          offset[0] = -tooltip.width - xMargin;
      }
    } else if (bottomOverflow > 0)
      // tooltip overflows container bottom
      // offset above container bottom
      offset[1] -= bottomOverflow + margin;

    return offset;
  };

  const rects = useMemo(() => {
    const ps = stack?.flatMap((s) => {
      const ps = s.map((p) => {
        const x0 = x(p.data[0] as any) ?? 0,
          width = isScaleBand(x) ? x.bandwidth() : 0,
          y0 = y(p[1]),
          y1 = y(p[0]);
        return [
          [x0, y0],
          [x0 + width, y0],
          [x0 + width, y1],
          [x0, y1],
        ] as [number, number][];
      });
      return ps;
    });
    return ps;
  }, [x, y, stack]);

  const bandwidth = isScaleBand(x) ? x.bandwidth() : 0,
    onPointerMove = useCallback(
      (e: PointerEvent) => {
        const container = svg?.current?.getBoundingClientRect();
        if (container === undefined) return;

        const [cx, cy] = d3.pointer(e),
          i =
            rects?.findIndex((rect) => d3.polygonContains(rect, [cx, cy])) ??
            -1,
          i0 = Math.floor(i / (x.domain().length ?? 1)),
          i1 = i % (x.domain().length ?? 1),
          series = stack?.at(i0),
          point = series?.at(i1),
          d = point?.data[1].get(series?.key ?? ''),
          pos: [number, number] = [
            (container.left ?? 0) +
              (x(point?.data[0] as any) ?? 0) +
              bandwidth / 2,
            (container.top ?? 0) + y(point?.[1] ?? 0),
          ];

        setIndex(i);
        setDatum(d);
        if (i !== -1) setPos(pos);
        setCursor([cx, cy]);

        const tooltip = tooltipRef.current?.getBoundingClientRect();
        if (tooltip === undefined) return;

        setOffset(calculateOffset(container, tooltip, pos, bandwidth));
      },
      [bandwidth, rects, stack, svg, x, y],
    );

  useEffect(() => {
    if (!svg?.current) return;
    const s = svg.current,
      onPointerEnter = () => void setOpen(true),
      onPointerLeave = () => void setOpen(false);

    s.addEventListener('pointerenter', onPointerEnter);
    s.addEventListener('pointerleave', onPointerLeave);
    s.addEventListener('pointermove', onPointerMove);
    return () => {
      s.removeEventListener('pointerenter', onPointerEnter);
      s.removeEventListener('pointerleave', onPointerLeave);
      s.removeEventListener('pointermove', onPointerMove);
    };
  }, [svg, onPointerMove]);

  const v = { i: index, datum, pos, data };
  return (
    <TooltipContext.Provider value={v}>
      {open &&
        ref?.current &&
        createPortal(
          <div
            ref={tooltipRef}
            style={{
              pointerEvents: 'none',
              transition: 'top 0.3s ease-out, left 0.3s ease-out',
              position: 'fixed',
              top: pos[1] + offset[1],
              left: pos[0] + offset[0],
            }}
          >
            {content !== undefined && index !== -1
              ? content(v)
              : index !== -1 && (
                  <div
                    style={{
                      width: '100%',
                      height: '100%',
                      color: 'black',
                      padding: '0 8px',
                      backgroundColor: 'white',
                      border: '1px solid black',
                      borderRadius: '4px',
                    }}
                  >
                    <>
                      <p>
                        data point: [
                        {Math.floor(
                          pos[0] -
                            (svg?.current?.getBoundingClientRect().left ?? 0),
                        )}
                        ,{' '}
                        {Math.floor(
                          pos[1] -
                            (svg?.current?.getBoundingClientRect().top ?? 0),
                        )}
                        ]
                      </p>
                      <p>
                        cursor: [{cursor.map((v) => Math.round(v)).toString()}]
                      </p>
                      <p>index: {index}</p>
                      <hr />
                      {datum &&
                        Object.entries(datum).map(([k, v]) => (
                          <p key={k}>
                            {k}: {v?.toString() ?? 'N/A'}
                          </p>
                        ))}
                    </>
                  </div>
                )}
            {children}
          </div>,
          ref.current,
        )}
    </TooltipContext.Provider>
  );
};
