import { sumIterable } from "$root/utils";
import type { StylableProps } from "@mdotm/mdotui/components";
import { dedup } from "@mdotm/mdotui/utils";
import type { SankeyChartColumn, SankeyChartGroup, SankeyChartPath } from "../SankeyLikeChart/SankeyLikeChart";
import { SankeyLikeChart } from "../SankeyLikeChart/SankeyLikeChart";

export type ExposureSankeyLikeChartProps<
	TData extends Array<{ name: string; label: string; weight: number; items: Array<{ weight: number }> }>,
> = {
	data: TData;
	aggregateBy: keyof TData[number]["items"][number];
} & StylableProps;

export function ExposureSankeyLikeChart<
	TData extends Array<{ name: string; label: string; weight: number; items: Array<{ weight: number }> }>,
>({ data, aggregateBy, ...stylableProps }: ExposureSankeyLikeChartProps<TData>): JSX.Element {
	const destNames = dedup(data.flatMap((x) => x.items.map((it) => (it as any)[aggregateBy])));

	const destGroups: Array<SankeyChartGroup> = destNames.map((destName) => {
		const items = data
			.map((source) =>
				source.items
					.filter((it) => (it as any)[aggregateBy] === destName)
					.map((it) => ({
						weight: (it.weight * source.weight) / 100,
						name: destName,
						colorKey: destName,
					})),
			)
			.flat();
		return {
			label: destName,
			name: destName,
			weight: sumIterable(items, (x) => x.weight),
			items,
		};
	});
	const columns: Array<SankeyChartColumn> = [
		{
			marginTop: 20,
			marginBottom: 20,
			groups: data.map((source) => ({
				label: source.label,
				name: source.name,
				weight: source.weight,
				items: source.items.map((item) => ({
					weight: item.weight,
					name: (item as any)[aggregateBy],
					colorKey: (item as any)[aggregateBy],
				})),
			})),
		},
		{
			marginTop: 40,
			marginBottom: 40,
			groups: destGroups,
		},
	];

	const destSlotsAllocation = new Map(destGroups.map((d) => [d.name, { cur: 0, total: d.items.length }]));
	const paths: Array<SankeyChartPath> = data.flatMap((source, sourceIndex) =>
		source.items.map((sourceItem, sourceItemIndex) => {
			const slot = destSlotsAllocation.get((sourceItem as any)[aggregateBy])!;
			const cur = slot.cur;
			slot.cur++;
			return {
				from: {
					columnIndex: 0,
					groupIndex: sourceIndex,
					itemIndex: sourceItemIndex,
				},
				to: {
					columnIndex: 1,
					groupIndex: destNames.indexOf((sourceItem as any)[aggregateBy]),
					itemIndex: cur,
				},
			};
		}),
	);

	return <SankeyLikeChart {...stylableProps} columns={columns} paths={paths} />;
}
