Skip to content
Snippets Groups Projects
Commit a12deec7 authored by Pavel KORSHUNOV's avatar Pavel KORSHUNOV
Browse files

added lstm, simple cnn, and simple mlp networks

parent 3c47b589
Branches
Tags
No related merge requests found
Pipeline #
...@@ -3,6 +3,10 @@ from .LightCNN9 import LightCNN9 ...@@ -3,6 +3,10 @@ from .LightCNN9 import LightCNN9
from .Dummy import Dummy from .Dummy import Dummy
from .MLP import MLP from .MLP import MLP
from .Embedding import Embedding from .Embedding import Embedding
from .lstm import simple_lstm_network
from .lstm import RegularizedLoss
from .simplemlp import mlp_network
from .simplecnn import simple2Dcnn_network
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
def __appropriate__(*args): def __appropriate__(*args):
...@@ -23,6 +27,7 @@ __appropriate__( ...@@ -23,6 +27,7 @@ __appropriate__(
LightCNN9, LightCNN9,
Dummy, Dummy,
MLP, MLP,
RegularizedLoss,
) )
__all__ = [_ for _ in dir() if not _.startswith('_')] __all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Pavle Korshunov <pavel.korshunov@idiap.ch>
# @date: Fri 15 Sep 2017 13:22 CEST
"""
LSTM network architecture.
"""
from bob.learn.tensorflow.layers import lstm
import tensorflow as tf
import bob.core
logger = bob.core.log.setup("bob.project.savi")
slim = tf.contrib.slim
class RegularizedLoss(object):
"""
Mean softmax loss with regularization
"""
def __init__(self, name="reg_loss", regularizing_coeff=0.1):
self.name = name
self.regularizing_coeff = regularizing_coeff
tv = tf.trainable_variables()
self.regularization_cost = tf.reduce_sum([tf.nn.l2_loss(v) for v in tv
if 'basic_lstm_cell/kernel' in v.name or 'weights' in v.name])
# for v in tv:
# if 'basic_lstm_cell/kernel' in v.name or 'weights' in v.name:
# print("regularizing:", v.name)
#
def __call__(self, graph, label):
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=graph, labels=label), name=self.name)
# loss = tf.reduce_sum(tf.pow(tf.nn.softmax(graph) - tf.contrib.layers.one_hot_encoding(label, 2), 2))
# return loss
return tf.reduce_mean(loss + self.regularizing_coeff * self.regularization_cost)
def simple_lstm_network(train_data_shuffler, lstm_cell_size=64, batch_size=10,
num_time_steps=28, num_classes=10, seed=10, reuse=False,
dropout=False, input_dropout=1.0, output_dropout=1.0):
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(seed=seed)
# Creating an LSTM network
# graph = tf.contrib.layers.dropout(inputs, keep_prob=0.5, is_training=(not reuse), scope="input_dropout")
graph = lstm(inputs, lstm_cell_size, num_time_steps=num_time_steps, batch_size=batch_size,
output_activation_size=num_classes, scope='lstm', name='sync_cell',
weights_initializer=initializer, activation=tf.nn.sigmoid, reuse=reuse,
dropout=dropout, input_dropout=input_dropout, output_dropout=output_dropout)
# graph = tf.contrib.layers.dropout(graph, keep_prob=0.5, is_training=(not reuse), scope="lstm_dropout")
# graph = tf.layers.batch_normalization(graph, reuse=reuse)
# graph = tf.contrib.layers.batch_norm(graph, trainable=(not reuse), reuse=reuse, scope='batch_norm')
# regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)
regularizer = None
# fully connect the LSTM output to the classes
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Pavle Korshunov <pavel.korshunov@idiap.ch>
# @date: Fri 15 Sep 2017 13:22 CEST
"""
Simple 2-layered 2D CNN network architecture.
"""
import tensorflow as tf
import bob.core
logger = bob.core.log.setup("bob.project.savi")
slim = tf.contrib.slim
def simple2Dcnn_network(train_data_shuffler, num_classes=10, seed=10, reuse=False):
"""
:param train_data_shuffler: The input is expected to have shape (batch_size, num_time_steps, input_vector_size),
"""
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed)
regularizer = None
graph = slim.conv2d(inputs, 32, [5, 5], activation_fn=tf.nn.relu,
stride=1,
weights_initializer=initializer,
weights_regularizer=regularizer,
scope='Conv1',
reuse=reuse)
graph = slim.max_pool2d(graph, [1, 2], stride=2, padding="SAME", scope='Pool1')
graph = slim.conv2d(graph, 32, [3, 3], activation_fn=tf.nn.relu,
stride=1,
weights_initializer=initializer,
weights_regularizer=regularizer,
scope='Conv2',
reuse=reuse)
graph = slim.max_pool2d(graph, [1, 2], stride=2, padding="SAME", scope='Pool2')
graph = slim.flatten(graph, scope='flatten1')
graph = slim.fully_connected(graph, 80,
weights_initializer=initializer,
activation_fn=tf.nn.relu,
scope='fc0',
reuse=reuse)
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Pavle Korshunov <pavel.korshunov@idiap.ch>
# @date: Fri 15 Sep 2017 13:22 CEST
"""
Simple MLP network architecture.
"""
import tensorflow as tf
import bob.core
logger = bob.core.log.setup("bob.project.savi")
slim = tf.contrib.slim
def get_first_frame_from_timeseries(time_series_inputs, num_time_steps):
"""
Args:
time_series_inputs: Expected to have shape (batch_size, num_time_steps, input_vector_size)
Returns: Only the first frame of size (batch_size, input_vector_size)
"""
from tensorflow.python.framework import ops
# shape inputs correctly
inputs = ops.convert_to_tensor(time_series_inputs)
shape = inputs.get_shape().as_list()
logger.info("MLP: the shape of the inputs: {0}".format(shape))
input_time_steps = shape[1] # second dimension must be the number of time steps in LSTM
if len(shape) == 4: # when inputs shape is 4, the last dimension must be 1
if shape[-1] == 1: # we accept last dimension to be 1, then we just reshape it
inputs = tf.reshape(inputs, shape=(-1, shape[1], shape[2]))
logger.info("LSTM: after reshape, the shape of the inputs: {0}".format(inputs.get_shape().as_list()))
else:
raise ValueError('The shape of input must be either (batch_size, num_time_steps, input_vector_size) or '
'(batch_size, num_time_steps, input_vector_size, 1), but it is {}'.format(shape))
if input_time_steps % num_time_steps:
raise ValueError('number of rows in one batch of input ({}) should be '
'the same as the num_time_steps of MLP ({})'
.format(input_time_steps, num_time_steps))
# convert inputs into the num_time_steps list of the inputs each of shape (batch_size, input_vector_size)
list_inputs = tf.unstack(inputs, num_time_steps, 1)
return list_inputs[0]
def mlp_network(train_data_shuffler, hidden_layer_size=64, num_time_steps=28, num_classes=10, seed=10, reuse=False):
"""
:param train_data_shuffler: The input is expected to have shape (batch_size, num_time_steps, input_vector_size),
but only first (batch_size, input_vector_size) will be used as the input input MLP.
"""
if isinstance(train_data_shuffler, tf.Tensor):
inputs = train_data_shuffler
else:
inputs = train_data_shuffler("data", from_queue=False)
inputs = get_first_frame_from_timeseries(inputs, num_time_steps)
initializer = tf.contrib.layers.xavier_initializer(uniform=False, dtype=tf.float32, seed=seed)
regularizer = None
# we just take the first input of size (batch_size, input_vector_size) from the list
# MLP is 2 fully-connected layers output to the classes
graph = slim.fully_connected(inputs, hidden_layer_size, activation_fn=tf.nn.relu, scope='fc0',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
graph = slim.fully_connected(graph, num_classes, activation_fn=None, scope='fc1',
weights_initializer=initializer, weights_regularizer=regularizer, reuse=reuse)
return graph
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment