Skip to content
Snippets Groups Projects
Commit 905ffc94 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

small changes: nit, bug fix, small features

parent 2ff7ed83
Branches
Tags
1 merge request!75A lot of new features
...@@ -3,22 +3,19 @@ ...@@ -3,22 +3,19 @@
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch> # @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Fri 04 Aug 2017 14:14:22 CEST # @date: Fri 04 Aug 2017 14:14:22 CEST
## MAXOUT IMPLEMENTED FOR TENSORFLOW # MAXOUT IMPLEMENTED FOR TENSORFLOW
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.layers import base from tensorflow.python.layers import base
import tensorflow as tf
def maxout(inputs, num_units, axis=-1, name=None): def maxout(inputs, num_units, axis=-1, name=None):
return MaxOut(num_units=num_units, axis=axis, name=name)(inputs) return Maxout(num_units=num_units, axis=axis, name=name)(inputs)
class MaxOut(base.Layer): class Maxout(base.Layer):
""" """
Adds a maxout op from Adds a maxout op from
"Maxout Networks" "Maxout Networks"
...@@ -41,33 +38,37 @@ class MaxOut(base.Layer): ...@@ -41,33 +38,37 @@ class MaxOut(base.Layer):
""" """
def __init__(self, num_units, axis=-1, name=None, **kwargs): def __init__(self, num_units, axis=-1, name=None, **kwargs):
super(MaxOut, self).__init__(name=name, trainable=False, **kwargs) super(Maxout, self).__init__(name=name, trainable=False, **kwargs)
self.axis = axis self.axis = axis
self.num_units = num_units self.num_units = num_units
def call(self, inputs, training=False): def call(self, inputs, training=False):
inputs = ops.convert_to_tensor(inputs) inputs = tf.convert_to_tensor(inputs)
shape = inputs.get_shape().as_list() shape = inputs.get_shape().as_list()
if self.axis is None:
# Assume that channel is the last dimension
self.axis = -1
num_channels = shape[self.axis]
if num_channels % self.num_units:
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'.format(
num_channels, self.num_units))
shape[self.axis] = -1
shape += [num_channels // self.num_units]
# Dealing with batches with arbitrary sizes # Dealing with batches with arbitrary sizes
for i in range(len(shape)): for i in range(len(shape)):
if shape[i] is None: if shape[i] is None:
shape[i] = gen_array_ops.shape(inputs)[i] shape[i] = tf.shape(inputs)[i]
outputs = math_ops.reduce_max(
gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
shape = outputs.get_shape().as_list()
shape[self.axis] = self.num_units
outputs.set_shape(shape)
num_channels = shape[self.axis]
if not isinstance(num_channels, tf.Tensor) and num_channels % self.num_units:
raise ValueError(
"number of features({}) is not "
"a multiple of num_units({})".format(num_channels, self.num_units)
)
if self.axis < 0:
axis = self.axis + len(shape)
else:
axis = self.axis
assert axis >= 0, "Find invalid axis: {}".format(self.axis)
expand_shape = shape[:]
expand_shape[axis] = self.num_units
k = num_channels // self.num_units
expand_shape.insert(axis, k)
outputs = tf.math.reduce_max(
tf.reshape(inputs, expand_shape), axis, keepdims=False
)
return outputs return outputs
from .Maxout import maxout from .Maxout import Maxout, maxout
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
...@@ -17,5 +17,8 @@ def __appropriate__(*args): ...@@ -17,5 +17,8 @@ def __appropriate__(*args):
obj.__module__ = __name__ obj.__module__ = __name__
__appropriate__(maxout) __appropriate__(
Maxout,
maxout,
)
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
...@@ -19,9 +19,9 @@ def contrastive_loss(left_embedding, ...@@ -19,9 +19,9 @@ def contrastive_loss(left_embedding,
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
:math:`L = 0.5 * (1-Y) * D^2 + 0.5 * (Y) * {max(0, margin - D)}^2` :math:`L = 0.5 * (1-Y) * D^2 + 0.5 * (Y) * {max(0, margin - D)}^2`
where, `0` are assign for pairs from the same class and `1` from pairs from different classes. where, `0` are assign for pairs from the same class and `1` from pairs from different classes.
**Parameters** **Parameters**
...@@ -65,15 +65,10 @@ def contrastive_loss(left_embedding, ...@@ -65,15 +65,10 @@ def contrastive_loss(left_embedding,
with tf.name_scope("total_loss"): with tf.name_scope("total_loss"):
loss = 0.5 * (within_class + between_class) loss = 0.5 * (within_class + between_class)
loss = tf.reduce_mean(loss, name="total_loss_raw") loss = tf.reduce_mean(loss, name="contrastive_loss")
tf.summary.scalar('loss_raw', loss)
tf.add_to_collection(tf.GraphKeys.LOSSES, loss) tf.add_to_collection(tf.GraphKeys.LOSSES, loss)
## Appending the regularization loss tf.summary.scalar('contrastive_loss', loss)
#regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
#loss = tf.add_n([loss] + regularization_losses, name="total_loss")
tf.summary.scalar('loss', loss)
tf.summary.scalar('between_class', between_class_loss) tf.summary.scalar('between_class', between_class_loss)
tf.summary.scalar('within_class', within_class_loss) tf.summary.scalar('within_class', within_class_loss)
......
...@@ -10,8 +10,9 @@ def append_logits(graph, ...@@ -10,8 +10,9 @@ def append_logits(graph,
n_classes, n_classes,
reuse=False, reuse=False,
l2_regularizer=5e-05, l2_regularizer=5e-05,
weights_std=0.1, trainable_variables=None): weights_std=0.1, trainable_variables=None,
trainable = is_trainable('Logits', trainable_variables) name='Logits'):
trainable = is_trainable(name, trainable_variables)
return slim.fully_connected( return slim.fully_connected(
graph, graph,
n_classes, n_classes,
...@@ -19,7 +20,7 @@ def append_logits(graph, ...@@ -19,7 +20,7 @@ def append_logits(graph,
weights_initializer=tf.truncated_normal_initializer( weights_initializer=tf.truncated_normal_initializer(
stddev=weights_std), stddev=weights_std),
weights_regularizer=slim.l2_regularizer(l2_regularizer), weights_regularizer=slim.l2_regularizer(l2_regularizer),
scope='Logits', scope=name,
reuse=reuse, reuse=reuse,
trainable=trainable, trainable=trainable,
) )
......
...@@ -15,7 +15,15 @@ logger = logging.getLogger(__name__) ...@@ -15,7 +15,15 @@ logger = logging.getLogger(__name__)
@click.command( @click.command(
entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand) entry_point_group='bob.learn.tensorflow.config', cls=ConfigCommand,
epilog="""\b
An example configuration could be::
# define the database:
from bob.bio.base.test.dummy.database import database
groups = ['dev']
biofiles = database.all_files(groups)
"""
)
@click.option( @click.option(
'--database', '--database',
'-d', '-d',
...@@ -50,14 +58,6 @@ def compute_statistics(database, biofiles, load_data, multiple_samples, ...@@ -50,14 +58,6 @@ def compute_statistics(database, biofiles, load_data, multiple_samples,
This script works with bob.bio.base databases. It will load all the samples This script works with bob.bio.base databases. It will load all the samples
and print their mean. and print their mean.
An example configuration could be::
# define the database:
from bob.bio.base.test.dummy.database import database
groups = ['dev']
biofiles = database.all_files(groups)
""" """
log_parameters(logger, ignore=('biofiles', )) log_parameters(logger, ignore=('biofiles', ))
logger.debug("len(biofiles): %d", len(biofiles)) logger.debug("len(biofiles): %d", len(biofiles))
......
...@@ -264,7 +264,8 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name, ...@@ -264,7 +264,8 @@ def eval(estimator, eval_input_fn, hooks, run_once, eval_interval_secs, name,
continue continue
# evaluate based on the just copied checkpoint_path # evaluate based on the just copied checkpoint_path
checkpoint_path = checkpoint_path.replace(estimator.model_dir, eval_dir) checkpoint_path = checkpoint_path.replace(estimator.model_dir, eval_dir + os.sep)
checkpoint_path = os.path.abspath(checkpoint_path)
logger.debug("Evaluating the model from %s", checkpoint_path) logger.debug("Evaluating the model from %s", checkpoint_path)
# Evaluate # Evaluate
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment