Skip to content
Snippets Groups Projects
Commit bdbc9988 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Changed LightCNN to functions issue #37

parent ad6a9bba
No related branches found
No related tags found
1 merge request!19Updates
Pipeline #
#!/usr/bin/env python #!/usr/bin/env python
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.layers import maxout from bob.learn.tensorflow.layers import maxout
......
#!/usr/bin/env python #!/usr/bin/env python
# vim: set fileencoding=utf-8 : # vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import tensorflow as tf import tensorflow as tf
from bob.learn.tensorflow.layers import maxout from bob.learn.tensorflow.layers import maxout
from .utils import append_logits from .utils import append_logits
class LightCNN9(object): def light_cnn9(inputs, seed=10, reuse=False):
"""Creates the graph for the Light CNN-9 in """Creates the graph for the Light CNN-9 in
Wu, Xiang, et al. "A light CNN for deep face representation with noisy labels." arXiv preprint arXiv:1511.02683 (2015). Wu, Xiang, et al. "A light CNN for deep face representation with noisy labels." arXiv preprint arXiv:1511.02683 (2015).
""" """
def __init__(self, slim = tf.contrib.slim
seed=10,
n_classes=10):
self.seed = seed with tf.variable_scope('LightCNN9', reuse=reuse):
self.n_classes = n_classes
def __call__(self, inputs, reuse=False, get_class_layer=True, end_point="logits"): initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed)
slim = tf.contrib.slim
#with tf.device(self.device):
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=self.seed)
end_points = dict() end_points = dict()
graph = slim.conv2d(inputs, 96, [5, 5], activation_fn=tf.nn.relu, graph = slim.conv2d(inputs, 96, [5, 5], activation_fn=tf.nn.relu,
...@@ -141,24 +132,14 @@ class LightCNN9(object): ...@@ -141,24 +132,14 @@ class LightCNN9(object):
graph = slim.flatten(graph, scope='flatten1') graph = slim.flatten(graph, scope='flatten1')
end_points['flatten1'] = graph end_points['flatten1'] = graph
graph = slim.dropout(graph, keep_prob=0.3, scope='dropout1') graph = slim.dropout(graph, keep_prob=0.5, scope='dropout1')
graph = slim.fully_connected(graph, 512, prelogits = slim.fully_connected(graph, 512,
weights_initializer=initializer, weights_initializer=initializer,
activation_fn=tf.nn.relu, activation_fn=tf.nn.relu,
scope='fc1', scope='fc1',
reuse=reuse) reuse=reuse)
end_points['fc1'] = graph end_points['fc1'] = prelogits
#graph = maxout(graph,
# num_units=256,
# name='Maxoutfc1')
graph = slim.dropout(graph, keep_prob=0.3, scope='dropout2')
if self.n_classes is not None:
# Appending the logits layer
graph = append_logits(graph, self.n_classes, reuse)
end_points['logits'] = graph
return end_points[end_point] return prelogits, end_points
from .Chopra import chopra from .Chopra import chopra
from .LightCNN9 import LightCNN9 from .LightCNN9 import light_cnn9
from .LightCNN29 import LightCNN29 from .LightCNN29 import LightCNN29
from .Dummy import Dummy from .Dummy import Dummy
from .MLP import MLP from .MLP import MLP
...@@ -24,7 +24,7 @@ def __appropriate__(*args): ...@@ -24,7 +24,7 @@ def __appropriate__(*args):
__appropriate__( __appropriate__(
Chopra, Chopra,
LightCNN9, light_cnn9,
Dummy, Dummy,
MLP, MLP,
) )
......
...@@ -120,7 +120,7 @@ def test_center_loss_tfrecord_embedding_validation(): ...@@ -120,7 +120,7 @@ def test_center_loss_tfrecord_embedding_validation():
prelogits=prelogits prelogits=prelogits
) )
trainer.train() trainer.train()
assert True assert True
tf.reset_default_graph() tf.reset_default_graph()
del trainer del trainer
......
...@@ -483,8 +483,8 @@ class Trainer(object): ...@@ -483,8 +483,8 @@ class Trainer(object):
# Appending histograms for each trainable variables # Appending histograms for each trainable variables
#for var in tf.trainable_variables(): #for var in tf.trainable_variables():
for var in tf.global_variables(): #for var in tf.global_variables():
tf.summary.histogram(var.op.name, var) # tf.summary.histogram(var.op.name, var)
# Train summary # Train summary
tf.summary.scalar('loss', average_loss) tf.summary.scalar('loss', average_loss)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment