import React, {memo, useCallback, useEffect, useMemo, useRef, useState} from "react";
import Dagre from 'dagre';
import Graph, {BaseGraphProps} from "./Graph";
import {BaseEdge, Edge, EdgeProps, Node, ReactFlowInstance, useReactFlow} from "reactflow";
import {randomId, useScrollIntoView} from "@mantine/hooks";
import {DFA, DFAPath} from "../utils/regex";
import {clsx, createStyles} from "@mantine/core";

const useStyles = createStyles(theme => ({
    root: {
        ".react-flow__handle": {
            visibility: "hidden",
        },
        ".react-flow__handle-bottom": {
            bottom: -1,
        },
        ".react-flow__handle-top": {
            top: -1,
        },
        ".react-flow__handle-left": {
            left: -1,
        },
        ".react-flow__handle-right": {
            right: -1,
        },
        ".react-flow__node": {
            transition: "background-color .1s",
            "&.highlight": {
                backgroundColor: theme.colorScheme === "light" ? theme.colors.red[3] : theme.fn.darken(theme.colors.red[9], .4),
            },
        },
    },
}))

const FloatingEdge = memo<EdgeProps>(props => {
    function getNodeIntersection(intersectionNode: Node, targetNode: Node) {
        const {
            width: intersectionNodeWidth,
            height: intersectionNodeHeight,
            position: intersectionNodePosition,
        } = intersectionNode
        const targetPosition = targetNode.position;

        const w = intersectionNodeWidth!;
        const h = intersectionNodeHeight!;

        const x1 = intersectionNodePosition!.x + w / 2;
        const y1 = intersectionNodePosition!.y + h / 2;
        const x2 = targetPosition!.x + targetNode.width! / 2;
        const y2 = targetPosition!.y + targetNode.height! / 2;

        const xx1 = (x2 - x1) / w - (y2 - y1) / h;
        const yy1 = (x2 - x1) / w + (y2 - y1) / h;
        const a = 1 / (Math.abs(xx1) + Math.abs(yy1));
        const xx3 = a * xx1;
        const yy3 = a * yy1;
        const res = {
            x: w / 2 * (xx3 + yy3) + x1,
            y: h / 2 * (-xx3 + yy3) + y1,
        }

        // Account for rounded corners
        const dx = 25 - w / 2 + Math.abs(res.x - x1), dy = 25 - h / 2 + Math.abs(res.y - y1)
        if (dx > 0 && dy > 0) {
            const min = Math.min(dx, dy)
            const b = min + 25
            const delta = 625 - min ** 2 + 50 * min
            const l = (b - Math.sqrt(delta)) / 2
            if (res.x > x1) res.x -= l
            else res.x += l
            if (res.y > y1) res.y -= l
            else res.y += l
        }

        return res;
    }

    function selfPath(x: number, y: number): [string, number, number] {
        return [`M ${x} ${y} l -5 4 c -50 40, 60 40, 10 0 l -5 -4`, x, y + 33]
    }

    function bezierPath(sx: number, sy: number, tx: number, ty: number): [string, number, number] {
        const dx = tx - sx, dy = ty - sy
        const px = (tx + sx) / 2 - dy / 4, py = (ty + sy) / 2 - dx / 4
        return [`M ${sx} ${sy} Q ${px} ${py} ${tx} ${ty}`,
            sx / 4 + px / 2 + tx / 4,
            sy / 4 + py / 2 + ty / 4,
        ]
    }

    const flow = useReactFlow()
    const sourceNode = flow.getNode(props.source)!
    const targetNode = flow.getNode(props.target)!

    let path, labelX, labelY
    if (props.source !== props.target) {

        const {x: sx, y: sy} = getNodeIntersection(sourceNode, targetNode)
        const {x: tx, y: ty} = getNodeIntersection(targetNode, sourceNode)

        ;[path, labelX, labelY] = bezierPath(sx, sy, tx, ty)
    } else [path, labelX, labelY] = selfPath(
        sourceNode.position.x + sourceNode.width! / 2,
        sourceNode.position.y + sourceNode.height! + 1,
    )

    props.data.labelPos = {x: labelX, y: labelY}

    return (
        <BaseEdge
            path={path}
            labelX={labelX}
            labelY={labelY}
            label={props.label}
            labelStyle={props.labelStyle}
            labelShowBg={props.labelShowBg}
            labelBgStyle={props.labelBgStyle}
            labelBgPadding={props.labelBgPadding}
            labelBgBorderRadius={props.labelBgBorderRadius}
            style={props.style}
            markerEnd={props.markerEnd}
            markerStart={props.markerStart}
            interactionWidth={props.interactionWidth}
        />
    )
})

const edgeTypes = {floating: FloatingEdge}

export interface DFAGraphProps extends BaseGraphProps {
    dfa?: DFA,
    nodesDraggable?: boolean
    animation?: DFAPath | null
    /**
     * Truthy value will start animation, changing to a different truthy value will restart animation
     */
    animate?: any
}

export default memo<DFAGraphProps>((
    {
        dfa, nodesDraggable,
        className, animation, animate, style, ...rest
    }
) => {
    // Create flow nodes & edges
    const [flowNodes, flowEdges] = useMemo(() => {
        const graph = new Dagre.graphlib.Graph()
        graph.setDefaultEdgeLabel(() => ({}))
        graph.setGraph({rankdir: "LR", edgesep: 50, ranksep: 80})
        dfa?.forEach(node => {
            graph.setNode(node.label, {width: node.label.length * 8.5 + 41.5, height: 50})
            node.transitions.forEach(({to}) => graph.setEdge(node.label, to))
        })
        Dagre.layout(graph)

        const flowNodes: Node[] = []
        const flowEdges: Edge[] = []
        let edgeID = 1
        dfa?.forEach(node => {
            const pos = graph.node(node.label)
            flowNodes.push({
                id: node.label,
                style: {
                    width: pos.width,
                    height: 50,
                },
                position: {x: pos.x - pos.width / 2, y: pos.y - pos.height / 2},
                className: node.type,
                data: {
                    label: node.label,
                },
            })
            node.transitions.forEach(tr => flowEdges.push({
                id: `edge-${edgeID++}`,
                type: "floating",
                source: node.label,
                target: tr.to,
                label: tr.char,
                data: {},
            }))
        })

        return [flowNodes, flowEdges]
    }, [dfa])

    // Fit to calculated bounds
    const fitView = useCallback((flow: ReactFlowInstance) => {
        const nodes = flow.getNodes(), edges = flow.getEdges()
        // Calculate graph bounds
        const bounds = {x1: 0, y1: 0, x2: 0, y2: 0}
        if (nodes.length > 0) {
            const node = nodes[0]
            bounds.x1 = node.position.x
            bounds.y1 = node.position.y
            bounds.x2 = node.position.x + node.width!
            bounds.y2 = node.position.y + node.height!
        }

        nodes.forEach(node => {
            let x1 = node.position.x, y1 = node.position.y, x2 = x1 + node.width!, y2 = y1 + node.height!
            if (node.className?.includes("input-node")) {
                // Account for the '>' labeling the input node
                x1 -= 20
            }
            if (x1 < bounds.x1) bounds.x1 = x1
            if (y1 < bounds.y1) bounds.y1 = y1
            if (x2 > bounds.x2) bounds.x2 = x2
            if (y2 > bounds.y2) bounds.y2 = y2
        })

        edges.forEach(edge => {
            const labelPos: { x: number, y: number } = edge.data.labelPos
            if (!labelPos) return
            const {x: labelX, y: labelY} = labelPos

            if (labelX < bounds.x1) bounds.x1 = labelX
            if (labelY < bounds.y1) bounds.y1 = labelY
            if (labelX > bounds.x2) bounds.x2 = labelX
            if (labelY > bounds.y2) bounds.y2 = labelY
        })

        flow.fitBounds({x: bounds.x1, y: bounds.y1, width: bounds.x2 - bounds.x1, height: bounds.y2 - bounds.y1})
    }, [])

    // Handle animation
    const [reactFlow, setReactFlow] = useState<ReactFlowInstance | null>(null)
    const animationRef = useRef<string | null>(null)
    const {targetRef, scrollIntoView} = useScrollIntoView<HTMLDivElement>({duration: 500})

    useEffect(() => {
        // Cancel running animation
        animationRef.current = null
        if (animation == null || animation.length === 0 || reactFlow == null || !animate) return

        // Start new animation
        const animationID = randomId()
        animationRef.current = animationID

        function sleep(millis: number) {
            return new Promise(resolve => setTimeout(resolve, millis))
        }

        ;(async function () {
            scrollIntoView({alignment: "center"})
            await sleep(600)

            type NodeUpdate = { id: string, className?: string } | null
            let prev: NodeUpdate = null, current: NodeUpdate = null

            if (animationID !== animationRef.current) return
            for (const node of animation) {
                // 'as any' required because of TS issue #11498
                const shouldUseDelay = node.label === (prev as any)?.id

                // eslint-disable-next-line no-loop-func
                reactFlow.setNodes(nodes => nodes.map(flowNode => {
                    // Remove highlight from old node
                    if (flowNode.id === prev?.id) {
                        const newNode = {...flowNode, className: prev.className}
                        prev = null
                        return newNode
                    } else
                    // Stop if animation got cancelled / new animation started
                    if (shouldUseDelay || node.label !== flowNode.id || animationID !== animationRef.current) return flowNode
                    // Highlight new node
                    current = {id: flowNode.id, className: flowNode.className}
                    return {...flowNode, className: clsx(flowNode.className, "highlight")}
                }))
                // Stop only if no node was left highlighted
                if (current == null && animationID !== animationRef.current) break
                // If the same node was highlighted earlier, wait a bit before highlighting it again (blinking effect)
                if (shouldUseDelay) {
                    await sleep(150)
                    // eslint-disable-next-line no-loop-func
                    reactFlow.setNodes(nodes => nodes.map(flowNode => {
                        if (node.label !== flowNode.id || animationID !== animationRef.current) return flowNode
                        current = {id: flowNode.id, className: flowNode.className}
                        return {...flowNode, className: clsx(flowNode.className, "highlight")}
                    }))
                }

                prev = current
                current = null
                // Sleep .5s
                await sleep(500)
            }
            // Finally, remove highlight from last node
            if (prev != null) reactFlow.setNodes(reactFlow.getNodes().map(
                flowNode => flowNode.id === prev?.id ? {...flowNode, className: prev.className} : flowNode
            ))
        })()
    }, [dfa, animation, animate, reactFlow, scrollIntoView])

    const {classes} = useStyles()
    return <Graph nodes={flowNodes} edges={flowEdges} edgeTypes={edgeTypes} fitView={fitView} nodesDraggable={nodesDraggable}
                  setReactFlow={setReactFlow} ref={targetRef} className={clsx(className, classes.root)} {...rest} />
})