Skip to content
Snippets Groups Projects
Commit d957c74a authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

make sure densenet layer names are consistent

parent fa765388
No related branches found
No related tags found
1 merge request!79Add keras-based models, add pixel-wise loss, other improvements
......@@ -5,10 +5,6 @@ Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from bob.extension import rc
......@@ -27,9 +23,15 @@ class ConvBlock(tf.keras.Model):
"""
def __init__(
self, num_filters, data_format, bottleneck, weight_decay=1e-4, dropout_rate=0
self,
num_filters,
data_format,
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.bottleneck = bottleneck
axis = -1 if data_format == "channels_last" else 1
......@@ -47,7 +49,7 @@ class ConvBlock(tf.keras.Model):
kernel_regularizer=l2(weight_decay),
name="conv1",
)
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis)
self.norm2 = tf.keras.layers.BatchNormalization(axis=axis, name="norm2")
self.relu2 = tf.keras.layers.Activation("relu", name="relu2")
# don't forget to set use_bias=False when using batchnorm
......@@ -103,16 +105,22 @@ class DenseBlock(tf.keras.Model):
bottleneck,
weight_decay=1e-4,
dropout_rate=0,
**kwargs,
):
super(DenseBlock, self).__init__()
super().__init__(**kwargs)
self.num_layers = num_layers
self.axis = -1 if data_format == "channels_last" else 1
self.blocks = []
for _ in range(int(self.num_layers)):
for i in range(int(self.num_layers)):
self.blocks.append(
ConvBlock(
growth_rate, data_format, bottleneck, weight_decay, dropout_rate
growth_rate,
data_format,
bottleneck,
weight_decay,
dropout_rate,
name=f"conv_block_{i+1}",
)
)
......@@ -134,8 +142,10 @@ class TransitionBlock(tf.keras.Model):
dropout_rate: dropout rate.
"""
def __init__(self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0):
super(TransitionBlock, self).__init__()
def __init__(
self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0, **kwargs
):
super().__init__(**kwargs)
axis = -1 if data_format == "channels_last" else 1
self.norm = tf.keras.layers.BatchNormalization(axis=axis, name="norm")
......@@ -200,8 +210,10 @@ class DenseNet(tf.keras.Model):
dropout_rate=0,
pool_initial=False,
include_top=True,
name="DenseNet",
**kwargs,
):
super(DenseNet, self).__init__()
super().__init__(name=name, **kwargs)
self.depth_of_model = depth_of_model
self.growth_rate = growth_rate
self.num_of_blocks = num_of_blocks
......@@ -302,6 +314,7 @@ class DenseNet(tf.keras.Model):
self.bottleneck,
self.weight_decay,
self.dropout_rate,
name=f"dense_block_{i+1}",
)
)
if i + 1 < self.num_of_blocks:
......@@ -311,6 +324,7 @@ class DenseNet(tf.keras.Model):
self.data_format,
self.weight_decay,
self.dropout_rate,
name=f"transition_block_{i+1}",
)
)
......@@ -408,15 +422,15 @@ class DeepPixBiS(tf.keras.Model):
tf.keras.layers.Conv2D(
filters=1,
kernel_size=1,
name="dec",
kernel_initializer="he_normal",
kernel_regularizer=l2(weight_decay),
data_format=data_format,
name="dec",
),
tf.keras.layers.Flatten(
data_format=data_format, name="Pixel_Logits_Flatten"
),
tf.keras.layers.Activation("sigmoid"),
tf.keras.layers.Activation("sigmoid", name="activation"),
]
def call(self, x, training=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment