diff --git a/bob/learn/tensorflow/loss/pixel_wise.py b/bob/learn/tensorflow/loss/pixel_wise.py
new file mode 100644
index 0000000000000000000000000000000000000000..b34695045c20273bdc0063a928ceb723324eca6d
--- /dev/null
+++ b/bob/learn/tensorflow/loss/pixel_wise.py
@@ -0,0 +1,63 @@
+from ..dataset import tf_repeat
+from .utils import (
+    balanced_softmax_cross_entropy_loss_weights,
+    balanced_sigmoid_cross_entropy_loss_weights,
+)
+import tensorflow as tf
+
+
+class PixelWise:
+    """A pixel wise loss which is just a cross entropy loss but applied to all pixels"""
+
+    def __init__(
+        self, balance_weights=True, n_one_hot_labels=None, label_smoothing=0.5, **kwargs
+    ):
+        super(PixelWise, self).__init__(**kwargs)
+        self.balance_weights = balance_weights
+        self.n_one_hot_labels = n_one_hot_labels
+        self.label_smoothing = label_smoothing
+
+    def __call__(self, labels, logits):
+        with tf.name_scope("PixelWiseLoss"):
+            flatten = tf.keras.layers.Flatten()
+            logits = flatten(logits)
+            n_pixels = logits.get_shape()[-1]
+            weights = 1.0
+            if self.balance_weights and self.n_one_hot_labels:
+                # use labels to figure out the required loss
+                weights = balanced_softmax_cross_entropy_loss_weights(
+                    labels, dtype=logits.dtype
+                )
+                # repeat weights for all pixels
+                weights = tf_repeat(weights[:, None], [1, n_pixels])
+                weights = tf.reshape(weights, (-1,))
+            elif self.balance_weights and not self.n_one_hot_labels:
+                # use labels to figure out the required loss
+                weights = balanced_sigmoid_cross_entropy_loss_weights(
+                    labels, dtype=logits.dtype
+                )
+                # repeat weights for all pixels
+                weights = tf_repeat(weights[:, None], [1, n_pixels])
+
+            if self.n_one_hot_labels:
+                labels = tf_repeat(labels, [n_pixels, 1])
+                labels = tf.reshape(labels, (-1, self.n_one_hot_labels))
+                # reshape logits too as softmax_cross_entropy is buggy and cannot really
+                # handle higher dimensions
+                logits = tf.reshape(logits, (-1, self.n_one_hot_labels))
+                loss_fn = tf.losses.softmax_cross_entropy
+            else:
+                labels = tf.reshape(labels, (-1, 1))
+                labels = tf_repeat(labels, [n_pixels, 1])
+                labels = tf.reshape(labels, (-1, n_pixels))
+                loss_fn = tf.losses.sigmoid_cross_entropy
+
+            loss_pixel_wise = loss_fn(
+                labels,
+                logits=logits,
+                weights=weights,
+                label_smoothing=self.label_smoothing,
+                reduction=tf.losses.Reduction.MEAN,
+            )
+        tf.summary.scalar("loss_pixel_wise", loss_pixel_wise)
+        return loss_pixel_wise