Skip to content
Snippets Groups Projects
Commit 047de394 authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

Bug fixes in the arch

parent 57851864
No related branches found
No related tags found
1 merge request!47Many changes
...@@ -44,7 +44,8 @@ from __future__ import print_function ...@@ -44,7 +44,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
def base_architecture(input_layer, mode, data_format, **kwargs): def base_architecture(input_layer, mode, data_format,
skip_first_two_pool=False, **kwargs):
# Keep track of all the endpoints # Keep track of all the endpoints
endpoints = {} endpoints = {}
bn_axis = 1 if data_format.lower() == 'channels_first' else 3 bn_axis = 1 if data_format.lower() == 'channels_first' else 3
...@@ -69,8 +70,12 @@ def base_architecture(input_layer, mode, data_format, **kwargs): ...@@ -69,8 +70,12 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
endpoints['BN-1-activation'] = bn1_act endpoints['BN-1-activation'] = bn1_act
# Pooling Layer #1 # Pooling Layer #1
pool1 = tf.layers.max_pooling2d( if skip_first_two_pool:
inputs=bn1_act, pool_size=[2, 2], strides=2, data_format=data_format) pool1 = bn1_act
else:
pool1 = tf.layers.max_pooling2d(
inputs=bn1_act, pool_size=[2, 2], strides=2,
data_format=data_format)
endpoints['MaxPooling-1'] = pool1 endpoints['MaxPooling-1'] = pool1
# ====================== # ======================
...@@ -92,8 +97,12 @@ def base_architecture(input_layer, mode, data_format, **kwargs): ...@@ -92,8 +97,12 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
endpoints['BN-2-activation'] = bn2_act endpoints['BN-2-activation'] = bn2_act
# Pooling Layer #2 # Pooling Layer #2
pool2 = tf.layers.max_pooling2d( if skip_first_two_pool:
inputs=bn2_act, pool_size=[2, 2], strides=2, data_format=data_format) pool2 = bn2_act
else:
pool2 = tf.layers.max_pooling2d(
inputs=bn2_act, pool_size=[2, 2], strides=2,
data_format=data_format)
endpoints['MaxPooling-2'] = pool2 endpoints['MaxPooling-2'] = pool2
# ====================== # ======================
...@@ -177,7 +186,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs): ...@@ -177,7 +186,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
# Batch Normalization #6 # Batch Normalization #6
bn6 = tf.layers.batch_normalization( bn6 = tf.layers.batch_normalization(
fc_1, axis=bn_axis, training=training, fused=True) fc_1, axis=1, training=training, fused=True)
endpoints['BN-6'] = bn6 endpoints['BN-6'] = bn6
bn6_act = tf.nn.relu(bn6) bn6_act = tf.nn.relu(bn6)
endpoints['BN-6-activation'] = bn6_act endpoints['BN-6-activation'] = bn6_act
...@@ -193,7 +202,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs): ...@@ -193,7 +202,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
# Batch Normalization #7 # Batch Normalization #7
bn7 = tf.layers.batch_normalization( bn7 = tf.layers.batch_normalization(
fc_2, axis=bn_axis, training=training, fused=True) fc_2, axis=1, training=training, fused=True)
endpoints['BN-7'] = bn7 endpoints['BN-7'] = bn7
bn7_act = tf.nn.relu(bn7) bn7_act = tf.nn.relu(bn7)
endpoints['BN-7-activation'] = bn7_act endpoints['BN-7-activation'] = bn7_act
...@@ -203,6 +212,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs): ...@@ -203,6 +212,7 @@ def base_architecture(input_layer, mode, data_format, **kwargs):
def architecture(input_layer, def architecture(input_layer,
mode=tf.estimator.ModeKeys.TRAIN, mode=tf.estimator.ModeKeys.TRAIN,
skip_first_two_pool=False,
n_classes=2, n_classes=2,
data_format='channels_last', data_format='channels_last',
reuse=False, reuse=False,
...@@ -210,7 +220,8 @@ def architecture(input_layer, ...@@ -210,7 +220,8 @@ def architecture(input_layer,
with tf.variable_scope('PatchCNN', reuse=reuse): with tf.variable_scope('PatchCNN', reuse=reuse):
bn7_act, endpoints = base_architecture(input_layer, mode, data_format) bn7_act, endpoints = base_architecture(
input_layer, mode, data_format, skip_first_two_pool)
# Logits layer # Logits layer
logits = tf.layers.dense(inputs=bn7_act, units=n_classes) logits = tf.layers.dense(inputs=bn7_act, units=n_classes)
endpoints['FC-3'] = logits endpoints['FC-3'] = logits
...@@ -229,8 +240,9 @@ def model_fn(features, labels, mode, params=None, config=None): ...@@ -229,8 +240,9 @@ def model_fn(features, labels, mode, params=None, config=None):
momentum = params.get('momentum', 0.99) momentum = params.get('momentum', 0.99)
arch_kwargs = { arch_kwargs = {
'n_classes': params.get('n_classes', None), 'skip_first_two_pool': params.get('skip_first_two_pool'),
'data_format': params.get('data_format', None), 'n_classes': params.get('n_classes'),
'data_format': params.get('data_format'),
} }
arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None} arch_kwargs = {k: v for k, v in arch_kwargs.items() if v is not None}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment