import React, { useCallback, useMemo } from 'react'

import { Box, Grid, IconButton, Stack, Typography } from '@mui/material'
import { useForm } from 'react-hook-form'

import { useAuth } from '../../context/AuthContext'
import PageHeading from '../PageHeading'
import { iconsObj } from '../../icons/Icons'
import { ListItemWithAvatar } from '../buttons/TargetAndBaseDataButton'
import ViolinPlot, { ViolinPlotData } from '../plots/ViolinPlot'
import ZScorePlot from '../plots/ZScorePlot'
import { Option } from '../forms/FormProps'
import FormInputSelect from '../forms/FormInputSelect'
import ROCPlot, { ROCPlotData } from '../plots/RocCurvePlot'
import ViolinPlotTooltip from '../ViolinPlotTooltip'
import NoMaxWidthTooltip from '../NoMaxWidthTooltip'
import RocCurveTooltip from '../RocCurveTooltip'
import LogisticRegressionPlot from '../plots/LogisticRegressionPlot'
import OpenInFullRoundedIcon from '@mui/icons-material/OpenInFullRounded'
import ViolinPlotDialog from '../plots/ViolinPlotDialog'
import { useIsMutating, useQuery } from '@tanstack/react-query'
import type {
  FilteredPatientClassesFunction,
  FormInput,
  HighlightedSampleData,
  PrsResultsProps,
  RelevantVariants,
} from './types'
import { useUnweightedPrsCalculations, useWeightedPrsResults } from './hooks'

const PrsResults: React.FC<PrsResultsProps> = (props) => {
  const { backendRequest } = useAuth()

  const { control, watch } = useForm<FormInput>({
    defaultValues: {
      prsMethod: 'PLINK',
      controlGroup: '',
      experimentalGroup: '',
      highlightedPatient: '',
    },
  })

  const isMutating = useIsMutating()

  const selectedPRSMethod = watch('prsMethod')
  const selectedControlGroup = watch('controlGroup')
  const selectedExperimentalGroup = watch('experimentalGroup')
  const selectedHighlightedPatient = watch('highlightedPatient')

  const prsMethodOptions: Option[] = useMemo(
    () => [
      { id: 'plink_prs', label: 'PLINK' },
      { id: 'unweighted_prs', label: 'Unweighted' },
      { id: 'linear_weighted_prs', label: 'Linear Weighted' },
      { id: 'linear_weighted_log_or_prs', label: 'Linear Weighted Log(OR)' },
      { id: 'prsice_avg_prs', label: 'PRSice Avg' },
      { id: 'prsice_std_prs', label: 'PRSice Std' },
    ],
    []
  )

  const relevantWeightedSnps = useQuery<RelevantVariants>({
    queryKey: [
      'analysis',
      { baseDataId: props.baseDataId, targetDataId: props.targetData.id },
      'gwas-variants',
      'weighted',
    ],
    queryFn: async () => {
      const responseRelevantWeightedSnps = await backendRequest({
        method: 'POST',
        endpoint:
          '/api/gwas-variants/get-relevant-variants?prs_method_type=WEIGHTED',
        requiresAuth: true,
        body: {
          gwas_variants_ids: props.snpIds,
        },
      })
      return responseRelevantWeightedSnps.data
    },
    enabled:
      !!props.snpIds &&
      props.snpIds.length > 0 &&
      !!props.baseDataId &&
      !!props.targetData.id,
    staleTime: 300_000,
  })

  const relevantUnweightedSnps = useQuery<RelevantVariants>({
    queryKey: [
      'analysis',
      { baseDataId: props.baseDataId, targetDataId: props.targetData.id },
      'gwas-variants',
      'unweighted',
    ],
    queryFn: async () => {
      const responseRelevantUnweightedSnps = await backendRequest({
        method: 'POST',
        endpoint:
          '/api/gwas-variants/get-relevant-variants?prs_method_type=UNWEIGHTED',
        requiresAuth: true,
        body: {
          gwas_variants_ids: props.snpIds,
        },
      })
      return responseRelevantUnweightedSnps.data
    },
    enabled:
      !!props.snpIds &&
      props.snpIds.length > 0 &&
      !!props.baseDataId &&
      !!props.targetData.id,
    staleTime: 300_000,
  })

  const filteredPatientClasses: FilteredPatientClassesFunction =
    useCallback(() => {
      return props.targetData.patient_classes.map((item) => ({
        patient_class_id: item.patient_class.id,
        patient_ids: item.patients_data.map((patient) => patient.id),
      }))
    }, [props.targetData.patient_classes])

  const weightedPrsResults = useWeightedPrsResults(
    props.baseDataId,
    props.targetData.id,
    relevantWeightedSnps.data,
    filteredPatientClasses
  )

  const unweightedPrsResults = useUnweightedPrsCalculations(
    props.baseDataId,
    props.targetData.id,
    relevantUnweightedSnps.data,
    filteredPatientClasses
  )

  const isLoading = [
    props.isLoading,
    relevantWeightedSnps.isLoading,
    relevantUnweightedSnps.isLoading,

    weightedPrsResults.isLoading,
    weightedPrsResults.isFetching,

    unweightedPrsResults.isLoading,
    unweightedPrsResults.isFetching,

    isMutating > 0,
  ].some((loading) => loading)

  const groupsOptions = useCallback(() => {
    return props.targetData.patient_classes.map((item) => ({
      id: item.patient_class.class_name,
      label: item.patient_class.class_name,
    }))
  }, [props.targetData.patient_classes])

  const highlightedPatientOptions = useCallback(() => {
    const modifiedTargetDataPatientsData =
      props.targetData.patient_classes.map((item) =>
        item.patients_data.map((patient) => patient.sample_name)
      ) || []

    return modifiedTargetDataPatientsData.flat().map((patient) => ({
      id: patient,
      label: patient,
    }))
  }, [props.targetData.patient_classes])

  const violinPlotData: ViolinPlotData[] | undefined = useMemo(() => {
    if (
      unweightedPrsResults.data == undefined ||
      weightedPrsResults.data == undefined
    ) {
      return undefined
    }
    const selectedPRSMethodId = prsMethodOptions.find(
      (option) => option.label === selectedPRSMethod
    )?.id
    if (!selectedPRSMethodId) {
      return []
    }

    // Check if the selected method is 'unweighted_prs' and use the unweighted results if so
    const isUnweighted = selectedPRSMethodId === 'unweighted_prs'
    const currentResults = isUnweighted
      ? unweightedPrsResults.data
      : weightedPrsResults.data

    if (Object.keys(currentResults).length > 0) {
      return currentResults.map((patientClass) => ({
        groupName: patientClass.patient_class.class_name,
        y: patientClass.patient_prs_results.map(
          (patientPrsResult) =>
            patientPrsResult[selectedPRSMethodId] ||
            patientPrsResult.unweighted_prs
        ),
        yText: patientClass.patient_prs_results.map(
          (patientPrsResult) => patientPrsResult.patient.sample_name
        ),
      }))
    }
    return []
  }, [selectedPRSMethod, unweightedPrsResults.data, weightedPrsResults.data])

  const highlightedPatientData: HighlightedSampleData | undefined =
    useMemo(() => {
      if (
        unweightedPrsResults.data == undefined ||
        weightedPrsResults.data == undefined
      ) {
        return undefined
      }
      const selectedPRSMethodId = prsMethodOptions.find(
        (option) => option.label === selectedPRSMethod
      )?.id
      if (!selectedPRSMethodId || !selectedHighlightedPatient) {
        return undefined
      }

      const isUnweighted = selectedPRSMethodId === 'unweighted_prs'
      const currentResults = isUnweighted
        ? unweightedPrsResults.data
        : weightedPrsResults.data

      for (const patientClass of currentResults) {
        const patient = patientClass.patient_prs_results.find(
          (prsResult) =>
            prsResult.patient.sample_name === selectedHighlightedPatient
        )
        if (patient) {
          return {
            y: isUnweighted
              ? patient.unweighted_prs
              : patient[selectedPRSMethodId],
            yText: patient.patient.sample_name,
            isControl: patient.class_name === selectedControlGroup,
            patientClass: patientClass.patient_class.class_name,
          }
        }
      }

      return undefined
    }, [
      selectedPRSMethod,
      selectedControlGroup,
      selectedHighlightedPatient,
      unweightedPrsResults.data,
      weightedPrsResults.data,
    ])

  const ROCPlotData: ROCPlotData[] | undefined = useMemo(() => {
    if (
      unweightedPrsResults.data == undefined ||
      weightedPrsResults.data == undefined
    ) {
      return undefined
    }
    const selectedPRSMethodId = prsMethodOptions.find(
      (option) => option.label === selectedPRSMethod
    )?.id
    if (!selectedPRSMethodId) {
      return []
    }

    // Check if the selected method is 'unweighted_prs' and use the unweighted results if so
    const isUnweighted = selectedPRSMethodId === 'unweighted_prs'
    const currentResults = isUnweighted
      ? unweightedPrsResults.data
      : weightedPrsResults.data

    if (Object.keys(currentResults).length > 0) {
      return currentResults.map((patientClass) => ({
        groupName: patientClass.patient_class.class_name,
        y: patientClass.patient_prs_results.map((patientPrsResult) =>
          isUnweighted
            ? patientPrsResult.unweighted_prs
            : patientPrsResult[selectedPRSMethodId]
        ),
      }))
    }
    return []
  }, [selectedPRSMethod, unweightedPrsResults.data, weightedPrsResults.data])

  const numberOfSnps = useMemo(() => {
    if (
      relevantUnweightedSnps.data == undefined ||
      relevantWeightedSnps.data == undefined
    ) {
      return undefined
    }
    if (
      Object.keys(relevantUnweightedSnps.data).length === 0 ||
      Object.keys(relevantWeightedSnps.data).length === 0
    ) {
      return undefined
    }

    if (selectedPRSMethod === 'Unweighted') {
      return relevantUnweightedSnps.data.deduplicated_gwas_variants_ids.length
    } else {
      return relevantWeightedSnps.data.deduplicated_gwas_variants_ids.length
    }
  }, [selectedPRSMethod, relevantUnweightedSnps, relevantWeightedSnps])

  const [open, setOpen] = React.useState(false)

  const handleClickOpen = () => {
    setOpen(true)
  }

  const handleClose = () => {
    setOpen(false)
  }

  return (
    <Grid
      container
      direction='column'
      justifyContent='space-between'
      alignItems='stretch'
      height={1}
    >
      <Grid
        container
        item
        direction='row'
        justifyContent='space-between'
        alignItems='center'
      >
        <Grid item md={6} xs={8}>
          <Box display='flex' alignItems='center' gap={2}>
            <PageHeading icon={iconsObj.ANALYSES} id='prs-results-heading'>
              {props.pageHeading}
            </PageHeading>
            <Typography variant='h6' textAlign='left'>
              Num. of SNPs: {numberOfSnps}
            </Typography>
          </Box>
        </Grid>
        <Grid item xs={4} md={6} lg={6} xl={4}>
          <Stack direction='row'>
            <FormInputSelect
              name='highlightedPatient'
              label='Select Highlighted Patient'
              control={control}
              options={highlightedPatientOptions()}
              isLoading={props.isLoading || isLoading}
              placeholder='Unselected'
            />
            <FormInputSelect
              name='prsMethod'
              label='Select PRS Method'
              control={control}
              options={prsMethodOptions}
              isLoading={props.isLoading || isLoading}
            />
          </Stack>
        </Grid>
      </Grid>
      <Grid
        container
        direction='row'
        justifyContent='space-between'
        alignItems='center'
        item
        height={1}
      >
        <Grid item xs={12} sm={12} md={6} lg={5} xl={3}>
          <Stack direction='row'>
            <ListItemWithAvatar
              icon={iconsObj.TARGET_DATA}
              primaryLabel='Target Data'
              secondaryLabel={props.targetDataName}
            />
            <ListItemWithAvatar
              icon={iconsObj.BASE_DATA}
              primaryLabel='Base Data'
              secondaryLabel={props.baseDataName}
            />
          </Stack>
        </Grid>
        <Grid item xs={4} md={6} lg={6} xl={4}>
          <Stack direction='row'>
            <FormInputSelect
              name='controlGroup'
              label='Select Control Group'
              control={control}
              options={groupsOptions()}
              isLoading={props.isLoading || isLoading}
            />
            <FormInputSelect
              name='experimentalGroup'
              label='Select Experimental Group'
              control={control}
              options={groupsOptions()}
              isLoading={props.isLoading || isLoading}
            />
          </Stack>
        </Grid>
      </Grid>
      {/* PRS charts Grid */}
      <Grid
        container
        item
        direction='row'
        mt={3}
        justifyContent='space-between'
        alignItems='stretch'
        height={1}
      >
        <Grid item xs={12} md={12} xl={6} height='calc(40vh)'>
          <Stack direction='row' spacing={1} sx={{ alignItems: 'center' }}>
            <Typography variant='h6' textAlign='left'>
              Violin Plot
            </Typography>
            <NoMaxWidthTooltip
              sx={{ justifySelf: 'start' }}
              placement='right'
              title={
                <ViolinPlotTooltip
                  selectedPRSMethod={
                    prsMethodOptions.find(
                      (option) => option.label === selectedPRSMethod
                    )?.id as string
                  }
                  unweightedPrsResults={unweightedPrsResults.data ?? []}
                  weightedPrsResults={weightedPrsResults.data ?? []}
                  columns={groupsOptions().map((option) => option.label)}
                  numberOfGroups={groupsOptions().length}
                  isLoading={props.isLoading || isLoading}
                />
              }
            >
              <IconButton size='small' sx={{ px: 1 }}>
                {iconsObj.INFO}
              </IconButton>
            </NoMaxWidthTooltip>
            <IconButton onClick={handleClickOpen} sx={{ justifySelf: 'end' }}>
              <OpenInFullRoundedIcon fontSize='small' />
            </IconButton>
            <ViolinPlotDialog
              open={open}
              handleClose={handleClose}
              data={violinPlotData ?? []}
              isLoading={props.isLoading || isLoading}
              controlClass={selectedControlGroup}
              highlightedSampleData={highlightedPatientData}
            />
          </Stack>
          <ViolinPlot
            dialogMode={false}
            data={violinPlotData ?? []}
            isLoading={isLoading}
            controlClass={selectedControlGroup}
            highlightedSampleData={highlightedPatientData}
          />
        </Grid>
        <Grid item xs={12} md={12} xl={6} height='calc(40vh)'>
          <Typography variant='h6' textAlign='left'>
            Logistic Regression
          </Typography>
          <LogisticRegressionPlot
            isLoading={props.isLoading || isLoading}
            selectedPRSMethod={
              prsMethodOptions.find(
                (option) => option.label === selectedPRSMethod
              )?.id as string
            }
            controlClass={selectedControlGroup}
            experimentalClass={selectedExperimentalGroup}
            unweightedPrsResults={unweightedPrsResults.data ?? []}
            weightedPrsResults={weightedPrsResults.data ?? []}
            highlightedSampleData={highlightedPatientData}
          />
        </Grid>
        <Grid item xs={12} md={12} xl={6} height='calc(30vh)'>
          <Typography variant='h6' textAlign='left'>
            Z-Score
          </Typography>
          <ZScorePlot
            data={violinPlotData ?? []}
            isLoading={props.isLoading || isLoading}
            controlClass={selectedControlGroup}
            highlightedSampleData={highlightedPatientData}
          />
        </Grid>
        <Grid item xs={12} md={12} xl={6} height='calc(30vh)'>
          <Stack direction='row' spacing={1}>
            <Typography variant='h6' textAlign='left'>
              ROC Curve
            </Typography>
            <NoMaxWidthTooltip
              placement='right'
              title={
                <RocCurveTooltip
                  selectedPRSMethod={
                    prsMethodOptions.find(
                      (option) => option.label === selectedPRSMethod
                    )?.id as string
                  }
                  unweightedPrsResults={unweightedPrsResults.data ?? []}
                  weightedPrsResults={weightedPrsResults.data ?? []}
                  controlClass={selectedControlGroup}
                  experimentalClass={selectedExperimentalGroup}
                  isLoading={props.isLoading || isLoading}
                />
              }
            >
              <IconButton size='small' sx={{ px: 1 }}>
                {iconsObj.INFO}
              </IconButton>
            </NoMaxWidthTooltip>
          </Stack>
          <ROCPlot
            data={ROCPlotData ?? []}
            isLoading={props.isLoading || isLoading}
            controlClass={selectedControlGroup}
            experimentalClass={selectedExperimentalGroup}
          />
        </Grid>
      </Grid>
    </Grid>
  )
}

export default PrsResults
