Skip to content
Snippets Groups Projects
Commit 17cd04f8 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Still scratching

parent 60788ef5
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 17:38 CEST
import tensoflow as tf
from bob.learn.tensorflow.util import *
from .Layer import Layer
class Conv2D(Layer):
"""
2D Convolution
"""
def __init__(self, input, activation=None,
kernel_size=3,
filters=8,
initialization='xavier',
use_gpu=False,
seed=10
):
"""
Base constructor
**Parameters**
input: Layer input
activation: Tensor Flow activation
kernel_size: Size of the convolutional kernel
filters: Number of filters
initialization: Initialization type
use_gpu: Store data in the GPU
seed: Seed for the Random number generation
"""
super(Conv2D, self).__init__(input, initialization='xavier', use_gpu=False, seed=10)
self.activation = activation
self.W = create_weight_variables([kernel_size, kernel_size, 1, filters],
seed=seed, name="conv", use_gpu=use_gpu)
if activation is not None:
self.b = create_bias_variables([filters], name="bias", use_gpu=self.use_gpu)
def get_graph(self):
with tf.name_scope('conv'):
conv = tf.nn.conv2d(self.input, self.W, strides=[1, 1, 1, 1], padding='SAME')
with tf.name_scope('activation'):
non_linearity = tf.nn.tanh(tf.nn.bias_add(conv, self.b))
return non_linearity
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 17:38 CEST
import tensoflow as tf
class Layer(object):
"""
Layer base class
"""
def __init__(self, input, initialization='xavier', use_gpu=False, seed=10):
"""
Base constructor
**Parameters**
input: Layer input
"""
self.input = input
self.initialization = initialization
self.use_gpu = use_gpu
self.seed = seed
def get_graph(self):
NotImplementedError("Please implement this function in derived classes")
# see https://docs.python.org/3/library/pkgutil.html
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
from DataShuffler import *
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 17:38 CEST
import tensoflow as tf
from bob.learn.tensorflow.util import *
from .Layer import Layer
class MaxPooling(Layer):
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment