import React, { useEffect, useState } from 'react'
import Plot from 'react-plotly.js'
import { Box, Skeleton } from '@mui/material'
import { useAuth } from '../../context/AuthContext'
import { enqueueSnackbar } from 'notistack'
import { HighlightedSampleData, PrsCalculationsResponse } from '../prsResults/types'

export type Props = {
  isLoading: boolean
  selectedPRSMethod: string
  controlClass: string
  experimentalClass: string
  unweightedPrsResults: PrsCalculationsResponse
  weightedPrsResults: PrsCalculationsResponse
  highlightedSampleData?: HighlightedSampleData
}

type LogisticRegressionResponse = {
  x_range: number[]
  y_range: number[]
  highlighted_sample: number
}

const getPrsResultsByMethodAndClass = (prsResults: PrsCalculationsResponse, className: string, selectedPRSMethod: string) => {
  return (
    prsResults
      .filter((patient_classes) => patient_classes.patient_class.class_name === className)
      .map((patient_classes) => patient_classes.patient_prs_results.map((prsResult) => prsResult[selectedPRSMethod]))[0] || []
  )
}

const LogisticRegressionPlot: React.FC<Props> = (props) => {
  const [logisticRegressionData, setLogisticRegressionData] = useState({} as LogisticRegressionResponse)
  const [scores, setScores] = useState<number[]>([])
  const [categories, setCategories] = useState<number[]>([])

  const { backendRequest, logout } = useAuth()

  useEffect(() => {
    if (
      Object.keys(props.weightedPrsResults).length > 0 &&
      Object.keys(props.unweightedPrsResults).length > 0 &&
      props.controlClass !== '' &&
      props.experimentalClass !== ''
    ) {
      const getData = async () => {
        const prsResults = props.selectedPRSMethod === 'unweighted_prs' ? props.unweightedPrsResults : props.weightedPrsResults

        const controlGroup = getPrsResultsByMethodAndClass(prsResults, props.controlClass, props.selectedPRSMethod)
        const experimentalGroup = getPrsResultsByMethodAndClass(prsResults, props.experimentalClass, props.selectedPRSMethod)

        setScores([...controlGroup, ...experimentalGroup])
        const categories = Array(controlGroup.length).fill(0).concat(Array(experimentalGroup.length).fill(1))
        setCategories(categories)

        const responseLogisticRegression = await backendRequest({
          method: 'POST',
          endpoint: '/api/prs-statistics/logistic-regression',
          requiresAuth: true,
          params: {
            highlighted_sample: props.highlightedSampleData?.y,
          },
          body: {
            control_group: controlGroup,
            experimental_group: experimentalGroup,
          },
        })

        const logisticRegressionResponse = responseLogisticRegression.data as LogisticRegressionResponse

        return {
          responseLogisticRegressionStatus: responseLogisticRegression.status,
          logisticRegressionData: logisticRegressionResponse,
        }
      }
      getData()
        .then(({ responseLogisticRegressionStatus: responseLogisticRegressionStatus, logisticRegressionData: logisticRegressionData }) => {
          if (responseLogisticRegressionStatus === 200 && logisticRegressionData) {
            setLogisticRegressionData(logisticRegressionData)
          } else if (responseLogisticRegressionStatus === 401) {
            logout()
          }
        })
        .catch((error) => {
          enqueueSnackbar(`An error occurred. ${error.response?.data.detail}`, { variant: 'error' })
        })
        .finally()
    }
  }, [props.selectedPRSMethod, props.weightedPrsResults, props.unweightedPrsResults, props.highlightedSampleData, props.controlClass, props.experimentalClass])

  return !props.isLoading ? (
    <Plot
      data={[
        {
          hoverinfo: 'text',
          hovertext: scores.map((score, index) => `x=${score}, y=${categories[index]}`),
          marker: { color: categories.map((category) => (category === 0 ? 'green' : 'red')), symbol: 'circle' },
          mode: 'markers',
          type: 'scatter',
          x: scores,
          y: categories,
        },
        {
          name: 'Disease probability',
          type: 'scatter',
          mode: 'lines',
          x: logisticRegressionData?.x_range,
          y: logisticRegressionData?.y_range,
          line: { color: '#ef553b' },
        },
      ]}
      layout={{
        margin: { t: 60 },
        plot_bgcolor: 'white',
        showlegend: false,
        xaxis: {
          title: 'Risk Score',
          showgrid: false,
        },
        yaxis: {
          title: 'Disease Risk',
          zeroline: false,
        },
        hovermode: 'closest',
        shapes: logisticRegressionData?.highlighted_sample
          ? [
              {
                type: 'line',
                xref: 'x',
                yref: 'paper',
                x0: props.highlightedSampleData?.y,
                y0: 0,
                x1: props.highlightedSampleData?.y,
                y1: 1,
                line: { color: 'red', width: 2, dash: 'dot' },
              },
            ]
          : [],
      }}
      style={{ width: '100%', height: '90%' }}
    />
  ) : (
    <Box width='100%' height='100%' p={3}>
      <Skeleton variant='rounded' height='90%' />
    </Box>
  )
}

export default LogisticRegressionPlot
