import * as styles from './TreeDiagram.module.scss'

import React, { useMemo } from 'react'
import * as d3 from 'd3'
import { sortBy } from 'lodash'
import classNames from 'classnames'

export interface TreeNode {
  name: string
  highlight?: boolean
  children?: TreeNode[]
}

interface Props {
  data: TreeNode
  direction?: 'up' | 'right'
  width?: number
  height?: number
  padding?: number
}

function semiRadialUp(width: number, height: number, padding: number) {
  const innerWidth = width - 2 * padding
  const innerHeight = height - 2 * padding

  return (x: number, y: number) => {
    x = x - 0.5
    y = y - 0

    const polarX = y * innerHeight * Math.cos(x - Math.PI / 2)
    const polarY = y * innerHeight * Math.sin(x - Math.PI / 2)

    const rectX = x * innerWidth
    const rectY = y * (-height + 2 * padding)

    return [
      (x = (polarX + rectX) / 2) + width / 2,
      (y = (polarY + rectY) / 2) + height - padding,
    ]
  }
}

function linearRight(width: number, height: number, padding: number) {
  const innerWidth = width - 2 * padding
  const innerHeight = height - 2 * padding

  return (x: number, y: number) => {
    const rectX = y * innerWidth + padding
    const rectY = x * innerHeight + padding

    return [rectX, rectY]
  }
}

export const TreeDiagram: React.FC<Props> = (props) => {
  const {
    data,
    direction = 'up',
    width = 400,
    height = 400,
    padding = 30,
  } = props

  const treeLayout = useMemo(() => {
    const root = d3.hierarchy(data)
    const layout = d3.tree<TreeNode>().separation(() => 1)(root)

    const transform =
      direction === 'up'
        ? semiRadialUp(width, height, padding)
        : linearRight(width, height, padding)

    layout.each((node) => ([node.x, node.y] = transform(node.x, node.y)))

    return layout
  }, [data, direction, height, padding, width])

  const shapes: { component: React.ReactNode; layer: number }[] = []

  treeLayout.each((node, index) => {
    if (node.parent != null) {
      shapes.push({
        component: (
          <path
            key={index}
            className={classNames(styles.link, {
              [styles.highlight]: node.data.highlight === true,
              [styles.deHighlight]: node.data.highlight === false,
            })}
            d={
              d3.link(d3.curveBumpY)({
                source: [node.parent.x, node.parent.y],
                target: [node.x, node.y],
              }) ?? undefined
            }
          />
        ),
        layer: 1,
      })
    }

    if (node.data.highlight == null) {
      shapes.push({
        component: <circle cx={node.x} cy={node.y} r={width / 160} />,
        layer: 2,
      })
    }

    if (node.data.highlight === true) {
      shapes.push({
        component: <circle cx={node.x} cy={node.y} r={width / 100} />,
        layer: 2,
      })
    }

    if (node.data.highlight === false && node.parent?.data.highlight) {
      shapes.push({
        component: (
          <g>
            <line
              className={styles.cross}
              x1={node.x - 4}
              y1={node.y - 4}
              x2={node.x + 4}
              y2={node.y + 4}
              stroke="black"
            />
            <line
              className={styles.cross}
              x1={node.x + 4}
              y1={node.y - 4}
              x2={node.x - 4}
              y2={node.y + 4}
              stroke="black"
            />
          </g>
        ),
        layer: 2,
      })
    }
  })

  const components = sortBy(shapes, (shape) => shape.layer).map(
    (shape) => shape.component,
  )

  return (
    <svg
      // width={width}
      // height={height}
      viewBox={`0 0 ${width} ${height}`}
      className={styles.svg}
    >
      <g>{components}</g>
    </svg>
  )
}
