import React, { useEffect, useRef, useState } from "react";
import ReactFlow, {
  MiniMap,
  Controls,
  applyNodeChanges,
  applyEdgeChanges,
  Handle,
  Position,
  NodeResizer,
  MarkerType,
} from "reactflow";
import dagre from "dagre";
import {
  App,
  Card,
  Collapse,
  Flex,
  Layout,
  Menu,
  message,
  Switch,
  Tag,
} from "antd";
import TextArea from "antd/es/input/TextArea";
import axios from "axios";
import { EventSourcePolyfill } from "event-source-polyfill";
import { useLocation, useNavigate } from "react-router-dom";
import { EventComponent, mergeTokens } from "../utils/llm_output_parser";
import AgentNodeComponent from "./graph/graphAgentNode";
import FloatingEdge from "./graph/floatingEdge";
import {
  AiSystemsApi,
  AiTasksApi,
  Configuration,
} from "../../../lib/src/api_client";
import { apiBasePath } from "../config";
import { useAuth0 } from "@auth0/auth0-react";
import EndNodeComponent from "./graph/graphEndNode";
import "reactflow/dist/style.css";
import "./graph/style.css";
import StatusTag from "./misc/taskStatusTag";
import { StopOutlined } from "@ant-design/icons";

const CustomNodeComponent = ({ data }) => {
  const [isExpanded, setIsExpanded] = useState(false);
  const handleToggle = () => {
    setIsExpanded(!isExpanded);
  };

  const processedEvents = mergeTokens(data.llm); // Process to merge adjacent tokens and prepare events for display

  // Styles for the output panel
  const outputStyle = {
    height: "300px",
    overflowY: "auto",
    padding: "4px",
  };

  const outputRef = useRef(null); // Create a ref for the output div
  const lastItemRef = useRef(null);

  useEffect(() => {
    if (outputRef.current && data.autoScroll) {
      outputRef.current.scrollTop = outputRef.current.scrollHeight;
    }
  }, [processedEvents, data.autoScroll]); // Include autoScroll in dependencies

  return (
    <Card title={data.label} bordered={true} style={{ width: 500 }}>
      <Handle type="target" position={Position.Left} />
      <p>
        <strong>ID:</strong> {data.id}
      </p>
      <p>
        <strong>Type:</strong> {data.type}
      </p>
      <Collapse defaultActiveKey={["1"]} onChange={handleToggle}>
        <Collapse.Panel header="Output" key="1">
          <div ref={outputRef} style={outputStyle}>
            {processedEvents.map((event, index) => (
              <EventComponent
                key={index}
                event={event}
                ref={
                  index >= processedEvents.length - 3 ? lastItemRef : undefined
                }
              />
            ))}
          </div>
        </Collapse.Panel>
      </Collapse>
      <Handle type="source" position={Position.Right} />
    </Card>
  );
};

const getLayoutedElements = (nodes, edges, direction = "TB") => {
  const dagreGraph = new dagre.graphlib.Graph();
  dagreGraph.setDefaultEdgeLabel(() => ({}));
  dagreGraph.setGraph({
    rankdir: direction, // You can change this to 'TB' (top to bottom) or 'RL' (right to left) to see what works best
    // align: "UL",              // Alignment of nodes. Can be 'UL', 'UR', 'DL', 'DR'
    edgesep: 100, // Increase edge separation
    ranksep: 300, // Increase rank separation between nodes
    marginx: 20, // Margin added around the nodes on the x-axis
    marginy: 20, // Margin added around the nodes on the y-axis
  });

  nodes.forEach((node) => {
    dagreGraph.setNode(node.id, { width: 700, height: 650 }); // Adjust width and height as needed
  });

  edges.forEach((edge) => {
    dagreGraph.setEdge(edge.source, edge.target, {
      label: edge.label,
      style: edge.style,
    });
  });

  dagre.layout(dagreGraph);

  nodes.forEach((node) => {
    const nodeWithPosition = dagreGraph.node(node.id);
    node.position = {
      x: nodeWithPosition.x - nodeWithPosition.width / 2,
      y: nodeWithPosition.y - nodeWithPosition.height / 2,
    };
  });

  return { nodes, edges };
};

// Node styles
const nodeTypes = {
  LLM: AgentNodeComponent,
  EXECUTOR: AgentNodeComponent,
  END: EndNodeComponent,
};

// const edgeTypes = {
//     floating: FloatingEdge,
// }

const AIComponent = ({ taskIdProp }) => {
  const [nodes, setNodes] = useState([]);
  const [edges, setEdges] = useState([]);
  const [system, setSystem] = useState(null);
  const [streamingData, setStreamingData] = useState({});
  const [streamingStatus, setStreamingStatus] = useState(false);
  const [autoScroll, setAutoScroll] = useState(true); // State to control auto-scroll
  const [nodeUsageData, setNodeUsageData] = useState({});
  const [accessToken, setAccessToken] = useState(null);
  const [totals, setTotals] = useState({
    totalCompletionTokens: 0,
    totalPromptTokens: 0,
    totalCompletionCost: 0,
    totalPromptCost: 0,
  });
  const location = useLocation();
  const aiSystemId = location.state.aiSystemId;
  const taskId = taskIdProp || location.state.taskId;
  const { getAccessTokenSilently } = useAuth0();
  const { message, notification } = App.useApp();
  const navigate = useNavigate();
  const [taskInfo, setTaskInfo] = useState(null);

  useEffect(() => {
    const fetchToken = async () => {
      const token = await getAccessTokenSilently();
      setAccessToken(token);
    };
    fetchToken();
  });

  function getTaskInfo() {
    const tasksApi = new AiTasksApi(
      new Configuration({
        accessToken: getAccessTokenSilently,
        basePath: apiBasePath,
      }),
    );

    tasksApi
      .getAiTaskV1AiTasksAiTaskIdGet(taskId)
      .then((response) => {
        setTaskInfo(response.data);
        if (
          response.data.current_status === "COMPLETED" ||
          response.data.current_status === "FAILED" ||
          response.data.current_status === "CANCELLED" ||
          response.data.current_status === "TERMINATED" ||
          response.data.current_status === "TIME_OUT"
        ) {
          setAutoScroll(false);
        }
      })
      .catch((error) => {
        message.error("Could not fetch task: " + error);
      });
  }

  const fetchTaskInfo = async () => {
    getTaskInfo();
  };

  const refreshTaskInfo = async () => {
    // skip if task loaded and current status is Done or Failed
    if (!taskId) return;
    if (
      taskInfo &&
      (taskInfo.current_status === "COMPLETED" ||
        taskInfo.current_status === "FAILED" ||
        taskInfo.current_status === "CANCELLED" ||
        taskInfo.current_status === "TERMINATED" ||
        taskInfo.current_status === "TIME_OUT")
    ) {
      return;
    } else {
      getTaskInfo();
    }
  };

  // every 5 seconds fetch taskinfo
  useEffect(() => {
    if (
      taskInfo &&
      (taskInfo.current_status === "COMPLETED" ||
        taskInfo.current_status === "FAILED" ||
        taskInfo.current_status === "CANCELLED" ||
        taskInfo.current_status === "TERMINATED" ||
        taskInfo.current_status === "TIME_OUT")
    ) {
      return;
    }
    const intervalId = setInterval(refreshTaskInfo, 5000);

    return () => clearInterval(intervalId);
  }, [taskId, taskInfo]);

  useEffect(() => {
    const fetchData = async () => {
      try {
        const api = new AiSystemsApi(
          new Configuration({
            accessToken: getAccessTokenSilently,
            basePath: apiBasePath,
          }),
        );

        const response =
          await api.getAiSystemCompleteV1AiSystemsAiSystemIdCompleteGet(
            aiSystemId,
          );

        const systemData = response.data.system;
        setSystem(response.data);
        const { nodes: layoutedNodes, edges: layoutedEdges } =
          getLayoutedElements(
            systemData.nodes.map((node) => ({
              id: node.name,
              type: node.node_type,
              data: { label: node.name },
              position: { x: 0, y: 0 },
              dragHandle: ".ant-card-head",
            })),
            systemData.edges.flatMap((edge) => {
              if (edge.edge_type === "ConditionalEdge") {
                return Object.entries(edge.mapping).flatMap(
                  ([conditionValue, targetNode]) => {
                    const startNodes = Array.isArray(edge.start_node)
                      ? edge.start_node
                      : [edge.start_node];
                    const targetNodes = Array.isArray(targetNode)
                      ? targetNode
                      : [targetNode];

                    return startNodes.flatMap((startNode) =>
                      targetNodes.map((target) => ({
                        id: `e${startNode}-${target}`,
                        source: startNode,
                        target: target,
                        animated: false,
                        markerEnd: {
                          type: MarkerType.ArrowClosed,
                          width: 50,
                          height: 50,
                        },
                        label: `If ${conditionValue}`,
                        style: {
                          strokeDasharray: "3",
                          // strokeWidth: 2,
                        },
                      })),
                    );
                  },
                );
              } else {
                const startNodes = Array.isArray(edge.start_node)
                  ? edge.start_node
                  : [edge.start_node];
                const endNodes = Array.isArray(edge.end_node)
                  ? edge.end_node
                  : [edge.end_node];

                return startNodes.flatMap((startNode) =>
                  endNodes.map((endNode) => ({
                    id: `e${startNode}-${endNode}`,
                    source: startNode,
                    target: endNode,
                    animated: true,
                    // type: "smoothstep",
                    markerEnd: {
                      type: MarkerType.ArrowClosed,
                      width: 50,
                      height: 50,
                    },
                    style: {
                      strokeDasharray: "0",
                      // strokeWidth: 2,
                    },
                  })),
                );
              }
            }),
          );
        setNodes(layoutedNodes);
        setEdges(layoutedEdges);
      } catch (error) {
        console.error("Failed to fetch system data:", error);
      }
    };

    fetchData();
    if (taskId) {
      fetchTaskInfo();
    }
  }, [aiSystemId]);

  useEffect(() => {
    if (!taskId) return;
    if (!accessToken) return;

    const headers = {
      headers: {
        Authorization: "Bearer " + accessToken,
      },
    };

    const eventSource = new EventSourcePolyfill(
      apiBasePath + `/v1/ai_tasks/${taskId}/stream`,
      headers,
    );

    eventSource.onopen = () => {
      setStreamingStatus(true);
    };

    eventSource.onmessage = (event) => {
      const { node, content } = JSON.parse(event.data);
      setStreamingData((prevData) => ({
        ...prevData,
        [node]: [...(prevData[node] || []), JSON.parse(event.data)], // Append new JSON content to the array
      }));
    };

    eventSource.onerror = (error) => {
      setStreamingStatus(false);
      eventSource.close();
      console.log("Failed to stream data: " + error);
      message.error("Streaming connection lost. Retrying in 5 seconds...");
      setTimeout(() => {
        eventSource.open();
      }, 5000);
    };

    return () => {
      eventSource.close();
    };
  }, [taskId, accessToken]);

  // Token usage streaming
  useEffect(() => {
    if (!taskId) return;
    if (!accessToken) return;

    const headers = {
      headers: {
        Authorization: "Bearer " + accessToken,
      },
    };

    const eventSource = new EventSourcePolyfill(
      apiBasePath + `/v1/ai_tasks/${taskId}/stream_usage`,
      headers,
    );

    // eventSource.onopen = () => {
    //     setStreamingStatus(true);
    // };

    eventSource.onmessage = (event) => {
      const data = JSON.parse(event.data);
      const node = data.node_name; // Assuming 'node_name' identifies the node
      setNodeUsageData((prevData) => {
        const existingData = prevData[node] || {
          completionTokens: 0,
          promptTokens: 0,
          completionCost: 0,
          promptCost: 0,
        };
        return {
          ...prevData,
          [node]: {
            completionTokens:
              existingData.completionTokens + data.completion_tokens,
            promptTokens: existingData.promptTokens + data.prompt_tokens,
            completionCost:
              existingData.completionCost + parseFloat(data.completion_cost),
            promptCost: existingData.promptCost + parseFloat(data.prompt_cost),
          },
        };
      });
      setTotals((totals) => ({
        totalCompletionTokens:
          totals.totalCompletionTokens + data.completion_tokens,
        totalPromptTokens: totals.totalPromptTokens + data.prompt_tokens,
        totalCompletionCost:
          totals.totalCompletionCost + parseFloat(data.completion_cost),
        totalPromptCost: totals.totalPromptCost + parseFloat(data.prompt_cost),
      }));
    };

    eventSource.onerror = () => {
      message.error(
        "Could not stream token usage data. Retrying in 5 seconds...",
      );
      eventSource.close();
      setTimeout(() => {
        eventSource.open();
      }, 5000);
    };

    return () => {
      eventSource.close();
    };
  }, [taskId, accessToken]);

  const handleAutoScrollChange = (checked) => {
    setAutoScroll(checked);
  };

  const handleCancelTask = (taskId) => {
    return async () => {
      const api = new AiTasksApi(
        new Configuration({
          accessToken: getAccessTokenSilently,
          basePath: apiBasePath,
        }),
      );
      try {
        await api.cancelTaskV1AiTasksTaskIdCancelPost(taskId);
        message.success("Task cancelled successfully");
        refreshTaskInfo();
      } catch (error) {
        message.error("Could not cancel task: " + error);
      }
    };
  };

  return (
    <Flex vertical style={{ height: "100%" }}>
      <Menu mode="horizontal" style={{ textAlign: "center" }}>
        {taskId && (
          <Menu.Item
            key="status"
            disabled
            style={{ opacity: 1, cursor: "default" }}
          >
            <StatusTag status={taskInfo?.current_status} />
          </Menu.Item>
        )}
        {taskId && (
          <Menu.Item
            key="info"
            onClick={() => {
              navigate("/tasks/view", { state: { taskId: taskId } });
            }}
          >
            <Tag>Task: {taskId}</Tag>
          </Menu.Item>
        )}
        <Menu.Item
          key="info2"
          disabled
          style={{ opacity: 1, cursor: "default" }}
          onClick={() => {
            navigate("/tasks/view", { state: { taskId: taskId } });
          }}
        >
          <Tag>AI System: {aiSystemId}</Tag>
        </Menu.Item>
        {taskInfo && (
          <Menu.Item
            key="cancel"
            disabled={
              taskInfo.current_status === "CANCELLED" ||
              taskInfo.current_status === "COMPLETED" ||
              taskInfo.current_status === "FAILED" ||
              taskInfo.current_status === "TERMINATED" ||
              taskInfo.current_status === "TIME_OUT"
            }
            onClick={handleCancelTask(taskId)}
            icon={<StopOutlined />}
          >
            Cancel task
          </Menu.Item>
        )}
        {/*<Menu.Item key="spacer" style={{ flexGrow: 1 }} disabled />*/}
        <Menu.Item
          key="autoscroll"
          disabled
          style={{ opacity: 1, cursor: "default" }}
        >
          <Switch
            checkedChildren="Scroll ON"
            unCheckedChildren="Scroll OFF"
            checked={autoScroll}
            onChange={handleAutoScrollChange}
          />
        </Menu.Item>
        <Menu.Item
          key="totalTokens"
          disabled
          style={{ opacity: 1, cursor: "default", textAlign: "right" }}
        >
          <Tag>
            Total Tokens:{" "}
            {(
              totals.totalCompletionTokens + totals.totalPromptTokens
            ).toLocaleString()}
          </Tag>
        </Menu.Item>
        <Menu.Item
          key="totalCost"
          disabled
          style={{ opacity: 1, cursor: "default", textAlign: "right" }}
        >
          <Tag>
            Total Credits:{" "}
            {(totals.totalCompletionCost + totals.totalPromptCost).toFixed(4)}
          </Tag>
        </Menu.Item>
      </Menu>
      <ReactFlow
        width="100%"
        height="100%"
        minZoom={0.1}
        nodes={nodes.map((node) => ({
          ...node,
          data: {
            ...node.data,
            nodeId: node.id,
            tools: system.tools,
            connections: system.connections,
            agents: system.agents,
            system: system.system,
            llm: streamingData[node.id],
            usage: nodeUsageData[node.id] || {},
            autoScroll: autoScroll,
          },
        }))}
        nodeTypes={nodeTypes}
        edges={edges}
        fitView
        // edgeTypes={edgeTypes}
        onNodesChange={(changes) => setNodes(applyNodeChanges(changes, nodes))}
        onEdgesChange={(changes) => setEdges(applyEdgeChanges(changes, edges))}
      >
        {/*<MiniMap />*/}
        <Controls />
        {streamingStatus && (
          <div
            style={{ position: "absolute", top: 10, right: 10, color: "green" }}
          >
            LIVE
          </div>
        )}
        {!streamingStatus && (
          <div
            style={{ position: "absolute", top: 10, right: 10, color: "red" }}
          >
            OFFLINE
          </div>
        )}
      </ReactFlow>
    </Flex>
  );
};

export default AIComponent;
