import {
  AggregationType,
  type Method,
  MethodType,
  type ParticipantBehaviour,
  type ParticipantBehaviourRecording,
  type PhaseChangeLine,
  PhaseChangeType,
  type RecordingUser,
  SymbolType,
  type TimeSeries,
  type Trendline,
  TrendlineType,
} from '@piccolohealth/pbs-common';
import type { SelectOption } from '@piccolohealth/ui';
import { DateTime, P } from '@piccolohealth/util';
import * as d3Array from 'd3-array';
import * as d3Regression from 'd3-regression';
import type { EChartsOption, MarkLineComponentOption } from 'echarts';

export type DatasetSource = {
  timestamp: string;
  [key: string]: string | number;
};

export type RecordingUserData = {
  [timestamp: string]: {
    [dimension: string]: RecordingUser[];
  };
};

export type RecordingData = [DateTime, number, RecordingUser];

export const FLOOR_PERIOD = 'day';

export type ChartType = NonNullable<TimeSeries['__typename']>;

export const recordingToRecordingData = (
  recording: ParticipantBehaviourRecording,
): RecordingData => {
  const timestamp = DateTime.fromISO(recording.timestamp.toString());

  switch (recording.__typename) {
    case 'ParticipantBehaviourRecordingFrequency':
    case 'ParticipantBehaviourRecordingEpisodicSeverity':
    case 'ParticipantBehaviourRecordingDuration':
      return [timestamp, recording.value, recording.user];
    case 'ParticipantBehaviourRecordingAbc':
      return [timestamp, 0, recording.user];
    default:
      throw new Error(`Unknown recording type: ${recording.__typename}`);
  }
};

export const methodTypeToMethodTypename = (method: MethodType): Method['__typename'] => {
  switch (method) {
    case 'Frequency':
      return 'FrequencyMethod';
    case 'EpisodicSeverity':
      return 'EpisodicSeverityMethod';
    case 'Duration':
      return 'DurationMethod';
    case 'Abc':
  }
};

export const methodTypenameToMethodType = (typename: Method['__typename']): MethodType => {
  switch (typename) {
    case 'FrequencyMethod':
      return MethodType.Frequency;
    case 'EpisodicSeverityMethod':
      return MethodType.EpisodicSeverity;
    case 'DurationMethod':
      return MethodType.Duration;
    case 'AbcMethod':
      return MethodType.Abc;
    case undefined:
      return MethodType.Frequency;
  }
};

export const CHART_TYPE_OPTIONS: SelectOption<NonNullable<TimeSeries['__typename']>>[] = [
  {
    label: 'Line',
    value: 'LineTimeSeries',
    raw: 'LineTimeSeries',
    color: 'gray',
  },
  {
    label: 'Bar',
    value: 'BarTimeSeries',
    raw: 'BarTimeSeries',
    color: 'gray',
  },
  {
    label: 'Scatter',
    value: 'ScatterTimeSeries',
    raw: 'ScatterTimeSeries',
    color: 'gray',
  },
];

export const CHART_TYPE_OPTIONS_MAP: Record<ChartType, SelectOption<ChartType>> = P.keyBy(
  CHART_TYPE_OPTIONS,
  (p) => p.raw,
);

export const METHOD_TYPE_OPTIONS: SelectOption<MethodType>[] = [
  {
    label: 'Frequency',
    value: MethodType.Frequency,
    raw: MethodType.Frequency,
  },
  {
    label: 'Episodic Severity',
    value: MethodType.EpisodicSeverity,
    raw: MethodType.EpisodicSeverity,
  },
  {
    label: 'Duration',
    value: MethodType.Duration,
    raw: MethodType.Duration,
  },
  {
    label: 'ABC',
    value: MethodType.Abc,
    raw: MethodType.Abc,
  },
];

export const METHOD_TYPE_OPTIONS_MAP: Record<MethodType, SelectOption<MethodType>> = P.keyBy(
  METHOD_TYPE_OPTIONS,
  (p) => p.raw,
);

export const AGGREGATION_TYPE_OPTIONS: SelectOption<AggregationType>[] = [
  {
    label: 'Sum',
    value: AggregationType.Sum,
    raw: AggregationType.Sum,
    color: 'gray',
  },
  {
    label: 'Avg',
    value: AggregationType.Average,
    raw: AggregationType.Average,
    color: 'gray',
  },
  {
    label: 'Min',
    value: AggregationType.Min,
    raw: AggregationType.Min,
    color: 'gray',
  },
  {
    label: 'Max',
    value: AggregationType.Max,
    raw: AggregationType.Max,
    color: 'gray',
  },
];

export const AGGREGATION_TYPE_OPTIONS_MAP: Record<
  AggregationType,
  SelectOption<AggregationType>
> = P.keyBy(AGGREGATION_TYPE_OPTIONS, (p) => p.raw);

export const SYMBOL_TYPE_OPTIONS: SelectOption<SymbolType>[] = [
  {
    label: 'Circle',
    value: SymbolType.Circle,
    raw: SymbolType.Circle,
    color: 'gray',
  },
  {
    label: 'Rectangle',
    value: SymbolType.Rect,
    raw: SymbolType.Rect,
    color: 'gray',
  },
  {
    label: 'Triangle',
    value: SymbolType.Triangle,
    raw: SymbolType.Triangle,
    color: 'gray',
  },
  {
    label: 'Diamond',
    value: SymbolType.Diamond,
    raw: SymbolType.Diamond,
    color: 'gray',
  },
];

export const SYMBOL_TYPE_OPTIONS_MAP: Record<SymbolType, SelectOption<SymbolType>> = P.keyBy(
  SYMBOL_TYPE_OPTIONS,
  (p) => p.raw,
);

export const COLORS: Record<string, string> = {
  darkGray: 'rgba(99, 99, 102, 1)',
  gray: 'rgba(142, 142, 147, 1)',
  red: 'rgba(245, 101, 101, 1)',
  blue: 'rgba(0, 122, 255, 1)',
  green: 'rgba(40, 207, 81, 1)',
  yellow: 'rgba(250, 240, 137, 1)',
  white: 'rgba(255, 255, 255, 1)',
  lightBlue: 'rgba(0, 122, 255, 0.2)',
};

export const getRegression = (trendline?: Trendline | null) => {
  const regression = P.run(() => {
    switch (trendline?.type) {
      case TrendlineType.Linear:
        return d3Regression.regressionLinear();
      case TrendlineType.Polynomial:
        return d3Regression.regressionPoly();
      case null:
      case undefined:
        return d3Regression.regressionLinear();
    }
  });

  return regression.x((d: [DateTime, number]) => d[0]).y((d: [DateTime, number]) => d[1]);
};

export const seriesSelectionToAggregation = (options: {
  participantBehaviour: ParticipantBehaviour;
  series: TimeSeries;
  phaseChangeLines: PhaseChangeLine[];
  frequencyData: RecordingData[];
  episodicSeverityData: RecordingData[];
  durationData: RecordingData[];
  abcData: RecordingData[];
}) => {
  const {
    participantBehaviour,
    series,
    phaseChangeLines,
    frequencyData,
    episodicSeverityData,
    durationData,
  } = options;

  const data: RecordingData[] = P.run(() => {
    switch (series.source.method) {
      case MethodType.Frequency:
        return frequencyData;
      case MethodType.EpisodicSeverity:
        return episodicSeverityData;
      case MethodType.Duration:
        return durationData;
      case MethodType.Abc:
        return options.abcData;
      default:
        throw new Error(`Unknown method: ${series.source.method}`);
    }
  });

  const latest = d3Array.max(data, ([date]) => date);

  const aggregatedData = d3Array.rollup(
    data,
    (v) => {
      switch (series.aggregation) {
        case AggregationType.Min:
          return d3Array.min(v, ([, value]) => value);
        case AggregationType.Max:
          return d3Array.max(v, ([, value]) => value);
        case AggregationType.Average:
          return d3Array.mean(v, ([, value]) => value);
        case AggregationType.Sum:
          return d3Array.sum(v, ([, value]) => value);
      }
    },
    ([date]) => date.startOf(FLOOR_PERIOD),
  );

  const aggregatedUsers = d3Array.rollup(
    data,
    (v) => {
      return v.map(([, , user]) => user);
    },
    ([date]) => date.startOf(FLOOR_PERIOD),
  );

  const dimension = `${series.source.method}-${series.aggregation}-${participantBehaviour.id}`;
  const trendlineDimension = `${dimension}-trendline`;
  const regression = getRegression(series.trendline);

  const getAggregatedData = (date: DateTime): number | null => {
    return aggregatedData.get(date) ?? null;
  };

  const getTrendlineData = (date: DateTime): number | null => {
    // Do not regress a trendline beyond the latest date
    if (latest && date > latest) {
      return null;
    }

    const regressedValue = P.round(
      regression(Array.from(aggregatedData)).predict(date.toMillis()),
      4,
    );

    if (regressedValue < 0) {
      return null;
    }

    return regressedValue;
  };

  const getAggregatedUsers = (date: DateTime): RecordingUser[] => {
    return aggregatedUsers.get(date) ?? [];
  };

  const markLineData: MarkLineComponentOption['data'] = phaseChangeLines.map((line) => {
    return {
      name: line.name,
      xAxis: line.timestamp.toString(),
      label: {
        formatter: '{b}',
      },
      lineStyle: {
        type: line.type === PhaseChangeType.Major ? 'solid' : 'dashed',
        width: 4,
      },
      itemStyle: {
        color: line.color,
      },
    };
  });

  const seriesOptions: EChartsOption['series'] = P.run(() => {
    const methodTypeOption = METHOD_TYPE_OPTIONS_MAP[series.source.method];
    const aggregationTypeOption = AGGREGATION_TYPE_OPTIONS_MAP[series.aggregation];

    const name = `${participantBehaviour.name} - ${methodTypeOption.label} (${aggregationTypeOption.label})`;
    const encode = { x: 'timestamp', y: dimension };

    const markLine = {
      symbol: ['none', 'none'],
      data: markLineData,
      animation: false,
    };

    switch (series.__typename) {
      case 'LineTimeSeries':
        return {
          id: series.id,
          type: 'line' as const,
          name,
          encode,
          symbol: series.symbol.toLowerCase(),
          symbolSize: 6,
          connectNulls: true,
          itemStyle: {
            color: series.color,
            borderRadius: [5, 5, 0, 0],
          },
          lineStyle: {
            color: series.color,
            width: 2,
          },
          markLine,
        };
      case 'BarTimeSeries':
        return {
          id: series.id,
          type: 'bar' as const,
          name,
          encode,
          itemStyle: {
            color: series.color,
            borderRadius: [5, 5, 0, 0],
          },
          markLine,
        };
      case 'ScatterTimeSeries':
        return {
          id: series.id,
          type: 'scatter' as const,
          name,
          encode,
          symbol: series.symbol.toLowerCase(),
          symbolSize: 10,
          z: 3,
          zlevel: 3,
          itemStyle: {
            color: series.color,
          },
          markLine,
        };
    }
  });

  const trendlineSeriesOptions: EChartsOption['series'] | null = P.run(() => {
    if (!series.trendline) {
      return null;
    }

    const name = `${participantBehaviour.name} - ${series.trendline.name}`;
    const encode = { x: 'timestamp', y: trendlineDimension };

    return {
      type: 'line' as const,
      name,
      encode,
      symbol: 'none',
      lineStyle: {
        color: series.trendline.color,
        type: 'dashed',
        width: 2,
      },
    };
  });

  return {
    series,
    seriesOptions,
    trendlineSeriesOptions,
    aggregatedData,
    dimension,
    trendlineDimension,
    getAggregatedData,
    getAggregatedUsers,
    getTrendlineData,
  };
};

export interface NormalizedRecordingDataset {
  dataset: EChartsOption['dataset'];
  users: RecordingUserData;
  seriesOptions: EChartsOption['series'];
  xAxisFormat: string;
  intervals?: DateTime[];
  first?: RecordingData;
  last?: RecordingData;
}

const generateDateRange = (
  start: DateTime,
  end: DateTime,
  unit: 'day' | 'month' | 'year' = 'day',
): DateTime[] => {
  let current = start;
  const dates: DateTime[] = [];

  while (current <= end) {
    dates.push(current);
    current = current.plus({ [unit]: 1 });
  }

  return dates;
};

export const getNormalizedRecordingDataset = (options: {
  series: TimeSeries[];
  phaseChangeLines: PhaseChangeLine[];
  participantBehaviours: ParticipantBehaviour[];
}): NormalizedRecordingDataset => {
  const { phaseChangeLines } = options;

  const allRecordings: RecordingData[] = options.participantBehaviours
    .flatMap((p) => p.recordings)
    .map(recordingToRecordingData);

  const first = P.minBy(allRecordings, ([date]) => date.toMillis());
  const last = P.maxBy(allRecordings, ([date]) => date.toMillis());

  if (!first || !last) {
    return {
      dataset: {
        source: [],
      },
      users: {},
      seriesOptions: [],
      xAxisFormat: '{d} {MMM}',
    };
  }

  const stop = DateTime.now().endOf('day');
  // Start one year ago or the first recording date, whichever is later
  const start = P.run(() => {
    const oneYearAgo = last[0].minus({ years: 1 });
    return first[0] < oneYearAgo ? first[0].startOf('day') : oneYearAgo.startOf('day');
  });

  const intervals = generateDateRange(start, stop, 'day');
  const xAxisFormat = '{MMM} {d}';
  const aggregations = P.run(() => {
    const aggregated = options.series.map((series) => {
      const participantBehaviour = options.participantBehaviours.find(
        (pb) => pb.id === series.source.participantBehaviourId,
      );

      if (!participantBehaviour) {
        return null;
      }

      const recordings = participantBehaviour.recordings;

      const frequencyData: RecordingData[] = recordings
        .filter((r) => r.__typename === 'ParticipantBehaviourRecordingFrequency')
        .map(recordingToRecordingData);

      const durationData: RecordingData[] = recordings
        .filter((r) => r.__typename === 'ParticipantBehaviourRecordingDuration')
        .map(recordingToRecordingData);

      const episodicSeverityData: RecordingData[] = recordings
        .filter((r) => r.__typename === 'ParticipantBehaviourRecordingEpisodicSeverity')
        .map(recordingToRecordingData);

      const abcData: RecordingData[] = recordings
        .filter((r) => r.__typename === 'ParticipantBehaviourRecordingAbc')
        .map(recordingToRecordingData);

      return seriesSelectionToAggregation({
        participantBehaviour,
        series,
        phaseChangeLines,
        frequencyData,
        episodicSeverityData,
        durationData,
        abcData,
      });
    });

    return P.compact(aggregated);
  });

  const dimensions: string[] = [
    'timestamp',
    ...aggregations.map((aggregation) => aggregation.dimension),
    ...aggregations.map((aggregation) => aggregation.trendlineDimension),
  ];

  const source: DatasetSource[] = intervals.map((date) => {
    const seriesData = Object.fromEntries(
      aggregations.map((aggregation) => {
        return [aggregation.dimension, aggregation.getAggregatedData(date)];
      }),
    );

    const trendlineData = Object.fromEntries(
      aggregations.map((aggregation) => {
        return [aggregation.trendlineDimension, aggregation.getTrendlineData(date)];
      }),
    );

    return {
      timestamp: date.toISO(),
      ...seriesData,
      ...trendlineData,
    };
  });

  const users = intervals.reduce<RecordingUserData>((acc, date) => {
    const dimensions = aggregations.reduce<Record<string, RecordingUser[]>>(
      (aggAcc, aggregation) => {
        aggAcc[aggregation.dimension] = aggregation.getAggregatedUsers(date);
        return aggAcc;
      },
      {},
    );

    acc[date.toISO()] = dimensions;
    return acc;
  }, {});

  const dataset = {
    source,
    dimensions,
  };

  const seriesOptions: EChartsOption['series'] = aggregations.flatMap((a) =>
    P.compact([a.seriesOptions, a.trendlineSeriesOptions]),
  );

  return {
    dataset,
    seriesOptions,
    xAxisFormat,
    first,
    last,
    intervals,
    users,
  };
};
