#!/usr/bin/env python
# coding=utf-8

import os

import click
import torch
from torch.utils.data import DataLoader

from bob.extension.scripts.click_helper import (
    verbosity_option,
    ConfigCommand,
    ResourceOption,
)

from ..utils.checkpointer import DetectronCheckpointer

import logging
logger = logging.getLogger(__name__)


@click.command(
    entry_point_group="bob.ip.binseg.config",
    cls=ConfigCommand,
    epilog="""Examples:

\b
    1. Trains a U-Net model (VGG-16 backbone) with DRIVE (vessel segmentation),
       on a GPU (``cuda:0``):

       $ bob binseg train -vv unet drive --batch-size=4 --device="cuda:0"

    2. Trains a HED model with HRF on a GPU (``cuda:0``):

       $ bob binseg train -vv hed hrf --batch-size=8 --device="cuda:0"

    3. Trains a M2U-Net model on the COVD-DRIVE dataset on the CPU:

       $ bob binseg train -vv m2unet covd-drive --batch-size=8

    4. Trains a DRIU model with SSL on the COVD-HRF dataset on the CPU:

       $ bob binseg train -vv --ssl driu-ssl covd-drive-ssl --batch-size=1

""",
)
@click.option(
    "--output-folder",
    "-o",
    help="Path where to store the generated model (created if does not exist)",
    required=True,
    type=click.Path(),
    default="results",
    cls=ResourceOption,
)
@click.option(
    "--model",
    "-m",
    help="A torch.nn.Module instance implementing the network to be trained",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--dataset",
    "-d",
    help="A torch.utils.data.dataset.Dataset instance implementing a dataset "
    "to be used for training the model, possibly including all pre-processing "
    "pipelines required or, optionally, a dictionary mapping string keys to "
    "torch.utils.data.dataset.Dataset instances.  At least one key "
    "named ``train`` must be available.  This dataset will be used for "
    "training the network model.  The dataset description must include all "
    "required pre-processing, including eventual data augmentation.  If a "
    "dataset named ``__train__`` is available, it is used prioritarily for "
    "training instead of ``train``.  If a dataset named ``__valid__`` is "
    "available, it is used for model validation (and automatic check-pointing) "
    "at each epoch.",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--optimizer",
    help="A torch.optim.Optimizer that will be used to train the network",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--criterion",
    help="A loss function to compute the FCN error for every sample "
    "respecting the PyTorch API for loss functions (see torch.nn.modules.loss)",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--scheduler",
    help="A learning rate scheduler that drives changes in the learning "
    "rate depending on the FCN state (see torch.optim.lr_scheduler)",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--pretrained-backbone",
    "-t",
    help="URL of a pre-trained model file that will be used to preset "
    "FCN weights (where relevant) before training starts "
    "(e.g. vgg16, mobilenetv2)",
    required=True,
    cls=ResourceOption,
)
@click.option(
    "--batch-size",
    "-b",
    help="Number of samples in every batch (this parameter affects "
    "memory requirements for the network).  If the number of samples in "
    "the batch is larger than the total number of samples available for "
    "training, this value is truncated.  If this number is smaller, then "
    "batches of the specified size are created and fed to the network "
    "until there are no more new samples to feed (epoch is finished).  "
    "If the total number of training samples is not a multiple of the "
    "batch-size, the last batch will be smaller than the first, unless "
    "--drop-incomplete--batch is set, in which case this batch is not used.",
    required=True,
    show_default=True,
    default=2,
    type=click.IntRange(min=1),
    cls=ResourceOption,
)
@click.option(
    "--drop-incomplete-batch/--no-drop-incomplete-batch",
    "-D",
    help="If set, then may drop the last batch in an epoch, in case it is "
    "incomplete.  If you set this option, you should also consider "
    "increasing the total number of epochs of training, as the total number "
    "of training steps may be reduced",
    required=True,
    show_default=True,
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--epochs",
    "-e",
    help="Number of epochs (complete training set passes) to train for",
    show_default=True,
    required=True,
    default=1000,
    type=click.IntRange(min=1),
    cls=ResourceOption,
)
@click.option(
    "--checkpoint-period",
    "-p",
    help="Number of epochs after which a checkpoint is saved. "
    "A value of zero will disable check-pointing. If checkpointing is "
    "enabled and training stops, it is automatically resumed from the "
    "last saved checkpoint if training is restarted with the same "
    "configuration.",
    show_default=True,
    required=True,
    default=0,
    type=click.IntRange(min=0),
    cls=ResourceOption,
)
@click.option(
    "--device",
    "-d",
    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
    show_default=True,
    required=True,
    default="cpu",
    cls=ResourceOption,
)
@click.option(
    "--seed",
    "-s",
    help="Seed to use for the random number generator",
    show_default=True,
    required=False,
    default=42,
    type=click.IntRange(min=0),
    cls=ResourceOption,
)
@click.option(
    "--ssl/--no-ssl",
    help="Switch ON/OFF semi-supervised training mode",
    show_default=True,
    required=True,
    default=False,
    cls=ResourceOption,
)
@click.option(
    "--rampup",
    "-r",
    help="Ramp-up length in epochs (for SSL training only)",
    show_default=True,
    required=True,
    default=900,
    type=click.IntRange(min=0),
    cls=ResourceOption,
)
@verbosity_option(cls=ResourceOption)
def train(
    model,
    optimizer,
    scheduler,
    output_folder,
    epochs,
    pretrained_backbone,
    batch_size,
    drop_incomplete_batch,
    criterion,
    dataset,
    checkpoint_period,
    device,
    seed,
    ssl,
    rampup,
    verbose,
    **kwargs,
):
    """Trains an FCN to perform binary segmentation

    Training is performed for a configurable number of epochs, and generates at
    least a final_model.pth.  It may also generate a number of intermediate
    checkpoints.  Checkpoints are model files (.pth files) that are stored
    during the training and useful to resume the procedure in case it stops
    abruptly.
    """

    torch.manual_seed(seed)

    use_dataset = dataset
    validation_dataset = None
    if isinstance(dataset, dict):
        if "__train__" in dataset:
            logger.info("Found (dedicated) '__train__' set for training")
            use_dataset = dataset["__train__"]
        else:
            use_dataset = dataset["train"]

        if "__valid__" in dataset:
            logger.info("Found (dedicated) '__valid__' set for validation")
            logger.info("Will checkpoint lowest loss model on validation set")
            validation_dataset = dataset["__valid__"]

    # PyTorch dataloader
    data_loader = DataLoader(
        dataset=use_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=drop_incomplete_batch,
        pin_memory=torch.cuda.is_available(),
    )

    valid_loader = None
    if validation_dataset is not None:
        valid_loader = DataLoader(
                dataset=validation_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
                pin_memory=torch.cuda.is_available(),
                )

    # Checkpointer
    checkpointer = DetectronCheckpointer(
        model, optimizer, scheduler, save_dir=output_folder, save_to_disk=True
    )

    arguments = {}
    arguments["epoch"] = 0
    extra_checkpoint_data = checkpointer.load(pretrained_backbone)
    arguments.update(extra_checkpoint_data)
    arguments["max_epoch"] = epochs

    logger.info("Training for {} epochs".format(arguments["max_epoch"]))
    logger.info("Continuing from epoch {}".format(arguments["epoch"]))

    if not ssl:
        from ..engine.trainer import run
        run(
            model,
            data_loader,
            valid_loader,
            optimizer,
            criterion,
            scheduler,
            checkpointer,
            checkpoint_period,
            device,
            arguments,
            output_folder,
        )

    else:
        from ..engine.ssltrainer import run
        run(
            model,
            data_loader,
            valid_loader,
            optimizer,
            criterion,
            scheduler,
            checkpointer,
            checkpoint_period,
            device,
            arguments,
            output_folder,
            rampup,
        )