Skip to content
Snippets Groups Projects
Commit 171ede05 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Patched the train.py script to accept tensors as input

parent 9359b23d
No related branches found
No related tags found
1 merge request!15Updates
Pipeline #
......@@ -87,22 +87,34 @@ def main():
trainer.create_network_from_file(output_dir)
else:
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
graph = dict()
graph['left'] = config.architecture(input_pl['left'])
graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
graph = dict()
graph['anchor'] = config.architecture(input_pl['anchor'])
graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
# Either bootstrap from scratch or take the pointer directly from the config script
train_graph = None
validation_graph = None
if hasattr(config, 'train_graph'):
train_graph = config.train_graph
if hasattr(config, 'validation_graph'):
validation_graph = config.validation_graph
else:
graph = config.architecture(input_pl)
trainer.create_network_from_scratch(graph, loss=config.loss,
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
train_graph = dict()
train_graph['left'] = config.architecture(input_pl['left'])
train_graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
train_graph = dict()
train_graph['anchor'] = config.architecture(input_pl['anchor'])
train_graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
train_graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
else:
train_graph = config.architecture(input_pl)
trainer.create_network_from_scratch(train_graph,
validation_graph=validation_graph,
loss=config.loss,
learning_rate=config.learning_rate,
optimizer=config.optimizer)
trainer.train()
......
......@@ -121,6 +121,7 @@ class SiameseTrainer(Trainer):
def create_network_from_scratch(self,
graph,
validation_graph=None,
optimizer=tf.train.AdamOptimizer(),
loss=None,
......
......@@ -59,8 +59,8 @@ class Trainer(object):
###### training options ##########
iterations=5000,
snapshot=500,
validation_snapshot=100,
snapshot=1000,
validation_snapshot=2000,
keep_checkpoint_every_n_hours=2,
## Analizer
......
......@@ -122,6 +122,7 @@ class TripletTrainer(Trainer):
def create_network_from_scratch(self,
graph,
validation_graph=None,
optimizer=tf.train.AdamOptimizer(),
loss=None,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment