Skip to content
Snippets Groups Projects
Commit 4f80284c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'resnet101' into 'master'

Properly implemented resnet50 and resnet101

See merge request !94
parents 5a8c7061 324a941d
No related branches found
No related tags found
1 merge request!94Properly implemented resnet50 and resnet101
Pipeline #50181 passed
......@@ -5,7 +5,6 @@ import tensorflow as tf
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import GlobalAvgPool2D
def _check_input(
......@@ -260,7 +259,12 @@ class ModifiedSoftMaxLayer(tf.keras.layers.Layer):
return logits
def add_bottleneck(model, bottleneck_size=128, dropout_rate=0.2):
from tensorflow.keras.layers import Flatten
def add_bottleneck(
model, bottleneck_size=128, dropout_rate=0.2, w_decay=5e-4, use_bias=True
):
"""
Amend a bottleneck layer to a Keras Model
......@@ -276,15 +280,31 @@ def add_bottleneck(model, bottleneck_size=128, dropout_rate=0.2):
dropout_rate: float
Dropout rate
"""
if not isinstance(model, tf.keras.models.Sequential):
new_model = tf.keras.models.Sequential(model, name="bottleneck")
else:
new_model = model
new_model.add(GlobalAvgPool2D())
new_model.add(BatchNormalization())
new_model.add(Dropout(dropout_rate, name="Dropout"))
new_model.add(Dense(bottleneck_size, use_bias=False, name="embeddings"))
new_model.add(BatchNormalization(axis=-1, scale=False, name="embeddings/BatchNorm"))
new_model.add(Flatten())
if w_decay is None:
regularizer = None
else:
regularizer = tf.keras.regularizers.l2(w_decay)
new_model.add(
Dense(
bottleneck_size,
use_bias=use_bias,
kernel_regularizer=regularizer,
)
)
new_model.add(BatchNormalization(axis=-1, name="embeddings"))
# new_model.add(BatchNormalization())
return new_model
......
......@@ -7,6 +7,8 @@ from .densenet import DenseNet
from .densenet import densenet161 # noqa: F401
from .embedding_validation import EmbeddingValidation
from .mine import MineModel
from .resnet50_modified import resnet50_modified # noqa: F401
from .resnet50_modified import resnet101_modified # noqa: F401
# gets sphinx autodoc done right - don't remove it
......
......@@ -12,6 +12,7 @@ class EmbeddingValidation(tf.keras.Model):
def compile(
self,
single_precision=False,
**kwargs,
):
"""
......@@ -27,14 +28,20 @@ class EmbeddingValidation(tf.keras.Model):
"""
X, y = data
with tf.GradientTape() as tape:
logits, _ = self(X, training=True)
loss = self.loss(y, logits)
# trainable_vars = self.trainable_variables
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
self.compiled_metrics.update_state(y, logits, sample_weight=None)
self.train_loss(loss)
tf.summary.scalar("training_loss", data=loss, step=self._train_counter)
return {m.name: m.result() for m in self.metrics + [self.train_loss]}
# self.optimizer.apply_gradients(zip(gradients, trainable_vars))
......
# -*- coding: utf-8 -*-
"""
The resnet50 from `tf.keras.applications.Resnet50` has a problem with the convolutional layers.
It basically add bias terms to such layers followed by batch normalizations, which is not correct
https://github.com/tensorflow/tensorflow/issues/37365
This resnet 50 implementation provides a cleaner version
"""
import tensorflow as tf
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.regularizers import l2
global weight_decay
weight_decay = 1e-4
class IdentityBlock(tf.keras.layers.Layer):
def __init__(
self, kernel_size, filters, stage, block, weight_decay=1e-4, name=None, **kwargs
):
"""Block that has no convolutianal layer as skip connection
Parameters
----------
kernel_size:
The kernel size of middle conv layer at main path
filters:
list of integers, the filterss of 3 conv layer at main path
stage:
Current stage label, used for generating layer names
block:
'a','b'..., current block label, used for generating layer names
"""
super().__init__(name=name, **kwargs)
filters1, filters2, filters3 = filters
bn_axis = 3
conv_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce"
bn_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce/bn"
layers = [
Conv2D(
filters1,
(1, 1),
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_1,
)
]
layers += [BatchNormalization(axis=bn_axis, name=bn_name_1)]
layers += [Activation("relu")]
conv_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3"
bn_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3/bn"
layers += [
Conv2D(
filters2,
kernel_size,
padding="same",
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_2,
)
]
layers += [BatchNormalization(axis=bn_axis, name=bn_name_2)]
layers += [Activation("relu")]
conv_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase"
bn_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase/bn"
layers += [
Conv2D(
filters3,
(1, 1),
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_3,
)
]
layers += [BatchNormalization(axis=bn_axis, name=bn_name_3)]
self.layers = layers
def call(self, input_tensor, training=None):
x = input_tensor
for lay in self.layers:
x = lay(x, training=training)
x = tf.keras.layers.add([x, input_tensor])
x = Activation("relu")(x)
return x
class ConvBlock(tf.keras.layers.Layer):
def __init__(
self,
kernel_size,
filters,
stage,
block,
strides=(2, 2),
weight_decay=1e-4,
name=None,
**kwargs,
):
"""Block that has a conv layer AS shortcut.
Parameters
----------
kernel_size:
The kernel size of middle conv layer at main path
filters:
list of integers, the filterss of 3 conv layer at main path
stage:
Current stage label, used for generating layer names
block:
'a','b'..., current block label, used for generating layer names
"""
super().__init__(name=name, **kwargs)
filters1, filters2, filters3 = filters
bn_axis = 3
conv_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce"
bn_name_1 = "conv" + str(stage) + "_" + str(block) + "_1x1_reduce/bn"
layers = [
Conv2D(
filters1,
(1, 1),
strides=strides,
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_1,
)
]
layers += [BatchNormalization(axis=bn_axis, name=bn_name_1)]
layers += [Activation("relu")]
conv_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3"
bn_name_2 = "conv" + str(stage) + "_" + str(block) + "_3x3/bn"
layers += [
Conv2D(
filters2,
kernel_size,
padding="same",
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_2,
)
]
layers += [BatchNormalization(axis=bn_axis, name=bn_name_2)]
layers += [Activation("relu")]
conv_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase"
bn_name_3 = "conv" + str(stage) + "_" + str(block) + "_1x1_increase/bn"
layers += [
Conv2D(
filters3,
(1, 1),
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_3,
)
]
layers += [BatchNormalization(axis=bn_axis, name=bn_name_3)]
conv_name_4 = "conv" + str(stage) + "_" + str(block) + "_1x1_proj"
bn_name_4 = "conv" + str(stage) + "_" + str(block) + "_1x1_proj/bn"
shortcut = [
Conv2D(
filters3,
(1, 1),
strides=strides,
kernel_initializer="orthogonal",
use_bias=False,
kernel_regularizer=l2(weight_decay),
name=conv_name_4,
)
]
shortcut += [BatchNormalization(axis=bn_axis, name=bn_name_4)]
self.layers = layers
self.shortcut = shortcut
def call(self, input_tensor, training=None):
x = input_tensor
for lay in self.layers:
x = lay(x, training=training)
x_s = input_tensor
for lay in self.shortcut:
x_s = lay(x_s, training=training)
x = tf.keras.layers.add([x, x_s])
x = Activation("relu")(x)
return x
def resnet50_modified(input_tensor=None, input_shape=None, **kwargs):
"""
The resnet50 from `tf.keras.applications.Resnet50` has a problem with the convolutional layers.
It basically add bias terms to such layers followed by batch normalizations, which is not correct
https://github.com/tensorflow/tensorflow/issues/37365
This resnet 50 implementation provides a cleaner version
"""
if input_tensor is None:
input_tensor = tf.keras.Input(shape=input_shape)
else:
if not tf.keras.backend.is_keras_tensor(input_tensor):
input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape)
bn_axis = 3
# inputs are of size 224 x 224 x 3
layers = [input_tensor]
layers += [
Conv2D(
64,
(7, 7),
strides=(2, 2),
kernel_initializer="orthogonal",
use_bias=False,
trainable=True,
kernel_regularizer=l2(weight_decay),
padding="same",
name="conv1/7x7_s2",
)
]
# inputs are of size 112 x 112 x 64
layers += [BatchNormalization(axis=bn_axis, name="conv1/7x7_s2/bn")]
layers += [Activation("relu")]
layers += [MaxPooling2D((3, 3), strides=(2, 2))]
# inputs are of size 56 x 56
layers += [ConvBlock(3, [64, 64, 256], stage=2, block=1, strides=(1, 1))]
layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=2)]
layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=3)]
# inputs are of size 28 x 28
layers += [ConvBlock(3, [128, 128, 512], stage=3, block=1)]
layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=2)]
layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=3)]
layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=4)]
# inputs are of size 14 x 14
layers += [ConvBlock(3, [256, 256, 1024], stage=4, block=1)]
layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=2)]
layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=3)]
layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=4)]
layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=5)]
layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=6)]
# inputs are of size 7 x 7
layers += [ConvBlock(3, [512, 512, 2048], stage=5, block=1)]
layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=2)]
layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=3)]
return tf.keras.Sequential(layers)
def resnet101_modified(input_tensor=None, input_shape=None, **kwargs):
"""
The resnet101 from `tf.keras.applications.Resnet101` has a problem with the convolutional layers.
It basically add bias terms to such layers followed by batch normalizations, which is not correct
https://github.com/tensorflow/tensorflow/issues/37365
This resnet 10 implementation provides a cleaner version
"""
if input_tensor is None:
input_tensor = tf.keras.Input(shape=input_shape)
else:
if not tf.keras.backend.is_keras_tensor(input_tensor):
input_tensor = tf.keras.Input(tensor=input_tensor, shape=input_shape)
bn_axis = 3
# inputs are of size 224 x 224 x 3
layers = [input_tensor]
layers += [
Conv2D(
64,
(7, 7),
strides=(2, 2),
kernel_initializer="orthogonal",
use_bias=False,
trainable=True,
kernel_regularizer=l2(weight_decay),
padding="same",
name="conv1/7x7_s2",
)
]
# inputs are of size 112 x 112 x 64
layers += [BatchNormalization(axis=bn_axis, name="conv1/7x7_s2/bn")]
layers += [Activation("relu")]
layers += [MaxPooling2D((3, 3), strides=(2, 2))]
# inputs are of size 56 x 56
layers += [ConvBlock(3, [64, 64, 256], stage=2, block=1, strides=(1, 1))]
layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=2)]
layers += [IdentityBlock(3, [64, 64, 256], stage=2, block=3)]
# inputs are of size 28 x 28
layers += [ConvBlock(3, [128, 128, 512], stage=3, block=1)]
layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=2)]
layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=3)]
layers += [IdentityBlock(3, [128, 128, 512], stage=3, block=4)]
# inputs are of size 14 x 14
# 23 blocks here. That's the only difference from
# resnet-101
layers += [ConvBlock(3, [256, 256, 1024], stage=4, block=1)]
for i in range(2, 24):
layers += [IdentityBlock(3, [256, 256, 1024], stage=4, block=i)]
# inputs are of size 7 x 7
layers += [ConvBlock(3, [512, 512, 2048], stage=5, block=1)]
layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=2)]
layers += [IdentityBlock(3, [512, 512, 2048], stage=5, block=3)]
return tf.keras.Sequential(layers)
if __name__ == "__main__":
input_tensor = tf.keras.layers.InputLayer([112, 112, 3])
model = resnet50_modified(input_tensor)
print(len(model.variables))
print(model.summary())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment