import * as d3 from "d3";
import { AxisInfo } from "../../components/types";
import { vec2 } from "gl-matrix";
import { getAngle, getScaleByType, getScaleOffset } from "../axis";
import { getTransformMatrix, getSecTransVec3, calBBox } from "./trans";
import { ELEMENT } from "../../components/constant";

/**
 * 获取散点图（高亮框）数据
 * @param v1 轴信息（vector）
 * @param v2 轴信息（vector）
 * @param mapping 映射信息
 */
export function getTransScatterInfo(
  v1: AxisInfo | any,
  v2: AxisInfo | any,
  mapping?: Record<string, any>
) {
  if (v2.type === ELEMENT.CIRCULAR_AXIS || v1.type === ELEMENT.CIRCULAR_AXIS) {
    let scatters = null;
    // 检测谁是 circularAxis，谁是 lineAxis
    const lineAxis = v2.type === ELEMENT.CIRCULAR_AXIS ? v1 : v2;
    const circularAxis = v2.type === ELEMENT.CIRCULAR_AXIS ? v2 : v1;

    const pureData = lineAxis.data.map((d: any) => d.value);
    const centerPos = circularAxis.startPos;
    const rotateAngle = getAngle(lineAxis.startPos, lineAxis.endPos);
    const rotateArc = (rotateAngle * Math.PI) / 180; // 极坐标旋转的 arc
    const radiusScale = lineAxis.scale;
    const xScale = circularAxis.scale;

    const len = lineAxis.len;
    const rawPoints = [
      [-len, -len, 1],
      [len, -len, 1],
      [len, len, 1],
      [-len, len, 1],
    ];
    const polyPoints = calBBox(rotateArc, centerPos, rawPoints).map((d) => [
      d.x,
      d.y,
    ]);

    // path 的 d 生成器
    const angle_data = d3.pie().sort(null);

    try {
      const line = d3.lineRadial().curve(d3.curveLinearClosed);

      scatters = angle_data(pureData)
        .map((angle, index) => {
          line
            .angle(
              xScale(circularAxis.data[index].value) +
                rotateArc +
                Math.PI / 2 +
                xScale.bandwidth() / 2
            )
            .radius(radiusScale(angle.value as any));
          // @ts-ignore
          return line([angle]).slice(1).slice(0, -1);
        })
        .map((s) => {
          const [x, y] = s.split(",");
          return { x: Number(x), y: Number(y) };
        });
    } catch (err) {
      console.warn(
        "please choose「string」type dimension for generate circular axis！"
      );
    }
    return {
      coordType: "polar",
      polyPoints,
      bbox: polyPoints.map((p) => ({ x: p[0], y: p[1] })),
      scatterPoints: {
        centerPos,
        scatters: scatters?.filter(
          (p) => !Number.isNaN(p.x) && !Number.isNaN(p.y)
        ),
        ...getMapping(mapping),
      },
    };
  }
  // 准备正交数据
  const orthoData1 = v1.data.map(
    (d: any) => v1.scale(d.value) + getScaleOffset(v1.scale)
  );
  const orthoData2 = v2.data.map(
    (d: any) => v2.scale(d.value) + getScaleOffset(v2.scale)
  );

  // 准备散点的坐标数组
  const scatterOrthoPoints = orthoData1.map((item: any, index: number) => {
    return [orthoData1[index], orthoData2[index], 1];
  });

  // 准备高亮散点框的坐标数组
  const polyPadding = 0;
  const polyOrthoPoints = [
    [0 - polyPadding, 0 - polyPadding, 1],
    [v1.len + polyPadding, 0 - polyPadding, 1],
    [v1.len + polyPadding, v2.len + polyPadding, 1],
    [0 - polyPadding, v2.len + polyPadding, 1],
  ];

  // #7: fix bug for parallel axes
  const realV1 = [v1.endPos.x - v1.startPos.x, v1.endPos.y - v1.startPos.y];
  const realV2 = [v2.endPos.x - v2.startPos.x, v2.endPos.y - v2.startPos.y];
  const v = vec2.cross([0, 0, 0], realV1 as any, realV2 as any);
  if (v[2] === 0) return { polyPoints: [], scatterPoints: {} };

  const { shear_matrix, firstTranslation, secTranslation } = getTransformMatrix(
    v1,
    v2
  );

  const scatters = scatterOrthoPoints.map((point: any) => {
    const secTransVec3 = getSecTransVec3(
      point,
      shear_matrix,
      firstTranslation,
      secTranslation
    );
    return {
      x: secTransVec3[0],
      y: secTransVec3[1],
      vec: secTransVec3[2],
    };
  });

  const polyPoints = polyOrthoPoints.map((point) => {
    return getSecTransVec3(
      point,
      shear_matrix,
      firstTranslation,
      secTranslation
    );
  });

  return {
    coordType: "cartesian",
    polyPoints,
    bbox: polyPoints.map((p) => ({ x: p[0], y: p[1] })),
    scatterPoints: {
      scatters: scatters?.filter(
        (p: any) => !Number.isNaN(p.x) && !Number.isNaN(p.y)
      ),
      ...getMapping(mapping),
    },
  };
}

/**
 * 映射
 * @param mapping 映射信息
 */
export function getMapping(mapping: Record<string, any> | undefined) {
  if (!mapping) return;
  const { size, color, opacity, strokeWidth } = mapping; // 映射属性存放可以映射的数据

  const mappingSizes = getLinearMapping(size, 20);
  const mappingOpacitys = getLinearMapping(opacity, 1);
  const mappingColors = getMappingColors(color);
  const mappingStrokeWidths = getLinearMapping(strokeWidth, 8);

  return { mappingSizes, mappingColors, mappingOpacitys, mappingStrokeWidths };
}

function getLinearMapping(mappingData: Record<string, any>, maxNum: number) {
  if (!mappingData) return;

  // 如：size 映射与 ‘MPG’对应
  // 默认给定最大范围半径 todo
  const pureData = mappingData.data.map((d: any) => d.value);
  // 判断 维度（如：MPG） 是什么类型的比例尺
  const scale = getScaleByType(pureData, maxNum) as
    | d3.ScaleBand<string>
    | d3.ScaleLinear<number, number, never>;

  // 所有半径
  return mappingData.data.map((d: any) => scale(d.value));
}

function getMappingColors(mappingData: Record<string, any>) {
  if (!mappingData) return;

  let mappingColors;
  // 如：color 映射与 ‘Cylinder’(string) 类型对应
  const pureData = mappingData.data.map((d: any) => d.value);
  // 如果是数值型维度，就从 红 -> 黄 -> 蓝
  // 如果是离散型维度，就给定分开的颜色
  if (typeof pureData[0] === "number") {
    const scale = d3
      .scaleLinear()
      .domain([Math.min(...pureData) as any, Math.max(...pureData)])
      .range([0, 1]); // [0,1] 配合颜色插值器

    mappingColors = mappingData.data.map((d: any) =>
      d3.interpolateRdYlBu(scale(d.value))
    );
  } else if (typeof pureData[0] === "string") {
    const scale = d3
      .scaleOrdinal()
      .domain(Array.from(new Set(pureData)))
      .range(d3.schemeCategory10);

    mappingColors = mappingData.data.map((d: any) => scale(d.value));
  }

  return mappingColors;
}
