Commit 96a6087f authored by Francois Marelli's avatar Francois Marelli


parent eebe3e0c
import torch
from torch.nn import MSELoss
class LogMSELoss(MSELoss):
r"""Creates a criterion that measures the logarithmic mean squared error between
`n` elements in the input `x` and target `y`:
:math:`{loss}(x, y) = \log( 1/n \sum |x_i - y_i|^2 + epsilon)`
`x` and `y` arbitrary shapes with a total of `n` elements each.
The sum operation still operates over all the elements, and divides by `n`.
The division by `n` can be avoided if one sets the internal variable
`size_average` to ``False``.
To get a batch of losses, a loss per batch element, set `reduce` to
``False``. These losses are not averaged and are not affected by
The epsilon is a positive float used to avoid log(0) leading to NaN.
size_average (bool, optional): By default, the losses are averaged
over observations for each minibatch. However, if the field
size_average is set to ``False``, the losses are instead summed for
each minibatch. Only applies when reduce is ``True``. Default: ``True``
reduce (bool, optional): By default, the losses are averaged
over observations for each minibatch, or summed, depending on
size_average. When reduce is ``False``, returns a loss per batch
element instead and ignores size_average. Default: ``True``
epsilon (float, optional): add a small positive term to the MSE before
taking the log to avoid NaN with log(0). Default: ``0.05``
- Input: :math:`(N, *)` where `*` means, any number of additional
- Target: :math:`(N, *)`, same shape as the input
>>> loss = neural_filters.LogMSELoss()
>>> input = autograd.Variable(torch.randn(3, 5), requires_grad=True)
>>> target = autograd.Variable(torch.randn(3, 5))
>>> output = loss(input, target)
>>> output.backward()
def __init__(self, size_average=True, reduce=True, epsilon=0.05):
super(LogMSELoss, self).__init__(size_average, reduce)
self.epsilon = epsilon
def forward(self, input, target):
loss = super(LogMSELoss, self).forward(input, target)
return torch.log(loss + self.epsilon)
\ No newline at end of file
......@@ -3,3 +3,4 @@ from .NeuralFilter1P import *
from .NeuralFilter2R import *
from .NeuralFilter2CC import *
from .NeuralFilter2CD import *
from .LogMSELoss import *
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment