Fixed issue #13

parent 21229e01
...@@ -36,7 +36,6 @@ class Base(object): ...@@ -36,7 +36,6 @@ class Base(object):
self.scale_value = 0.00390625 self.scale_value = 0.00390625
self.input_dtype = input_dtype self.input_dtype = input_dtype
# TODO: Check if the bacth size is higher than the input data # TODO: Check if the bacth size is higher than the input data
self.batch_size = batch_size self.batch_size = batch_size
...@@ -60,6 +59,10 @@ class Base(object): ...@@ -60,6 +59,10 @@ class Base(object):
self.data_augmentation = data_augmentation self.data_augmentation = data_augmentation
self.deployment_shape = [None] + list(input_shape) self.deployment_shape = [None] + list(input_shape)
def set_placeholders(self, data, label):
self.data_placeholder = data
self.label_placeholder = label
def get_placeholders(self, name=""): def get_placeholders(self, name=""):
""" """
Returns a place holder with the size of your batch Returns a place holder with the size of your batch
...@@ -156,3 +159,4 @@ class Base(object): ...@@ -156,3 +159,4 @@ class Base(object):
...@@ -21,6 +21,11 @@ class Siamese(Base): ...@@ -21,6 +21,11 @@ class Siamese(Base):
super(Siamese, self).__init__(**kwargs) super(Siamese, self).__init__(**kwargs)
self.data2_placeholder = None self.data2_placeholder = None
def set_placeholders(self, data, data2, label):
self.data_placeholder = data
self.data2_placeholder = data2
self.label_placeholder = label
def get_placeholders(self, name=""): def get_placeholders(self, name=""):
""" """
Returns a place holder with the size of your batch Returns a place holder with the size of your batch
......
...@@ -121,6 +121,7 @@ def test_cnn_trainer(): ...@@ -121,6 +121,7 @@ def test_cnn_trainer():
# At least 80% of accuracy # At least 80% of accuracy
assert accuracy > 80. assert accuracy > 80.
shutil.rmtree(directory) shutil.rmtree(directory)
del chopra
def test_siamesecnn_trainer(): def test_siamesecnn_trainer():
...@@ -166,6 +167,7 @@ def test_siamesecnn_trainer(): ...@@ -166,6 +167,7 @@ def test_siamesecnn_trainer():
# At least 80% of accuracy # At least 80% of accuracy
assert eer < 0.25 assert eer < 0.25
shutil.rmtree(directory) shutil.rmtree(directory)
del chopra
def test_tripletcnn_trainer(): def test_tripletcnn_trainer():
...@@ -212,3 +214,4 @@ def test_tripletcnn_trainer(): ...@@ -212,3 +214,4 @@ def test_tripletcnn_trainer():
# At least 80% of accuracy # At least 80% of accuracy
assert eer < 0.25 assert eer < 0.25
shutil.rmtree(directory) shutil.rmtree(directory)
del chopra
...@@ -68,7 +68,7 @@ def test_cnn_trainer_scratch(): ...@@ -68,7 +68,7 @@ def test_cnn_trainer_scratch():
scratch = scratch_network() scratch = scratch_network()
trainer = Trainer(architecture=scratch, trainer = Trainer(architecture=scratch,
loss=loss, loss=loss,
iterations=iterations, iterations=iterations+200,
analizer=None, analizer=None,
prefetch=False, prefetch=False,
learning_rate=constant(0.05, name="lr2"), learning_rate=constant(0.05, name="lr2"),
...@@ -76,5 +76,8 @@ def test_cnn_trainer_scratch(): ...@@ -76,5 +76,8 @@ def test_cnn_trainer_scratch():
model_from_file=os.path.join(directory, "model.ckp")) model_from_file=os.path.join(directory, "model.ckp"))
trainer.train(train_data_shuffler) trainer.train(train_data_shuffler)
accuracy = validate_network(validation_data, validation_labels, directory)
assert accuracy > 90 accuracy = validate_network(validation_data, validation_labels, directory2)
assert accuracy > 85
shutil.rmtree(directory)
shutil.rmtree(directory2)
...@@ -102,6 +102,42 @@ class SiameseTrainer(Trainer): ...@@ -102,6 +102,42 @@ class SiameseTrainer(Trainer):
self.between_class_graph_validation = None self.between_class_graph_validation = None
self.within_class_graph_validation = None self.within_class_graph_validation = None
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
"""
# Persisting the placeholders
if self.prefetch:
batch, batch2, label = train_data_shuffler.get_placeholders_forprefetch()
else:
batch, batch2, label = train_data_shuffler.get_placeholders()
tf.add_to_collection("train_placeholder_data", batch)
tf.add_to_collection("train_placeholder_data2", batch2)
tf.add_to_collection("train_placeholder_label", label)
# Creating validation graph
if validation_data_shuffler is not None:
batch, batch2, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_data2", batch2)
tf.add_to_collection("validation_placeholder_label", label)
def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Load placeholders from file
"""
train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
tf.get_collection("train_placeholder_data2")[0],
tf.get_collection("train_placeholder_label")[0])
if validation_data_shuffler is not None:
train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
tf.get_collection("validation_placeholder_data2")[0],
tf.get_collection("validation_placeholder_label")[0])
def compute_graph(self, data_shuffler, prefetch=False, name="", train=True): def compute_graph(self, data_shuffler, prefetch=False, name="", train=True):
""" """
Computes the graph for the trainer. Computes the graph for the trainer.
......
...@@ -75,10 +75,7 @@ class Trainer(object): ...@@ -75,10 +75,7 @@ class Trainer(object):
self.loss = loss self.loss = loss
self.temp_dir = temp_dir self.temp_dir = temp_dir
#self.base_learning_rate = base_learning_rate
self.learning_rate = learning_rate self.learning_rate = learning_rate
#self.weight_decay = weight_decay
#self.decay_steps = decay_steps
self.iterations = iterations self.iterations = iterations
self.snapshot = snapshot self.snapshot = snapshot
...@@ -116,12 +113,12 @@ class Trainer(object): ...@@ -116,12 +113,12 @@ class Trainer(object):
""" """
Computes the graph for the trainer. Computes the graph for the trainer.
** Parameters ** ** Parameters **
data_shuffler: Data shuffler data_shuffler: Data shuffler
prefetch: prefetch: Uses prefetch
name: Name of the graph name: Name of the graph
training: Is it a training graph?
""" """
# Defining place holders # Defining place holders
...@@ -203,8 +200,7 @@ class Trainer(object): ...@@ -203,8 +200,7 @@ class Trainer(object):
if self.validation_summary_writter is None: if self.validation_summary_writter is None:
self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph) self.validation_summary_writter = tf.train.SummaryWriter(os.path.join(self.temp_dir, 'validation'), session.graph)
summaries = [] summaries = [summary_pb2.Summary.Value(tag="loss", simple_value=float(l))]
summaries.append(summary_pb2.Summary.Value(tag="loss", simple_value=float(l)))
self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step) self.validation_summary_writter.add_summary(summary_pb2.Summary(value=summaries), step)
logger.info("Loss VALIDATION set step={0} = {1}".format(step, l)) logger.info("Loss VALIDATION set step={0} = {1}".format(step, l))
...@@ -212,7 +208,6 @@ class Trainer(object): ...@@ -212,7 +208,6 @@ class Trainer(object):
""" """
Creates a simple tensorboard summary with the value of the loss and learning rate Creates a simple tensorboard summary with the value of the loss and learning rate
""" """
# Train summary # Train summary
tf.scalar_summary('loss', self.training_graph, name="train") tf.scalar_summary('loss', self.training_graph, name="train")
tf.scalar_summary('lr', self.learning_rate, name="train") tf.scalar_summary('lr', self.learning_rate, name="train")
...@@ -251,12 +246,9 @@ class Trainer(object): ...@@ -251,12 +246,9 @@ class Trainer(object):
session.run(self.enqueue_op, feed_dict=feed_dict) session.run(self.enqueue_op, feed_dict=feed_dict)
def create_graphs(self, train_data_shuffler, validation_data_shuffler): def bootstrap_graphs(self, train_data_shuffler, validation_data_shuffler):
""" """
Create all the necessary graphs for training, validation and inference graphs
:param train_data_shuffler:
:param validation_data_shuffler:
:return:
""" """
# Creating train graph # Creating train graph
...@@ -269,11 +261,70 @@ class Trainer(object): ...@@ -269,11 +261,70 @@ class Trainer(object):
tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder) tf.add_to_collection("inference_placeholder", self.architecture.inference_placeholder)
tf.add_to_collection("inference_graph", self.architecture.inference_graph) tf.add_to_collection("inference_graph", self.architecture.inference_graph)
# Creating validation graph
if validation_data_shuffler is not None: if validation_data_shuffler is not None:
# Creating validation graph
self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False) self.validation_graph = self.compute_graph(validation_data_shuffler, name="validation", training=False)
tf.add_to_collection("validation_graph", self.validation_graph) tf.add_to_collection("validation_graph", self.validation_graph)
batch, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_label", label)
self.bootstrap_placeholders(train_data_shuffler, validation_data_shuffler)
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
"""
# Persisting the placeholders
if self.prefetch:
batch, label = train_data_shuffler.get_placeholders_forprefetch()
else:
batch, label = train_data_shuffler.get_placeholders()
tf.add_to_collection("train_placeholder_data", batch)
tf.add_to_collection("train_placeholder_label", label)
# Creating validation graph
if validation_data_shuffler is not None:
batch, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_label", label)
def bootstrap_graphs_fromfile(self, session, train_data_shuffler, validation_data_shuffler):
"""
Bootstrap all the necessary data from file
"""
saver = self.architecture.load(session, self.model_from_file)
# Loading training graph
self.training_graph = tf.get_collection("training_graph")[0]
# Loding other elements
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
self.summaries_train = tf.get_collection("summaries_train")[0]
if validation_data_shuffler is not None:
self.validation_graph = tf.get_collection("validation_graph")[0]
self.bootstrap_placeholders_fromfile(train_data_shuffler, validation_data_shuffler)
return saver
def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Load placeholders from file
"""
train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
tf.get_collection("train_placeholder_label")[0])
if validation_data_shuffler is not None:
train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
tf.get_collection("validation_placeholder_label")[0])
def train(self, train_data_shuffler, validation_data_shuffler=None): def train(self, train_data_shuffler, validation_data_shuffler=None):
""" """
Train the network Train the network
...@@ -296,12 +347,10 @@ class Trainer(object): ...@@ -296,12 +347,10 @@ class Trainer(object):
# Loading a pretrained model # Loading a pretrained model
if self.model_from_file != "": if self.model_from_file != "":
logger.info("Loading pretrained model from {0}".format(self.model_from_file)) logger.info("Loading pretrained model from {0}".format(self.model_from_file))
saver = self.architecture.load(session, self.model_from_file) saver = self.bootstrap_graphs_fromfile(session, train_data_shuffler, validation_data_shuffler)
self.training_graph = tf.get_collection("training_graph")[0]
self.optimizer = tf.get_collection("optimizer")[0]
self.learning_rate = tf.get_collection("learning_rate")[0]
else: else:
self.create_graphs(train_data_shuffler, validation_data_shuffler) # Bootstraping all the graphs
self.bootstrap_graphs(train_data_shuffler, validation_data_shuffler)
# TODO: find an elegant way to provide this as a parameter of the trainer # TODO: find an elegant way to provide this as a parameter of the trainer
self.global_step = tf.Variable(0, trainable=False) self.global_step = tf.Variable(0, trainable=False)
...@@ -314,6 +363,7 @@ class Trainer(object): ...@@ -314,6 +363,7 @@ class Trainer(object):
# Train summary # Train summary
self.summaries_train = self.create_general_summary() self.summaries_train = self.create_general_summary()
tf.add_to_collection("summaries_train", self.summaries_train)
tf.initialize_all_variables().run() tf.initialize_all_variables().run()
......
...@@ -101,6 +101,42 @@ class TripletTrainer(Trainer): ...@@ -101,6 +101,42 @@ class TripletTrainer(Trainer):
self.between_class_graph_validation = None self.between_class_graph_validation = None
self.within_class_graph_validation = None self.within_class_graph_validation = None
def bootstrap_placeholders(self, train_data_shuffler, validation_data_shuffler):
"""
Persist the placeholders
"""
# Persisting the placeholders
if self.prefetch:
batch, batch2, batch3 = train_data_shuffler.get_placeholders_forprefetch()
else:
batch, batch2, batch3 = train_data_shuffler.get_placeholders()
tf.add_to_collection("train_placeholder_data", batch)
tf.add_to_collection("train_placeholder_data2", batch2)
tf.add_to_collection("train_placeholder_data3", batch3)
# Creating validation graph
if validation_data_shuffler is not None:
batch, batch2, label = validation_data_shuffler.get_placeholders()
tf.add_to_collection("validation_placeholder_data", batch)
tf.add_to_collection("validation_placeholder_data2", batch2)
tf.add_to_collection("validation_placeholder_data3", batch3)
def bootstrap_placeholders_fromfile(self, train_data_shuffler, validation_data_shuffler):
"""
Load placeholders from file
"""
train_data_shuffler.set_placeholders(tf.get_collection("train_placeholder_data")[0],
tf.get_collection("train_placeholder_data2")[0],
tf.get_collection("train_placeholder_data3")[0])
if validation_data_shuffler is not None:
train_data_shuffler.set_placeholders(tf.get_collection("validation_placeholder_data")[0],
tf.get_collection("validation_placeholder_data2")[0],
tf.get_collection("validation_placeholder_data3")[0])
def compute_graph(self, data_shuffler, prefetch=False, name="", train=True): def compute_graph(self, data_shuffler, prefetch=False, name="", train=True):
""" """
Computes the graph for the trainer. Computes the graph for the trainer.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment