From d957c74aceb787c9198d0549a0992ccdbec76e31 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Fri, 7 Feb 2020 16:07:40 +0100
Subject: [PATCH] make sure densenet layer names are consistent

---
 bob/learn/tensorflow/models/densenet.py | 44 ++++++++++++++++---------
 1 file changed, 29 insertions(+), 15 deletions(-)

diff --git a/bob/learn/tensorflow/models/densenet.py b/bob/learn/tensorflow/models/densenet.py
index 0cda6666..e1f70d07 100644
--- a/bob/learn/tensorflow/models/densenet.py
+++ b/bob/learn/tensorflow/models/densenet.py
@@ -5,10 +5,6 @@ Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
 
 """
 
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
 import tensorflow as tf
 from bob.extension import rc
 
@@ -27,9 +23,15 @@ class ConvBlock(tf.keras.Model):
     """
 
     def __init__(
-        self, num_filters, data_format, bottleneck, weight_decay=1e-4, dropout_rate=0
+        self,
+        num_filters,
+        data_format,
+        bottleneck,
+        weight_decay=1e-4,
+        dropout_rate=0,
+        **kwargs,
     ):
-        super().__init__()
+        super().__init__(**kwargs)
         self.bottleneck = bottleneck
 
         axis = -1 if data_format == "channels_last" else 1
@@ -47,7 +49,7 @@ class ConvBlock(tf.keras.Model):
                 kernel_regularizer=l2(weight_decay),
                 name="conv1",
             )
-            self.norm2 = tf.keras.layers.BatchNormalization(axis=axis)
+            self.norm2 = tf.keras.layers.BatchNormalization(axis=axis, name="norm2")
 
         self.relu2 = tf.keras.layers.Activation("relu", name="relu2")
         # don't forget to set use_bias=False when using batchnorm
@@ -103,16 +105,22 @@ class DenseBlock(tf.keras.Model):
         bottleneck,
         weight_decay=1e-4,
         dropout_rate=0,
+        **kwargs,
     ):
-        super(DenseBlock, self).__init__()
+        super().__init__(**kwargs)
         self.num_layers = num_layers
         self.axis = -1 if data_format == "channels_last" else 1
 
         self.blocks = []
-        for _ in range(int(self.num_layers)):
+        for i in range(int(self.num_layers)):
             self.blocks.append(
                 ConvBlock(
-                    growth_rate, data_format, bottleneck, weight_decay, dropout_rate
+                    growth_rate,
+                    data_format,
+                    bottleneck,
+                    weight_decay,
+                    dropout_rate,
+                    name=f"conv_block_{i+1}",
                 )
             )
 
@@ -134,8 +142,10 @@ class TransitionBlock(tf.keras.Model):
         dropout_rate: dropout rate.
     """
 
-    def __init__(self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0):
-        super(TransitionBlock, self).__init__()
+    def __init__(
+        self, num_filters, data_format, weight_decay=1e-4, dropout_rate=0, **kwargs
+    ):
+        super().__init__(**kwargs)
         axis = -1 if data_format == "channels_last" else 1
 
         self.norm = tf.keras.layers.BatchNormalization(axis=axis, name="norm")
@@ -200,8 +210,10 @@ class DenseNet(tf.keras.Model):
         dropout_rate=0,
         pool_initial=False,
         include_top=True,
+        name="DenseNet",
+        **kwargs,
     ):
-        super(DenseNet, self).__init__()
+        super().__init__(name=name, **kwargs)
         self.depth_of_model = depth_of_model
         self.growth_rate = growth_rate
         self.num_of_blocks = num_of_blocks
@@ -302,6 +314,7 @@ class DenseNet(tf.keras.Model):
                     self.bottleneck,
                     self.weight_decay,
                     self.dropout_rate,
+                    name=f"dense_block_{i+1}",
                 )
             )
             if i + 1 < self.num_of_blocks:
@@ -311,6 +324,7 @@ class DenseNet(tf.keras.Model):
                         self.data_format,
                         self.weight_decay,
                         self.dropout_rate,
+                        name=f"transition_block_{i+1}",
                     )
                 )
 
@@ -408,15 +422,15 @@ class DeepPixBiS(tf.keras.Model):
             tf.keras.layers.Conv2D(
                 filters=1,
                 kernel_size=1,
-                name="dec",
                 kernel_initializer="he_normal",
                 kernel_regularizer=l2(weight_decay),
                 data_format=data_format,
+                name="dec",
             ),
             tf.keras.layers.Flatten(
                 data_format=data_format, name="Pixel_Logits_Flatten"
             ),
-            tf.keras.layers.Activation("sigmoid"),
+            tf.keras.layers.Activation("sigmoid", name="activation"),
         ]
 
     def call(self, x, training=None):
-- 
GitLab