Patched the train.py script to accept tensors as input

parent 9359b23d
Pipeline #12834 passed with stages
in 13 minutes and 11 seconds
......@@ -87,22 +87,34 @@ def main():
trainer.create_network_from_file(output_dir)
else:
# Either bootstrap from scratch or take the pointer directly from the config script
train_graph = None
validation_graph = None
if hasattr(config, 'train_graph'):
train_graph = config.train_graph
if hasattr(config, 'validation_graph'):
validation_graph = config.validation_graph
else:
# Preparing the architecture
input_pl = config.train_data_shuffler("data", from_queue=False)
if isinstance(trainer, bob.learn.tensorflow.trainers.SiameseTrainer):
graph = dict()
graph['left'] = config.architecture(input_pl['left'])
graph['right'] = config.architecture(input_pl['right'], reuse=True)
train_graph = dict()
train_graph['left'] = config.architecture(input_pl['left'])
train_graph['right'] = config.architecture(input_pl['right'], reuse=True)
elif isinstance(trainer, bob.learn.tensorflow.trainers.TripletTrainer):
graph = dict()
graph['anchor'] = config.architecture(input_pl['anchor'])
graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
train_graph = dict()
train_graph['anchor'] = config.architecture(input_pl['anchor'])
train_graph['positive'] = config.architecture(input_pl['positive'], reuse=True)
train_graph['negative'] = config.architecture(input_pl['negative'], reuse=True)
else:
graph = config.architecture(input_pl)
train_graph = config.architecture(input_pl)
trainer.create_network_from_scratch(graph, loss=config.loss,
trainer.create_network_from_scratch(train_graph,
validation_graph=validation_graph,
loss=config.loss,
learning_rate=config.learning_rate,
optimizer=config.optimizer)
trainer.train()
......
......@@ -121,6 +121,7 @@ class SiameseTrainer(Trainer):
def create_network_from_scratch(self,
graph,
validation_graph=None,
optimizer=tf.train.AdamOptimizer(),
loss=None,
......
......@@ -59,8 +59,8 @@ class Trainer(object):
###### training options ##########
iterations=5000,
snapshot=500,
validation_snapshot=100,
snapshot=1000,
validation_snapshot=2000,
keep_checkpoint_every_n_hours=2,
## Analizer
......
......@@ -122,6 +122,7 @@ class TripletTrainer(Trainer):
def create_network_from_scratch(self,
graph,
validation_graph=None,
optimizer=tf.train.AdamOptimizer(),
loss=None,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment