Skip to content
Snippets Groups Projects
Commit cd3f1fee authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

make sure densenet layer names are consistent

parent 4523531b
Branches
No related tags found
No related merge requests found
......@@ -5,10 +5,6 @@ Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from bob.extension import rc
......@@ -27,9 +23,15 @@ class ConvBlock(tf.keras.Model):
"""
def __init__(
self, num_filters, data_format, bottleneck, weight_decay=1e-4, dropout_rate=0
self,
num_filters,
data_format,
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.bottleneck = bottleneck
axis = -1 if data_format == "channels_last" else 1
......@@ -47,7 +49,7 @@ class ConvBlock(tf.keras.Model):
kernel_regularizer=l2(weight_decay),
name="conv1",
)
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis)
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis, name="norm2")
self.relu2 = tf.keras.layers.Activation("relu", name="relu2")
# don't forget to set use_bias=False when using batchnorm
......@@ -103,16 +105,22 @@ class DenseBlock(tf.keras.Model):
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
):
super(DenseBlock, self).__init__()
super().__init__(**kwargs)
self.num_layers = num_layers
self.axis = -1 if data_format == "channels_last" else 1
self.blocks = []
for _ in range(int(self.num_layers)):
for i in range(int(self.num_layers)):
self.blocks.append(
ConvBlock(
growth_rate, data_format, bottleneck, weight_decay, dropout_rate
growth_rate,
data_format,
bottleneck,
weight_decay,
dropout_rate,
name=f"conv_block_{i+1}",
)
)
......@@ -134,8 +142,10 @@ class TransitionBlock(tf.keras.Model):
dropout_rate: dropout rate.
"""
def __init__(self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0):
super(TransitionBlock, self).__init__()
def __init__(
self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0, **kwargs
):
super().__init__(**kwargs)
axis = -1 if data_format == "channels_last" else 1
self.norm = tf.keras.layers.BatchNormalization(axis=axis, name="norm")
......@@ -200,8 +210,10 @@ class DenseNet(tf.keras.Model):
dropout_rate=0,
pool_initial=False,
include_top=True,
name="DenseNet",
**kwargs,
):
super(DenseNet, self).__init__()
super().__init__(name=name, **kwargs)
self.depth_of_model = depth_of_model
self.growth_rate = growth_rate
self.num_of_blocks = num_of_blocks
......@@ -302,6 +314,7 @@ class DenseNet(tf.keras.Model):
self.bottleneck,
self.weight_decay,
self.dropout_rate,
name=f"dense_block_{i+1}",
)
)
if i + 1 < self.num_of_blocks:
......@@ -311,6 +324,7 @@ class DenseNet(tf.keras.Model):
self.data_format,
self.weight_decay,
self.dropout_rate,
name=f"transition_block_{i+1}",
)
)
......@@ -408,15 +422,15 @@ class DeepPixBiS(tf.keras.Model):
tf.keras.layers.Conv2D(
filters=1,
kernel_size=1,
name="dec",
kernel_initializer="he_normal",
kernel_regularizer=l2(weight_decay),
data_format=data_format,
name="dec",
),
tf.keras.layers.Flatten(
data_format=data_format, name="Pixel_Logits_Flatten"
),
tf.keras.layers.Activation("sigmoid"),
tf.keras.layers.Activation("sigmoid", name="activation"),
]
def call(self, x, training=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment