import React, {
  useEffect,
  forwardRef,
  useImperativeHandle,
  useState,
  useRef,
} from "react";
import * as RJD from "react-js-diagrams";
import PropTypes from "prop-types";
import { DTCustomNodeModel } from "components/decisionTrees/drawflow/CustomNode/DTCustomNodeModel";
import { DTCustomLinkModel } from "components/decisionTrees/drawflow/CustomNode/DTCustomLinkModel";
import { DecisionTreeContext } from "context/DecisionTreeContext";
import usePrevious from "utility/hooks/usePrevious";

const DTDiagramDraw = forwardRef(
  ({ engine, decisionTree, handleCloneChange, isRevision }, ref) => {
    const [zoom, setZoom] = useState(100);
    const canvas = useRef();
    const [offset, setOffset] = useState({});
    const [model, setModel] = useState(new RJD.DiagramModel());
    const [decisionTreeClone, setDecisionTreeClone] = useState(null);
    const [forceUpdate, setForceUpdate] = useState(false);
    const prevZoom = usePrevious(zoom);

    const topMargin = 50;
    const leftMargin = 20;

    const {
      setIsCanvasDragging,
      isCanvasDragging,
      setIsLoading,
      setIsUpdatedStep,
      setIsEdited,
      isEdited,
      isFirstLoad,
      setIsFirstLoad,
    } = React.useContext(DecisionTreeContext);

    model.setZoomLevel(zoom);

    useImperativeHandle(
      ref,
      () => ({
        zoomFit() {
          handleZoomToFit();
        },
        zoom() {
          handleZoom();
        },
        zoomOut() {
          handleZoomOut();
        },
      }),
      []
    );

    const triggerMouseEvent = (eventName, element) => {
      const event = document.createEvent("MouseEvents");
      event.initEvent(eventName, true, true);
      element.dispatchEvent(event);
    };

    useEffect(() => {
      if (Object.keys(offset).length > 0) {
        model.setOffset(offset.offsetX, offset.offsetY);
      }

      if (canvas.current && canvas.current.refs && canvas.current.refs.canvas) {
        canvas.current.refs.canvas.addEventListener(
          "wheel",
          function (e) {
            if (e.ctrlKey) {
              const diagramModel = engine.diagramModel;
              e.preventDefault();
              e.stopPropagation();
              if (
                diagramModel.getZoomLevel() - e.deltaY / 300 > 10 &&
                diagramModel.getZoomLevel() - e.deltaY < 300
              ) {
                diagramModel.setZoomLevel(
                  diagramModel.getZoomLevel() - e.deltaY / 300
                );
                setZoom(diagramModel.getZoomLevel() - e.deltaY / 300);
                engine.enableRepaintEntities([]);
                engine.forceUpdate();
              }
            }
          },
          {
            passive: false,
          }
        );
      }
      engine.setDiagramModel(model);
      engine.forceUpdate();
    }, [engine, model, offset]);

    useEffect(() => {
      if (decisionTree) {
        setDecisionTreeClone(decisionTree);
        processDecisionTreeData(decisionTree);
      }
    }, [decisionTree]);

    useEffect(() => {
      processDecisionTreeData(decisionTreeClone);
      engine.enableRepaintEntities([]);
      engine.forceUpdate();
      setTimeout(() => {
        engine.forceUpdate();
      });
      if (canvas.current && canvas.current.refs && canvas.current.refs.canvas) {
        triggerMouseEvent("mousedown", canvas.current.refs.canvas);
        triggerMouseEvent("mouseup", canvas.current.refs.canvas);
      }
    }, [decisionTreeClone]);

    const handleZoomToFit = () => {
      const xFactor = engine.canvas.clientWidth / engine.canvas.scrollWidth;
      const yFactor = engine.canvas.clientHeight / engine.canvas.scrollHeight;
      const zoomFactor = xFactor < yFactor ? xFactor : yFactor;
      const offsetX =
        (engine.canvas.clientWidth - engine.canvas.scrollWidth * zoomFactor) /
        2;
      const offsetY =
        (engine.canvas.clientHeight - engine.canvas.scrollHeight * zoomFactor) /
        2;
      setZoom(zoom * zoomFactor);
      model.setZoomLevel(zoom * zoomFactor);
      model.setOffset(offsetX, Math.max(offsetY, topMargin));
      engine.enableRepaintEntities([]);
      engine.forceUpdate();
    };

    const handleZoom = () => {
      const zoomSize = engine.diagramModel.getZoomLevel() + 10;
      engine.diagramModel.setZoomLevel(zoomSize);
      engine.enableRepaintEntities([]);
      setZoom(zoomSize);
      setTimeout(() => {
        engine.forceUpdate();
      });
    };

    const handleZoomOut = () => {
      if (engine.diagramModel.getZoomLevel() > 20) {
        const zoomSize = engine.diagramModel.getZoomLevel() - 10;
        engine.diagramModel.setZoomLevel(zoomSize);
        engine.enableRepaintEntities([]);
        setZoom(zoomSize);
        setTimeout(() => {
          engine.forceUpdate();
        });
      }
    };

    const processDecisionTreeData = (decisionTreeForProcess) => {
      if (!decisionTreeForProcess || !decisionTreeForProcess.nodes) return;

      const localModel = new RJD.DiagramModel();
      const MIN_HORIZONTAL_SPACING = 100;
      const VERTICAL_SPACING = 100;
      const childMap = {};
      const nodeModelMap = {};

      // Track max height per level
      const levelHeights = new Map();

      const updateLevelHeight = (level, y) => {
        if (!levelHeights.has(level)) {
          levelHeights.set(level, y);
        } else {
          levelHeights.set(level, Math.max(levelHeights.get(level), y));
        }
      };

      decisionTreeForProcess.nodes.forEach((node) => {
        nodeModelMap[node.decision_tree_node_id] = new DTCustomNodeModel({
          ...node,
          data: node,
        });
        if (!node.parent_decision_tree_node_id) {
          childMap[node.decision_tree_node_id] = [];
        } else {
          if (!childMap[node.parent_decision_tree_node_id]) {
            childMap[node.parent_decision_tree_node_id] = [];
          }
          childMap[node.parent_decision_tree_node_id].push(
            node.decision_tree_node_id
          );
        }
      });

      function getSubtreeBounds(nodeId, nodeModelMap, childMap) {
        const node = nodeModelMap[nodeId];
        const nodeElement = document.querySelector(`[data-nodeid="${nodeId}"]`);
        const nodeWidth = nodeElement ? nodeElement.offsetWidth : 200;

        let minX = node.x - nodeWidth / 2;
        let maxX = node.x + nodeWidth / 2;

        const children = childMap[nodeId] || [];
        children.forEach((childId) => {
          const childBounds = getSubtreeBounds(childId, nodeModelMap, childMap);
          minX = Math.min(minX, childBounds.minX);
          maxX = Math.max(maxX, childBounds.maxX);
        });

        return { minX, maxX, width: maxX - minX };
      }

      function separateSiblings(
        nodeId,
        nodeModelMap,
        childMap,
        horizontalSpacing = 20
      ) {
        const children = childMap[nodeId] || [];
        if (children.length < 2) return;

        const boundsArray = children.map((childId) =>
          getSubtreeBounds(childId, nodeModelMap, childMap)
        );

        for (let i = 1; i < children.length; i++) {
          const prevBounds = boundsArray[i - 1];
          const currBounds = boundsArray[i];

          if (currBounds.minX < prevBounds.maxX + horizontalSpacing) {
            const shiftAmount =
              prevBounds.maxX + horizontalSpacing - currBounds.minX;
            shiftSubtree(children[i], shiftAmount, nodeModelMap, childMap);

            const updatedBounds = getSubtreeBounds(
              children[i],
              nodeModelMap,
              childMap
            );
            boundsArray[i] = updatedBounds;
          }
        }
      }

      function shiftSubtree(nodeId, shiftAmount, nodeModelMap, childMap) {
        const node = nodeModelMap[nodeId];
        node.x += shiftAmount;

        const children = childMap[nodeId] || [];
        children.forEach((childId) => {
          shiftSubtree(childId, shiftAmount, nodeModelMap, childMap);
        });
      }

      const layoutTree = (nodeId, x, y, level = 0) => {
        const nodeModel = nodeModelMap[nodeId];
        const children = childMap[nodeId] || [];
        let subtreeWidth = 0;

        const nodeElement = document.querySelector(`[data-nodeid="${nodeId}"]`);
        const nodeWidth = nodeElement ? nodeElement.offsetWidth : 200;
        const nodeHeight = nodeElement ? nodeElement.offsetHeight : 100;

        // Calculate cumulative height for this level
        const cumulativeHeight = Array.from({ length: level }, (_, i) => {
          const levelNodes = decisionTreeForProcess.nodes.filter(
            (n) => getNodeLevel(n.decision_tree_node_id) === i
          );
          return Math.max(
            ...levelNodes.map((n) => {
              const el = document.querySelector(
                `[data-nodeid="${n.decision_tree_node_id}"]`
              );
              return el ? el.offsetHeight : 100;
            }),
            100
          );
        }).reduce((sum, height) => sum + height + VERTICAL_SPACING, 0);

        nodeModel.x = x + leftMargin;
        nodeModel.y = cumulativeHeight + topMargin;
        updateLevelHeight(level, nodeModel.y);

        localModel.addNode(nodeModel);

        if (children.length > 0) {
          let currentX = x;
          let lastChildEndX = currentX;

          children.forEach((childId, index) => {
            const childElement = document.querySelector(
              `[data-nodeid="${childId}"]`
            );
            const childWidth = childElement ? childElement.offsetWidth : 200;
            subtreeWidth += childWidth;

            if (index === 0) {
              currentX = x - childWidth / 2;
            } else {
              const prevChildId = children[index - 1];
              const prevChildElement = document.querySelector(
                `[data-nodeid="${prevChildId}"]`
              );
              const prevChildWidth = prevChildElement
                ? prevChildElement.offsetWidth
                : 200;

              currentX = Math.max(
                lastChildEndX + MIN_HORIZONTAL_SPACING,
                nodeModelMap[prevChildId].x +
                  prevChildWidth +
                  MIN_HORIZONTAL_SPACING
              );
            }

            lastChildEndX = currentX + childWidth;
            const verticalOffset = nodeHeight + VERTICAL_SPACING;
            layoutTree(childId, currentX, y + verticalOffset, level + 1);
          });

          // Center parent above children only if it has multiple children
          if (children.length > 1) {
            const firstChild = nodeModelMap[children[0]];
            const lastChild = nodeModelMap[children[children.length - 1]];
            nodeModel.x = (firstChild.x + lastChild.x) / 2;
          } else if (children.length === 1) {
            // Single child - align parent directly above
            nodeModel.x = nodeModelMap[children[0]].x;
          }
        }

        // After placing all children:
        if (children.length > 0) {
          // Separate siblings to remove overlaps
          separateSiblings(
            nodeId,
            nodeModelMap,
            childMap,
            MIN_HORIZONTAL_SPACING
          );

          // Re-center parent if multiple children
          if (children.length > 1) {
            const firstChildBounds = getSubtreeBounds(
              children[0],
              nodeModelMap,
              childMap
            );
            const lastChildBounds = getSubtreeBounds(
              children[children.length - 1],
              nodeModelMap,
              childMap
            );
            nodeModel.x = (firstChildBounds.minX + lastChildBounds.maxX) / 2;
          } else if (children.length === 1) {
            const singleChild = nodeModelMap[children[0]];
            nodeModel.x = singleChild.x;
          }
        }
        return Math.max(subtreeWidth, 100); // minimum width of a subtree
      };

      // Helper function to get node level
      const getNodeLevel = (nodeId) => {
        let level = 0;
        let currentNode = decisionTreeForProcess.nodes.find(
          (n) => n.decision_tree_node_id === nodeId
        );
        while (currentNode && currentNode.parent_decision_tree_node_id) {
          level++;
          currentNode = decisionTreeForProcess.nodes.find(
            (n) =>
              n.decision_tree_node_id ===
              currentNode.parent_decision_tree_node_id
          );
        }
        return level;
      };

      // Helper function to get node width
      const getNodeWidth = (nodeId) => {
        const nodeElement = document.querySelector(`[data-nodeid="${nodeId}"]`);
        return nodeElement ? nodeElement.offsetWidth : 200;
      };

      const rootNodes = decisionTreeForProcess.nodes.filter(
        (node) => !node.parent_decision_tree_node_id
      );
      let x = 0;
      rootNodes.forEach((rootNode) => {
        const subtreeWidth = layoutTree(rootNode.decision_tree_node_id, x, 0);
        x += subtreeWidth + 100; // space between trees
      });

      // Function to adjust parent node positions
      const adjustParents = (nodeId) => {
        const children = childMap[nodeId];
        if (children && children.length > 0) {
          let minX = Infinity;
          let maxX = -Infinity;
          children.forEach((childId) => {
            adjustParents(childId);
            const childModel = nodeModelMap[childId];
            minX = Math.min(minX, childModel.x);
            maxX = Math.max(maxX, childModel.x);
          });
          const parentNodeModel = nodeModelMap[nodeId];
          parentNodeModel.x = (minX + maxX) / 2;
        }
      };

      rootNodes.forEach((rootNode) => {
        adjustParents(rootNode.decision_tree_node_id);
      });

      const addLinks = (decisionTreeForProcess, nodeModelMap, localModel) => {
        decisionTreeForProcess.nodes.forEach((node) => {
          if (node.parent_decision_tree_node_id) {
            const parentNodeModel =
              nodeModelMap[node.parent_decision_tree_node_id];
            const childNodeModel = nodeModelMap[node.decision_tree_node_id];
            let portOut = parentNodeModel.getPort("right");
            let portIn = childNodeModel.getPort("left");

            if (portOut && portIn) {
              const link = new DTCustomLinkModel();
              link.setSourcePort(portOut);
              link.setTargetPort(portIn);
              link.linkId = `${node.decision_tree_node_id}-${node.parent_decision_tree_node_id}`;
              localModel.addLink(link);
            }
          }
        });
      };

      // After initial layout, normalize Y coordinates to ensure all nodes on same level have the same Y
      const normalizeYCoordinates = () => {
        const processNode = (nodeId, level = 0) => {
          const node = nodeModelMap[nodeId];
          node.y = levelHeights.get(level);
          const children = childMap[nodeId] || [];
          children.forEach((childId) => {
            processNode(childId, level + 1);
          });
        };

        rootNodes.forEach((rootNode) => {
          processNode(rootNode.decision_tree_node_id, 0);
        });
      };
      engine.forceUpdate();
      normalizeYCoordinates();

      // Add links after all nodes are properly positioned
      addLinks(decisionTreeForProcess, nodeModelMap, localModel);

      setModel(localModel);
      engine.setDiagramModel(localModel);
      model.setZoomLevel(zoom);
      engine.forceUpdate();
      engine.enableRepaintEntities([]);

      setTimeout(() => {
        engine.forceUpdate();
        if (isFirstLoad) {
          handleZoomToFit();
          setIsFirstLoad(false);
        }
      });

      if (isRevision) {
        setIsCanvasDragging(false);
      }
    };

    return (
      <RJD.DiagramWidget
        ref={canvas}
        diagramEngine={engine}
        actions={{
          zoom: false,
          canvasDrag: { isCanvasDragging },
          copy: false,
          selectAll: false,
          deleteItems: false,
          multiselect: false,
          multiselectDrag: false,
          moveItems: false,
        }}
      />
    );
  }
);

DTDiagramDraw.displayName = "DTDiagramDraw";

DTDiagramDraw.propTypes = {
  engine: PropTypes.any.isRequired,
  decisionTree: PropTypes.object,
  handleCloneChange: PropTypes.func,
  isRevision: PropTypes.bool,
};

export default DTDiagramDraw;
