From cb9744bc7597fe5e9319ca7ac7bf41fc7c77a0b4 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 16:12:19 +0100
Subject: [PATCH] new keras models

---
 .../tensorflow/models/autoencoder_face.py     |  99 ++++++
 bob/learn/tensorflow/models/autoencoder_yz.py | 305 ++++++++++++++++++
 bob/learn/tensorflow/models/mlp.py            | 111 +++++++
 3 files changed, 515 insertions(+)
 create mode 100644 bob/learn/tensorflow/models/autoencoder_face.py
 create mode 100644 bob/learn/tensorflow/models/autoencoder_yz.py
 create mode 100644 bob/learn/tensorflow/models/mlp.py

diff --git a/bob/learn/tensorflow/models/autoencoder_face.py b/bob/learn/tensorflow/models/autoencoder_face.py
new file mode 100644
index 00000000..318dd65c
--- /dev/null
+++ b/bob/learn/tensorflow/models/autoencoder_face.py
@@ -0,0 +1,99 @@
+import tensorflow as tf
+from .densenet import densenet161
+
+
+def _get_l2_kw(weight_decay):
+    l2_kw = {}
+    if weight_decay is not None:
+        l2_kw = {"kernel_regularizer": tf.keras.regularizers.l2(weight_decay)}
+    return l2_kw
+
+
+class ConvDecoder(tf.keras.Sequential):
+    """The decoder similar to the one in
+    https://github.com/google/compare_gan/blob/master/compare_gan/architectures/sndcgan.py
+    """
+
+    def __init__(
+        self,
+        z_dim,
+        decoder_layers=(
+            (512, 7, 7, 0),
+            (256, 4, 2, 1),
+            (128, 4, 2, 1),
+            (64, 4, 2, 1),
+            (32, 4, 2, 1),
+            (16, 4, 2, 1),
+            (3, 1, 1, 0),
+        ),
+        weight_decay=1e-5,
+        name="Decoder",
+        **kwargs,
+    ):
+        self.z_dim = z_dim
+        self.data_format = data_format = "channels_last"
+        l2_kw = _get_l2_kw(weight_decay)
+        layers = [
+            tf.keras.layers.Reshape((1, 1, z_dim), input_shape=(z_dim,), name="reshape")
+        ]
+        for i, (filters, kernel_size, strides, cropping) in enumerate(decoder_layers):
+            dconv = tf.keras.layers.Conv2DTranspose(
+                filters,
+                kernel_size,
+                strides=strides,
+                use_bias=i == len(decoder_layers) - 1,
+                data_format=data_format,
+                name=f"dconv_{i}",
+                **l2_kw,
+            )
+            crop = tf.keras.layers.Cropping2D(
+                cropping=cropping, data_format=data_format, name=f"crop_{i}"
+            )
+
+            if i == len(decoder_layers) - 1:
+                act = tf.keras.layers.Activation("tanh", name=f"tanh_{i}")
+                bn = None
+            else:
+                act = tf.keras.layers.Activation("relu", name=f"relu_{i}")
+                bn = tf.keras.layers.BatchNormalization(
+                    scale=False, fused=False, name=f"bn_{i}"
+                )
+            if bn is not None:
+                layers.extend([dconv, crop, bn, act])
+            else:
+                layers.extend([dconv, crop, act])
+        with tf.name_scope(name):
+            super().__init__(layers=layers, name=name, **kwargs)
+
+
+class Autoencoder(tf.keras.Model):
+    """
+    A class defining a simple convolutional autoencoder.
+
+    Attributes
+    ----------
+    data_format : str
+        channels_last is only supported
+    decoder : object
+        The encoder part
+    encoder : object
+        The decoder part
+    """
+
+    def __init__(self, encoder, decoder, name="Autoencoder", **kwargs):
+        super().__init__(name=name, **kwargs)
+        self.encoder = encoder
+        self.decoder = decoder
+
+    def call(self, x, training=None):
+        z = self.encoder(x, training=training)
+        x_hat = self.decoder(z, training=training)
+        return z, x_hat
+
+def autoencoder_face(z_dim=256, weight_decay=1e-9):
+    encoder = densenet161(
+        output_classes=z_dim, weight_decay=weight_decay, weights=None, name="DenseNet"
+    )
+    decoder = ConvDecoder(z_dim=z_dim, weight_decay=weight_decay, name="Decoder")
+    autoencoder = Autoencoder(encoder, decoder, name="Autoencoder")
+    return autoencoder
diff --git a/bob/learn/tensorflow/models/autoencoder_yz.py b/bob/learn/tensorflow/models/autoencoder_yz.py
new file mode 100644
index 00000000..6acad5c2
--- /dev/null
+++ b/bob/learn/tensorflow/models/autoencoder_yz.py
@@ -0,0 +1,305 @@
+import tensorflow as tf
+from .densenet import densenet161, ConvBlock
+
+
+def _get_l2_kw(weight_decay):
+    l2_kw = {}
+    if weight_decay is not None:
+        l2_kw = {"kernel_regularizer": tf.keras.regularizers.l2(weight_decay)}
+    return l2_kw
+
+
+class ConvEncoder(tf.keras.Model):
+    """The encoder part"""
+
+    def __init__(
+        self,
+        encoder_layers,
+        data_format="channels_last",
+        weight_decay=1e-5,
+        name="Encoder",
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
+        self.data_format = data_format
+        l2_kw = _get_l2_kw(weight_decay)
+        layers = []
+        for i, (filters, kernel_size, strides, padding) in enumerate(encoder_layers):
+            bn = tf.keras.layers.BatchNormalization(
+                scale=False, fused=False, name=f"bn_{i}"
+            )
+            if i == 0:
+                act = tf.keras.layers.Activation("linear", name=f"linear_{i}")
+            else:
+                act = tf.keras.layers.Activation("relu", name=f"relu_{i}")
+            pad = tf.keras.layers.ZeroPadding2D(
+                padding=padding, data_format=data_format, name=f"pad_{i}"
+            )
+            conv = tf.keras.layers.Conv2D(
+                filters,
+                kernel_size,
+                strides=strides,
+                use_bias=(i == len(encoder_layers) - 1),
+                data_format=data_format,
+                name=f"conv_{i}",
+                **l2_kw,
+            )
+            if i == len(encoder_layers) - 1:
+                pool = tf.keras.layers.AvgPool2D(
+                    data_format=data_format, name=f"pool_{i}"
+                )
+            else:
+                pool = tf.keras.layers.MaxPooling2D(
+                    data_format=data_format, name=f"pool_{i}"
+                )
+            layers.extend([bn, act, pad, conv, pool])
+        self.sequential_layers = layers
+
+    def call(self, x, training=None):
+        for l in self.sequential_layers:
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
+        return x
+
+
+class ConvDecoder(tf.keras.Model):
+    """The encoder part"""
+
+    def __init__(
+        self, decoder_layers, y_dim, weight_decay=1e-5, name="Decoder", **kwargs
+    ):
+        super().__init__(name=name, **kwargs)
+        self.data_format = data_format = "channels_last"
+        self.y_dim = y_dim
+        l2_kw = _get_l2_kw(weight_decay)
+        layers = []
+        for i, (filters, kernel_size, strides, cropping) in enumerate(decoder_layers):
+            dconv = tf.keras.layers.Conv2DTranspose(
+                filters,
+                kernel_size,
+                strides=strides,
+                use_bias=False,
+                data_format=data_format,
+                name=f"dconv_{i}",
+                **l2_kw,
+            )
+            crop = tf.keras.layers.Cropping2D(
+                cropping=cropping, data_format=data_format, name=f"crop_{i}"
+            )
+            bn = tf.keras.layers.BatchNormalization(
+                scale=(i == len(decoder_layers) - 1), fused=False, name=f"bn_{i}"
+            )
+            if i == len(decoder_layers) - 1:
+                act = tf.keras.layers.Activation("tanh", name=f"tanh_{i}")
+            else:
+                act = tf.keras.layers.Activation("relu", name=f"relu_{i}")
+            layers.extend([dconv, crop, bn, act])
+        self.sequential_layers = layers
+
+    def call(self, x, y, training=None):
+        y = tf.reshape(tf.cast(y, x.dtype), (-1, 1, 1, self.y_dim))
+        x = tf.concat([x, y], axis=-1)
+        for l in self.sequential_layers:
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
+        return x
+
+
+class Autoencoder(tf.keras.Model):
+    """
+    A class defining a simple convolutional autoencoder.
+
+    Attributes
+    ----------
+    data_format : str
+        channels_last is only supported
+    decoder : object
+        The encoder part
+    encoder : object
+        The decoder part
+    """
+
+    def __init__(
+        self, encoder, decoder, z_dim, weight_decay=1e-5, name="Autoencoder", **kwargs
+    ):
+        super().__init__(name=name, **kwargs)
+        data_format = "channels_last"
+        self.data_format = data_format
+        self.weight_decay = weight_decay
+        self.encoder = encoder
+        self.decoder = decoder
+        self.z_dim = z_dim
+
+    def call(self, x, y, training=None):
+        self.encoder_output = tf.reshape(
+            self.encoder(x, training=training), (-1, 1, 1, self.z_dim)
+        )
+        self.decoder_output = self.decoder(self.encoder_output, y, training=training)
+        return self.decoder_output
+
+
+def densenet161_autoencoder(z_dim=256, y_dim=3, weight_decay=1e-10):
+
+    encoder = densenet161(output_classes=z_dim, weight_decay=weight_decay, weights=None)
+    decoder_layers = (
+        (128, 7, 7, 0),
+        (64, 4, 2, 1),
+        (32, 4, 2, 1),
+        (16, 4, 2, 1),
+        (8, 4, 2, 1),
+        (4, 4, 2, 1),
+        (3, 1, 1, 0),
+    )
+    decoder = ConvDecoder(
+        decoder_layers, y_dim=y_dim, weight_decay=weight_decay, name="Decoder"
+    )
+    autoencoder = Autoencoder(encoder, decoder, z_dim=z_dim, weight_decay=weight_decay)
+    return autoencoder
+
+
+class ConvDecoderSupervised(tf.keras.Model):
+    """The encoder part"""
+
+    def __init__(
+        self,
+        decoder_layers,
+        weight_decay=1e-5,
+        data_format="channels_last",
+        name="Decoder",
+        y_dim=None,
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
+        self.data_format = data_format
+        self.y_dim = y_dim
+        l2_kw = _get_l2_kw(weight_decay)
+        layers = []
+        for i, (filters, kernel_size, strides, cropping) in enumerate(decoder_layers):
+            dconv = tf.keras.layers.Conv2DTranspose(
+                filters,
+                kernel_size,
+                strides=strides,
+                use_bias=False,
+                data_format=data_format,
+                name=f"dconv_{i}",
+                **l2_kw,
+            )
+            crop = tf.keras.layers.Cropping2D(
+                cropping=cropping, data_format=data_format, name=f"crop_{i}"
+            )
+            bn = tf.keras.layers.BatchNormalization(
+                scale=(i == len(decoder_layers) - 1), fused=False, name=f"bn_{i}"
+            )
+            if i == len(decoder_layers) - 1:
+                act = tf.keras.layers.Activation("tanh", name=f"tanh_{i}")
+            else:
+                act = tf.keras.layers.Activation("relu", name=f"relu_{i}")
+            layers.extend([dconv, crop, bn, act])
+        self.sequential_layers = layers
+
+    def call(self, x, training=None):
+        x = tf.reshape(x, (-1, 1, 1, x.get_shape().as_list()[-1]))
+        if self.y_dim is not None:
+            y_fixed = tf.one_hot([[[0]]], self.y_dim, dtype=x.dtype)
+            y_fixed = tf.tile(y_fixed, multiples=[tf.shape(x)[0], 1, 1, 1])
+            x = tf.concat([x, y_fixed], axis=-1)
+        x = tf.keras.Input(tensor=x)
+        for l in self.sequential_layers:
+            try:
+                x = l(x, training=training)
+            except TypeError:
+                x = l(x)
+        return x
+
+
+def densenet161_autoencoder_supervised(
+    x,
+    training,
+    weight_decay=1e-10,
+    z_dim=256,
+    y_dim=1,
+    deeppixbis_add_one_more_layer=False,
+    start_from_face_autoencoder=False,
+):
+    data_format = "channels_last"
+    with tf.name_scope("Autoencoder"):
+        densenet = densenet161(
+            output_classes=z_dim,
+            weight_decay=weight_decay,
+            weights=None,
+            data_format=data_format,
+        )
+        z = densenet(x, training=training)
+        transition = tf.keras.Input(tensor=densenet.transition_blocks[1].output)
+
+        layers = [
+            tf.keras.layers.Conv2D(
+                filters=1,
+                kernel_size=1,
+                kernel_initializer="he_normal",
+                kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
+                data_format=data_format,
+                name="dec",
+            ),
+            tf.keras.layers.Flatten(
+                data_format=data_format, name="Pixel_Logits_Flatten"
+            ),
+        ]
+
+        if deeppixbis_add_one_more_layer:
+            layers.insert(
+                0,
+                ConvBlock(
+                    num_filters=32,
+                    data_format=data_format,
+                    bottleneck=True,
+                    weight_decay=weight_decay,
+                    name="prelogits",
+                ),
+            )
+
+        y = transition
+        with tf.name_scope("DeepPixBiS"):
+            for l in layers:
+                try:
+                    y = l(y, training=training)
+                except TypeError:
+                    y = l(y)
+
+        deep_pix_bis_final_layers = tf.keras.Model(
+            inputs=transition, outputs=y, name="DeepPixBiS"
+        )
+        encoder = tf.keras.Model(inputs=[x, transition], outputs=[y, z], name="Encoder")
+        encoder.densenet = densenet
+        if deeppixbis_add_one_more_layer:
+            encoder.prelogits = deep_pix_bis_final_layers.layers[-3].output
+        else:
+            encoder.prelogits = transition
+        encoder.deep_pix_bis = deep_pix_bis_final_layers
+        decoder_layers = (
+            (128, 7, 7, 0),
+            (64, 4, 2, 1),
+            (32, 4, 2, 1),
+            (16, 4, 2, 1),
+            (8, 4, 2, 1),
+            (4, 4, 2, 1),
+            (3, 1, 1, 0),
+        )
+        decoder = ConvDecoderSupervised(
+            decoder_layers,
+            weight_decay=weight_decay,
+            name="Decoder",
+            data_format=data_format,
+            y_dim=3 if start_from_face_autoencoder else None,
+        )
+        x_hat = decoder(z, training=training)
+        autoencoder = tf.keras.Model(
+            inputs=[x, transition], outputs=[y, z, x_hat], name="Autoencoder"
+        )
+        autoencoder.encoder = encoder
+        autoencoder.decoder = decoder
+    return autoencoder, y, z, x_hat
diff --git a/bob/learn/tensorflow/models/mlp.py b/bob/learn/tensorflow/models/mlp.py
new file mode 100644
index 00000000..3804c4e3
--- /dev/null
+++ b/bob/learn/tensorflow/models/mlp.py
@@ -0,0 +1,111 @@
+import tensorflow as tf
+
+
+class MLP(tf.keras.Model):
+    """An MLP that can be trained with center loss and cross entropy."""
+
+    def __init__(
+        self,
+        n_classes=1,
+        hidden_layers=(256, 128, 64, 32),
+        weight_decay=1e-5,
+        name="MLP",
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
+
+        dense_kw = {}
+        if weight_decay is not None:
+            dense_kw["kernel_regularizer"] = tf.keras.regularizers.l2(weight_decay)
+
+        sequential_layers = []
+        for i, n in enumerate(hidden_layers, start=1):
+            sequential_layers.extend(
+                [
+                    tf.keras.layers.Dense(n, use_bias=False, name=f"dense_{i}", **dense_kw),
+                    tf.keras.layers.BatchNormalization(scale=False, name=f"bn_{i}"),
+                    tf.keras.layers.Activation("relu", name=f"relu_{i}"),
+                ]
+            )
+
+        sequential_layers.append(
+            tf.keras.layers.Dense(n_classes, name="logits", **dense_kw)
+        )
+
+        self.hidden_layers = hidden_layers
+        self.n_classes = n_classes
+        self.sequential_layers = sequential_layers
+        self.prelogits_shape = hidden_layers[-1]
+
+    def call(self, x, training=None):
+        assert hasattr(
+            x, "_keras_history"
+        ), "The input must be wrapped inside a keras Input layer."
+
+        for i, layer in enumerate(self.sequential_layers):
+            try:
+                x = layer(x, training=training)
+            except TypeError:
+                x = layer(x)
+
+        return x
+
+    @property
+    def prelogits(self):
+        return self.layers[-2].output
+
+
+class MLPDropout(tf.keras.Model):
+    """An MLP that can be trained with center loss and cross entropy."""
+
+    def __init__(
+        self,
+        n_classes=1,
+        hidden_layers=(256, 128, 64, 32),
+        weight_decay=1e-5,
+        drop_rate=0.5,
+        name="MLP",
+        **kwargs,
+    ):
+        super().__init__(name=name, **kwargs)
+
+        dense_kw = {}
+        if weight_decay is not None:
+            dense_kw["kernel_regularizer"] = tf.keras.regularizers.l2(weight_decay)
+
+        sequential_layers = []
+        for i, n in enumerate(hidden_layers, start=1):
+            sequential_layers.extend(
+                [
+                    tf.keras.layers.Dense(n, use_bias=False, name=f"dense_{i}", **dense_kw),
+                    tf.keras.layers.Activation("relu", name=f"relu_{i}"),
+                    tf.keras.layers.Dropout(rate=drop_rate, name=f"drop_{i}"),
+                ]
+            )
+
+        sequential_layers.append(
+            tf.keras.layers.Dense(n_classes, name="logits", **dense_kw)
+        )
+
+        self.hidden_layers = hidden_layers
+        self.n_classes = n_classes
+        self.drop_rate = drop_rate
+        self.sequential_layers = sequential_layers
+        self.prelogits_shape = hidden_layers[-1]
+
+    def call(self, x, training=None):
+        assert hasattr(
+            x, "_keras_history"
+        ), "The input must be wrapped inside a keras Input layer."
+
+        for i, layer in enumerate(self.sequential_layers):
+            try:
+                x = layer(x, training=training)
+            except TypeError:
+                x = layer(x)
+
+        return x
+
+    @property
+    def prelogits(self):
+        return self.layers[-2].output
-- 
GitLab