/// <reference types="styled-components/cssprop" />
import React, { useState, useReducer, useRef, useCallback } from 'react'
import styled from 'styled-components/macro'
import * as tf from '@tensorflow/tfjs'
import * as SIP from 'sip.js'
import { store } from '../../index'
import { ReactComponent as CallIcon } from '../../icons/call.svg'
import { ReactComponent as CallingIcon } from '../../icons/calling.svg'
import { Input, Button, Checkbox } from '../controls'
import withStateLifter from '../StateLifter'
import { Train, Config } from '../../api'
import { FormHelper } from '../helpers'
import makeCall from './makeCall'
import { Sequential } from '@tensorflow/tfjs'
import DataTable from '../DataTable'

const trainResults = {
  1: 'მიუწვდომელია',
  2: 'არარეგისტრირებულია',
  3: 'ზარი გადის',
}

const xList = [] as any //trainData.map((s) => s.map((x) => x.map((y) => [y])))
const yList = [] as any

let trainModel: tf.Sequential

const columns = [
  {
    name: 'phoneNumber',
    title: 'ტელ. ნომერი',
    editable: true,
  },
  {
    name: 'result',
    title: 'შედეგი',
  },
  {
    name: 'label',
    title: 'ლეიბლი',
  },
  {
    name: 'odds',
    title: 'ალბათობები',
  },
]

const initialState = {
  registered: false,
  calling: false,
  time: '',
  status: '',
  totalCalls: 0,
  currentPhoneCalls: 0,
  phoneNumber: '',
  labels: [],
  testNumbers: [],
}

function drawSpectrogram(data) {
  const cnv = document.getElementById('myCanvas') as HTMLCanvasElement
  const ctx = cnv.getContext('2d') as CanvasRenderingContext2D
  data.forEach((time, x) => {
    time.forEach((freq, y) => {
      ctx.fillStyle = `rgb(${freq}, 0, 0)`
      ctx.fillRect(x, y, 10, 1)
    })
  })
}

const TrainForm = ({ lifter }) => {
  const [state, setState] = useReducer(
    (state, newState) => ({ ...state, ...newState }),
    initialState
  )

  const {
    registered,
    calling,
    time,
    status,
    totalCalls,
    currentPhoneCalls,
    phoneNumber,
  } = state

  const buttonRef = useRef<HTMLButtonElement>(null)
  const remoteAudioRef = useRef<HTMLVideoElement>(null)

  const refs = useRef({ testNumbers: initialState.testNumbers } as any)
  const { current } = refs

  current.state = state
  current.setState = setState
  current.buttonRef = buttonRef
  current.remoteAudioRef = remoteAudioRef

  React.useEffect(() => {
    Promise.all([Config.getSip(), Train.getLabels()])
      .then(([conf, labels]) => {
        current.config = conf
        setState({ labels })
      })
      .finally(() => {
        const { config } = current

        const { current: remoteAudio } = remoteAudioRef

        const ua = new SIP.UA({
          register: true,
          uri: `${config.username}@${config.registrar}`,
          password: config.password,
          transportOptions: {
            wsServers: config.wsServer,
          },
          // @ts-ignore
          media: { remote: { audio: remoteAudio } },
        })

        current.userAgent = ua

        ua.on('registered', (...args) => {
          setState({ registered: true })
        })

        ua.on('unregistered', (...args) => {
          setState({ registered: false })
        })
      })
  }, [current])

  const onDataReceive = useCallback(
    async (data, train) => {
      const tensorData = data.map((x) => x.map((y) => [y]))

      if (train) {
        const { label } = current
        await Train.create({
          data,
          phoneNumber: label.phoneNumber,
          trainLabelId: label.trainLabelId,
        })

        xList.push(tensorData)
        yList.push(label.labelValue)
      } else if (trainModel) {
        // while (true) {
        //   const emptyIndex = tensorData.findIndex((x) => x.every((y) => !y[0]))
        //   if (emptyIndex < 0) break
        //   tensorData.splice(emptyIndex, 1)
        // }

        let slicedData = tensorData

        // while (true) {
        //   //const emptyIndex = slicedData.findIndex((x) => x.every((y) => !y))
        //   const isEmpty = slicedData[0].every((y) => !y)
        //   //if (emptyIndex < 0) break
        //   if (!isEmpty) break
        //   slicedData.splice(0, 1)
        // }

        //let sampleData = tensorData.slice(0, FRAMES)
        // if (sampleData.length < FRAMES) {
        //   for (let index = sampleData.length; index < FRAMES - 1; index++) {
        //     const empty: number[] = Array(256)
        //     empty.fill(0, 0, 256)
        //     sampleData.push(empty)
        //   }
        // }

        // go thru first 30 frames to eliminate missing frames
        let maxOdd = 0
        let maxOdds
        let index = 0
        for (index = 0; index < 30; index++) {
          slicedData = tensorData.slice(index, FRAMES + index)
          const result = trainModel.predict(
            tf.tensor4d([slicedData], [1, FRAMES, 256, 1])
          )
          // @ts-ignore
          const odds = result.arraySync()[0]
          const odd = Math.max(...odds)
          if (odd > maxOdd) {
            maxOdd = odd
            maxOdds = odds
          }
          if (maxOdd === 1) break
        }

        const testNumber = current.testNumbers[current.testIndex]
        const maxIndex = maxOdds.indexOf(maxOdd)

        if (maxOdd < 0.99) {
          testNumber.result = 'ამოცნობა შეუძლებელია'
          testNumber.label = 'ამოცნობა შეუძლებელია'
        } else {
          const label = state.labels.find(
            (label) => label.labelValue === maxIndex
          )

          testNumber.result = trainResults[label.result]
          testNumber.label = label.labelText
        }

        testNumber.odds = `Max: ${maxOdd}, Value: ${maxIndex}, Frame: ${index}`
        //odds.map((x: number) => parseFloat(x.toFixed(5)))
        //.join(', ')

        current.testNumbers = [...current.testNumbers]
        setState({ testNumbers: current.testNumbers })
      }

      drawSpectrogram(data)
    },
    [current, state.labels]
  )

  current.onDataReceive = onDataReceive

  const onCall = useCallback(
    (train) => {
      setState({ train })
      if (train) {
        const callsPerPhone = parseInt(
          FormHelper.getValue(lifter, 'callsPerPhone')
        )

        current.trainLabels = [...state.labels] //TODO: filter by train and phone number
        if (!current.trainLabels.length) return

        current.callsPerPhone = callsPerPhone
        current.currentPhoneCalls = 0
        setState({ currentPhoneCalls: 0 })

        makeCall(current, train)
      } else {
        current.testIndex = 0
        //const predictNumber =  FormHelper.getValue(lifter, 'predictNumber')
        //current.predictNumber = predictNumber
        makeCall(current, train)
      }
    },
    [current, lifter, state.labels]
  )

  const onNewModel = useCallback(() => {
    trainModel = tf.sequential({
      layers: [
        tf.layers.depthwiseConv2d({
          inputShape: [FRAMES, 256, 1],
          kernelSize: 3, //window of convolution, less is better i guess
          activation: 'relu',
        }),
        // tf.layers.depthwiseConv2d({
        //   kernelSize: 3, //window of convolution, less is better i guess
        //   activation: 'relu',
        // }),
        tf.layers.maxPooling2d({ poolSize: [5, 5], strides: [5, 5] }),
        tf.layers.dropout({ rate: 0.1 }),

        tf.layers.depthwiseConv2d({
          kernelSize: 3, //window of convolution, less is better i guess
          activation: 'relu',
        }),
        // tf.layers.depthwiseConv2d({
        //   kernelSize: 3, //window of convolution, less is better i guess
        //   activation: 'relu',
        // }),
        tf.layers.maxPooling2d({ poolSize: [5, 5], strides: [5, 5] }),
        //tf.layers.dropout({ rate: 0.1 }),

        // tf.layers.depthwiseConv2d({
        //   //inputShape: [156, 256, 1],
        //   kernelSize: 5, //window of convolution, less is better i guess
        //   activation: 'relu',
        // }),
        // tf.layers.maxPooling2d({ poolSize: [5, 5], strides: [5, 5] }),
        // tf.layers.dropout({ rate: 0.1 }),
        // tf.layers.depthwiseConv2d({
        //   //inputShape: [156, 256, 1],
        //   kernelSize: 2, //window of convolution, less is better i guess
        //   activation: 'relu',
        // }),
        // tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }),
        tf.layers.flatten(),
        // tf.layers.dense({
        //   units: state.labels.length * 20,
        //   activation: 'sigmoid',
        // }),
        tf.layers.dense({
          units: state.labels.length,
          activation: 'softmax',
        }),
      ],
    })

    setState({ trainModel })
  }, [state.labels.length])

  const onTrainModel = React.useCallback(() => {
    if (!trainModel) {
      onNewModel()
    }

    // compile model
    const learningRate = 0.01
    const adam = tf.train.adam(learningRate)
    trainModel.compile({
      optimizer: adam,
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy'],
    })

    const tensorX = tf.tensor4d(xList)
    const tensorY = tf.oneHot(tf.tensor1d(yList).toInt(), state.labels.length)

    trainModel.fit(tensorX, tensorY, {
      epochs: 10,
      batchSize: 10,
      validationSplit: 0.1,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          console.log('Epoch: ' + epoch + ', logs: ', logs)
          if (
            logs &&
            logs.acc > 0.99 &&
            logs.val_acc > 0.99 &&
            logs.loss < 0.01 &&
            logs.val_loss < 0.01
          ) {
            trainModel.stopTraining = true
          }
        },
      },
    })
  }, [state.labels.length, onNewModel])

  const onSaveModel = React.useCallback(() => {
    const state = store.getState()
    const { access_token } = state.oidc.user

    trainModel.save(
      tf.io.http('api/train/model', {
        requestInit: {
          method: 'POST',
          headers: { authorization: `Bearer ${access_token}` },
        },
      })
    )
  }, [])

  const onLoadModel = React.useCallback(async () => {
    trainModel = (await tf.loadLayersModel(
      'api/train/model/model.json'
    )) as Sequential

    console.log(trainModel)

    setState({ trainModel })
  }, [])

  const FRAMES = 156

  const onLoadData = React.useCallback(async () => {
    //const valueCount = {}
    let lastId = 0
    while (true) {
      const response = await Train.getOne(lastId)
      if (!response) break

      for (const item of response) {
        lastId = item.id
        // const count = valueCount[item.labelValue]
        // if (!count) {
        //   valueCount[item.labelValue] = 1
        // } else {
        //   //if (count === 30) continue
        //   valueCount[item.labelValue]++
        // }

        //let slicedData = item.data

        //let slicedData = item.data

        // while (true) {
        //   //const emptyIndex = slicedData.findIndex((x) => x.every((y) => !y))
        //   const isEmpty = slicedData[0].every((y) => !y)
        //   //if (emptyIndex < 0) break
        //   if (!isEmpty) break
        //   slicedData.splice(0, 1)
        // }

        let slicedData = item.data.slice(0, FRAMES)

        // if (slicedData.length < frames) {
        //   for (let index = slicedData.length; index < frames - 1; index++) {
        //     const empty: number[] = Array(256)
        //     empty.fill(0, 0, 256)
        //     slicedData.push(empty)
        //   }
        // }

        // const reduced: number[][][] = []
        // sampleData.forEach((x) => {
        //   const newX: number[][] = []
        //   for (let index = 0; index < x.length; index += 2) {
        //     newX.push([(x[index] + x[index + 1]) / 2])
        //   }
        //   reduced.push(newX)
        // })

        //reduced = reduced.map((x) => x.map((y) => [y]))
        slicedData = slicedData.map((x) => x.map((y) => [y]))
        xList.push(slicedData)

        yList.push(item.labelValue)
      }
    }
  }, [])

  const onStopTraining = useCallback(() => {
    trainModel.stopTraining = true
  }, [])

  return (
    <div className="flex-fill overflow-auto flex-column">
      <video id="remoteAudio" css="display: none" ref={remoteAudioRef}></video>
      <div className="flex-row-center-space">
        <div>
          დარეკვების რაოდენობა თითოეულ ნომერზე
          <Input
            className="width-10 gap-left"
            required
            disabled={calling}
            number
            integer
            initialValue="10"
            name="callsPerPhone"
            placeholder="დარეკვების რაოდენობა"
          />
        </div>

        <div>
          <Button onClick={onLoadData} disabled={!registered || calling}>
            Load data
          </Button>

          <Button
            className="gap-left"
            onClick={onNewModel}
            disabled={!registered || calling}
          >
            New Model
          </Button>

          <Button
            className="gap-left"
            onClick={onTrainModel}
            disabled={!registered || calling}
          >
            Train Model
          </Button>

          <Button
            className="gap-left"
            onClick={onStopTraining}
            disabled={!registered || calling}
          >
            Stop Training
          </Button>

          <Button
            className="gap-left"
            onClick={onSaveModel}
            disabled={!registered || calling}
          >
            Save Model
          </Button>

          <Button
            className="gap-left"
            onClick={onLoadModel}
            disabled={!registered || calling}
          >
            Load Model
          </Button>
        </div>
      </div>
      <div className="flex-row flex-fill">
        <div className="flex-column gap-top">
          {state.labels &&
            state.labels.map((label) => (
              <div
                className="flex-row-center gap-bottom"
                css={state.label === label ? 'background-color: gainsboro' : ''}
              >
                <div css="width: 30px">{label.labelValue}</div>
                <Checkbox
                  initialValue={true}
                  name={`labelTrain${label.trainLabelId}`}
                  className="gap-right"
                  disabled
                />
                <Input
                  placeholder="ტელ. ნომერი"
                  initialValue={label.phoneNumber}
                  name={`labelPhoneNumber${label.trainLabelId}`}
                  css="width: 120px"
                  disabled
                />
                <div css="width: 170px" className="gap-left">
                  {trainResults[label.result]}
                </div>
                <div css="width: 410px" className="gap-left">
                  {label.labelText}
                </div>
              </div>
            ))}
        </div>

        <div className="flex-column gap-left flex-fill">
          <div className="flex-row">
            <div className="flex-column">
              <div className="flex-row-center gap-bottom gap-top">
                <Button
                  css="padding: 5px;"
                  className="gap-right"
                  disabled={!registered || calling}
                  // @ts-ignore
                  danger={calling}
                  onClick={() => onCall(true)}
                  ref={buttonRef}
                >
                  {calling ? (
                    <CallingIcon css="width: 26px; height: 24px" />
                  ) : (
                    <CallIcon css="width: 26px; height: 24px" />
                  )}
                  მოგროვება
                </Button>

                {state.train && (
                  <>
                    {!!phoneNumber && (
                      <span className="gap-left">{phoneNumber}</span>
                    )}
                    {!!currentPhoneCalls && (
                      <span className="gap-left">{currentPhoneCalls}</span>
                    )}
                    {!!totalCalls && (
                      <span className="gap-left">{totalCalls}</span>
                    )}
                    {calling && <span className="gap-left">{time}</span>}
                    {/* {!!status && <span className="gap-left">{status}</span>} */}
                  </>
                )}
              </div>

              <canvas
                id="myCanvas"
                width="312"
                height="256"
                css="width: 312px; height: 256px"
              ></canvas>
            </div>

            {state.trainModel && (
              <div className="flex-column gap-top gap-left">
                <div>Model name: {state.trainModel.name}</div>
                {state.trainModel.layers.map((layer, index) => (
                  <div>
                    Layer {index}: {layer.name}
                  </div>
                ))}
              </div>
            )}
          </div>

          <div className="flex-row-center gap-top">
            <Button
              css="padding: 5px;"
              disabled={!registered || calling || !state.trainModel}
              // @ts-ignore
              danger={calling}
              onClick={() => onCall(false)}
            >
              {calling ? (
                <CallingIcon css="width: 26px; height: 24px" />
              ) : (
                <CallIcon css="width: 26px; height: 24px" />
              )}
              ამოცნობა
            </Button>

            {!state.train && (
              <>
                {!!phoneNumber && (
                  <span className="gap-left">{phoneNumber}</span>
                )}
                {!!currentPhoneCalls && (
                  <span className="gap-left">{currentPhoneCalls}</span>
                )}
                {!!totalCalls && <span className="gap-left">{totalCalls}</span>}
                {calling && <span className="gap-left">{time}</span>}
              </>
            )}
          </div>

          <div
            className="flex-row-stretch gap-top flex-fill"
            css="padding-bottom: 1px;"
          >
            {/* <Input
              className="width-10 gap-right"
              required
              disabled={calling}
              name="predictNumber"
              placeholder="შეიყვანეთ ნომრები ამოსაცნობად"
              css="width: 260px"
            /> */}

            <DataTable
              rows={state.testNumbers}
              columns={columns}
              appendable={!calling}
              clipboardImport={!calling}
              //getContextMenuItems={this.getContextMenuItems}
              //onRowSelected={this.handleRowSelected}
              //cellClassName={this.getCellClassName}
              //loading={callsFetching}
            />
          </div>

          <div>{state.predictionResult}</div>
          <div>{state.predictionLabel}</div>
          <div>{state.odds}</div>
        </div>
      </div>
    </div>
  )
}

export default withStateLifter(TrainForm)
