diff --git a/bob/learn/tensorflow/models/__init__.py b/bob/learn/tensorflow/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49e41eb2ce115a1b5c709b3066157b0389ef4a0d
--- /dev/null
+++ b/bob/learn/tensorflow/models/__init__.py
@@ -0,0 +1 @@
+from .inception_resnet_v1 import InceptionResNetV1
diff --git a/bob/learn/tensorflow/models/inception_resnet_v1.py b/bob/learn/tensorflow/models/inception_resnet_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..285d7b9f38816045cf659d06d0aab29564b216b1
--- /dev/null
+++ b/bob/learn/tensorflow/models/inception_resnet_v1.py
@@ -0,0 +1,620 @@
+# -*- coding: utf-8 -*-
+"""Inception-ResNet V1 model for Keras.
+# Reference
+http://arxiv.org/abs/1602.07261
+https://github.com/davidsandberg/facenet/blob/master/src/models/inception_resnet_v1.py
+https://github.com/myutwo150/keras-inception-resnet-v2/blob/master/inception_resnet_v2.py
+"""
+from functools import partial
+
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Activation
+from tensorflow.keras.layers import BatchNormalization
+from tensorflow.keras.layers import Concatenate
+from tensorflow.keras.layers import Conv2D
+from tensorflow.keras.layers import Dense
+from tensorflow.keras.layers import Dropout
+from tensorflow.keras.layers import GlobalAveragePooling2D
+from tensorflow.keras.layers import Input
+from tensorflow.keras.layers import Lambda
+from tensorflow.keras.layers import MaxPooling2D
+from tensorflow.keras.layers import Add
+from tensorflow.keras.layers import TimeDistributed
+from tensorflow.keras import backend as K
+import tensorflow as tf
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def scaling(x, scale):
+    return x * scale
+
+
+def conv2d_bn(
+    x,
+    filters,
+    kernel_size,
+    strides=1,
+    padding="same",
+    activation="relu",
+    use_bias=False,
+    name=None,
+    timedistributed=False,
+    kernel_regularizer=None,
+    training=False,
+):
+    if not timedistributed:
+
+        def MyTimeDistributed(x):
+            return x
+
+    else:
+        MyTimeDistributed = TimeDistributed
+
+    x = MyTimeDistributed(
+        Conv2D(
+            filters,
+            kernel_size,
+            strides=strides,
+            padding=padding,
+            use_bias=use_bias,
+            kernel_regularizer=kernel_regularizer,
+            name=name,
+        )
+    )(x)
+    if not use_bias:
+        bn_axis = 1 if K.image_data_format() == "channels_first" else 3
+        bn_name = _generate_layer_name("BatchNorm", prefix=name)
+        x = BatchNormalization(
+            axis=bn_axis + 1 if timedistributed else bn_axis,
+            momentum=0.995,
+            epsilon=0.001,
+            scale=False,
+            name=bn_name,
+        )(x, training=training)
+    if activation is not None:
+        ac_name = _generate_layer_name("Activation", prefix=name)
+        x = MyTimeDistributed(Activation(activation, name=ac_name))(x)
+    return x
+
+
+def _generate_layer_name(name, branch_idx=None, prefix=None):
+    if prefix is None:
+        return None
+    if branch_idx is None:
+        return "_".join((prefix, name))
+    return "_".join((prefix, "Branch", str(branch_idx), name))
+
+
+def _inception_resnet_block(
+    x,
+    scale,
+    block_type,
+    block_idx,
+    activation="relu",
+    timedistributed=False,
+    kernel_regularizer=None,
+    training=False,
+):
+    if not timedistributed:
+
+        def MyTimeDistributed(x):
+            return x
+
+    else:
+        MyTimeDistributed = TimeDistributed
+
+    channel_axis = 1 if K.image_data_format() == "channels_first" else 3
+    if block_idx is None:
+        prefix = None
+    else:
+        prefix = "_".join((block_type, str(block_idx)))
+    name_fmt = partial(_generate_layer_name, prefix=prefix)
+
+    if block_type == "Block35":
+        branch_0 = conv2d_bn(
+            x,
+            32,
+            1,
+            name=name_fmt("Conv2d_1x1", 0),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            x,
+            32,
+            1,
+            name=name_fmt("Conv2d_0a_1x1", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            branch_1,
+            32,
+            3,
+            name=name_fmt("Conv2d_0b_3x3", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_2 = conv2d_bn(
+            x,
+            32,
+            1,
+            name=name_fmt("Conv2d_0a_1x1", 2),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_2 = conv2d_bn(
+            branch_2,
+            32,
+            3,
+            name=name_fmt("Conv2d_0b_3x3", 2),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_2 = conv2d_bn(
+            branch_2,
+            32,
+            3,
+            name=name_fmt("Conv2d_0c_3x3", 2),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branches = [branch_0, branch_1, branch_2]
+    elif block_type == "Block17":
+        branch_0 = conv2d_bn(
+            x,
+            128,
+            1,
+            name=name_fmt("Conv2d_1x1", 0),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            x,
+            128,
+            1,
+            name=name_fmt("Conv2d_0a_1x1", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            branch_1,
+            128,
+            [1, 7],
+            name=name_fmt("Conv2d_0b_1x7", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            branch_1,
+            128,
+            [7, 1],
+            name=name_fmt("Conv2d_0c_7x1", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branches = [branch_0, branch_1]
+    elif block_type == "Block8":
+        branch_0 = conv2d_bn(
+            x,
+            192,
+            1,
+            name=name_fmt("Conv2d_1x1", 0),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            x,
+            192,
+            1,
+            name=name_fmt("Conv2d_0a_1x1", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            branch_1,
+            192,
+            [1, 3],
+            name=name_fmt("Conv2d_0b_1x3", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branch_1 = conv2d_bn(
+            branch_1,
+            192,
+            [3, 1],
+            name=name_fmt("Conv2d_0c_3x1", 1),
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+        branches = [branch_0, branch_1]
+    else:
+        raise ValueError(
+            "Unknown Inception-ResNet block type. "
+            'Expects "Block35", "Block17" or "Block8", '
+            "but got: " + str(block_type)
+        )
+
+    if timedistributed:
+        channel_axis += 1
+
+    mixed = Concatenate(axis=channel_axis, name=name_fmt("Concatenate"))(branches)
+    up = conv2d_bn(
+        mixed,
+        K.int_shape(x)[channel_axis],
+        1,
+        activation=None,
+        use_bias=True,
+        name=name_fmt("Conv2d_1x1"),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    up = MyTimeDistributed(
+        Lambda(
+            scaling,
+            output_shape=K.int_shape(up)[2 if timedistributed else 1 :],
+            arguments={"scale": scale},
+        )
+    )(up)
+    x = Add()([x, up])
+    if activation is not None:
+        x = MyTimeDistributed(Activation(activation, name=name_fmt("Activation")))(x)
+    return x
+
+
+def InceptionResNetV1(
+    input_shape=(160, 160, 3),
+    inputs=None,
+    classes=128,
+    dropout_keep_prob=0.8,
+    weight_decay=1e-5,
+    weights_path=None,
+    timedistributed=False,
+    training=False,
+):
+    if not timedistributed:
+
+        def MyTimeDistributed(x):
+            return x
+
+    else:
+        MyTimeDistributed = TimeDistributed
+        input_shape = [None] + list(input_shape)
+
+    if weight_decay is None:
+        kernel_regularizer = None
+    else:
+        kernel_regularizer = tf.keras.regularizers.l2(l=weight_decay)
+
+    if inputs is None:
+        inputs = Input(shape=input_shape)
+    x = conv2d_bn(
+        inputs,
+        32,
+        3,
+        strides=2,
+        padding="valid",
+        name="Conv2d_1a_3x3",
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    x = conv2d_bn(
+        x,
+        32,
+        3,
+        padding="valid",
+        name="Conv2d_2a_3x3",
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    x = conv2d_bn(
+        x,
+        64,
+        3,
+        name="Conv2d_2b_3x3",
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    x = MyTimeDistributed(MaxPooling2D(3, strides=2, name="MaxPool_3a_3x3"))(x)
+    x = conv2d_bn(
+        x,
+        80,
+        1,
+        padding="valid",
+        name="Conv2d_3b_1x1",
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    x = conv2d_bn(
+        x,
+        192,
+        3,
+        padding="valid",
+        name="Conv2d_4a_3x3",
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    x = conv2d_bn(
+        x,
+        256,
+        3,
+        strides=2,
+        padding="valid",
+        name="Conv2d_4b_3x3",
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+
+    # 5x Block35 (Inception-ResNet-A block):
+    for block_idx in range(1, 6):
+        x = _inception_resnet_block(
+            x,
+            scale=0.17,
+            block_type="Block35",
+            block_idx=block_idx,
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+
+    # Mixed 6a (Reduction-A block):
+    channel_axis = 1 if K.image_data_format() == "channels_first" else 3
+    name_fmt = partial(_generate_layer_name, prefix="Mixed_6a")
+    branch_0 = conv2d_bn(
+        x,
+        384,
+        3,
+        strides=2,
+        padding="valid",
+        name=name_fmt("Conv2d_1a_3x3", 0),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_1 = conv2d_bn(
+        x,
+        192,
+        1,
+        name=name_fmt("Conv2d_0a_1x1", 1),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_1 = conv2d_bn(
+        branch_1,
+        192,
+        3,
+        name=name_fmt("Conv2d_0b_3x3", 1),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_1 = conv2d_bn(
+        branch_1,
+        256,
+        3,
+        strides=2,
+        padding="valid",
+        name=name_fmt("Conv2d_1a_3x3", 1),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_pool = MyTimeDistributed(
+        MaxPooling2D(3, strides=2, padding="valid", name=name_fmt("MaxPool_1a_3x3", 2))
+    )(x)
+    branches = [branch_0, branch_1, branch_pool]
+    x = Concatenate(
+        axis=channel_axis + 1 if timedistributed else channel_axis, name="Mixed_6a"
+    )(branches)
+
+    # 10x Block17 (Inception-ResNet-B block):
+    for block_idx in range(1, 11):
+        x = _inception_resnet_block(
+            x,
+            scale=0.1,
+            block_type="Block17",
+            block_idx=block_idx,
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+
+    # Mixed 7a (Reduction-B block): 8 x 8 x 2080
+    name_fmt = partial(_generate_layer_name, prefix="Mixed_7a")
+    branch_0 = conv2d_bn(
+        x,
+        256,
+        1,
+        name=name_fmt("Conv2d_0a_1x1", 0),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_0 = conv2d_bn(
+        branch_0,
+        384,
+        3,
+        strides=2,
+        padding="valid",
+        name=name_fmt("Conv2d_1a_3x3", 0),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_1 = conv2d_bn(
+        x,
+        256,
+        1,
+        name=name_fmt("Conv2d_0a_1x1", 1),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_1 = conv2d_bn(
+        branch_1,
+        256,
+        3,
+        strides=2,
+        padding="valid",
+        name=name_fmt("Conv2d_1a_3x3", 1),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_2 = conv2d_bn(
+        x,
+        256,
+        1,
+        name=name_fmt("Conv2d_0a_1x1", 2),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_2 = conv2d_bn(
+        branch_2,
+        256,
+        3,
+        name=name_fmt("Conv2d_0b_3x3", 2),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_2 = conv2d_bn(
+        branch_2,
+        256,
+        3,
+        strides=2,
+        padding="valid",
+        name=name_fmt("Conv2d_1a_3x3", 2),
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+    branch_pool = MyTimeDistributed(
+        MaxPooling2D(3, strides=2, padding="valid", name=name_fmt("MaxPool_1a_3x3", 3))
+    )(x)
+    branches = [branch_0, branch_1, branch_2, branch_pool]
+    x = Concatenate(
+        axis=channel_axis + 1 if timedistributed else channel_axis, name="Mixed_7a"
+    )(branches)
+
+    # 5x Block8 (Inception-ResNet-C block):
+    for block_idx in range(1, 6):
+        x = _inception_resnet_block(
+            x,
+            scale=0.2,
+            block_type="Block8",
+            block_idx=block_idx,
+            kernel_regularizer=kernel_regularizer,
+            timedistributed=timedistributed,
+            training=training,
+        )
+    x = _inception_resnet_block(
+        x,
+        scale=1.0,
+        activation=None,
+        block_type="Block8",
+        block_idx=6,
+        kernel_regularizer=kernel_regularizer,
+        timedistributed=timedistributed,
+        training=training,
+    )
+
+    # Classification block
+    x = MyTimeDistributed(GlobalAveragePooling2D(name="AvgPool"))(x)
+    x = MyTimeDistributed(Dropout(1.0 - dropout_keep_prob, name="Dropout"))(
+        x, training=training
+    )
+    # Bottleneck
+    x = MyTimeDistributed(
+        Dense(
+            classes,
+            use_bias=False,
+            name="Bottleneck",
+            kernel_regularizer=kernel_regularizer,
+        )
+    )(x)
+    bn_name = _generate_layer_name("BatchNorm", prefix="Bottleneck")
+    x = BatchNormalization(momentum=0.995, epsilon=0.001, scale=False, name=bn_name)(
+        x, training=training
+    )
+
+    if timedistributed:
+        return x
+    # Create model
+    model = Model(inputs, x, name="inception_resnet_v1")
+    if weights_path is not None:
+        logger.info("restoring model weights from %s", weights_path)
+        model.load_weights(weights_path)
+
+    return model
+
+
+if __name__ == "__main__":
+    import pkg_resources
+
+    tf.enable_eager_execution()
+    import numpy as np
+    from bob.extension import rc
+
+    def input_fn():
+        features = {
+            "data": np.empty((100, 160, 160, 3), dtype="float32"),
+            "key": "path",
+        }
+        labels = {"bio": 10, "pad": int(True)}
+        dataset = tf.data.Dataset.from_tensors((features, labels))
+        return dataset.repeat(2).batch(1)
+
+    input_shape = (None, 160, 160, 3)
+    inputs = tf.keras.layers.Input(input_shape, name="data")
+    key = tf.keras.layers.Input((None,), name="key")
+    embedding = InceptionResNetV1(timedistributed=True, inputs=inputs)
+    tf.keras.Model(inputs, embedding, name="inception_resnet_v1").load_weights(
+        rc["bob.learn.tensorflow.facenet_keras_weights"]
+    )
+    model = tf.keras.layers.LSTM(128)(embedding)
+    bio = tf.keras.layers.Dense(10, activation="softmax", name="bio")(model)
+    pad = tf.keras.layers.Dense(2, activation="softmax", name="pad")(model)
+    model = tf.keras.Model(inputs=[inputs, key], outputs=[bio, pad])
+    # model.build(input_shape=[None] + list(input_shape))
+    model.compile(
+        optimizer=tf.train.AdamOptimizer(),
+        loss="categorical_crossentropy",
+        loss_weights=[0.5, 0.5],
+        metrics=["accuracy"],
+    )
+
+    # model.fit(input_fn(), steps_per_epoch=1)
+
+    estimator = tf.keras.estimator.model_to_estimator(
+        model,
+        model_dir="/scratch/amohammadi/tmp/keras_model",
+        # config=run_config,
+    )
+
+    estimator.train(input_fn)
diff --git a/bob/learn/tensorflow/models/simple_cnn.py b/bob/learn/tensorflow/models/simple_cnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..64a37aa5a236b00037a933669a46997ce07966d3
--- /dev/null
+++ b/bob/learn/tensorflow/models/simple_cnn.py
@@ -0,0 +1,98 @@
+"""
+The network using keras (same as new_architecture function below)::
+
+    from tensorflow.python.keras import *
+    from tensorflow.python.keras.layers import *
+    simplecnn = Sequential([
+        Conv2D(32,(3,3),padding='same',use_bias=False, input_shape=(28,28,3)),
+        BatchNormalization(scale=False),
+        Activation('relu'),
+        MaxPool2D(padding='same'),
+        Conv2D(64,(3,3),padding='same',use_bias=False),
+        BatchNormalization(scale=False),
+        Activation('relu'),
+        MaxPool2D(padding='same'),
+        Flatten(),
+        Dense(1024, use_bias=False),
+        BatchNormalization(scale=False),
+        Activation('relu'),
+        Dropout(rate=0.4),
+        Dense(2),
+    ])
+    simplecnn.summary()
+    _________________________________________________________________
+    Layer (type)                 Output Shape              Param #
+    =================================================================
+    conv2d_1 (Conv2D)            (None, 28, 28, 32)        864
+    _________________________________________________________________
+    batch_normalization_1 (Batch (None, 28, 28, 32)        96
+    _________________________________________________________________
+    activation_1 (Activation)    (None, 28, 28, 32)        0
+    _________________________________________________________________
+    max_pooling2d_1 (MaxPooling2 (None, 14, 14, 32)        0
+    _________________________________________________________________
+    conv2d_2 (Conv2D)            (None, 14, 14, 64)        18432
+    _________________________________________________________________
+    batch_normalization_2 (Batch (None, 14, 14, 64)        192
+    _________________________________________________________________
+    activation_2 (Activation)    (None, 14, 14, 64)        0
+    _________________________________________________________________
+    max_pooling2d_2 (MaxPooling2 (None, 7, 7, 64)          0
+    _________________________________________________________________
+    flatten_1 (Flatten)          (None, 3136)              0
+    _________________________________________________________________
+    dense_1 (Dense)              (None, 1024)              3211264
+    _________________________________________________________________
+    batch_normalization_3 (Batch (None, 1024)              3072
+    _________________________________________________________________
+    activation_3 (Activation)    (None, 1024)              0
+    _________________________________________________________________
+    dropout_1 (Dropout)          (None, 1024)              0
+    _________________________________________________________________
+    dense_2 (Dense)              (None, 2)                 2050
+    =================================================================
+    Total params: 3,235,970
+    Trainable params: 3,233,730
+    Non-trainable params: 2,240
+    _________________________________________________________________
+"""
+
+from tensorflow.python.keras import Sequential, Input
+from tensorflow.python.keras.layers import (
+    Conv2D,
+    BatchNormalization,
+    Activation,
+    MaxPool2D,
+    Flatten,
+    Dense,
+    Dropout,
+    TimeDistributed,
+)
+
+
+def SimpleCNN(input_shape=(28, 28, 3), inputs=None, timedistributed=False):
+
+    if inputs is None:
+        inputs = Input(input_shape)
+    layers = [
+        Conv2D(32, (3, 3), padding="same", use_bias=False),
+        BatchNormalization(scale=False),
+        Activation("relu"),
+        MaxPool2D(padding="same"),
+        Conv2D(64, (3, 3), padding="same", use_bias=False),
+        BatchNormalization(scale=False),
+        Activation("relu"),
+        MaxPool2D(padding="same"),
+        Flatten(),
+        Dense(1024, use_bias=False),
+        BatchNormalization(scale=False),
+        Activation("relu"),
+        Dropout(rate=0.4),
+        # Dense(2, activation='softmax'),
+    ]
+    if timedistributed:
+        for i, layer in enumerate(layers):
+            layers[i] = TimeDistributed(layer)
+        return layers
+    simplecnn = Sequential([inputs] + layers)
+    return simplecnn
diff --git a/bob/learn/tensorflow/network/MultiScale.py b/bob/learn/tensorflow/network/MultiScale.py
new file mode 100644
index 0000000000000000000000000000000000000000..9029a17ba244cacdff827a6d5ad01e892b092d1e
--- /dev/null
+++ b/bob/learn/tensorflow/network/MultiScale.py
@@ -0,0 +1,138 @@
+"""Contains a model definition for AlexNet.
+
+This work was first described in:
+  ImageNet Classification with Deep Convolutional Neural Networks
+  Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
+
+and later refined in:
+  One weird trick for parallelizing convolutional neural networks
+  Alex Krizhevsky, 2014
+
+Here we provide the implementation proposed in "One weird trick" and not
+"ImageNet Classification", as per the paper, the LRN layers have been removed.
+
+Usage:
+  with slim.arg_scope(alexnet.multiscalecnn_arg_scope()):
+    outputs, end_points = alexnet.multiscalecnn(inputs)
+
+"""
+
+
+def multiscalecnn(
+    inputs,
+    mode,
+    num_classes=2,
+    dropout_keep_prob=0.5,
+    spatial_squeeze=True,
+    scope="multiscalecnn",
+    reuse=False,
+):
+    """AlexNet version 2.
+
+    Described in: http://arxiv.org/pdf/1404.5997v2.pdf
+    Parameters from:
+    github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
+    layers-imagenet-1gpu.cfg
+
+    Note: All the fully_connected layers have been transformed to conv2d slim.
+          To use in classification mode, resize input to 224x224. To use in fully
+          convolutional mode, set spatial_squeeze to false.
+          The LRN layers have been removed and change the initializers from
+          random_normal_initializer to xavier_initializer.
+
+    Args:
+      inputs: a tensor of size [batch_size, height, width, channels].
+      num_classes: number of predicted classes.
+      is_training: whether or not the model is being trained.
+      dropout_keep_prob: the probability that activations are kept in the dropout
+        layers during training.
+      spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+        outputs. Useful to remove unnecessary dimensions for classification.
+      scope: Optional scope for the variables.
+
+    Returns:
+      the last op containing the log predictions and end_points dict.
+    """
+    with tf.variable_scope(scope, "multiscalecnn", [inputs], reuse=reuse) as sc:
+        end_points_collection = sc.original_name_scope + "_end_points"
+        # Collect outputs for conv2d, fully_connected, max_pool2d, and batch_norm.
+        with slim.arg_scope(
+            [slim.conv2d, slim.fully_connected, slim.max_pool2d, slim.batch_norm],
+            outputs_collections=[end_points_collection],
+        ):
+            net = slim.conv2d(inputs, 3, 1, scope="conv1")
+            net = slim.conv2d(net, 32, 3, scope="conv2a")
+            net = slim.conv2d(net, 32, 3, scope="conv2b")
+            net = slim.max_pool2d(net, 3, scope="pool1")
+            first_scale = slim.dropout(net, dropout_keep_prob, scope="dropout1")
+
+            net = slim.conv2d(first_scale, 64, 3, scope="conv3a")
+            net = slim.conv2d(net, 64, 3, scope="conv3b")
+            net = slim.max_pool2d(net, 3, scope="pool2")
+            second_scale = slim.dropout(net, dropout_keep_prob, scope="dropout2")
+
+            net = slim.conv2d(second_scale, 64, 3, scope="conv4a")
+            net = slim.conv2d(net, 64, 3, scope="conv4b")
+            net = slim.max_pool2d(net, 3, scope="pool3")
+            third_scale = slim.dropout(net, dropout_keep_prob, scope="dropout3")
+
+            # add 1x1 convs
+            first_scale = slim.conv2d(
+                first_scale, 1, 1, scope="conv5", activation_fn=None
+            )
+            second_scale = slim.conv2d(
+                second_scale, 1, 1, scope="conv6", activation_fn=None
+            )
+            third_scale = slim.conv2d(
+                third_scale, 1, 1, scope="conv7", activation_fn=None
+            )
+
+            # AlexNet
+            net = slim.conv2d(inputs, 64, [11, 11], 4, padding="VALID", scope="conv1")
+            net = slim.conv2d(net, 192, [5, 5], scope="conv2")
+            net = slim.max_pool2d(net, [3, 3], 2, scope="pool2")
+            net = slim.conv2d(net, 384, [3, 3], scope="conv3")
+            net = slim.conv2d(net, 384, [3, 3], scope="conv4")
+            net = slim.conv2d(net, 256, [3, 3], scope="conv5")
+            net = slim.max_pool2d(net, [3, 3], 2, scope="pool5")
+
+            # Use conv2d instead of fully_connected layers.
+            with slim.arg_scope(
+                [slim.conv2d],
+                weights_initializer=trunc_normal(0.005),
+                biases_initializer=init_ops.constant_initializer(0.1),
+            ):
+                net = slim.conv2d(net, 4096, [5, 5], padding="VALID", scope="fc6")
+                net = slim.dropout(
+                    net, dropout_keep_prob, is_training=is_training, scope="dropout6"
+                )
+                net = slim.conv2d(net, 4096, [1, 1], scope="fc7")
+                net = slim.dropout(
+                    net, dropout_keep_prob, is_training=is_training, scope="dropout7"
+                )
+                net = slim.conv2d(
+                    net,
+                    num_classes,
+                    [1, 1],
+                    activation_fn=None,
+                    normalizer_fn=None,
+                    biases_initializer=init_ops.zeros_initializer(),
+                    scope="fc8",
+                )
+
+            # Convert end_points_collection into a end_point dict.
+            end_points = utils.convert_collection_to_dict(end_points_collection)
+            if spatial_squeeze:
+                net = array_ops.squeeze(net, [1, 2], name="fc8/squeezed")
+                end_points[sc.name + "/fc8"] = net
+            return net, end_points
+
+
+multiscalecnn.default_image_size = 224
+
+
+def multiscalecnn_architecture(inputs, mode, reuse=False):
+    with slim.arg_scope(
+        multiscalecnn_arg_scope(is_training=mode == tf.estimator.ModeKeys.TRAIN)
+    ):
+        outputs, end_points = multiscalecnn(inputs)
diff --git a/bob/learn/tensorflow/network/densenet.py b/bob/learn/tensorflow/network/densenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..04ec16f65dc7f35fd7a60088599c89a04356183c
--- /dev/null
+++ b/bob/learn/tensorflow/network/densenet.py
@@ -0,0 +1,211 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""
+DenseNet from  arXiv:1608.06993v3
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from . import densenet_utils
+
+slim = tf.contrib.slim
+densenet_arg_scope = densenet_utils.densenet_arg_scope
+DenseBlock = densenet_utils.DenseBlock
+TransitionLayer = densenet_utils.TransitionLayer
+
+
+def densenet(
+    inputs,
+    input_batch_norm=True,
+    n_filters_first_conv=16,
+    n_dense=4,
+    growth_rate=12,
+    n_layers_per_block=[6, 12, 24, 16],
+    dropout_p=0.2,
+    bottleneck=False,
+    compression=1.0,
+    is_training=False,
+    dense_prediction=False,
+    reuse=None,
+    scope=None,
+):
+    """
+    DenseNet as described for ImageNet use. Supports B (bottleneck) and
+    C (compression) variants.
+    Args:
+      n_classes: number of classes
+      n_filters_first_conv: number of filters for the first convolution applied
+      n_dense: number of dense_blocks
+      growth_rate: number of new feature maps created by each layer in a dense block
+      n_layers_per_block: number of layers per block. Can be an int or a list of size 2 * n_dense + 1
+      dropout_p: dropout rate applied after each convolution (0. for not using)
+          is_training: whether is training or not.
+      dense_prediction: Bool, defaults to False
+      reuse: whether or not the network and its variables should be reused. To be
+        able to reuse 'scope' must be given.
+      scope: Optional variable_scope.
+    Returns:
+    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+      if
+    end_points: A dictionary from components of the network to the corresponding
+      activation.
+    """
+    # check n_layers_per_block argument
+    if type(n_layers_per_block) == list:
+        assert len(n_layers_per_block) == n_dense
+    elif type(n_layers_per_block) == int:
+        n_layers_per_block = [n_layers_per_block] * n_dense
+    else:
+        raise ValueError
+
+    with tf.variable_scope(scope, "densenet", [inputs], reuse=reuse) as sc:
+        end_points_collection = sc.name + "_end_points"
+        with slim.arg_scope(
+            [slim.conv2d, DenseBlock, TransitionLayer],
+            outputs_collections=end_points_collection,
+        ):
+            with slim.arg_scope(
+                [slim.batch_norm, slim.dropout], is_training=is_training
+            ):
+
+                if input_batch_norm:
+                    inputs = slim.batch_norm(
+                        inputs,
+                        decay=0.997,
+                        epsilon=1e-5,
+                        scale=True,
+                        activation_fn=None,
+                        updates_collections=tf.GraphKeys.UPDATE_OPS,
+                    )
+
+                #####################
+                # First Convolution #
+                #####################
+                # We perform a first convolution.
+                # If DenseNet BC, first convolution has 2*growth_rate output channels
+                if bottleneck and compression < 1.0:
+                    n_filters_first_conv = 2 * growth_rate
+                net = slim.conv2d(
+                    inputs,
+                    n_filters_first_conv,
+                    [7, 7],
+                    stride=[2, 2],
+                    scope="first_conv",
+                )
+                net = slim.pool(net, [2, 2], stride=[2, 2], pooling_type="MAX")
+                n_filters = n_filters_first_conv
+
+                #####################
+                #    Dense blocks   #
+                #####################
+
+                for i in range(n_dense - 1):
+                    # Dense Block
+                    net, _ = DenseBlock(
+                        net,
+                        n_layers_per_block[i],
+                        growth_rate,
+                        dropout_p,
+                        bottleneck=bottleneck,
+                        scope="denseblock%d" % (i + 1),
+                    )
+                    n_filters += n_layers_per_block[i] * growth_rate
+
+                    # Transition layer
+                    net = TransitionLayer(
+                        net,
+                        n_filters,
+                        dropout_p,
+                        compression=compression,
+                        scope="transition%d" % (i + 1),
+                    )
+
+                # Final dense block (no transition layer afterwards)
+                net, _ = DenseBlock(
+                    net,
+                    n_layers_per_block[n_dense - 1],
+                    growth_rate,
+                    dropout_p,
+                    scope="denseblock%d" % (n_dense),
+                )
+
+                #####################
+                #      Outputs      #
+                #####################
+                pool_name = "pool%d" % (n_dense + 1)
+                if dense_prediction:
+                    net = slim.pool(net, [7, 7], pooling_type="AVG", scope=pool_name)
+                    net = slim.flatten(net, scope="prelogits")
+                    # net = slim.conv2d(
+                    #     net,
+                    #     num_classes,
+                    #     [1, 1],
+                    #     activation_fn=None,
+                    #     normalizer_fn=None,
+                    #     scope="logits",
+                    # )
+
+                else:
+                    net = tf.reduce_mean(net, [1, 2], name=pool_name, keepdims=True)
+                    net = slim.flatten(net, scope="prelogits")
+                    # net = slim.conv2d(
+                    #     net,
+                    #     num_classes,
+                    #     [1, 1],
+                    #     activation_fn=None,
+                    #     normalizer_fn=None,
+                    #     scope="4Dlogits",
+                    # )
+                    # net = tf.squeeze(net, [1, 2], name="logits")
+
+                # Convert end_points_collection into a dictionary of end_points.
+                end_points = slim.utils.convert_collection_to_dict(
+                    end_points_collection
+                )
+
+                # end_points["predictions"] = slim.softmax(net, scope="predictions")
+                return net, end_points
+
+
+def densenet_161(inputs, mode, trainable_variables=None, reuse=None, scope=None):
+    with tf.contrib.slim.arg_scope(
+        densenet_arg_scope(
+            weight_decay=1e-4,
+            batch_norm_decay=0.997,
+            batch_norm_epsilon=1e-5,
+            batch_norm_scale=True,
+            activation_fn=tf.nn.relu,
+            use_batch_norm=True,
+        )
+    ):
+        return densenet(
+            inputs,
+            input_batch_norm=True,
+            n_filters_first_conv=96,
+            n_dense=4,
+            growth_rate=48,
+            n_layers_per_block=[6, 12, 36, 24],
+            dropout_p=0.0,
+            bottleneck=True,
+            compression=0.5,
+            is_training=mode == tf.estimator.ModeKeys.TRAIN,
+            dense_prediction=False,
+            reuse=reuse,
+            scope=scope,
+        )
diff --git a/bob/learn/tensorflow/network/densenet_utils.py b/bob/learn/tensorflow/network/densenet_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c986f5acfbbe61a472e40ee17941bdef5481abc
--- /dev/null
+++ b/bob/learn/tensorflow/network/densenet_utils.py
@@ -0,0 +1,188 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""
+Contains blocks for building DenseNet-based models
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+import numpy as np
+
+
+slim = tf.contrib.slim
+
+
+def densenet_arg_scope(
+    weight_decay=0.0001,
+    batch_norm_decay=0.997,
+    batch_norm_epsilon=1e-5,
+    batch_norm_scale=True,
+    activation_fn=tf.nn.relu,
+    use_batch_norm=True,
+):
+    """
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+
+    batch_norm_decay: The moving average decay when estimating layer activation
+    statistics in batch normalization.
+
+    batch_norm_epsilon: Small constant to prevent division by zero when
+    normalizing activations by their variance in batch normalization.
+
+    batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
+    activations in the batch normalization layer.
+
+    activation_fn: The activation function which is used in ResNet.
+
+    use_batch_norm: Whether or not to use batch normalization.
+
+  Returns:
+    An `arg_scope` to use for the densenet models.
+  """
+    batch_norm_params = {
+        "decay": batch_norm_decay,
+        "epsilon": batch_norm_epsilon,
+        "scale": batch_norm_scale,
+        "activation_fn": activation_fn,
+        "updates_collections": tf.GraphKeys.UPDATE_OPS,
+    }
+
+    with slim.arg_scope(
+        [slim.conv2d],
+        padding="SAME",
+        weights_regularizer=slim.l2_regularizer(weight_decay),
+        weights_initializer=slim.variance_scaling_initializer(),
+        activation_fn=None,
+        normalizer_fn=slim.batch_norm if use_batch_norm else None,
+        normalizer_params=batch_norm_params,
+    ):
+        with slim.arg_scope([slim.batch_norm], **batch_norm_params) as arg_sc:
+            return arg_sc
+
+
+def preact_conv(inputs, n_filters, filter_size=[3, 3], dropout_p=0.2):
+    """
+    Basic pre-activation layer for DenseNets
+    Apply successivly BatchNormalization, ReLU nonlinearity, Convolution and
+    Dropout (if dropout_p > 0) on the inputs
+    """
+    preact = slim.batch_norm(inputs)
+    conv = slim.conv2d(preact, n_filters, filter_size, normalizer_fn=None)
+    if dropout_p != 0.0:
+        conv = slim.dropout(conv, keep_prob=(1.0 - dropout_p))
+    return conv
+
+
+@slim.add_arg_scope
+def DenseBlock(
+    stack,
+    n_layers,
+    growth_rate,
+    dropout_p,
+    bottleneck=False,
+    scope=None,
+    outputs_collections=None,
+):
+    """
+  DenseBlock for DenseNet and FC-DenseNet
+
+  Args:
+    stack: input 4D tensor
+    n_layers: number of internal layers
+    growth_rate: number of feature maps per internal layer
+
+  Returns:
+    stack: current stack of feature maps (4D tensor)
+    new_features: 4D tensor containing only the new feature maps generated
+      in this block
+  """
+    with tf.name_scope(scope) as sc:
+        new_features = []
+        for j in range(n_layers):
+            # Compute new feature maps
+            # if bottleneck, do a 1x1 conv before the 3x3
+            if bottleneck:
+                stack = preact_conv(
+                    stack, 4 * growth_rate, filter_size=[1, 1], dropout_p=0.0
+                )
+            layer = preact_conv(stack, growth_rate, dropout_p=dropout_p)
+            new_features.append(layer)
+            # stack new layer
+            stack = tf.concat([stack, layer], axis=-1)
+        new_features = tf.concat(new_features, axis=-1)
+        return stack, new_features
+
+
+@slim.add_arg_scope
+def TransitionLayer(
+    inputs,
+    n_filters,
+    dropout_p=0.2,
+    compression=1.0,
+    scope=None,
+    outputs_collections=None,
+):
+    """
+  Transition layer for DenseNet
+  Apply 1x1 BN  + conv then 2x2 max pooling
+  """
+    with tf.name_scope(scope) as sc:
+        if compression < 1.0:
+            n_filters = int(np.floor(n_filters * compression))
+        l = preact_conv(inputs, n_filters, filter_size=[1, 1], dropout_p=dropout_p)
+        l = slim.pool(l, [2, 2], stride=[2, 2], pooling_type="AVG")
+
+        return l
+
+
+@slim.add_arg_scope
+def TransitionDown(
+    inputs, n_filters, dropout_p=0.2, scope=None, outputs_collections=None
+):
+    """
+  Transition Down (TD) for FC-DenseNet
+  Apply 1x1 BN + ReLU + conv then 2x2 max pooling
+  """
+    with tf.name_scope(scope) as sc:
+        l = preact_conv(inputs, n_filters, filter_size=[1, 1], dropout_p=dropout_p)
+        l = slim.pool(l, [2, 2], stride=[2, 2], pooling_type="MAX")
+        return l
+
+
+@slim.add_arg_scope
+def TransitionUp(
+    block_to_upsample,
+    skip_connection,
+    n_filters_keep,
+    scope=None,
+    outputs_collections=None,
+):
+    """
+  Transition Up for FC-DenseNet
+  Performs upsampling on block_to_upsample by a factor 2 and concatenates it
+  with the skip_connection
+  """
+    with tf.name_scope(scope) as sc:
+        # Upsample
+        l = slim.conv2d_transpose(
+            block_to_upsample, n_filters_keep, kernel_size=[3, 3], stride=[2, 2]
+        )
+        # Concatenate with skip connection
+        l = tf.concat([l, skip_connection], axis=-1)
+        return l
diff --git a/bob/learn/tensorflow/network/inception_resnet_v2.py b/bob/learn/tensorflow/network/inception_resnet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c358a5bf9a007803ce7c814553ba76f0809ba0b9
--- /dev/null
+++ b/bob/learn/tensorflow/network/inception_resnet_v2.py
@@ -0,0 +1,504 @@
+# Copied from https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the definition of the Inception Resnet V2 architecture.
+
+As described in http://arxiv.org/abs/1602.07261.
+
+  Inception-v4, Inception-ResNet and the Impact of Residual Connections
+    on Learning
+  Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
+    """Builds the 35x35 resnet block."""
+    with tf.variable_scope(scope, "Block35", [net], reuse=reuse):
+        with tf.variable_scope("Branch_0"):
+            tower_conv = slim.conv2d(net, 32, 1, scope="Conv2d_1x1")
+        with tf.variable_scope("Branch_1"):
+            tower_conv1_0 = slim.conv2d(net, 32, 1, scope="Conv2d_0a_1x1")
+            tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope="Conv2d_0b_3x3")
+        with tf.variable_scope("Branch_2"):
+            tower_conv2_0 = slim.conv2d(net, 32, 1, scope="Conv2d_0a_1x1")
+            tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope="Conv2d_0b_3x3")
+            tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope="Conv2d_0c_3x3")
+        mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_1, tower_conv2_2])
+        up = slim.conv2d(
+            mixed,
+            net.get_shape()[3],
+            1,
+            normalizer_fn=None,
+            activation_fn=None,
+            scope="Conv2d_1x1",
+        )
+        scaled_up = up * scale
+        if activation_fn == tf.nn.relu6:
+            # Use clip_by_value to simulate bandpass activation.
+            scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)
+
+        net += scaled_up
+        if activation_fn:
+            net = activation_fn(net)
+    return net
+
+
+def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
+    """Builds the 17x17 resnet block."""
+    with tf.variable_scope(scope, "Block17", [net], reuse=reuse):
+        with tf.variable_scope("Branch_0"):
+            tower_conv = slim.conv2d(net, 192, 1, scope="Conv2d_1x1")
+        with tf.variable_scope("Branch_1"):
+            tower_conv1_0 = slim.conv2d(net, 128, 1, scope="Conv2d_0a_1x1")
+            tower_conv1_1 = slim.conv2d(
+                tower_conv1_0, 160, [1, 7], scope="Conv2d_0b_1x7"
+            )
+            tower_conv1_2 = slim.conv2d(
+                tower_conv1_1, 192, [7, 1], scope="Conv2d_0c_7x1"
+            )
+        mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
+        up = slim.conv2d(
+            mixed,
+            net.get_shape()[3],
+            1,
+            normalizer_fn=None,
+            activation_fn=None,
+            scope="Conv2d_1x1",
+        )
+
+        scaled_up = up * scale
+        if activation_fn == tf.nn.relu6:
+            # Use clip_by_value to simulate bandpass activation.
+            scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)
+
+        net += scaled_up
+        if activation_fn:
+            net = activation_fn(net)
+    return net
+
+
+def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
+    """Builds the 8x8 resnet block."""
+    with tf.variable_scope(scope, "Block8", [net], reuse=reuse):
+        with tf.variable_scope("Branch_0"):
+            tower_conv = slim.conv2d(net, 192, 1, scope="Conv2d_1x1")
+        with tf.variable_scope("Branch_1"):
+            tower_conv1_0 = slim.conv2d(net, 192, 1, scope="Conv2d_0a_1x1")
+            tower_conv1_1 = slim.conv2d(
+                tower_conv1_0, 224, [1, 3], scope="Conv2d_0b_1x3"
+            )
+            tower_conv1_2 = slim.conv2d(
+                tower_conv1_1, 256, [3, 1], scope="Conv2d_0c_3x1"
+            )
+        mixed = tf.concat(axis=3, values=[tower_conv, tower_conv1_2])
+        up = slim.conv2d(
+            mixed,
+            net.get_shape()[3],
+            1,
+            normalizer_fn=None,
+            activation_fn=None,
+            scope="Conv2d_1x1",
+        )
+
+        scaled_up = up * scale
+        if activation_fn == tf.nn.relu6:
+            # Use clip_by_value to simulate bandpass activation.
+            scaled_up = tf.clip_by_value(scaled_up, -6.0, 6.0)
+
+        net += scaled_up
+        if activation_fn:
+            net = activation_fn(net)
+    return net
+
+
+def inception_resnet_v2_base(
+    inputs,
+    final_endpoint="Conv2d_7b_1x1",
+    output_stride=16,
+    align_feature_maps=False,
+    scope=None,
+    activation_fn=tf.nn.relu,
+):
+    """Inception model from  http://arxiv.org/abs/1602.07261.
+
+  Constructs an Inception Resnet v2 network from inputs to the given final
+  endpoint. This method can construct the network up to the final inception
+  block Conv2d_7b_1x1.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    final_endpoint: specifies the endpoint to construct the network up to. It
+      can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
+      'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
+      'Mixed_5b', 'Mixed_6a', 'PreAuxLogits', 'Mixed_7a', 'Conv2d_7b_1x1']
+    output_stride: A scalar that specifies the requested ratio of input to
+      output spatial resolution. Only supports 8 and 16.
+    align_feature_maps: When true, changes all the VALID paddings in the network
+      to SAME padding so that the feature maps are aligned.
+    scope: Optional variable_scope.
+    activation_fn: Activation function for block scopes.
+
+  Returns:
+    tensor_out: output tensor corresponding to the final_endpoint.
+    end_points: a set of activations for external use, for example summaries or
+                losses.
+
+  Raises:
+    ValueError: if final_endpoint is not set to one of the predefined values,
+      or if the output_stride is not 8 or 16, or if the output_stride is 8 and
+      we request an end point after 'PreAuxLogits'.
+  """
+    if output_stride != 8 and output_stride != 16:
+        raise ValueError("output_stride must be 8 or 16.")
+
+    padding = "SAME" if align_feature_maps else "VALID"
+
+    end_points = {}
+
+    def add_and_check_final(name, net):
+        end_points[name] = net
+        return name == final_endpoint
+
+    with tf.variable_scope(scope, "InceptionResnetV2", [inputs]):
+        with slim.arg_scope(
+            [slim.conv2d, slim.max_pool2d, slim.avg_pool2d], stride=1, padding="SAME"
+        ):
+            # 149 x 149 x 32
+            net = slim.conv2d(
+                inputs, 32, 3, stride=2, padding=padding, scope="Conv2d_1a_3x3"
+            )
+            if add_and_check_final("Conv2d_1a_3x3", net):
+                return net, end_points
+
+            # 147 x 147 x 32
+            net = slim.conv2d(net, 32, 3, padding=padding, scope="Conv2d_2a_3x3")
+            if add_and_check_final("Conv2d_2a_3x3", net):
+                return net, end_points
+            # 147 x 147 x 64
+            net = slim.conv2d(net, 64, 3, scope="Conv2d_2b_3x3")
+            if add_and_check_final("Conv2d_2b_3x3", net):
+                return net, end_points
+            # 73 x 73 x 64
+            net = slim.max_pool2d(
+                net, 3, stride=2, padding=padding, scope="MaxPool_3a_3x3"
+            )
+            if add_and_check_final("MaxPool_3a_3x3", net):
+                return net, end_points
+            # 73 x 73 x 80
+            net = slim.conv2d(net, 80, 1, padding=padding, scope="Conv2d_3b_1x1")
+            if add_and_check_final("Conv2d_3b_1x1", net):
+                return net, end_points
+            # 71 x 71 x 192
+            net = slim.conv2d(net, 192, 3, padding=padding, scope="Conv2d_4a_3x3")
+            if add_and_check_final("Conv2d_4a_3x3", net):
+                return net, end_points
+            # 35 x 35 x 192
+            net = slim.max_pool2d(
+                net, 3, stride=2, padding=padding, scope="MaxPool_5a_3x3"
+            )
+            if add_and_check_final("MaxPool_5a_3x3", net):
+                return net, end_points
+
+            # 35 x 35 x 320
+            with tf.variable_scope("Mixed_5b"):
+                with tf.variable_scope("Branch_0"):
+                    tower_conv = slim.conv2d(net, 96, 1, scope="Conv2d_1x1")
+                with tf.variable_scope("Branch_1"):
+                    tower_conv1_0 = slim.conv2d(net, 48, 1, scope="Conv2d_0a_1x1")
+                    tower_conv1_1 = slim.conv2d(
+                        tower_conv1_0, 64, 5, scope="Conv2d_0b_5x5"
+                    )
+                with tf.variable_scope("Branch_2"):
+                    tower_conv2_0 = slim.conv2d(net, 64, 1, scope="Conv2d_0a_1x1")
+                    tower_conv2_1 = slim.conv2d(
+                        tower_conv2_0, 96, 3, scope="Conv2d_0b_3x3"
+                    )
+                    tower_conv2_2 = slim.conv2d(
+                        tower_conv2_1, 96, 3, scope="Conv2d_0c_3x3"
+                    )
+                with tf.variable_scope("Branch_3"):
+                    tower_pool = slim.avg_pool2d(
+                        net, 3, stride=1, padding="SAME", scope="AvgPool_0a_3x3"
+                    )
+                    tower_pool_1 = slim.conv2d(tower_pool, 64, 1, scope="Conv2d_0b_1x1")
+                net = tf.concat(
+                    [tower_conv, tower_conv1_1, tower_conv2_2, tower_pool_1], 3
+                )
+
+            if add_and_check_final("Mixed_5b", net):
+                return net, end_points
+            # TODO(alemi): Register intermediate endpoints
+            net = slim.repeat(net, 10, block35, scale=0.17, activation_fn=activation_fn)
+
+            # 17 x 17 x 1088 if output_stride == 8,
+            # 33 x 33 x 1088 if output_stride == 16
+            use_atrous = output_stride == 8
+
+            with tf.variable_scope("Mixed_6a"):
+                with tf.variable_scope("Branch_0"):
+                    tower_conv = slim.conv2d(
+                        net,
+                        384,
+                        3,
+                        stride=1 if use_atrous else 2,
+                        padding=padding,
+                        scope="Conv2d_1a_3x3",
+                    )
+                with tf.variable_scope("Branch_1"):
+                    tower_conv1_0 = slim.conv2d(net, 256, 1, scope="Conv2d_0a_1x1")
+                    tower_conv1_1 = slim.conv2d(
+                        tower_conv1_0, 256, 3, scope="Conv2d_0b_3x3"
+                    )
+                    tower_conv1_2 = slim.conv2d(
+                        tower_conv1_1,
+                        384,
+                        3,
+                        stride=1 if use_atrous else 2,
+                        padding=padding,
+                        scope="Conv2d_1a_3x3",
+                    )
+                with tf.variable_scope("Branch_2"):
+                    tower_pool = slim.max_pool2d(
+                        net,
+                        3,
+                        stride=1 if use_atrous else 2,
+                        padding=padding,
+                        scope="MaxPool_1a_3x3",
+                    )
+                net = tf.concat([tower_conv, tower_conv1_2, tower_pool], 3)
+
+            if add_and_check_final("Mixed_6a", net):
+                return net, end_points
+
+            # TODO(alemi): register intermediate endpoints
+            with slim.arg_scope([slim.conv2d], rate=2 if use_atrous else 1):
+                net = slim.repeat(
+                    net, 20, block17, scale=0.10, activation_fn=activation_fn
+                )
+            if add_and_check_final("PreAuxLogits", net):
+                return net, end_points
+
+            if output_stride == 8:
+                # TODO(gpapan): Properly support output_stride for the rest of the net.
+                raise ValueError(
+                    "output_stride==8 is only supported up to the "
+                    "PreAuxlogits end_point for now."
+                )
+
+            # 8 x 8 x 2080
+            with tf.variable_scope("Mixed_7a"):
+                with tf.variable_scope("Branch_0"):
+                    tower_conv = slim.conv2d(net, 256, 1, scope="Conv2d_0a_1x1")
+                    tower_conv_1 = slim.conv2d(
+                        tower_conv,
+                        384,
+                        3,
+                        stride=2,
+                        padding=padding,
+                        scope="Conv2d_1a_3x3",
+                    )
+                with tf.variable_scope("Branch_1"):
+                    tower_conv1 = slim.conv2d(net, 256, 1, scope="Conv2d_0a_1x1")
+                    tower_conv1_1 = slim.conv2d(
+                        tower_conv1,
+                        288,
+                        3,
+                        stride=2,
+                        padding=padding,
+                        scope="Conv2d_1a_3x3",
+                    )
+                with tf.variable_scope("Branch_2"):
+                    tower_conv2 = slim.conv2d(net, 256, 1, scope="Conv2d_0a_1x1")
+                    tower_conv2_1 = slim.conv2d(
+                        tower_conv2, 288, 3, scope="Conv2d_0b_3x3"
+                    )
+                    tower_conv2_2 = slim.conv2d(
+                        tower_conv2_1,
+                        320,
+                        3,
+                        stride=2,
+                        padding=padding,
+                        scope="Conv2d_1a_3x3",
+                    )
+                with tf.variable_scope("Branch_3"):
+                    tower_pool = slim.max_pool2d(
+                        net, 3, stride=2, padding=padding, scope="MaxPool_1a_3x3"
+                    )
+                net = tf.concat(
+                    [tower_conv_1, tower_conv1_1, tower_conv2_2, tower_pool], 3
+                )
+
+            if add_and_check_final("Mixed_7a", net):
+                return net, end_points
+
+            # TODO(alemi): register intermediate endpoints
+            net = slim.repeat(net, 9, block8, scale=0.20, activation_fn=activation_fn)
+            net = block8(net, activation_fn=None)
+
+            # 8 x 8 x 1536
+            net = slim.conv2d(net, 1536, 1, scope="Conv2d_7b_1x1")
+            if add_and_check_final("Conv2d_7b_1x1", net):
+                return net, end_points
+
+        raise ValueError("final_endpoint (%s) not recognized", final_endpoint)
+
+
+def inception_resnet_v2(
+    inputs,
+    num_classes=1001,
+    is_training=True,
+    dropout_keep_prob=0.8,
+    reuse=None,
+    scope="InceptionResnetV2",
+    create_aux_logits=True,
+    activation_fn=tf.nn.relu,
+):
+    """Creates the Inception Resnet V2 model.
+
+  Args:
+    inputs: a 4-D tensor of size [batch_size, height, width, 3].
+      Dimension batch_size may be undefined. If create_aux_logits is false,
+      also height and width may be undefined.
+    num_classes: number of predicted classes. If 0 or None, the logits layer
+      is omitted and the input features to the logits layer (before  dropout)
+      are returned instead.
+    is_training: whether is training or not.
+    dropout_keep_prob: float, the fraction to keep before final layer.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+    create_aux_logits: Whether to include the auxilliary logits.
+    activation_fn: Activation function for conv2d.
+
+  Returns:
+    net: the output of the logits layer (if num_classes is a non-zero integer),
+      or the non-dropped-out input to the logits layer (if num_classes is 0 or
+      None).
+    end_points: the set of end_points from the inception model.
+  """
+    end_points = {}
+
+    with tf.variable_scope(scope, "InceptionResnetV2", [inputs], reuse=reuse) as scope:
+        with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
+
+            net, end_points = inception_resnet_v2_base(
+                inputs, scope=scope, activation_fn=activation_fn
+            )
+
+            if create_aux_logits and num_classes:
+                with tf.variable_scope("AuxLogits"):
+                    aux = end_points["PreAuxLogits"]
+                    aux = slim.avg_pool2d(
+                        aux, 5, stride=3, padding="VALID", scope="Conv2d_1a_3x3"
+                    )
+                    aux = slim.conv2d(aux, 128, 1, scope="Conv2d_1b_1x1")
+                    aux = slim.conv2d(
+                        aux,
+                        768,
+                        aux.get_shape()[1:3],
+                        padding="VALID",
+                        scope="Conv2d_2a_5x5",
+                    )
+                    aux = slim.flatten(aux)
+                    aux = slim.fully_connected(
+                        aux, num_classes, activation_fn=None, scope="Logits"
+                    )
+                    end_points["AuxLogits"] = aux
+
+            with tf.variable_scope("Logits"):
+                # TODO(sguada,arnoegw): Consider adding a parameter global_pool which
+                # can be set to False to disable pooling here (as in resnet_*()).
+                kernel_size = net.get_shape()[1:3]
+                if kernel_size.is_fully_defined():
+                    net = slim.avg_pool2d(
+                        net, kernel_size, padding="VALID", scope="AvgPool_1a_8x8"
+                    )
+                else:
+                    net = tf.reduce_mean(
+                        net, [1, 2], keep_dims=True, name="global_pool"
+                    )
+                end_points["global_pool"] = net
+                if not num_classes:
+                    return net, end_points
+                net = slim.flatten(net)
+                net = slim.dropout(
+                    net, dropout_keep_prob, is_training=is_training, scope="Dropout"
+                )
+                end_points["PreLogitsFlatten"] = net
+                logits = slim.fully_connected(
+                    net, num_classes, activation_fn=None, scope="Logits"
+                )
+                end_points["Logits"] = logits
+                end_points["Predictions"] = tf.nn.softmax(logits, name="Predictions")
+
+        return logits, end_points
+
+
+inception_resnet_v2.default_image_size = 299
+
+
+def inception_resnet_v2_arg_scope(
+    weight_decay=0.00004,
+    batch_norm_decay=0.9997,
+    batch_norm_epsilon=0.001,
+    activation_fn=tf.nn.relu,
+    batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS,
+    batch_norm_scale=False,
+):
+    """Returns the scope with the default parameters for inception_resnet_v2.
+
+  Args:
+    weight_decay: the weight decay for weights variables.
+    batch_norm_decay: decay for the moving average of batch_norm momentums.
+    batch_norm_epsilon: small float added to variance to avoid dividing by zero.
+    activation_fn: Activation function for conv2d.
+    batch_norm_updates_collections: Collection for the update ops for
+      batch norm.
+    batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
+      activations in the batch normalization layer.
+
+  Returns:
+    a arg_scope with the parameters needed for inception_resnet_v2.
+  """
+    # Set weight_decay for weights in conv2d and fully_connected layers.
+    with slim.arg_scope(
+        [slim.conv2d, slim.fully_connected],
+        weights_regularizer=slim.l2_regularizer(weight_decay),
+        biases_regularizer=slim.l2_regularizer(weight_decay),
+    ):
+
+        batch_norm_params = {
+            "decay": batch_norm_decay,
+            "epsilon": batch_norm_epsilon,
+            "updates_collections": batch_norm_updates_collections,
+            "fused": None,  # Use fused batch norm if possible.
+            "scale": batch_norm_scale,
+        }
+        # Set activation_fn and parameters for batch_norm.
+        with slim.arg_scope(
+            [slim.conv2d],
+            activation_fn=activation_fn,
+            normalizer_fn=slim.batch_norm,
+            normalizer_params=batch_norm_params,
+        ) as scope:
+            return scope