Commit 586efd56 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fix issue with VGG16 from slim. The slim models adds the hot-encoded in the architecture function

Organizing name scopes
parent 52622ba7
Pipeline #30172 passed with stage
in 168 minutes and 3 seconds
......@@ -9,7 +9,11 @@ VGG16 and VGG19 wrappers
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import vgg
from tensorflow.contrib.layers.python.layers import layers as layers_lib
from tensorflow.contrib import layers
import tensorflow.contrib.slim as slim
from tensorflow.python.ops import variable_scope
from .utils import is_trainable
def vgg_19(inputs,
......@@ -40,7 +44,10 @@ def vgg_19(inputs,
def vgg_16(inputs,
reuse=None,
mode=tf.estimator.ModeKeys.TRAIN, **kwargs):
mode=tf.estimator.ModeKeys.TRAIN,
trainable_variables=None,
scope="vgg_16",
**kwargs):
"""
Oxford Net VGG 16-Layers version E Example from tf-slim
......@@ -57,9 +64,71 @@ def vgg_16(inputs,
Estimator mode keys
"""
with slim.arg_scope(
[slim.conv2d],
trainable=mode==tf.estimator.ModeKeys.TRAIN):
return vgg.vgg_16(inputs, spatial_squeeze=False)
dropout_keep_prob = 0.5
end_points = {}
with variable_scope.variable_scope(scope, 'vgg_16', [inputs]) as sc:
end_points_collection = sc.original_name_scope + '_end_points'
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope(
[layers.conv2d, layers_lib.fully_connected, layers_lib.max_pool2d], outputs_collections=end_points_collection):
with slim.arg_scope(
[slim.conv2d],
trainable=mode==tf.estimator.ModeKeys.TRAIN):
name = "conv1"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers_lib.repeat(
inputs, 2, layers.conv2d, 64, [3, 3], scope=name, trainable=trainable)
net = layers_lib.max_pool2d(net, [2, 2], scope='pool1')
end_points[name] = net
name = "conv2"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers_lib.repeat(net, 2, layers.conv2d, 128, [3, 3], scope=name, trainable=trainable)
net = layers_lib.max_pool2d(net, [2, 2], scope='pool2')
end_points[name] = net
name = "conv3"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers_lib.repeat(net, 3, layers.conv2d, 256, [3, 3], scope=name, trainable=trainable)
net = layers_lib.max_pool2d(net, [2, 2], scope='pool3')
end_points[name] = net
name = "conv4"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope=name, trainable=trainable)
net = layers_lib.max_pool2d(net, [2, 2], scope='pool4')
end_points[name] = net
name = "conv5"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers_lib.repeat(net, 3, layers.conv2d, 512, [3, 3], scope=name, trainable=trainable)
net = layers_lib.max_pool2d(net, [2, 2], scope='pool5')
end_points[name] = net
net = layers.flatten(net)
# Use conv2d instead of fully_connected layers.
name = "fc6"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers.fully_connected(net, 4096, scope=name, trainable=trainable)
net = layers_lib.dropout(
net, dropout_keep_prob, is_training=mode==tf.estimator.ModeKeys.TRAIN, scope='dropout6')
end_points[name] = net
name = "fc7"
trainable = is_trainable(name, trainable_variables, mode=mode)
net = layers.fully_connected(net, 4096, scope=name, trainable=trainable)
net = layers_lib.dropout(
net, dropout_keep_prob, is_training=mode==tf.estimator.ModeKeys.TRAIN, scope='dropout7')
end_points[name] = net
# Convert end_points_collection into a end_point dict.
return net, end_points
#return vgg.vgg_16(inputs, spatial_squeeze=False)
......@@ -100,6 +100,7 @@ def test_inceptionv1_adaptation():
def test_vgg():
tf.reset_default_graph()
# Testing VGG19 Training mode
inputs = tf.placeholder(tf.float32, shape=(1, 224, 224, 3))
graph, _ = vgg_19(inputs)
......@@ -116,10 +117,11 @@ def test_vgg():
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
# Testing VGG 16 training mode
# Testing VGG 16 training mode
inputs = tf.placeholder(tf.float32, shape=(1, 224, 224, 3))
graph, _ = vgg_16(inputs)
assert len(tf.trainable_variables()) == 32
assert len(tf.trainable_variables()) == 30
tf.reset_default_graph()
assert len(tf.global_variables()) == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment