import * as styles from './ScatterPlot.module.scss'

import React, { useState, useEffect, useRef, useMemo, useCallback } from 'react'
import * as d3 from 'd3'
import { flatMap } from 'lodash'
// import Measure from 'react-measure'
import classNames from 'classnames'

import { Size, EdgeInsets } from 'utils/geometry'
import { expandRangeByFraction } from './utils/ranges'
import { normalizeSeriesData } from './normalizeSeriesData'

/**
 * Data can be specified in a view ways, either as an array of points, or an
 * array of series each with a set of values and optionally a color and name.
 * Points can either be specified with coordinate tuple, or an object with x, y,
 * and any other properties that may be used in tooltips.
 */
interface Props<Datum extends { x: number; y: number }> {
  data?:
    | [number, number][]
    | Datum[]
    | {
        name?: string
        color?: string
        values: [number, number][]
      }[]
    | {
        name?: string
        color?: string
        values: Datum[]
      }[]
  verticalAxisLabel: string | [string, string, string] | null
  horizontalAxisLabel: string | [string, string, string] | null
  showIntervals?: boolean
  domain?: { x: [number, number]; y: [number, number] }
  pointSize?: number
  strokePoints?: boolean
  paths?: string[]
  renderTooltip?(datum: Datum): React.ReactNode
}

const labelHeight = 20

const ScatterPlot = <Datum extends { x: number; y: number }>(
  props: Props<Datum>,
) => {
  const {
    data,

    verticalAxisLabel,
    horizontalAxisLabel,
    showIntervals,
    domain,
    pointSize,
    strokePoints,
    paths,
    renderTooltip,
  } = props

  const [containerSize] = useState<{
    width: number
    height: number
  }>({ width: 600, height: 600 })

  const bounds = new Size(containerSize.width, containerSize.height).rect

  const dataSets = useMemo(() => {
    return normalizeSeriesData<Datum>(data ?? null)
  }, [data])

  const [hoveredItem, setHoveredItem] = useState<Datum | null>(null)

  const mouseOver = useCallback((event: MouseEvent, datum: Datum) => {
    setHoveredItem(datum)
  }, [])

  const mouseLeave = useCallback(() => {
    setHoveredItem(null)
  }, [])

  const margins = useMemo(
    () =>
      new EdgeInsets(
        0,
        0,
        horizontalAxisLabel != null ? 40 : 20,
        verticalAxisLabel != null ? 40 : 20,
      ),
    [horizontalAxisLabel, verticalAxisLabel],
  )

  const svgElement = useRef<SVGSVGElement | null>(null)

  const contentRect = useMemo(() => {
    return bounds.inset(margins)
  }, [bounds, margins])

  const allPoints = useMemo(() => {
    return flatMap(dataSets, (set) => set.values)
  }, [dataSets])

  const scale = useMemo(() => {
    const x = d3
      .scaleLinear()
      .domain(
        domain?.x ||
          expandRangeByFraction(
            [
              d3.min(allPoints, (d) => d.x) ?? 0,
              d3.max(allPoints, (d) => d.x) ?? 1,
            ],
            0.15,
          ),
      )
      .range([contentRect.x, contentRect.maxX])
      .nice()

    const y = d3
      .scaleLinear()
      .domain(
        domain?.y ||
          expandRangeByFraction(
            [
              d3.min(allPoints, (d) => d.y) ?? 0,
              d3.max(allPoints, (d) => d.y) ?? 1,
            ],
            0.15,
          ),
      )
      .range([contentRect.maxY, contentRect.y])
      .nice()

    return { x, y }
  }, [allPoints, contentRect, domain])

  useEffect(() => {
    if (svgElement.current == null) {
      return
    }

    const svg = d3.select(svgElement.current)

    svg.selectAll('g').remove()

    const xAxis = d3.axisBottom(scale.x).tickSize(0).ticks(0)
    const yAxis = d3.axisLeft(scale.y).tickSize(0).ticks(0)

    if (showIntervals) {
      xAxis.ticks(1).tickPadding(10)
      yAxis.ticks(1).tickPadding(10)
    }

    svg.selectAll('*').remove()

    svg
      .append('g')
      .attr('class', styles.axis)
      .attr('transform', `translate(0, ${contentRect.maxY})`)
      .call(xAxis)

    svg
      .append('g')
      .attr('class', styles.axis)
      .attr('transform', `translate(${contentRect.x}, 0)`)
      .call(yAxis)

    if (horizontalAxisLabel) {
      const group = svg.append('g')

      if (typeof horizontalAxisLabel === 'string') {
        group
          .append('text')
          .attr('class', styles.axisLabel)
          .attr('alignment-baseline', 'hanging')
          .attr('text-anchor', 'middle')
          .attr('x', contentRect.midX)
          .attr('y', bounds.height - labelHeight)
          .text(horizontalAxisLabel)
      } else {
        group
          .append('text')
          .attr(
            'class',
            classNames(styles.axisLabel, styles.axisLabelSecondary),
          )
          .attr('alignment-baseline', 'hanging')
          .attr('text-anchor', 'start')
          .attr('x', contentRect.x + 30)
          .attr('y', bounds.height - labelHeight)
          .text(horizontalAxisLabel[0])

        group
          .append('text')
          .attr('class', styles.axisLabel)
          .attr('alignment-baseline', 'hanging')
          .attr('text-anchor', 'middle')
          .attr('x', contentRect.midX)
          .attr('y', bounds.height - labelHeight)
          .text(`← ${horizontalAxisLabel[1]} →`)

        group
          .append('text')
          .attr(
            'class',
            classNames(styles.axisLabel, styles.axisLabelSecondary),
          )
          .attr('alignment-baseline', 'hanging')
          .attr('text-anchor', 'end')
          .attr('x', contentRect.maxX - 30)
          .attr('y', bounds.height - labelHeight)
          .text(horizontalAxisLabel[2])
      }
    }

    if (verticalAxisLabel) {
      const group = svg.append('g')

      if (typeof verticalAxisLabel === 'string') {
        group
          .append('text')
          .attr('class', styles.axisLabel)
          .attr('text-anchor', 'middle')
          .attr('x', -contentRect.midY)
          .attr('y', labelHeight)
          .attr('transform', 'rotate(-90)')
          .text(verticalAxisLabel)
      } else {
        group
          .append('text')
          .attr(
            'class',
            classNames(styles.axisLabel, styles.axisLabelSecondary),
          )
          .attr('text-anchor', 'start')
          .attr('x', -contentRect.maxY + 30)
          .attr('y', labelHeight)
          .attr('transform', 'rotate(-90)')
          .text(verticalAxisLabel[0])

        group
          .append('text')
          .attr('class', styles.axisLabel)
          .attr('text-anchor', 'middle')
          .attr('x', -contentRect.midY)
          .attr('y', labelHeight)
          .attr('transform', 'rotate(-90)')
          .text(`← ${verticalAxisLabel[1]} →`)

        group
          .append('text')
          .attr(
            'class',
            classNames(styles.axisLabel, styles.axisLabelSecondary),
          )
          .attr('text-anchor', 'end')
          .attr('x', -contentRect.y - 30)
          .attr('y', labelHeight)
          .attr('transform', 'rotate(-90)')
          .text(verticalAxisLabel[2])
      }
    }

    const series = svg
      .selectAll(`.${styles.series}`)
      .data(dataSets)
      .enter()
      .append('g')
      .attr('class', styles.series)
      .classed(styles.noStroke, strokePoints === false)
      .style('fill', (d) => d.color || 'var(--accent-color)')

    series
      .selectAll(`.${styles.dataPoint}`)
      .data((d) => d.values)
      .enter()
      .append('circle')
      .attr('class', styles.dataPoint)
      .attr('cx', (d) => scale.x(d.x))
      .attr('cy', (d) => scale.y(d.y))
      .attr('r', pointSize || 7)
      .on('mouseover', mouseOver)
      .on('mouseleave', mouseLeave)

    const pathsTransform = `translate(${scale.x(0)}, ${scale.y(0)}) scale(${
      (scale.x.range()[1] - scale.x.range()[0]) /
      (scale.x.domain()[1] - scale.x.domain()[0])
    }, ${
      (scale.y.range()[1] - scale.y.range()[0]) /
      (scale.y.domain()[1] - scale.y.domain()[0])
    })`

    if (paths) {
      svg
        .selectAll(`.${styles.path}`)
        .data(paths)
        .enter()
        .append('path')
        .attr('class', styles.path)
        .attr('transform', pathsTransform)
        .attr('vector-effect', 'non-scaling-stroke')
        .attr('d', (d) => d)
    }
  }, [
    bounds.height,
    contentRect,
    dataSets,
    horizontalAxisLabel,
    mouseLeave,
    mouseOver,
    paths,
    pointSize,
    scale,
    showIntervals,
    strokePoints,
    verticalAxisLabel,
  ])

  return (
    <div className={styles.container} /*ref={measureRef} */>
      <svg
        ref={svgElement}
        viewBox={`0 0 ${bounds.width} ${bounds.height}`}
        xmlns="http://www.w3.org/2000/svg"
      />

      {renderTooltip && hoveredItem && (
        <div
          className={styles.tooltip}
          style={{
            top: `${scale.y(hoveredItem.y) + 10}px`,
            left: `${scale.x(hoveredItem.x) + 10}px`,
          }}
        >
          {renderTooltip(hoveredItem)}
        </div>
      )}
    </div>
  )

  // Temporarily disable scaling since the measuring is causing infinite render
  // loops and crashing the article.

  // return (
  //   <Measure
  //     bounds
  //     onResize={(contentRect) => {
  //       if (contentRect.bounds) {
  //         setContainerSize({
  //           width: contentRect.bounds.width,
  //           height: contentRect.bounds.height
  //         })
  //       }
  //     }}
  //   >
  //     {({ measureRef }) => (
  //       // .....
  //     )}
  //   </Measure>
  // )
}

export default ScatterPlot
