From bdbc99889af399ae94dbc0c3bba19203a89474a2 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Thu, 12 Oct 2017 14:20:12 +0200
Subject: [PATCH] Changed LightCNN to functions issue #37

---
 bob/learn/tensorflow/network/LightCNN29.py    |  1 -
 bob/learn/tensorflow/network/LightCNN9.py     | 35 +++++--------------
 bob/learn/tensorflow/network/__init__.py      |  4 +--
 .../tensorflow/test/test_cnn_other_losses.py  |  2 +-
 bob/learn/tensorflow/trainers/Trainer.py      |  4 +--
 5 files changed, 13 insertions(+), 33 deletions(-)

diff --git a/bob/learn/tensorflow/network/LightCNN29.py b/bob/learn/tensorflow/network/LightCNN29.py
index af030802..dc5bba68 100755
--- a/bob/learn/tensorflow/network/LightCNN29.py
+++ b/bob/learn/tensorflow/network/LightCNN29.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST 
 
 import tensorflow as tf
 from bob.learn.tensorflow.layers import maxout
diff --git a/bob/learn/tensorflow/network/LightCNN9.py b/bob/learn/tensorflow/network/LightCNN9.py
index 2eada46f..296e5e6e 100755
--- a/bob/learn/tensorflow/network/LightCNN9.py
+++ b/bob/learn/tensorflow/network/LightCNN9.py
@@ -1,30 +1,21 @@
 #!/usr/bin/env python
 # vim: set fileencoding=utf-8 :
 # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
-# @date: Wed 11 May 2016 09:39:36 CEST 
 
 import tensorflow as tf
 from bob.learn.tensorflow.layers import maxout
 from .utils import append_logits
 
-class LightCNN9(object):
+def light_cnn9(inputs, seed=10, reuse=False):
     """Creates the graph for the Light CNN-9 in 
 
        Wu, Xiang, et al. "A light CNN for deep face representation with noisy labels." arXiv preprint arXiv:1511.02683 (2015).
     """
-    def __init__(self,
-                 seed=10,
-                 n_classes=10):
+    slim = tf.contrib.slim
 
-            self.seed = seed
-            self.n_classes = n_classes
+    with tf.variable_scope('LightCNN9', reuse=reuse):
 
-    def __call__(self, inputs, reuse=False, get_class_layer=True, end_point="logits"):
-        slim = tf.contrib.slim
-
-        #with tf.device(self.device):
-
-        initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=self.seed)
+        initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed)
         end_points = dict()
                     
         graph = slim.conv2d(inputs, 96, [5, 5], activation_fn=tf.nn.relu,
@@ -141,24 +132,14 @@ class LightCNN9(object):
         graph = slim.flatten(graph, scope='flatten1')
         end_points['flatten1'] = graph        
 
-        graph = slim.dropout(graph, keep_prob=0.3, scope='dropout1')
+        graph = slim.dropout(graph, keep_prob=0.5, scope='dropout1')
 
-        graph = slim.fully_connected(graph, 512,
+        prelogits = slim.fully_connected(graph, 512,
                                      weights_initializer=initializer,
                                      activation_fn=tf.nn.relu,
                                      scope='fc1',
                                      reuse=reuse)
-        end_points['fc1'] = graph                                     
-        #graph = maxout(graph,
-        #               num_units=256,
-        #               name='Maxoutfc1')
-        
-        graph = slim.dropout(graph, keep_prob=0.3, scope='dropout2')
-
-        if self.n_classes is not None:
-            # Appending the logits layer
-            graph = append_logits(graph, self.n_classes, reuse)
-            end_points['logits'] = graph
+        end_points['fc1'] = prelogits
 
-        return end_points[end_point]
+    return prelogits, end_points
 
diff --git a/bob/learn/tensorflow/network/__init__.py b/bob/learn/tensorflow/network/__init__.py
index 68ed993e..ec2b64e6 100755
--- a/bob/learn/tensorflow/network/__init__.py
+++ b/bob/learn/tensorflow/network/__init__.py
@@ -1,5 +1,5 @@
 from .Chopra import chopra
-from .LightCNN9 import LightCNN9
+from .LightCNN9 import light_cnn9
 from .LightCNN29 import LightCNN29
 from .Dummy import Dummy
 from .MLP import MLP
@@ -24,7 +24,7 @@ def __appropriate__(*args):
 
 __appropriate__(
     Chopra,
-    LightCNN9,
+    light_cnn9,
     Dummy,
     MLP,
     )
diff --git a/bob/learn/tensorflow/test/test_cnn_other_losses.py b/bob/learn/tensorflow/test/test_cnn_other_losses.py
index dfcbc34e..f40a6d90 100755
--- a/bob/learn/tensorflow/test/test_cnn_other_losses.py
+++ b/bob/learn/tensorflow/test/test_cnn_other_losses.py
@@ -120,7 +120,7 @@ def test_center_loss_tfrecord_embedding_validation():
                                         prelogits=prelogits
                                         )
     trainer.train()
-
+    
     assert True
     tf.reset_default_graph()
     del trainer
diff --git a/bob/learn/tensorflow/trainers/Trainer.py b/bob/learn/tensorflow/trainers/Trainer.py
index 25c660e8..733561ea 100755
--- a/bob/learn/tensorflow/trainers/Trainer.py
+++ b/bob/learn/tensorflow/trainers/Trainer.py
@@ -483,8 +483,8 @@ class Trainer(object):
 
         # Appending histograms for each trainable variables
         #for var in tf.trainable_variables():
-        for var in tf.global_variables():
-            tf.summary.histogram(var.op.name, var)
+        #for var in tf.global_variables():
+        #    tf.summary.histogram(var.op.name, var)
         
         # Train summary
         tf.summary.scalar('loss', average_loss)
-- 
GitLab