import { useLocaleFormatters } from "$root/localization/hooks";
import { customObjectEntriesFn, stableColorGenerator, sumIterable } from "$root/utils";
import type { StylableProps } from "@mdotm/mdotui/components";
import { ComputedSizeContainer, FloatingContent, Svg, Text } from "@mdotm/mdotui/components";
import { ForEach } from "@mdotm/mdotui/react-extensions";
import type { Rect2D } from "@mdotm/mdotui/utils";
import { mapBetweenRanges } from "@mdotm/mdotui/utils";
import fastDeepEqual from "fast-deep-equal";
import { useRef, useState } from "react";
/** rect */
export type SankeyChartItem = {
	name: string;
	weight: number;
	colorKey: string;
};
/** group containing rects */
export type SankeyChartGroup = {
	weight: number;
	name: string;
	label: string;
	items: Array<SankeyChartItem>;
};
/** column containing groups */
export type SankeyChartColumn = {
	marginTop: number;
	marginBottom: number;
	groups: Array<SankeyChartGroup>;
};
/** column containing groups */
export type SankeyChartPath = {
	from: {
		columnIndex: number;
		groupIndex: number;
		itemIndex: number;
	};
	to: {
		columnIndex: number;
		groupIndex: number;
		itemIndex: number;
	};
};

export type SankeyLikeChartProps = {
	columns: Array<SankeyChartColumn>;
	paths: Array<SankeyChartPath>;
} & StylableProps;

export function SankeyLikeChart({ columns, paths, ...stylableProps }: SankeyLikeChartProps): JSX.Element {
	return (
		<ComputedSizeContainer {...stylableProps}>
			{({ htmlEl: _htmlEl, ...rect }) => {
				return <Inner viewBox={rect} columns={columns} paths={paths} />;
			}}
		</ComputedSizeContainer>
	);
}

const columnWidth = 25;
const marginLeft = 10;
const marginRight = 10;
const groupLabelWidth = 100;
const groupLabelMarginLeft = 10;
const groupLabelMarginRight = 10;

function Inner({
	viewBox,
	columns,
	paths,
}: {
	viewBox: Rect2D;
	columns: Array<SankeyChartColumn>;
	paths: Array<SankeyChartPath>;
}) {
	const [hoveredGroupIndex, setHoveredHoveredGroupIndex] = useState<(SankeyChartPath & { mainIndex: number }) | null>(
		null,
	);
	const intervalRef = useRef<NodeJS.Timeout | null>(null);

	function onMouseEnter(data: typeof hoveredGroupIndex) {
		setHoveredHoveredGroupIndex(data);
		if (intervalRef.current) {
			clearTimeout(intervalRef.current);
		}
	}

	function onMouseLeave() {
		const id = setTimeout(() => setHoveredHoveredGroupIndex(null), 350);
		intervalRef.current = id;
	}

	const marginHBetweenColumns =
		(viewBox.width -
			groupLabelWidth * 2 -
			(groupLabelMarginLeft + groupLabelMarginRight) * 2 -
			marginLeft -
			marginRight -
			columnWidth * columns.length) /
		(columns.length - 1);
	const marginVBetweenGroups = 10;
	const marginVBetweenItems = 5;

	function getItemWeight(item: SankeyChartItem) {
		return item.weight;
	}
	function getItemWeightSumInGroup(group: SankeyChartGroup) {
		return sumIterable(group.items, (x) => x.weight);
	}
	function getGroupWeight(group: SankeyChartGroup) {
		return group.weight;
	}
	function getColumnWeight(column: SankeyChartColumn) {
		return sumIterable(column.groups, getGroupWeight);
	}
	function getColumnHeight(columnIndex: number) {
		return viewBox.height - columns[columnIndex].marginBottom - columns[columnIndex].marginTop;
	}
	function getGroupHeight(columnIndex: number, groupIndex: number) {
		return Math.max(
			1,
			mapBetweenRanges(
				getGroupWeight(columns[columnIndex].groups[groupIndex]),
				0,
				getColumnWeight(columns[columnIndex]),
				0,
				getColumnHeight(columnIndex) - marginVBetweenGroups * (columns[columnIndex].groups.length - 1),
			),
		);
	}
	function getColumnX(columnIndex: number) {
		let startColX = marginLeft + groupLabelWidth + groupLabelMarginLeft + groupLabelMarginRight;
		for (let i = 0; i < columnIndex; i++) {
			startColX += columnWidth + marginHBetweenColumns;
		}
		return startColX;
	}
	// TODO: memo all these functions to drastically improve performance
	function getGroupLabelX(columnIndex: number) {
		if (columnIndex === 0) {
			return marginLeft + groupLabelMarginLeft;
		} else if (columnIndex === columns.length - 1) {
			return viewBox.width - marginRight - groupLabelMarginRight - groupLabelWidth;
		} else {
			throw new RangeError("column index must be 0 or (columns.length - 1)"); // TODO: support more than 2 cols, how do we print groups in the middle without making a mess?
		}
	}
	function getColumnY(columnIndex: number) {
		return columns[columnIndex].marginTop;
	}
	function getGroupY(columnIndex: number, groupIndex: number) {
		let startGroupY = getColumnY(columnIndex);
		for (let i = 0; i < groupIndex; i++) {
			const groupHeight = getGroupHeight(columnIndex, i);
			startGroupY += groupHeight + marginVBetweenGroups;
		}
		return startGroupY;
	}
	function getItemHeight(columnIndex: number, groupIndex: number, itemIndex: number) {
		return Math.max(
			1,
			mapBetweenRanges(
				getItemWeight(columns[columnIndex].groups[groupIndex].items[itemIndex]),
				0,
				getItemWeightSumInGroup(columns[columnIndex].groups[groupIndex]),
				0,
				getGroupHeight(columnIndex, groupIndex) -
					marginVBetweenItems * (columns[columnIndex].groups[groupIndex].items.length - 1),
			),
		);
	}
	function getItemY(columnIndex: number, groupIndex: number, itemIndex: number) {
		let startItemY = getGroupY(columnIndex, groupIndex);
		for (let i = 0; i < itemIndex; i++) {
			const itemHeight = getItemHeight(columnIndex, groupIndex, i);
			startItemY += itemHeight + marginVBetweenItems;
		}
		return startItemY;
	}

	function getPath<TStartingPoint extends keyof SankeyChartPath>(
		startingPoint: TStartingPoint,
		point: SankeyChartPath[TStartingPoint],
	) {
		let mainIndex: number | undefined = undefined;
		let to: SankeyChartPath["to"] | undefined = undefined;
		let from: SankeyChartPath["from"] | undefined = undefined;
		mainIndex = paths.findIndex((x) => fastDeepEqual(startingPoint === "from" ? x.from : x.to, point));
		if (mainIndex > -1) {
			const path = paths[mainIndex];
			to = path.to;
			from = path.from;
		}

		if (to === undefined || from === undefined) {
			return null;
		}

		return { mainIndex, to, from };
	}

	function isPathMatching<TStartingPoint extends keyof SankeyChartPath>(
		startingPoint: TStartingPoint,
		point: SankeyChartPath[TStartingPoint],
	) {
		if (!hoveredGroupIndex) {
			return false;
		}
		const hoveredPoint = hoveredGroupIndex[startingPoint];
		return customObjectEntriesFn(hoveredPoint).every(([level, value]) => point[level] === value);
	}

	function brakeLabel(name: string, length: number): [string, string] {
		const brakedTextByBrakeOrComma = name.split(/[ ,]+/);
		return brakedTextByBrakeOrComma.reduce(
			(accumulator, word, index) => {
				const joinedString = accumulator[0].concat((index === 0 ? "" : " ") + word);
				if (joinedString.length <= length && accumulator[1].length === 0) {
					accumulator[0] = joinedString;
					return accumulator;
				}

				accumulator[1] = accumulator[1].concat((index === 0 ? "" : " ") + word);
				return accumulator;
			},
			["", ""],
		);
	}

	function getColumn(currentColumns: SankeyChartColumn[], index: number) {
		const column = currentColumns[index];
		function getGroup(groupIndex: number) {
			const group = column.groups[groupIndex];
			function getItem(itemIndex: number) {
				const item = group.items[itemIndex];

				return item;
			}

			return { ...group, getItem };
		}

		return { ...column, getGroup };
	}

	const { formatNumber } = useLocaleFormatters();
	return (
		<Svg viewBox={viewBox}>
			<ForEach collection={columns}>
				{({ item: column, index: columnIndex }) => (
					<ForEach collection={column.groups}>
						{({ item: group, index: groupIndex }) => {
							const minWeightToShowData = 5;
							const maxLabelLenght = 15;
							const startingPoint = columnIndex === 0 ? "from" : "to";
							const groupOpacity =
								hoveredGroupIndex?.[startingPoint].groupIndex === groupIndex
									? 1
									: hoveredGroupIndex !== null
									  ? 0
									  : getGroupWeight(group) < minWeightToShowData
									    ? 0
									    : 1;
							return (
								<>
									{(columnIndex === 0 || columnIndex === columns.length - 1) && (
										<>
											<ForEach
												collection={brakeLabel(group.label, maxLabelLenght)
													.filter((word) => word)
													.reverse()}
											>
												{({ item, index }) => (
													<text
														fill="black"
														fontSize={12}
														x={
															getGroupLabelX(columnIndex) +
															(columnIndex === 0
																? groupLabelWidth - 4 /* extra margin to balance the percent sign underneath */
																: 0)
														}
														opacity={groupOpacity}
														y={
															getGroupY(columnIndex, groupIndex) +
															getGroupHeight(columnIndex, groupIndex) / 2 -
															8 -
															index * 20
														}
														textAnchor={columnIndex === 0 ? "end" : "start"}
														className="whitespace-pre-line"
														height={40}
													>
														{item.length > maxLabelLenght ? item.substring(0, maxLabelLenght).concat("...") : item}
													</text>
												)}
											</ForEach>
											<text
												fill="black"
												fontSize={20}
												fontWeight={500}
												opacity={groupOpacity}
												x={getGroupLabelX(columnIndex) + (columnIndex === 0 ? groupLabelWidth : 0)}
												y={getGroupY(columnIndex, groupIndex) + getGroupHeight(columnIndex, groupIndex) / 2 + 14}
												textAnchor={columnIndex === 0 ? "end" : "start"}
												className="transition-opacity cursor-default"
											>
												{formatNumber(getGroupWeight(group), 2)}%
											</text>
										</>
									)}
									<ForEach collection={group.items}>
										{({ item, index: itemIndex }) => {
											const point = { columnIndex, groupIndex, itemIndex };
											const path = getPath(startingPoint, point);
											const match = isPathMatching(startingPoint, point);

											const itemOpacity = match
												? 1
												: hoveredGroupIndex !== null
												  ? 0
												  : item.weight < 5 || getGroupWeight(group) < 5
												    ? 0
												    : 1;
											return (
												<>
													<rect
														fill={stableColorGenerator(item.colorKey)}
														x={getColumnX(columnIndex)}
														width={columnWidth}
														y={getItemY(columnIndex, groupIndex, itemIndex)}
														height={getItemHeight(columnIndex, groupIndex, itemIndex)}
														className="transition-opacity cursor-default"
														opacity={hoveredGroupIndex === null ? 0.7 : match ? 1 : 0.1}
														onMouseEnter={() => onMouseEnter(path)}
														onMouseLeave={onMouseLeave}
													/>
													<text
														fontSize={10}
														x={getColumnX(columnIndex) + columnWidth / 2}
														y={
															getItemY(columnIndex, groupIndex, itemIndex) +
															getItemHeight(columnIndex, groupIndex, itemIndex) / 2 +
															4 /* offset to center font size of 10 */
														}
														onMouseEnter={() => onMouseEnter(path)}
														onMouseLeave={onMouseLeave}
														opacity={itemOpacity}
														className="transition-opacity cursor-default"
														textAnchor="middle"
													>
														{formatNumber(item.weight, 1)}
													</text>
												</>
											);
										}}
									</ForEach>
								</>
							);
						}}
					</ForEach>
				)}
			</ForEach>
			<ForEach collection={paths}>
				{({ item: { from, to }, index }) => {
					const extraSpacing = hoveredGroupIndex === null ? 2 : hoveredGroupIndex.mainIndex === index ? 0 : 2;
					const fromX = getColumnX(from.columnIndex) + columnWidth + extraSpacing; /* extra spacing */
					const fromY = getItemY(from.columnIndex, from.groupIndex, from.itemIndex);
					const fromHeight = getItemHeight(from.columnIndex, from.groupIndex, from.itemIndex);
					const toX = getColumnX(to.columnIndex) - extraSpacing; /* extra spacing */
					const toY = getItemY(to.columnIndex, to.groupIndex, to.itemIndex);
					const toHeight = getItemHeight(to.columnIndex, to.groupIndex, to.itemIndex);
					return (
						<FloatingContent
							open={hoveredGroupIndex?.mainIndex === index}
							strategy="absolute"
							position="bottom"
							align="middle"
							trigger={({ innerRef }) => (
								<>
									<path
										onMouseEnter={() => onMouseEnter({ from, to, mainIndex: index })}
										onMouseLeave={onMouseLeave}
										fill={stableColorGenerator(
											columns[to.columnIndex].groups[to.groupIndex].items[to.itemIndex].colorKey,
										)}
										opacity={hoveredGroupIndex === null ? 0.5 : hoveredGroupIndex?.mainIndex === index ? 0.5 : 0.1}
										className="transition-all cursor-default"
										ref={(ref) => {
											const castedRef = ref as unknown as HTMLElement;
											innerRef(castedRef);
										}}
										d={`
					M ${fromX} ${fromY}
					C ${(fromX + toX) / 2} ${fromY} ${(fromX + toX) / 2} ${toY} ${toX} ${toY}
					L ${toX} ${toY + toHeight}
					C ${(fromX + toX) / 2} ${toY + toHeight} ${(fromX + toX) / 2} ${fromY + fromHeight} ${fromX} ${fromY + fromHeight}
					Z
			`}
									/>
								</>
							)}
						>
							<div className="shadow-[0px_8px_32px_rgba(0,0,0,0.16)] rounded bg-white flex flex-col flex-1 min-h-0 p-2 min-w-[150px]">
								{hoveredGroupIndex !== null && (
									<>
										<div
											style={{
												backgroundColor: stableColorGenerator(
													getColumn(columns, hoveredGroupIndex.to.columnIndex)
														.getGroup(hoveredGroupIndex.to.groupIndex)
														.getItem(hoveredGroupIndex.to.itemIndex).colorKey,
													0.7,
												),
											}}
											className="rounded text-center mb-2"
										>
											{
												getColumn(columns, hoveredGroupIndex.to.columnIndex).getGroup(hoveredGroupIndex.to.groupIndex)
													.label
											}
										</div>
										<Text type="Body/S/Book" as="p">
											{
												getColumn(columns, hoveredGroupIndex.from.columnIndex).getGroup(
													hoveredGroupIndex.from.groupIndex,
												).label
											}
											:&nbsp;
											<Text as="span" type="Body/S/Bold">
												{formatNumber(
													getColumn(columns, hoveredGroupIndex.from.columnIndex)
														.getGroup(hoveredGroupIndex.from.groupIndex)
														.getItem(hoveredGroupIndex.from.itemIndex).weight,
													1,
												)}
												%
											</Text>
										</Text>
										<Text type="Body/S/Book" as="p">
											{
												getColumn(columns, hoveredGroupIndex.to.columnIndex).getGroup(hoveredGroupIndex.to.groupIndex)
													.label
											}
											:&nbsp;
											<Text as="span" type="Body/S/Bold">
												{formatNumber(
													getColumn(columns, hoveredGroupIndex.to.columnIndex)
														.getGroup(hoveredGroupIndex.to.groupIndex)
														.getItem(hoveredGroupIndex.to.itemIndex).weight,
												)}
												%
											</Text>
										</Text>
									</>
								)}
							</div>
						</FloatingContent>
					);
				}}
			</ForEach>
		</Svg>
	);
}
