import * as yup from 'yup';
import {
  LR0,
  LRF,
  SEED,
  MOMENTUM,
  WEIGHT_DECAY,
  WARMUP_STEPS,
  WARMUP_EPOCHS,
  WARMUP_MOMENTUM,
  WARMUP_BIAS_LR,
  SCHD_POWER,
  REDUCE_ZERO_LABEL,
  BOX,
  CLS,
  CLS_PW,
  OBJ,
  OBJ_PW,
  IOU_T,
  ANCHOR_T,
  FL_GAMMA,
} from '@netspresso/shared';
import { EPOCHS_PER_TRIAL } from '../constants';
import { Tasks } from '../lib';
import {
  isImageClassification,
  isSemanticSegmentation,
  parseHypPlaceholderCls,
  parseHypPlaceholderDet,
  parseHypPlaceholderSeg,
} from '../utils';

export const HyperparameterScheme = (task: Tasks): yup.Schema => {
  if (isImageClassification(task)) {
    return ClassifiHyperScheme;
  }

  if (isSemanticSegmentation(task)) {
    return SegmentationHyperScheme;
  }

  return DetectionHyperScheme;
};

export const ClassifiHyperScheme = yup.object({
  [LR0]: yup.number().typeError(parseHypPlaceholderCls(LR0)).min(1e-6).max(0.1),
  [LRF]: yup.number().typeError(parseHypPlaceholderCls(LRF)).min(1e-6).max(0.1),
  [SEED]: yup.number().typeError(parseHypPlaceholderCls(SEED)).integer('seed must be an integer.').min(0),
  [MOMENTUM]: yup.number().typeError(parseHypPlaceholderCls(MOMENTUM)).min(0).max(1),
  [WEIGHT_DECAY]: yup.number().typeError(parseHypPlaceholderCls(WEIGHT_DECAY)).min(0).max(0.9),
  [WARMUP_STEPS]: yup
    .number()
    .typeError(parseHypPlaceholderCls(WARMUP_STEPS))
    .integer('warmup_steps must be an integer.')
    .min(0),
  [WARMUP_EPOCHS]: yup.number().typeError(parseHypPlaceholderCls(WARMUP_EPOCHS)).min(0),
  [WARMUP_MOMENTUM]: yup.number().typeError(parseHypPlaceholderCls(WARMUP_MOMENTUM)).min(0).max(1),
  [WARMUP_BIAS_LR]: yup.number().typeError(parseHypPlaceholderCls(WARMUP_BIAS_LR)).min(0).max(0.01),
  [SCHD_POWER]: yup.number().typeError(parseHypPlaceholderCls(SCHD_POWER)).min(0),
});

export const DetectionHyperScheme = yup.object({
  [LR0]: yup.number().typeError(parseHypPlaceholderDet(LR0)).min(1e-6).lessThan(1),
  [LRF]: yup.number().typeError(parseHypPlaceholderDet(LRF)).min(1e-6).lessThan(1),
  [MOMENTUM]: yup.number().typeError(parseHypPlaceholderDet(MOMENTUM)).moreThan(0).lessThan(1),
  [WEIGHT_DECAY]: yup.number().typeError(parseHypPlaceholderDet(WEIGHT_DECAY)).min(0).max(1),
  [WARMUP_EPOCHS]: yup.number().typeError(parseHypPlaceholderDet(WARMUP_EPOCHS)).min(0).max(yup.ref(EPOCHS_PER_TRIAL)),
  [WARMUP_MOMENTUM]: yup.number().typeError(parseHypPlaceholderDet(WARMUP_MOMENTUM)).moreThan(0).lessThan(1),
  [WARMUP_BIAS_LR]: yup.number().typeError(parseHypPlaceholderDet(WARMUP_BIAS_LR)).moreThan(0).lessThan(1),
  [BOX]: yup.number().typeError(parseHypPlaceholderDet(BOX)).min(0).lessThan(1),
  [CLS]: yup.number().typeError(parseHypPlaceholderDet(CLS)).min(0).lessThan(1),
  [CLS_PW]: yup.number().typeError(parseHypPlaceholderDet(CLS_PW)).moreThan(0),
  [OBJ]: yup.number().typeError(parseHypPlaceholderDet(OBJ)).min(0).max(1),
  [OBJ_PW]: yup.number().typeError(parseHypPlaceholderDet(OBJ_PW)).moreThan(0),
  [IOU_T]: yup.number().typeError(parseHypPlaceholderDet(IOU_T)).moreThan(0).lessThan(1),
  [ANCHOR_T]: yup.number().typeError(parseHypPlaceholderDet(ANCHOR_T)).moreThan(0),
  [FL_GAMMA]: yup.number().typeError(parseHypPlaceholderDet(FL_GAMMA)).min(0),
  [EPOCHS_PER_TRIAL]: yup.number().oneOf([3, 100, 150, 200, 250, 300, 350, 450]),
});

export const SegmentationHyperScheme = yup.object({
  [LR0]: yup.number().default(6e-5),
  [LRF]: yup.number().default(0),
  [SEED]: yup.number().typeError(parseHypPlaceholderSeg(SEED)).integer('seed must be an integer.').min(1),
  [WEIGHT_DECAY]: yup.number().typeError(parseHypPlaceholderSeg(WEIGHT_DECAY)).min(0).max(0.01),
  [WARMUP_STEPS]: yup
    .number()
    .typeError(parseHypPlaceholderSeg(WARMUP_STEPS))
    .integer('warmup_steps must be an integer.')
    .min(100)
    .max(2000),
  [SCHD_POWER]: yup.number().default(1),
  [REDUCE_ZERO_LABEL]: yup.bool().default(false),
});

export const HyperparameterDefaultValues = (task: Tasks): Record<string, number | boolean> => {
  if (isImageClassification(task)) {
    return ClassifiHyperDefaultValues;
  }

  if (isSemanticSegmentation(task)) {
    return SegmentationHyperDefaultValues;
  }

  return DetectionHyperDefaultValues;
};

export const ClassifiHyperDefaultValues = {
  [LR0]: 0.01,
  [LRF]: 0.01,
  [SEED]: 1,
  [MOMENTUM]: 0.937,
  [WEIGHT_DECAY]: 0.0005,
  [WARMUP_STEPS]: 100,
  [WARMUP_EPOCHS]: 5,
  [WARMUP_MOMENTUM]: 0.8,
  [WARMUP_BIAS_LR]: 0.0001,
  [SCHD_POWER]: 0.1,
};

export const DetectionHyperDefaultValues = {
  [LR0]: 0.01,
  [LRF]: 0.01,
  [MOMENTUM]: 0.937,
  [WEIGHT_DECAY]: 0.0005,
  [WARMUP_EPOCHS]: 3,
  [WARMUP_MOMENTUM]: 0.8,
  [WARMUP_BIAS_LR]: 0.1,
  [BOX]: 0.05,
  [CLS]: 0.5,
  [CLS_PW]: 1.0,
  [OBJ]: 1.0,
  [OBJ_PW]: 1.0,
  [IOU_T]: 0.2,
  [ANCHOR_T]: 4,
  [FL_GAMMA]: 0,
  [EPOCHS_PER_TRIAL]: 200,
};

export const SegmentationHyperDefaultValues = {
  [LR0]: 6.0e-5,
  [LRF]: 0,
  [SEED]: 1,
  [WEIGHT_DECAY]: 0.01,
  [WARMUP_STEPS]: 100,
  [SCHD_POWER]: 1,
  [REDUCE_ZERO_LABEL]: false,
};
