Skip to content
Snippets Groups Projects

Many changes

Merged Amir MOHAMMADI requested to merge amir into master
1 file
+ 83
35
Compare changes
  • Side-by-side
  • Inline
import tensorflow as tf
import tensorflow as tf
def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
def create_conv_layer(inputs, mode, data_format, endpoints, number, filters,
 
kernel_size, pool_size, pool_strides,
 
add_batch_norm=False):
 
bn_axis = 1 if data_format.lower() == 'channels_first' else 3
 
training = mode == tf.estimator.ModeKeys.TRAIN
 
 
if add_batch_norm:
 
activation = None
 
else:
 
activation = tf.nn.relu
 
 
name = 'conv{}'.format(number)
 
conv = tf.layers.conv2d(
 
inputs=inputs,
 
filters=filters,
 
kernel_size=kernel_size,
 
padding="same",
 
activation=activation,
 
data_format=data_format)
 
endpoints[name] = conv
 
 
if add_batch_norm:
 
name = 'bn{}'.format(number)
 
bn = tf.layers.batch_normalization(
 
conv, axis=bn_axis, training=training)
 
endpoints[name] = bn
 
 
name = 'activation{}'.format(number)
 
bn_act = tf.nn.relu(bn)
 
endpoints[name] = bn_act
 
else:
 
bn_act = conv
 
 
name = 'pool{}'.format(number)
 
pool = tf.layers.max_pooling2d(
 
inputs=bn_act, pool_size=pool_size, strides=pool_strides,
 
padding='same', data_format=data_format)
 
endpoints[name] = pool
 
 
return pool
 
 
 
def base_architecture(input_layer, mode, kernerl_size, data_format,
 
add_batch_norm=False, **kwargs):
 
training = mode == tf.estimator.ModeKeys.TRAIN
# Keep track of all the endpoints
# Keep track of all the endpoints
endpoints = {}
endpoints = {}
@@ -9,38 +53,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
@@ -9,38 +53,18 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
# Computes 32 features using a kernerl_size filter with ReLU
# Computes 32 features using a kernerl_size filter with ReLU
# activation.
# activation.
# Padding is added to preserve width and height.
# Padding is added to preserve width and height.
conv1 = tf.layers.conv2d(
pool1 = create_conv_layer(
inputs=input_layer,
inputs=input_layer, mode=mode, data_format=data_format,
filters=32,
endpoints=endpoints, number=1, filters=32, kernel_size=kernerl_size,
kernel_size=kernerl_size,
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm)
padding="same",
activation=tf.nn.relu,
data_format=data_format)
endpoints['conv1'] = conv1
# Pooling Layer #1
# First max pooling layer with a 2x2 filter and stride of 2
pool1 = tf.layers.max_pooling2d(
inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format)
endpoints['pool1'] = pool1
# Convolutional Layer #2
# Convolutional Layer #2
# Computes 64 features using a kernerl_size filter.
# Computes 64 features using a kernerl_size filter.
# Padding is added to preserve width and height.
# Padding is added to preserve width and height.
conv2 = tf.layers.conv2d(
pool2 = create_conv_layer(
inputs=pool1,
inputs=pool1, mode=mode, data_format=data_format,
filters=64,
endpoints=endpoints, number=2, filters=64, kernel_size=kernerl_size,
kernel_size=kernerl_size,
pool_size=(2, 2), pool_strides=2, add_batch_norm=add_batch_norm)
padding="same",
activation=tf.nn.relu,
data_format=data_format)
endpoints['conv2'] = conv2
# Pooling Layer #2
# Second max pooling layer with a 2x2 filter and stride of 2
pool2 = tf.layers.max_pooling2d(
inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format)
endpoints['pool2'] = pool2
# Flatten tensor into a batch of vectors
# Flatten tensor into a batch of vectors
# TODO: use tf.layers.flatten in tensorflow 1.4 and above
# TODO: use tf.layers.flatten in tensorflow 1.4 and above
@@ -49,13 +73,30 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
@@ -49,13 +73,30 @@ def base_architecture(input_layer, mode, kernerl_size, data_format, **kwargs):
# Dense Layer
# Dense Layer
# Densely connected layer with 1024 neurons
# Densely connected layer with 1024 neurons
 
if add_batch_norm:
 
activation = None
 
else:
 
activation = tf.nn.relu
 
dense = tf.layers.dense(
dense = tf.layers.dense(
inputs=pool2_flat, units=1024, activation=tf.nn.relu)
inputs=pool2_flat, units=1024, activation=activation)
endpoints['dense'] = dense
endpoints['dense'] = dense
 
if add_batch_norm:
 
name = 'bn{}'.format(3)
 
bn = tf.layers.batch_normalization(
 
dense, axis=1, training=training)
 
endpoints[name] = bn
 
 
name = 'activation{}'.format(3)
 
bn_act = tf.nn.relu(bn)
 
endpoints[name] = bn_act
 
else:
 
bn_act = dense
 
# Add dropout operation; 0.6 probability that element will be kept
# Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout(
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
inputs=bn_act, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN)
endpoints['dropout'] = dropout
endpoints['dropout'] = dropout
return dropout, endpoints
return dropout, endpoints
@@ -67,12 +108,14 @@ def architecture(input_layer,
@@ -67,12 +108,14 @@ def architecture(input_layer,
n_classes=2,
n_classes=2,
data_format='channels_last',
data_format='channels_last',
reuse=False,
reuse=False,
 
add_batch_norm=False,
**kwargs):
**kwargs):
with tf.variable_scope('SimpleCNN', reuse=reuse):
with tf.variable_scope('SimpleCNN', reuse=reuse):
dropout, endpoints = base_architecture(input_layer, mode, kernerl_size,
dropout, endpoints = base_architecture(
data_format)
input_layer, mode, kernerl_size, data_format,
 
add_batch_norm=add_batch_norm)
# Logits layer
# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, n_classes]
# Output Tensor Shape: [batch_size, n_classes]
@@ -94,6 +137,7 @@ def model_fn(features, labels, mode, params=None, config=None):
@@ -94,6 +137,7 @@ def model_fn(features, labels, mode, params=None, config=None):
'kernerl_size': params.get('kernerl_size', None),
'kernerl_size': params.get('kernerl_size', None),
'n_classes': params.get('n_classes', None),
'n_classes': params.get('n_classes', None),
'data_format': params.get('data_format', None),
'data_format': params.get('data_format', None),
 
'add_batch_norm': params.get('add_batch_norm', None)
}
}
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
@@ -118,10 +162,14 @@ def model_fn(features, labels, mode, params=None, config=None):
@@ -118,10 +162,14 @@ def model_fn(features, labels, mode, params=None, config=None):
# Configure the training op
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf.estimator.ModeKeys.TRAIN:
 
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.GradientDescentOptimizer(
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=learning_rate)
learning_rate=learning_rate)
train_op = optimizer.minimize(
# for batch normalization to be updated as well:
loss=loss, global_step=tf.train.get_or_create_global_step())
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
 
with tf.control_dependencies(update_ops):
 
train_op = optimizer.minimize(
 
loss=loss, global_step=global_step)
# Log accuracy and loss
# Log accuracy and loss
with tf.name_scope('train_metrics'):
with tf.name_scope('train_metrics'):
tf.summary.scalar('accuracy', accuracy[1])
tf.summary.scalar('accuracy', accuracy[1])
Loading