autoencoder_face.py 3.08 KB
Newer Older
Amir MOHAMMADI's avatar
Amir MOHAMMADI committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
import tensorflow as tf
from .densenet import densenet161


def _get_l2_kw(weight_decay):
    l2_kw = {}
    if weight_decay is not None:
        l2_kw = {"kernel_regularizer": tf.keras.regularizers.l2(weight_decay)}
    return l2_kw


class ConvDecoder(tf.keras.Sequential):
    """The decoder similar to the one in
    https://github.com/google/compare_gan/blob/master/compare_gan/architectures/sndcgan.py
    """

    def __init__(
        self,
        z_dim,
        decoder_layers=(
            (512, 7, 7, 0),
            (256, 4, 2, 1),
            (128, 4, 2, 1),
            (64, 4, 2, 1),
            (32, 4, 2, 1),
            (16, 4, 2, 1),
            (3, 1, 1, 0),
        ),
        weight_decay=1e-5,
        name="Decoder",
        **kwargs,
    ):
        self.z_dim = z_dim
        self.data_format = data_format = "channels_last"
        l2_kw = _get_l2_kw(weight_decay)
        layers = [
            tf.keras.layers.Reshape((1, 1, z_dim), input_shape=(z_dim,), name="reshape")
        ]
        for i, (filters, kernel_size, strides, cropping) in enumerate(decoder_layers):
            dconv = tf.keras.layers.Conv2DTranspose(
                filters,
                kernel_size,
                strides=strides,
                use_bias=i == len(decoder_layers) - 1,
                data_format=data_format,
                name=f"dconv_{i}",
                **l2_kw,
            )
            crop = tf.keras.layers.Cropping2D(
                cropping=cropping, data_format=data_format, name=f"crop_{i}"
            )

            if i == len(decoder_layers) - 1:
                act = tf.keras.layers.Activation("tanh", name=f"tanh_{i}")
                bn = None
            else:
                act = tf.keras.layers.Activation("relu", name=f"relu_{i}")
                bn = tf.keras.layers.BatchNormalization(
                    scale=False, fused=False, name=f"bn_{i}"
                )
            if bn is not None:
                layers.extend([dconv, crop, bn, act])
            else:
                layers.extend([dconv, crop, act])
        with tf.name_scope(name):
            super().__init__(layers=layers, name=name, **kwargs)


class Autoencoder(tf.keras.Model):
    """
    A class defining a simple convolutional autoencoder.

    Attributes
    ----------
    data_format : str
        channels_last is only supported
    decoder : object
        The encoder part
    encoder : object
        The decoder part
    """

    def __init__(self, encoder, decoder, name="Autoencoder", **kwargs):
        super().__init__(name=name, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, x, training=None):
        z = self.encoder(x, training=training)
        x_hat = self.decoder(z, training=training)
        return z, x_hat

def autoencoder_face(z_dim=256, weight_decay=1e-9):
    encoder = densenet161(
        output_classes=z_dim, weight_decay=weight_decay, weights=None, name="DenseNet"
    )
    decoder = ConvDecoder(z_dim=z_dim, weight_decay=weight_decay, name="Decoder")
    autoencoder = Autoencoder(encoder, decoder, name="Autoencoder")
    return autoencoder