Commit 8235f252 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Created layer with maxout

parent 7f70d520
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Fri 04 Aug 2017 14:14:22 CEST
## MAXOUT IMPLEMENTED FOR TENSORFLOW
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.layers import base
def maxout(inputs, num_units, axis=-1, name=None):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
"Maxout Networks"
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua
Bengio
Usually the operation is performed in the filter/channel dimension. This can also be
used after fully-connected layers to reduce number of features.
Args:
inputs: Tensor input
num_units: Specifies how many features will remain after maxout in the `axis` dimension (usually channel).
This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
name: Optional scope for name_scope.
Returns:
A `Tensor` representing the results of the pooling operation.
Raises:
ValueError: if num_units is not multiple of number of features.
"""
return MaxOut(num_units=num_units, axis=axis, name=name)(inputs)
class MaxOut(base.Layer):
"""Adds a maxout op from https://arxiv.org/abs/1302.4389
"Maxout Networks"
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, Yoshua
Bengio
Usually the operation is performed in the filter/channel dimension. This can also be
used after fully-connected layers to reduce number of features.
Args:
inputs: Tensor input
num_units: Specifies how many features will remain after maxout in the `axis` dimension (usually channel).
This must be multiple of number of `axis`.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
name: Optional scope for name_scope.
Returns:
A `Tensor` representing the results of the pooling operation.
Raises:
ValueError: if num_units is not multiple of number of features.
"""
def __init__(self,
num_units,
axis=-1,
name=None,
**kwargs):
super(MaxOut, self).__init__(
name=name, trainable=False, **kwargs)
self.axis = axis
self.num_units = num_units
def call(self, inputs, training=False):
inputs = ops.convert_to_tensor(inputs)
shape = inputs.get_shape().as_list()
if self.axis is None:
# Assume that channel is the last dimension
self.axis = -1
num_channels = shape[self.axis]
if num_channels % self.num_units:
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'
.format(num_channels, self.num_units))
shape[self.axis] = -1
shape += [num_channels // self.num_units]
# Dealing with batches with arbitrary sizes
for i in range(len(shape)):
if shape[i] is None:
shape[i] = gen_array_ops.shape(inputs)[i]
outputs = math_ops.reduce_max(gen_array_ops.reshape(inputs, shape), -1, keep_dims=False)
return outputs
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Thu 13 Oct 2016 13:35 CEST
import tensorflow as tf
import numpy as np
from bob.learn.tensorflow.layers import maxout
from nose.tools import assert_raises_regexp
slim = tf.contrib.slim
def test_simple():
x = np.zeros([64, 10, 36])
graph = maxout(x, num_units=3)
assert graph.get_shape().as_list() == [64, 10, 3]
def test_fully_connected():
x = np.zeros([64, 50])
graph = slim.fully_connected(x, 50, activation_fn=None)
graph = maxout(graph, num_units=10)
assert graph.get_shape().as_list() == [64, 10]
def test_nchw():
x = np.random.uniform(size=(10, 100, 100, 3)).astype(np.float32)
graph = slim.conv2d(x, 10, [3, 3])
graph = maxout(graph, num_units=1)
assert graph.get_shape().as_list() == [10, 100, 100, 1]
def test_invalid_shape():
x = np.random.uniform(size=(10, 100, 100, 3)).astype(np.float32)
graph = slim.conv2d(x, 3, [3, 3])
with assert_raises_regexp(ValueError, 'number of features'):
graph = maxout(graph, num_units=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment