Commit d957c74a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

make sure densenet layer names are consistent

parent fa765388
......@@ -5,10 +5,6 @@ Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from bob.extension import rc
......@@ -27,9 +23,15 @@ class ConvBlock(tf.keras.Model):
"""
def __init__(
self, num_filters, data_format, bottleneck, weight_decay=1e-4, dropout_rate=0
self,
num_filters,
data_format,
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.bottleneck = bottleneck
axis = -1 if data_format == "channels_last" else 1
......@@ -47,7 +49,7 @@ class ConvBlock(tf.keras.Model):
kernel_regularizer=l2(weight_decay),
name="conv1",
)
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis)
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis, name="norm2")
self.relu2 = tf.keras.layers.Activation("relu", name="relu2")
# don't forget to set use_bias=False when using batchnorm
......@@ -103,16 +105,22 @@ class DenseBlock(tf.keras.Model):
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
):
super(DenseBlock, self).__init__()
super().__init__(**kwargs)
self.num_layers = num_layers
self.axis = -1 if data_format == "channels_last" else 1
self.blocks = []
for _ in range(int(self.num_layers)):
for i in range(int(self.num_layers)):
self.blocks.append(
ConvBlock(
growth_rate, data_format, bottleneck, weight_decay, dropout_rate
growth_rate,
data_format,
bottleneck,
weight_decay,
dropout_rate,
name=f"conv_block_{i+1}",
)
)
......@@ -134,8 +142,10 @@ class TransitionBlock(tf.keras.Model):
dropout_rate: dropout rate.
"""
def __init__(self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0):
super(TransitionBlock, self).__init__()
def __init__(
self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0, **kwargs
):
super().__init__(**kwargs)
axis = -1 if data_format == "channels_last" else 1
self.norm = tf.keras.layers.BatchNormalization(axis=axis, name="norm")
......@@ -200,8 +210,10 @@ class DenseNet(tf.keras.Model):
dropout_rate=0,
pool_initial=False,
include_top=True,
name="DenseNet",
**kwargs,
):
super(DenseNet, self).__init__()
super().__init__(name=name, **kwargs)
self.depth_of_model = depth_of_model
self.growth_rate = growth_rate
self.num_of_blocks = num_of_blocks
......@@ -302,6 +314,7 @@ class DenseNet(tf.keras.Model):
self.bottleneck,
self.weight_decay,
self.dropout_rate,
name=f"dense_block_{i+1}",
)
)
if i + 1 < self.num_of_blocks:
......@@ -311,6 +324,7 @@ class DenseNet(tf.keras.Model):
self.data_format,
self.weight_decay,
self.dropout_rate,
name=f"transition_block_{i+1}",
)
)
......@@ -408,15 +422,15 @@ class DeepPixBiS(tf.keras.Model):
tf.keras.layers.Conv2D(
filters=1,
kernel_size=1,
name="dec",
kernel_initializer="he_normal",
kernel_regularizer=l2(weight_decay),
data_format=data_format,
name="dec",
),
tf.keras.layers.Flatten(
data_format=data_format, name="Pixel_Logits_Flatten"
),
tf.keras.layers.Activation("sigmoid"),
tf.keras.layers.Activation("sigmoid", name="activation"),
]
def call(self, x, training=None):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment