"""Loss implementations"""

import torch
from torch.nn.modules.loss import _Loss

# Conditionally decorates a method if a decorator exists in PyTorch
# This overcomes an import error with versions of PyTorch >= 1.2, where the
# decorator ``weak_script_method`` is not anymore available.  See:
# https://github.com/pytorch/pytorch/commit/10c4b98ade8349d841518d22f19a653a939e260c#diff-ee07db084d958260fd24b4b02d4f078d
# from July 4th, 2019.
try:
    from torch._jit_internal import weak_script_method
except ImportError:

    def weak_script_method(x):
        return x


class WeightedBCELogitsLoss(_Loss):
    """
    Implements Equation 1 in [MANINIS-2016]_. Based on
    :py:class:`torch.nn.BCEWithLogitsLoss`.

    Calculate sum of weighted cross entropy loss.
    """

    def __init__(
        self,
        weight=None,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super(WeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer("weight", weight)
        self.register_buffer("pos_weight", pos_weight)

    @weak_script_method
    def forward(self, input, target, masks=None):
        """
        Parameters
        ----------
        input : :py:class:`torch.Tensor`
        target : :py:class:`torch.Tensor`
        masks : :py:class:`torch.Tensor`, optional

        Returns
        -------
        :py:class:`torch.Tensor`
        """
        n, c, h, w = target.shape
        num_pos = (
            torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1)
        )  # torch.Size([n, 1])
        if hasattr(masks, "dtype"):
            num_mask_neg = c * h * w - torch.sum(masks, dim=[1, 2, 3]).float().reshape(
                n, 1
            )  # torch.Size([n, 1])
            num_neg = c * h * w - num_pos - num_mask_neg
        else:
            num_neg = c * h * w - num_pos
        numposnumtotal = torch.ones_like(target) * (
            num_pos / (num_pos + num_neg)
        ).unsqueeze(1).unsqueeze(2)
        numnegnumtotal = torch.ones_like(target) * (
            num_neg / (num_pos + num_neg)
        ).unsqueeze(1).unsqueeze(2)
        weight = torch.where((target <= 0.5), numposnumtotal, numnegnumtotal)

        loss = torch.nn.functional.binary_cross_entropy_with_logits(
            input, target, weight=weight, reduction=self.reduction
        )
        return loss


class SoftJaccardBCELogitsLoss(_Loss):
    """
    Implements Equation 3 in [IGLOVIKOV-2018]_.  Based on
    ``torch.nn.BCEWithLogitsLoss``.

    Attributes
    ----------
    alpha : float
        determines the weighting of SoftJaccard and BCE. Default: ``0.7``
    """

    def __init__(
        self,
        alpha=0.7,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super(SoftJaccardBCELogitsLoss, self).__init__(size_average, reduce, reduction)
        self.alpha = alpha

    @weak_script_method
    def forward(self, input, target, masks=None):
        """
        Parameters
        ----------
        input : :py:class:`torch.Tensor`
        target : :py:class:`torch.Tensor`
        masks : :py:class:`torch.Tensor`, optional

        Returns
        -------
        :py:class:`torch.Tensor`
        """
        eps = 1e-8
        probabilities = torch.sigmoid(input)
        intersection = (probabilities * target).sum()
        sums = probabilities.sum() + target.sum()

        softjaccard = intersection / (sums - intersection + eps)

        bceloss = torch.nn.functional.binary_cross_entropy_with_logits(
            input, target, weight=None, reduction=self.reduction
        )
        loss = self.alpha * bceloss + (1 - self.alpha) * (1 - softjaccard)
        return loss


class HEDWeightedBCELogitsLoss(_Loss):
    """
    Implements Equation 2 in [HE-2015]_. Based on
    ``torch.nn.modules.loss.BCEWithLogitsLoss``.

    Calculate sum of weighted cross entropy loss.
    """

    def __init__(
        self,
        weight=None,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super(HEDWeightedBCELogitsLoss, self).__init__(size_average, reduce, reduction)
        self.register_buffer("weight", weight)
        self.register_buffer("pos_weight", pos_weight)

    @weak_script_method
    def forward(self, inputlist, target, masks=None):
        """
        Parameters
        ----------
        inputlist : list of :py:class:`torch.Tensor`
            HED uses multiple side-output feature maps for the loss calculation
        target : :py:class:`torch.Tensor`
        masks : :py:class:`torch.Tensor`, optional
        Returns
        -------
        :py:class:`torch.Tensor`
        """
        loss_over_all_inputs = []
        for input in inputlist:
            n, c, h, w = target.shape
            num_pos = (
                torch.sum(target, dim=[1, 2, 3]).float().reshape(n, 1)
            )  # torch.Size([n, 1])
            if hasattr(masks, "dtype"):
                num_mask_neg = c * h * w - torch.sum(
                    masks, dim=[1, 2, 3]
                ).float().reshape(
                    n, 1
                )  # torch.Size([n, 1])
                num_neg = c * h * w - num_pos - num_mask_neg
            else:
                num_neg = c * h * w - num_pos  # torch.Size([n, 1])
            numposnumtotal = torch.ones_like(target) * (
                num_pos / (num_pos + num_neg)
            ).unsqueeze(1).unsqueeze(2)
            numnegnumtotal = torch.ones_like(target) * (
                num_neg / (num_pos + num_neg)
            ).unsqueeze(1).unsqueeze(2)
            weight = torch.where((target <= 0.5), numposnumtotal, numnegnumtotal)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                input, target, weight=weight, reduction=self.reduction
            )
            loss_over_all_inputs.append(loss.unsqueeze(0))
        final_loss = torch.cat(loss_over_all_inputs).mean()
        return final_loss


class HEDSoftJaccardBCELogitsLoss(_Loss):
    """

    Implements  Equation 3 in [IGLOVIKOV-2018]_ for the hed network. Based on
    :py:class:`torch.nn.BCEWithLogitsLoss`.

    Attributes
    ----------
    alpha : float
        determines the weighting of SoftJaccard and BCE. Default: ``0.3``
    """

    def __init__(
        self,
        alpha=0.3,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super(HEDSoftJaccardBCELogitsLoss, self).__init__(
            size_average, reduce, reduction
        )
        self.alpha = alpha

    @weak_script_method
    def forward(self, inputlist, target, masks=None):
        """
        Parameters
        ----------
        input : :py:class:`torch.Tensor`
        target : :py:class:`torch.Tensor`
        masks : :py:class:`torch.Tensor`, optional

        Returns
        -------
        :py:class:`torch.Tensor`
        """
        eps = 1e-8
        loss_over_all_inputs = []
        for input in inputlist:
            probabilities = torch.sigmoid(input)
            intersection = (probabilities * target).sum()
            sums = probabilities.sum() + target.sum()

            softjaccard = intersection / (sums - intersection + eps)

            bceloss = torch.nn.functional.binary_cross_entropy_with_logits(
                input, target, weight=None, reduction=self.reduction
            )
            loss = self.alpha * bceloss + (1 - self.alpha) * (1 - softjaccard)
            loss_over_all_inputs.append(loss.unsqueeze(0))
        final_loss = torch.cat(loss_over_all_inputs).mean()
        return final_loss


class MixJacLoss(_Loss):
    """

    Parameters
    ----------

    lambda_u : int
        determines the weighting of SoftJaccard and BCE.

    """

    def __init__(
        self,
        lambda_u=100,
        jacalpha=0.7,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super(MixJacLoss, self).__init__(size_average, reduce, reduction)
        self.lambda_u = lambda_u
        self.labeled_loss = SoftJaccardBCELogitsLoss(alpha=jacalpha)
        self.unlabeled_loss = torch.nn.BCEWithLogitsLoss()

    @weak_script_method
    def forward(self, input, target, unlabeled_input, unlabeled_traget, ramp_up_factor):
        """
        Parameters
        ----------

        input : :py:class:`torch.Tensor`
        target : :py:class:`torch.Tensor`
        unlabeled_input : :py:class:`torch.Tensor`
        unlabeled_traget : :py:class:`torch.Tensor`
        ramp_up_factor : float

        Returns
        -------

        list

        """
        ll = self.labeled_loss(input, target)
        ul = self.unlabeled_loss(unlabeled_input, unlabeled_traget)

        loss = ll + self.lambda_u * ramp_up_factor * ul
        return loss, ll, ul