#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import logging
import time
import datetime
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np

from bob.ip.binseg.utils.metric import SmoothedValue
from bob.ip.binseg.utils.plot import loss_curve

def sharpen(x, T):
    temp = x**(1/T)
    return temp / temp.sum(dim=1, keepdim=True)

def mix_up(alpha, input, target, unlabeled_input, unlabled_target):
    """Applies mix up as described in [MIXMATCH_19].

    Parameters
    ----------
    alpha : float

    input : :py:class:`torch.Tensor`

    target : :py:class:`torch.Tensor`

    unlabeled_input : :py:class:`torch.Tensor`

    unlabled_target : :py:class:`torch.Tensor`


    Returns
    -------

    list

    """
    # TODO:
    with torch.no_grad():
        l = np.random.beta(alpha, alpha) # Eq (8)
        l = max(l, 1 - l) # Eq (9)
        # Shuffle and concat. Alg. 1 Line: 12
        w_inputs = torch.cat([input,unlabeled_input],0)
        w_targets = torch.cat([target,unlabled_target],0)
        idx = torch.randperm(w_inputs.size(0)) # get random index

        # Apply MixUp to labeled data and entries from W. Alg. 1 Line: 13
        input_mixedup = l * input + (1 - l) * w_inputs[idx[len(input):]]
        target_mixedup = l * target + (1 - l) * w_targets[idx[len(target):]]

        # Apply MixUp to unlabeled data and entries from W. Alg. 1 Line: 14
        unlabeled_input_mixedup = l * unlabeled_input + (1 - l) * w_inputs[idx[:len(unlabeled_input)]]
        unlabled_target_mixedup =  l * unlabled_target + (1 - l) * w_targets[idx[:len(unlabled_target)]]
        return input_mixedup, target_mixedup, unlabeled_input_mixedup, unlabled_target_mixedup


def square_rampup(current, rampup_length=16):
    """slowly ramp-up ``lambda_u``

    Parameters
    ----------

    current : int
        current epoch

    rampup_length : :obj:`int`, optional
        how long to ramp up, by default 16

    Returns
    -------

    factor : float
        ramp up factor
    """

    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip((current/ float(rampup_length))**2, 0.0, 1.0)
    return float(current)

def linear_rampup(current, rampup_length=16):
    """slowly ramp-up ``lambda_u``

    Parameters
    ----------
    current : int
        current epoch

    rampup_length : :obj:`int`, optional
        how long to ramp up, by default 16

    Returns
    -------

    factor: float
        ramp up factor

    """
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
    return float(current)

def guess_labels(unlabeled_images, model):
    """
    Calculate the average predictions by 2 augmentations: horizontal and vertical flips

    Parameters
    ----------

    unlabeled_images : :py:class:`torch.Tensor`
        ``[n,c,h,w]``

    target : :py:class:`torch.Tensor`

    Returns
    -------

    shape : :py:class:`torch.Tensor`
        ``[n,c,h,w]``

    """
    with torch.no_grad():
        guess1 = torch.sigmoid(model(unlabeled_images)).unsqueeze(0)
        # Horizontal flip and unsqueeze to work with batches (increase flip dimension by 1)
        hflip = torch.sigmoid(model(unlabeled_images.flip(2))).unsqueeze(0)
        guess2  = hflip.flip(3)
        # Vertical flip and unsqueeze to work with batches (increase flip dimension by 1)
        vflip = torch.sigmoid(model(unlabeled_images.flip(3))).unsqueeze(0)
        guess3 = vflip.flip(4)
        # Concat
        concat = torch.cat([guess1,guess2,guess3],0)
        avg_guess = torch.mean(concat,0)
        return avg_guess

def do_ssltrain(
    model,
    data_loader,
    optimizer,
    criterion,
    scheduler,
    checkpointer,
    checkpoint_period,
    device,
    arguments,
    output_folder,
    rampup_length
):
    """
    Train model and save to disk.

    Parameters
    ----------

    model : :py:class:`torch.nn.Module`
        Network (e.g. DRIU, HED, UNet)

    data_loader : :py:class:`torch.utils.data.DataLoader`

    optimizer : :py:mod:`torch.optim`

    criterion : :py:class:`torch.nn.modules.loss._Loss`
        loss function

    scheduler : :py:mod:`torch.optim`
        learning rate scheduler

    checkpointer : :py:class:`bob.ip.binseg.utils.checkpointer.DetectronCheckpointer`
        checkpointer

    checkpoint_period : int
        save a checkpoint every n epochs

    device : str
        device to use ``'cpu'`` or ``'cuda'``

    arguments : dict
        start end end epochs

    output_folder : str
        output path

    rampup_length : int
        rampup epochs

    """
    logger = logging.getLogger("bob.ip.binseg.engine.trainer")
    logger.info("Start training")
    start_epoch = arguments["epoch"]
    max_epoch = arguments["max_epoch"]

    # Logg to file
    with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+",1) as outfile:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

        model.train().to(device)
        # Total training timer
        start_training_time = time.time()
        for epoch in range(start_epoch, max_epoch):
            scheduler.step()
            losses = SmoothedValue(len(data_loader))
            labeled_loss = SmoothedValue(len(data_loader))
            unlabeled_loss = SmoothedValue(len(data_loader))
            epoch = epoch + 1
            arguments["epoch"] = epoch

            # Epoch time
            start_epoch_time = time.time()

            for samples in tqdm(data_loader):
                # labeled
                images = samples[1].to(device)
                ground_truths = samples[2].to(device)
                unlabeled_images = samples[4].to(device)
                # labeled outputs
                outputs = model(images)
                unlabeled_outputs = model(unlabeled_images)
                # guessed unlabeled outputs
                unlabeled_ground_truths = guess_labels(unlabeled_images, model)
                #unlabeled_ground_truths = sharpen(unlabeled_ground_truths,0.5)
                #images, ground_truths, unlabeled_images, unlabeled_ground_truths = mix_up(0.75, images, ground_truths, unlabeled_images, unlabeled_ground_truths)
                ramp_up_factor = square_rampup(epoch,rampup_length=rampup_length)

                loss, ll, ul = criterion(outputs, ground_truths, unlabeled_outputs, unlabeled_ground_truths, ramp_up_factor)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                losses.update(loss)
                labeled_loss.update(ll)
                unlabeled_loss.update(ul)
                logger.debug("batch loss: {}".format(loss.item()))

            if epoch % checkpoint_period == 0:
                checkpointer.save("model_{:03d}".format(epoch), **arguments)

            if epoch == max_epoch:
                checkpointer.save("model_final", **arguments)

            epoch_time = time.time() - start_epoch_time


            eta_seconds = epoch_time * (max_epoch - epoch)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            outfile.write(("{epoch}, "
                        "{avg_loss:.6f}, "
                        "{median_loss:.6f}, "
                        "{median_labeled_loss},"
                        "{median_unlabeled_loss},"
                        "{lr:.6f}, "
                        "{memory:.0f}"
                        "\n"
                        ).format(
                    eta=eta_string,
                    epoch=epoch,
                    avg_loss=losses.avg,
                    median_loss=losses.median,
                    median_labeled_loss = labeled_loss.median,
                    median_unlabeled_loss = unlabeled_loss.median,
                    lr=optimizer.param_groups[0]["lr"],
                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0,
                    )
                )
            logger.info(("eta: {eta}, "
                        "epoch: {epoch}, "
                        "avg. loss: {avg_loss:.6f}, "
                        "median loss: {median_loss:.6f}, "
                        "labeled loss: {median_labeled_loss}, "
                        "unlabeled loss: {median_unlabeled_loss}, "
                        "lr: {lr:.6f}, "
                        "max mem: {memory:.0f}"
                        ).format(
                    eta=eta_string,
                    epoch=epoch,
                    avg_loss=losses.avg,
                    median_loss=losses.median,
                    median_labeled_loss = labeled_loss.median,
                    median_unlabeled_loss = unlabeled_loss.median,
                    lr=optimizer.param_groups[0]["lr"],
                    memory = (torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) if torch.cuda.is_available() else .0
                    )
                )


        total_training_time = time.time() - start_training_time
        total_time_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info(
            "Total training time: {} ({:.4f} s / epoch)".format(
                total_time_str, total_training_time / (max_epoch)
            ))

    log_plot_file = os.path.join(output_folder,"{}_trainlog.pdf".format(model.name))
    logdf = pd.read_csv(os.path.join(output_folder,"{}_trainlog.csv".format(model.name)),header=None, names=["avg. loss", "median loss", "labeled loss", "unlabeled loss", "lr","max memory"])
    fig = loss_curve(logdf,output_folder)
    logger.info("saving {}".format(log_plot_file))
    fig.savefig(log_plot_file)