Commit 171ede05 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Patched the train.py script to accept tensors as input

parent 9359b23d
Pipeline #12834 passed with stages
in 13 minutes and 11 seconds
...@@ -87,22 +87,34 @@ def main(): ...@@ -87,22 +87,34 @@ def main():
trainer.create_network_from_file(output_dir) trainer.create_network_from_file(output_dir)
else: else:
# Preparing the architecture # Either bootstrap from scratch or take the pointer directly from the config script
input_pl = config.train_data_shuffler("data", from_queue=False) train_graph = None
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer): validation_graph = None
graph = dict()
graph['left'] = config.architecture(input_pl['left']) if hasattr(config, 'train_graph'):
graph['right'] = config.architecture(input_pl['right'], reuse=True) train_graph = config.train_graph
if hasattr(config, 'validation_graph'):
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer): validation_graph = config.validation_graph
graph = dict()
graph['anchor'] = config.architecture(input_pl['anchor'])
graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
else: else:
graph = config.architecture(input_pl) # Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
trainer.create_network_from_scratch(graph, loss=config.loss, if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
train_graph = dict()
train_graph['left'] = config.architecture(input_pl['left'])
train_graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
train_graph = dict()
train_graph['anchor'] = config.architecture(input_pl['anchor'])
train_graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
train_graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
else:
train_graph = config.architecture(input_pl)
trainer.create_network_from_scratch(train_graph,
validation_graph=validation_graph,
loss=config.loss,
learning_rate=config.learning_rate, learning_rate=config.learning_rate,
optimizer=config.optimizer) optimizer=config.optimizer)
trainer.train() trainer.train()
......
...@@ -121,6 +121,7 @@ class SiameseTrainer(Trainer): ...@@ -121,6 +121,7 @@ class SiameseTrainer(Trainer):
def create_network_from_scratch(self, def create_network_from_scratch(self,
graph, graph,
validation_graph=None,
optimizer=tf.train.AdamOptimizer(), optimizer=tf.train.AdamOptimizer(),
loss=None, loss=None,
......
...@@ -59,8 +59,8 @@ class Trainer(object): ...@@ -59,8 +59,8 @@ class Trainer(object):
###### training options ########## ###### training options ##########
iterations=5000, iterations=5000,
snapshot=500, snapshot=1000,
validation_snapshot=100, validation_snapshot=2000,
keep_checkpoint_every_n_hours=2, keep_checkpoint_every_n_hours=2,
## Analizer ## Analizer
......
...@@ -122,6 +122,7 @@ class TripletTrainer(Trainer): ...@@ -122,6 +122,7 @@ class TripletTrainer(Trainer):
def create_network_from_scratch(self, def create_network_from_scratch(self,
graph, graph,
validation_graph=None,
optimizer=tf.train.AdamOptimizer(), optimizer=tf.train.AdamOptimizer(),
loss=None, loss=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment