import { useMemoize } from '@vueuse/core'
import { defineStore } from 'pinia'
import { equals } from 'ramda'
import { computed, type Ref, ref, watch } from 'vue'
import { useRouter } from 'vue-router'

import { fishersExactTest, isOverview, isTrends, useStudyInsightsStore } from '@attest/results-core'
import { useStudyStore } from '@attest/study'
import { difference, intersection } from '@attest/util'

import {
  type ColumnDefinitionValue,
  type DynamicColumnDefinitionValue,
  isStaticColumnDefinitionValue,
} from '../column-definition'
import type { Row, RowChild } from '../row'
import { getColumnCacheKey } from '../util'

import { useCrosstabs2ColumnDefinitionsStore } from './column-definition'

export type SigTestValue = 'higher' | 'lower' | 'none'
export type SigTestComparison = StackedColumnDefinitionValue | 'previous-column' | 'total'
export type StackedColumnDefinitionValue = DynamicColumnDefinitionValue & {
  stackedColumnDefinitionValue?: DynamicColumnDefinitionValue
}

export type SigTestContext = 'crosstabs' | 'trends' | 'overview'

const SIG_TEST_STORE = 'sigTest'

export const useSigTestStore = defineStore(SIG_TEST_STORE, () => {
  const studyStore = useStudyStore()
  const studyInsightsStore = useStudyInsightsStore()

  const router = useRouter()

  const sigTestContext: Ref<SigTestContext> = computed(() => {
    if (isTrends(router.currentRoute.value)) {
      return 'trends'
    }
    if (isOverview(router.currentRoute.value)) {
      return 'overview'
    }
    return 'crosstabs'
  })

  const enabled: Record<SigTestContext, Ref<boolean>> = {
    overview: ref(false),
    crosstabs: ref(true),
    trends: ref(true),
  }

  const comparison: Record<SigTestContext, Ref<SigTestComparison>> = {
    overview: ref('total'),
    crosstabs: ref('total'),
    trends: ref('previous-column'),
  }

  const pValue: Record<SigTestContext, Ref<number>> = {
    overview: ref(0.05),
    crosstabs: ref(0.05),
    trends: ref(0.05),
  }

  function updateComparison(newValue: SigTestComparison | 'total' | 'previous-column'): void {
    comparison[sigTestContext.value].value = newValue
  }

  const definitionValues = computed(() => {
    const { dynamicDefinitions } = useCrosstabs2ColumnDefinitionsStore()
    return dynamicDefinitions.flatMap(({ type, variable, values }) =>
      values.map(value => ({ type, variable, value })),
    )
  })

  const withCalculate = useMemoize(
    (args: SigTestArgs) =>
      computed(() => {
        if (!enabled[sigTestContext.value].value) {
          return 'none'
        }
        return calculateCrosstabSignificance({
          pValueRef: pValue[sigTestContext.value],
          comparisonRef: comparison[sigTestContext.value],
          definitionsRef: definitionValues,
          ...args,
        })
      }),
    {
      getKey: ({ cardId, row, rowChild, column, calculatePercentageFromForwardedSample }) => {
        return `${cardId}${row.id}${rowChild.id}${getColumnCacheKey(column)}${calculatePercentageFromForwardedSample}${[...(rowChild.forwardedRespondentIds ?? [])].join('')}${[...rowChild.respondentIds].join('')}`
      },
    },
  )

  function $reset(): void {
    enabled.crosstabs.value = true
    enabled.trends.value = true
    comparison.crosstabs.value = 'total'
    comparison.trends.value = 'previous-column'
    pValue.crosstabs.value = 0.05
    pValue.trends.value = 0.05
    withCalculate.clear()
  }

  const baseRespondentIds = computed(() => studyInsightsStore.baseRespondentIds)
  watch(baseRespondentIds, () => withCalculate.clear())

  // reset state if study id shifts
  const id = computed(() => studyStore.study?.id)
  watch(id, $reset)

  return {
    enabled,
    comparison,
    pValue,
    updateComparison,
    sigTestContext,
    calculate: (args: SigTestArgs) =>
      !enabled[sigTestContext.value].value ? 'none' : withCalculate(args).value,
    $reset,
  }
})

type SigTestArgs = {
  cardId: string
  column: ColumnDefinitionValue
  rowChild: RowChild
  row: Row
  calculatePercentageFromForwardedSample: boolean
  dynamicColumnDefinitionValues?: DynamicColumnDefinitionValue[]
}

function getRespondentIdsForCard(cardId: string, definition: ColumnDefinitionValue): Set<string> {
  return useCrosstabs2ColumnDefinitionsStore().getRespondentIdsForCard(cardId, definition)
}

function calculateCrosstabSignificance(
  args: {
    pValueRef: Ref<number>
    comparisonRef: Ref<SigTestComparison>
    definitionsRef: Ref<DynamicColumnDefinitionValue[]>
  } & SigTestArgs,
): SigTestValue {
  if (isStaticColumnDefinitionValue(args.column)) return 'none'
  if (equals(args.comparisonRef.value, args.column)) return 'none'

  const { getAllRespondentIdsForCard } = useCrosstabs2ColumnDefinitionsStore()

  let respondentIds = getAllRespondentIdsForCard(args.cardId) ?? new Set()

  const forwardedRespondentIds = args.rowChild.forwardedRespondentIds as Set<string>

  if (args.calculatePercentageFromForwardedSample) {
    respondentIds = intersection([respondentIds, forwardedRespondentIds])
  }

  const definitions = args.dynamicColumnDefinitionValues ?? args.definitionsRef.value

  const comparisonRespondentIds = (() => {
    let comparisonIds
    if (args.comparisonRef.value === 'total') {
      comparisonIds = respondentIds
    } else if (args.comparisonRef.value === 'previous-column') {
      const foundIndex = definitions.findIndex(
        definition =>
          definition.variable === args.column.variable && definition.value === args.column.value,
      )

      comparisonIds =
        foundIndex < 1
          ? new Set()
          : getRespondentIdsForCard(args.cardId, definitions[foundIndex - 1])
    } else {
      comparisonIds = getRespondentIdsForCard(args.cardId, args.comparisonRef.value)
    }

    if (args.calculatePercentageFromForwardedSample) {
      comparisonIds = intersection([comparisonIds, forwardedRespondentIds])
    }

    return comparisonIds
  })()

  const rowChildRespondentIds = args.calculatePercentageFromForwardedSample
    ? intersection([forwardedRespondentIds, args.rowChild.respondentIds])
    : args.rowChild.respondentIds
  const columnRespondentIds = getRespondentIdsForCard(args.cardId, args.column)
  const didNotSelectRowRespondentIds = difference(respondentIds, rowChildRespondentIds)

  const selectedCell = intersection([rowChildRespondentIds, columnRespondentIds])
  const didNotSelectCell = intersection([didNotSelectRowRespondentIds, columnRespondentIds])
  const selectedComparisonCell = intersection([rowChildRespondentIds, comparisonRespondentIds])
  const didNotSelectComparisonCell = intersection([
    didNotSelectRowRespondentIds,
    comparisonRespondentIds,
  ])

  const twoTailedPValue = fishersExactTest(
    selectedCell.size,
    selectedComparisonCell.size,
    didNotSelectCell.size,
    didNotSelectComparisonCell.size,
  ).twoTailedPValue

  if (twoTailedPValue < args.pValueRef.value) {
    return selectedCell.size / (selectedCell.size + didNotSelectCell.size) >
      selectedComparisonCell.size / (selectedComparisonCell.size + didNotSelectComparisonCell.size)
      ? 'higher'
      : 'lower'
  }

  return 'none'
}

export function getSignificantComparisonTitle(comparison: SigTestComparison): string {
  if (comparison === 'total') return 'Total'
  if (comparison === 'previous-column')
    return useSigTestStore().sigTestContext === 'crosstabs' ? 'Previous Column' : 'Previous Wave'
  return useCrosstabs2ColumnDefinitionsStore().getValueTitle(comparison)
}

export function getSignificanceConfidenceLevel(pValue: number): string {
  return `${(1 - pValue) * 100}%`
}
