From 38de7bdc60d2e849589cb35513b3c2b166bfe2fc Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 16:13:06 +0100
Subject: [PATCH] Add GAN tools

---
 bob/learn/tensorflow/gan/__init__.py          |   2 +
 bob/learn/tensorflow/gan/losses.py            | 171 ++++++++++
 .../tensorflow/gan/spectral_normalization.py  | 316 ++++++++++++++++++
 bob/learn/tensorflow/models/discriminator.py  |  99 +++++-
 4 files changed, 583 insertions(+), 5 deletions(-)
 create mode 100644 bob/learn/tensorflow/gan/__init__.py
 create mode 100644 bob/learn/tensorflow/gan/losses.py
 create mode 100644 bob/learn/tensorflow/gan/spectral_normalization.py

diff --git a/bob/learn/tensorflow/gan/__init__.py b/bob/learn/tensorflow/gan/__init__.py
new file mode 100644
index 00000000..502898e8
--- /dev/null
+++ b/bob/learn/tensorflow/gan/__init__.py
@@ -0,0 +1,2 @@
+from . import spectral_normalization
+from . import losses
diff --git a/bob/learn/tensorflow/gan/losses.py b/bob/learn/tensorflow/gan/losses.py
new file mode 100644
index 00000000..ec378245
--- /dev/null
+++ b/bob/learn/tensorflow/gan/losses.py
@@ -0,0 +1,171 @@
+import tensorflow as tf
+
+
+def relativistic_discriminator_loss(
+    discriminator_real_outputs,
+    discriminator_gen_outputs,
+    label_smoothing=0.25,
+    real_weights=1.0,
+    generated_weights=1.0,
+    scope=None,
+    loss_collection=tf.GraphKeys.LOSSES,
+    reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+    add_summaries=False,
+):
+    """Relativistic (average) loss
+
+  Args:
+    discriminator_real_outputs: Discriminator output on real data.
+    discriminator_gen_outputs: Discriminator output on generated data. Expected
+      to be in the range of (-inf, inf).
+    label_smoothing: The amount of smoothing for positive labels. This technique
+      is taken from `Improved Techniques for Training GANs`
+      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
+    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
+      `real_data`, and must be broadcastable to `real_data` (i.e., all
+      dimensions must be either `1`, or the same as the corresponding
+      dimension).
+    generated_weights: Same as `real_weights`, but for `generated_data`.
+    scope: The scope for the operations performed in computing the loss.
+    loss_collection: collection to which this loss will be added.
+    reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
+    add_summaries: Whether or not to add summaries for the loss.
+
+  Returns:
+    A loss Tensor. The shape depends on `reduction`.
+  """
+    with tf.name_scope(
+        scope,
+        "discriminator_relativistic_loss",
+        (
+            discriminator_real_outputs,
+            discriminator_gen_outputs,
+            real_weights,
+            generated_weights,
+            label_smoothing,
+        ),
+    ) as scope:
+
+        real_logit = discriminator_real_outputs - tf.reduce_mean(
+            discriminator_gen_outputs
+        )
+        fake_logit = discriminator_gen_outputs - tf.reduce_mean(
+            discriminator_real_outputs
+        )
+
+        loss_on_real = tf.losses.sigmoid_cross_entropy(
+            tf.ones_like(real_logit),
+            real_logit,
+            real_weights,
+            label_smoothing,
+            scope,
+            loss_collection=None,
+            reduction=reduction,
+        )
+        loss_on_generated = tf.losses.sigmoid_cross_entropy(
+            tf.zeros_like(fake_logit),
+            fake_logit,
+            generated_weights,
+            scope=scope,
+            loss_collection=None,
+            reduction=reduction,
+        )
+
+        loss = loss_on_real + loss_on_generated
+        tf.losses.add_loss(loss, loss_collection)
+
+        if add_summaries:
+            tf.summary.scalar("discriminator_gen_relativistic_loss", loss_on_generated)
+            tf.summary.scalar("discriminator_real_relativistic_loss", loss_on_real)
+            tf.summary.scalar("discriminator_relativistic_loss", loss)
+
+    return loss
+
+
+def relativistic_generator_loss(
+    discriminator_real_outputs,
+    discriminator_gen_outputs,
+    label_smoothing=0.0,
+    real_weights=1.0,
+    generated_weights=1.0,
+    scope=None,
+    loss_collection=tf.GraphKeys.LOSSES,
+    reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+    add_summaries=False,
+    confusion_labels=False,
+):
+    """Relativistic (average) loss
+
+  Args:
+    discriminator_real_outputs: Discriminator output on real data.
+    discriminator_gen_outputs: Discriminator output on generated data. Expected
+      to be in the range of (-inf, inf).
+    label_smoothing: The amount of smoothing for positive labels. This technique
+      is taken from `Improved Techniques for Training GANs`
+      (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
+    real_weights: Optional `Tensor` whose rank is either 0, or the same rank as
+      `real_data`, and must be broadcastable to `real_data` (i.e., all
+      dimensions must be either `1`, or the same as the corresponding
+      dimension).
+    generated_weights: Same as `real_weights`, but for `generated_data`.
+    scope: The scope for the operations performed in computing the loss.
+    loss_collection: collection to which this loss will be added.
+    reduction: A `tf.compat.v1.losses.Reduction` to apply to loss.
+    add_summaries: Whether or not to add summaries for the loss.
+
+  Returns:
+    A loss Tensor. The shape depends on `reduction`.
+  """
+    with tf.name_scope(
+        scope,
+        "generator_relativistic_loss",
+        (
+            discriminator_real_outputs,
+            discriminator_gen_outputs,
+            real_weights,
+            generated_weights,
+            label_smoothing,
+        ),
+    ) as scope:
+
+        real_logit = discriminator_real_outputs - tf.reduce_mean(
+            discriminator_gen_outputs
+        )
+        fake_logit = discriminator_gen_outputs - tf.reduce_mean(
+            discriminator_real_outputs
+        )
+
+        if confusion_labels:
+            real_labels = tf.ones_like(real_logit) / 2
+            fake_labels = tf.ones_like(fake_logit) / 2
+        else:
+            real_labels = tf.zeros_like(real_logit)
+            fake_labels = tf.ones_like(fake_logit)
+
+        loss_on_real = tf.losses.sigmoid_cross_entropy(
+            real_labels,
+            real_logit,
+            real_weights,
+            label_smoothing,
+            scope,
+            loss_collection=None,
+            reduction=reduction,
+        )
+        loss_on_generated = tf.losses.sigmoid_cross_entropy(
+            fake_labels,
+            fake_logit,
+            generated_weights,
+            scope=scope,
+            loss_collection=None,
+            reduction=reduction,
+        )
+
+        loss = loss_on_real + loss_on_generated
+        tf.losses.add_loss(loss, loss_collection)
+
+        if add_summaries:
+            tf.summary.scalar("generator_gen_relativistic_loss", loss_on_generated)
+            tf.summary.scalar("generator_real_relativistic_loss", loss_on_real)
+            tf.summary.scalar("generator_relativistic_loss", loss)
+
+    return loss
diff --git a/bob/learn/tensorflow/gan/spectral_normalization.py b/bob/learn/tensorflow/gan/spectral_normalization.py
new file mode 100644
index 00000000..ad2ecfaa
--- /dev/null
+++ b/bob/learn/tensorflow/gan/spectral_normalization.py
@@ -0,0 +1,316 @@
+# Copied from: https://github.com/tensorflow/tensorflow/blob/c4f40aea1d4f916aa3dfeb79f024c495ac609106/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras-like layers and utilities that implement Spectral Normalization.
+
+Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato,
+et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT-
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import numbers
+import re
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
+
+__all__ = [
+    'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer',
+    'spectral_normalization_custom_getter', 'keras_spectral_normalization'
+]
+
+# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then
+# can't directly be assigned back to the tf.bfloat16 variable.
+_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64)
+_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u'
+
+
+def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None):
+  """Estimates the largest singular value in the weight tensor.
+
+  Args:
+    w_tensor: The weight matrix whose spectral norm should be computed.
+    power_iteration_rounds: The number of iterations of the power method to
+      perform. A higher number yields a better approximation.
+    name: An optional scope name.
+
+  Returns:
+    The largest singular value (the spectral norm) of w.
+  """
+  with variable_scope.variable_scope(name, 'spectral_norm'):
+    # The paper says to flatten convnet kernel weights from
+    # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D
+    # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to
+    # (KH * KW * C_in, C_out), and similarly for other layers that put output
+    # channels as last dimension.
+    # n.b. this means that w here is equivalent to w.T in the paper.
+    w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1]))
+
+    # Persisted approximation of first left singular vector of matrix `w`.
+    u_var = variable_scope.get_variable(
+        _PERSISTED_U_VARIABLE_SUFFIX,
+        shape=(w.shape[0], 1),
+        dtype=w.dtype,
+        initializer=init_ops.random_normal_initializer(),
+        trainable=False)
+    u = u_var
+
+    # Use power iteration method to approximate spectral norm.
+    for _ in range(power_iteration_rounds):
+      # `v` approximates the first right singular vector of matrix `w`.
+      v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u))
+      u = nn.l2_normalize(math_ops.matmul(w, v))
+
+    # Update persisted approximation.
+    with ops.control_dependencies([u_var.assign(u, name='update_u')]):
+      u = array_ops.identity(u)
+
+    u = array_ops.stop_gradient(u)
+    v = array_ops.stop_gradient(v)
+
+    # Largest singular value of `w`.
+    spectral_norm = math_ops.matmul(
+        math_ops.matmul(array_ops.transpose(u), w), v)
+    spectral_norm.shape.assert_is_fully_defined()
+    spectral_norm.shape.assert_is_compatible_with([1, 1])
+
+    return spectral_norm[0][0]
+
+
+def spectral_normalize(w, power_iteration_rounds=1, name=None):
+  """Normalizes a weight matrix by its spectral norm.
+
+  Args:
+    w: The weight matrix to be normalized.
+    power_iteration_rounds: The number of iterations of the power method to
+      perform. A higher number yields a better approximation.
+    name: An optional scope name.
+
+  Returns:
+    A normalized weight matrix tensor.
+  """
+  with variable_scope.variable_scope(name, 'spectral_normalize'):
+    w_normalized = w / compute_spectral_norm(
+        w, power_iteration_rounds=power_iteration_rounds)
+    return array_ops.reshape(w_normalized, w.get_shape())
+
+
+def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None):
+  """Returns a functions that can be used to apply spectral norm regularization.
+
+  Small spectral norms enforce a small Lipschitz constant, which is necessary
+  for Wasserstein GANs.
+
+  Args:
+    scale: A scalar multiplier. 0.0 disables the regularizer.
+    power_iteration_rounds: The number of iterations of the power method to
+      perform. A higher number yields a better approximation.
+    scope: An optional scope name.
+
+  Returns:
+    A function with the signature `sn(weights)` that applies spectral norm
+    regularization.
+
+  Raises:
+    ValueError: If scale is negative or if scale is not a float.
+  """
+  if isinstance(scale, numbers.Integral):
+    raise ValueError('scale cannot be an integer: %s' % scale)
+  if isinstance(scale, numbers.Real):
+    if scale < 0.0:
+      raise ValueError(
+          'Setting a scale less than 0 on a regularizer: %g' % scale)
+    if scale == 0.0:
+      logging.info('Scale of 0 disables regularizer.')
+      return lambda _: None
+
+  def sn(weights, name=None):
+    """Applies spectral norm regularization to weights."""
+    with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name:
+      scale_t = ops.convert_to_tensor(
+          scale, dtype=weights.dtype.base_dtype, name='scale')
+      return math_ops.multiply(
+          scale_t,
+          compute_spectral_norm(
+              weights, power_iteration_rounds=power_iteration_rounds),
+          name=name)
+
+  return sn
+
+
+def _default_name_filter(name):
+  """A filter function to identify common names of weight variables.
+
+  Args:
+    name: The variable name.
+
+  Returns:
+    Whether `name` is a standard name for a weight/kernel variables used in the
+    Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries.
+  """
+  match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name)
+  return match is not None
+
+
+def spectral_normalization_custom_getter(name_filter=_default_name_filter,
+                                         power_iteration_rounds=1):
+  """Custom getter that performs Spectral Normalization on a weight tensor.
+
+  Specifically it divides the weight tensor by its largest singular value. This
+  is intended to stabilize GAN training, by making the discriminator satisfy a
+  local 1-Lipschitz constraint.
+
+  Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan].
+
+  [sn-gan]: https://openreview.net/forum?id=B1QRgziT-
+
+  To reproduce an SN-GAN, apply this custom_getter to every weight tensor of
+  your discriminator. The last dimension of the weight tensor must be the number
+  of output channels.
+
+  Apply this to layers by supplying this as the `custom_getter` of a
+  `tf.compat.v1.variable_scope`. For example:
+
+    with tf.compat.v1.variable_scope('discriminator',
+                           custom_getter=spectral_norm_getter()):
+      net = discriminator_fn(net)
+
+  IMPORTANT: Keras does not respect the custom_getter supplied by the
+  VariableScope, so Keras users should use `keras_spectral_normalization`
+  instead of (or in addition to) this approach.
+
+  It is important to carefully select to which weights you want to apply
+  Spectral Normalization. In general you want to normalize the kernels of
+  convolution and dense layers, but you do not want to normalize biases. You
+  also want to avoid normalizing batch normalization (and similar) variables,
+  but in general such layers play poorly with Spectral Normalization, since the
+  gamma can cancel out the normalization in other layers. By default we supply a
+  filter that matches the kernel variable names of the dense and convolution
+  layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim
+  libraries. If you are using anything else you'll need a custom `name_filter`.
+
+  This custom getter internally creates a variable used to compute the spectral
+  norm by power iteration. It will update every time the variable is accessed,
+  which means the normalized discriminator weights may change slightly whilst
+  training the generator. Whilst unusual, this matches how the paper's authors
+  implement it, and in general additional rounds of power iteration can't hurt.
+
+  Args:
+    name_filter: Optionally, a method that takes a Variable name as input and
+      returns whether this Variable should be normalized.
+    power_iteration_rounds: The number of iterations of the power method to
+      perform per step. A higher number yields a better approximation of the
+      true spectral norm.
+
+  Returns:
+    A custom getter function that applies Spectral Normalization to all
+    Variables whose names match `name_filter`.
+
+  Raises:
+    ValueError: If name_filter is not callable.
+  """
+  if not callable(name_filter):
+    raise ValueError('name_filter must be callable')
+
+  def _internal_getter(getter, name, *args, **kwargs):
+    """A custom getter function that applies Spectral Normalization.
+
+    Args:
+      getter: The true getter to call.
+      name: Name of new/existing variable, in the same format as
+        tf.get_variable.
+      *args: Other positional arguments, in the same format as tf.get_variable.
+      **kwargs: Keyword arguments, in the same format as tf.get_variable.
+
+    Returns:
+      The return value of `getter(name, *args, **kwargs)`, spectrally
+      normalized.
+
+    Raises:
+      ValueError: If used incorrectly, or if `dtype` is not supported.
+    """
+    if not name_filter(name):
+      return getter(name, *args, **kwargs)
+
+    if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX):
+      raise ValueError(
+          'Cannot apply Spectral Normalization to internal variables created '
+          'for Spectral Normalization. Tried to normalized variable [%s]' %
+          name)
+
+    if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM:
+      raise ValueError('Disallowed data type {}'.format(kwargs['dtype']))
+
+    # This layer's weight Variable/PartitionedVariable.
+    w_tensor = getter(name, *args, **kwargs)
+
+    if len(w_tensor.get_shape()) < 2:
+      raise ValueError(
+          'Spectral norm can only be applied to multi-dimensional tensors')
+
+    return spectral_normalize(
+        w_tensor,
+        power_iteration_rounds=power_iteration_rounds,
+        name=(name + '/spectral_normalize'))
+
+  return _internal_getter
+
+
+@contextlib.contextmanager
+def keras_spectral_normalization(name_filter=_default_name_filter,
+                                 power_iteration_rounds=1):
+  """A context manager that enables Spectral Normalization for Keras.
+
+  Keras doesn't respect the `custom_getter` in the VariableScope, so this is a
+  bit of a hack to make things work.
+
+  Usage:
+    with keras_spectral_normalization():
+      net = discriminator_fn(net)
+
+  Args:
+    name_filter: Optionally, a method that takes a Variable name as input and
+      returns whether this Variable should be normalized.
+    power_iteration_rounds: The number of iterations of the power method to
+      perform per step. A higher number yields a better approximation of the
+      true spectral norm.
+
+  Yields:
+    A context manager that wraps the standard Keras variable creation method
+    with the `spectral_normalization_custom_getter`.
+  """
+  original_make_variable = keras_base_layer_utils.make_variable
+  sn_getter = spectral_normalization_custom_getter(
+      name_filter=name_filter, power_iteration_rounds=power_iteration_rounds)
+
+  def make_variable_wrapper(name, *args, **kwargs):
+    return sn_getter(original_make_variable, name, *args, **kwargs)
+
+  keras_base_layer_utils.make_variable = make_variable_wrapper
+
+  yield
+
+  keras_base_layer_utils.make_variable = original_make_variable
diff --git a/bob/learn/tensorflow/models/discriminator.py b/bob/learn/tensorflow/models/discriminator.py
index 67eb5a56..beb1fde2 100644
--- a/bob/learn/tensorflow/models/discriminator.py
+++ b/bob/learn/tensorflow/models/discriminator.py
@@ -1,9 +1,11 @@
 import tensorflow as tf
+from ..gan.spectral_normalization import spectral_norm_regularizer
+from ..utils import gram_matrix
 
 
 class ConvDiscriminator(tf.keras.Model):
     """A discriminator that can sit on top of DenseNet 161's transition 1 block.
-    The output of that block given 224x224 inputs is 14x14x384."""
+    The output of that block given 224x224x3 inputs is 14x14x384."""
 
     def __init__(self, data_format="channels_last", n_classes=1, **kwargs):
         super().__init__(**kwargs)
@@ -13,10 +15,10 @@ class ConvDiscriminator(tf.keras.Model):
         self.sequential_layers = [
             tf.keras.layers.Conv2D(200, 1, data_format=data_format),
             tf.keras.layers.Activation("relu"),
-            tf.layers.AveragePooling2D(3, 2, data_format=data_format),
+            tf.keras.layers.AveragePooling2D(3, 2, data_format=data_format),
             tf.keras.layers.Conv2D(100, 1, data_format=data_format),
             tf.keras.layers.Activation("relu"),
-            tf.layers.AveragePooling2D(3, 2, data_format=data_format),
+            tf.keras.layers.AveragePooling2D(3, 2, data_format=data_format),
             tf.keras.layers.Flatten(data_format=data_format),
             tf.keras.layers.Dense(n_classes),
             tf.keras.layers.Activation(act),
@@ -24,7 +26,10 @@ class ConvDiscriminator(tf.keras.Model):
 
     def call(self, x, training=None):
         for l in self.sequential_layers:
-            x = l(x)
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
         return x
 
 
@@ -66,5 +71,89 @@ class ConvDiscriminator2(tf.keras.Model):
 
     def call(self, x, training=None):
         for l in self.sequential_layers:
-            x = l(x)
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
         return x
+
+
+class ConvDiscriminator3(tf.keras.Model):
+    """A discriminator that takes images and tries its best.
+    Be careful, this one returns logits."""
+
+    def __init__(self, data_format="channels_last", n_classes=1, **kwargs):
+        super().__init__(**kwargs)
+        self.data_format = data_format
+        self.n_classes = n_classes
+        spectral_norm = spectral_norm_regularizer(scale=1.0)
+        conv2d_kw = {"kernel_regularizer": spectral_norm, "data_format": data_format}
+        self.sequential_layers = [
+            tf.keras.layers.Conv2D(64, 3, strides=1, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.Conv2D(64, 4, strides=2, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.Conv2D(128, 3, strides=1, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.Conv2D(128, 4, strides=2, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.Conv2D(256, 3, strides=1, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.Conv2D(256, 4, strides=2, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.Conv2D(512, 3, strides=1, **conv2d_kw),
+            tf.keras.layers.LeakyReLU(0.1),
+            tf.keras.layers.GlobalAveragePooling2D(data_format=data_format),
+            tf.keras.layers.Dense(n_classes),
+        ]
+
+    def call(self, x, training=None):
+        for l in self.sequential_layers:
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
+        return x
+
+
+class DenseDiscriminator(tf.keras.Model):
+    """A discriminator that takes vectors as input and tries its best.
+    Be careful, this one returns logits."""
+
+    def __init__(self, n_classes=1, **kwargs):
+        super().__init__(**kwargs)
+        self.n_classes = n_classes
+        self.sequential_layers = [
+            tf.keras.layers.Dense(1000),
+            tf.keras.layers.Activation("relu"),
+            tf.keras.layers.Dense(1000),
+            tf.keras.layers.Activation("relu"),
+            tf.keras.layers.Dense(n_classes),
+        ]
+
+    def call(self, x, training=None):
+        for l in self.sequential_layers:
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
+        return x
+
+
+class GramComparer1(tf.keras.Model):
+    """A model to compare images based on their gram matrices."""
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.batchnorm = tf.keras.layers.BatchNormalization()
+        self.conv2d = tf.keras.layers.Conv2D(128, 7)
+
+    def call(self, x_1_2, training=None):
+        def _call(x):
+            x = self.batchnorm(x, training=training)
+            x = self.conv2d(x)
+            return gram_matrix(x)
+
+        gram1 = _call(x_1_2[..., :3])
+        gram2 = _call(x_1_2[..., 3:])
+        return -tf.reduce_mean((gram1 - gram2) ** 2, axis=[1, 2])[:, None]
-- 
GitLab