Skip to content
Snippets Groups Projects
Commit 8235f252 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created layer with maxout

parent 7f70d520
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Fri 04 Aug 2017 14:14:22 CEST
## MAXOUT IMPLEMENTED FOR TENSORFLOW
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.layers import base
def maxout(inputs, num_units, axis=-1, name=None):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
"Maxout Networks"
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua
Bengio
Usually the operation is performed in the filter/channel dimension. This can also be
used after fully-connected layers to reduce number of features.
Args:
inputs: Tensor input
num_units: Specifies how many features will remain after maxout in the `axis` dimension (usually channel).
This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
name: Optional scope for name_scope.
Returns:
A `Tensor` representing the results of the pooling operation.
Raises:
ValueError: if num_units is not multiple of number of features.
"""
return MaxOut(num_units=num_units, axis=axis, name=name)(inputs)
class MaxOut(base.Layer):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
"Maxout Networks"
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua
Bengio
Usually the operation is performed in the filter/channel dimension. This can also be
used after fully-connected layers to reduce number of features.
Args:
inputs: Tensor input
num_units: Specifies how many features will remain after maxout in the `axis` dimension (usually channel).
This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
name: Optional scope for name_scope.
Returns:
A `Tensor` representing the results of the pooling operation.
Raises:
ValueError: if num_units is not multiple of number of features.
"""
def __init__(self,
num_units,
axis=-1,
name=None,
**kwargs):
super(MaxOut, self).__init__(
name=name, trainable=False, **kwargs)
self.axis = axis
self.num_units = num_units
def call(self, inputs, training=False):
inputs = ops.convert_to_tensor(inputs)
shape = inputs.get_shape().as_list()
if self.axis is None:
# Assume that channel is the last dimension
self.axis = -1
num_channels = shape[self.axis]
if num_channels % self.num_units:
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'
.format(num_channels, self.num_units))
shape[self.axis] = -1
shape += [num_channels // self.num_units]
# Dealing with batches with arbitrary sizes
for i in range(len(shape)):
if shape[i] is None:
shape[i] = gen_array_ops.shape(inputs)[i]
outputs = math_ops.reduce_max(gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
return outputs
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import tensorflow as tf
import numpy as np
from bob.learn.tensorflow.layers import maxout
from nose.tools import assert_raises_regexp
slim = tf.contrib.slim
def test_simple():
x = np.zeros([64, 10, 36])
graph = maxout(x, num_units=3)
assert graph.get_shape().as_list() == [64, 10, 3]
def test_fully_connected():
x = np.zeros([64, 50])
graph = slim.fully_connected(x, 50, activation_fn=None)
graph = maxout(graph, num_units=10)
assert graph.get_shape().as_list() == [64, 10]
def test_nchw():
x = np.random.uniform(size=(10, 100, 100, 3)).astype(np.float32)
graph = slim.conv2d(x, 10, [3, 3])
graph = maxout(graph, num_units=1)
assert graph.get_shape().as_list() == [10, 100, 100, 1]
def test_invalid_shape():
x = np.random.uniform(size=(10, 100, 100, 3)).astype(np.float32)
graph = slim.conv2d(x, 3, [3, 3])
with assert_raises_regexp(ValueError, 'number of features'):
graph = maxout(graph, num_units=2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment