Skip to content
Snippets Groups Projects

Many changes

Merged Amir MOHAMMADI requested to merge amir into master
1 file
+ 22
10
Compare changes
  • Side-by-side
  • Inline
@@ -44,7 +44,8 @@ from __future__ import print_function
import tensorflow as tf
def base_architecture(input_layer, mode, data_format, **kwargs):
def base_architecture(input_layer, mode, data_format,
skip_first_two_pool=False, **kwargs):
# Keep track of all the endpoints
endpoints = {}
bn_axis = 1 if data_format.lower() == 'channels_first' else 3
@@ -69,8 +70,12 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
endpoints['BN-1-activation'] = bn1_act
# Pooling Layer #1
pool1 = tf.layers.max_pooling2d(
inputs=bn1_act, pool_size=[2, 2], strides=2, data_format=data_format)
if skip_first_two_pool:
pool1 = bn1_act
else:
pool1 = tf.layers.max_pooling2d(
inputs=bn1_act, pool_size=[2, 2], strides=2,
data_format=data_format)
endpoints['MaxPooling-1'] = pool1
# ======================
@@ -92,8 +97,12 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
endpoints['BN-2-activation'] = bn2_act
# Pooling Layer #2
pool2 = tf.layers.max_pooling2d(
inputs=bn2_act, pool_size=[2, 2], strides=2, data_format=data_format)
if skip_first_two_pool:
pool2 = bn2_act
else:
pool2 = tf.layers.max_pooling2d(
inputs=bn2_act, pool_size=[2, 2], strides=2,
data_format=data_format)
endpoints['MaxPooling-2'] = pool2
# ======================
@@ -177,7 +186,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
# Batch Normalization #6
bn6 = tf.layers.batch_normalization(
fc_1, axis=bn_axis, training=training, fused=True)
fc_1, axis=1, training=training, fused=True)
endpoints['BN-6'] = bn6
bn6_act = tf.nn.relu(bn6)
endpoints['BN-6-activation'] = bn6_act
@@ -193,7 +202,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
# Batch Normalization #7
bn7 = tf.layers.batch_normalization(
fc_2, axis=bn_axis, training=training, fused=True)
fc_2, axis=1, training=training, fused=True)
endpoints['BN-7'] = bn7
bn7_act = tf.nn.relu(bn7)
endpoints['BN-7-activation'] = bn7_act
@@ -203,6 +212,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
def architecture(input_layer,
mode=tf.estimator.ModeKeys.TRAIN,
skip_first_two_pool=False,
n_classes=2,
data_format='channels_last',
reuse=False,
@@ -210,7 +220,8 @@ def architecture(input_layer,
with tf.variable_scope('PatchCNN', reuse=reuse):
bn7_act, endpoints = base_architecture(input_layer, mode, data_format)
bn7_act, endpoints = base_architecture(
input_layer, mode, data_format, skip_first_two_pool)
# Logits layer
logits = tf.layers.dense(inputs=bn7_act, units=n_classes)
endpoints['FC-3'] = logits
@@ -229,8 +240,9 @@ def model_fn(features, labels, mode, params=None, config=None):
momentum = params.get('momentum', 0.99)
arch_kwargs = {
'n_classes': params.get('n_classes', None),
'data_format': params.get('data_format', None),
'skip_first_two_pool': params.get('skip_first_two_pool'),
'n_classes': params.get('n_classes'),
'data_format': params.get('data_format'),
}
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
Loading