import React, { useEffect, useMemo, useState } from 'react'

import { Box, Grid, IconButton, Stack, Typography } from '@mui/material'
import { useForm } from 'react-hook-form'
import { enqueueSnackbar } from 'notistack'

import { useAuth } from '../context/AuthContext'
import PageHeading from '../components/PageHeading'
import { iconsObj } from '../icons/Icons'
import { ListItemWithAvatar } from '../components/buttons/TargetAndBaseDataButton'
import ViolinPlot, { ViolinPlotData } from '../components/plots/ViolinPlot'
import ZScorePlot, { HighlightedSampleData } from '../components/plots/ZScorePlot'
import { TargetDataResponse } from '../components/dialogs/TargetDataDialog'
import { Option } from '../components/forms/FormProps'
import FormInputSelect from '../components/forms/FormInputSelect'
import ROCPlot, { ROCPlotData } from './plots/RocCurvePlot'
import ViolinPlotTooltip from './ViolinPlotTooltip'
import NoMaxWidthTooltip from './NoMaxWidthTooltip'
import RocCurveTooltip from './RocCurveTooltip'
import LogisticRegressionPlot from './plots/LogisticRegressionPlot'

export type PrsResultsProps = {
  pageHeading: string
  baseDataName: string
  targetDataName: string
  targetData: TargetDataResponse
  snpIds: number[]
  isLoading: boolean
  setIsLoading: (isLoading: boolean) => void
}

type RelevantVariants = {
  deduplicated_gwas_variants_ids: number[]
  duplicate_gwas_variants_ids: number[]
  unusable_gwas_variants_ids: number[]
}

export type PrsCalculationsResponse = {
  patient_class: PatientClass[]
}

type PatientClass = {
  patient_class: {
    class_name: string
    id: number
  }
  patient_prs_results: PatientPrsResult[]
}

type PatientPrsResult = {
  patient: {
    sample_name: string
    id: number
  }
  linear_weighted_prs: number
  linear_weighted_log_or_prs: number
  plink_prs: number
  prsice_avg_prs: number
  prsice_std_prs: number
  unweighted_prs: number
  // eslint-disable-next-line @typescript-eslint/no-explicit-any
  [key: string]: any
}

type FormInput = {
  prsMethod: string
  controlGroup: string
  experimentalGroup: string
  highlightedPatient: string
}

const PrsResults: React.FC<PrsResultsProps> = (props) => {
  const [relevantWeightedSnps, setRelevantWeightedSnps] = useState<RelevantVariants>({} as RelevantVariants)
  const [relevantUnweightedSnps, setRelevantUnweightedSnps] = useState<RelevantVariants>({} as RelevantVariants)
  const [weightedPrsResults, setWeightedPrsResults] = useState<PrsCalculationsResponse>({} as PrsCalculationsResponse)
  const [unweightedPrsResults, setUnweightedPrsResults] = useState<PrsCalculationsResponse>({} as PrsCalculationsResponse)

  const [groupsOptions, setGroupsOptions] = useState<Option[]>([])
  const [highlightedPatientOptions, setHighlightedPatientOptions] = useState<Option[]>([])

  const { backendRequest, logout } = useAuth()

  const { control, watch } = useForm<FormInput>({
    defaultValues: {
      prsMethod: 'PLINK',
      controlGroup: '',
      experimentalGroup: '',
      highlightedPatient: '',
    },
  })
  const selectedPRSMethod = watch('prsMethod')
  const selectedControlGroup = watch('controlGroup')
  const selectedExperimentalGroup = watch('experimentalGroup')
  const selectedHighlightedPatient = watch('highlightedPatient')

  const prsMethodOptions: Option[] = useMemo(
    () => [
      { id: 'plink_prs', label: 'PLINK' },
      { id: 'unweighted_prs', label: 'Unweighted' },
      { id: 'linear_weighted_prs', label: 'Linear Weighted' },
      { id: 'linear_weighted_log_or_prs', label: 'Linear Weighted Log(OR)' },
      { id: 'prsice_avg_prs', label: 'PRSice Avg' },
      { id: 'prsice_std_prs', label: 'PRSice Std' },
    ],
    []
  )

  useEffect(() => {
    props.setIsLoading(true)
    if (props.snpIds.length > 0 && Object.keys(props.targetData).length > 0) {
      const getData = async () => {
        const responseRelevantWeightedSnps = await backendRequest({
          method: 'POST',
          endpoint: '/api/gwas-variants/get-relevant-variants?prs_method_type=WEIGHTED',
          requiresAuth: true,
          body: {
            gwas_variants_ids: props.snpIds,
          },
        })

        const relevantWeightedSnpsData = responseRelevantWeightedSnps.data as RelevantVariants | undefined

        const responseRelevantUnweightedSnps = await backendRequest({
          method: 'POST',
          endpoint: '/api/gwas-variants/get-relevant-variants?prs_method_type=UNWEIGHTED',
          requiresAuth: true,
          body: {
            gwas_variants_ids: props.snpIds,
          },
        })

        const relevantUnweightedSnpsData = responseRelevantUnweightedSnps.data as RelevantVariants | undefined

        const filteredPatientClasses = props.targetData.patient_classes
          .filter((item) => item.patient_class.class_name !== 'Unclassified')
          .map((item) => ({
            patient_class_id: item.patient_class.id,
            patient_ids: item.patients.map((patient) => patient.id),
          }))

        const responseWeightedPrsCalculations = await backendRequest({
          method: 'POST',
          endpoint: '/api/prs/weighted-prs-calculations',
          requiresAuth: true,
          body: {
            gwas_variants_ids: relevantWeightedSnpsData?.deduplicated_gwas_variants_ids,
            patient_classes: filteredPatientClasses,
          },
        })

        const weightedPrsCalculationsData = responseWeightedPrsCalculations.data as PrsCalculationsResponse | undefined

        const responseUnweightedPrsCalculations = await backendRequest({
          method: 'POST',
          endpoint: '/api/prs/unweighted-prs-calculation',
          requiresAuth: true,
          body: {
            gwas_variants_ids: relevantUnweightedSnpsData?.deduplicated_gwas_variants_ids,
            patient_classes: filteredPatientClasses,
          },
        })

        const unweightedPrsCalculationsData = responseUnweightedPrsCalculations.data as PrsCalculationsResponse | undefined

        return {
          relevantWeightedSnpsStatus: responseRelevantWeightedSnps.status,
          relevantWeightedSnpsData: relevantWeightedSnpsData,
          relevantUnweightedSnpsStatus: responseRelevantUnweightedSnps.status,
          relevantUnweightedSnpsData: relevantUnweightedSnpsData,
          weightedPrsCalculationsStatus: responseWeightedPrsCalculations.status,
          weightedPrsCalculationsData: weightedPrsCalculationsData,
          unweightedPrsCalculationsStatus: responseUnweightedPrsCalculations.status,
          unweightedPrsCalculationsData: unweightedPrsCalculationsData,
        }
      }
      getData()
        .then(
          ({
            relevantWeightedSnpsStatus,
            relevantWeightedSnpsData,
            relevantUnweightedSnpsStatus,
            relevantUnweightedSnpsData,
            weightedPrsCalculationsStatus,
            weightedPrsCalculationsData,
            unweightedPrsCalculationsStatus,
            unweightedPrsCalculationsData,
          }) => {
            if (relevantWeightedSnpsStatus === 200 && relevantWeightedSnpsData) {
              setRelevantWeightedSnps(relevantWeightedSnpsData)
            } else if (relevantWeightedSnpsStatus === 401) {
              logout()
            }
            if (relevantUnweightedSnpsStatus === 200 && relevantUnweightedSnpsData) {
              setRelevantUnweightedSnps(relevantUnweightedSnpsData)
            } else if (relevantWeightedSnpsStatus === 401) {
              logout()
            }
            if (weightedPrsCalculationsStatus === 200 && weightedPrsCalculationsData) {
              setWeightedPrsResults(weightedPrsCalculationsData)
            } else if (weightedPrsCalculationsStatus === 401) {
              logout()
            }
            if (unweightedPrsCalculationsStatus === 200 && unweightedPrsCalculationsData) {
              setUnweightedPrsResults(unweightedPrsCalculationsData)
            } else if (weightedPrsCalculationsStatus === 401) {
              logout()
            }
            const groupsOptions = props.targetData.patient_classes
              .filter((item) => item.patient_class.class_name !== 'Unclassified')
              .map((item) => ({
                id: item.patient_class.class_name,
                label: item.patient_class.class_name,
              }))
            setGroupsOptions(groupsOptions)

            const modifiedTargetDataPatientsData =
              props.targetData.patient_classes
                .filter((item) => item.patient_class.class_name !== 'Unclassified')
                .map((item) => item.patients.map((patient) => patient.sample_name)) || []

            const newHighlightedPatientOptions = modifiedTargetDataPatientsData.flat().map((patient) => ({
              id: patient,
              label: patient,
            }))

            setHighlightedPatientOptions(newHighlightedPatientOptions)
          }
        )
        .catch((error) => {
          enqueueSnackbar(`An error occurred. ${error.response?.data.detail}`, { variant: 'error' })
        })
        .finally(() => {
          props.setIsLoading(false)
        })
    }
  }, [props.snpIds, props.targetData])

  const violinPlotData: ViolinPlotData[] = useMemo(() => {
    const selectedPRSMethodId = prsMethodOptions.find((option) => option.label === selectedPRSMethod)?.id
    if (!selectedPRSMethodId) {
      return []
    }

    // Check if the selected method is 'unweighted_prs' and use the unweighted results if so
    const isUnweighted = selectedPRSMethodId === 'unweighted_prs'
    const currentResults = isUnweighted ? unweightedPrsResults : weightedPrsResults

    if (Object.keys(currentResults).length > 0) {
      return currentResults.patient_class.map((patientClass) => ({
        groupName: patientClass.patient_class.class_name,
        y: patientClass.patient_prs_results.map((patientPrsResult) => patientPrsResult[selectedPRSMethodId] || patientPrsResult.unweighted_prs),
        yText: patientClass.patient_prs_results.map((patientPrsResult) => patientPrsResult.patient.sample_name),
      }))
    }
    return []
  }, [selectedPRSMethod, unweightedPrsResults, weightedPrsResults])

  const highlightedPatientData: HighlightedSampleData | undefined = useMemo(() => {
    const selectedPRSMethodId = prsMethodOptions.find((option) => option.label === selectedPRSMethod)?.id
    if (!selectedPRSMethodId || !selectedHighlightedPatient) {
      return undefined
    }

    const isUnweighted = selectedPRSMethodId === 'unweighted_prs'
    const currentResults = isUnweighted ? unweightedPrsResults : weightedPrsResults

    for (const patientClass of currentResults.patient_class) {
      const patient = patientClass.patient_prs_results.find((prsResult) => prsResult.patient.sample_name === selectedHighlightedPatient)
      if (patient) {
        return {
          y: isUnweighted ? patient.unweighted_prs : patient[selectedPRSMethodId],
          yText: patient.patient.sample_name,
          isControl: patientClass.patient_class.class_name === selectedControlGroup,
        }
      }
    }

    return undefined
  }, [selectedPRSMethod, selectedControlGroup, selectedHighlightedPatient, unweightedPrsResults, weightedPrsResults])

  const ROCPlotData: ROCPlotData[] = useMemo(() => {
    const selectedPRSMethodId = prsMethodOptions.find((option) => option.label === selectedPRSMethod)?.id
    if (!selectedPRSMethodId) {
      return []
    }

    // Check if the selected method is 'unweighted_prs' and use the unweighted results if so
    const isUnweighted = selectedPRSMethodId === 'unweighted_prs'
    const currentResults = isUnweighted ? unweightedPrsResults : weightedPrsResults

    if (Object.keys(currentResults).length > 0) {
      return currentResults.patient_class.map((patientClass) => ({
        groupName: patientClass.patient_class.class_name,
        y: patientClass.patient_prs_results.map((patientPrsResult) => (isUnweighted ? patientPrsResult.unweighted_prs : patientPrsResult[selectedPRSMethodId])),
      }))
    }
    return []
  }, [selectedPRSMethod, unweightedPrsResults, weightedPrsResults])

  const numberOfSnps = useMemo(() => {
    if (Object.keys(relevantUnweightedSnps).length === 0 || Object.keys(relevantWeightedSnps).length === 0) {
      return undefined
    }
    
    if (selectedPRSMethod === 'Unweighted') {
      return relevantUnweightedSnps.deduplicated_gwas_variants_ids.length
    } else {
      return relevantWeightedSnps.deduplicated_gwas_variants_ids.length
    }
  }, [selectedPRSMethod, relevantUnweightedSnps, relevantWeightedSnps])

  return (
    <Grid container direction='column' justifyContent='space-between' alignItems='stretch' height={1}>
      <Grid container item direction='row' justifyContent='space-between' alignItems='center'>
        <Grid item md={6} xs={8}>
          <Box display='flex' alignItems='center' gap={2}>
            <PageHeading icon={iconsObj.ANALYSES} id='prs-results-heading'>{props.pageHeading}</PageHeading>
            <Typography variant='h6' textAlign='left'>
              Num. of SNPs: {numberOfSnps}
            </Typography>
          </Box>
        </Grid>
        <Grid item xs={4} md={6} lg={6} xl={4}>
          <Stack direction='row'>
            <FormInputSelect
              name='highlightedPatient'
              label='Select Highlighted Patient'
              control={control}
              options={highlightedPatientOptions}
              isLoading={props.isLoading}
            />
            <FormInputSelect name='prsMethod' label='Select PRS Method' control={control} options={prsMethodOptions} isLoading={props.isLoading} />
          </Stack>
        </Grid>
      </Grid>
      <Grid container direction='row' justifyContent='space-between' alignItems='center' item height={1}>
        <Grid item xs={12} sm={12} md={6} lg={5} xl={3}>
          <Stack direction='row'>
            <ListItemWithAvatar icon={iconsObj.TARGET_DATA} primaryLabel='Target Data' secondaryLabel={props.targetDataName} />
            <ListItemWithAvatar icon={iconsObj.BASE_DATA} primaryLabel='Base Data' secondaryLabel={props.baseDataName} />
          </Stack>
        </Grid>
        <Grid item xs={4} md={6} lg={6} xl={4}>
          <Stack direction='row'>
            <FormInputSelect name='controlGroup' label='Select Control Group' control={control} options={groupsOptions} isLoading={props.isLoading} />
            <FormInputSelect name='experimentalGroup' label='Select Experimental Group' control={control} options={groupsOptions} isLoading={props.isLoading} />
          </Stack>
        </Grid>
      </Grid>
      {/* PRS charts Grid */}
      <Grid container item direction='row' mt={3} justifyContent='space-between' alignItems='stretch' height={1}>
        <Grid item xs={12} md={12} xl={6} height='calc(40vh)'>
          <Stack direction='row' spacing={1}>
            <Typography variant='h6' textAlign='left'>
              Violin Plot
            </Typography>
            <NoMaxWidthTooltip
              placement='right'
              title={
                <ViolinPlotTooltip
                  selectedPRSMethod={prsMethodOptions.find((option) => option.label === selectedPRSMethod)?.id as string}
                  unweightedPrsResults={unweightedPrsResults}
                  weightedPrsResults={weightedPrsResults}
                  columns={groupsOptions.map((option) => option.label)}
                  numberOfGroups={groupsOptions.length}
                  isLoading={props.isLoading}
                />
              }
            >
              <IconButton size='small' sx={{ px: 1 }}>
                {iconsObj.INFO}
              </IconButton>
            </NoMaxWidthTooltip>
          </Stack>
          <ViolinPlot data={violinPlotData} isLoading={props.isLoading} />
        </Grid>
        <Grid item xs={12} md={12} xl={6} height='calc(40vh)'>
          <Typography variant='h6' textAlign='left'>
            Logistic Regression
          </Typography>
          <LogisticRegressionPlot
            isLoading={props.isLoading}
            selectedPRSMethod={prsMethodOptions.find((option) => option.label === selectedPRSMethod)?.id as string}
            controlClass={selectedControlGroup}
            experimentalClass={selectedExperimentalGroup}
            unweightedPrsResults={unweightedPrsResults}
            weightedPrsResults={weightedPrsResults}
            highlightedSampleData={highlightedPatientData}
          />
        </Grid>
        <Grid item xs={12} md={12} xl={6} height='calc(30vh)'>
          <Typography variant='h6' textAlign='left'>
            Z-Score
          </Typography>
          <ZScorePlot data={violinPlotData} isLoading={props.isLoading} controlClass={selectedControlGroup} highlightedSampleData={highlightedPatientData} />
        </Grid>
        <Grid item xs={12} md={12} xl={6} height='calc(30vh)'>
          <Stack direction='row' spacing={1}>
            <Typography variant='h6' textAlign='left'>
              ROC Curve
            </Typography>
            <NoMaxWidthTooltip
              placement='right'
              title={
                <RocCurveTooltip
                  selectedPRSMethod={prsMethodOptions.find((option) => option.label === selectedPRSMethod)?.id as string}
                  unweightedPrsResults={unweightedPrsResults}
                  weightedPrsResults={weightedPrsResults}
                  controlClass={selectedControlGroup}
                  experimentalClass={selectedExperimentalGroup}
                  isLoading={props.isLoading}
                />
              }
            >
              <IconButton size='small' sx={{ px: 1 }}>
                {iconsObj.INFO}
              </IconButton>
            </NoMaxWidthTooltip>
          </Stack>
          <ROCPlot data={ROCPlotData} isLoading={props.isLoading} controlClass={selectedControlGroup} experimentalClass={selectedExperimentalGroup} />
        </Grid>
      </Grid>
    </Grid>
  )
}

export default PrsResults
