Skip to content
Snippets Groups Projects

Updates

Merged Tiago de Freitas Pereira requested to merge updates into master
4 files
+ 31
17
Compare changes
  • Side-by-side
  • Inline

Files

+ 27
15
@@ -87,22 +87,34 @@ def main():
trainer.create_network_from_file(output_dir)
else:
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
graph = dict()
graph['left'] = config.architecture(input_pl['left'])
graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
graph = dict()
graph['anchor'] = config.architecture(input_pl['anchor'])
graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
# Either bootstrap from scratch or take the pointer directly from the config script
train_graph = None
validation_graph = None
if hasattr(config, 'train_graph'):
train_graph = config.train_graph
if hasattr(config, 'validation_graph'):
validation_graph = config.validation_graph
else:
graph = config.architecture(input_pl)
trainer.create_network_from_scratch(graph, loss=config.loss,
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
train_graph = dict()
train_graph['left'] = config.architecture(input_pl['left'])
train_graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
train_graph = dict()
train_graph['anchor'] = config.architecture(input_pl['anchor'])
train_graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
train_graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
else:
train_graph = config.architecture(input_pl)
trainer.create_network_from_scratch(train_graph,
validation_graph=validation_graph,
loss=config.loss,
learning_rate=config.learning_rate,
optimizer=config.optimizer)
trainer.train()
Loading