import {
    BaseEdge,
    Edge as FlowEdge,
    getSmoothStepPath,
    Handle,
    Node as FlowNode,
    NodeProps,
    Position,
    ReactFlowInstance,
    SmoothStepEdgeProps
} from "reactflow";
import React, {memo, useCallback} from "react";
import Graph, {BaseGraphProps} from "./Graph";
import {NFABranch, NFANode} from "../utils/regex";

type CustomNodeData = {
    label: string,
    type: "input" | "output" | "default",
    bottomType: "input" | "output" | "both",
}

const CustomNode = memo<NodeProps<CustomNodeData>>(({data, isConnectable}) => {
    const i = data.bottomType === "input", o = data.bottomType === "output", b = data.bottomType === "both"
    return (<>
        {data.type !== "input" && <Handle type={"target"} position={Position.Left} isConnectable={isConnectable}/>}

        <div>{data?.label}</div>
        {data.type !== "output" && <Handle type={"source"} position={Position.Right} isConnectable={isConnectable}/>}

        {(i || b) && <Handle id={"bottom-in"} type={"target"} position={Position.Bottom} isConnectable={isConnectable}
                             style={b ? {left: "70%", bottom: -2} : undefined}/>}
        {(o || b) && <Handle id={"bottom-out"} type={"source"} position={Position.Bottom} isConnectable={isConnectable}
                             style={b ? {left: "30%", bottom: -2} : undefined}/>}
    </>)
})

type CustomEdgeData = {
    spacing: number,
}

const CustomEdge = memo<SmoothStepEdgeProps<CustomEdgeData>>(props => {
    const [path, , labelY] = getSmoothStepPath({
        ...props,
        sourceY: props.sourceY - 1,
        borderRadius: 0,
        offset: props.data?.spacing,
    })

    return (
        <BaseEdge
            path={path}
            labelX={(props.sourceX + props.targetX) / 2}
            labelY={labelY}
            label={props.label}
            labelStyle={props.labelStyle}
            labelShowBg={props.labelShowBg}
            labelBgStyle={props.labelBgStyle}
            labelBgPadding={props.labelBgPadding}
            labelBgBorderRadius={props.labelBgBorderRadius}
            style={props.style}
            markerEnd={props.markerEnd}
            markerStart={props.markerStart}
            interactionWidth={props.interactionWidth}
        />
    )
})

const nodeTypes = {custom: CustomNode}
const edgeTypes = {custom: CustomEdge}

export interface NFAGraphProps extends BaseGraphProps {
    nfa?: NFABranch
}

export default memo<NFAGraphProps>(({nfa, ...rest}) => {
    const X_SPACING = 100, Y_SPACING = 80, TRANSITION_SPACING = 20

    function createNode(
        id: number, x: number, y: number, type: "input" | "output" | "default", bottom?: "input" | "output" | "both"
    ): FlowNode<CustomNodeData> {
        return {
            id: `${id}`,
            data: {
                label: `${id}`,
                type,
                bottomType: bottom ?? "input",
            },
            style: {
                width: 50,
                height: 50,
            },
            type: bottom ? "custom" : type,
            position: {x, y},
            sourcePosition: Position.Right,
            targetPosition: Position.Left,
            className: type ? type + "-node" : undefined,
        }
    }

    let flowNodes: FlowNode<CustomNodeData>[] = []
    let transitions: FlowEdge<CustomEdgeData>[] = []
    let transitionID = 1

    function pushTransition(label: String, from: number, to: number, spacing?: number) {
        transitions.push({
            id: `transition-${transitionID++}`,
            type: spacing ? "custom" : "default",
            source: `${from}`,
            sourceHandle: spacing ? "bottom-out" : null,
            target: `${to}`,
            targetHandle: spacing ? "bottom-in" : null,
            label: label,
            labelStyle: {
                fontWeight: label !== 'ε' ? "bold" : undefined,
            },
            data: spacing != null ? {
                spacing: spacing,
            } : undefined
        })
    }

    // List of nodes that are targeted with Node.transitionTo
    const needBottomInput = new Set<number>()

    // Preprocess node tree - calculate all necessary values (like node heights, spacing between jump transitions, branch width...)
    function preprocessNodes(branch: NFABranch): number {
        // Remember all jump-transitions
        const transitionsIds: { id: number, to: number }[] = []
        // For each node, remember it's left-to-right index and the node object
        const nodesMap = new Map<number, { index: number, node: NFANode }>()
        // Calculate width of branch
        let width = 0

        // Calculate max height of all splitting nodes
        let max = 0
        for (let i = 0; i < branch.length; i++) {
            const node = branch[i]
            nodesMap.set(node.id, {index: i, node})
            if (node.transition) {
                transitionsIds.push({id: node.id, to: node.transition.to})
                needBottomInput.add(node.transition.to)
            }

            ++width;
            if (!node.split) continue

            // Height of a split node is the sum of its n branches' height + n-1 spacings between the branches
            let maxWidth = 0
            node.height = node.split.reduce((acc, branch) => {
                try {
                    return acc + preprocessNodes(branch) + Y_SPACING;
                } finally {
                    if (branch.width! > maxWidth) maxWidth = branch.width!
                }
            }, -Y_SPACING)
            width += maxWidth
            max = Math.max(max, node.height)
        }
        branch.width = width

        // Calculate height of each transition
        const transitions = transitionsIds.map((tr) => {
            let a = nodesMap.get(tr.id)!.index, b = nodesMap.get(tr.to)!.index
            if (a > b) [a, b] = [b, a]
            // If there is a split node between a and b, this transition needs to space accordingly
            let max = 0
            for (let i = a; i < b; ++i) {
                const h = branch[i].height
                if (h != null && h > max) max = h
            }
            return {a, b, initialSpacing: max / 2, source: nodesMap.get(tr.id)!.node, processed: false}

        }).sort(({a: a1, b: b1}, {a: a2, b: b2}) => a1 - a2 ?? b1 - b2)

        // maxSpacing is the height of this branch ignoring splits, but counting all jump transitions
        let maxSpacing = 0

        function processTransition(index: number): number {
            const current = transitions[index]
            if (current.processed) return current.source.transitionSpacing!
            else current.processed = true
            /* All transition that start before current one ends are considered 'inside':
               A --> B ---> C --> D
               ^     ^------^     ^
               |     a t[i] b     |
               --------------------
               a     t[index]     b */
            let spacing = current.initialSpacing
            for (let i = index + 1; i < transitions.length && transitions[i].a < current.b;) {
                spacing = Math.max(spacing, processTransition(i))
                const b = transitions[i].b
                for (++i; i < transitions.length && transitions[i].a < b;) ++i
            }
            spacing += TRANSITION_SPACING
            current.source.transitionSpacing = spacing
            // Remember the "tallest" transition
            if (spacing > maxSpacing) maxSpacing = spacing
            return spacing
        }

        transitions.forEach((_, index) => processTransition(index))

        // How much the jump transitions extend below the lowest node in the branch
        const spacingDelta = maxSpacing - max / 2
        if (spacingDelta > 0) {
            branch.height = max + spacingDelta
            branch.shiftUp = spacingDelta / 2
        } else {
            branch.height = max
            branch.shiftUp = 0
        }
        return branch.height
    }

    // Process the nodes, calculating coordinates and outputting a list of Nodes and Edges that are rendered by ReactFlow
    function processNodes(branch: NFABranch, startX: number, startY: number, initial?: boolean): number {
        let x = startX
        // When rendering a branch, center it to account for jump transitions that extend below the lowest node
        startY -= branch.shiftUp!

        let prev: NFANode | null = null
        const last = branch[branch.length - 1]
        for (const node of branch) {
            const bOut = node.transition != null, bIn = needBottomInput.has(node.id)
            flowNodes.push(createNode(node.id, x, startY,
                (initial && x === 0) ? "input" : (initial && node === last) ? "output" : "default",
                bOut && bIn ? "both" : bIn ? "input" : bOut ? "output" : undefined))
            x += X_SPACING

            // Add custom transition
            if (node.transition) pushTransition(node.transition.char ?? 'ε', node.id, node.transition.to, node.transitionSpacing)
            // Add transition from previous node
            if (prev) {
                if (!prev.split) pushTransition(node.char ?? 'ε', prev.id, node.id)
                else prev.split.forEach(branch => pushTransition('ε', branch[branch.length - 1].id, node.id))
            }
            prev = node
            if (!node.split) continue

            const X = x // for es-lint unsafe variable warning
            const maxBranchWidth = Math.max(...node.split.map(branch => branch.width!))
            let yTop = startY - node.height! / 2

            x = node.split.reduce((acc, branch) => {
                try {
                    const delta = (maxBranchWidth - branch.width!) * X_SPACING / 2
                    return Math.max(acc, processNodes(branch, X + delta, yTop + branch.height! / 2))
                } finally {
                    yTop += branch.height! + Y_SPACING
                }
            }, 0)
            // Add transitions
            node.split.forEach(branch => pushTransition('ε', node.id, branch[0].id))
        }
        return x
    }

    if (nfa && nfa.length > 0) {
        preprocessNodes(nfa)
        processNodes(nfa, 0, 0, true)
    }

    const fitView = useCallback((flow: ReactFlowInstance) => {
        if (nfa == null || nfa.length === 0) return
        // Total width and height of branch, which is centered around y=0
        // note: node's top-left corner is its coordinate; node is 50x50
        const bWidth = (nfa.width! - 1) * X_SPACING + 50, bHeight = nfa.height! + 55 // extra 5 pixels for transition label
        flow.fitBounds({
            x: -20,
            y: -bHeight / 2 + 28,
            width: bWidth + 20,
            height: bHeight,
        })
    }, [nfa])

    return <Graph fitView={fitView} nodes={flowNodes} edges={transitions} nodeTypes={nodeTypes} edgeTypes={edgeTypes} {...rest}/>
})