Commit d957c74a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

make sure densenet layer names are consistent

parent fa765388
...@@ -5,10 +5,6 @@ Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993) ...@@ -5,10 +5,6 @@ Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from bob.extension import rc from bob.extension import rc
...@@ -27,9 +23,15 @@ class ConvBlock(tf.keras.Model): ...@@ -27,9 +23,15 @@ class ConvBlock(tf.keras.Model):
""" """
def __init__( def __init__(
self, num_filters, data_format, bottleneck, weight_decay=1e-4, dropout_rate=0 self,
num_filters,
data_format,
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
): ):
super().__init__() super().__init__(**kwargs)
self.bottleneck = bottleneck self.bottleneck = bottleneck
axis = -1 if data_format == "channels_last" else 1 axis = -1 if data_format == "channels_last" else 1
...@@ -47,7 +49,7 @@ class ConvBlock(tf.keras.Model): ...@@ -47,7 +49,7 @@ class ConvBlock(tf.keras.Model):
kernel_regularizer=l2(weight_decay), kernel_regularizer=l2(weight_decay),
name="conv1", name="conv1",
) )
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis) self.norm2 = tf.keras.layers.BatchNormalization(axis=axis, name="norm2")
self.relu2 = tf.keras.layers.Activation("relu", name="relu2") self.relu2 = tf.keras.layers.Activation("relu", name="relu2")
# don't forget to set use_bias=False when using batchnorm # don't forget to set use_bias=False when using batchnorm
...@@ -103,16 +105,22 @@ class DenseBlock(tf.keras.Model): ...@@ -103,16 +105,22 @@ class DenseBlock(tf.keras.Model):
bottleneck, bottleneck,
weight_decay=1e-4, weight_decay=1e-4,
dropout_rate=0, dropout_rate=0,
**kwargs,
): ):
super(DenseBlock, self).__init__() super().__init__(**kwargs)
self.num_layers = num_layers self.num_layers = num_layers
self.axis = -1 if data_format == "channels_last" else 1 self.axis = -1 if data_format == "channels_last" else 1
self.blocks = [] self.blocks = []
for _ in range(int(self.num_layers)): for i in range(int(self.num_layers)):
self.blocks.append( self.blocks.append(
ConvBlock( ConvBlock(
growth_rate, data_format, bottleneck, weight_decay, dropout_rate growth_rate,
data_format,
bottleneck,
weight_decay,
dropout_rate,
name=f"conv_block_{i+1}",
) )
) )
...@@ -134,8 +142,10 @@ class TransitionBlock(tf.keras.Model): ...@@ -134,8 +142,10 @@ class TransitionBlock(tf.keras.Model):
dropout_rate: dropout rate. dropout_rate: dropout rate.
""" """
def __init__(self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0): def __init__(
super(TransitionBlock, self).__init__() self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0, **kwargs
):
super().__init__(**kwargs)
axis = -1 if data_format == "channels_last" else 1 axis = -1 if data_format == "channels_last" else 1
self.norm = tf.keras.layers.BatchNormalization(axis=axis, name="norm") self.norm = tf.keras.layers.BatchNormalization(axis=axis, name="norm")
...@@ -200,8 +210,10 @@ class DenseNet(tf.keras.Model): ...@@ -200,8 +210,10 @@ class DenseNet(tf.keras.Model):
dropout_rate=0, dropout_rate=0,
pool_initial=False, pool_initial=False,
include_top=True, include_top=True,
name="DenseNet",
**kwargs,
): ):
super(DenseNet, self).__init__() super().__init__(name=name, **kwargs)
self.depth_of_model = depth_of_model self.depth_of_model = depth_of_model
self.growth_rate = growth_rate self.growth_rate = growth_rate
self.num_of_blocks = num_of_blocks self.num_of_blocks = num_of_blocks
...@@ -302,6 +314,7 @@ class DenseNet(tf.keras.Model): ...@@ -302,6 +314,7 @@ class DenseNet(tf.keras.Model):
self.bottleneck, self.bottleneck,
self.weight_decay, self.weight_decay,
self.dropout_rate, self.dropout_rate,
name=f"dense_block_{i+1}",
) )
) )
if i + 1 < self.num_of_blocks: if i + 1 < self.num_of_blocks:
...@@ -311,6 +324,7 @@ class DenseNet(tf.keras.Model): ...@@ -311,6 +324,7 @@ class DenseNet(tf.keras.Model):
self.data_format, self.data_format,
self.weight_decay, self.weight_decay,
self.dropout_rate, self.dropout_rate,
name=f"transition_block_{i+1}",
) )
) )
...@@ -408,15 +422,15 @@ class DeepPixBiS(tf.keras.Model): ...@@ -408,15 +422,15 @@ class DeepPixBiS(tf.keras.Model):
tf.keras.layers.Conv2D( tf.keras.layers.Conv2D(
filters=1, filters=1,
kernel_size=1, kernel_size=1,
name="dec",
kernel_initializer="he_normal", kernel_initializer="he_normal",
kernel_regularizer=l2(weight_decay), kernel_regularizer=l2(weight_decay),
data_format=data_format, data_format=data_format,
name="dec",
), ),
tf.keras.layers.Flatten( tf.keras.layers.Flatten(
data_format=data_format, name="Pixel_Logits_Flatten" data_format=data_format, name="Pixel_Logits_Flatten"
), ),
tf.keras.layers.Activation("sigmoid"), tf.keras.layers.Activation("sigmoid", name="activation"),
] ]
def call(self, x, training=None): def call(self, x, training=None):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment